refactor(sessionresolver): adapt to changing network conditions (#238)

* feat(sessionresolver): try many and use what works

* fix(sessionresolver): make sure we can use quic

* fix: the config struct is unnecessary

* fix: make kvstore optional

* feat: write simple integration test

* feat: start adding tests

* feat: continue writing tests

* fix(sessionresolver): add more unit tests

* fix(sessionresolver): finish adding tests

* refactor(sessionresolver): changes after code review
This commit is contained in:
Simone Basso 2021-03-03 11:28:39 +01:00 committed by GitHub
parent 12e1164940
commit 034db78f94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1260 additions and 66 deletions

View File

@ -0,0 +1,27 @@
package sessionresolver
import (
"context"
"time"
)
// childResolver is the DNS client that this package uses
// to perform individual domain name resolutions.
type childResolver interface {
// LookupHost performs a DNS lookup.
LookupHost(ctx context.Context, domain string) ([]string, error)
// CloseIdleConnections closes idle connections.
CloseIdleConnections()
}
// timeLimitedLookup performs a time-limited lookup using the given re.
func (r *Resolver) timeLimitedLookup(ctx context.Context, re childResolver, hostname string) ([]string, error) {
// Algorithm similar to Firefox TRR2 mode. See:
// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox
// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side
// and therefore see to use DoH more often.
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
return re.LookupHost(ctx, hostname)
}

View File

@ -0,0 +1,80 @@
package sessionresolver
import (
"context"
"errors"
"io"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
type FakeResolver struct {
Closed bool
Data []string
Err error
Sleep time.Duration
}
func (r *FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
select {
case <-time.After(r.Sleep):
return r.Data, r.Err
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (r *FakeResolver) CloseIdleConnections() {
r.Closed = true
}
func TestTimeLimitedLookupSuccess(t *testing.T) {
reso := &Resolver{}
re := &FakeResolver{
Data: []string{"8.8.8.8", "8.8.4.4"},
}
ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(re.Data, out); diff != "" {
t.Fatal(diff)
}
}
func TestTimeLimitedLookupFailure(t *testing.T) {
reso := &Resolver{}
re := &FakeResolver{
Err: io.EOF,
}
ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
if !errors.Is(err, re.Err) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}
func TestTimeLimitedLookupWillTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode")
}
reso := &Resolver{}
re := &FakeResolver{
Err: io.EOF,
Sleep: 20 * time.Second,
}
ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}

View File

@ -0,0 +1,25 @@
package sessionresolver
import "github.com/ooni/probe-cli/v3/internal/engine/netx"
// dnsclientmaker makes a new resolver.
type dnsclientmaker interface {
// Make makes a new resolver.
Make(config netx.Config, URL string) (childResolver, error)
}
// clientmaker returns a valid dnsclientmaker
func (r *Resolver) clientmaker() dnsclientmaker {
if r.dnsClientMaker != nil {
return r.dnsClientMaker
}
return &defaultDNSClientMaker{}
}
// defaultDNSClientMaker is the default dnsclientmaker
type defaultDNSClientMaker struct{}
// Make implements dnsclientmaker.Make.
func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) {
return netx.NewDNSClient(config, URL)
}

View File

@ -0,0 +1,52 @@
package sessionresolver
import (
"context"
"errors"
"io"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx"
)
type fakeDNSClientMaker struct {
reso childResolver
err error
savedConfig netx.Config
savedURL string
}
func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) {
c.savedConfig = config
c.savedURL = URL
return c.reso, c.err
}
func TestClientMakerWithOverride(t *testing.T) {
m := &fakeDNSClientMaker{err: io.EOF}
reso := &Resolver{dnsClientMaker: m}
out, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}
func TestClientDefaultWithCancelledContext(t *testing.T) {
reso := &Resolver{}
re, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
out, err := re.LookupHost(ctx, "dns.google")
if !errors.Is(err, context.Canceled) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil output")
}
}

View File

@ -0,0 +1,35 @@
package sessionresolver
import (
"encoding/json"
)
// codec is the codec we use.
type codec interface {
// Encode encodes v as a stream of bytes.
Encode(v interface{}) ([]byte, error)
// Decode decodes b into a stream of bytes.
Decode(b []byte, v interface{}) error
}
// getCodec always returns a valid codec.
func (r *Resolver) getCodec() codec {
if r.codec != nil {
return r.codec
}
return &defaultCodec{}
}
// defaultCodec is the default codec.
type defaultCodec struct{}
// Decode decodes b into v using the default codec.
func (*defaultCodec) Decode(b []byte, v interface{}) error {
return json.Unmarshal(b, v)
}
// Encode encodes v using the default codec.
func (*defaultCodec) Encode(v interface{}) ([]byte, error) {
return json.Marshal(v)
}

