diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index 14614c1..7a5b0b2 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -63,6 +63,11 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS, return r.r.LookupNS(netxlite.ContextWithTrace(ctx, r.tx), domain) } +// NewStdlibResolver returns a trace-ware system resolver +func (tx *Trace) NewStdlibResolver(logger model.Logger) model.Resolver { + return tx.wrapResolver(tx.newStdlibResolver(logger)) +} + // NewParallelUDPResolver returns a trace-ware parallel UDP resolver func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { return tx.wrapResolver(tx.newParallelUDPResolver(logger, dialer, address)) diff --git a/internal/measurexlite/dns_test.go b/internal/measurexlite/dns_test.go index 8b6e7ee..6fa092f 100644 --- a/internal/measurexlite/dns_test.go +++ b/internal/measurexlite/dns_test.go @@ -2,9 +2,11 @@ package measurexlite import ( "context" + "net" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" @@ -12,7 +14,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/testingx" ) -func TestNewUnwrappedParallelResolver(t *testing.T) { +func TestNewResolver(t *testing.T) { t.Run("WrapResolver creates a wrapped resolver with Trace", func(t *testing.T) { underlying := &mocks.Resolver{} zeroTime := time.Now() @@ -37,6 +39,19 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { MockNetwork: func() string { return "udp" }, + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"1.1.1.1"}, nil + }, + MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + return &model.HTTPSSvc{ + IPv4: []string{"1.1.1.1"}, + }, nil + }, + MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) { + return []*net.NS{{ + Host: "1.1.1.1", + }}, nil + }, MockCloseIdleConnections: func() { called = true }, @@ -57,6 +72,46 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { } }) + t.Run("LookupHost is correctly forwarded", func(t *testing.T) { + want := []string{"1.1.1.1"} + ctx := context.Background() + got, err := resolver.LookupHost(ctx, "example.com") + if err != nil { + t.Fatal("expected nil error") + } + if diff := cmp.Diff(want, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("LookupHTTPS is correctly forwarded", func(t *testing.T) { + want := &model.HTTPSSvc{ + IPv4: []string{"1.1.1.1"}, + } + ctx := context.Background() + got, err := resolver.LookupHTTPS(ctx, "example.com") + if err != nil { + t.Fatal("expected nil error") + } + if diff := cmp.Diff(want, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("LookupNS is correctly forwarded", func(t *testing.T) { + want := []*net.NS{{ + Host: "1.1.1.1", + }} + ctx := context.Background() + got, err := resolver.LookupNS(ctx, "example.com") + if err != nil { + t.Fatal("expected nil error") + } + if diff := cmp.Diff(want, got); diff != "" { + t.Fatal(diff) + } + }) + t.Run("CloseIdleConnections is correctly forwarded", func(t *testing.T) { resolver.CloseIdleConnections() if !called { @@ -188,6 +243,48 @@ func TestNewUnwrappedParallelResolver(t *testing.T) { }) } +func TestNewWrappedResolvers(t *testing.T) { + t.Run("NewParallelDNSOverHTTPSResolver works as intended", func(t *testing.T) { + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + resolver := trace.NewParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") + resolvert := resolver.(*resolverTrace) + if resolvert.tx != trace { + t.Fatal("invalid trace") + } + if resolver.Network() != "doh" { + t.Fatal("unexpected resolver network") + } + }) + + t.Run("NewParallelUDPResolver works as intended", func(t *testing.T) { + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + dialer := netxlite.NewDialerWithStdlibResolver(model.DiscardLogger) + resolver := trace.NewParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") + resolvert := resolver.(*resolverTrace) + if resolvert.tx != trace { + t.Fatal("invalid trace") + } + if resolver.Network() != "udp" { + t.Fatal("unexpected resolver network") + } + }) + + t.Run("NewStdlibResolver works as intended", func(t *testing.T) { + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + resolver := trace.NewStdlibResolver(model.DiscardLogger) + resolvert := resolver.(*resolverTrace) + if resolvert.tx != trace { + t.Fatal("invalid trace") + } + if resolver.Network() != "system" { + t.Fatal("unexpected resolver network") + } + }) +} + func TestAnswersFromAddrs(t *testing.T) { tests := []struct { name string diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 475b48d..946e4e2 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -38,6 +38,10 @@ type Trace struct { // this channel manually, ensure it has some buffer. NetworkEvent chan *model.ArchivalNetworkEvent + // NewStdlibResolverFn is OPTIONAL and can be used to overide + // calls to the netxlite.NewStdlibResolver factory. + NewStdlibResolverFn func(logger model.Logger) model.Resolver + // NewParallelUDPResolverFn is OPTIONAL and can be used to overide // calls to the netxlite.NewParallelUDPResolver factory. NewParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver @@ -129,6 +133,15 @@ func NewTrace(index int64, zeroTime time.Time) *Trace { } } +// newStdlibResolver indirectly calls the passed netxlite.NewStdlibResolver +// thus allowing us to mock this function for testing +func (tx *Trace) newStdlibResolver(logger model.Logger) model.Resolver { + if tx.NewStdlibResolverFn != nil { + return tx.NewStdlibResolverFn(logger) + } + return netxlite.NewStdlibResolver(logger) +} + // newParallelUDPResolver indirectly calls the passed netxlite.NewParallerUDPResolver // thus allowing us to mock this function for testing func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index 0bd83eb..854e625 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -46,6 +46,12 @@ func TestNewTrace(t *testing.T) { } }) + t.Run("NewStdlibResolverFn is nil", func(t *testing.T) { + if trace.NewStdlibResolverFn != nil { + t.Fatal("expected nil NewStdlibResolverFn") + } + }) + t.Run("NewParallelUDPResolverFn is nil", func(t *testing.T) { if trace.NewParallelUDPResolverFn != nil { t.Fatal("expected nil NewParallelUDPResolverFn") @@ -142,6 +148,46 @@ func TestNewTrace(t *testing.T) { } func TestTrace(t *testing.T) { + t.Run("NewStdlibResolverFn works as intended", func(t *testing.T) { + t.Run("when not nil", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := &Trace{ + NewStdlibResolverFn: func(logger model.Logger) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{}, mockedErr + }, + } + }, + } + resolver := tx.newStdlibResolver(model.DiscardLogger) + ctx := context.Background() + addrs, err := resolver.LookupHost(ctx, "example.com") + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") + } + }) + + t.Run("when nil", func(t *testing.T) { + tx := &Trace{ + NewParallelUDPResolverFn: nil, + } + resolver := tx.newStdlibResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + addrs, err := resolver.LookupHost(ctx, "example.com") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") + } + }) + }) + t.Run("NewParallelUDPResolverFn works as intended", func(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") @@ -196,7 +242,7 @@ func TestTrace(t *testing.T) { } }, } - resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "dns.google.com") + resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") ctx := context.Background() addrs, err := resolver.LookupHost(ctx, "example.com") if !errors.Is(err, mockedErr) {