diff --git a/internal/engine/internal/sessionresolver/childresolver.go b/internal/engine/internal/sessionresolver/childresolver.go new file mode 100644 index 0000000..8848058 --- /dev/null +++ b/internal/engine/internal/sessionresolver/childresolver.go @@ -0,0 +1,27 @@ +package sessionresolver + +import ( + "context" + "time" +) + +// childResolver is the DNS client that this package uses +// to perform individual domain name resolutions. +type childResolver interface { + // LookupHost performs a DNS lookup. + LookupHost(ctx context.Context, domain string) ([]string, error) + + // CloseIdleConnections closes idle connections. + CloseIdleConnections() +} + +// timeLimitedLookup performs a time-limited lookup using the given re. +func (r *Resolver) timeLimitedLookup(ctx context.Context, re childResolver, hostname string) ([]string, error) { + // Algorithm similar to Firefox TRR2 mode. See: + // https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox + // We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side + // and therefore see to use DoH more often. + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + return re.LookupHost(ctx, hostname) +} diff --git a/internal/engine/internal/sessionresolver/childresolver_test.go b/internal/engine/internal/sessionresolver/childresolver_test.go new file mode 100644 index 0000000..8be08ff --- /dev/null +++ b/internal/engine/internal/sessionresolver/childresolver_test.go @@ -0,0 +1,80 @@ +package sessionresolver + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +type FakeResolver struct { + Closed bool + Data []string + Err error + Sleep time.Duration +} + +func (r *FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { + select { + case <-time.After(r.Sleep): + return r.Data, r.Err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (r *FakeResolver) CloseIdleConnections() { + r.Closed = true +} + +func TestTimeLimitedLookupSuccess(t *testing.T) { + reso := &Resolver{} + re := &FakeResolver{ + Data: []string{"8.8.8.8", "8.8.4.4"}, + } + ctx := context.Background() + out, err := reso.timeLimitedLookup(ctx, re, "dns.google") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(re.Data, out); diff != "" { + t.Fatal(diff) + } +} + +func TestTimeLimitedLookupFailure(t *testing.T) { + reso := &Resolver{} + re := &FakeResolver{ + Err: io.EOF, + } + ctx := context.Background() + out, err := reso.timeLimitedLookup(ctx, re, "dns.google") + if !errors.Is(err, re.Err) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestTimeLimitedLookupWillTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + reso := &Resolver{} + re := &FakeResolver{ + Err: io.EOF, + Sleep: 20 * time.Second, + } + ctx := context.Background() + out, err := reso.timeLimitedLookup(ctx, re, "dns.google") + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} diff --git a/internal/engine/internal/sessionresolver/clientmaker.go b/internal/engine/internal/sessionresolver/clientmaker.go new file mode 100644 index 0000000..8ca6f06 --- /dev/null +++ b/internal/engine/internal/sessionresolver/clientmaker.go @@ -0,0 +1,25 @@ +package sessionresolver + +import "github.com/ooni/probe-cli/v3/internal/engine/netx" + +// dnsclientmaker makes a new resolver. +type dnsclientmaker interface { + // Make makes a new resolver. + Make(config netx.Config, URL string) (childResolver, error) +} + +// clientmaker returns a valid dnsclientmaker +func (r *Resolver) clientmaker() dnsclientmaker { + if r.dnsClientMaker != nil { + return r.dnsClientMaker + } + return &defaultDNSClientMaker{} +} + +// defaultDNSClientMaker is the default dnsclientmaker +type defaultDNSClientMaker struct{} + +// Make implements dnsclientmaker.Make. +func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) { + return netx.NewDNSClient(config, URL) +} diff --git a/internal/engine/internal/sessionresolver/clientmaker_test.go b/internal/engine/internal/sessionresolver/clientmaker_test.go new file mode 100644 index 0000000..6db8855 --- /dev/null +++ b/internal/engine/internal/sessionresolver/clientmaker_test.go @@ -0,0 +1,52 @@ +package sessionresolver + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/ooni/probe-cli/v3/internal/engine/netx" +) + +type fakeDNSClientMaker struct { + reso childResolver + err error + savedConfig netx.Config + savedURL string +} + +func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) { + c.savedConfig = config + c.savedURL = URL + return c.reso, c.err +} + +func TestClientMakerWithOverride(t *testing.T) { + m := &fakeDNSClientMaker{err: io.EOF} + reso := &Resolver{dnsClientMaker: m} + out, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestClientDefaultWithCancelledContext(t *testing.T) { + reso := &Resolver{} + re, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query") + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + out, err := re.LookupHost(ctx, "dns.google") + if !errors.Is(err, context.Canceled) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil output") + } +} diff --git a/internal/engine/internal/sessionresolver/codec.go b/internal/engine/internal/sessionresolver/codec.go new file mode 100644 index 0000000..b73f080 --- /dev/null +++ b/internal/engine/internal/sessionresolver/codec.go @@ -0,0 +1,35 @@ +package sessionresolver + +import ( + "encoding/json" +) + +// codec is the codec we use. +type codec interface { + // Encode encodes v as a stream of bytes. + Encode(v interface{}) ([]byte, error) + + // Decode decodes b into a stream of bytes. + Decode(b []byte, v interface{}) error +} + +// getCodec always returns a valid codec. +func (r *Resolver) getCodec() codec { + if r.codec != nil { + return r.codec + } + return &defaultCodec{} +} + +// defaultCodec is the default codec. +type defaultCodec struct{} + +// Decode decodes b into v using the default codec. +func (*defaultCodec) Decode(b []byte, v interface{}) error { + return json.Unmarshal(b, v) +} + +// Encode encodes v using the default codec. +func (*defaultCodec) Encode(v interface{}) ([]byte, error) { + return json.Marshal(v) +} diff --git a/internal/engine/internal/sessionresolver/codec_test.go b/internal/engine/internal/sessionresolver/codec_test.go new file mode 100644 index 0000000..7687770 --- /dev/null +++ b/internal/engine/internal/sessionresolver/codec_test.go @@ -0,0 +1,48 @@ +package sessionresolver + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +type FakeCodec struct { + EncodeData []byte + EncodeErr error + DecodeErr error +} + +func (c *FakeCodec) Encode(v interface{}) ([]byte, error) { + return c.EncodeData, c.EncodeErr +} + +func (c *FakeCodec) Decode(b []byte, v interface{}) error { + return c.DecodeErr +} + +func TestCodecCustom(t *testing.T) { + c := &FakeCodec{} + reso := &Resolver{codec: c} + if r := reso.getCodec(); r != c { + t.Fatal("not the codec we expected") + } +} + +func TestCodecDefault(t *testing.T) { + reso := &Resolver{} + in := resolverinfo{ + URL: "https://dns.google/dns.query", + Score: 0.99, + } + data, err := reso.getCodec().Encode(in) + if err != nil { + t.Fatal(err) + } + var out resolverinfo + if err := reso.getCodec().Decode(data, &out); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(in, out); diff != "" { + t.Fatal(diff) + } +} diff --git a/internal/engine/internal/sessionresolver/dependencies.go b/internal/engine/internal/sessionresolver/dependencies.go new file mode 100644 index 0000000..1dd1515 --- /dev/null +++ b/internal/engine/internal/sessionresolver/dependencies.go @@ -0,0 +1,32 @@ +package sessionresolver + +// KVStore is a generic key-value store. We use it to store +// on disk persistent state used by this package. +type KVStore interface { + // Get gets the value for the given key. + Get(key string) ([]byte, error) + + // Set sets the value of the given key. + Set(key string, value []byte) error +} + +// Logger defines the common logger interface. +type Logger interface { + // Debug emits a debug message. + Debug(msg string) + + // Debugf formats and emits a debug message. + Debugf(format string, v ...interface{}) + + // Info emits an informational message. + Info(msg string) + + // Infof format and emits an informational message. + Infof(format string, v ...interface{}) + + // Warn emits a warning message. + Warn(msg string) + + // Warnf formats and emits a warning message. + Warnf(format string, v ...interface{}) +} diff --git a/internal/engine/internal/sessionresolver/errwrapper.go b/internal/engine/internal/sessionresolver/errwrapper.go new file mode 100644 index 0000000..d343854 --- /dev/null +++ b/internal/engine/internal/sessionresolver/errwrapper.go @@ -0,0 +1,23 @@ +package sessionresolver + +import ( + "errors" + "fmt" +) + +// errwrapper wraps an error to include the URL of the +// resolver that we're currently using. +type errwrapper struct { + error + URL string +} + +// Error implements error.Error. +func (ew *errwrapper) Error() string { + return fmt.Sprintf("<%s> %s", ew.URL, ew.error.Error()) +} + +// Is allows consumers to query for the type of the underlying error. +func (ew *errwrapper) Is(target error) bool { + return errors.Is(ew.error, target) +} diff --git a/internal/engine/internal/sessionresolver/errwrapper_test.go b/internal/engine/internal/sessionresolver/errwrapper_test.go new file mode 100644 index 0000000..44c2e3d --- /dev/null +++ b/internal/engine/internal/sessionresolver/errwrapper_test.go @@ -0,0 +1,24 @@ +package sessionresolver + +import ( + "errors" + "io" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestErrWrapper(t *testing.T) { + ew := &errwrapper{ + error: io.EOF, + URL: "https://dns.quad9.net/dns-query", + } + o := ew.Error() + expect := " EOF" + if diff := cmp.Diff(expect, o); diff != "" { + t.Fatal(diff) + } + if !errors.Is(ew, io.EOF) { + t.Fatal("not the sub-error we expected") + } +} diff --git a/internal/engine/internal/sessionresolver/integration_test.go b/internal/engine/internal/sessionresolver/integration_test.go new file mode 100644 index 0000000..f359aca --- /dev/null +++ b/internal/engine/internal/sessionresolver/integration_test.go @@ -0,0 +1,29 @@ +package sessionresolver_test + +import ( + "context" + "testing" + + "github.com/ooni/probe-cli/v3/internal/engine/internal/sessionresolver" +) + +func TestSessionResolverGood(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + reso := &sessionresolver.Resolver{} + defer reso.CloseIdleConnections() + if reso.Network() != "sessionresolver" { + t.Fatal("unexpected Network") + } + if reso.Address() != "" { + t.Fatal("unexpected Address") + } + addrs, err := reso.LookupHost(context.Background(), "google.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) < 1 { + t.Fatal("expected some addrs here") + } +} diff --git a/internal/engine/internal/sessionresolver/memkvstore.go b/internal/engine/internal/sessionresolver/memkvstore.go new file mode 100644 index 0000000..daacc1c --- /dev/null +++ b/internal/engine/internal/sessionresolver/memkvstore.go @@ -0,0 +1,43 @@ +package sessionresolver + +import ( + "errors" + "fmt" + "sync" +) + +func (r *Resolver) kvstore() KVStore { + defer r.mu.Unlock() + r.mu.Lock() + if r.KVStore == nil { + r.KVStore = &memkvstore{} + } + return r.KVStore +} + +var errMemkvstoreNotFound = errors.New("memkvstore: not found") + +type memkvstore struct { + m map[string][]byte + mu sync.Mutex +} + +func (kvs *memkvstore) Get(key string) ([]byte, error) { + defer kvs.mu.Unlock() + kvs.mu.Lock() + out, good := kvs.m[key] + if !good { + return nil, fmt.Errorf("%w: %s", errMemkvstoreNotFound, key) + } + return out, nil +} + +func (kvs *memkvstore) Set(key string, value []byte) error { + defer kvs.mu.Unlock() + kvs.mu.Lock() + if kvs.m == nil { + kvs.m = make(map[string][]byte) + } + kvs.m[key] = value + return nil +} diff --git a/internal/engine/internal/sessionresolver/memkvstore_test.go b/internal/engine/internal/sessionresolver/memkvstore_test.go new file mode 100644 index 0000000..b924ef4 --- /dev/null +++ b/internal/engine/internal/sessionresolver/memkvstore_test.go @@ -0,0 +1,47 @@ +package sessionresolver + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestKVStoreCustom(t *testing.T) { + kvs := &memkvstore{} + reso := &Resolver{KVStore: kvs} + o := reso.kvstore() + if o != kvs { + t.Fatal("not the kvstore we expected") + } +} + +func TestMemkvstoreGetNotFound(t *testing.T) { + reso := &Resolver{} + key := "antani" + out, err := reso.kvstore().Get(key) + if !errors.Is(err, errMemkvstoreNotFound) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestMemkvstoreRoundTrip(t *testing.T) { + reso := &Resolver{} + key := []string{"antani", "mascetti"} + value := [][]byte{[]byte(`mascetti`), []byte(`antani`)} + for idx := 0; idx < 2; idx++ { + if err := reso.kvstore().Set(key[idx], value[idx]); err != nil { + t.Fatal(err) + } + out, err := reso.kvstore().Get(key[idx]) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(value[idx], out); diff != "" { + t.Fatal(diff) + } + } +} diff --git a/internal/engine/internal/sessionresolver/resolvermaker.go b/internal/engine/internal/sessionresolver/resolvermaker.go new file mode 100644 index 0000000..95e30f6 --- /dev/null +++ b/internal/engine/internal/sessionresolver/resolvermaker.go @@ -0,0 +1,121 @@ +package sessionresolver + +import ( + "math/rand" + "strings" + "time" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/engine/netx" + "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" +) + +// resolvemaker contains rules for making a resolver. +type resolvermaker struct { + url string + score float64 +} + +// systemResolverURL is the URL of the system resolver. +const systemResolverURL = "system:///" + +// allmakers contains all the makers in a list. We use the http3 +// prefix to indicate we wanna use http3. The code will translate +// this to https and set the proper next options. +var allmakers = []*resolvermaker{{ + url: "https://cloudflare-dns.com/dns-query", +}, { + url: "http3://cloudflare-dns.com/dns-query", +}, { + url: "https://dns.google/dns-query", +}, { + url: "http3://dns.google/dns-query", +}, { + url: "https://dns.quad9.net/dns-query", +}, { + url: "https://doh.powerdns.org/", +}, { + url: systemResolverURL, +}, { + url: "https://mozilla.cloudflare-dns.com/dns-query", +}, { + url: "http3://mozilla.cloudflare-dns.com/dns-query", +}} + +// allbyurl contains all the resolvermakers by URL +var allbyurl map[string]*resolvermaker + +// init fills allbyname and gives a nonzero initial score +// to all resolvers except for the system resolver. We set +// the system resolver score to zero, so that it's less +// likely than other resolvers in this list. +func init() { + allbyurl = make(map[string]*resolvermaker) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + for _, e := range allmakers { + allbyurl[e.url] = e + if e.url != systemResolverURL { + e.score = rng.Float64() + } + } +} + +// byteCounter returns the configured byteCounter or a default +func (r *Resolver) byteCounter() *bytecounter.Counter { + if r.ByteCounter != nil { + return r.ByteCounter + } + return bytecounter.New() +} + +// logger returns the configured logger or a default +func (r *Resolver) logger() Logger { + if r.Logger != nil { + return r.Logger + } + return log.Log +} + +// newresolver creates a new resolver with the given config and URL. This is +// where we expand http3 to https and set the h3 options. +func (r *Resolver) newresolver(URL string) (childResolver, error) { + h3 := strings.HasPrefix(URL, "http3://") + if h3 { + URL = strings.Replace(URL, "http3://", "https://", 1) + } + return r.clientmaker().Make(netx.Config{ + BogonIsError: true, + ByteCounter: r.byteCounter(), + HTTP3Enabled: h3, + Logger: r.logger(), + }, URL) +} + +// getresolver returns a resolver with the given URL. This function caches +// already allocated resolvers so we only allocate them once. +func (r *Resolver) getresolver(URL string) (childResolver, error) { + defer r.mu.Unlock() + r.mu.Lock() + if re, found := r.res[URL]; found == true { + return re, nil // already created + } + re, err := r.newresolver(URL) + if err != nil { + return nil, err // config err? + } + if r.res == nil { + r.res = make(map[string]childResolver) + } + r.res[URL] = re + return re, nil +} + +// closeall closes the cached resolvers. +func (r *Resolver) closeall() { + defer r.mu.Unlock() + r.mu.Lock() + for _, re := range r.res { + re.CloseIdleConnections() + } + r.res = nil +} diff --git a/internal/engine/internal/sessionresolver/resolvermaker_test.go b/internal/engine/internal/sessionresolver/resolvermaker_test.go new file mode 100644 index 0000000..3e58cfc --- /dev/null +++ b/internal/engine/internal/sessionresolver/resolvermaker_test.go @@ -0,0 +1,124 @@ +package sessionresolver + +import ( + "errors" + "strings" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" +) + +func TestDefaultByteCounter(t *testing.T) { + reso := &Resolver{} + bc := reso.byteCounter() + if bc == nil { + t.Fatal("expected non-nil byte counter") + } +} + +func TestDefaultLogger(t *testing.T) { + logger := &log.Logger{} + reso := &Resolver{Logger: logger} + lo := reso.logger() + if lo != logger { + t.Fatal("expected another logger here counter") + } +} + +func TestGetResolverHTTPSStandard(t *testing.T) { + bc := bytecounter.New() + URL := "https://dns.google" + re := &FakeResolver{} + cmk := &fakeDNSClientMaker{reso: re} + reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc} + out, err := reso.getresolver(URL) + if err != nil { + t.Fatal(err) + } + if out != re { + t.Fatal("not the result we expected") + } + o2, err := reso.getresolver(URL) + if err != nil { + t.Fatal(err) + } + if out != o2 { + t.Fatal("not the result we expected") + } + reso.closeall() + if re.Closed != true { + t.Fatal("was not closed") + } + if cmk.savedURL != URL { + t.Fatal("not the URL we expected") + } + if cmk.savedConfig.ByteCounter != bc { + t.Fatal("unexpected ByteCounter") + } + if cmk.savedConfig.BogonIsError != true { + t.Fatal("unexpected BogonIsError") + } + if cmk.savedConfig.HTTP3Enabled != false { + t.Fatal("unexpected HTTP3Enabled") + } + if cmk.savedConfig.Logger != log.Log { + t.Fatal("unexpected Log") + } +} + +func TestGetResolverHTTP3(t *testing.T) { + bc := bytecounter.New() + URL := "http3://dns.google" + re := &FakeResolver{} + cmk := &fakeDNSClientMaker{reso: re} + reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc} + out, err := reso.getresolver(URL) + if err != nil { + t.Fatal(err) + } + if out != re { + t.Fatal("not the result we expected") + } + o2, err := reso.getresolver(URL) + if err != nil { + t.Fatal(err) + } + if out != o2 { + t.Fatal("not the result we expected") + } + reso.closeall() + if re.Closed != true { + t.Fatal("was not closed") + } + if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) { + t.Fatal("not the URL we expected") + } + if cmk.savedConfig.ByteCounter != bc { + t.Fatal("unexpected ByteCounter") + } + if cmk.savedConfig.BogonIsError != true { + t.Fatal("unexpected BogonIsError") + } + if cmk.savedConfig.HTTP3Enabled != true { + t.Fatal("unexpected HTTP3Enabled") + } + if cmk.savedConfig.Logger != log.Log { + t.Fatal("unexpected Log") + } +} + +func TestGetResolverInvalidURL(t *testing.T) { + bc := bytecounter.New() + URL := "http3://dns.google" + errMocked := errors.New("mocked error") + cmk := &fakeDNSClientMaker{err: errMocked} + reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc} + out, err := reso.getresolver(URL) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("not the result we expected") + } +} diff --git a/internal/engine/internal/sessionresolver/sessionresolver.go b/internal/engine/internal/sessionresolver/sessionresolver.go index 190d5f1..8613a21 100644 --- a/internal/engine/internal/sessionresolver/sessionresolver.go +++ b/internal/engine/internal/sessionresolver/sessionresolver.go @@ -1,85 +1,134 @@ // Package sessionresolver contains the resolver used by the session. This -// resolver uses Powerdns DoH by default and falls back on the system -// provided resolver if Powerdns DoH is not working. +// resolver will try to figure out which is the best service for running +// domain name resolutions and will consistently use it. +// +// Occasionally this code will also swap the best resolver with other +// ~good resolvers to give them a chance to perform. +// +// The penalty/reward mechanism is strongly derivative, so the code should +// adapt ~quickly to changing network conditions. Occasionally, we will +// have longer resolutions when trying out other resolvers. +// +// At the beginning we randomize the known resolvers so that we do not +// have any preferential ordering. The initial resolutions may be slower +// if there are many issues with resolvers. +// +// The system resolver is given the lowest priority at the beginning +// but it will of course be the most popular resolver if anything else +// is failing us. (We will still occasionally probe for other working +// resolvers and increase their score on success.) package sessionresolver import ( "context" + "encoding/json" + "errors" "fmt" + "math/rand" + "sync" "time" - "github.com/ooni/probe-cli/v3/internal/engine/atomicx" - "github.com/ooni/probe-cli/v3/internal/engine/netx" + "github.com/ooni/probe-cli/v3/internal/engine/internal/multierror" + "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" "github.com/ooni/probe-cli/v3/internal/engine/runtimex" ) -// Resolver is the session resolver. +// Resolver is the session resolver. You should create an instance of +// this structure and use it in session.go. type Resolver struct { - Primary netx.DNSClient - PrimaryFailure *atomicx.Int64 - PrimaryQuery *atomicx.Int64 - Fallback netx.DNSClient - FallbackFailure *atomicx.Int64 - FallbackQuery *atomicx.Int64 + ByteCounter *bytecounter.Counter // optional + KVStore KVStore // optional + Logger Logger // optional + codec codec + dnsClientMaker dnsclientmaker + mu sync.Mutex + once sync.Once + res map[string]childResolver } -// New creates a new session resolver. -func New(config netx.Config) *Resolver { - primary, err := netx.NewDNSClientWithOverrides(config, - "https://cloudflare.com/dns-query", "dns.cloudflare.com", "", "") - runtimex.PanicOnError(err, "cannot create dns over https resolver") - fallback, err := netx.NewDNSClient(config, "system:///") - runtimex.PanicOnError(err, "cannot create system resolver") - return &Resolver{ - Primary: primary, - PrimaryFailure: atomicx.NewInt64(), - PrimaryQuery: atomicx.NewInt64(), - Fallback: fallback, - FallbackFailure: atomicx.NewInt64(), - FallbackQuery: atomicx.NewInt64(), - } -} - -// CloseIdleConnections closes the idle connections, if any +// CloseIdleConnections closes the idle connections, if any. This +// function is guaranteed to be idempotent. func (r *Resolver) CloseIdleConnections() { - r.Primary.CloseIdleConnections() - r.Fallback.CloseIdleConnections() + r.once.Do(r.closeall) } // Stats returns stats about the session resolver. func (r *Resolver) Stats() string { - return fmt.Sprintf("sessionresolver: failure rate: primary: %d/%d; fallback: %d/%d", - r.PrimaryFailure.Load(), r.PrimaryQuery.Load(), - r.FallbackFailure.Load(), r.FallbackQuery.Load()) + data, err := json.Marshal(r.readstatedefault()) + runtimex.PanicOnError(err, "json.Marshal should not fail here") + return fmt.Sprintf("sessionresolver: %s", string(data)) } -// LookupHost implements Resolver.LookupHost +// ErrLookupHost indicates that LookupHost failed. +var ErrLookupHost = errors.New("sessionresolver: LookupHost failed") + +// LookupHost implements Resolver.LookupHost. This function returns a +// multierror.Union error on failure, so you can see individual errors +// and get a better picture of what's been going wrong. func (r *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { - // Algorithm similar to Firefox TRR2 mode. See: - // https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox - // We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side - // and therefore see to use DoH more often. - r.PrimaryQuery.Add(1) - trr2, cancel := context.WithTimeout(ctx, 4*time.Second) - defer cancel() - addrs, err := r.Primary.LookupHost(trr2, hostname) - if err != nil { - r.PrimaryFailure.Add(1) - r.FallbackQuery.Add(1) - addrs, err = r.Fallback.LookupHost(ctx, hostname) - if err != nil { - r.FallbackFailure.Add(1) + state := r.readstatedefault() + r.maybeConfusion(state, time.Now().UnixNano()) + defer r.writestate(state) + me := multierror.New(ErrLookupHost) + for _, e := range state { + addrs, err := r.lookupHost(ctx, e, hostname) + if err == nil { + return addrs, nil } + me.Add(&errwrapper{error: err, URL: e.URL}) } - return addrs, err + return nil, me } -// Network implements Resolver.Network +func (r *Resolver) lookupHost(ctx context.Context, ri *resolverinfo, hostname string) ([]string, error) { + const ewma = 0.9 // the last sample is very important + re, err := r.getresolver(ri.URL) + if err != nil { + r.logger().Warnf("sessionresolver: getresolver: %s", err.Error()) + ri.Score = 0 // this is a hard error + return nil, err + } + addrs, err := r.timeLimitedLookup(ctx, re, hostname) + if err == nil { + r.logger().Infof("sessionresolver: %s... %v", ri.URL, nil) + ri.Score = ewma*1.0 + (1-ewma)*ri.Score // increase score + return addrs, nil + } + r.logger().Warnf("sessionresolver: %s... %s", ri.URL, err.Error()) + ri.Score = ewma*0.0 + (1-ewma)*ri.Score // decrease score + return nil, err +} + +// maybeConfusion will rearrange the first elements of the vector +// with low probability, so giving other resolvers a chance +// to run and show that they are also viable. We do not fully +// reorder the vector because that could lead to long runtimes. +// +// The return value is only meaningful for testing. +func (r *Resolver) maybeConfusion(state []*resolverinfo, seed int64) int { + rng := rand.New(rand.NewSource(seed)) + const confusion = 0.3 + if rng.Float64() >= confusion { + return -1 + } + switch len(state) { + case 0, 1: // nothing to do + return 0 + case 2: + state[0], state[1] = state[1], state[0] + return 2 + default: + state[0], state[2] = state[2], state[0] + return 3 + } +} + +// Network implements Resolver.Network. func (r *Resolver) Network() string { return "sessionresolver" } -// Address implements Resolver.Address +// Address implements Resolver.Address. func (r *Resolver) Address() string { return "" } diff --git a/internal/engine/internal/sessionresolver/sessionresolver_test.go b/internal/engine/internal/sessionresolver/sessionresolver_test.go index 727dd26..b83cde8 100644 --- a/internal/engine/internal/sessionresolver/sessionresolver_test.go +++ b/internal/engine/internal/sessionresolver/sessionresolver_test.go @@ -1,31 +1,249 @@ -package sessionresolver_test +package sessionresolver import ( "context" + "errors" + "net" "strings" "testing" - "github.com/ooni/probe-cli/v3/internal/engine/internal/sessionresolver" - "github.com/ooni/probe-cli/v3/internal/engine/netx" + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/engine/internal/multierror" ) -func TestFallbackWorks(t *testing.T) { - reso := sessionresolver.New(netx.Config{}) - defer reso.CloseIdleConnections() +func TestNetworkWorks(t *testing.T) { + reso := &Resolver{} if reso.Network() != "sessionresolver" { - t.Fatal("unexpected Network") + t.Fatal("unexpected value returned by Network") } +} + +func TestAddressWorks(t *testing.T) { + reso := &Resolver{} if reso.Address() != "" { - t.Fatal("unexpected Address") + t.Fatal("unexpected value returned by Address") } - addrs, err := reso.LookupHost(context.Background(), "antani.ooni.nu") - if err == nil || !strings.HasSuffix(err.Error(), "no such host") { - t.Fatal("not the error we expected") +} + +func TestTypicalUsageWithFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + reso := &Resolver{} + addrs, err := reso.LookupHost(ctx, "ooni.org") + if !errors.Is(err, ErrLookupHost) { + t.Fatal("not the error we expected", err) + } + var me *multierror.Union + if !errors.As(err, &me) { + t.Fatal("cannot convert error") + } + for _, child := range me.Children { + // net.DNSError does not include the underlying error + // but just a string representing the error. This + // means that we need to go down hunting what's the + // real error that occurred and use more verbose code. + { + var errWrapper *errwrapper + if !errors.As(child, &errWrapper) { + t.Fatal("not an instance of errwrapper") + } + var dnsError *net.DNSError + if errors.As(errWrapper.error, &dnsError) { + if !strings.HasSuffix(dnsError.Err, "operation was canceled") { + t.Fatal("not the error we expected", dnsError.Err) + } + continue + } + } + // otherwise just unwrap and check whether it's + // a real context.Canceled error. + if !errors.Is(child, context.Canceled) { + t.Fatal("unexpected sub-error", child) + } + } + if addrs != nil { + t.Fatal("expected nil here") + } + if len(reso.res) < 1 { + t.Fatal("expected to see some resolvers here") + } + if reso.Stats() == "" { + t.Fatal("expected to see some string returned by stats") + } + reso.CloseIdleConnections() + if len(reso.res) != 0 { + t.Fatal("expected to see no resolvers after CloseIdleConnections") + } +} + +func TestTypicalUsageWithSuccess(t *testing.T) { + expected := []string{"8.8.8.8", "8.8.4.4"} + ctx := context.Background() + reso := &Resolver{ + dnsClientMaker: &fakeDNSClientMaker{ + reso: &FakeResolver{Data: expected}, + }, + } + addrs, err := reso.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expected, addrs); diff != "" { + t.Fatal(diff) + } +} + +func TestLittleLLookupHostWithInvalidURL(t *testing.T) { + reso := &Resolver{} + ctx := context.Background() + ri := &resolverinfo{URL: "\t\t\t", Score: 0.99} + addrs, err := reso.lookupHost(ctx, ri, "ooni.org") + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) } if addrs != nil { t.Fatal("expected nil addrs here") } - if reso.PrimaryFailure.Load() != 1 || reso.FallbackFailure.Load() != 1 { - t.Fatal("not the counters we expected to see here") + if ri.Score != 0 { + t.Fatal("unexpected ri.Score", ri.Score) + } +} + +func TestLittleLLookupHostWithSuccess(t *testing.T) { + expected := []string{"8.8.8.8", "8.8.4.4"} + reso := &Resolver{ + dnsClientMaker: &fakeDNSClientMaker{ + reso: &FakeResolver{Data: expected}, + }, + } + ctx := context.Background() + ri := &resolverinfo{URL: "dot://dns-nonexistent.ooni.org", Score: 0.1} + addrs, err := reso.lookupHost(ctx, ri, "dns.google") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expected, addrs); diff != "" { + t.Fatal(diff) + } + if ri.Score < 0.88 || ri.Score > 0.92 { + t.Fatal("unexpected score", ri.Score) + } +} + +func TestLittleLLookupHostWithFailure(t *testing.T) { + errMocked := errors.New("mocked error") + reso := &Resolver{ + dnsClientMaker: &fakeDNSClientMaker{ + reso: &FakeResolver{Err: errMocked}, + }, + } + ctx := context.Background() + ri := &resolverinfo{URL: "dot://dns-nonexistent.ooni.org", Score: 0.95} + addrs, err := reso.lookupHost(ctx, ri, "dns.google") + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil addrs here") + } + if ri.Score < 0.094 || ri.Score > 0.096 { + t.Fatal("unexpected score", ri.Score) + } +} + +func TestMaybeConfusionNoConfusion(t *testing.T) { + reso := &Resolver{} + rv := reso.maybeConfusion(nil, 0) + if rv != -1 { + t.Fatal("unexpected return value", rv) + } +} + +func TestMaybeConfusionNoArray(t *testing.T) { + reso := &Resolver{} + rv := reso.maybeConfusion(nil, 11) + if rv != 0 { + t.Fatal("unexpected return value", rv) + } +} + +func TestMaybeConfusionSingleEntry(t *testing.T) { + reso := &Resolver{} + state := []*resolverinfo{{}} + rv := reso.maybeConfusion(state, 11) + if rv != 0 { + t.Fatal("unexpected return value", rv) + } +} + +func TestMaybeConfusionTwoEntries(t *testing.T) { + reso := &Resolver{} + state := []*resolverinfo{{ + Score: 0.8, + URL: "https://dns.google/dns-query", + }, { + Score: 0.4, + URL: "http3://dns.google/dns-query", + }} + rv := reso.maybeConfusion(state, 11) + if rv != 2 { + t.Fatal("unexpected return value", rv) + } + if state[0].Score != 0.4 { + t.Fatal("unexpected state[0].Score") + } + if state[0].URL != "http3://dns.google/dns-query" { + t.Fatal("unexpected state[0].URL") + } + if state[1].Score != 0.8 { + t.Fatal("unexpected state[1].Score") + } + if state[1].URL != "https://dns.google/dns-query" { + t.Fatal("unexpected state[1].URL") + } +} + +func TestMaybeConfusionManyEntries(t *testing.T) { + reso := &Resolver{} + state := []*resolverinfo{{ + Score: 0.8, + URL: "https://dns.google/dns-query", + }, { + Score: 0.4, + URL: "http3://dns.google/dns-query", + }, { + Score: 0.1, + URL: "system:///", + }, { + Score: 0.01, + URL: "dot://dns.google", + }} + rv := reso.maybeConfusion(state, 11) + if rv != 3 { + t.Fatal("unexpected return value", rv) + } + if state[0].Score != 0.1 { + t.Fatal("unexpected state[0].Score") + } + if state[0].URL != "system:///" { + t.Fatal("unexpected state[0].URL") + } + if state[1].Score != 0.4 { + t.Fatal("unexpected state[1].Score") + } + if state[1].URL != "http3://dns.google/dns-query" { + t.Fatal("unexpected state[1].URL") + } + if state[2].Score != 0.8 { + t.Fatal("unexpected state[2].Score") + } + if state[2].URL != "https://dns.google/dns-query" { + t.Fatal("unexpected state[2].URL") + } + if state[3].Score != 0.01 { + t.Fatal("unexpected state[3].Score") + } + if state[3].URL != "dot://dns.google" { + t.Fatal("unexpected state[3].URL") } } diff --git a/internal/engine/internal/sessionresolver/state.go b/internal/engine/internal/sessionresolver/state.go new file mode 100644 index 0000000..b725f9a --- /dev/null +++ b/internal/engine/internal/sessionresolver/state.go @@ -0,0 +1,93 @@ +package sessionresolver + +import ( + "errors" + "sort" +) + +// storekey is the key used by the key value store to store +// the state required by this package. +const storekey = "sessionresolver.state" + +// resolverinfo contains info about a resolver. +type resolverinfo struct { + // URL is the URL of a resolver. + URL string + + // Score is the score of a resolver. + Score float64 +} + +// readstate reads the resolver state from disk +func (r *Resolver) readstate() ([]*resolverinfo, error) { + data, err := r.kvstore().Get(storekey) + if err != nil { + return nil, err + } + var ri []*resolverinfo + if err := r.getCodec().Decode(data, &ri); err != nil { + return nil, err + } + return ri, nil +} + +// errNoEntries indicates that no entry remained after we pruned +// all the available entries in readstateandprune. +var errNoEntries = errors.New("sessionresolver: no available entries") + +// readstateandprune reads the state from disk and removes all the +// entries that we don't actually support. +func (r *Resolver) readstateandprune() ([]*resolverinfo, error) { + ri, err := r.readstate() + if err != nil { + return nil, err + } + var out []*resolverinfo + for _, e := range ri { + if _, found := allbyurl[e.URL]; !found { + continue // we don't support this specific entry + } + out = append(out, e) + } + if len(out) <= 0 { + return nil, errNoEntries + } + return out, nil +} + +// sortstate sorts the state by descending score +func sortstate(ri []*resolverinfo) { + sort.SliceStable(ri, func(i, j int) bool { + return ri[i].Score >= ri[j].Score + }) +} + +// readstatedefault reads the state from disk and merges the state +// so that all supported entries are represented. +func (r *Resolver) readstatedefault() []*resolverinfo { + ri, _ := r.readstateandprune() + here := make(map[string]bool) + for _, e := range ri { + here[e.URL] = true // record what we already have + } + for _, e := range allmakers { + if _, found := here[e.url]; found { + continue // already here so no need to add + } + ri = append(ri, &resolverinfo{ + URL: e.url, + Score: e.score, + }) + } + sortstate(ri) + return ri +} + +// writestate writes the state on the kvstore. +func (r *Resolver) writestate(ri []*resolverinfo) error { + data, err := r.getCodec().Encode(ri) + if err != nil { + return err + } + return r.kvstore().Set(storekey, data) +} diff --git a/internal/engine/internal/sessionresolver/state_test.go b/internal/engine/internal/sessionresolver/state_test.go new file mode 100644 index 0000000..168ee5c --- /dev/null +++ b/internal/engine/internal/sessionresolver/state_test.go @@ -0,0 +1,120 @@ +package sessionresolver + +import ( + "errors" + "testing" +) + +func TestReadStateNothingInKVStore(t *testing.T) { + reso := &Resolver{KVStore: &memkvstore{}} + out, err := reso.readstate() + if !errors.Is(err, errMemkvstoreNotFound) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestReadStateDecodeError(t *testing.T) { + errMocked := errors.New("mocked error") + reso := &Resolver{ + KVStore: &memkvstore{}, + codec: &FakeCodec{DecodeErr: errMocked}, + } + if err := reso.KVStore.Set(storekey, []byte(`[]`)); err != nil { + t.Fatal(err) + } + out, err := reso.readstate() + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestReadStateAndPruneReadStateError(t *testing.T) { + reso := &Resolver{KVStore: &memkvstore{}} + out, err := reso.readstateandprune() + if !errors.Is(err, errMemkvstoreNotFound) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestReadStateAndPruneWithUnsupportedEntries(t *testing.T) { + reso := &Resolver{KVStore: &memkvstore{}} + var in []*resolverinfo + in = append(in, &resolverinfo{}) + if err := reso.writestate(in); err != nil { + t.Fatal(err) + } + out, err := reso.readstateandprune() + if !errors.Is(err, errNoEntries) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestReadStateDefaultWithMissingEntries(t *testing.T) { + reso := &Resolver{KVStore: &memkvstore{}} + // let us simulate that we have just one entry here + existingURL := "https://dns.google/dns-query" + existingScore := 0.88 + var in []*resolverinfo + in = append(in, &resolverinfo{ + URL: existingURL, + Score: existingScore, + }) + if err := reso.writestate(in); err != nil { + t.Fatal(err) + } + // let us seee what we read + out := reso.readstatedefault() + if len(out) < 1 { + t.Fatal("expected non-empty output") + } + keys := make(map[string]bool) + var found bool + for _, e := range out { + keys[e.URL] = true + if e.URL == existingURL { + if e.Score != existingScore { + t.Fatal("the score is not what we expected") + } + found = true + } + } + if !found { + t.Fatal("did not found the pre-loaded URL") + } + for k := range allbyurl { + if _, found := keys[k]; !found { + t.Fatal("missing key", k) + } + } +} + +func TestWriteStateCannotSerialize(t *testing.T) { + errMocked := errors.New("mocked error") + reso := &Resolver{ + codec: &FakeCodec{ + EncodeErr: errMocked, + }, + } + existingURL := "https://dns.google/dns-query" + existingScore := 0.88 + var in []*resolverinfo + in = append(in, &resolverinfo{ + URL: existingURL, + Score: existingScore, + }) + if err := reso.writestate(in); !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } +} diff --git a/internal/engine/session.go b/internal/engine/session.go index b4ad9f8..f6c8e69 100644 --- a/internal/engine/session.go +++ b/internal/engine/session.go @@ -109,7 +109,11 @@ func NewSession(config SessionConfig) (*Session, error) { BogonIsError: true, Logger: sess.logger, } - sess.resolver = sessionresolver.New(httpConfig) + sess.resolver = &sessionresolver.Resolver{ + ByteCounter: sess.byteCounter, + KVStore: config.KVStore, + Logger: sess.logger, + } httpConfig.FullResolver = sess.resolver httpConfig.ProxyURL = config.ProxyURL // no need to proxy the resolver sess.httpDefaultTransport = netx.NewHTTPTransport(httpConfig)