View File

@ -0,0 +1,48 @@
package sessionresolver
import (
"testing"
"github.com/google/go-cmp/cmp"
)
type FakeCodec struct {
EncodeData []byte
EncodeErr error
DecodeErr error
}
func (c *FakeCodec) Encode(v interface{}) ([]byte, error) {
return c.EncodeData, c.EncodeErr
}
func (c *FakeCodec) Decode(b []byte, v interface{}) error {
return c.DecodeErr
}
func TestCodecCustom(t *testing.T) {
c := &FakeCodec{}
reso := &Resolver{codec: c}
if r := reso.getCodec(); r != c {
t.Fatal("not the codec we expected")
}
}
func TestCodecDefault(t *testing.T) {
reso := &Resolver{}
in := resolverinfo{
URL: "https://dns.google/dns.query",
Score: 0.99,
}
data, err := reso.getCodec().Encode(in)
if err != nil {
t.Fatal(err)
}
var out resolverinfo
if err := reso.getCodec().Decode(data, &out); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(in, out); diff != "" {
t.Fatal(diff)
}
}

View File

@ -0,0 +1,32 @@
package sessionresolver
// KVStore is a generic key-value store. We use it to store
// on disk persistent state used by this package.
type KVStore interface {
// Get gets the value for the given key.
Get(key string) ([]byte, error)
// Set sets the value of the given key.
Set(key string, value []byte) error
}
// Logger defines the common logger interface.
type Logger interface {
// Debug emits a debug message.
Debug(msg string)
// Debugf formats and emits a debug message.
Debugf(format string, v ...interface{})
// Info emits an informational message.
Info(msg string)
// Infof format and emits an informational message.
Infof(format string, v ...interface{})
// Warn emits a warning message.
Warn(msg string)
// Warnf formats and emits a warning message.
Warnf(format string, v ...interface{})
}

View File

@ -0,0 +1,23 @@
package sessionresolver
import (
"errors"
"fmt"
)
// errwrapper wraps an error to include the URL of the
// resolver that we're currently using.
type errwrapper struct {
error
URL string
}
// Error implements error.Error.
func (ew *errwrapper) Error() string {
return fmt.Sprintf("<%s> %s", ew.URL, ew.error.Error())
}
// Is allows consumers to query for the type of the underlying error.
func (ew *errwrapper) Is(target error) bool {
return errors.Is(ew.error, target)
}

View File

@ -0,0 +1,24 @@
package sessionresolver
import (
"errors"
"io"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestErrWrapper(t *testing.T) {
ew := &errwrapper{
error: io.EOF,
URL: "https://dns.quad9.net/dns-query",
}
o := ew.Error()
expect := "<https://dns.quad9.net/dns-query> EOF"
if diff := cmp.Diff(expect, o); diff != "" {
t.Fatal(diff)
}
if !errors.Is(ew, io.EOF) {
t.Fatal("not the sub-error we expected")
}
}

View File

@ -0,0 +1,29 @@
package sessionresolver_test
import (
"context"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/internal/sessionresolver"
)
func TestSessionResolverGood(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode")
}
reso := &sessionresolver.Resolver{}
defer reso.CloseIdleConnections()
if reso.Network() != "sessionresolver" {
t.Fatal("unexpected Network")
}
if reso.Address() != "" {
t.Fatal("unexpected Address")
}
addrs, err := reso.LookupHost(context.Background(), "google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) < 1 {
t.Fatal("expected some addrs here")
}
}

View File

@ -0,0 +1,43 @@
package sessionresolver
import (
"errors"
"fmt"
"sync"
)
func (r *Resolver) kvstore() KVStore {
defer r.mu.Unlock()
r.mu.Lock()
if r.KVStore == nil {
r.KVStore = &memkvstore{}
}
return r.KVStore
}
var errMemkvstoreNotFound = errors.New("memkvstore: not found")
type memkvstore struct {
m map[string][]byte
mu sync.Mutex
}
func (kvs *memkvstore) Get(key string) ([]byte, error) {
defer kvs.mu.Unlock()
kvs.mu.Lock()
out, good := kvs.m[key]
if !good {
return nil, fmt.Errorf("%w: %s", errMemkvstoreNotFound, key)
}
return out, nil
}
func (kvs *memkvstore) Set(key string, value []byte) error {
defer kvs.mu.Unlock()
kvs.mu.Lock()
if kvs.m == nil {
kvs.m = make(map[string][]byte)
}
kvs.m[key] = value
return nil
}

