refactor(netxlite): introduce the getaddrinfo transport (#775)

This diff modifies the system resolver to use a getaddrinf transport.

Obviously the transport is a fake, but its existence will allow us
to observe DNS events more naturally.

A lookup using the system resolver would be a ANY lookup that will
contain all the resolved IP addresses into the same response.

This change was also part of websteps-illustrated, albeit the way in
which I did it there was less clean than what we have here.

Ref issue: https://github.com/ooni/probe/issues/2096
This commit is contained in:
Simone Basso 2022-06-01 09:59:44 +02:00 committed by GitHub
parent 7e0b47311d
commit 923d81cdee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 432 additions and 163 deletions

View File

@ -53,7 +53,7 @@ func TestWorkingAsIntended(t *testing.T) {
Client: http.DefaultClient,
Dialer: netxlite.DefaultDialer,
MaxAcceptableBody: 1 << 24,
Resolver: &netxlite.ResolverSystem{},
Resolver: netxlite.NewResolverSystem(),
}
srv := httptest.NewServer(handler)
defer srv.Close()

View File

@ -78,7 +78,7 @@ var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool()
// NewResolver creates a new resolver from the specified config
func NewResolver(config Config) model.Resolver {
if config.BaseResolver == nil {
config.BaseResolver = &netxlite.ResolverSystem{}
config.BaseResolver = netxlite.NewResolverSystem()
}
var r model.Resolver = config.BaseResolver
r = &netxlite.AddressResolver{
@ -260,7 +260,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
}
switch resolverURL.Scheme {
case "system":
return &netxlite.ResolverSystem{}, nil
return netxlite.NewResolverSystem(), nil
case "https":
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
httpClient := &http.Client{Transport: NewHTTPTransport(config)}

View File

@ -32,7 +32,7 @@ func TestNewResolverVanilla(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -82,7 +82,7 @@ func TestNewResolverWithBogonFilter(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -111,7 +111,7 @@ func TestNewResolverWithLogging(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -141,7 +141,7 @@ func TestNewResolverWithSaver(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -170,7 +170,7 @@ func TestNewResolverWithReadWriteCache(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -204,7 +204,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystemDoNotInstantiate)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -556,7 +556,7 @@ func TestNewDNSClientSystemResolver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok {
if _, ok := dnsclient.(*netxlite.ResolverSystemDoNotInstantiate); !ok {
t.Fatal("not the resolver we expected")
}
dnsclient.CloseIdleConnections()
@ -568,7 +568,7 @@ func TestNewDNSClientEmpty(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok {
if _, ok := dnsclient.(*netxlite.ResolverSystemDoNotInstantiate); !ok {
t.Fatal("not the resolver we expected")
}
dnsclient.CloseIdleConnections()

View File

@ -64,7 +64,7 @@ func testresolverquickidna(t *testing.T, reso model.Resolver) {
}
func TestNewResolverSystem(t *testing.T) {
reso := &netxlite.ResolverSystem{}
reso := netxlite.NewResolverSystem()
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}

View File

@ -123,7 +123,7 @@ func TestDialerResolver(t *testing.T) {
t.Run("fails without a port", func(t *testing.T) {
d := &dialerResolver{
Dialer: &DialerSystem{},
Resolver: &resolverSystem{},
Resolver: newResolverSystem(),
}
const missingPort = "ooni.nu"
conn, err := d.DialContext(context.Background(), "tcp", missingPort)

View File

@ -0,0 +1,125 @@
package netxlite
//
// DNS over getaddrinfo: fake transport to allow us to observe
// lookups using getaddrinfo as a DNSTransport.
//
import (
"context"
"net"
"time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
)
// dnsOverGetaddrinfoTransport is a DNSTransport using getaddrinfo.
type dnsOverGetaddrinfoTransport struct {
testableTimeout time.Duration
testableLookupHost func(ctx context.Context, domain string) ([]string, error)
}
var _ model.DNSTransport = &dnsOverGetaddrinfoTransport{}
func (txp *dnsOverGetaddrinfoTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
if query.Type() != dns.TypeANY {
return nil, ErrNoDNSTransport
}
addrs, err := txp.lookup(ctx, query.Domain())
if err != nil {
return nil, err
}
resp := &dnsOverGetaddrinfoResponse{
addrs: addrs,
query: query,
}
return resp, nil
}
type dnsOverGetaddrinfoResponse struct {
addrs []string
query model.DNSQuery
}
func (txp *dnsOverGetaddrinfoTransport) lookup(
ctx context.Context, hostname string) ([]string, error) {
// This code forces adding a shorter timeout to the domain name
// resolutions when using the system resolver. We have seen cases
// in which such a timeout becomes too large. One such case is
// described in https://github.com/ooni/probe/issues/1726.
addrsch, errch := make(chan []string, 1), make(chan error, 1)
ctx, cancel := context.WithTimeout(ctx, txp.timeout())
defer cancel()
go func() {
addrs, err := txp.lookupfn()(ctx, hostname)
if err != nil {
errch <- err
return
}
addrsch <- addrs
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case addrs := <-addrsch:
return addrs, nil
case err := <-errch:
return nil, err
}
}
func (txp *dnsOverGetaddrinfoTransport) timeout() time.Duration {
if txp.testableTimeout > 0 {
return txp.testableTimeout
}
return 15 * time.Second
}
func (txp *dnsOverGetaddrinfoTransport) lookupfn() func(ctx context.Context, domain string) ([]string, error) {
if txp.testableLookupHost != nil {
return txp.testableLookupHost
}
return TProxy.DefaultResolver().LookupHost
}
func (txp *dnsOverGetaddrinfoTransport) RequiresPadding() bool {
return false
}
func (txp *dnsOverGetaddrinfoTransport) Network() string {
return TProxy.DefaultResolver().Network()
}
func (txp *dnsOverGetaddrinfoTransport) Address() string {
return ""
}
func (txp *dnsOverGetaddrinfoTransport) CloseIdleConnections() {
// nothing
}
func (r *dnsOverGetaddrinfoResponse) Query() model.DNSQuery {
return r.query
}
func (r *dnsOverGetaddrinfoResponse) Bytes() []byte {
return nil
}
func (r *dnsOverGetaddrinfoResponse) Rcode() int {
return 0
}
func (r *dnsOverGetaddrinfoResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
return nil, ErrNoDNSTransport
}
func (r *dnsOverGetaddrinfoResponse) DecodeLookupHost() ([]string, error) {
return r.addrs, nil
}
func (r *dnsOverGetaddrinfoResponse) DecodeNS() ([]*net.NS, error) {
return nil, ErrNoDNSTransport
}

View File

@ -0,0 +1,189 @@
package netxlite
import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"
"github.com/miekg/dns"
)
func TestDNSOverGetaddrinfo(t *testing.T) {
t.Run("RequiresPadding", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.RequiresPadding() {
t.Fatal("expected false")
}
})
t.Run("Network", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.Network() != TProxy.DefaultResolver().Network() {
t.Fatal("unexpected Network")
}
})
t.Run("Address", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.Address() != "" {
t.Fatal("unexpected Address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
txp.CloseIdleConnections() // does not crash
})
t.Run("check default timeout", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.timeout() != 15*time.Second {
t.Fatal("unexpected default timeout")
}
})
t.Run("check default lookup host func not nil", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{}
if txp.lookupfn() == nil {
t.Fatal("expected non-nil func here")
}
})
t.Run("RoundTrip", func(t *testing.T) {
t.Run("with invalid query type", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeA, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp")
}
})
t.Run("with success", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("invalid addrs")
}
if resp.Query() != query {
t.Fatal("invalid query")
}
if len(resp.Bytes()) != 0 {
t.Fatal("invalid response bytes")
}
if resp.Rcode() != 0 {
t.Fatal("invalid rcode")
}
https, err := resp.DecodeHTTPS()
if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("expected nil https")
}
ns, err := resp.DecodeNS()
if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("unexpected err", err)
}
if len(ns) != 0 {
t.Fatal("expected zero-length ns")
}
})
t.Run("with timeout and success", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan interface{})
txp := &dnsOverGetaddrinfoTransport{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer wg.Done()
<-done
return []string{"8.8.8.8"}, nil
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("invalid resp")
}
close(done)
wg.Wait()
})
t.Run("with timeout and failure", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan interface{})
txp := &dnsOverGetaddrinfoTransport{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer wg.Done()
<-done
return nil, errors.New("no such host")
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("invalid resp")
}
close(done)
wg.Wait()
})
t.Run("with NXDOMAIN", func(t *testing.T) {
txp := &dnsOverGetaddrinfoTransport{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, ErrOODNSNoSuchHost
},
}
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google", dns.TypeANY, false)
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("invalid resp")
}
})
})
}

