diff --git a/internal/netxlite/resolvercore.go b/internal/netxlite/resolvercore.go index ccc7a78..8b4243e 100644 --- a/internal/netxlite/resolvercore.go +++ b/internal/netxlite/resolvercore.go @@ -130,11 +130,17 @@ var _ model.Resolver = &resolverSystem{} func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) { encoder := &DNSEncoderMiekg{} query := encoder.Encode(hostname, dns.TypeANY, false) + trace := ContextTraceOrDefault(ctx) + start := trace.TimeNow() resp, err := r.t.RoundTrip(ctx, query) + end := trace.TimeNow() if err != nil { - return nil, err + trace.OnDNSRoundTripForLookupHost(start, r, query, resp, []string{}, err, end) + return []string{}, err } - return resp.DecodeLookupHost() + addrs, err := resp.DecodeLookupHost() + trace.OnDNSRoundTripForLookupHost(start, r, query, resp, addrs, err, end) + return addrs, err } func (r *resolverSystem) Network() string { diff --git a/internal/netxlite/resolvercore_test.go b/internal/netxlite/resolvercore_test.go index eca36fc..334f97f 100644 --- a/internal/netxlite/resolvercore_test.go +++ b/internal/netxlite/resolvercore_test.go @@ -7,12 +7,14 @@ import ( "net" "strings" "testing" + "time" "github.com/apex/log" "github.com/google/go-cmp/cmp" "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/testingx" ) func typecheckForSystemResolver(t *testing.T, resolver model.Resolver, logger model.DebugLogger) { @@ -202,6 +204,130 @@ func TestResolverSystem(t *testing.T) { t.Fatal("expected no results") } }) + + t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) { + var ( + onLookupCalled bool + goodQueryType bool + goodLookupAddrs bool + goodLookupError bool + goodLookupResolver bool + ) + expected := []string{"1.1.1.1"} + r := &resolverSystem{ + t: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + if query.Type() != dns.TypeANY { + return nil, errors.New("unexpected query type") + } + return &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + return expected, nil + }, + }, nil + }, + MockNetwork: func() string { + return "mocked" + }, + }, + } + zeroTime := time.Now() + deteterministicTime := testingx.NewTimeDeterministic(zeroTime) + tx := &mocks.Trace{ + MockTimeNow: deteterministicTime.Now, + MockOnDNSRoundTripForLookupHost: func(started time.Time, reso model.Resolver, query model.DNSQuery, + response model.DNSResponse, addrs []string, err error, finished time.Time) { + onLookupCalled = true + goodQueryType = (query.Type() == dns.TypeANY) + goodLookupAddrs = (cmp.Diff(addrs, expected) == "") + goodLookupError = (err == nil) + goodLookupResolver = (reso.Network() == "mocked") + }, + } + ctx := ContextWithTrace(context.Background(), tx) + addrs, err := r.LookupHost(ctx, "example.com") + if err != nil { + t.Fatal("unexpected error", err) + } + if diff := cmp.Diff(expected, addrs); diff != "" { + t.Fatal("unexpected addresses") + } + if !onLookupCalled { + t.Fatal("onLookupCalled not called") + } + if !goodQueryType { + t.Fatal("unexpected query type in system resolver") + } + if !goodLookupAddrs { + t.Fatal("unexpected addresses in LookupHost") + } + if !goodLookupError { + t.Fatal("unexpected error in trace") + } + if !goodLookupResolver { + t.Fatal("unexpected resolver network encountered") + } + }) + + t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) { + var ( + onLookupCalled bool + goodQueryType bool + goodLookupAddrs bool + goodLookupError bool + goodLookupResolver bool + ) + expected := errors.New("mocked") + r := &resolverSystem{ + t: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + if query.Type() != dns.TypeANY { + return nil, errors.New("unexpected query type") + } + return nil, expected + }, + MockNetwork: func() string { + return "mocked" + }, + }, + } + zeroTime := time.Now() + deteterministicTime := testingx.NewTimeDeterministic(zeroTime) + tx := &mocks.Trace{ + MockTimeNow: deteterministicTime.Now, + MockOnDNSRoundTripForLookupHost: func(started time.Time, reso model.Resolver, query model.DNSQuery, + response model.DNSResponse, addrs []string, err error, finished time.Time) { + onLookupCalled = true + goodQueryType = (query.Type() == dns.TypeANY) + goodLookupAddrs = (len(addrs) == 0) + goodLookupError = errors.Is(err, expected) + goodLookupResolver = (reso.Network() == "mocked") + }, + } + ctx := ContextWithTrace(context.Background(), tx) + addrs, err := r.LookupHost(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if len(addrs) != 0 { + t.Fatal("unexpected addresses") + } + if !onLookupCalled { + t.Fatal("onLookupCalled not called") + } + if !goodQueryType { + t.Fatal("unexpected query type in system resolver") + } + if !goodLookupAddrs { + t.Fatal("unexpected addresses in LookupHost") + } + if !goodLookupError { + t.Fatal("unexpected error in trace") + } + if !goodLookupResolver { + t.Fatal("unexpected resolver network encountered") + } + }) } func TestResolverLogger(t *testing.T) { diff --git a/internal/netxlite/resolverparallel_test.go b/internal/netxlite/resolverparallel_test.go index 73504b4..fa8d30a 100644 --- a/internal/netxlite/resolverparallel_test.go +++ b/internal/netxlite/resolverparallel_test.go @@ -6,10 +6,13 @@ import ( "errors" "net" "testing" + "time" + "github.com/google/go-cmp/cmp" "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/testingx" ) func TestParallelResolver(t *testing.T) { @@ -272,4 +275,214 @@ func TestParallelResolver(t *testing.T) { } }) }) + + t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) { + var ( + onLookupACalled bool + onLookupAAAACalled bool + goodQueryTypeA bool + goodQueryTypeAAAA bool + goodLookupAddrsA bool + goodLookupAddrsAAAA bool + goodLookupErrorA bool + goodLookupErrorAAAA bool + goodLookupResolverA bool + goodLookupResolverAAAA bool + ) + expectedA := []string{"1.1.1.1"} + expectedAAAA := []string{"::1"} + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + if query.Type() == dns.TypeA { + return &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + return expectedA, nil + }, + }, nil + } + if query.Type() == dns.TypeAAAA { + return &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + return expectedAAAA, nil + }, + }, nil + } + return nil, errors.New("unexpected query type") + }, + MockNetwork: func() string { + return "mocked" + }, + MockRequiresPadding: func() bool { + return false + }, + } + r := NewUnwrappedParallelResolver(txp) + zeroTime := time.Now() + deteterministicTime := testingx.NewTimeDeterministic(zeroTime) + tx := &mocks.Trace{ + MockTimeNow: deteterministicTime.Now, + MockOnDNSRoundTripForLookupHost: func(started time.Time, reso model.Resolver, query model.DNSQuery, + response model.DNSResponse, addrs []string, err error, finished time.Time) { + if query.Type() == dns.TypeA { + onLookupACalled = true + goodQueryTypeA = (query.Type() == dns.TypeA) + goodLookupAddrsA = (cmp.Diff(expectedA, addrs) == "") + goodLookupErrorA = (err == nil) + goodLookupResolverA = (reso.Network() == "mocked") + } + if query.Type() == dns.TypeAAAA { + onLookupAAAACalled = true + goodQueryTypeAAAA = (query.Type() == dns.TypeAAAA) + goodLookupAddrsAAAA = (cmp.Diff(expectedAAAA, addrs) == "") + goodLookupErrorAAAA = (err == nil) + goodLookupResolverAAAA = (reso.Network() == "mocked") + } + }, + } + want := []string{"1.1.1.1", "::1"} + ctx := ContextWithTrace(context.Background(), tx) + addrs, err := r.LookupHost(ctx, "example.com") + if err != nil { + t.Fatal("unexpected error", err) + } + // Note: the implementation always puts IPv4 addrs before IPv6 addrs + if diff := cmp.Diff(want, addrs); diff != "" { + t.Fatal("unexpected addresses") + } + + t.Run("with A reply", func(t *testing.T) { + if !onLookupACalled { + t.Fatal("onLookupACalled not called") + } + if !goodQueryTypeA { + t.Fatal("unexpected query type in parallel resolver") + } + if !goodLookupAddrsA { + t.Fatal("unexpected addresses in LookupHost") + } + if !goodLookupErrorA { + t.Fatal("unexpected error in trace") + } + if !goodLookupResolverA { + t.Fatal("unexpected resolver network encountered") + } + }) + + t.Run("with AAAA reply", func(t *testing.T) { + if !onLookupAAAACalled { + t.Fatal("onLookupAAAACalled not called") + } + if !goodQueryTypeAAAA { + t.Fatal("unexpected query type in parallel resolver") + } + if !goodLookupAddrsAAAA { + t.Fatal("unexpected addresses in LookupHost") + } + if !goodLookupErrorAAAA { + t.Fatal("unexpected error in trace") + } + if !goodLookupResolverAAAA { + t.Fatal("unexpected resolver network encountered") + } + }) + }) + + t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) { + var ( + onLookupACalled bool + onLookupAAAACalled bool + goodQueryTypeA bool + goodQueryTypeAAAA bool + goodLookupAddrsA bool + goodLookupAddrsAAAA bool + goodLookupErrorA bool + goodLookupErrorAAAA bool + goodLookupResolverA bool + goodLookupResolverAAAA bool + ) + expected := errors.New("mocked") + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + if query.Type() == dns.TypeAAAA || query.Type() == dns.TypeA { + return nil, expected + } + return nil, errors.New("unexpected query type") + }, + MockNetwork: func() string { + return "mocked" + }, + MockRequiresPadding: func() bool { + return false + }, + } + r := NewUnwrappedParallelResolver(txp) + zeroTime := time.Now() + deteterministicTime := testingx.NewTimeDeterministic(zeroTime) + tx := &mocks.Trace{ + MockTimeNow: deteterministicTime.Now, + MockOnDNSRoundTripForLookupHost: func(started time.Time, reso model.Resolver, query model.DNSQuery, + response model.DNSResponse, addrs []string, err error, finished time.Time) { + if query.Type() == dns.TypeA { + onLookupACalled = true + goodQueryTypeA = (query.Type() == dns.TypeA) + goodLookupAddrsA = (len(addrs) == 0) + goodLookupErrorA = errors.Is(expected, err) + goodLookupResolverA = (reso.Network() == "mocked") + return + } + if query.Type() == dns.TypeAAAA { + onLookupAAAACalled = true + goodQueryTypeAAAA = (query.Type() == dns.TypeAAAA) + goodLookupAddrsAAAA = (len(addrs) == 0) + goodLookupErrorAAAA = errors.Is(expected, err) + goodLookupResolverAAAA = (reso.Network() == "mocked") + return + } + }, + } + ctx := ContextWithTrace(context.Background(), tx) + addrs, err := r.LookupHost(ctx, "example.com") + if !errors.Is(expected, err) { + t.Fatal("unexpected error", err) + } + if len(addrs) != 0 { + t.Fatal("unexpected addresses") + } + + t.Run("with A reply", func(t *testing.T) { + if !onLookupACalled { + t.Fatal("onLookupACalled not called") + } + if !goodQueryTypeA { + t.Fatal("unexpected query type in parallel resolver") + } + if !goodLookupAddrsA { + t.Fatal("unexpected addresses in LookupHost") + } + if !goodLookupErrorA { + t.Fatal("unexpected error in trace") + } + if !goodLookupResolverA { + t.Fatal("unexpected resolver network encountered") + } + }) + + t.Run("with AAAA reply", func(t *testing.T) { + if !onLookupAAAACalled { + t.Fatal("onLookupAAAACalled not called") + } + if !goodQueryTypeAAAA { + t.Fatal("unexpected query type in parallel resolver") + } + if !goodLookupAddrsAAAA { + t.Fatal("unexpected addresses in LookupHost") + } + if !goodLookupErrorAAAA { + t.Fatal("unexpected error in trace") + } + if !goodLookupResolverAAAA { + t.Fatal("unexpected resolver network encountered") + } + }) + }) }