View File

@ -0,0 +1,47 @@
package sessionresolver
import (
"errors"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestKVStoreCustom(t *testing.T) {
kvs := &memkvstore{}
reso := &Resolver{KVStore: kvs}
o := reso.kvstore()
if o != kvs {
t.Fatal("not the kvstore we expected")
}
}
func TestMemkvstoreGetNotFound(t *testing.T) {
reso := &Resolver{}
key := "antani"
out, err := reso.kvstore().Get(key)
if !errors.Is(err, errMemkvstoreNotFound) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}
func TestMemkvstoreRoundTrip(t *testing.T) {
reso := &Resolver{}
key := []string{"antani", "mascetti"}
value := [][]byte{[]byte(`mascetti`), []byte(`antani`)}
for idx := 0; idx < 2; idx++ {
if err := reso.kvstore().Set(key[idx], value[idx]); err != nil {
t.Fatal(err)
}
out, err := reso.kvstore().Get(key[idx])
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(value[idx], out); diff != "" {
t.Fatal(diff)
}
}
}

View File

@ -0,0 +1,121 @@
package sessionresolver
import (
"math/rand"
"strings"
"time"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
)
// resolvemaker contains rules for making a resolver.
type resolvermaker struct {
url string
score float64
}
// systemResolverURL is the URL of the system resolver.
const systemResolverURL = "system:///"
// allmakers contains all the makers in a list. We use the http3
// prefix to indicate we wanna use http3. The code will translate
// this to https and set the proper next options.
var allmakers = []*resolvermaker{{
url: "https://cloudflare-dns.com/dns-query",
}, {
url: "http3://cloudflare-dns.com/dns-query",
}, {
url: "https://dns.google/dns-query",
}, {
url: "http3://dns.google/dns-query",
}, {
url: "https://dns.quad9.net/dns-query",
}, {
url: "https://doh.powerdns.org/",
}, {
url: systemResolverURL,
}, {
url: "https://mozilla.cloudflare-dns.com/dns-query",
}, {
url: "http3://mozilla.cloudflare-dns.com/dns-query",
}}
// allbyurl contains all the resolvermakers by URL
var allbyurl map[string]*resolvermaker
// init fills allbyname and gives a nonzero initial score
// to all resolvers except for the system resolver. We set
// the system resolver score to zero, so that it's less
// likely than other resolvers in this list.
func init() {
allbyurl = make(map[string]*resolvermaker)
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for _, e := range allmakers {
allbyurl[e.url] = e
if e.url != systemResolverURL {
e.score = rng.Float64()
}
}
}
// byteCounter returns the configured byteCounter or a default
func (r *Resolver) byteCounter() *bytecounter.Counter {
if r.ByteCounter != nil {
return r.ByteCounter
}
return bytecounter.New()
}
// logger returns the configured logger or a default
func (r *Resolver) logger() Logger {
if r.Logger != nil {
return r.Logger
}
return log.Log
}
// newresolver creates a new resolver with the given config and URL. This is
// where we expand http3 to https and set the h3 options.
func (r *Resolver) newresolver(URL string) (childResolver, error) {
h3 := strings.HasPrefix(URL, "http3://")
if h3 {
URL = strings.Replace(URL, "http3://", "https://", 1)
}
return r.clientmaker().Make(netx.Config{
BogonIsError: true,
ByteCounter: r.byteCounter(),
HTTP3Enabled: h3,
Logger: r.logger(),
}, URL)
}
// getresolver returns a resolver with the given URL. This function caches
// already allocated resolvers so we only allocate them once.
func (r *Resolver) getresolver(URL string) (childResolver, error) {
defer r.mu.Unlock()
r.mu.Lock()
if re, found := r.res[URL]; found == true {
return re, nil // already created
}
re, err := r.newresolver(URL)
if err != nil {
return nil, err // config err?
}
if r.res == nil {
r.res = make(map[string]childResolver)
}
r.res[URL] = re
return re, nil
}
// closeall closes the cached resolvers.
func (r *Resolver) closeall() {
defer r.mu.Unlock()
r.mu.Lock()
for _, re := range r.res {
re.CloseIdleConnections()
}
r.res = nil
}

View File

@ -0,0 +1,124 @@
package sessionresolver
import (
"errors"
"strings"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
)
func TestDefaultByteCounter(t *testing.T) {
reso := &Resolver{}
bc := reso.byteCounter()
if bc == nil {
t.Fatal("expected non-nil byte counter")
}
}
func TestDefaultLogger(t *testing.T) {
logger := &log.Logger{}
reso := &Resolver{Logger: logger}
lo := reso.logger()
if lo != logger {
t.Fatal("expected another logger here counter")
}
}
func TestGetResolverHTTPSStandard(t *testing.T) {
bc := bytecounter.New()
URL := "https://dns.google"
re := &FakeResolver{}
cmk := &fakeDNSClientMaker{reso: re}
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
out, err := reso.getresolver(URL)
if err != nil {
t.Fatal(err)
}
if out != re {
t.Fatal("not the result we expected")
}
o2, err := reso.getresolver(URL)
if err != nil {
t.Fatal(err)
}
if out != o2 {
t.Fatal("not the result we expected")
}
reso.closeall()
if re.Closed != true {
t.Fatal("was not closed")
}
if cmk.savedURL != URL {
t.Fatal("not the URL we expected")
}
if cmk.savedConfig.ByteCounter != bc {
t.Fatal("unexpected ByteCounter")
}
if cmk.savedConfig.BogonIsError != true {
t.Fatal("unexpected BogonIsError")
}
if cmk.savedConfig.HTTP3Enabled != false {
t.Fatal("unexpected HTTP3Enabled")
}
if cmk.savedConfig.Logger != log.Log {
t.Fatal("unexpected Log")
}
}
func TestGetResolverHTTP3(t *testing.T) {
bc := bytecounter.New()
URL := "http3://dns.google"
re := &FakeResolver{}
cmk := &fakeDNSClientMaker{reso: re}
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
out, err := reso.getresolver(URL)
if err != nil {
t.Fatal(err)
}
if out != re {
t.Fatal("not the result we expected")
}
o2, err := reso.getresolver(URL)
if err != nil {
t.Fatal(err)
}
if out != o2 {
t.Fatal("not the result we expected")
}
reso.closeall()
if re.Closed != true {
t.Fatal("was not closed")
}
if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) {
t.Fatal("not the URL we expected")
}
if cmk.savedConfig.ByteCounter != bc {
t.Fatal("unexpected ByteCounter")
}
if cmk.savedConfig.BogonIsError != true {
t.Fatal("unexpected BogonIsError")
}
if cmk.savedConfig.HTTP3Enabled != true {
t.Fatal("unexpected HTTP3Enabled")
}
if cmk.savedConfig.Logger != log.Log {
t.Fatal("unexpected Log")
}
}
func TestGetResolverInvalidURL(t *testing.T) {
bc := bytecounter.New()
URL := "http3://dns.google"
errMocked := errors.New("mocked error")
cmk := &fakeDNSClientMaker{err: errMocked}
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
out, err := reso.getresolver(URL)
if !errors.Is(err, errMocked) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("not the result we expected")
}
}