View File

@ -10,32 +10,33 @@ package netxlite
var (
DefaultDialer = &DialerSystem{}
DefaultTLSHandshaker = defaultTLSHandshaker
NewResolverSystem = newResolverSystem
NewConnUTLS = newConnUTLS
DefaultResolver = &resolverSystem{}
DefaultResolver = newResolverSystem()
)
// These types export internal names to legacy ooni/probe-cli code.
//
// Deprecated: do not use these names in new code.
type (
DialerResolver = dialerResolver
DialerLogger = dialerLogger
HTTPTransportWrapper = httpTransportConnectionsCloser
HTTPTransportLogger = httpTransportLogger
ErrorWrapperDialer = dialerErrWrapper
ErrorWrapperQUICListener = quicListenerErrWrapper
ErrorWrapperQUICDialer = quicDialerErrWrapper
ErrorWrapperResolver = resolverErrWrapper
ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper
QUICListenerStdlib = quicListenerStdlib
QUICDialerQUICGo = quicDialerQUICGo
QUICDialerResolver = quicDialerResolver
QUICDialerLogger = quicDialerLogger
ResolverSystem = resolverSystem
ResolverLogger = resolverLogger
ResolverIDNA = resolverIDNA
TLSHandshakerConfigurable = tlsHandshakerConfigurable
TLSHandshakerLogger = tlsHandshakerLogger
TLSDialerLegacy = tlsDialer
AddressResolver = resolverShortCircuitIPAddr
DialerResolver = dialerResolver
DialerLogger = dialerLogger
HTTPTransportWrapper = httpTransportConnectionsCloser
HTTPTransportLogger = httpTransportLogger
ErrorWrapperDialer = dialerErrWrapper
ErrorWrapperQUICListener = quicListenerErrWrapper
ErrorWrapperQUICDialer = quicDialerErrWrapper
ErrorWrapperResolver = resolverErrWrapper
ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper
QUICListenerStdlib = quicListenerStdlib
QUICDialerQUICGo = quicDialerQUICGo
QUICDialerResolver = quicDialerResolver
QUICDialerLogger = quicDialerLogger
ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport
ResolverLogger = resolverLogger
ResolverIDNA = resolverIDNA
TLSHandshakerConfigurable = tlsHandshakerConfigurable
TLSHandshakerLogger = tlsHandshakerLogger
TLSDialerLegacy = tlsDialer
AddressResolver = resolverShortCircuitIPAddr
)

