diff --git a/internal/engine/internal/sessionresolver/childresolver.go b/internal/engine/internal/sessionresolver/childresolver.go index 8848058..f589722 100644 --- a/internal/engine/internal/sessionresolver/childresolver.go +++ b/internal/engine/internal/sessionresolver/childresolver.go @@ -3,25 +3,46 @@ package sessionresolver import ( "context" "time" + + "github.com/ooni/probe-cli/v3/internal/model" ) -// childResolver is the DNS client that this package uses -// to perform individual domain name resolutions. -type childResolver interface { - // LookupHost performs a DNS lookup. - LookupHost(ctx context.Context, domain string) ([]string, error) - - // CloseIdleConnections closes idle connections. - CloseIdleConnections() -} +// defaultTimeLimitedLookupTimeout is the default timeout the code should +// pass to the timeLimitedLookup function. +// +// This algorithm is similar to Firefox using TRR2 mode. See: +// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox +// +// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side +// and therefore see to use DoH more often. +const defaultTimeLimitedLookupTimeout = 4 * time.Second // timeLimitedLookup performs a time-limited lookup using the given re. -func (r *Resolver) timeLimitedLookup(ctx context.Context, re childResolver, hostname string) ([]string, error) { - // Algorithm similar to Firefox TRR2 mode. See: - // https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox - // We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side - // and therefore see to use DoH more often. - ctx, cancel := context.WithTimeout(ctx, 4*time.Second) - defer cancel() - return re.LookupHost(ctx, hostname) +func timeLimitedLookup(ctx context.Context, re model.Resolver, hostname string) ([]string, error) { + return timeLimitedLookupWithTimeout(ctx, re, hostname, defaultTimeLimitedLookupTimeout) +} + +// timeLimitedLookupResult is the result of a timeLimitedLookup +type timeLimitedLookupResult struct { + addrs []string + err error +} + +// timeLimitedLookupWithTimeout is like timeLimitedLookup but with explicit timeout. +func timeLimitedLookupWithTimeout(ctx context.Context, re model.Resolver, + hostname string, timeout time.Duration) ([]string, error) { + outch := make(chan *timeLimitedLookupResult, 1) // buffer + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + go func() { + out := &timeLimitedLookupResult{} + out.addrs, out.err = re.LookupHost(ctx, hostname) + outch <- out + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case out := <-outch: + return out.addrs, out.err + } } diff --git a/internal/engine/internal/sessionresolver/childresolver_test.go b/internal/engine/internal/sessionresolver/childresolver_test.go index 888e636..e6a2bb7 100644 --- a/internal/engine/internal/sessionresolver/childresolver_test.go +++ b/internal/engine/internal/sessionresolver/childresolver_test.go @@ -8,51 +8,35 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model/mocks" ) -type FakeResolver struct { - Closed bool - Data []string - Err error - Sleep time.Duration -} - -func (r *FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { - select { - case <-time.After(r.Sleep): - return r.Data, r.Err - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -func (r *FakeResolver) CloseIdleConnections() { - r.Closed = true -} - func TestTimeLimitedLookupSuccess(t *testing.T) { - reso := &Resolver{} - re := &FakeResolver{ - Data: []string{"8.8.8.8", "8.8.4.4"}, + expected := []string{"8.8.8.8", "8.8.4.4"} + re := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return expected, nil + }, } ctx := context.Background() - out, err := reso.timeLimitedLookup(ctx, re, "dns.google") + out, err := timeLimitedLookup(ctx, re, "dns.google") if err != nil { t.Fatal(err) } - if diff := cmp.Diff(re.Data, out); diff != "" { + if diff := cmp.Diff(expected, out); diff != "" { t.Fatal(diff) } } func TestTimeLimitedLookupFailure(t *testing.T) { - reso := &Resolver{} - re := &FakeResolver{ - Err: io.EOF, + re := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, io.EOF + }, } ctx := context.Background() - out, err := reso.timeLimitedLookup(ctx, re, "dns.google") - if !errors.Is(err, re.Err) { + out, err := timeLimitedLookup(ctx, re, "dns.google") + if !errors.Is(err, io.EOF) { t.Fatal("not the error we expected", err) } if out != nil { @@ -61,20 +45,23 @@ func TestTimeLimitedLookupFailure(t *testing.T) { } func TestTimeLimitedLookupWillTimeout(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - reso := &Resolver{} - re := &FakeResolver{ - Err: io.EOF, - Sleep: 20 * time.Second, + done := make(chan bool) + block := make(chan bool) + re := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + defer close(done) + <-block + return nil, io.EOF + }, } ctx := context.Background() - out, err := reso.timeLimitedLookup(ctx, re, "dns.google") + out, err := timeLimitedLookupWithTimeout(ctx, re, "dns.google", 10*time.Millisecond) if !errors.Is(err, context.DeadlineExceeded) { t.Fatal("not the error we expected", err) } if out != nil { t.Fatal("expected nil here") } + close(block) + <-done } diff --git a/internal/engine/internal/sessionresolver/clientmaker.go b/internal/engine/internal/sessionresolver/clientmaker.go index 8ca6f06..b2ce862 100644 --- a/internal/engine/internal/sessionresolver/clientmaker.go +++ b/internal/engine/internal/sessionresolver/clientmaker.go @@ -1,11 +1,14 @@ package sessionresolver -import "github.com/ooni/probe-cli/v3/internal/engine/netx" +import ( + "github.com/ooni/probe-cli/v3/internal/engine/netx" + "github.com/ooni/probe-cli/v3/internal/model" +) // dnsclientmaker makes a new resolver. type dnsclientmaker interface { // Make makes a new resolver. - Make(config netx.Config, URL string) (childResolver, error) + Make(config netx.Config, URL string) (model.Resolver, error) } // clientmaker returns a valid dnsclientmaker @@ -20,6 +23,6 @@ func (r *Resolver) clientmaker() dnsclientmaker { type defaultDNSClientMaker struct{} // Make implements dnsclientmaker.Make. -func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) { +func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) { return netx.NewDNSClient(config, URL) } diff --git a/internal/engine/internal/sessionresolver/clientmaker_test.go b/internal/engine/internal/sessionresolver/clientmaker_test.go index 6db8855..eae04b3 100644 --- a/internal/engine/internal/sessionresolver/clientmaker_test.go +++ b/internal/engine/internal/sessionresolver/clientmaker_test.go @@ -7,16 +7,17 @@ import ( "testing" "github.com/ooni/probe-cli/v3/internal/engine/netx" + "github.com/ooni/probe-cli/v3/internal/model" ) type fakeDNSClientMaker struct { - reso childResolver + reso model.Resolver err error savedConfig netx.Config savedURL string } -func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) { +func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) { c.savedConfig = config c.savedURL = URL return c.reso, c.err diff --git a/internal/engine/internal/sessionresolver/resolvermaker.go b/internal/engine/internal/sessionresolver/resolvermaker.go index 373d778..2da4411 100644 --- a/internal/engine/internal/sessionresolver/resolvermaker.go +++ b/internal/engine/internal/sessionresolver/resolvermaker.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/model" @@ -71,15 +70,12 @@ func (r *Resolver) byteCounter() *bytecounter.Counter { // logger returns the configured logger or a default func (r *Resolver) logger() model.Logger { - if r.Logger != nil { - return r.Logger - } - return log.Log + return model.ValidLoggerOrDefault(r.Logger) } // newresolver creates a new resolver with the given config and URL. This is // where we expand http3 to https and set the h3 options. -func (r *Resolver) newresolver(URL string) (childResolver, error) { +func (r *Resolver) newresolver(URL string) (model.Resolver, error) { h3 := strings.HasPrefix(URL, "http3://") if h3 { URL = strings.Replace(URL, "http3://", "https://", 1) @@ -95,7 +91,7 @@ func (r *Resolver) newresolver(URL string) (childResolver, error) { // getresolver returns a resolver with the given URL. This function caches // already allocated resolvers so we only allocate them once. -func (r *Resolver) getresolver(URL string) (childResolver, error) { +func (r *Resolver) getresolver(URL string) (model.Resolver, error) { defer r.mu.Unlock() r.mu.Lock() if re, found := r.res[URL]; found { @@ -106,7 +102,7 @@ func (r *Resolver) getresolver(URL string) (childResolver, error) { return nil, err // config err? } if r.res == nil { - r.res = make(map[string]childResolver) + r.res = make(map[string]model.Resolver) } r.res[URL] = re return re, nil diff --git a/internal/engine/internal/sessionresolver/resolvermaker_test.go b/internal/engine/internal/sessionresolver/resolvermaker_test.go index 42ab705..7ebf433 100644 --- a/internal/engine/internal/sessionresolver/resolvermaker_test.go +++ b/internal/engine/internal/sessionresolver/resolvermaker_test.go @@ -5,8 +5,9 @@ import ( "strings" "testing" - "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/bytecounter" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" ) func TestDefaultByteCounter(t *testing.T) { @@ -18,18 +19,33 @@ func TestDefaultByteCounter(t *testing.T) { } func TestDefaultLogger(t *testing.T) { - logger := &log.Logger{} - reso := &Resolver{Logger: logger} - lo := reso.logger() - if lo != logger { - t.Fatal("expected another logger here counter") - } + t.Run("when using a different logger", func(t *testing.T) { + logger := &mocks.Logger{} + reso := &Resolver{Logger: logger} + lo := reso.logger() + if lo != logger { + t.Fatal("expected another logger here") + } + }) + + t.Run("when no logger is set", func(t *testing.T) { + reso := &Resolver{Logger: nil} + lo := reso.logger() + if lo != model.DiscardLogger { + t.Fatal("expected another logger here") + } + }) } func TestGetResolverHTTPSStandard(t *testing.T) { bc := bytecounter.New() URL := "https://dns.google" - re := &FakeResolver{} + var closed bool + re := &mocks.Resolver{ + MockCloseIdleConnections: func() { + closed = true + }, + } cmk := &fakeDNSClientMaker{reso: re} reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc} out, err := reso.getresolver(URL) @@ -47,7 +63,7 @@ func TestGetResolverHTTPSStandard(t *testing.T) { t.Fatal("not the result we expected") } reso.closeall() - if re.Closed != true { + if closed != true { t.Fatal("was not closed") } if cmk.savedURL != URL { @@ -62,7 +78,7 @@ func TestGetResolverHTTPSStandard(t *testing.T) { if cmk.savedConfig.HTTP3Enabled != false { t.Fatal("unexpected HTTP3Enabled") } - if cmk.savedConfig.Logger != log.Log { + if cmk.savedConfig.Logger != model.DiscardLogger { t.Fatal("unexpected Log") } } @@ -70,7 +86,12 @@ func TestGetResolverHTTPSStandard(t *testing.T) { func TestGetResolverHTTP3(t *testing.T) { bc := bytecounter.New() URL := "http3://dns.google" - re := &FakeResolver{} + var closed bool + re := &mocks.Resolver{ + MockCloseIdleConnections: func() { + closed = true + }, + } cmk := &fakeDNSClientMaker{reso: re} reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc} out, err := reso.getresolver(URL) @@ -88,7 +109,7 @@ func TestGetResolverHTTP3(t *testing.T) { t.Fatal("not the result we expected") } reso.closeall() - if re.Closed != true { + if closed != true { t.Fatal("was not closed") } if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) { @@ -103,7 +124,7 @@ func TestGetResolverHTTP3(t *testing.T) { if cmk.savedConfig.HTTP3Enabled != true { t.Fatal("unexpected HTTP3Enabled") } - if cmk.savedConfig.Logger != log.Log { + if cmk.savedConfig.Logger != model.DiscardLogger { t.Fatal("unexpected Log") } } diff --git a/internal/engine/internal/sessionresolver/sessionresolver.go b/internal/engine/internal/sessionresolver/sessionresolver.go index 254a7a0..782298d 100644 --- a/internal/engine/internal/sessionresolver/sessionresolver.go +++ b/internal/engine/internal/sessionresolver/sessionresolver.go @@ -95,7 +95,7 @@ type Resolver struct { // res maps a URL to a child resolver. We will // construct child resolvers just once and we // will track them into this field. - res map[string]childResolver + res map[string]model.Resolver } // CloseIdleConnections closes the idle connections, if any. This @@ -169,7 +169,7 @@ func (r *Resolver) lookupHost(ctx context.Context, ri *resolverinfo, hostname st ri.Score = 0 // this is a hard error return nil, err } - addrs, err := r.timeLimitedLookup(ctx, re, hostname) + addrs, err := timeLimitedLookup(ctx, re, hostname) if err == nil { r.logger().Infof("sessionresolver: %s... %v", ri.URL, model.ErrorToStringOrOK(nil)) ri.Score = ewma*1.0 + (1-ewma)*ri.Score // increase score diff --git a/internal/engine/internal/sessionresolver/sessionresolver_test.go b/internal/engine/internal/sessionresolver/sessionresolver_test.go index 51f75bc..d1cf3f9 100644 --- a/internal/engine/internal/sessionresolver/sessionresolver_test.go +++ b/internal/engine/internal/sessionresolver/sessionresolver_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/kvstore" + "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/multierror" ) @@ -85,7 +86,11 @@ func TestTypicalUsageWithSuccess(t *testing.T) { reso := &Resolver{ KVStore: &kvstore.Memory{}, dnsClientMaker: &fakeDNSClientMaker{ - reso: &FakeResolver{Data: expected}, + reso: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return expected, nil + }, + }, }, } addrs, err := reso.LookupHost(ctx, "dns.google") @@ -117,7 +122,11 @@ func TestLittleLLookupHostWithSuccess(t *testing.T) { expected := []string{"8.8.8.8", "8.8.4.4"} reso := &Resolver{ dnsClientMaker: &fakeDNSClientMaker{ - reso: &FakeResolver{Data: expected}, + reso: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return expected, nil + }, + }, }, } ctx := context.Background() @@ -138,7 +147,11 @@ func TestLittleLLookupHostWithFailure(t *testing.T) { errMocked := errors.New("mocked error") reso := &Resolver{ dnsClientMaker: &fakeDNSClientMaker{ - reso: &FakeResolver{Err: errMocked}, + reso: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, errMocked + }, + }, }, } ctx := context.Background()