diff --git a/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go b/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go index 8adc62f..b82e07f 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go @@ -53,7 +53,7 @@ func TestWorkingAsIntended(t *testing.T) { Client: http.DefaultClient, Dialer: netxlite.DefaultDialer, MaxAcceptableBody: 1 << 24, - Resolver: &netxlite.ResolverSystem{}, + Resolver: netxlite.NewResolverSystem(), } srv := httptest.NewServer(handler) defer srv.Close() diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index bd4715d..20df40a 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -78,7 +78,7 @@ var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool() // NewResolver creates a new resolver from the specified config func NewResolver(config Config) model.Resolver { if config.BaseResolver == nil { - config.BaseResolver = &netxlite.ResolverSystem{} + config.BaseResolver = netxlite.NewResolverSystem() } var r model.Resolver = config.BaseResolver r = &netxlite.AddressResolver{ @@ -260,7 +260,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, } switch resolverURL.Scheme { case "system": - return &netxlite.ResolverSystem{}, nil + return netxlite.NewResolverSystem(), nil case "https": config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} httpClient := &http.Client{Transport: NewHTTPTransport(config)} diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 3cf2ea5..3224857 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -32,7 +32,7 @@ func TestNewResolverVanilla(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(*netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate) if !ok { t.Fatal("not the resolver we expected") } @@ -82,7 +82,7 @@ func TestNewResolverWithBogonFilter(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(*netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate) if !ok { t.Fatal("not the resolver we expected") } @@ -111,7 +111,7 @@ func TestNewResolverWithLogging(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(*netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate) if !ok { t.Fatal("not the resolver we expected") } @@ -141,7 +141,7 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(*netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate) if !ok { t.Fatal("not the resolver we expected") } @@ -170,7 +170,7 @@ func TestNewResolverWithReadWriteCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(*netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate) if !ok { t.Fatal("not the resolver we expected") } @@ -204,7 +204,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(*netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate) if !ok { t.Fatal("not the resolver we expected") } @@ -556,7 +556,7 @@ func TestNewDNSClientSystemResolver(t *testing.T) { if err != nil { t.Fatal(err) } - if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok { + if _, ok := dnsclient.(*netxlite.ResolverSystemDoNotInstantiate); !ok { t.Fatal("not the resolver we expected") } dnsclient.CloseIdleConnections() @@ -568,7 +568,7 @@ func TestNewDNSClientEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok { + if _, ok := dnsclient.(*netxlite.ResolverSystemDoNotInstantiate); !ok { t.Fatal("not the resolver we expected") } dnsclient.CloseIdleConnections() diff --git a/internal/engine/netx/resolver/integration_test.go b/internal/engine/netx/resolver/integration_test.go index 545f208..1bb9606 100644 --- a/internal/engine/netx/resolver/integration_test.go +++ b/internal/engine/netx/resolver/integration_test.go @@ -64,7 +64,7 @@ func testresolverquickidna(t *testing.T, reso model.Resolver) { } func TestNewResolverSystem(t *testing.T) { - reso := &netxlite.ResolverSystem{} + reso := netxlite.NewResolverSystem() testresolverquick(t, reso) testresolverquickidna(t, reso) } diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index b8736cd..688ba2b 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -123,7 +123,7 @@ func TestDialerResolver(t *testing.T) { t.Run("fails without a port", func(t *testing.T) { d := &dialerResolver{ Dialer: &DialerSystem{}, - Resolver: &resolverSystem{}, + Resolver: newResolverSystem(), } const missingPort = "ooni.nu" conn, err := d.DialContext(context.Background(), "tcp", missingPort) diff --git a/internal/netxlite/dnsovergetaddrinfo.go b/internal/netxlite/dnsovergetaddrinfo.go new file mode 100644 index 0000000..0aca6ed --- /dev/null +++ b/internal/netxlite/dnsovergetaddrinfo.go @@ -0,0 +1,125 @@ +package netxlite + +// +// DNS over getaddrinfo: fake transport to allow us to observe +// lookups using getaddrinfo as a DNSTransport. +// + +import ( + "context" + "net" + "time" + + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/model" +) + +// dnsOverGetaddrinfoTransport is a DNSTransport using getaddrinfo. +type dnsOverGetaddrinfoTransport struct { + testableTimeout time.Duration + testableLookupHost func(ctx context.Context, domain string) ([]string, error) +} + +var _ model.DNSTransport = &dnsOverGetaddrinfoTransport{} + +func (txp *dnsOverGetaddrinfoTransport) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + if query.Type() != dns.TypeANY { + return nil, ErrNoDNSTransport + } + addrs, err := txp.lookup(ctx, query.Domain()) + if err != nil { + return nil, err + } + resp := &dnsOverGetaddrinfoResponse{ + addrs: addrs, + query: query, + } + return resp, nil +} + +type dnsOverGetaddrinfoResponse struct { + addrs []string + query model.DNSQuery +} + +func (txp *dnsOverGetaddrinfoTransport) lookup( + ctx context.Context, hostname string) ([]string, error) { + // This code forces adding a shorter timeout to the domain name + // resolutions when using the system resolver. We have seen cases + // in which such a timeout becomes too large. One such case is + // described in https://github.com/ooni/probe/issues/1726. + addrsch, errch := make(chan []string, 1), make(chan error, 1) + ctx, cancel := context.WithTimeout(ctx, txp.timeout()) + defer cancel() + go func() { + addrs, err := txp.lookupfn()(ctx, hostname) + if err != nil { + errch <- err + return + } + addrsch <- addrs + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case addrs := <-addrsch: + return addrs, nil + case err := <-errch: + return nil, err + } +} + +func (txp *dnsOverGetaddrinfoTransport) timeout() time.Duration { + if txp.testableTimeout > 0 { + return txp.testableTimeout + } + return 15 * time.Second +} + +func (txp *dnsOverGetaddrinfoTransport) lookupfn() func(ctx context.Context, domain string) ([]string, error) { + if txp.testableLookupHost != nil { + return txp.testableLookupHost + } + return TProxy.DefaultResolver().LookupHost +} + +func (txp *dnsOverGetaddrinfoTransport) RequiresPadding() bool { + return false +} + +func (txp *dnsOverGetaddrinfoTransport) Network() string { + return TProxy.DefaultResolver().Network() +} + +func (txp *dnsOverGetaddrinfoTransport) Address() string { + return "" +} + +func (txp *dnsOverGetaddrinfoTransport) CloseIdleConnections() { + // nothing +} + +func (r *dnsOverGetaddrinfoResponse) Query() model.DNSQuery { + return r.query +} + +func (r *dnsOverGetaddrinfoResponse) Bytes() []byte { + return nil +} + +func (r *dnsOverGetaddrinfoResponse) Rcode() int { + return 0 +} + +func (r *dnsOverGetaddrinfoResponse) DecodeHTTPS() (*model.HTTPSSvc, error) { + return nil, ErrNoDNSTransport +} + +func (r *dnsOverGetaddrinfoResponse) DecodeLookupHost() ([]string, error) { + return r.addrs, nil +} + +func (r *dnsOverGetaddrinfoResponse) DecodeNS() ([]*net.NS, error) { + return nil, ErrNoDNSTransport +} diff --git a/internal/netxlite/dnsovergetaddrinfo_test.go b/internal/netxlite/dnsovergetaddrinfo_test.go new file mode 100644 index 0000000..78ca5dc --- /dev/null +++ b/internal/netxlite/dnsovergetaddrinfo_test.go @@ -0,0 +1,189 @@ +package netxlite + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestDNSOverGetaddrinfo(t *testing.T) { + t.Run("RequiresPadding", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{} + if txp.RequiresPadding() { + t.Fatal("expected false") + } + }) + + t.Run("Network", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{} + if txp.Network() != TProxy.DefaultResolver().Network() { + t.Fatal("unexpected Network") + } + }) + + t.Run("Address", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{} + if txp.Address() != "" { + t.Fatal("unexpected Address") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{} + txp.CloseIdleConnections() // does not crash + }) + + t.Run("check default timeout", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{} + if txp.timeout() != 15*time.Second { + t.Fatal("unexpected default timeout") + } + }) + + t.Run("check default lookup host func not nil", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{} + if txp.lookupfn() == nil { + t.Fatal("expected non-nil func here") + } + }) + + t.Run("RoundTrip", func(t *testing.T) { + t.Run("with invalid query type", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{ + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"8.8.8.8"}, nil + }, + } + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google", dns.TypeA, false) + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if !errors.Is(err, ErrNoDNSTransport) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } + }) + + t.Run("with success", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{ + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"8.8.8.8"}, nil + }, + } + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google", dns.TypeANY, false) + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("invalid addrs") + } + if resp.Query() != query { + t.Fatal("invalid query") + } + if len(resp.Bytes()) != 0 { + t.Fatal("invalid response bytes") + } + if resp.Rcode() != 0 { + t.Fatal("invalid rcode") + } + https, err := resp.DecodeHTTPS() + if !errors.Is(err, ErrNoDNSTransport) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("expected nil https") + } + ns, err := resp.DecodeNS() + if !errors.Is(err, ErrNoDNSTransport) { + t.Fatal("unexpected err", err) + } + if len(ns) != 0 { + t.Fatal("expected zero-length ns") + } + }) + + t.Run("with timeout and success", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + done := make(chan interface{}) + txp := &dnsOverGetaddrinfoTransport{ + testableTimeout: 1 * time.Microsecond, + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + defer wg.Done() + <-done + return []string{"8.8.8.8"}, nil + }, + } + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google", dns.TypeANY, false) + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("invalid resp") + } + close(done) + wg.Wait() + }) + + t.Run("with timeout and failure", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + done := make(chan interface{}) + txp := &dnsOverGetaddrinfoTransport{ + testableTimeout: 1 * time.Microsecond, + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + defer wg.Done() + <-done + return nil, errors.New("no such host") + }, + } + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google", dns.TypeANY, false) + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("invalid resp") + } + close(done) + wg.Wait() + }) + + t.Run("with NXDOMAIN", func(t *testing.T) { + txp := &dnsOverGetaddrinfoTransport{ + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, ErrOODNSNoSuchHost + }, + } + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google", dns.TypeANY, false) + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if err == nil || !strings.HasSuffix(err.Error(), "no such host") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("invalid resp") + } + }) + }) +} diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 0e72d09..b30ad57 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -10,32 +10,33 @@ package netxlite var ( DefaultDialer = &DialerSystem{} DefaultTLSHandshaker = defaultTLSHandshaker + NewResolverSystem = newResolverSystem NewConnUTLS = newConnUTLS - DefaultResolver = &resolverSystem{} + DefaultResolver = newResolverSystem() ) // These types export internal names to legacy ooni/probe-cli code. // // Deprecated: do not use these names in new code. type ( - DialerResolver = dialerResolver - DialerLogger = dialerLogger - HTTPTransportWrapper = httpTransportConnectionsCloser - HTTPTransportLogger = httpTransportLogger - ErrorWrapperDialer = dialerErrWrapper - ErrorWrapperQUICListener = quicListenerErrWrapper - ErrorWrapperQUICDialer = quicDialerErrWrapper - ErrorWrapperResolver = resolverErrWrapper - ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper - QUICListenerStdlib = quicListenerStdlib - QUICDialerQUICGo = quicDialerQUICGo - QUICDialerResolver = quicDialerResolver - QUICDialerLogger = quicDialerLogger - ResolverSystem = resolverSystem - ResolverLogger = resolverLogger - ResolverIDNA = resolverIDNA - TLSHandshakerConfigurable = tlsHandshakerConfigurable - TLSHandshakerLogger = tlsHandshakerLogger - TLSDialerLegacy = tlsDialer - AddressResolver = resolverShortCircuitIPAddr + DialerResolver = dialerResolver + DialerLogger = dialerLogger + HTTPTransportWrapper = httpTransportConnectionsCloser + HTTPTransportLogger = httpTransportLogger + ErrorWrapperDialer = dialerErrWrapper + ErrorWrapperQUICListener = quicListenerErrWrapper + ErrorWrapperQUICDialer = quicDialerErrWrapper + ErrorWrapperResolver = resolverErrWrapper + ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper + QUICListenerStdlib = quicListenerStdlib + QUICDialerQUICGo = quicDialerQUICGo + QUICDialerResolver = quicDialerResolver + QUICDialerLogger = quicDialerLogger + ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport + ResolverLogger = resolverLogger + ResolverIDNA = resolverIDNA + TLSHandshakerConfigurable = tlsHandshakerConfigurable + TLSHandshakerLogger = tlsHandshakerLogger + TLSDialerLegacy = tlsDialer + AddressResolver = resolverShortCircuitIPAddr ) diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolvercore.go similarity index 86% rename from internal/netxlite/resolver.go rename to internal/netxlite/resolvercore.go index ac97c22..c4d0009 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolvercore.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" "golang.org/x/net/idna" ) @@ -24,7 +25,13 @@ var ErrNoDNSTransport = errors.New("operation requires a DNS transport") // NewResolverStdlib creates a new Resolver by combining WrapResolver // with an internal "system" resolver type. func NewResolverStdlib(logger model.DebugLogger) model.Resolver { - return WrapResolver(logger, &resolverSystem{}) + return WrapResolver(logger, newResolverSystem()) +} + +func newResolverSystem() *resolverSystem { + return &resolverSystem{ + t: &dnsOverGetaddrinfoTransport{}, + } } // NewResolverUDP creates a new Resolver using DNS-over-UDP. @@ -73,62 +80,31 @@ func WrapResolver(logger model.DebugLogger, resolver model.Resolver) model.Resol // resolverSystem is the system resolver. type resolverSystem struct { - testableTimeout time.Duration - testableLookupHost func(ctx context.Context, domain string) ([]string, error) + t model.DNSTransport } var _ model.Resolver = &resolverSystem{} func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) { - // This code forces adding a shorter timeout to the domain name - // resolutions when using the system resolver. We have seen cases - // in which such a timeout becomes too large. One such case is - // described in https://github.com/ooni/probe/issues/1726. - addrsch, errch := make(chan []string, 1), make(chan error, 1) - ctx, cancel := context.WithTimeout(ctx, r.timeout()) - defer cancel() - go func() { - addrs, err := r.lookupHost()(ctx, hostname) - if err != nil { - errch <- err - return - } - addrsch <- addrs - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case addrs := <-addrsch: - return addrs, nil - case err := <-errch: + encoder := &DNSEncoderMiekg{} + query := encoder.Encode(hostname, dns.TypeANY, false) + resp, err := r.t.RoundTrip(ctx, query) + if err != nil { return nil, err } -} - -func (r *resolverSystem) timeout() time.Duration { - if r.testableTimeout > 0 { - return r.testableTimeout - } - return 15 * time.Second -} - -func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) ([]string, error) { - if r.testableLookupHost != nil { - return r.testableLookupHost - } - return TProxy.DefaultResolver().LookupHost + return resp.DecodeLookupHost() } func (r *resolverSystem) Network() string { - return TProxy.DefaultResolver().Network() + return r.t.Network() } func (r *resolverSystem) Address() string { - return "" + return r.t.Address() } func (r *resolverSystem) CloseIdleConnections() { - // nothing to do + r.t.CloseIdleConnections() } func (r *resolverSystem) LookupHTTPS( @@ -138,11 +114,6 @@ func (r *resolverSystem) LookupHTTPS( func (r *resolverSystem) LookupNS( ctx context.Context, domain string) ([]*net.NS, error) { - // TODO(bassosimone): figure out in which context it makes sense - // to issue this query. How is this implemented under the hood by - // the stdlib? Is it using /etc/resolve.conf on Unix? Until we - // known all these details, let's pretend this functionality does - // not exist in the stdlib and focus on custom resolvers. return nil, ErrNoDNSTransport } diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolvercore_test.go similarity index 89% rename from internal/netxlite/resolver_test.go rename to internal/netxlite/resolvercore_test.go index b0f265d..a4ee4db 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolvercore_test.go @@ -6,12 +6,11 @@ import ( "io" "net" "strings" - "sync" "testing" - "time" "github.com/apex/log" "github.com/google/go-cmp/cmp" + "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" ) @@ -25,7 +24,8 @@ func TestNewResolverSystem(t *testing.T) { } shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) - _ = errWrapper.Resolver.(*resolverSystem) + reso := errWrapper.Resolver.(*resolverSystem) + _ = reso.t.(*dnsOverGetaddrinfoTransport) } func TestNewResolverUDP(t *testing.T) { @@ -46,112 +46,95 @@ func TestNewResolverUDP(t *testing.T) { } func TestResolverSystem(t *testing.T) { - t.Run("Network and Address", func(t *testing.T) { - r := &resolverSystem{} - if r.Network() != getaddrinfoResolverNetwork() { + t.Run("Network", func(t *testing.T) { + expected := "antani" + r := &resolverSystem{ + t: &mocks.DNSTransport{ + MockNetwork: func() string { + return expected + }, + }, + } + if r.Network() != expected { t.Fatal("invalid Network") } - if r.Address() != "" { + }) + + t.Run("Address", func(t *testing.T) { + expected := "address" + r := &resolverSystem{ + t: &mocks.DNSTransport{ + MockAddress: func() string { + return expected + }, + }, + } + if r.Address() != expected { t.Fatal("invalid Address") } }) t.Run("CloseIdleConnections", func(t *testing.T) { - r := &resolverSystem{} - r.CloseIdleConnections() // to cover it - }) - - t.Run("check default timeout", func(t *testing.T) { - r := &resolverSystem{} - if r.timeout() != 15*time.Second { - t.Fatal("unexpected default timeout") + var called bool + r := &resolverSystem{ + t: &mocks.DNSTransport{ + MockCloseIdleConnections: func() { + called = true + }, + }, } - }) - - t.Run("check default lookup host func not nil", func(t *testing.T) { - r := &resolverSystem{} - if r.lookupHost() == nil { - t.Fatal("expected non-nil func here") + r.CloseIdleConnections() + if !called { + t.Fatal("not called") } }) t.Run("LookupHost", func(t *testing.T) { t.Run("with success", func(t *testing.T) { + expected := []string{"8.8.8.8", "8.8.4.4"} r := &resolverSystem{ - testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"8.8.8.8"}, nil + t: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + if query.Type() != dns.TypeANY { + return nil, errors.New("unexpected lookup type") + } + resp := &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + return expected, nil + }, + } + return resp, nil + }, }, } ctx := context.Background() - addrs, err := r.LookupHost(ctx, "example.antani") + addrs, err := r.LookupHost(ctx, "dns.google") if err != nil { t.Fatal(err) } - if len(addrs) != 1 || addrs[0] != "8.8.8.8" { - t.Fatal("invalid addrs") + if diff := cmp.Diff(expected, addrs); diff != "" { + t.Fatal(diff) } }) - t.Run("with timeout and success", func(t *testing.T) { - wg := &sync.WaitGroup{} - wg.Add(1) - done := make(chan interface{}) + t.Run("with failure", func(t *testing.T) { + expected := errors.New("mocked") r := &resolverSystem{ - testableTimeout: 1 * time.Microsecond, - testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { - defer wg.Done() - <-done - return []string{"8.8.8.8"}, nil + t: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + if query.Type() != dns.TypeANY { + return nil, errors.New("unexpected lookup type") + } + return nil, expected + }, }, } ctx := context.Background() - addrs, err := r.LookupHost(ctx, "example.antani") - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("not the error we expected", err) + addrs, err := r.LookupHost(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) } - if addrs != nil { - t.Fatal("invalid addrs") - } - close(done) - wg.Wait() - }) - - t.Run("with timeout and failure", func(t *testing.T) { - wg := &sync.WaitGroup{} - wg.Add(1) - done := make(chan interface{}) - r := &resolverSystem{ - testableTimeout: 1 * time.Microsecond, - testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { - defer wg.Done() - <-done - return nil, errors.New("no such host") - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "example.antani") - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("invalid addrs") - } - close(done) - wg.Wait() - }) - - t.Run("with NXDOMAIN", func(t *testing.T) { - r := &resolverSystem{ - testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, errors.New("no such host") - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "example.antani") - if err == nil || !strings.HasSuffix(err.Error(), "no such host") { - t.Fatal("not the error we expected", err) - } - if addrs != nil { + if len(addrs) != 0 { t.Fatal("invalid addrs") } }) @@ -174,8 +157,8 @@ func TestResolverSystem(t *testing.T) { if !errors.Is(err, ErrNoDNSTransport) { t.Fatal("not the error we expected") } - if ns != nil { - t.Fatal("expected nil result") + if len(ns) != 0 { + t.Fatal("expected no results") } }) } diff --git a/internal/netxlite/parallelresolver.go b/internal/netxlite/resolverparallel.go similarity index 100% rename from internal/netxlite/parallelresolver.go rename to internal/netxlite/resolverparallel.go diff --git a/internal/netxlite/parallelresolver_test.go b/internal/netxlite/resolverparallel_test.go similarity index 100% rename from internal/netxlite/parallelresolver_test.go rename to internal/netxlite/resolverparallel_test.go diff --git a/internal/netxlite/serialresolver.go b/internal/netxlite/resolverserial.go similarity index 100% rename from internal/netxlite/serialresolver.go rename to internal/netxlite/resolverserial.go diff --git a/internal/netxlite/serialresolver_test.go b/internal/netxlite/resolverserial_test.go similarity index 100% rename from internal/netxlite/serialresolver_test.go rename to internal/netxlite/resolverserial_test.go