View File

@ -1,85 +1,134 @@
// Package sessionresolver contains the resolver used by the session. This // Package sessionresolver contains the resolver used by the session. This
// resolver uses Powerdns DoH by default and falls back on the system // resolver will try to figure out which is the best service for running
// provided resolver if Powerdns DoH is not working. // domain name resolutions and will consistently use it.
//
// Occasionally this code will also swap the best resolver with other
// ~good resolvers to give them a chance to perform.
//
// The penalty/reward mechanism is strongly derivative, so the code should
// adapt ~quickly to changing network conditions. Occasionally, we will
// have longer resolutions when trying out other resolvers.
//
// At the beginning we randomize the known resolvers so that we do not
// have any preferential ordering. The initial resolutions may be slower
// if there are many issues with resolvers.
//
// The system resolver is given the lowest priority at the beginning
// but it will of course be the most popular resolver if anything else
// is failing us. (We will still occasionally probe for other working
// resolvers and increase their score on success.)
package sessionresolver package sessionresolver
import ( import (
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"math/rand"
"sync"
"time" "time"
"github.com/ooni/probe-cli/v3/internal/engine/atomicx" "github.com/ooni/probe-cli/v3/internal/engine/internal/multierror"
"github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/runtimex" "github.com/ooni/probe-cli/v3/internal/engine/runtimex"
) )
// Resolver is the session resolver. // Resolver is the session resolver. You should create an instance of
// this structure and use it in session.go.
type Resolver struct { type Resolver struct {
Primary netx.DNSClient ByteCounter *bytecounter.Counter // optional
PrimaryFailure *atomicx.Int64 KVStore KVStore // optional
PrimaryQuery *atomicx.Int64 Logger Logger // optional
Fallback netx.DNSClient codec codec
FallbackFailure *atomicx.Int64 dnsClientMaker dnsclientmaker
FallbackQuery *atomicx.Int64 mu sync.Mutex
once sync.Once
res map[string]childResolver
} }
// New creates a new session resolver. // CloseIdleConnections closes the idle connections, if any. This
func New(config netx.Config) *Resolver { // function is guaranteed to be idempotent.
primary, err := netx.NewDNSClientWithOverrides(config,
"https://cloudflare.com/dns-query", "dns.cloudflare.com", "", "")
runtimex.PanicOnError(err, "cannot create dns over https resolver")
fallback, err := netx.NewDNSClient(config, "system:///")
runtimex.PanicOnError(err, "cannot create system resolver")
return &Resolver{
Primary: primary,
PrimaryFailure: atomicx.NewInt64(),
PrimaryQuery: atomicx.NewInt64(),
Fallback: fallback,
FallbackFailure: atomicx.NewInt64(),
FallbackQuery: atomicx.NewInt64(),
}
}
// CloseIdleConnections closes the idle connections, if any
func (r *Resolver) CloseIdleConnections() { func (r *Resolver) CloseIdleConnections() {
r.Primary.CloseIdleConnections() r.once.Do(r.closeall)
r.Fallback.CloseIdleConnections()
} }
// Stats returns stats about the session resolver. // Stats returns stats about the session resolver.
func (r *Resolver) Stats() string { func (r *Resolver) Stats() string {
return fmt.Sprintf("sessionresolver: failure rate: primary: %d/%d; fallback: %d/%d", data, err := json.Marshal(r.readstatedefault())
r.PrimaryFailure.Load(), r.PrimaryQuery.Load(), runtimex.PanicOnError(err, "json.Marshal should not fail here")
r.FallbackFailure.Load(), r.FallbackQuery.Load()) return fmt.Sprintf("sessionresolver: %s", string(data))
} }
// LookupHost implements Resolver.LookupHost // ErrLookupHost indicates that LookupHost failed.
var ErrLookupHost = errors.New("sessionresolver: LookupHost failed")
// LookupHost implements Resolver.LookupHost. This function returns a
// multierror.Union error on failure, so you can see individual errors
// and get a better picture of what's been going wrong.
func (r *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
// Algorithm similar to Firefox TRR2 mode. See: state := r.readstatedefault()
// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox r.maybeConfusion(state, time.Now().UnixNano())
// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side defer r.writestate(state)
// and therefore see to use DoH more often. me := multierror.New(ErrLookupHost)
r.PrimaryQuery.Add(1) for _, e := range state {
trr2, cancel := context.WithTimeout(ctx, 4*time.Second) addrs, err := r.lookupHost(ctx, e, hostname)
defer cancel() if err == nil {
addrs, err := r.Primary.LookupHost(trr2, hostname) return addrs, nil
if err != nil {
r.PrimaryFailure.Add(1)
r.FallbackQuery.Add(1)
addrs, err = r.Fallback.LookupHost(ctx, hostname)
if err != nil {
r.FallbackFailure.Add(1)
} }
me.Add(&errwrapper{error: err, URL: e.URL})
} }
return addrs, err return nil, me
} }
// Network implements Resolver.Network func (r *Resolver) lookupHost(ctx context.Context, ri *resolverinfo, hostname string) ([]string, error) {
const ewma = 0.9 // the last sample is very important
re, err := r.getresolver(ri.URL)
if err != nil {
r.logger().Warnf("sessionresolver: getresolver: %s", err.Error())
ri.Score = 0 // this is a hard error
return nil, err
}
addrs, err := r.timeLimitedLookup(ctx, re, hostname)
if err == nil {
r.logger().Infof("sessionresolver: %s... %v", ri.URL, nil)
ri.Score = ewma*1.0 + (1-ewma)*ri.Score // increase score
return addrs, nil
}
r.logger().Warnf("sessionresolver: %s... %s", ri.URL, err.Error())
ri.Score = ewma*0.0 + (1-ewma)*ri.Score // decrease score
return nil, err
}
// maybeConfusion will rearrange the first elements of the vector
// with low probability, so giving other resolvers a chance
// to run and show that they are also viable. We do not fully
// reorder the vector because that could lead to long runtimes.
//
// The return value is only meaningful for testing.
func (r *Resolver) maybeConfusion(state []*resolverinfo, seed int64) int {
rng := rand.New(rand.NewSource(seed))
const confusion = 0.3
if rng.Float64() >= confusion {
return -1
}
switch len(state) {
case 0, 1: // nothing to do
return 0
case 2:
state[0], state[1] = state[1], state[0]
return 2
default:
state[0], state[2] = state[2], state[0]
return 3
}
}
// Network implements Resolver.Network.
func (r *Resolver) Network() string { func (r *Resolver) Network() string {
return "sessionresolver" return "sessionresolver"
} }
// Address implements Resolver.Address // Address implements Resolver.Address.
func (r *Resolver) Address() string { func (r *Resolver) Address() string {
return "" return ""
} }