View File

@ -12,6 +12,7 @@ import (
"strings"
"time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
"golang.org/x/net/idna"
)
@ -24,7 +25,13 @@ var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
// NewResolverStdlib creates a new Resolver by combining WrapResolver
// with an internal "system" resolver type.
func NewResolverStdlib(logger model.DebugLogger) model.Resolver {
return WrapResolver(logger, &resolverSystem{})
return WrapResolver(logger, newResolverSystem())
}
func newResolverSystem() *resolverSystem {
return &resolverSystem{
t: &dnsOverGetaddrinfoTransport{},
}
}
// NewResolverUDP creates a new Resolver using DNS-over-UDP.
@ -73,62 +80,31 @@ func WrapResolver(logger model.DebugLogger, resolver model.Resolver) model.Resol
// resolverSystem is the system resolver.
type resolverSystem struct {
testableTimeout time.Duration
testableLookupHost func(ctx context.Context, domain string) ([]string, error)
t model.DNSTransport
}
var _ model.Resolver = &resolverSystem{}
func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
// This code forces adding a shorter timeout to the domain name
// resolutions when using the system resolver. We have seen cases
// in which such a timeout becomes too large. One such case is
// described in https://github.com/ooni/probe/issues/1726.
addrsch, errch := make(chan []string, 1), make(chan error, 1)
ctx, cancel := context.WithTimeout(ctx, r.timeout())
defer cancel()
go func() {
addrs, err := r.lookupHost()(ctx, hostname)
if err != nil {
errch <- err
return
}
addrsch <- addrs
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case addrs := <-addrsch:
return addrs, nil
case err := <-errch:
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, dns.TypeANY, false)
resp, err := r.t.RoundTrip(ctx, query)
if err != nil {
return nil, err
}
}
func (r *resolverSystem) timeout() time.Duration {
if r.testableTimeout > 0 {
return r.testableTimeout
}
return 15 * time.Second
}
func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) ([]string, error) {
if r.testableLookupHost != nil {
return r.testableLookupHost
}
return TProxy.DefaultResolver().LookupHost
return resp.DecodeLookupHost()
}
func (r *resolverSystem) Network() string {
return TProxy.DefaultResolver().Network()
return r.t.Network()
}
func (r *resolverSystem) Address() string {
return ""
return r.t.Address()
}
func (r *resolverSystem) CloseIdleConnections() {
// nothing to do
r.t.CloseIdleConnections()
}
func (r *resolverSystem) LookupHTTPS(
@ -138,11 +114,6 @@ func (r *resolverSystem) LookupHTTPS(
func (r *resolverSystem) LookupNS(
ctx context.Context, domain string) ([]*net.NS, error) {
// TODO(bassosimone): figure out in which context it makes sense
// to issue this query. How is this implemented under the hood by
// the stdlib? Is it using /etc/resolve.conf on Unix? Until we
// known all these details, let's pretend this functionality does
// not exist in the stdlib and focus on custom resolvers.
return nil, ErrNoDNSTransport
}

View File

@ -6,12 +6,11 @@ import (
"io"
"net"
"strings"
"sync"
"testing"
"time"
"github.com/apex/log"
"github.com/google/go-cmp/cmp"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
@ -25,7 +24,8 @@ func TestNewResolverSystem(t *testing.T) {
}
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
_ = errWrapper.Resolver.(*resolverSystem)
reso := errWrapper.Resolver.(*resolverSystem)
_ = reso.t.(*dnsOverGetaddrinfoTransport)
}
func TestNewResolverUDP(t *testing.T) {
@ -46,112 +46,95 @@ func TestNewResolverUDP(t *testing.T) {
}
func TestResolverSystem(t *testing.T) {
t.Run("Network and Address", func(t *testing.T) {
r := &resolverSystem{}
if r.Network() != getaddrinfoResolverNetwork() {
t.Run("Network", func(t *testing.T) {
expected := "antani"
r := &resolverSystem{
t: &mocks.DNSTransport{
MockNetwork: func() string {
return expected
},
},
}
if r.Network() != expected {
t.Fatal("invalid Network")
}
if r.Address() != "" {
})
t.Run("Address", func(t *testing.T) {
expected := "address"
r := &resolverSystem{
t: &mocks.DNSTransport{
MockAddress: func() string {
return expected
},
},
}
if r.Address() != expected {
t.Fatal("invalid Address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
r := &resolverSystem{}
r.CloseIdleConnections() // to cover it
})
t.Run("check default timeout", func(t *testing.T) {
r := &resolverSystem{}
if r.timeout() != 15*time.Second {
t.Fatal("unexpected default timeout")
var called bool
r := &resolverSystem{
t: &mocks.DNSTransport{
MockCloseIdleConnections: func() {
called = true
},
},
}
})
t.Run("check default lookup host func not nil", func(t *testing.T) {
r := &resolverSystem{}
if r.lookupHost() == nil {
t.Fatal("expected non-nil func here")
r.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("LookupHost", func(t *testing.T) {
t.Run("with success", func(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
r := &resolverSystem{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
t: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
if query.Type() != dns.TypeANY {
return nil, errors.New("unexpected lookup type")
}
resp := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return expected, nil
},
}
return resp, nil
},
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
addrs, err := r.LookupHost(ctx, "dns.google")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("invalid addrs")
if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff)
}
})
t.Run("with timeout and success", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan interface{})
t.Run("with failure", func(t *testing.T) {
expected := errors.New("mocked")
r := &resolverSystem{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer wg.Done()
<-done
return []string{"8.8.8.8"}, nil
t: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
if query.Type() != dns.TypeANY {
return nil, errors.New("unexpected lookup type")
}
return nil, expected
},
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
addrs, err := r.LookupHost(ctx, "dns.google")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if addrs != nil {
t.Fatal("invalid addrs")
}
close(done)
wg.Wait()
})
t.Run("with timeout and failure", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan interface{})
r := &resolverSystem{
testableTimeout: 1 * time.Microsecond,
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer wg.Done()
<-done
return nil, errors.New("no such host")
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
}
if addrs != nil {
t.Fatal("invalid addrs")
}
close(done)
wg.Wait()
})
t.Run("with NXDOMAIN", func(t *testing.T) {
r := &resolverSystem{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, errors.New("no such host")
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected", err)
}
if addrs != nil {
if len(addrs) != 0 {
t.Fatal("invalid addrs")
}
})
@ -174,8 +157,8 @@ func TestResolverSystem(t *testing.T) {
if !errors.Is(err, ErrNoDNSTransport) {
t.Fatal("not the error we expected")
}
if ns != nil {
t.Fatal("expected nil result")
if len(ns) != 0 {
t.Fatal("expected no results")
}
})
}