feat: refactor dns implementation in measurexlite (#857)
* refactor: remove query-based mapping and introducing resolver wrapper * refactor dnsping to adapt to measurexlite * dnsping: extra comments * Apply suggestions from code review * Update internal/measurexlite/dns_test.go See https://github.com/ooni/probe/issues/2208 Co-authored-by: decfox <decfox@github.com> Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
		
							parent
							
								
									576b52b1e3
								
							
						
					
					
						commit
						fc51590a67
					
				@ -12,7 +12,6 @@ import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/ooni/probe-cli/v3/internal/measurexlite"
 | 
			
		||||
	"github.com/ooni/probe-cli/v3/internal/model"
 | 
			
		||||
	"github.com/ooni/probe-cli/v3/internal/netxlite"
 | 
			
		||||
@ -148,19 +147,17 @@ func (m *Measurer) dnsRoundTrip(ctx context.Context, index int64, zeroTime time.
 | 
			
		||||
	resolver := trace.NewParallelUDPResolver(logger, dialer, address)
 | 
			
		||||
	_, err := resolver.LookupHost(ctx, domain)
 | 
			
		||||
	ol.Stop(err)
 | 
			
		||||
	// Add the dns.TypeA ping
 | 
			
		||||
	pings = append(pings, m.makePingFromLookup(<-trace.DNSLookup[dns.TypeA]))
 | 
			
		||||
	// Add the dns.TypeAAAA ping
 | 
			
		||||
	pings = append(pings, m.makePingFromLookup(<-trace.DNSLookup[dns.TypeAAAA]))
 | 
			
		||||
	tk.addPings(pings)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// makePingfromLookup returns a SinglePing from the result of a single query
 | 
			
		||||
func (m *Measurer) makePingFromLookup(lookup *model.ArchivalDNSLookupResult) (pings *SinglePing) {
 | 
			
		||||
	return &SinglePing{
 | 
			
		||||
	for _, lookup := range trace.DNSLookupsFromRoundTrip() {
 | 
			
		||||
		// make sure we only include the query types we care about (in principle, there
 | 
			
		||||
		// should be no other query, so we're doing this just for robustness).
 | 
			
		||||
		if lookup.QueryType == "A" || lookup.QueryType == "AAAA" {
 | 
			
		||||
			pings = append(pings, &SinglePing{
 | 
			
		||||
				Query: lookup,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	tk.addPings(pings)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewExperimentMeasurer creates a new ExperimentMeasurer.
 | 
			
		||||
func NewExperimentMeasurer(config Config) model.ExperimentMeasurer {
 | 
			
		||||
 | 
			
		||||
@ -17,11 +17,10 @@ import (
 | 
			
		||||
	"github.com/ooni/probe-cli/v3/internal/tracex"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// newParallelResolverTrace is equivalent to netxlite.NewParallelResolver
 | 
			
		||||
// except that it returns a model.Resolver that uses this trace.
 | 
			
		||||
func (tx *Trace) newParallelResolverTrace(newResolver func() model.Resolver) model.Resolver {
 | 
			
		||||
// wrapResolver resolver wraps the passed resolver to save data into the trace
 | 
			
		||||
func (tx *Trace) wrapResolver(resolver model.Resolver) model.Resolver {
 | 
			
		||||
	return &resolverTrace{
 | 
			
		||||
		r:  tx.newParallelResolver(newResolver),
 | 
			
		||||
		r:  resolver,
 | 
			
		||||
		tx: tx,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -66,29 +65,20 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS,
 | 
			
		||||
 | 
			
		||||
// NewParallelUDPResolver returns a trace-ware parallel UDP resolver
 | 
			
		||||
func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
 | 
			
		||||
	return tx.newParallelResolverTrace(func() model.Resolver {
 | 
			
		||||
		return netxlite.NewParallelUDPResolver(logger, dialer, address)
 | 
			
		||||
	})
 | 
			
		||||
	return tx.wrapResolver(tx.newParallelUDPResolver(logger, dialer, address))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewParallelDNSOverHTTPSResolver returns a trace-aware parallel DoH resolver
 | 
			
		||||
func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver {
 | 
			
		||||
	return tx.newParallelResolverTrace(func() model.Resolver {
 | 
			
		||||
		return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL)
 | 
			
		||||
	})
 | 
			
		||||
	return tx.wrapResolver(tx.newParallelDNSOverHTTPSResolver(logger, URL))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost
 | 
			
		||||
func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery,
 | 
			
		||||
	response model.DNSResponse, addrs []string, err error, finished time.Time) {
 | 
			
		||||
	ch := tx.DNSLookup[query.Type()]
 | 
			
		||||
	if ch == nil {
 | 
			
		||||
		// Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms.
 | 
			
		||||
		log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query.Type()])
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	t := finished.Sub(tx.ZeroTime)
 | 
			
		||||
	select {
 | 
			
		||||
	case ch <- NewArchivalDNSLookupResultFromRoundTrip(
 | 
			
		||||
	case tx.DNSLookup <- NewArchivalDNSLookupResultFromRoundTrip(
 | 
			
		||||
		tx.Index,
 | 
			
		||||
		started.Sub(tx.ZeroTime),
 | 
			
		||||
		reso,
 | 
			
		||||
@ -96,7 +86,7 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol
 | 
			
		||||
		response,
 | 
			
		||||
		addrs,
 | 
			
		||||
		err,
 | 
			
		||||
		finished.Sub(tx.ZeroTime),
 | 
			
		||||
		t,
 | 
			
		||||
	):
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
@ -151,17 +141,11 @@ func archivalAnswersFromAddrs(addrs []string) (out []model.ArchivalDNSAnswer) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DNSLookupsFromRoundTrip drains the network events buffered inside the corresponding query channel
 | 
			
		||||
func (tx *Trace) DNSLookupsFromRoundTrip(query uint16) (out []*model.ArchivalDNSLookupResult) {
 | 
			
		||||
	ch := tx.DNSLookup[query]
 | 
			
		||||
	if ch == nil {
 | 
			
		||||
		// Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms.
 | 
			
		||||
		log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query])
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
// DNSLookupsFromRoundTrip drains the network events buffered inside the DNSLookup channel
 | 
			
		||||
func (tx *Trace) DNSLookupsFromRoundTrip() (out []*model.ArchivalDNSLookupResult) {
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case ev := <-ch:
 | 
			
		||||
		case ev := <-tx.DNSLookup:
 | 
			
		||||
			out = append(out, ev)
 | 
			
		||||
		default:
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
@ -13,17 +13,11 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
	t.Run("NewUnwrappedParallelResolver creates an UnwrappedParallelResolver with Trace", func(t *testing.T) {
 | 
			
		||||
	t.Run("WrapResolver creates a wrapped resolver with Trace", func(t *testing.T) {
 | 
			
		||||
		underlying := &mocks.Resolver{}
 | 
			
		||||
		zeroTime := time.Now()
 | 
			
		||||
		trace := NewTrace(0, zeroTime)
 | 
			
		||||
		trace.NewParallelResolverFn = func() model.Resolver {
 | 
			
		||||
			return underlying
 | 
			
		||||
		}
 | 
			
		||||
		resolver := trace.newParallelResolverTrace(func() model.Resolver {
 | 
			
		||||
			return nil
 | 
			
		||||
		})
 | 
			
		||||
		resolvert := resolver.(*resolverTrace)
 | 
			
		||||
		resolvert := trace.wrapResolver(underlying).(*resolverTrace)
 | 
			
		||||
		if resolvert.r != underlying {
 | 
			
		||||
			t.Fatal("invalid parallel resolver")
 | 
			
		||||
		}
 | 
			
		||||
@ -36,8 +30,7 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
		var called bool
 | 
			
		||||
		zeroTime := time.Now()
 | 
			
		||||
		trace := NewTrace(0, zeroTime)
 | 
			
		||||
		newMockResolver := func() model.Resolver {
 | 
			
		||||
			return &mocks.Resolver{
 | 
			
		||||
		mockResolver := &mocks.Resolver{
 | 
			
		||||
			MockAddress: func() string {
 | 
			
		||||
				return "dns.google"
 | 
			
		||||
			},
 | 
			
		||||
@ -48,8 +41,7 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
				called = true
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		}
 | 
			
		||||
		resolver := trace.newParallelResolver(newMockResolver)
 | 
			
		||||
		resolver := trace.wrapResolver(mockResolver)
 | 
			
		||||
 | 
			
		||||
		t.Run("Address is correctly forwarded", func(t *testing.T) {
 | 
			
		||||
			got := resolver.Address()
 | 
			
		||||
@ -94,16 +86,14 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
				return true
 | 
			
		||||
			},
 | 
			
		||||
			MockNetwork: func() string {
 | 
			
		||||
				return ""
 | 
			
		||||
				return "mocked"
 | 
			
		||||
			},
 | 
			
		||||
			MockAddress: func() string {
 | 
			
		||||
				return "dns.google"
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		newResolver := func() model.Resolver {
 | 
			
		||||
			return netxlite.NewUnwrappedParallelResolver(txp)
 | 
			
		||||
		}
 | 
			
		||||
		resolver := trace.newParallelResolverTrace(newResolver)
 | 
			
		||||
		r := netxlite.NewUnwrappedParallelResolver(txp)
 | 
			
		||||
		resolver := trace.wrapResolver(r)
 | 
			
		||||
		ctx := context.Background()
 | 
			
		||||
		addrs, err := resolver.LookupHost(ctx, "example.com")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@ -119,45 +109,27 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
			t.Fatal("unexpected array output", addrs)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		t.Run("DNSLookups QueryType A", func(t *testing.T) {
 | 
			
		||||
			events := trace.DNSLookupsFromRoundTrip(dns.TypeA)
 | 
			
		||||
			if len(events) != 1 {
 | 
			
		||||
				t.Fatal("expected to see single DNSLookup event")
 | 
			
		||||
		t.Run("DNSLookup events", func(t *testing.T) {
 | 
			
		||||
			events := trace.DNSLookupsFromRoundTrip()
 | 
			
		||||
			if len(events) != 2 {
 | 
			
		||||
				t.Fatal("unexpected DNS events")
 | 
			
		||||
			}
 | 
			
		||||
			lookup := events[0]
 | 
			
		||||
			answers := lookup.Answers
 | 
			
		||||
			if lookup.Failure != nil {
 | 
			
		||||
				t.Fatal("unexpected err", *(lookup.Failure))
 | 
			
		||||
			for _, ev := range events {
 | 
			
		||||
				if ev.ResolverAddress != "dns.google" {
 | 
			
		||||
					t.Fatal("unexpected resolver address")
 | 
			
		||||
				}
 | 
			
		||||
			if lookup.ResolverAddress != "dns.google" {
 | 
			
		||||
				t.Fatal("unexpected address field")
 | 
			
		||||
				if ev.Engine != "mocked" {
 | 
			
		||||
					t.Fatal("unexpected engine")
 | 
			
		||||
				}
 | 
			
		||||
			if len(answers) != 1 {
 | 
			
		||||
				t.Fatal("expected 1 DNS answer, got", len(answers))
 | 
			
		||||
				if len(ev.Answers) != 1 {
 | 
			
		||||
					t.Fatal("expected single answer in DNSLookup event")
 | 
			
		||||
				}
 | 
			
		||||
			if answers[0].AnswerType != "A" || answers[0].IPv4 != "1.1.1.1" {
 | 
			
		||||
				t.Fatal("unexpected DNS answer", answers)
 | 
			
		||||
				if ev.QueryType == "A" && ev.Answers[0].IPv4 != "1.1.1.1" {
 | 
			
		||||
					t.Fatal("unexpected A query result")
 | 
			
		||||
				}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		t.Run("DNSLookups QueryType AAAA", func(t *testing.T) {
 | 
			
		||||
			events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA)
 | 
			
		||||
			if len(events) != 1 {
 | 
			
		||||
				t.Fatal("expected to see single DNSLookup event")
 | 
			
		||||
				if ev.QueryType == "AAAA" && ev.Answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" {
 | 
			
		||||
					t.Fatal("unexpected AAAA query result")
 | 
			
		||||
				}
 | 
			
		||||
			lookup := events[0]
 | 
			
		||||
			answers := lookup.Answers
 | 
			
		||||
			if lookup.Failure != nil {
 | 
			
		||||
				t.Fatal("unexpected err", *(lookup.Failure))
 | 
			
		||||
			}
 | 
			
		||||
			if lookup.ResolverAddress != "dns.google" {
 | 
			
		||||
				t.Fatal("unexpected address field")
 | 
			
		||||
			}
 | 
			
		||||
			if len(answers) != 1 {
 | 
			
		||||
				t.Fatal("expected 1 DNS answer, got", len(answers))
 | 
			
		||||
			}
 | 
			
		||||
			if answers[0].AnswerType != "AAAA" || answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" {
 | 
			
		||||
				t.Fatal("unexpected DNS answer", answers)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
@ -166,10 +138,7 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
		zeroTime := time.Now()
 | 
			
		||||
		td := testingx.NewTimeDeterministic(zeroTime)
 | 
			
		||||
		trace := NewTrace(0, zeroTime)
 | 
			
		||||
		trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{
 | 
			
		||||
			dns.TypeA:    make(chan *model.ArchivalDNSLookupResult), // no buffer
 | 
			
		||||
			dns.TypeAAAA: make(chan *model.ArchivalDNSLookupResult), // no buffer
 | 
			
		||||
		}
 | 
			
		||||
		trace.DNSLookup = make(chan *model.ArchivalDNSLookupResult) // no buffer
 | 
			
		||||
		trace.TimeNowFn = td.Now
 | 
			
		||||
		txp := &mocks.DNSTransport{
 | 
			
		||||
			MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
 | 
			
		||||
@ -193,10 +162,8 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
				return "dns.google"
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		newResolver := func() model.Resolver {
 | 
			
		||||
			return netxlite.NewUnwrappedParallelResolver(txp)
 | 
			
		||||
		}
 | 
			
		||||
		resolver := trace.newParallelResolverTrace(newResolver)
 | 
			
		||||
		r := netxlite.NewUnwrappedParallelResolver(txp)
 | 
			
		||||
		resolver := trace.wrapResolver(r)
 | 
			
		||||
		ctx := context.Background()
 | 
			
		||||
		addrs, err := resolver.LookupHost(ctx, "example.com")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@ -205,17 +172,17 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
 | 
			
		||||
		if len(addrs) != 2 {
 | 
			
		||||
			t.Fatal("unexpected array output", addrs)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		t.Run("DNSLookups QueryType A", func(t *testing.T) {
 | 
			
		||||
			events := trace.DNSLookupsFromRoundTrip(dns.TypeA)
 | 
			
		||||
			if len(events) != 0 {
 | 
			
		||||
				t.Fatal("expected to see no DNSLookup")
 | 
			
		||||
		if addrs[0] != "1.1.1.1" && addrs[1] != "1.1.1.1" {
 | 
			
		||||
			t.Fatal("unexpected array output", addrs)
 | 
			
		||||
		}
 | 
			
		||||
		})
 | 
			
		||||
		t.Run("DNSLookups QueryType AAAA", func(t *testing.T) {
 | 
			
		||||
			events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA)
 | 
			
		||||
		if addrs[0] != "fe80::a00:20ff:feb9:4c54" && addrs[1] != "fe80::a00:20ff:feb9:4c54" {
 | 
			
		||||
			t.Fatal("unexpected array output", addrs)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		t.Run("DNSLookup Events", func(t *testing.T) {
 | 
			
		||||
			events := trace.DNSLookupsFromRoundTrip()
 | 
			
		||||
			if len(events) != 0 {
 | 
			
		||||
				t.Fatal("expected to see no DNSLookup")
 | 
			
		||||
				t.Fatal("expected to see no DNSLookup events")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
@ -271,26 +238,3 @@ func TestAnswersFromAddrs(t *testing.T) {
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDNSLookupsFromRoundTrips(t *testing.T) {
 | 
			
		||||
	zeroTime := time.Now()
 | 
			
		||||
	trace := NewTrace(0, zeroTime)
 | 
			
		||||
	checkPanic := func(query uint16, f func(uint16) []*model.ArchivalDNSLookupResult) {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if r := recover(); r != nil {
 | 
			
		||||
				t.Fatal("unexpected panic encoutered")
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
		f(query)
 | 
			
		||||
	}
 | 
			
		||||
	t.Run("DNSLookup is nil", func(t *testing.T) {
 | 
			
		||||
		trace.DNSLookup = nil
 | 
			
		||||
		checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip)
 | 
			
		||||
	})
 | 
			
		||||
	t.Run("Query has nil channel", func(t *testing.T) {
 | 
			
		||||
		trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{
 | 
			
		||||
			dns.TypeA: nil,
 | 
			
		||||
		}
 | 
			
		||||
		checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@ package measurexlite
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/ooni/probe-cli/v3/internal/model"
 | 
			
		||||
	"github.com/ooni/probe-cli/v3/internal/netxlite"
 | 
			
		||||
)
 | 
			
		||||
@ -39,9 +38,13 @@ type Trace struct {
 | 
			
		||||
	// this channel manually, ensure it has some buffer.
 | 
			
		||||
	NetworkEvent chan *model.ArchivalNetworkEvent
 | 
			
		||||
 | 
			
		||||
	// NewParallelResolverFn is OPTIONAL and can be used to overide
 | 
			
		||||
	// calls to the netxlite.NewParallelResolver factory.
 | 
			
		||||
	NewParallelResolverFn func() model.Resolver
 | 
			
		||||
	// NewParallelUDPResolverFn is OPTIONAL and can be used to overide
 | 
			
		||||
	// calls to the netxlite.NewParallelUDPResolver factory.
 | 
			
		||||
	NewParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver
 | 
			
		||||
 | 
			
		||||
	// NewParallelDNSOverHTTPSResolverFn is OPTIONAL and can be used to overide
 | 
			
		||||
	// calls to the netxlite.NewParallelDNSOverHTTPSUDPResolver factory.
 | 
			
		||||
	NewParallelDNSOverHTTPSResolverFn func(logger model.Logger, URL string) model.Resolver
 | 
			
		||||
 | 
			
		||||
	// NewDialerWithoutResolverFn is OPTIONAL and can be used to override
 | 
			
		||||
	// calls to the netxlite.NewDialerWithoutResolver factory.
 | 
			
		||||
@ -51,13 +54,9 @@ type Trace struct {
 | 
			
		||||
	// calls to the netxlite.NewTLSHandshakerStdlib factory.
 | 
			
		||||
	NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker
 | 
			
		||||
 | 
			
		||||
	// DNSLookup is MANDATORY and buffers DNSLookup results based on the
 | 
			
		||||
	// query type. When we create this map using NewTrace, we will create
 | 
			
		||||
	// an entry for each dns.Type in DNSQueryTypes. If you create this channel
 | 
			
		||||
	// manually, you probably want to to the same (and most likely you also
 | 
			
		||||
	// want to create buffered channels). Note that the code will print a
 | 
			
		||||
	// warning and otherwise ignore all the query types not included in this map.
 | 
			
		||||
	DNSLookup map[uint16]chan *model.ArchivalDNSLookupResult
 | 
			
		||||
	// DNSLookup is MANDATORY and buffers DNS Lookup observations. If you create
 | 
			
		||||
	// this channel manually, ensure it has some buffer.
 | 
			
		||||
	DNSLookup chan *model.ArchivalDNSLookupResult
 | 
			
		||||
 | 
			
		||||
	// TCPConnect is MANDATORY and buffers TCP connect observations. If you create
 | 
			
		||||
	// this channel manually, ensure it has some buffer.
 | 
			
		||||
@ -93,25 +92,6 @@ const (
 | 
			
		||||
	TLSHandshakeBufferSize = 8
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// DNSQueryTypes contains the list of DNS query types for which
 | 
			
		||||
// NewTrace create entries in Trace.DNSLookup.
 | 
			
		||||
var DNSQueryTypes = []uint16{
 | 
			
		||||
	dns.TypeANY,
 | 
			
		||||
	dns.TypeA,
 | 
			
		||||
	dns.TypeAAAA,
 | 
			
		||||
	dns.TypeCNAME,
 | 
			
		||||
	dns.TypeNS,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newDefaultDNSLookupMap is a convenience factory for creating Trace.DNSLookup
 | 
			
		||||
func newDefaultDNSLookupMap() map[uint16]chan *model.ArchivalDNSLookupResult {
 | 
			
		||||
	out := make(map[uint16]chan *model.ArchivalDNSLookupResult)
 | 
			
		||||
	for _, qtype := range DNSQueryTypes {
 | 
			
		||||
		out[qtype] = make(chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize)
 | 
			
		||||
	}
 | 
			
		||||
	return out
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewTrace creates a new instance of Trace using default settings.
 | 
			
		||||
//
 | 
			
		||||
// We create buffered channels using as buffer sizes the constants that
 | 
			
		||||
@ -132,7 +112,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
 | 
			
		||||
		),
 | 
			
		||||
		NewDialerWithoutResolverFn: nil, // use default
 | 
			
		||||
		NewTLSHandshakerStdlibFn:   nil, // use default
 | 
			
		||||
		DNSLookup:                  newDefaultDNSLookupMap(),
 | 
			
		||||
		DNSLookup: make(
 | 
			
		||||
			chan *model.ArchivalDNSLookupResult,
 | 
			
		||||
			DNSLookupBufferSize,
 | 
			
		||||
		),
 | 
			
		||||
		TCPConnect: make(
 | 
			
		||||
			chan *model.ArchivalTCPConnectResult,
 | 
			
		||||
			TCPConnectBufferSize,
 | 
			
		||||
@ -146,6 +129,24 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newParallelUDPResolver indirectly calls the passed netxlite.NewParallerUDPResolver
 | 
			
		||||
// thus allowing us to mock this function for testing
 | 
			
		||||
func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
 | 
			
		||||
	if tx.NewParallelUDPResolverFn != nil {
 | 
			
		||||
		return tx.NewParallelUDPResolverFn(logger, dialer, address)
 | 
			
		||||
	}
 | 
			
		||||
	return netxlite.NewParallelUDPResolver(logger, dialer, address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newParallelDNSOverHTTPSResolver indirectly calls the passed netxlite.NewParallerDNSOverHTTPSResolver
 | 
			
		||||
// thus allowing us to mock this function for testing
 | 
			
		||||
func (tx *Trace) newParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver {
 | 
			
		||||
	if tx.NewParallelDNSOverHTTPSResolverFn != nil {
 | 
			
		||||
		return tx.NewParallelDNSOverHTTPSResolverFn(logger, URL)
 | 
			
		||||
	}
 | 
			
		||||
	return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver
 | 
			
		||||
// thus allowing us to mock this func for testing.
 | 
			
		||||
func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
 | 
			
		||||
@ -155,15 +156,6 @@ func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
 | 
			
		||||
	return netxlite.NewDialerWithoutResolver(dl)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newParallelResolver indirectly calls the passed netxlite.NewParallerResolver
 | 
			
		||||
// thus allowing us to mock this function for testing
 | 
			
		||||
func (tx *Trace) newParallelResolver(newResolver func() model.Resolver) model.Resolver {
 | 
			
		||||
	if tx.NewParallelResolverFn != nil {
 | 
			
		||||
		return tx.NewParallelResolverFn()
 | 
			
		||||
	}
 | 
			
		||||
	return newResolver()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
 | 
			
		||||
// thus allowing us to mock this func for testing.
 | 
			
		||||
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
 | 
			
		||||
 | 
			
		||||
@ -46,9 +46,15 @@ func TestNewTrace(t *testing.T) {
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		t.Run("NewParallelResolverFn is nil", func(t *testing.T) {
 | 
			
		||||
			if trace.NewParallelResolverFn != nil {
 | 
			
		||||
				t.Fatal("expected nil NewUnwrappedParallelResolverFn")
 | 
			
		||||
		t.Run("NewParallelUDPResolverFn is nil", func(t *testing.T) {
 | 
			
		||||
			if trace.NewParallelUDPResolverFn != nil {
 | 
			
		||||
				t.Fatal("expected nil NewParallelUDPResolverFn")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		t.Run("NewParallelDNSOverHTTPSResolverFn is nil", func(t *testing.T) {
 | 
			
		||||
			if trace.NewParallelDNSOverHTTPSResolverFn != nil {
 | 
			
		||||
				t.Fatal("expected nil NewParallelDNSOverHTTPSResolverFn")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
@ -66,22 +72,20 @@ func TestNewTrace(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
		t.Run("DNSLookup has the expected buffer size", func(t *testing.T) {
 | 
			
		||||
			ff := &testingx.FakeFiller{}
 | 
			
		||||
			for _, qtype := range DNSQueryTypes {
 | 
			
		||||
				var count int
 | 
			
		||||
			var idx int
 | 
			
		||||
		Loop:
 | 
			
		||||
			for {
 | 
			
		||||
				ev := &model.ArchivalDNSLookupResult{}
 | 
			
		||||
				ff.Fill(ev)
 | 
			
		||||
				select {
 | 
			
		||||
					case trace.DNSLookup[qtype] <- ev:
 | 
			
		||||
						count++
 | 
			
		||||
				case trace.DNSLookup <- ev:
 | 
			
		||||
					idx++
 | 
			
		||||
				default:
 | 
			
		||||
					break Loop
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
				if count != DNSLookupBufferSize {
 | 
			
		||||
					t.Fatal("invalid DNSLookup A channel buffer size")
 | 
			
		||||
				}
 | 
			
		||||
			if idx != DNSLookupBufferSize {
 | 
			
		||||
				t.Fatal("invalid DNSLookup channel buffer size")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
@ -138,11 +142,11 @@ func TestNewTrace(t *testing.T) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTrace(t *testing.T) {
 | 
			
		||||
	t.Run("NewParallelResolverFn works as intended", func(t *testing.T) {
 | 
			
		||||
	t.Run("NewParallelUDPResolverFn works as intended", func(t *testing.T) {
 | 
			
		||||
		t.Run("when not nil", func(t *testing.T) {
 | 
			
		||||
			mockedErr := errors.New("mocked")
 | 
			
		||||
			tx := &Trace{
 | 
			
		||||
				NewParallelResolverFn: func() model.Resolver {
 | 
			
		||||
				NewParallelUDPResolverFn: func(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
 | 
			
		||||
					return &mocks.Resolver{
 | 
			
		||||
						MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
 | 
			
		||||
							return []string{}, mockedErr
 | 
			
		||||
@ -150,9 +154,8 @@ func TestTrace(t *testing.T) {
 | 
			
		||||
					}
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
			resolver := tx.newParallelResolver(func() model.Resolver {
 | 
			
		||||
				return nil
 | 
			
		||||
			})
 | 
			
		||||
			dialer := &mocks.Dialer{}
 | 
			
		||||
			resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53")
 | 
			
		||||
			ctx := context.Background()
 | 
			
		||||
			addrs, err := resolver.LookupHost(ctx, "example.com")
 | 
			
		||||
			if !errors.Is(err, mockedErr) {
 | 
			
		||||
@ -165,26 +168,58 @@ func TestTrace(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
		t.Run("when nil", func(t *testing.T) {
 | 
			
		||||
			tx := &Trace{
 | 
			
		||||
				NewParallelResolverFn: nil,
 | 
			
		||||
				NewParallelUDPResolverFn: nil,
 | 
			
		||||
			}
 | 
			
		||||
			newResolver := func() model.Resolver {
 | 
			
		||||
				return &mocks.Resolver{
 | 
			
		||||
					MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
 | 
			
		||||
						return []string{"1.1.1.1"}, nil
 | 
			
		||||
					},
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			resolver := tx.newParallelResolver(newResolver)
 | 
			
		||||
			ctx := context.Background()
 | 
			
		||||
			dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger)
 | 
			
		||||
			resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53")
 | 
			
		||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
			cancel()
 | 
			
		||||
			addrs, err := resolver.LookupHost(ctx, "example.com")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
			if err == nil || err.Error() != netxlite.FailureInterrupted {
 | 
			
		||||
				t.Fatal("unexpected err", err)
 | 
			
		||||
			}
 | 
			
		||||
			if len(addrs) != 1 {
 | 
			
		||||
				t.Fatal("expected array of size 1")
 | 
			
		||||
			if len(addrs) != 0 {
 | 
			
		||||
				t.Fatal("expected array of size 0")
 | 
			
		||||
			}
 | 
			
		||||
			if addrs[0] != "1.1.1.1" {
 | 
			
		||||
				t.Fatal("unexpected array output", addrs)
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("NewParallelDNSOverHTTPSResolverFn works as intended", func(t *testing.T) {
 | 
			
		||||
		t.Run("when not nil", func(t *testing.T) {
 | 
			
		||||
			mockedErr := errors.New("mocked")
 | 
			
		||||
			tx := &Trace{
 | 
			
		||||
				NewParallelDNSOverHTTPSResolverFn: func(logger model.Logger, URL string) model.Resolver {
 | 
			
		||||
					return &mocks.Resolver{
 | 
			
		||||
						MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
 | 
			
		||||
							return []string{}, mockedErr
 | 
			
		||||
						},
 | 
			
		||||
					}
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
			resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "dns.google.com")
 | 
			
		||||
			ctx := context.Background()
 | 
			
		||||
			addrs, err := resolver.LookupHost(ctx, "example.com")
 | 
			
		||||
			if !errors.Is(err, mockedErr) {
 | 
			
		||||
				t.Fatal("unexpected err", err)
 | 
			
		||||
			}
 | 
			
		||||
			if len(addrs) != 0 {
 | 
			
		||||
				t.Fatal("expected array of size 0")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		t.Run("when nil", func(t *testing.T) {
 | 
			
		||||
			tx := &Trace{
 | 
			
		||||
				NewParallelDNSOverHTTPSResolverFn: nil,
 | 
			
		||||
			}
 | 
			
		||||
			resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com")
 | 
			
		||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
			cancel()
 | 
			
		||||
			addrs, err := resolver.LookupHost(ctx, "example.com")
 | 
			
		||||
			if err == nil || err.Error() != netxlite.FailureInterrupted {
 | 
			
		||||
				t.Fatal("unexpected err", err)
 | 
			
		||||
			}
 | 
			
		||||
			if len(addrs) != 0 {
 | 
			
		||||
				t.Fatal("expected array of size 0")
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user