feat: refactor dns implementation in measurexlite (#857)

* refactor: remove query-based mapping and introducing resolver wrapper

* refactor dnsping to adapt to measurexlite

* dnsping: extra comments

* Apply suggestions from code review

* Update internal/measurexlite/dns_test.go

See https://github.com/ooni/probe/issues/2208

Co-authored-by: decfox <decfox@github.com>
Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
DecFox 2022-08-11 19:30:37 +05:30 committed by GitHub
parent 576b52b1e3
commit fc51590a67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 221 deletions

View File

@ -12,7 +12,6 @@ import (
"sync" "sync"
"time" "time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/measurexlite"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
@ -148,19 +147,17 @@ func (m *Measurer) dnsRoundTrip(ctx context.Context, index int64, zeroTime time.
resolver := trace.NewParallelUDPResolver(logger, dialer, address) resolver := trace.NewParallelUDPResolver(logger, dialer, address)
_, err := resolver.LookupHost(ctx, domain) _, err := resolver.LookupHost(ctx, domain)
ol.Stop(err) ol.Stop(err)
// Add the dns.TypeA ping for _, lookup := range trace.DNSLookupsFromRoundTrip() {
pings = append(pings, m.makePingFromLookup(<-trace.DNSLookup[dns.TypeA])) // make sure we only include the query types we care about (in principle, there
// Add the dns.TypeAAAA ping // should be no other query, so we're doing this just for robustness).
pings = append(pings, m.makePingFromLookup(<-trace.DNSLookup[dns.TypeAAAA])) if lookup.QueryType == "A" || lookup.QueryType == "AAAA" {
tk.addPings(pings) pings = append(pings, &SinglePing{
}
// makePingfromLookup returns a SinglePing from the result of a single query
func (m *Measurer) makePingFromLookup(lookup *model.ArchivalDNSLookupResult) (pings *SinglePing) {
return &SinglePing{
Query: lookup, Query: lookup,
})
} }
} }
tk.addPings(pings)
}
// NewExperimentMeasurer creates a new ExperimentMeasurer. // NewExperimentMeasurer creates a new ExperimentMeasurer.
func NewExperimentMeasurer(config Config) model.ExperimentMeasurer { func NewExperimentMeasurer(config Config) model.ExperimentMeasurer {

View File

@ -17,11 +17,10 @@ import (
"github.com/ooni/probe-cli/v3/internal/tracex" "github.com/ooni/probe-cli/v3/internal/tracex"
) )
// newParallelResolverTrace is equivalent to netxlite.NewParallelResolver // wrapResolver resolver wraps the passed resolver to save data into the trace
// except that it returns a model.Resolver that uses this trace. func (tx *Trace) wrapResolver(resolver model.Resolver) model.Resolver {
func (tx *Trace) newParallelResolverTrace(newResolver func() model.Resolver) model.Resolver {
return &resolverTrace{ return &resolverTrace{
r: tx.newParallelResolver(newResolver), r: resolver,
tx: tx, tx: tx,
} }
} }
@ -66,29 +65,20 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS,
// NewParallelUDPResolver returns a trace-ware parallel UDP resolver // NewParallelUDPResolver returns a trace-ware parallel UDP resolver
func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
return tx.newParallelResolverTrace(func() model.Resolver { return tx.wrapResolver(tx.newParallelUDPResolver(logger, dialer, address))
return netxlite.NewParallelUDPResolver(logger, dialer, address)
})
} }
// NewParallelDNSOverHTTPSResolver returns a trace-aware parallel DoH resolver // NewParallelDNSOverHTTPSResolver returns a trace-aware parallel DoH resolver
func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver {
return tx.newParallelResolverTrace(func() model.Resolver { return tx.wrapResolver(tx.newParallelDNSOverHTTPSResolver(logger, URL))
return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL)
})
} }
// OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost // OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost
func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery, func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) { response model.DNSResponse, addrs []string, err error, finished time.Time) {
ch := tx.DNSLookup[query.Type()] t := finished.Sub(tx.ZeroTime)
if ch == nil {
// Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms.
log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query.Type()])
return
}
select { select {
case ch <- NewArchivalDNSLookupResultFromRoundTrip( case tx.DNSLookup <- NewArchivalDNSLookupResultFromRoundTrip(
tx.Index, tx.Index,
started.Sub(tx.ZeroTime), started.Sub(tx.ZeroTime),
reso, reso,
@ -96,7 +86,7 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol
response, response,
addrs, addrs,
err, err,
finished.Sub(tx.ZeroTime), t,
): ):
default: default:
} }
@ -151,17 +141,11 @@ func archivalAnswersFromAddrs(addrs []string) (out []model.ArchivalDNSAnswer) {
return return
} }
// DNSLookupsFromRoundTrip drains the network events buffered inside the corresponding query channel // DNSLookupsFromRoundTrip drains the network events buffered inside the DNSLookup channel
func (tx *Trace) DNSLookupsFromRoundTrip(query uint16) (out []*model.ArchivalDNSLookupResult) { func (tx *Trace) DNSLookupsFromRoundTrip() (out []*model.ArchivalDNSLookupResult) {
ch := tx.DNSLookup[query]
if ch == nil {
// Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms.
log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query])
return
}
for { for {
select { select {
case ev := <-ch: case ev := <-tx.DNSLookup:
out = append(out, ev) out = append(out, ev)
default: default:
return return

View File

@ -13,17 +13,11 @@ import (
) )
func TestNewUnwrappedParallelResolver(t *testing.T) { func TestNewUnwrappedParallelResolver(t *testing.T) {
t.Run("NewUnwrappedParallelResolver creates an UnwrappedParallelResolver with Trace", func(t *testing.T) { t.Run("WrapResolver creates a wrapped resolver with Trace", func(t *testing.T) {
underlying := &mocks.Resolver{} underlying := &mocks.Resolver{}
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.NewParallelResolverFn = func() model.Resolver { resolvert := trace.wrapResolver(underlying).(*resolverTrace)
return underlying
}
resolver := trace.newParallelResolverTrace(func() model.Resolver {
return nil
})
resolvert := resolver.(*resolverTrace)
if resolvert.r != underlying { if resolvert.r != underlying {
t.Fatal("invalid parallel resolver") t.Fatal("invalid parallel resolver")
} }
@ -36,8 +30,7 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
var called bool var called bool
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
newMockResolver := func() model.Resolver { mockResolver := &mocks.Resolver{
return &mocks.Resolver{
MockAddress: func() string { MockAddress: func() string {
return "dns.google" return "dns.google"
}, },
@ -48,8 +41,7 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
called = true called = true
}, },
} }
} resolver := trace.wrapResolver(mockResolver)
resolver := trace.newParallelResolver(newMockResolver)
t.Run("Address is correctly forwarded", func(t *testing.T) { t.Run("Address is correctly forwarded", func(t *testing.T) {
got := resolver.Address() got := resolver.Address()
@ -94,16 +86,14 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
return true return true
}, },
MockNetwork: func() string { MockNetwork: func() string {
return "" return "mocked"
}, },
MockAddress: func() string { MockAddress: func() string {
return "dns.google" return "dns.google"
}, },
} }
newResolver := func() model.Resolver { r := netxlite.NewUnwrappedParallelResolver(txp)
return netxlite.NewUnwrappedParallelResolver(txp) resolver := trace.wrapResolver(r)
}
resolver := trace.newParallelResolverTrace(newResolver)
ctx := context.Background() ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com") addrs, err := resolver.LookupHost(ctx, "example.com")
if err != nil { if err != nil {
@ -119,45 +109,27 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
t.Fatal("unexpected array output", addrs) t.Fatal("unexpected array output", addrs)
} }
t.Run("DNSLookups QueryType A", func(t *testing.T) { t.Run("DNSLookup events", func(t *testing.T) {
events := trace.DNSLookupsFromRoundTrip(dns.TypeA) events := trace.DNSLookupsFromRoundTrip()
if len(events) != 1 { if len(events) != 2 {
t.Fatal("expected to see single DNSLookup event") t.Fatal("unexpected DNS events")
} }
lookup := events[0] for _, ev := range events {
answers := lookup.Answers if ev.ResolverAddress != "dns.google" {
if lookup.Failure != nil { t.Fatal("unexpected resolver address")
t.Fatal("unexpected err", *(lookup.Failure))
} }
if lookup.ResolverAddress != "dns.google" { if ev.Engine != "mocked" {
t.Fatal("unexpected address field") t.Fatal("unexpected engine")
} }
if len(answers) != 1 { if len(ev.Answers) != 1 {
t.Fatal("expected 1 DNS answer, got", len(answers)) t.Fatal("expected single answer in DNSLookup event")
} }
if answers[0].AnswerType != "A" || answers[0].IPv4 != "1.1.1.1" { if ev.QueryType == "A" && ev.Answers[0].IPv4 != "1.1.1.1" {
t.Fatal("unexpected DNS answer", answers) t.Fatal("unexpected A query result")
} }
}) if ev.QueryType == "AAAA" && ev.Answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" {
t.Fatal("unexpected AAAA query result")
t.Run("DNSLookups QueryType AAAA", func(t *testing.T) {
events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA)
if len(events) != 1 {
t.Fatal("expected to see single DNSLookup event")
} }
lookup := events[0]
answers := lookup.Answers
if lookup.Failure != nil {
t.Fatal("unexpected err", *(lookup.Failure))
}
if lookup.ResolverAddress != "dns.google" {
t.Fatal("unexpected address field")
}
if len(answers) != 1 {
t.Fatal("expected 1 DNS answer, got", len(answers))
}
if answers[0].AnswerType != "AAAA" || answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" {
t.Fatal("unexpected DNS answer", answers)
} }
}) })
}) })
@ -166,10 +138,7 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
zeroTime := time.Now() zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime) td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{ trace.DNSLookup = make(chan *model.ArchivalDNSLookupResult) // no buffer
dns.TypeA: make(chan *model.ArchivalDNSLookupResult), // no buffer
dns.TypeAAAA: make(chan *model.ArchivalDNSLookupResult), // no buffer
}
trace.TimeNowFn = td.Now trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
@ -193,10 +162,8 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
return "dns.google" return "dns.google"
}, },
} }
newResolver := func() model.Resolver { r := netxlite.NewUnwrappedParallelResolver(txp)
return netxlite.NewUnwrappedParallelResolver(txp) resolver := trace.wrapResolver(r)
}
resolver := trace.newParallelResolverTrace(newResolver)
ctx := context.Background() ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com") addrs, err := resolver.LookupHost(ctx, "example.com")
if err != nil { if err != nil {
@ -205,17 +172,17 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
if len(addrs) != 2 { if len(addrs) != 2 {
t.Fatal("unexpected array output", addrs) t.Fatal("unexpected array output", addrs)
} }
if addrs[0] != "1.1.1.1" && addrs[1] != "1.1.1.1" {
t.Run("DNSLookups QueryType A", func(t *testing.T) { t.Fatal("unexpected array output", addrs)
events := trace.DNSLookupsFromRoundTrip(dns.TypeA)
if len(events) != 0 {
t.Fatal("expected to see no DNSLookup")
} }
}) if addrs[0] != "fe80::a00:20ff:feb9:4c54" && addrs[1] != "fe80::a00:20ff:feb9:4c54" {
t.Run("DNSLookups QueryType AAAA", func(t *testing.T) { t.Fatal("unexpected array output", addrs)
events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA) }
t.Run("DNSLookup Events", func(t *testing.T) {
events := trace.DNSLookupsFromRoundTrip()
if len(events) != 0 { if len(events) != 0 {
t.Fatal("expected to see no DNSLookup") t.Fatal("expected to see no DNSLookup events")
} }
}) })
}) })
@ -271,26 +238,3 @@ func TestAnswersFromAddrs(t *testing.T) {
}) })
} }
} }
func TestDNSLookupsFromRoundTrips(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
checkPanic := func(query uint16, f func(uint16) []*model.ArchivalDNSLookupResult) {
defer func() {
if r := recover(); r != nil {
t.Fatal("unexpected panic encoutered")
}
}()
f(query)
}
t.Run("DNSLookup is nil", func(t *testing.T) {
trace.DNSLookup = nil
checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip)
})
t.Run("Query has nil channel", func(t *testing.T) {
trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{
dns.TypeA: nil,
}
checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip)
})
}

