refactor(netxlite): introduce the getaddrinfo transport (#775)

This diff modifies the system resolver to use a getaddrinf transport.

Obviously the transport is a fake, but its existence will allow us
to observe DNS events more naturally.

A lookup using the system resolver would be a ANY lookup that will
contain all the resolved IP addresses into the same response.

This change was also part of websteps-illustrated, albeit the way in
which I did it there was less clean than what we have here.

Ref issue: https://github.com/ooni/probe/issues/2096
This commit is contained in:
Simone Basso 2022-06-01 09:59:44 +02:00 committed by GitHub
parent 7e0b47311d
commit 923d81cdee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 432 additions and 163 deletions

View File

@ -53,7 +53,7 @@ func TestWorkingAsIntended(t *testing.T) {
Client: http.DefaultClient, Client: http.DefaultClient,
Dialer: netxlite.DefaultDialer, Dialer: netxlite.DefaultDialer,
MaxAcceptableBody: 1 << 24, MaxAcceptableBody: 1 << 24,
Resolver: &netxlite.ResolverSystem{}, Resolver: netxlite.NewResolverSystem(),
} }
srv := httptest.NewServer(handler) srv := httptest.NewServer(handler)
defer srv.Close() defer srv.Close()

View File

@ -78,7 +78,7 @@ var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool()
// NewResolver creates a new resolver from the specified config // NewResolver creates a new resolver from the specified config
func NewResolver(config Config) model.Resolver { func NewResolver(config Config) model.Resolver {
if config.BaseResolver == nil { if config.BaseResolver == nil {
config.BaseResolver = &netxlite.ResolverSystem{} config.BaseResolver = netxlite.NewResolverSystem()
} }
var r model.Resolver = config.BaseResolver var r model.Resolver = config.BaseResolver
r = &netxlite.AddressResolver{ r = &netxlite.AddressResolver{
@ -260,7 +260,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
} }
switch resolverURL.Scheme { switch resolverURL.Scheme {
case "system": case "system":
return &netxlite.ResolverSystem{}, nil return netxlite.NewResolverSystem(), nil
case "https": case "https":
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
httpClient := &http.Client{Transport: NewHTTPTransport(config)} httpClient := &http.Client{Transport: NewHTTPTransport(config)}

View File

@ -32,7 +32,7 @@ func TestNewResolverVanilla(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
_, ok = ar.Resolver.(*netxlite.ResolverSystem) _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -82,7 +82,7 @@ func TestNewResolverWithBogonFilter(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
_, ok = ar.Resolver.(*netxlite.ResolverSystem) _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -111,7 +111,7 @@ func TestNewResolverWithLogging(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
_, ok = ar.Resolver.(*netxlite.ResolverSystem) _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -141,7 +141,7 @@ func TestNewResolverWithSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
_, ok = ar.Resolver.(*netxlite.ResolverSystem) _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -170,7 +170,7 @@ func TestNewResolverWithReadWriteCache(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
_, ok = ar.Resolver.(*netxlite.ResolverSystem) _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -204,7 +204,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
_, ok = ar.Resolver.(*netxlite.ResolverSystem) _, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -556,7 +556,7 @@ func TestNewDNSClientSystemResolver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok { if _, ok := dnsclient.(*netxlite.ResolverSystemDoNotInstantiate); !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -568,7 +568,7 @@ func TestNewDNSClientEmpty(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok { if _, ok := dnsclient.(*netxlite.ResolverSystemDoNotInstantiate); !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()

View File

@ -64,7 +64,7 @@ func testresolverquickidna(t *testing.T, reso model.Resolver) {
} }
func TestNewResolverSystem(t *testing.T) { func TestNewResolverSystem(t *testing.T) {
reso := &netxlite.ResolverSystem{} reso := netxlite.NewResolverSystem()
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }

View File

@ -123,7 +123,7 @@ func TestDialerResolver(t *testing.T) {
t.Run("fails without a port", func(t *testing.T) { t.Run("fails without a port", func(t *testing.T) {
d := &dialerResolver{ d := &dialerResolver{
Dialer: &DialerSystem{}, Dialer: &DialerSystem{},
Resolver: &resolverSystem{}, Resolver: newResolverSystem(),
} }
const missingPort = "ooni.nu" const missingPort = "ooni.nu"
conn, err := d.DialContext(context.Background(), "tcp", missingPort) conn, err := d.DialContext(context.Background(), "tcp", missingPort)

View File

@ -0,0 +1,125 @@
package netxlite
//
// DNS over getaddrinfo: fake transport to allow us to observe
// lookups using getaddrinfo as a DNSTransport.
//
import (
"context"
"net"
"time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
)
// dnsOverGetaddrinfoTransport is a DNSTransport using getaddrinfo.
type dnsOverGetaddrinfoTransport struct {
testableTimeout time.Duration
testableLookupHost func(ctx context.Context, domain string) ([]string, error)
}
var _ model.DNSTransport = &dnsOverGetaddrinfoTransport{}
func (txp *dnsOverGetaddrinfoTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
if query.Type() != dns.TypeANY {
return nil, ErrNoDNSTransport
}
addrs, err := txp.lookup(ctx, query.Domain())
if err != nil {
return nil, err
}
resp := &dnsOverGetaddrinfoResponse{
addrs: addrs,
query: query,
}
return resp, nil
}
type dnsOverGetaddrinfoResponse struct {
addrs []string
query model.DNSQuery
}
func (txp *dnsOverGetaddrinfoTransport) lookup(
ctx context.Context, hostname string) ([]string, error) {
// This code forces adding a shorter timeout to the domain name
// resolutions when using the system resolver. We have seen cases
// in which such a timeout becomes too large. One such case is
// described in https://github.com/ooni/probe/issues/1726.
addrsch, errch := make(chan []string, 1), make(chan error, 1)
ctx, cancel := context.WithTimeout(ctx, txp.timeout())
defer cancel()
go func() {
addrs, err := txp.lookupfn()(ctx, hostname)
if err != nil {
errch <- err
return
}
addrsch <- addrs
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case addrs := <-addrsch:
return addrs, nil
case err := <-errch:
return nil, err
}
}
func (txp *dnsOverGetaddrinfoTransport) timeout() time.Duration {
if txp.testableTimeout > 0 {
return txp.testableTimeout
}
return 15 * time.Second
}
func (txp *dnsOverGetaddrinfoTransport) lookupfn() func(ctx context.Context, domain string) ([]string, error) {
if txp.testableLookupHost != nil {
return txp.testableLookupHost
}
return TProxy.DefaultResolver().LookupHost
}
func (txp *dnsOverGetaddrinfoTransport) RequiresPadding() bool {
return false
}
func (txp *dnsOverGetaddrinfoTransport) Network() string {
return TProxy.DefaultResolver().Network()
}
func (txp *dnsOverGetaddrinfoTransport) Address() string {
return ""
}
func (txp *dnsOverGetaddrinfoTransport) CloseIdleConnections() {
// nothing
}
func (r *dnsOverGetaddrinfoResponse) Query() model.DNSQuery {
return r.query
}
func (r *dnsOverGetaddrinfoResponse) Bytes() []byte {
return nil
}
func (r *dnsOverGetaddrinfoResponse) Rcode() int {
return 0
}
func (r *dnsOverGetaddrinfoResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
return nil, ErrNoDNSTransport
}
func (r *dnsOverGetaddrinfoResponse) DecodeLookupHost() ([]string, error) {
return r.addrs, nil
}
func (r *dnsOverGetaddrinfoResponse) DecodeNS() ([]*net.NS, error) {
return nil, ErrNoDNSTransport
}

View File

@ -0,0 +1,189 @@
package netxlite
import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"
"github.com/miekg/dns"
)
func TestDNSOverGetaddrinfo(t *testing.T) {
t.Run("RequiresPadding", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.RequiresPadding() {
t.Fatal("expected false")
}
})
t.Run("Network", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.Network() != TProxy.DefaultResolver().Network() {
t.Fatal("unexpected Network")
}
})
t.Run("Address", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.Address() != "" {
t.Fatal("unexpected Address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
txp.CloseIdleConnections() // does not crash
})
t.Run("check default timeout", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.timeout() != 15*time.Second {
t.Fatal("unexpected default timeout")
}
})
t.Run("check default lookup host func not nil", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.lookupfn() == nil {
t.Fatal("expected non-nil func here")
}
})
t.Run("RoundTrip", func(t *testing.T) {
t.Run("with invalid query type", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeA, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp")
}
})
t.Run("with success", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("invalid addrs")
}
if resp.Query() != query {
t.Fatal("invalid query")
}
if len(resp.Bytes()) != 0 {
t.Fatal("invalid response bytes")
}
if resp.Rcode() != 0 {
t.Fatal("invalid rcode")
}
https, err := resp.DecodeHTTPS()
if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("expected nil https")
}
ns, err := resp.DecodeNS()
if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("unexpected err", err)
}
if len(ns) != 0 {
t.Fatal("expected zero-length ns")
}
})
t.Run("with timeout and success", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan interface{})
txp := &dnsOverGetaddrinfoTransport{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer wg.Done()
<-done
return []string{"8.8.8.8"}, nil
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("invalid resp")
}
close(done)
wg.Wait()
})
t.Run("with timeout and failure", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan interface{})
txp := &dnsOverGetaddrinfoTransport{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer wg.Done()
<-done
return nil, errors.New("no such host")
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("invalid resp")
}
close(done)
wg.Wait()
})
t.Run("with NXDOMAIN", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, ErrOODNSNoSuchHost
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("invalid resp")
}
})
})
}

