From fc51590a67c65a8e721322454e95c2290fed1b5f Mon Sep 17 00:00:00 2001 From: DecFox <33030671+DecFox@users.noreply.github.com> Date: Thu, 11 Aug 2022 19:30:37 +0530 Subject: [PATCH] feat: refactor dns implementation in measurexlite (#857) * refactor: remove query-based mapping and introducing resolver wrapper * refactor dnsping to adapt to measurexlite * dnsping: extra comments * Apply suggestions from code review * Update internal/measurexlite/dns_test.go See https://github.com/ooni/probe/issues/2208 Co-authored-by: decfox Co-authored-by: Simone Basso --- internal/engine/experiment/dnsping/dnsping.go | 21 ++- internal/measurexlite/dns.go | 38 ++--- internal/measurexlite/dns_test.go | 152 ++++++------------ internal/measurexlite/trace.go | 72 ++++----- internal/measurexlite/trace_test.go | 111 ++++++++----- 5 files changed, 173 insertions(+), 221 deletions(-) diff --git a/internal/engine/experiment/dnsping/dnsping.go b/internal/engine/experiment/dnsping/dnsping.go index b03f005..ae835c1 100644 --- a/internal/engine/experiment/dnsping/dnsping.go +++ b/internal/engine/experiment/dnsping/dnsping.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" @@ -148,18 +147,16 @@ func (m *Measurer) dnsRoundTrip(ctx context.Context, index int64, zeroTime time. resolver := trace.NewParallelUDPResolver(logger, dialer, address) _, err := resolver.LookupHost(ctx, domain) ol.Stop(err) - // Add the dns.TypeA ping - pings = append(pings, m.makePingFromLookup(<-trace.DNSLookup[dns.TypeA])) - // Add the dns.TypeAAAA ping - pings = append(pings, m.makePingFromLookup(<-trace.DNSLookup[dns.TypeAAAA])) - tk.addPings(pings) -} - -// makePingfromLookup returns a SinglePing from the result of a single query -func (m *Measurer) makePingFromLookup(lookup *model.ArchivalDNSLookupResult) (pings *SinglePing) { - return &SinglePing{ - Query: lookup, + for _, lookup := range trace.DNSLookupsFromRoundTrip() { + // make sure we only include the query types we care about (in principle, there + // should be no other query, so we're doing this just for robustness). + if lookup.QueryType == "A" || lookup.QueryType == "AAAA" { + pings = append(pings, &SinglePing{ + Query: lookup, + }) + } } + tk.addPings(pings) } // NewExperimentMeasurer creates a new ExperimentMeasurer. diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index 37ab72b..14614c1 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -17,11 +17,10 @@ import ( "github.com/ooni/probe-cli/v3/internal/tracex" ) -// newParallelResolverTrace is equivalent to netxlite.NewParallelResolver -// except that it returns a model.Resolver that uses this trace. -func (tx *Trace) newParallelResolverTrace(newResolver func() model.Resolver) model.Resolver { +// wrapResolver resolver wraps the passed resolver to save data into the trace +func (tx *Trace) wrapResolver(resolver model.Resolver) model.Resolver { return &resolverTrace{ - r: tx.newParallelResolver(newResolver), + r: resolver, tx: tx, } } @@ -66,29 +65,20 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS, // NewParallelUDPResolver returns a trace-ware parallel UDP resolver func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - return tx.newParallelResolverTrace(func() model.Resolver { - return netxlite.NewParallelUDPResolver(logger, dialer, address) - }) + return tx.wrapResolver(tx.newParallelUDPResolver(logger, dialer, address)) } // NewParallelDNSOverHTTPSResolver returns a trace-aware parallel DoH resolver func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { - return tx.newParallelResolverTrace(func() model.Resolver { - return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL) - }) + return tx.wrapResolver(tx.newParallelDNSOverHTTPSResolver(logger, URL)) } // OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) { - ch := tx.DNSLookup[query.Type()] - if ch == nil { - // Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms. - log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query.Type()]) - return - } + t := finished.Sub(tx.ZeroTime) select { - case ch <- NewArchivalDNSLookupResultFromRoundTrip( + case tx.DNSLookup <- NewArchivalDNSLookupResultFromRoundTrip( tx.Index, started.Sub(tx.ZeroTime), reso, @@ -96,7 +86,7 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol response, addrs, err, - finished.Sub(tx.ZeroTime), + t, ): default: } @@ -151,17 +141,11 @@ func archivalAnswersFromAddrs(addrs []string) (out []model.ArchivalDNSAnswer) { return } -// DNSLookupsFromRoundTrip drains the network events buffered inside the corresponding query channel -func (tx *Trace) DNSLookupsFromRoundTrip(query uint16) (out []*model.ArchivalDNSLookupResult) { - ch := tx.DNSLookup[query] - if ch == nil { - // Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms. - log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query]) - return - } +// DNSLookupsFromRoundTrip drains the network events buffered inside the DNSLookup channel +func (tx *Trace) DNSLookupsFromRoundTrip() (out []*model.ArchivalDNSLookupResult) { for { select { - case ev := <-ch: + case ev := <-tx.DNSLookup: out = append(out, ev) default: return diff --git a/internal/measurexlite/dns_test.go b/internal/measurexlite/dns_test.go index f3c9c6c..8b6e7ee 100644 --- a/internal/measurexlite/dns_test.go +++ b/internal/measurexlite/dns_test.go @@ -13,17 +13,11 @@ import ( ) func TestNewUnwrappedParallelResolver(t *testing.T) { - t.Run("NewUnwrappedParallelResolver creates an UnwrappedParallelResolver with Trace", func(t *testing.T) { + t.Run("WrapResolver creates a wrapped resolver with Trace", func(t *testing.T) { underlying := &mocks.Resolver{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.NewParallelResolverFn = func() model.Resolver { - return underlying - } - resolver := trace.newParallelResolverTrace(func() model.Resolver { - return nil - }) - resolvert := resolver.(*resolverTrace) + resolvert := trace.wrapResolver(underlying).(*resolverTrace) if resolvert.r != underlying { t.Fatal("invalid parallel resolver") } @@ -36,20 +30,18 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { var called bool zeroTime := time.Now() trace := NewTrace(0, zeroTime) - newMockResolver := func() model.Resolver { - return &mocks.Resolver{ - MockAddress: func() string { - return "dns.google" - }, - MockNetwork: func() string { - return "udp" - }, - MockCloseIdleConnections: func() { - called = true - }, - } + mockResolver := &mocks.Resolver{ + MockAddress: func() string { + return "dns.google" + }, + MockNetwork: func() string { + return "udp" + }, + MockCloseIdleConnections: func() { + called = true + }, } - resolver := trace.newParallelResolver(newMockResolver) + resolver := trace.wrapResolver(mockResolver) t.Run("Address is correctly forwarded", func(t *testing.T) { got := resolver.Address() @@ -94,16 +86,14 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { return true }, MockNetwork: func() string { - return "" + return "mocked" }, MockAddress: func() string { return "dns.google" }, } - newResolver := func() model.Resolver { - return netxlite.NewUnwrappedParallelResolver(txp) - } - resolver := trace.newParallelResolverTrace(newResolver) + r := netxlite.NewUnwrappedParallelResolver(txp) + resolver := trace.wrapResolver(r) ctx := context.Background() addrs, err := resolver.LookupHost(ctx, "example.com") if err != nil { @@ -119,45 +109,27 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { t.Fatal("unexpected array output", addrs) } - t.Run("DNSLookups QueryType A", func(t *testing.T) { - events := trace.DNSLookupsFromRoundTrip(dns.TypeA) - if len(events) != 1 { - t.Fatal("expected to see single DNSLookup event") + t.Run("DNSLookup events", func(t *testing.T) { + events := trace.DNSLookupsFromRoundTrip() + if len(events) != 2 { + t.Fatal("unexpected DNS events") } - lookup := events[0] - answers := lookup.Answers - if lookup.Failure != nil { - t.Fatal("unexpected err", *(lookup.Failure)) - } - if lookup.ResolverAddress != "dns.google" { - t.Fatal("unexpected address field") - } - if len(answers) != 1 { - t.Fatal("expected 1 DNS answer, got", len(answers)) - } - if answers[0].AnswerType != "A" || answers[0].IPv4 != "1.1.1.1" { - t.Fatal("unexpected DNS answer", answers) - } - }) - - t.Run("DNSLookups QueryType AAAA", func(t *testing.T) { - events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA) - if len(events) != 1 { - t.Fatal("expected to see single DNSLookup event") - } - lookup := events[0] - answers := lookup.Answers - if lookup.Failure != nil { - t.Fatal("unexpected err", *(lookup.Failure)) - } - if lookup.ResolverAddress != "dns.google" { - t.Fatal("unexpected address field") - } - if len(answers) != 1 { - t.Fatal("expected 1 DNS answer, got", len(answers)) - } - if answers[0].AnswerType != "AAAA" || answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" { - t.Fatal("unexpected DNS answer", answers) + for _, ev := range events { + if ev.ResolverAddress != "dns.google" { + t.Fatal("unexpected resolver address") + } + if ev.Engine != "mocked" { + t.Fatal("unexpected engine") + } + if len(ev.Answers) != 1 { + t.Fatal("expected single answer in DNSLookup event") + } + if ev.QueryType == "A" && ev.Answers[0].IPv4 != "1.1.1.1" { + t.Fatal("unexpected A query result") + } + if ev.QueryType == "AAAA" && ev.Answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" { + t.Fatal("unexpected AAAA query result") + } } }) }) @@ -166,10 +138,7 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { zeroTime := time.Now() td := testingx.NewTimeDeterministic(zeroTime) trace := NewTrace(0, zeroTime) - trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{ - dns.TypeA: make(chan *model.ArchivalDNSLookupResult), // no buffer - dns.TypeAAAA: make(chan *model.ArchivalDNSLookupResult), // no buffer - } + trace.DNSLookup = make(chan *model.ArchivalDNSLookupResult) // no buffer trace.TimeNowFn = td.Now txp := &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { @@ -193,10 +162,8 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { return "dns.google" }, } - newResolver := func() model.Resolver { - return netxlite.NewUnwrappedParallelResolver(txp) - } - resolver := trace.newParallelResolverTrace(newResolver) + r := netxlite.NewUnwrappedParallelResolver(txp) + resolver := trace.wrapResolver(r) ctx := context.Background() addrs, err := resolver.LookupHost(ctx, "example.com") if err != nil { @@ -205,17 +172,17 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { if len(addrs) != 2 { t.Fatal("unexpected array output", addrs) } + if addrs[0] != "1.1.1.1" && addrs[1] != "1.1.1.1" { + t.Fatal("unexpected array output", addrs) + } + if addrs[0] != "fe80::a00:20ff:feb9:4c54" && addrs[1] != "fe80::a00:20ff:feb9:4c54" { + t.Fatal("unexpected array output", addrs) + } - t.Run("DNSLookups QueryType A", func(t *testing.T) { - events := trace.DNSLookupsFromRoundTrip(dns.TypeA) + t.Run("DNSLookup Events", func(t *testing.T) { + events := trace.DNSLookupsFromRoundTrip() if len(events) != 0 { - t.Fatal("expected to see no DNSLookup") - } - }) - t.Run("DNSLookups QueryType AAAA", func(t *testing.T) { - events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA) - if len(events) != 0 { - t.Fatal("expected to see no DNSLookup") + t.Fatal("expected to see no DNSLookup events") } }) }) @@ -271,26 +238,3 @@ func TestAnswersFromAddrs(t *testing.T) { }) } } - -func TestDNSLookupsFromRoundTrips(t *testing.T) { - zeroTime := time.Now() - trace := NewTrace(0, zeroTime) - checkPanic := func(query uint16, f func(uint16) []*model.ArchivalDNSLookupResult) { - defer func() { - if r := recover(); r != nil { - t.Fatal("unexpected panic encoutered") - } - }() - f(query) - } - t.Run("DNSLookup is nil", func(t *testing.T) { - trace.DNSLookup = nil - checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip) - }) - t.Run("Query has nil channel", func(t *testing.T) { - trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{ - dns.TypeA: nil, - } - checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip) - }) -} diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 092b2a4..475b48d 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -7,7 +7,6 @@ package measurexlite import ( "time" - "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -39,9 +38,13 @@ type Trace struct { // this channel manually, ensure it has some buffer. NetworkEvent chan *model.ArchivalNetworkEvent - // NewParallelResolverFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewParallelResolver factory. - NewParallelResolverFn func() model.Resolver + // NewParallelUDPResolverFn is OPTIONAL and can be used to overide + // calls to the netxlite.NewParallelUDPResolver factory. + NewParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver + + // NewParallelDNSOverHTTPSResolverFn is OPTIONAL and can be used to overide + // calls to the netxlite.NewParallelDNSOverHTTPSUDPResolver factory. + NewParallelDNSOverHTTPSResolverFn func(logger model.Logger, URL string) model.Resolver // NewDialerWithoutResolverFn is OPTIONAL and can be used to override // calls to the netxlite.NewDialerWithoutResolver factory. @@ -51,13 +54,9 @@ type Trace struct { // calls to the netxlite.NewTLSHandshakerStdlib factory. NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker - // DNSLookup is MANDATORY and buffers DNSLookup results based on the - // query type. When we create this map using NewTrace, we will create - // an entry for each dns.Type in DNSQueryTypes. If you create this channel - // manually, you probably want to to the same (and most likely you also - // want to create buffered channels). Note that the code will print a - // warning and otherwise ignore all the query types not included in this map. - DNSLookup map[uint16]chan *model.ArchivalDNSLookupResult + // DNSLookup is MANDATORY and buffers DNS Lookup observations. If you create + // this channel manually, ensure it has some buffer. + DNSLookup chan *model.ArchivalDNSLookupResult // TCPConnect is MANDATORY and buffers TCP connect observations. If you create // this channel manually, ensure it has some buffer. @@ -93,25 +92,6 @@ const ( TLSHandshakeBufferSize = 8 ) -// DNSQueryTypes contains the list of DNS query types for which -// NewTrace create entries in Trace.DNSLookup. -var DNSQueryTypes = []uint16{ - dns.TypeANY, - dns.TypeA, - dns.TypeAAAA, - dns.TypeCNAME, - dns.TypeNS, -} - -// newDefaultDNSLookupMap is a convenience factory for creating Trace.DNSLookup -func newDefaultDNSLookupMap() map[uint16]chan *model.ArchivalDNSLookupResult { - out := make(map[uint16]chan *model.ArchivalDNSLookupResult) - for _, qtype := range DNSQueryTypes { - out[qtype] = make(chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize) - } - return out -} - // NewTrace creates a new instance of Trace using default settings. // // We create buffered channels using as buffer sizes the constants that @@ -132,7 +112,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace { ), NewDialerWithoutResolverFn: nil, // use default NewTLSHandshakerStdlibFn: nil, // use default - DNSLookup: newDefaultDNSLookupMap(), + DNSLookup: make( + chan *model.ArchivalDNSLookupResult, + DNSLookupBufferSize, + ), TCPConnect: make( chan *model.ArchivalTCPConnectResult, TCPConnectBufferSize, @@ -146,6 +129,24 @@ func NewTrace(index int64, zeroTime time.Time) *Trace { } } +// newParallelUDPResolver indirectly calls the passed netxlite.NewParallerUDPResolver +// thus allowing us to mock this function for testing +func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { + if tx.NewParallelUDPResolverFn != nil { + return tx.NewParallelUDPResolverFn(logger, dialer, address) + } + return netxlite.NewParallelUDPResolver(logger, dialer, address) +} + +// newParallelDNSOverHTTPSResolver indirectly calls the passed netxlite.NewParallerDNSOverHTTPSResolver +// thus allowing us to mock this function for testing +func (tx *Trace) newParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { + if tx.NewParallelDNSOverHTTPSResolverFn != nil { + return tx.NewParallelDNSOverHTTPSResolverFn(logger, URL) + } + return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL) +} + // newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver // thus allowing us to mock this func for testing. func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { @@ -155,15 +156,6 @@ func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { return netxlite.NewDialerWithoutResolver(dl) } -// newParallelResolver indirectly calls the passed netxlite.NewParallerResolver -// thus allowing us to mock this function for testing -func (tx *Trace) newParallelResolver(newResolver func() model.Resolver) model.Resolver { - if tx.NewParallelResolverFn != nil { - return tx.NewParallelResolverFn() - } - return newResolver() -} - // newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib // thus allowing us to mock this func for testing. func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index 54fae61..0bd83eb 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -46,9 +46,15 @@ func TestNewTrace(t *testing.T) { } }) - t.Run("NewParallelResolverFn is nil", func(t *testing.T) { - if trace.NewParallelResolverFn != nil { - t.Fatal("expected nil NewUnwrappedParallelResolverFn") + t.Run("NewParallelUDPResolverFn is nil", func(t *testing.T) { + if trace.NewParallelUDPResolverFn != nil { + t.Fatal("expected nil NewParallelUDPResolverFn") + } + }) + + t.Run("NewParallelDNSOverHTTPSResolverFn is nil", func(t *testing.T) { + if trace.NewParallelDNSOverHTTPSResolverFn != nil { + t.Fatal("expected nil NewParallelDNSOverHTTPSResolverFn") } }) @@ -66,23 +72,21 @@ func TestNewTrace(t *testing.T) { t.Run("DNSLookup has the expected buffer size", func(t *testing.T) { ff := &testingx.FakeFiller{} - for _, qtype := range DNSQueryTypes { - var count int - Loop: - for { - ev := &model.ArchivalDNSLookupResult{} - ff.Fill(ev) - select { - case trace.DNSLookup[qtype] <- ev: - count++ - default: - break Loop - } - } - if count != DNSLookupBufferSize { - t.Fatal("invalid DNSLookup A channel buffer size") + var idx int + Loop: + for { + ev := &model.ArchivalDNSLookupResult{} + ff.Fill(ev) + select { + case trace.DNSLookup <- ev: + idx++ + default: + break Loop } } + if idx != DNSLookupBufferSize { + t.Fatal("invalid DNSLookup channel buffer size") + } }) t.Run("TCPConnect has the expected buffer size", func(t *testing.T) { @@ -138,11 +142,11 @@ func TestNewTrace(t *testing.T) { } func TestTrace(t *testing.T) { - t.Run("NewParallelResolverFn works as intended", func(t *testing.T) { + t.Run("NewParallelUDPResolverFn works as intended", func(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") tx := &Trace{ - NewParallelResolverFn: func() model.Resolver { + NewParallelUDPResolverFn: func(logger model.Logger, dialer model.Dialer, address string) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{}, mockedErr @@ -150,9 +154,8 @@ func TestTrace(t *testing.T) { } }, } - resolver := tx.newParallelResolver(func() model.Resolver { - return nil - }) + dialer := &mocks.Dialer{} + resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") ctx := context.Background() addrs, err := resolver.LookupHost(ctx, "example.com") if !errors.Is(err, mockedErr) { @@ -165,26 +168,58 @@ func TestTrace(t *testing.T) { t.Run("when nil", func(t *testing.T) { tx := &Trace{ - NewParallelResolverFn: nil, + NewParallelUDPResolverFn: nil, } - newResolver := func() model.Resolver { - return &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"1.1.1.1"}, nil - }, - } - } - resolver := tx.newParallelResolver(newResolver) - ctx := context.Background() + dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger) + resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") + ctx, cancel := context.WithCancel(context.Background()) + cancel() addrs, err := resolver.LookupHost(ctx, "example.com") - if err != nil { + if err == nil || err.Error() != netxlite.FailureInterrupted { t.Fatal("unexpected err", err) } - if len(addrs) != 1 { - t.Fatal("expected array of size 1") + if len(addrs) != 0 { + t.Fatal("expected array of size 0") } - if addrs[0] != "1.1.1.1" { - t.Fatal("unexpected array output", addrs) + }) + }) + + t.Run("NewParallelDNSOverHTTPSResolverFn works as intended", func(t *testing.T) { + t.Run("when not nil", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := &Trace{ + NewParallelDNSOverHTTPSResolverFn: func(logger model.Logger, URL string) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{}, mockedErr + }, + } + }, + } + resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "dns.google.com") + ctx := context.Background() + addrs, err := resolver.LookupHost(ctx, "example.com") + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") + } + }) + + t.Run("when nil", func(t *testing.T) { + tx := &Trace{ + NewParallelDNSOverHTTPSResolverFn: nil, + } + resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + addrs, err := resolver.LookupHost(ctx, "example.com") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") } }) })