View File

@ -7,7 +7,6 @@ package measurexlite
import ( import (
"time" "time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
@ -39,9 +38,13 @@ type Trace struct {
// this channel manually, ensure it has some buffer. // this channel manually, ensure it has some buffer.
NetworkEvent chan *model.ArchivalNetworkEvent NetworkEvent chan *model.ArchivalNetworkEvent
// NewParallelResolverFn is OPTIONAL and can be used to overide // NewParallelUDPResolverFn is OPTIONAL and can be used to overide
// calls to the netxlite.NewParallelResolver factory. // calls to the netxlite.NewParallelUDPResolver factory.
NewParallelResolverFn func() model.Resolver NewParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver
// NewParallelDNSOverHTTPSResolverFn is OPTIONAL and can be used to overide
// calls to the netxlite.NewParallelDNSOverHTTPSUDPResolver factory.
NewParallelDNSOverHTTPSResolverFn func(logger model.Logger, URL string) model.Resolver
// NewDialerWithoutResolverFn is OPTIONAL and can be used to override // NewDialerWithoutResolverFn is OPTIONAL and can be used to override
// calls to the netxlite.NewDialerWithoutResolver factory. // calls to the netxlite.NewDialerWithoutResolver factory.
@ -51,13 +54,9 @@ type Trace struct {
// calls to the netxlite.NewTLSHandshakerStdlib factory. // calls to the netxlite.NewTLSHandshakerStdlib factory.
NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker
// DNSLookup is MANDATORY and buffers DNSLookup results based on the // DNSLookup is MANDATORY and buffers DNS Lookup observations. If you create
// query type. When we create this map using NewTrace, we will create // this channel manually, ensure it has some buffer.
// an entry for each dns.Type in DNSQueryTypes. If you create this channel DNSLookup chan *model.ArchivalDNSLookupResult
// manually, you probably want to to the same (and most likely you also
// want to create buffered channels). Note that the code will print a
// warning and otherwise ignore all the query types not included in this map.
DNSLookup map[uint16]chan *model.ArchivalDNSLookupResult
// TCPConnect is MANDATORY and buffers TCP connect observations. If you create // TCPConnect is MANDATORY and buffers TCP connect observations. If you create
// this channel manually, ensure it has some buffer. // this channel manually, ensure it has some buffer.
@ -93,25 +92,6 @@ const (
TLSHandshakeBufferSize = 8 TLSHandshakeBufferSize = 8
) )
// DNSQueryTypes contains the list of DNS query types for which
// NewTrace create entries in Trace.DNSLookup.
var DNSQueryTypes = []uint16{
dns.TypeANY,
dns.TypeA,
dns.TypeAAAA,
dns.TypeCNAME,
dns.TypeNS,
}
// newDefaultDNSLookupMap is a convenience factory for creating Trace.DNSLookup
func newDefaultDNSLookupMap() map[uint16]chan *model.ArchivalDNSLookupResult {
out := make(map[uint16]chan *model.ArchivalDNSLookupResult)
for _, qtype := range DNSQueryTypes {
out[qtype] = make(chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize)
}
return out
}
// NewTrace creates a new instance of Trace using default settings. // NewTrace creates a new instance of Trace using default settings.
// //
// We create buffered channels using as buffer sizes the constants that // We create buffered channels using as buffer sizes the constants that
@ -132,7 +112,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
), ),
NewDialerWithoutResolverFn: nil, // use default NewDialerWithoutResolverFn: nil, // use default
NewTLSHandshakerStdlibFn: nil, // use default NewTLSHandshakerStdlibFn: nil, // use default
DNSLookup: newDefaultDNSLookupMap(), DNSLookup: make(
chan *model.ArchivalDNSLookupResult,
DNSLookupBufferSize,
),
TCPConnect: make( TCPConnect: make(
chan *model.ArchivalTCPConnectResult, chan *model.ArchivalTCPConnectResult,
TCPConnectBufferSize, TCPConnectBufferSize,
@ -146,6 +129,24 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
} }
} }
// newParallelUDPResolver indirectly calls the passed netxlite.NewParallerUDPResolver
// thus allowing us to mock this function for testing
func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
if tx.NewParallelUDPResolverFn != nil {
return tx.NewParallelUDPResolverFn(logger, dialer, address)
}
return netxlite.NewParallelUDPResolver(logger, dialer, address)
}
// newParallelDNSOverHTTPSResolver indirectly calls the passed netxlite.NewParallerDNSOverHTTPSResolver
// thus allowing us to mock this function for testing
func (tx *Trace) newParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver {
if tx.NewParallelDNSOverHTTPSResolverFn != nil {
return tx.NewParallelDNSOverHTTPSResolverFn(logger, URL)
}
return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL)
}
// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver // newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver
// thus allowing us to mock this func for testing. // thus allowing us to mock this func for testing.
func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
@ -155,15 +156,6 @@ func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
return netxlite.NewDialerWithoutResolver(dl) return netxlite.NewDialerWithoutResolver(dl)
} }
// newParallelResolver indirectly calls the passed netxlite.NewParallerResolver
// thus allowing us to mock this function for testing
func (tx *Trace) newParallelResolver(newResolver func() model.Resolver) model.Resolver {
if tx.NewParallelResolverFn != nil {
return tx.NewParallelResolverFn()
}
return newResolver()
}
// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib // newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
// thus allowing us to mock this func for testing. // thus allowing us to mock this func for testing.
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {

View File

@ -46,9 +46,15 @@ func TestNewTrace(t *testing.T) {
} }
}) })
t.Run("NewParallelResolverFn is nil", func(t *testing.T) { t.Run("NewParallelUDPResolverFn is nil", func(t *testing.T) {
if trace.NewParallelResolverFn != nil { if trace.NewParallelUDPResolverFn != nil {
t.Fatal("expected nil NewUnwrappedParallelResolverFn") t.Fatal("expected nil NewParallelUDPResolverFn")
}
})
t.Run("NewParallelDNSOverHTTPSResolverFn is nil", func(t *testing.T) {
if trace.NewParallelDNSOverHTTPSResolverFn != nil {
t.Fatal("expected nil NewParallelDNSOverHTTPSResolverFn")
} }
}) })
@ -66,22 +72,20 @@ func TestNewTrace(t *testing.T) {
t.Run("DNSLookup has the expected buffer size", func(t *testing.T) { t.Run("DNSLookup has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{} ff := &testingx.FakeFiller{}
for _, qtype := range DNSQueryTypes { var idx int
var count int
Loop: Loop:
for { for {
ev := &model.ArchivalDNSLookupResult{} ev := &model.ArchivalDNSLookupResult{}
ff.Fill(ev) ff.Fill(ev)
select { select {
case trace.DNSLookup[qtype] <- ev: case trace.DNSLookup <- ev:
count++ idx++
default: default:
break Loop break Loop
} }
} }
if count != DNSLookupBufferSize { if idx != DNSLookupBufferSize {
t.Fatal("invalid DNSLookup A channel buffer size") t.Fatal("invalid DNSLookup channel buffer size")
}
} }
}) })
@ -138,11 +142,11 @@ func TestNewTrace(t *testing.T) {
} }
func TestTrace(t *testing.T) { func TestTrace(t *testing.T) {
t.Run("NewParallelResolverFn works as intended", func(t *testing.T) { t.Run("NewParallelUDPResolverFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) { t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked") mockedErr := errors.New("mocked")
tx := &Trace{ tx := &Trace{
NewParallelResolverFn: func() model.Resolver { NewParallelUDPResolverFn: func(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
return &mocks.Resolver{ return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{}, mockedErr return []string{}, mockedErr
@ -150,9 +154,8 @@ func TestTrace(t *testing.T) {
} }
}, },
} }
resolver := tx.newParallelResolver(func() model.Resolver { dialer := &mocks.Dialer{}
return nil resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53")
})
ctx := context.Background() ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com") addrs, err := resolver.LookupHost(ctx, "example.com")
if !errors.Is(err, mockedErr) { if !errors.Is(err, mockedErr) {
@ -165,26 +168,58 @@ func TestTrace(t *testing.T) {
t.Run("when nil", func(t *testing.T) { t.Run("when nil", func(t *testing.T) {
tx := &Trace{ tx := &Trace{
NewParallelResolverFn: nil, NewParallelUDPResolverFn: nil,
} }
newResolver := func() model.Resolver { dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger)
return &mocks.Resolver{ resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53")
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { ctx, cancel := context.WithCancel(context.Background())
return []string{"1.1.1.1"}, nil cancel()
},
}
}
resolver := tx.newParallelResolver(newResolver)
ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com") addrs, err := resolver.LookupHost(ctx, "example.com")
if err != nil { if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if len(addrs) != 1 { if len(addrs) != 0 {
t.Fatal("expected array of size 1") t.Fatal("expected array of size 0")
} }
if addrs[0] != "1.1.1.1" { })
t.Fatal("unexpected array output", addrs) })
t.Run("NewParallelDNSOverHTTPSResolverFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewParallelDNSOverHTTPSResolverFn: func(logger model.Logger, URL string) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{}, mockedErr
},
}
},
}
resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "dns.google.com")
ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com")
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if len(addrs) != 0 {
t.Fatal("expected array of size 0")
}
})
t.Run("when nil", func(t *testing.T) {
tx := &Trace{
NewParallelDNSOverHTTPSResolverFn: nil,
}
resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com")
ctx, cancel := context.WithCancel(context.Background())
cancel()
addrs, err := resolver.LookupHost(ctx, "example.com")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if len(addrs) != 0 {
t.Fatal("expected array of size 0")
} }
}) })
}) })