View File

@ -1,31 +1,249 @@
package sessionresolver_test package sessionresolver
import ( import (
"context" "context"
"errors"
"net"
"strings" "strings"
"testing" "testing"
"github.com/ooni/probe-cli/v3/internal/engine/internal/sessionresolver" "github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/internal/multierror"
) )
func TestFallbackWorks(t *testing.T) { func TestNetworkWorks(t *testing.T) {
reso := sessionresolver.New(netx.Config{}) reso := &Resolver{}
defer reso.CloseIdleConnections()
if reso.Network() != "sessionresolver" { if reso.Network() != "sessionresolver" {
t.Fatal("unexpected Network") t.Fatal("unexpected value returned by Network")
} }
}
func TestAddressWorks(t *testing.T) {
reso := &Resolver{}
if reso.Address() != "" { if reso.Address() != "" {
t.Fatal("unexpected Address") t.Fatal("unexpected value returned by Address")
} }
addrs, err := reso.LookupHost(context.Background(), "antani.ooni.nu") }
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected") func TestTypicalUsageWithFailure(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
reso := &Resolver{}
addrs, err := reso.LookupHost(ctx, "ooni.org")
if !errors.Is(err, ErrLookupHost) {
t.Fatal("not the error we expected", err)
}
var me *multierror.Union
if !errors.As(err, &me) {
t.Fatal("cannot convert error")
}
for _, child := range me.Children {
// net.DNSError does not include the underlying error
// but just a string representing the error. This
// means that we need to go down hunting what's the
// real error that occurred and use more verbose code.
{
var errWrapper *errwrapper
if !errors.As(child, &errWrapper) {
t.Fatal("not an instance of errwrapper")
}
var dnsError *net.DNSError
if errors.As(errWrapper.error, &dnsError) {
if !strings.HasSuffix(dnsError.Err, "operation was canceled") {
t.Fatal("not the error we expected", dnsError.Err)
}
continue
}
}
// otherwise just unwrap and check whether it's
// a real context.Canceled error.
if !errors.Is(child, context.Canceled) {
t.Fatal("unexpected sub-error", child)
}
}
if addrs != nil {
t.Fatal("expected nil here")
}
if len(reso.res) < 1 {
t.Fatal("expected to see some resolvers here")
}
if reso.Stats() == "" {
t.Fatal("expected to see some string returned by stats")
}
reso.CloseIdleConnections()
if len(reso.res) != 0 {
t.Fatal("expected to see no resolvers after CloseIdleConnections")
}
}
func TestTypicalUsageWithSuccess(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
ctx := context.Background()
reso := &Resolver{
dnsClientMaker: &fakeDNSClientMaker{
reso: &FakeResolver{Data: expected},
},
}
addrs, err := reso.LookupHost(ctx, "dns.google")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff)
}
}
func TestLittleLLookupHostWithInvalidURL(t *testing.T) {
reso := &Resolver{}
ctx := context.Background()
ri := &resolverinfo{URL: "\t\t\t", Score: 0.99}
addrs, err := reso.lookupHost(ctx, ri, "ooni.org")
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("not the error we expected", err)
} }
if addrs != nil { if addrs != nil {
t.Fatal("expected nil addrs here") t.Fatal("expected nil addrs here")
} }
if reso.PrimaryFailure.Load() != 1 || reso.FallbackFailure.Load() != 1 { if ri.Score != 0 {
t.Fatal("not the counters we expected to see here") t.Fatal("unexpected ri.Score", ri.Score)
}
}
func TestLittleLLookupHostWithSuccess(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
reso := &Resolver{
dnsClientMaker: &fakeDNSClientMaker{
reso: &FakeResolver{Data: expected},
},
}
ctx := context.Background()
ri := &resolverinfo{URL: "dot://dns-nonexistent.ooni.org", Score: 0.1}
addrs, err := reso.lookupHost(ctx, ri, "dns.google")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff)
}
if ri.Score < 0.88 || ri.Score > 0.92 {
t.Fatal("unexpected score", ri.Score)
}
}
func TestLittleLLookupHostWithFailure(t *testing.T) {
errMocked := errors.New("mocked error")
reso := &Resolver{
dnsClientMaker: &fakeDNSClientMaker{
reso: &FakeResolver{Err: errMocked},
},
}
ctx := context.Background()
ri := &resolverinfo{URL: "dot://dns-nonexistent.ooni.org", Score: 0.95}
addrs, err := reso.lookupHost(ctx, ri, "dns.google")
if !errors.Is(err, errMocked) {
t.Fatal("not the error we expected", err)
}
if addrs != nil {
t.Fatal("expected nil addrs here")
}
if ri.Score < 0.094 || ri.Score > 0.096 {
t.Fatal("unexpected score", ri.Score)
}
}
func TestMaybeConfusionNoConfusion(t *testing.T) {
reso := &Resolver{}
rv := reso.maybeConfusion(nil, 0)
if rv != -1 {
t.Fatal("unexpected return value", rv)
}
}
func TestMaybeConfusionNoArray(t *testing.T) {
reso := &Resolver{}
rv := reso.maybeConfusion(nil, 11)
if rv != 0 {
t.Fatal("unexpected return value", rv)
}
}
func TestMaybeConfusionSingleEntry(t *testing.T) {
reso := &Resolver{}
state := []*resolverinfo{{}}
rv := reso.maybeConfusion(state, 11)
if rv != 0 {
t.Fatal("unexpected return value", rv)
}
}
func TestMaybeConfusionTwoEntries(t *testing.T) {
reso := &Resolver{}
state := []*resolverinfo{{
Score: 0.8,
URL: "https://dns.google/dns-query",
}, {
Score: 0.4,
URL: "http3://dns.google/dns-query",
}}
rv := reso.maybeConfusion(state, 11)
if rv != 2 {
t.Fatal("unexpected return value", rv)
}
if state[0].Score != 0.4 {
t.Fatal("unexpected state[0].Score")
}
if state[0].URL != "http3://dns.google/dns-query" {
t.Fatal("unexpected state[0].URL")
}
if state[1].Score != 0.8 {
t.Fatal("unexpected state[1].Score")
}
if state[1].URL != "https://dns.google/dns-query" {
t.Fatal("unexpected state[1].URL")
}
}
func TestMaybeConfusionManyEntries(t *testing.T) {
reso := &Resolver{}
state := []*resolverinfo{{
Score: 0.8,
URL: "https://dns.google/dns-query",
}, {
Score: 0.4,
URL: "http3://dns.google/dns-query",
}, {
Score: 0.1,
URL: "system:///",
}, {
Score: 0.01,
URL: "dot://dns.google",
}}
rv := reso.maybeConfusion(state, 11)
if rv != 3 {
t.Fatal("unexpected return value", rv)
}
if state[0].Score != 0.1 {
t.Fatal("unexpected state[0].Score")
}
if state[0].URL != "system:///" {
t.Fatal("unexpected state[0].URL")
}
if state[1].Score != 0.4 {
t.Fatal("unexpected state[1].Score")
}
if state[1].URL != "http3://dns.google/dns-query" {
t.Fatal("unexpected state[1].URL")
}
if state[2].Score != 0.8 {
t.Fatal("unexpected state[2].Score")
}
if state[2].URL != "https://dns.google/dns-query" {
t.Fatal("unexpected state[2].URL")
}
if state[3].Score != 0.01 {
t.Fatal("unexpected state[3].Score")
}
if state[3].URL != "dot://dns.google" {
t.Fatal("unexpected state[3].URL")
} }
} }

