refactor: start refactoring session resolver (#807)

This diff addresses the following points of https://github.com/ooni/probe/issues/2135:

- [x] the `childResolver` type is useless and we can use `model.Resolver` directly;
- [x] we should use `model/mocks` instead of custom fakes;
- [x] we should not use `log.Log` rather we should use `model.DiscardLogger`;
- [x] make `timeLimitedLookup` easier to test with a `-short` tests;
- [x] ensure `timeLimitedLookup` returns as soon as its context expires regardless of the child resolver;

Subsequent diffs will address more points mentioned in there.
This commit is contained in:
Simone Basso 2022-06-08 14:06:22 +02:00 committed by GitHub
parent dea23b49d5
commit fe29b432e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 128 additions and 86 deletions

View File

@ -3,25 +3,46 @@ package sessionresolver
import ( import (
"context" "context"
"time" "time"
"github.com/ooni/probe-cli/v3/internal/model"
) )
// childResolver is the DNS client that this package uses // defaultTimeLimitedLookupTimeout is the default timeout the code should
// to perform individual domain name resolutions. // pass to the timeLimitedLookup function.
type childResolver interface { //
// LookupHost performs a DNS lookup. // This algorithm is similar to Firefox using TRR2 mode. See:
LookupHost(ctx context.Context, domain string) ([]string, error) // https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox
//
// CloseIdleConnections closes idle connections. // We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side
CloseIdleConnections() // and therefore see to use DoH more often.
} const defaultTimeLimitedLookupTimeout = 4 * time.Second
// timeLimitedLookup performs a time-limited lookup using the given re. // timeLimitedLookup performs a time-limited lookup using the given re.
func (r *Resolver) timeLimitedLookup(ctx context.Context, re childResolver, hostname string) ([]string, error) { func timeLimitedLookup(ctx context.Context, re model.Resolver, hostname string) ([]string, error) {
// Algorithm similar to Firefox TRR2 mode. See: return timeLimitedLookupWithTimeout(ctx, re, hostname, defaultTimeLimitedLookupTimeout)
// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox }
// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side
// and therefore see to use DoH more often. // timeLimitedLookupResult is the result of a timeLimitedLookup
ctx, cancel := context.WithTimeout(ctx, 4*time.Second) type timeLimitedLookupResult struct {
defer cancel() addrs []string
return re.LookupHost(ctx, hostname) err error
}
// timeLimitedLookupWithTimeout is like timeLimitedLookup but with explicit timeout.
func timeLimitedLookupWithTimeout(ctx context.Context, re model.Resolver,
hostname string, timeout time.Duration) ([]string, error) {
outch := make(chan *timeLimitedLookupResult, 1) // buffer
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
go func() {
out := &timeLimitedLookupResult{}
out.addrs, out.err = re.LookupHost(ctx, hostname)
outch <- out
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-outch:
return out.addrs, out.err
}
} }

View File

@ -8,51 +8,35 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
type FakeResolver struct {
Closed bool
Data []string
Err error
Sleep time.Duration
}
func (r *FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
select {
case <-time.After(r.Sleep):
return r.Data, r.Err
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (r *FakeResolver) CloseIdleConnections() {
r.Closed = true
}
func TestTimeLimitedLookupSuccess(t *testing.T) { func TestTimeLimitedLookupSuccess(t *testing.T) {
reso := &Resolver{} expected := []string{"8.8.8.8", "8.8.4.4"}
re := &FakeResolver{ re := &mocks.Resolver{
Data: []string{"8.8.8.8", "8.8.4.4"}, MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil
},
} }
ctx := context.Background() ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google") out, err := timeLimitedLookup(ctx, re, "dns.google")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if diff := cmp.Diff(re.Data, out); diff != "" { if diff := cmp.Diff(expected, out); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
} }
func TestTimeLimitedLookupFailure(t *testing.T) { func TestTimeLimitedLookupFailure(t *testing.T) {
reso := &Resolver{} re := &mocks.Resolver{
re := &FakeResolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
Err: io.EOF, return nil, io.EOF
},
} }
ctx := context.Background() ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google") out, err := timeLimitedLookup(ctx, re, "dns.google")
if !errors.Is(err, re.Err) { if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
if out != nil { if out != nil {
@ -61,20 +45,23 @@ func TestTimeLimitedLookupFailure(t *testing.T) {
} }
func TestTimeLimitedLookupWillTimeout(t *testing.T) { func TestTimeLimitedLookupWillTimeout(t *testing.T) {
if testing.Short() { done := make(chan bool)
t.Skip("skip test in short mode") block := make(chan bool)
} re := &mocks.Resolver{
reso := &Resolver{} MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
re := &FakeResolver{ defer close(done)
Err: io.EOF, <-block
Sleep: 20 * time.Second, return nil, io.EOF
},
} }
ctx := context.Background() ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google") out, err := timeLimitedLookupWithTimeout(ctx, re, "dns.google", 10*time.Millisecond)
if !errors.Is(err, context.DeadlineExceeded) { if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
if out != nil { if out != nil {
t.Fatal("expected nil here") t.Fatal("expected nil here")
} }
close(block)
<-done
} }

View File

@ -1,11 +1,14 @@
package sessionresolver package sessionresolver
import "github.com/ooni/probe-cli/v3/internal/engine/netx" import (
"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/model"
)
// dnsclientmaker makes a new resolver. // dnsclientmaker makes a new resolver.
type dnsclientmaker interface { type dnsclientmaker interface {
// Make makes a new resolver. // Make makes a new resolver.
Make(config netx.Config, URL string) (childResolver, error) Make(config netx.Config, URL string) (model.Resolver, error)
} }
// clientmaker returns a valid dnsclientmaker // clientmaker returns a valid dnsclientmaker
@ -20,6 +23,6 @@ func (r *Resolver) clientmaker() dnsclientmaker {
type defaultDNSClientMaker struct{} type defaultDNSClientMaker struct{}
// Make implements dnsclientmaker.Make. // Make implements dnsclientmaker.Make.
func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) { func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
return netx.NewDNSClient(config, URL) return netx.NewDNSClient(config, URL)
} }

View File

@ -7,16 +7,17 @@ import (
"testing" "testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/model"
) )
type fakeDNSClientMaker struct { type fakeDNSClientMaker struct {
reso childResolver reso model.Resolver
err error err error
savedConfig netx.Config savedConfig netx.Config
savedURL string savedURL string
} }
func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) { func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
c.savedConfig = config c.savedConfig = config
c.savedURL = URL c.savedURL = URL
return c.reso, c.err return c.reso, c.err

View File

@ -5,7 +5,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
@ -71,15 +70,12 @@ func (r *Resolver) byteCounter() *bytecounter.Counter {
// logger returns the configured logger or a default // logger returns the configured logger or a default
func (r *Resolver) logger() model.Logger { func (r *Resolver) logger() model.Logger {
if r.Logger != nil { return model.ValidLoggerOrDefault(r.Logger)
return r.Logger
}
return log.Log
} }
// newresolver creates a new resolver with the given config and URL. This is // newresolver creates a new resolver with the given config and URL. This is
// where we expand http3 to https and set the h3 options. // where we expand http3 to https and set the h3 options.
func (r *Resolver) newresolver(URL string) (childResolver, error) { func (r *Resolver) newresolver(URL string) (model.Resolver, error) {
h3 := strings.HasPrefix(URL, "http3://") h3 := strings.HasPrefix(URL, "http3://")
if h3 { if h3 {
URL = strings.Replace(URL, "http3://", "https://", 1) URL = strings.Replace(URL, "http3://", "https://", 1)
@ -95,7 +91,7 @@ func (r *Resolver) newresolver(URL string) (childResolver, error) {
// getresolver returns a resolver with the given URL. This function caches // getresolver returns a resolver with the given URL. This function caches
// already allocated resolvers so we only allocate them once. // already allocated resolvers so we only allocate them once.
func (r *Resolver) getresolver(URL string) (childResolver, error) { func (r *Resolver) getresolver(URL string) (model.Resolver, error) {
defer r.mu.Unlock() defer r.mu.Unlock()
r.mu.Lock() r.mu.Lock()
if re, found := r.res[URL]; found { if re, found := r.res[URL]; found {
@ -106,7 +102,7 @@ func (r *Resolver) getresolver(URL string) (childResolver, error) {
return nil, err // config err? return nil, err // config err?
} }
if r.res == nil { if r.res == nil {
r.res = make(map[string]childResolver) r.res = make(map[string]model.Resolver)
} }
r.res[URL] = re r.res[URL] = re
return re, nil return re, nil

View File

@ -5,8 +5,9 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestDefaultByteCounter(t *testing.T) { func TestDefaultByteCounter(t *testing.T) {
@ -18,18 +19,33 @@ func TestDefaultByteCounter(t *testing.T) {
} }
func TestDefaultLogger(t *testing.T) { func TestDefaultLogger(t *testing.T) {
logger := &log.Logger{} t.Run("when using a different logger", func(t *testing.T) {
reso := &Resolver{Logger: logger} logger := &mocks.Logger{}
lo := reso.logger() reso := &Resolver{Logger: logger}
if lo != logger { lo := reso.logger()
t.Fatal("expected another logger here counter") if lo != logger {
} t.Fatal("expected another logger here")
}
})
t.Run("when no logger is set", func(t *testing.T) {
reso := &Resolver{Logger: nil}
lo := reso.logger()
if lo != model.DiscardLogger {
t.Fatal("expected another logger here")
}
})
} }
func TestGetResolverHTTPSStandard(t *testing.T) { func TestGetResolverHTTPSStandard(t *testing.T) {
bc := bytecounter.New() bc := bytecounter.New()
URL := "https://dns.google" URL := "https://dns.google"
re := &FakeResolver{} var closed bool
re := &mocks.Resolver{
MockCloseIdleConnections: func() {
closed = true
},
}
cmk := &fakeDNSClientMaker{reso: re} cmk := &fakeDNSClientMaker{reso: re}
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc} reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
out, err := reso.getresolver(URL) out, err := reso.getresolver(URL)
@ -47,7 +63,7 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
t.Fatal("not the result we expected") t.Fatal("not the result we expected")
} }
reso.closeall() reso.closeall()
if re.Closed != true { if closed != true {
t.Fatal("was not closed") t.Fatal("was not closed")
} }
if cmk.savedURL != URL { if cmk.savedURL != URL {
@ -62,7 +78,7 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
if cmk.savedConfig.HTTP3Enabled != false { if cmk.savedConfig.HTTP3Enabled != false {
t.Fatal("unexpected HTTP3Enabled") t.Fatal("unexpected HTTP3Enabled")
} }
if cmk.savedConfig.Logger != log.Log { if cmk.savedConfig.Logger != model.DiscardLogger {
t.Fatal("unexpected Log") t.Fatal("unexpected Log")
} }
} }
@ -70,7 +86,12 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
func TestGetResolverHTTP3(t *testing.T) { func TestGetResolverHTTP3(t *testing.T) {
bc := bytecounter.New() bc := bytecounter.New()
URL := "http3://dns.google" URL := "http3://dns.google"
re := &FakeResolver{} var closed bool
re := &mocks.Resolver{
MockCloseIdleConnections: func() {
closed = true
},
}
cmk := &fakeDNSClientMaker{reso: re} cmk := &fakeDNSClientMaker{reso: re}
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc} reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
out, err := reso.getresolver(URL) out, err := reso.getresolver(URL)
@ -88,7 +109,7 @@ func TestGetResolverHTTP3(t *testing.T) {
t.Fatal("not the result we expected") t.Fatal("not the result we expected")
} }
reso.closeall() reso.closeall()
if re.Closed != true { if closed != true {
t.Fatal("was not closed") t.Fatal("was not closed")
} }
if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) { if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) {
@ -103,7 +124,7 @@ func TestGetResolverHTTP3(t *testing.T) {
if cmk.savedConfig.HTTP3Enabled != true { if cmk.savedConfig.HTTP3Enabled != true {
t.Fatal("unexpected HTTP3Enabled") t.Fatal("unexpected HTTP3Enabled")
} }
if cmk.savedConfig.Logger != log.Log { if cmk.savedConfig.Logger != model.DiscardLogger {
t.Fatal("unexpected Log") t.Fatal("unexpected Log")
} }
} }

View File

@ -95,7 +95,7 @@ type Resolver struct {
// res maps a URL to a child resolver. We will // res maps a URL to a child resolver. We will
// construct child resolvers just once and we // construct child resolvers just once and we
// will track them into this field. // will track them into this field.
res map[string]childResolver res map[string]model.Resolver
} }
// CloseIdleConnections closes the idle connections, if any. This // CloseIdleConnections closes the idle connections, if any. This
@ -169,7 +169,7 @@ func (r *Resolver) lookupHost(ctx context.Context, ri *resolverinfo, hostname st
ri.Score = 0 // this is a hard error ri.Score = 0 // this is a hard error
return nil, err return nil, err
} }
addrs, err := r.timeLimitedLookup(ctx, re, hostname) addrs, err := timeLimitedLookup(ctx, re, hostname)
if err == nil { if err == nil {
r.logger().Infof("sessionresolver: %s... %v", ri.URL, model.ErrorToStringOrOK(nil)) r.logger().Infof("sessionresolver: %s... %v", ri.URL, model.ErrorToStringOrOK(nil))
ri.Score = ewma*1.0 + (1-ewma)*ri.Score // increase score ri.Score = ewma*1.0 + (1-ewma)*ri.Score // increase score

View File

@ -11,6 +11,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/kvstore" "github.com/ooni/probe-cli/v3/internal/kvstore"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/multierror" "github.com/ooni/probe-cli/v3/internal/multierror"
) )
@ -85,7 +86,11 @@ func TestTypicalUsageWithSuccess(t *testing.T) {
reso := &Resolver{ reso := &Resolver{
KVStore: &kvstore.Memory{}, KVStore: &kvstore.Memory{},
dnsClientMaker: &fakeDNSClientMaker{ dnsClientMaker: &fakeDNSClientMaker{
reso: &FakeResolver{Data: expected}, reso: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil
},
},
}, },
} }
addrs, err := reso.LookupHost(ctx, "dns.google") addrs, err := reso.LookupHost(ctx, "dns.google")
@ -117,7 +122,11 @@ func TestLittleLLookupHostWithSuccess(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"} expected := []string{"8.8.8.8", "8.8.4.4"}
reso := &Resolver{ reso := &Resolver{
dnsClientMaker: &fakeDNSClientMaker{ dnsClientMaker: &fakeDNSClientMaker{
reso: &FakeResolver{Data: expected}, reso: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil
},
},
}, },
} }
ctx := context.Background() ctx := context.Background()
@ -138,7 +147,11 @@ func TestLittleLLookupHostWithFailure(t *testing.T) {
errMocked := errors.New("mocked error") errMocked := errors.New("mocked error")
reso := &Resolver{ reso := &Resolver{
dnsClientMaker: &fakeDNSClientMaker{ dnsClientMaker: &fakeDNSClientMaker{
reso: &FakeResolver{Err: errMocked}, reso: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, errMocked
},
},
}, },
} }
ctx := context.Background() ctx := context.Background()