diff --git a/internal/netxlite/dnsx/decoder.go b/internal/netxlite/dnsx/decoder.go index eaf0032..9e539d3 100644 --- a/internal/netxlite/dnsx/decoder.go +++ b/internal/netxlite/dnsx/decoder.go @@ -2,13 +2,20 @@ package dnsx import ( "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/model" "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" ) +// HTTPSSvc is an HTTPSSvc reply. +type HTTPSSvc = model.HTTPSSvc + // The Decoder decodes DNS replies. type Decoder interface { // DecodeLookupHost decodes an A or AAAA reply. DecodeLookupHost(qtype uint16, data []byte) ([]string, error) + + // DecodeHTTPS decodes an HTTPS reply. + DecodeHTTPS(data []byte) (*HTTPSSvc, error) } // MiekgDecoder uses github.com/miekg/dns to implement the Decoder. @@ -33,6 +40,37 @@ func (d *MiekgDecoder) parseReply(data []byte) (*dns.Msg, error) { } } +func (d *MiekgDecoder) DecodeHTTPS(data []byte) (*HTTPSSvc, error) { + reply, err := d.parseReply(data) + if err != nil { + return nil, err + } + out := &HTTPSSvc{} + for _, answer := range reply.Answer { + switch avalue := answer.(type) { + case *dns.HTTPS: + for _, v := range avalue.Value { + switch extv := v.(type) { + case *dns.SVCBAlpn: + out.ALPN = extv.Alpn + case *dns.SVCBIPv4Hint: + for _, ip := range extv.Hint { + out.IPv4 = append(out.IPv4, ip.String()) + } + case *dns.SVCBIPv6Hint: + for _, ip := range extv.Hint { + out.IPv6 = append(out.IPv6, ip.String()) + } + } + } + } + } + if len(out.ALPN) <= 0 { + return nil, errorsx.ErrOODNSNoAnswer + } + return out, nil +} + func (d *MiekgDecoder) DecodeLookupHost(qtype uint16, data []byte) ([]string, error) { reply, err := d.parseReply(data) if err != nil { diff --git a/internal/netxlite/dnsx/decoder_test.go b/internal/netxlite/dnsx/decoder_test.go index 6392a6b..2269b38 100644 --- a/internal/netxlite/dnsx/decoder_test.go +++ b/internal/netxlite/dnsx/decoder_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" ) @@ -198,3 +199,112 @@ func TestParseReply(t *testing.T) { t.Fatal("expected nil reply") } } + +func genReplyHTTPS(t *testing.T, alpns, ipv4, ipv6 []string) []byte { + question := dns.Question{ + Name: dns.Fqdn("x.org"), + Qtype: dns.TypeHTTPS, + Qclass: dns.ClassINET, + } + query := new(dns.Msg) + query.Id = dns.Id() + query.RecursionDesired = true + query.Question = make([]dns.Question, 1) + query.Question[0] = question + reply := new(dns.Msg) + reply.Compress = true + reply.MsgHdr.RecursionAvailable = true + reply.SetReply(query) + answer := &dns.HTTPS{ + SVCB: dns.SVCB{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn("x.org"), + Rrtype: dns.TypeHTTPS, + Class: dns.ClassINET, + Ttl: 100, + Rdlength: 0, + }, + Priority: 5, + Target: dns.Fqdn("x.org"), + Value: []dns.SVCBKeyValue{}, + }, + } + reply.Answer = append(reply.Answer, answer) + if len(alpns) > 0 { + answer.Value = append(answer.Value, &dns.SVCBAlpn{ + Alpn: alpns, + }) + answer.Hdr.Rdlength++ + } + if len(ipv4) > 0 { + var addrs []net.IP + for _, addr := range ipv4 { + addrs = append(addrs, net.ParseIP(addr)) + } + answer.Value = append(answer.Value, &dns.SVCBIPv4Hint{ + Hint: addrs, + }) + answer.Hdr.Rdlength++ + } + if len(ipv6) > 0 { + var addrs []net.IP + for _, addr := range ipv6 { + addrs = append(addrs, net.ParseIP(addr)) + } + answer.Value = append(answer.Value, &dns.SVCBIPv6Hint{ + Hint: addrs, + }) + answer.Hdr.Rdlength++ + } + data, err := reply.Pack() + if err != nil { + t.Fatal(err) + } + return data +} + +func TestDecodeHTTPS(t *testing.T) { + t.Run("with nil data", func(t *testing.T) { + d := &MiekgDecoder{} + reply, err := d.DecodeHTTPS(nil) + if err == nil || err.Error() != "dns: overflow unpacking uint16" { + t.Fatal("not the error we expected", err) + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("with empty answer", func(t *testing.T) { + data := genReplyHTTPS(t, nil, nil, nil) + d := &MiekgDecoder{} + reply, err := d.DecodeHTTPS(data) + if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { + t.Fatal("unexpected err", err) + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("with full answer", func(t *testing.T) { + alpn := []string{"h3"} + v4 := []string{"1.1.1.1"} + v6 := []string{"::1"} + data := genReplyHTTPS(t, alpn, v4, v6) + d := &MiekgDecoder{} + reply, err := d.DecodeHTTPS(data) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(alpn, reply.ALPN); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(v4, reply.IPv4); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(v6, reply.IPv6); diff != "" { + t.Fatal(diff) + } + }) +} diff --git a/internal/netxlite/dnsx/mocks/decoder.go b/internal/netxlite/dnsx/mocks/decoder.go index 55da00c..6b79394 100644 --- a/internal/netxlite/dnsx/mocks/decoder.go +++ b/internal/netxlite/dnsx/mocks/decoder.go @@ -1,11 +1,23 @@ package mocks +import "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/model" + +// HTTPSSvc is the result of HTTPS queries. +type HTTPSSvc = model.HTTPSSvc + // Decoder allows mocking dnsx.Decoder. type Decoder struct { - MockDecode func(qtype uint16, reply []byte) ([]string, error) + MockDecodeLookupHost func(qtype uint16, reply []byte) ([]string, error) + + MockDecodeHTTPS func(reply []byte) (*HTTPSSvc, error) } -// Decode calls MockDecode. -func (e *Decoder) Decode(qtype uint16, reply []byte) ([]string, error) { - return e.MockDecode(qtype, reply) +// DecodeLookupHost calls MockDecodeLookupHost. +func (e *Decoder) DecodeLookupHost(qtype uint16, reply []byte) ([]string, error) { + return e.MockDecodeLookupHost(qtype, reply) +} + +// DecodeHTTPS calls MockDecodeHTTPS. +func (e *Decoder) DecodeHTTPS(reply []byte) (*HTTPSSvc, error) { + return e.MockDecodeHTTPS(reply) } diff --git a/internal/netxlite/dnsx/mocks/decoder_test.go b/internal/netxlite/dnsx/mocks/decoder_test.go index ed6c13c..67437fa 100644 --- a/internal/netxlite/dnsx/mocks/decoder_test.go +++ b/internal/netxlite/dnsx/mocks/decoder_test.go @@ -8,14 +8,30 @@ import ( ) func TestDecoder(t *testing.T) { - t.Run("Decode", func(t *testing.T) { + t.Run("DecodeLookupHost", func(t *testing.T) { expected := errors.New("mocked error") e := &Decoder{ - MockDecode: func(qtype uint16, reply []byte) ([]string, error) { + MockDecodeLookupHost: func(qtype uint16, reply []byte) ([]string, error) { return nil, expected }, } - out, err := e.Decode(dns.TypeA, make([]byte, 17)) + out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17)) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if out != nil { + t.Fatal("unexpected out") + } + }) + + t.Run("DecodeHTTPS", func(t *testing.T) { + expected := errors.New("mocked error") + e := &Decoder{ + MockDecodeHTTPS: func(reply []byte) (*HTTPSSvc, error) { + return nil, expected + }, + } + out, err := e.DecodeHTTPS(make([]byte, 17)) if !errors.Is(err, expected) { t.Fatal("unexpected err", err) } diff --git a/internal/netxlite/dnsx/model/model.go b/internal/netxlite/dnsx/model/model.go new file mode 100644 index 0000000..75e9055 --- /dev/null +++ b/internal/netxlite/dnsx/model/model.go @@ -0,0 +1,14 @@ +// Package model contains the dnsx model. +package model + +// HTTPSSvc is an HTTPSSvc reply. +type HTTPSSvc struct { + // ALPN contains the ALPNs inside the HTTPS reply + ALPN []string + + // IPv4 contains the IPv4 hints. + IPv4 []string + + // IPv6 contains the IPv6 hints. + IPv6 []string +} diff --git a/internal/netxlite/dnsx/serial.go b/internal/netxlite/dnsx/serial.go index aa4e3ff..3f8eac2 100644 --- a/internal/netxlite/dnsx/serial.go +++ b/internal/netxlite/dnsx/serial.go @@ -61,6 +61,21 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str return addrs, nil } +// LookupHTTPS issues an HTTPS query without retrying on failure. +func (r *SerialResolver) LookupHTTPS( + ctx context.Context, hostname string) (*HTTPSSvc, error) { + querydata, err := r.Encoder.Encode( + hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) + if err != nil { + return nil, err + } + replydata, err := r.Txp.RoundTrip(ctx, querydata) + if err != nil { + return nil, err + } + return r.Decoder.DecodeHTTPS(replydata) +} + func (r *SerialResolver) lookupHostWithRetry( ctx context.Context, hostname string, qtype uint16) ([]string, error) { var errorslist []error diff --git a/internal/netxlite/dnsx/serial_test.go b/internal/netxlite/dnsx/serial_test.go index 6b944c8..afcdc84 100644 --- a/internal/netxlite/dnsx/serial_test.go +++ b/internal/netxlite/dnsx/serial_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" ) @@ -160,3 +161,93 @@ func TestSerialResolverCloseIdleConnections(t *testing.T) { t.Fatal("not called") } } + +func TestSerialResolverLookupHTTPS(t *testing.T) { + t.Run("for encoding error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &SerialResolver{ + Encoder: &mocks.Encoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return nil, expected + }, + }, + Decoder: nil, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.RoundTripper{ + MockRequiresPadding: func() bool { + return false + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) + + t.Run("for round-trip error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &SerialResolver{ + Encoder: &mocks.Encoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return make([]byte, 64), nil + }, + }, + Decoder: nil, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.RoundTripper{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return nil, expected + }, + MockRequiresPadding: func() bool { + return false + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) + + t.Run("for decode error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &SerialResolver{ + Encoder: &mocks.Encoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return make([]byte, 64), nil + }, + }, + Decoder: &mocks.Decoder{ + MockDecodeHTTPS: func(reply []byte) (*mocks.HTTPSSvc, error) { + return nil, expected + }, + }, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.RoundTripper{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return make([]byte, 128), nil + }, + MockRequiresPadding: func() bool { + return false + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) +} diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index bd41fe6..9af9092 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -97,6 +97,11 @@ func (r *ResolverLegacyAdapter) CloseIdleConnections() { } } +func (r *ResolverLegacyAdapter) LookupHTTPS( + ctx context.Context, domain string) (*HTTPSSvc, error) { + return nil, ErrNoDNSTransport +} + // DialerLegacy establishes network connections. // // This definition is DEPRECATED. Please, use Dialer. diff --git a/internal/netxlite/legacy_test.go b/internal/netxlite/legacy_test.go index af132cd..550d1d5 100644 --- a/internal/netxlite/legacy_test.go +++ b/internal/netxlite/legacy_test.go @@ -1,6 +1,8 @@ package netxlite import ( + "context" + "errors" "net" "testing" @@ -43,6 +45,17 @@ func TestResolverLegacyAdapter(t *testing.T) { } r.CloseIdleConnections() // does not crash }) + + t.Run("for LookupHTTPS", func(t *testing.T) { + r := NewResolverLegacyAdapter(&net.Resolver{}) + https, err := r.LookupHTTPS(context.Background(), "x.org") + if !errors.Is(err, ErrNoDNSTransport) { + t.Fatal("not the error we expected") + } + if https != nil { + t.Fatal("expected nil result") + } + }) } func TestDialerLegacyAdapter(t *testing.T) { diff --git a/internal/netxlite/mocks/resolver.go b/internal/netxlite/mocks/resolver.go index 3abf749..d75fbd4 100644 --- a/internal/netxlite/mocks/resolver.go +++ b/internal/netxlite/mocks/resolver.go @@ -1,6 +1,10 @@ package mocks -import "context" +import ( + "context" + + "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/model" +) // Resolver is a mockable Resolver. type Resolver struct { @@ -8,6 +12,7 @@ type Resolver struct { MockNetwork func() string MockAddress func() string MockCloseIdleConnections func() + MockLookupHTTPS func(ctx context.Context, domain string) (*HTTPSSvc, error) } // LookupHost calls MockLookupHost. @@ -29,3 +34,11 @@ func (r *Resolver) Network() string { func (r *Resolver) CloseIdleConnections() { r.MockCloseIdleConnections() } + +// HTTPSSvc is an HTTPSSvc reply. +type HTTPSSvc = model.HTTPSSvc + +// LookupHTTPS calls MockLookupHTTPS. +func (r *Resolver) LookupHTTPS(ctx context.Context, domain string) (*HTTPSSvc, error) { + return r.MockLookupHTTPS(ctx, domain) +} diff --git a/internal/netxlite/mocks/resolver_test.go b/internal/netxlite/mocks/resolver_test.go index b696201..5b6c144 100644 --- a/internal/netxlite/mocks/resolver_test.go +++ b/internal/netxlite/mocks/resolver_test.go @@ -58,4 +58,21 @@ func TestResolver(t *testing.T) { t.Fatal("not called") } }) + + t.Run("LookupHTTPS", func(t *testing.T) { + expected := errors.New("mocked error") + r := &Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*HTTPSSvc, error) { + return nil, expected + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if https != nil { + t.Fatal("expected nil addr") + } + }) } diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolver.go index ffaabba..c915a18 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolver.go @@ -7,10 +7,14 @@ import ( "net" "time" + "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx" "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" "golang.org/x/net/idna" ) +// HTTPSSvc is the type returned for HTTPSSvc queries. +type HTTPSSvc = dnsx.HTTPSSvc + // Resolver performs domain name resolutions. type Resolver interface { // LookupHost behaves like net.Resolver.LookupHost. @@ -24,8 +28,17 @@ type Resolver interface { // CloseIdleConnections closes idle connections, if any. CloseIdleConnections() + + // LookupHTTPS issues a single HTTPS query for + // a domain without any retry mechanism whatsoever. + LookupHTTPS( + ctx context.Context, domain string) (*HTTPSSvc, error) } +// ErrNoDNSTransport indicates that the requested Resolver operation +// cannot be performed because we're using the "system" resolver. +var ErrNoDNSTransport = errors.New("operation requires a DNS transport") + // NewResolverStdlib creates a new Resolver by combining // WrapResolver with an internal "system" resolver type that // adds extra functionality to net.Resolver. @@ -120,6 +133,11 @@ func (r *resolverSystem) CloseIdleConnections() { // nothing to do } +func (r *resolverSystem) LookupHTTPS( + ctx context.Context, domain string) (*HTTPSSvc, error) { + return nil, ErrNoDNSTransport +} + // resolverLogger is a resolver that emits events type resolverLogger struct { Resolver @@ -142,6 +160,24 @@ func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]str return addrs, nil } +func (r *resolverLogger) LookupHTTPS( + ctx context.Context, domain string) (*HTTPSSvc, error) { + prefix := fmt.Sprintf("resolve[HTTPS] %s with %s (%s)", domain, r.Network(), r.Address()) + r.Logger.Debugf("%s...", prefix) + start := time.Now() + https, err := r.Resolver.LookupHTTPS(ctx, domain) + elapsed := time.Since(start) + if err != nil { + r.Logger.Debugf("%s... %s in %s", prefix, err, elapsed) + return nil, err + } + alpn := https.ALPN + a := https.IPv4 + aaaa := https.IPv6 + r.Logger.Debugf("%s... %+v %+v %+v in %s", prefix, alpn, a, aaaa, elapsed) + return https, nil +} + // resolverIDNA supports resolving Internationalized Domain Names. // // See RFC3492 for more information. @@ -157,6 +193,15 @@ func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]strin return r.Resolver.LookupHost(ctx, host) } +func (r *resolverIDNA) LookupHTTPS( + ctx context.Context, domain string) (*HTTPSSvc, error) { + host, err := idna.ToASCII(domain) + if err != nil { + return nil, err + } + return r.Resolver.LookupHTTPS(ctx, host) +} + // resolverShortCircuitIPAddr recognizes when the input hostname is an // IP address and returns it immediately to the caller. type resolverShortCircuitIPAddr struct { @@ -193,6 +238,11 @@ func (r *nullResolver) CloseIdleConnections() { // nothing to do } +func (r *nullResolver) LookupHTTPS( + ctx context.Context, domain string) (*HTTPSSvc, error) { + return nil, ErrNoResolver +} + // resolverErrWrapper is a Resolver that knows about wrapping errors. type resolverErrWrapper struct { Resolver @@ -208,3 +258,13 @@ func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([ } return addrs, nil } + +func (r *resolverErrWrapper) LookupHTTPS( + ctx context.Context, domain string) (*HTTPSSvc, error) { + out, err := r.Resolver.LookupHTTPS(ctx, domain) + if err != nil { + return nil, errorsx.NewErrWrapper( + errorsx.ClassifyResolverError, errorsx.ResolveOperation, err) + } + return out, nil +} diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolver_test.go index e34ad55..b0b3569 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolver_test.go @@ -134,6 +134,17 @@ func TestResolverSystem(t *testing.T) { } }) }) + + t.Run("LookupHTTPS", func(t *testing.T) { + r := &resolverSystem{} + https, err := r.LookupHTTPS(context.Background(), "x.org") + if !errors.Is(err, ErrNoDNSTransport) { + t.Fatal("not the error we expected") + } + if https != nil { + t.Fatal("expected nil result") + } + }) } func TestResolverLogger(t *testing.T) { @@ -206,6 +217,80 @@ func TestResolverLogger(t *testing.T) { } }) }) + + t.Run("LookupHTTPS", func(t *testing.T) { + t.Run("with success", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } + expected := &HTTPSSvc{ + ALPN: []string{"h3"}, + IPv4: []string{"1.1.1.1"}, + } + r := &resolverLogger{ + Logger: lo, + Resolver: &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*HTTPSSvc, error) { + return expected, nil + }, + MockNetwork: func() string { + return "system" + }, + MockAddress: func() string { + return "" + }, + }, + } + https, err := r.LookupHTTPS(context.Background(), "dns.google") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expected, https); diff != "" { + t.Fatal(diff) + } + if count != 2 { + t.Fatal("unexpected count") + } + }) + + t.Run("with failure", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } + expected := errors.New("mocked error") + r := &resolverLogger{ + Logger: lo, + Resolver: &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*HTTPSSvc, error) { + return nil, expected + }, + MockNetwork: func() string { + return "system" + }, + MockAddress: func() string { + return "" + }, + }, + } + https, err := r.LookupHTTPS(context.Background(), "dns.google") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if https != nil { + t.Fatal("expected nil addr here") + } + if count != 2 { + t.Fatal("unexpected count") + } + }) + }) + } func TestResolverIDNA(t *testing.T) { @@ -249,6 +334,51 @@ func TestResolverIDNA(t *testing.T) { } }) }) + + t.Run("LookupHTTPS", func(t *testing.T) { + t.Run("with valid IDNA in input", func(t *testing.T) { + expected := &HTTPSSvc{ + ALPN: []string{"h3"}, + IPv4: []string{"1.1.1.1"}, + IPv6: []string{}, + } + r := &resolverIDNA{ + Resolver: &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*HTTPSSvc, error) { + if domain != "xn--d1acpjx3f.xn--p1ai" { + return nil, errors.New("passed invalid domain") + } + return expected, nil + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "яндекс.рф") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expected, https); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("with invalid punycode", func(t *testing.T) { + r := &resolverIDNA{Resolver: &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*HTTPSSvc, error) { + return nil, errors.New("should not happen") + }, + }} + // See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/ + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "xn--0000h") + if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") { + t.Fatal("not the error we expected") + } + if https != nil { + t.Fatal("expected no response here") + } + }) + }) } func TestResolverShortCircuitIPAddr(t *testing.T) { @@ -292,22 +422,43 @@ func TestResolverShortCircuitIPAddr(t *testing.T) { } func TestNullResolver(t *testing.T) { - r := &nullResolver{} - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "dns.google") - if !errors.Is(err, ErrNoResolver) { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("expected nil addr") - } - if r.Network() != "null" { - t.Fatal("invalid network") - } - if r.Address() != "" { - t.Fatal("invalid address") - } - r.CloseIdleConnections() // for coverage + t.Run("LookupHost", func(t *testing.T) { + r := &nullResolver{} + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "dns.google") + if !errors.Is(err, ErrNoResolver) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil addr") + } + if r.Network() != "null" { + t.Fatal("invalid network") + } + if r.Address() != "" { + t.Fatal("invalid address") + } + r.CloseIdleConnections() // for coverage + }) + + t.Run("LookupHTTPS", func(t *testing.T) { + r := &nullResolver{} + ctx := context.Background() + addrs, err := r.LookupHTTPS(ctx, "dns.google") + if !errors.Is(err, ErrNoResolver) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil addr") + } + if r.Network() != "null" { + t.Fatal("invalid network") + } + if r.Address() != "" { + t.Fatal("invalid address") + } + r.CloseIdleConnections() // for coverage + }) } func TestResolverErrWrapper(t *testing.T) { @@ -393,4 +544,46 @@ func TestResolverErrWrapper(t *testing.T) { t.Fatal("not called") } }) + + t.Run("LookupHTTPS", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expected := &HTTPSSvc{ + ALPN: []string{"h3"}, + } + reso := &resolverErrWrapper{ + Resolver: &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*HTTPSSvc, error) { + return expected, nil + }, + }, + } + ctx := context.Background() + https, err := reso.LookupHTTPS(ctx, "") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expected, https); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := io.EOF + reso := &resolverErrWrapper{ + Resolver: &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*HTTPSSvc, error) { + return nil, expected + }, + }, + } + ctx := context.Background() + https, err := reso.LookupHTTPS(ctx, "") + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected addrs") + } + }) + }) }