View File

@ -0,0 +1,93 @@
package sessionresolver
import (
"errors"
"sort"
)
// storekey is the key used by the key value store to store
// the state required by this package.
const storekey = "sessionresolver.state"
// resolverinfo contains info about a resolver.
type resolverinfo struct {
// URL is the URL of a resolver.
URL string
// Score is the score of a resolver.
Score float64
}
// readstate reads the resolver state from disk
func (r *Resolver) readstate() ([]*resolverinfo, error) {
data, err := r.kvstore().Get(storekey)
if err != nil {
return nil, err
}
var ri []*resolverinfo
if err := r.getCodec().Decode(data, &ri); err != nil {
return nil, err
}
return ri, nil
}
// errNoEntries indicates that no entry remained after we pruned
// all the available entries in readstateandprune.
var errNoEntries = errors.New("sessionresolver: no available entries")
// readstateandprune reads the state from disk and removes all the
// entries that we don't actually support.
func (r *Resolver) readstateandprune() ([]*resolverinfo, error) {
ri, err := r.readstate()
if err != nil {
return nil, err
}
var out []*resolverinfo
for _, e := range ri {
if _, found := allbyurl[e.URL]; !found {
continue // we don't support this specific entry
}
out = append(out, e)
}
if len(out) <= 0 {
return nil, errNoEntries
}
return out, nil
}
// sortstate sorts the state by descending score
func sortstate(ri []*resolverinfo) {
sort.SliceStable(ri, func(i, j int) bool {
return ri[i].Score >= ri[j].Score
})
}
// readstatedefault reads the state from disk and merges the state
// so that all supported entries are represented.
func (r *Resolver) readstatedefault() []*resolverinfo {
ri, _ := r.readstateandprune()
here := make(map[string]bool)
for _, e := range ri {
here[e.URL] = true // record what we already have
}
for _, e := range allmakers {
if _, found := here[e.url]; found {
continue // already here so no need to add
}
ri = append(ri, &resolverinfo{
URL: e.url,
Score: e.score,
})
}
sortstate(ri)
return ri
}
// writestate writes the state on the kvstore.
func (r *Resolver) writestate(ri []*resolverinfo) error {
data, err := r.getCodec().Encode(ri)
if err != nil {
return err
}
return r.kvstore().Set(storekey, data)
}