View File

@ -10,8 +10,9 @@ package netxlite
var ( var (
DefaultDialer = &DialerSystem{} DefaultDialer = &DialerSystem{}
DefaultTLSHandshaker = defaultTLSHandshaker DefaultTLSHandshaker = defaultTLSHandshaker
NewResolverSystem = newResolverSystem
NewConnUTLS = newConnUTLS NewConnUTLS = newConnUTLS
DefaultResolver = &resolverSystem{} DefaultResolver = newResolverSystem()
) )
// These types export internal names to legacy ooni/probe-cli code. // These types export internal names to legacy ooni/probe-cli code.
@ -31,7 +32,7 @@ type (
QUICDialerQUICGo = quicDialerQUICGo QUICDialerQUICGo = quicDialerQUICGo
QUICDialerResolver = quicDialerResolver QUICDialerResolver = quicDialerResolver
QUICDialerLogger = quicDialerLogger QUICDialerLogger = quicDialerLogger
ResolverSystem = resolverSystem ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport
ResolverLogger = resolverLogger ResolverLogger = resolverLogger
ResolverIDNA = resolverIDNA ResolverIDNA = resolverIDNA
TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerConfigurable = tlsHandshakerConfigurable

View File

@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -24,7 +25,13 @@ var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
// NewResolverStdlib creates a new Resolver by combining WrapResolver // NewResolverStdlib creates a new Resolver by combining WrapResolver
// with an internal "system" resolver type. // with an internal "system" resolver type.
func NewResolverStdlib(logger model.DebugLogger) model.Resolver { func NewResolverStdlib(logger model.DebugLogger) model.Resolver {
return WrapResolver(logger, &resolverSystem{}) return WrapResolver(logger, newResolverSystem())
}
func newResolverSystem() *resolverSystem {
return &resolverSystem{
t: &dnsOverGetaddrinfoTransport{},
}
} }
// NewResolverUDP creates a new Resolver using DNS-over-UDP. // NewResolverUDP creates a new Resolver using DNS-over-UDP.
@ -73,62 +80,31 @@ func WrapResolver(logger model.DebugLogger, resolver model.Resolver) model.Resol
// resolverSystem is the system resolver. // resolverSystem is the system resolver.
type resolverSystem struct { type resolverSystem struct {
testableTimeout time.Duration t model.DNSTransport
testableLookupHost func(ctx context.Context, domain string) ([]string, error)
} }
var _ model.Resolver = &resolverSystem{} var _ model.Resolver = &resolverSystem{}
func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
// This code forces adding a shorter timeout to the domain name encoder := &DNSEncoderMiekg{}
// resolutions when using the system resolver. We have seen cases query := encoder.Encode(hostname, dns.TypeANY, false)
// in which such a timeout becomes too large. One such case is resp, err := r.t.RoundTrip(ctx, query)
// described in https://github.com/ooni/probe/issues/1726.
addrsch, errch := make(chan []string, 1), make(chan error, 1)
ctx, cancel := context.WithTimeout(ctx, r.timeout())
defer cancel()
go func() {
addrs, err := r.lookupHost()(ctx, hostname)
if err != nil { if err != nil {
errch <- err
return
}
addrsch <- addrs
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case addrs := <-addrsch:
return addrs, nil
case err := <-errch:
return nil, err return nil, err
} }
} return resp.DecodeLookupHost()
func (r *resolverSystem) timeout() time.Duration {
if r.testableTimeout > 0 {
return r.testableTimeout
}
return 15 * time.Second
}
func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) ([]string, error) {
if r.testableLookupHost != nil {
return r.testableLookupHost
}
return TProxy.DefaultResolver().LookupHost
} }
func (r *resolverSystem) Network() string { func (r *resolverSystem) Network() string {
return TProxy.DefaultResolver().Network() return r.t.Network()
} }
func (r *resolverSystem) Address() string { func (r *resolverSystem) Address() string {
return "" return r.t.Address()
} }
func (r *resolverSystem) CloseIdleConnections() { func (r *resolverSystem) CloseIdleConnections() {
// nothing to do r.t.CloseIdleConnections()
} }
func (r *resolverSystem) LookupHTTPS( func (r *resolverSystem) LookupHTTPS(
@ -138,11 +114,6 @@ func (r *resolverSystem) LookupHTTPS(
func (r *resolverSystem) LookupNS( func (r *resolverSystem) LookupNS(
ctx context.Context, domain string) ([]*net.NS, error) { ctx context.Context, domain string) ([]*net.NS, error) {
// TODO(bassosimone): figure out in which context it makes sense
// to issue this query. How is this implemented under the hood by
// the stdlib? Is it using /etc/resolve.conf on Unix? Until we
// known all these details, let's pretend this functionality does
// not exist in the stdlib and focus on custom resolvers.
return nil, ErrNoDNSTransport return nil, ErrNoDNSTransport
} }

View File

@ -6,12 +6,11 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/apex/log" "github.com/apex/log"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
@ -25,7 +24,8 @@ func TestNewResolverSystem(t *testing.T) {
} }
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
_ = errWrapper.Resolver.(*resolverSystem) reso := errWrapper.Resolver.(*resolverSystem)
_ = reso.t.(*dnsOverGetaddrinfoTransport)
} }
func TestNewResolverUDP(t *testing.T) { func TestNewResolverUDP(t *testing.T) {
@ -46,112 +46,95 @@ func TestNewResolverUDP(t *testing.T) {
} }
func TestResolverSystem(t *testing.T) { func TestResolverSystem(t *testing.T) {
t.Run("Network and Address", func(t *testing.T) { t.Run("Network", func(t *testing.T) {
r := &resolverSystem{} expected := "antani"
if r.Network() != getaddrinfoResolverNetwork() { r := &resolverSystem{
t: &mocks.DNSTransport{
MockNetwork: func() string {
return expected
},
},
}
if r.Network() != expected {
t.Fatal("invalid Network") t.Fatal("invalid Network")
} }
if r.Address() != "" { })
t.Run("Address", func(t *testing.T) {
expected := "address"
r := &resolverSystem{
t: &mocks.DNSTransport{
MockAddress: func() string {
return expected
},
},
}
if r.Address() != expected {
t.Fatal("invalid Address") t.Fatal("invalid Address")
} }
}) })
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
r := &resolverSystem{} var called bool
r.CloseIdleConnections() // to cover it r := &resolverSystem{
}) t: &mocks.DNSTransport{
MockCloseIdleConnections: func() {
t.Run("check default timeout", func(t *testing.T) { called = true
r := &resolverSystem{} },
if r.timeout() != 15*time.Second { },
t.Fatal("unexpected default timeout")
} }
}) r.CloseIdleConnections()
if !called {
t.Run("check default lookup host func not nil", func(t *testing.T) { t.Fatal("not called")
r := &resolverSystem{}
if r.lookupHost() == nil {
t.Fatal("expected non-nil func here")
} }
}) })
t.Run("LookupHost", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("with success", func(t *testing.T) { t.Run("with success", func(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
r := &resolverSystem{ r := &resolverSystem{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { t: &mocks.DNSTransport{
return []string{"8.8.8.8"}, nil MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
if query.Type() != dns.TypeANY {
return nil, errors.New("unexpected lookup type")
}
resp := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return expected, nil
},
}
return resp, nil
},
}, },
} }
ctx := context.Background() ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani") addrs, err := r.LookupHost(ctx, "dns.google")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(addrs) != 1 || addrs[0] != "8.8.8.8" { if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal("invalid addrs") t.Fatal(diff)
} }
}) })
t.Run("with timeout and success", func(t *testing.T) { t.Run("with failure", func(t *testing.T) {
wg := &sync.WaitGroup{} expected := errors.New("mocked")
wg.Add(1)
done := make(chan interface{})
r := &resolverSystem{ r := &resolverSystem{
testableTimeout: 1 * time.Microsecond, t: &mocks.DNSTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
defer wg.Done() if query.Type() != dns.TypeANY {
<-done return nil, errors.New("unexpected lookup type")
return []string{"8.8.8.8"}, nil }
return nil, expected
},
}, },
} }
ctx := context.Background() ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani") addrs, err := r.LookupHost(ctx, "dns.google")
if !errors.Is(err, context.DeadlineExceeded) { if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err) t.Fatal("unexpected err", err)
} }
if addrs != nil { if len(addrs) != 0 {
t.Fatal("invalid addrs")
}
close(done)
wg.Wait()
})
t.Run("with timeout and failure", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan interface{})
r := &resolverSystem{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer wg.Done()
<-done
return nil, errors.New("no such host")
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
}
if addrs != nil {
t.Fatal("invalid addrs")
}
close(done)
wg.Wait()
})
t.Run("with NXDOMAIN", func(t *testing.T) {
r := &resolverSystem{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, errors.New("no such host")
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected", err)
}
if addrs != nil {
t.Fatal("invalid addrs") t.Fatal("invalid addrs")
} }
}) })
@ -174,8 +157,8 @@ func TestResolverSystem(t *testing.T) {
if !errors.Is(err, ErrNoDNSTransport) { if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if ns != nil { if len(ns) != 0 {
t.Fatal("expected nil result") t.Fatal("expected no results")
} }
}) })
} }