View File

@ -0,0 +1,120 @@
package sessionresolver
import (
"errors"
"testing"
)
func TestReadStateNothingInKVStore(t *testing.T) {
reso := &Resolver{KVStore: &memkvstore{}}
out, err := reso.readstate()
if !errors.Is(err, errMemkvstoreNotFound) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}
func TestReadStateDecodeError(t *testing.T) {
errMocked := errors.New("mocked error")
reso := &Resolver{
KVStore: &memkvstore{},
codec: &FakeCodec{DecodeErr: errMocked},
}
if err := reso.KVStore.Set(storekey, []byte(`[]`)); err != nil {
t.Fatal(err)
}
out, err := reso.readstate()
if !errors.Is(err, errMocked) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}
func TestReadStateAndPruneReadStateError(t *testing.T) {
reso := &Resolver{KVStore: &memkvstore{}}
out, err := reso.readstateandprune()
if !errors.Is(err, errMemkvstoreNotFound) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}
func TestReadStateAndPruneWithUnsupportedEntries(t *testing.T) {
reso := &Resolver{KVStore: &memkvstore{}}
var in []*resolverinfo
in = append(in, &resolverinfo{})
if err := reso.writestate(in); err != nil {
t.Fatal(err)
}
out, err := reso.readstateandprune()
if !errors.Is(err, errNoEntries) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
}
func TestReadStateDefaultWithMissingEntries(t *testing.T) {
reso := &Resolver{KVStore: &memkvstore{}}
// let us simulate that we have just one entry here
existingURL := "https://dns.google/dns-query"
existingScore := 0.88
var in []*resolverinfo
in = append(in, &resolverinfo{
URL: existingURL,
Score: existingScore,
})
if err := reso.writestate(in); err != nil {
t.Fatal(err)
}
// let us seee what we read
out := reso.readstatedefault()
if len(out) < 1 {
t.Fatal("expected non-empty output")
}
keys := make(map[string]bool)
var found bool
for _, e := range out {
keys[e.URL] = true
if e.URL == existingURL {
if e.Score != existingScore {
t.Fatal("the score is not what we expected")
}
found = true
}
}
if !found {
t.Fatal("did not found the pre-loaded URL")
}
for k := range allbyurl {
if _, found := keys[k]; !found {
t.Fatal("missing key", k)
}
}
}
func TestWriteStateCannotSerialize(t *testing.T) {
errMocked := errors.New("mocked error")
reso := &Resolver{
codec: &FakeCodec{
EncodeErr: errMocked,
},
}
existingURL := "https://dns.google/dns-query"
existingScore := 0.88
var in []*resolverinfo
in = append(in, &resolverinfo{
URL: existingURL,
Score: existingScore,
})
if err := reso.writestate(in); !errors.Is(err, errMocked) {
t.Fatal("not the error we expected", err)
}
}

View File

@ -109,7 +109,11 @@ func NewSession(config SessionConfig) (*Session, error) {
BogonIsError: true, BogonIsError: true,
Logger: sess.logger, Logger: sess.logger,
} }
sess.resolver = sessionresolver.New(httpConfig) sess.resolver = &sessionresolver.Resolver{
ByteCounter: sess.byteCounter,
KVStore: config.KVStore,
Logger: sess.logger,
}
httpConfig.FullResolver = sess.resolver httpConfig.FullResolver = sess.resolver
httpConfig.ProxyURL = config.ProxyURL // no need to proxy the resolver httpConfig.ProxyURL = config.ProxyURL // no need to proxy the resolver
sess.httpDefaultTransport = netx.NewHTTPTransport(httpConfig) sess.httpDefaultTransport = netx.NewHTTPTransport(httpConfig)