refactor(sessionresolver): adapt to changing network conditions (#238)
* feat(sessionresolver): try many and use what works * fix(sessionresolver): make sure we can use quic * fix: the config struct is unnecessary * fix: make kvstore optional * feat: write simple integration test * feat: start adding tests * feat: continue writing tests * fix(sessionresolver): add more unit tests * fix(sessionresolver): finish adding tests * refactor(sessionresolver): changes after code review
This commit is contained in:
parent
12e1164940
commit
034db78f94
27
internal/engine/internal/sessionresolver/childresolver.go
Normal file
27
internal/engine/internal/sessionresolver/childresolver.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// childResolver is the DNS client that this package uses
|
||||
// to perform individual domain name resolutions.
|
||||
type childResolver interface {
|
||||
// LookupHost performs a DNS lookup.
|
||||
LookupHost(ctx context.Context, domain string) ([]string, error)
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// timeLimitedLookup performs a time-limited lookup using the given re.
|
||||
func (r *Resolver) timeLimitedLookup(ctx context.Context, re childResolver, hostname string) ([]string, error) {
|
||||
// Algorithm similar to Firefox TRR2 mode. See:
|
||||
// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox
|
||||
// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side
|
||||
// and therefore see to use DoH more often.
|
||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
return re.LookupHost(ctx, hostname)
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
type FakeResolver struct {
|
||||
Closed bool
|
||||
Data []string
|
||||
Err error
|
||||
Sleep time.Duration
|
||||
}
|
||||
|
||||
func (r *FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
select {
|
||||
case <-time.After(r.Sleep):
|
||||
return r.Data, r.Err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *FakeResolver) CloseIdleConnections() {
|
||||
r.Closed = true
|
||||
}
|
||||
|
||||
func TestTimeLimitedLookupSuccess(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
re := &FakeResolver{
|
||||
Data: []string{"8.8.8.8", "8.8.4.4"},
|
||||
}
|
||||
ctx := context.Background()
|
||||
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(re.Data, out); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeLimitedLookupFailure(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
re := &FakeResolver{
|
||||
Err: io.EOF,
|
||||
}
|
||||
ctx := context.Background()
|
||||
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
|
||||
if !errors.Is(err, re.Err) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeLimitedLookupWillTimeout(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode")
|
||||
}
|
||||
reso := &Resolver{}
|
||||
re := &FakeResolver{
|
||||
Err: io.EOF,
|
||||
Sleep: 20 * time.Second,
|
||||
}
|
||||
ctx := context.Background()
|
||||
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
25
internal/engine/internal/sessionresolver/clientmaker.go
Normal file
25
internal/engine/internal/sessionresolver/clientmaker.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package sessionresolver
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
|
||||
// dnsclientmaker makes a new resolver.
|
||||
type dnsclientmaker interface {
|
||||
// Make makes a new resolver.
|
||||
Make(config netx.Config, URL string) (childResolver, error)
|
||||
}
|
||||
|
||||
// clientmaker returns a valid dnsclientmaker
|
||||
func (r *Resolver) clientmaker() dnsclientmaker {
|
||||
if r.dnsClientMaker != nil {
|
||||
return r.dnsClientMaker
|
||||
}
|
||||
return &defaultDNSClientMaker{}
|
||||
}
|
||||
|
||||
// defaultDNSClientMaker is the default dnsclientmaker
|
||||
type defaultDNSClientMaker struct{}
|
||||
|
||||
// Make implements dnsclientmaker.Make.
|
||||
func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) {
|
||||
return netx.NewDNSClient(config, URL)
|
||||
}
|
52
internal/engine/internal/sessionresolver/clientmaker_test.go
Normal file
52
internal/engine/internal/sessionresolver/clientmaker_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
)
|
||||
|
||||
type fakeDNSClientMaker struct {
|
||||
reso childResolver
|
||||
err error
|
||||
savedConfig netx.Config
|
||||
savedURL string
|
||||
}
|
||||
|
||||
func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) {
|
||||
c.savedConfig = config
|
||||
c.savedURL = URL
|
||||
return c.reso, c.err
|
||||
}
|
||||
|
||||
func TestClientMakerWithOverride(t *testing.T) {
|
||||
m := &fakeDNSClientMaker{err: io.EOF}
|
||||
reso := &Resolver{dnsClientMaker: m}
|
||||
out, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDefaultWithCancelledContext(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
re, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // fail immediately
|
||||
out, err := re.LookupHost(ctx, "dns.google")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil output")
|
||||
}
|
||||
}
|
35
internal/engine/internal/sessionresolver/codec.go
Normal file
35
internal/engine/internal/sessionresolver/codec.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// codec is the codec we use.
|
||||
type codec interface {
|
||||
// Encode encodes v as a stream of bytes.
|
||||
Encode(v interface{}) ([]byte, error)
|
||||
|
||||
// Decode decodes b into a stream of bytes.
|
||||
Decode(b []byte, v interface{}) error
|
||||
}
|
||||
|
||||
// getCodec always returns a valid codec.
|
||||
func (r *Resolver) getCodec() codec {
|
||||
if r.codec != nil {
|
||||
return r.codec
|
||||
}
|
||||
return &defaultCodec{}
|
||||
}
|
||||
|
||||
// defaultCodec is the default codec.
|
||||
type defaultCodec struct{}
|
||||
|
||||
// Decode decodes b into v using the default codec.
|
||||
func (*defaultCodec) Decode(b []byte, v interface{}) error {
|
||||
return json.Unmarshal(b, v)
|
||||
}
|
||||
|
||||
// Encode encodes v using the default codec.
|
||||
func (*defaultCodec) Encode(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
48
internal/engine/internal/sessionresolver/codec_test.go
Normal file
48
internal/engine/internal/sessionresolver/codec_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
type FakeCodec struct {
|
||||
EncodeData []byte
|
||||
EncodeErr error
|
||||
DecodeErr error
|
||||
}
|
||||
|
||||
func (c *FakeCodec) Encode(v interface{}) ([]byte, error) {
|
||||
return c.EncodeData, c.EncodeErr
|
||||
}
|
||||
|
||||
func (c *FakeCodec) Decode(b []byte, v interface{}) error {
|
||||
return c.DecodeErr
|
||||
}
|
||||
|
||||
func TestCodecCustom(t *testing.T) {
|
||||
c := &FakeCodec{}
|
||||
reso := &Resolver{codec: c}
|
||||
if r := reso.getCodec(); r != c {
|
||||
t.Fatal("not the codec we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodecDefault(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
in := resolverinfo{
|
||||
URL: "https://dns.google/dns.query",
|
||||
Score: 0.99,
|
||||
}
|
||||
data, err := reso.getCodec().Encode(in)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var out resolverinfo
|
||||
if err := reso.getCodec().Decode(data, &out); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(in, out); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
32
internal/engine/internal/sessionresolver/dependencies.go
Normal file
32
internal/engine/internal/sessionresolver/dependencies.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package sessionresolver
|
||||
|
||||
// KVStore is a generic key-value store. We use it to store
|
||||
// on disk persistent state used by this package.
|
||||
type KVStore interface {
|
||||
// Get gets the value for the given key.
|
||||
Get(key string) ([]byte, error)
|
||||
|
||||
// Set sets the value of the given key.
|
||||
Set(key string, value []byte) error
|
||||
}
|
||||
|
||||
// Logger defines the common logger interface.
|
||||
type Logger interface {
|
||||
// Debug emits a debug message.
|
||||
Debug(msg string)
|
||||
|
||||
// Debugf formats and emits a debug message.
|
||||
Debugf(format string, v ...interface{})
|
||||
|
||||
// Info emits an informational message.
|
||||
Info(msg string)
|
||||
|
||||
// Infof format and emits an informational message.
|
||||
Infof(format string, v ...interface{})
|
||||
|
||||
// Warn emits a warning message.
|
||||
Warn(msg string)
|
||||
|
||||
// Warnf formats and emits a warning message.
|
||||
Warnf(format string, v ...interface{})
|
||||
}
|
23
internal/engine/internal/sessionresolver/errwrapper.go
Normal file
23
internal/engine/internal/sessionresolver/errwrapper.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// errwrapper wraps an error to include the URL of the
|
||||
// resolver that we're currently using.
|
||||
type errwrapper struct {
|
||||
error
|
||||
URL string
|
||||
}
|
||||
|
||||
// Error implements error.Error.
|
||||
func (ew *errwrapper) Error() string {
|
||||
return fmt.Sprintf("<%s> %s", ew.URL, ew.error.Error())
|
||||
}
|
||||
|
||||
// Is allows consumers to query for the type of the underlying error.
|
||||
func (ew *errwrapper) Is(target error) bool {
|
||||
return errors.Is(ew.error, target)
|
||||
}
|
24
internal/engine/internal/sessionresolver/errwrapper_test.go
Normal file
24
internal/engine/internal/sessionresolver/errwrapper_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestErrWrapper(t *testing.T) {
|
||||
ew := &errwrapper{
|
||||
error: io.EOF,
|
||||
URL: "https://dns.quad9.net/dns-query",
|
||||
}
|
||||
o := ew.Error()
|
||||
expect := "<https://dns.quad9.net/dns-query> EOF"
|
||||
if diff := cmp.Diff(expect, o); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if !errors.Is(ew, io.EOF) {
|
||||
t.Fatal("not the sub-error we expected")
|
||||
}
|
||||
}
|
29
internal/engine/internal/sessionresolver/integration_test.go
Normal file
29
internal/engine/internal/sessionresolver/integration_test.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package sessionresolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/sessionresolver"
|
||||
)
|
||||
|
||||
func TestSessionResolverGood(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in short mode")
|
||||
}
|
||||
reso := &sessionresolver.Resolver{}
|
||||
defer reso.CloseIdleConnections()
|
||||
if reso.Network() != "sessionresolver" {
|
||||
t.Fatal("unexpected Network")
|
||||
}
|
||||
if reso.Address() != "" {
|
||||
t.Fatal("unexpected Address")
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) < 1 {
|
||||
t.Fatal("expected some addrs here")
|
||||
}
|
||||
}
|
43
internal/engine/internal/sessionresolver/memkvstore.go
Normal file
43
internal/engine/internal/sessionresolver/memkvstore.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func (r *Resolver) kvstore() KVStore {
|
||||
defer r.mu.Unlock()
|
||||
r.mu.Lock()
|
||||
if r.KVStore == nil {
|
||||
r.KVStore = &memkvstore{}
|
||||
}
|
||||
return r.KVStore
|
||||
}
|
||||
|
||||
var errMemkvstoreNotFound = errors.New("memkvstore: not found")
|
||||
|
||||
type memkvstore struct {
|
||||
m map[string][]byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (kvs *memkvstore) Get(key string) ([]byte, error) {
|
||||
defer kvs.mu.Unlock()
|
||||
kvs.mu.Lock()
|
||||
out, good := kvs.m[key]
|
||||
if !good {
|
||||
return nil, fmt.Errorf("%w: %s", errMemkvstoreNotFound, key)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (kvs *memkvstore) Set(key string, value []byte) error {
|
||||
defer kvs.mu.Unlock()
|
||||
kvs.mu.Lock()
|
||||
if kvs.m == nil {
|
||||
kvs.m = make(map[string][]byte)
|
||||
}
|
||||
kvs.m[key] = value
|
||||
return nil
|
||||
}
|
47
internal/engine/internal/sessionresolver/memkvstore_test.go
Normal file
47
internal/engine/internal/sessionresolver/memkvstore_test.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestKVStoreCustom(t *testing.T) {
|
||||
kvs := &memkvstore{}
|
||||
reso := &Resolver{KVStore: kvs}
|
||||
o := reso.kvstore()
|
||||
if o != kvs {
|
||||
t.Fatal("not the kvstore we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemkvstoreGetNotFound(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
key := "antani"
|
||||
out, err := reso.kvstore().Get(key)
|
||||
if !errors.Is(err, errMemkvstoreNotFound) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemkvstoreRoundTrip(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
key := []string{"antani", "mascetti"}
|
||||
value := [][]byte{[]byte(`mascetti`), []byte(`antani`)}
|
||||
for idx := 0; idx < 2; idx++ {
|
||||
if err := reso.kvstore().Set(key[idx], value[idx]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out, err := reso.kvstore().Get(key[idx])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(value[idx], out); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
}
|
121
internal/engine/internal/sessionresolver/resolvermaker.go
Normal file
121
internal/engine/internal/sessionresolver/resolvermaker.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
)
|
||||
|
||||
// resolvemaker contains rules for making a resolver.
|
||||
type resolvermaker struct {
|
||||
url string
|
||||
score float64
|
||||
}
|
||||
|
||||
// systemResolverURL is the URL of the system resolver.
|
||||
const systemResolverURL = "system:///"
|
||||
|
||||
// allmakers contains all the makers in a list. We use the http3
|
||||
// prefix to indicate we wanna use http3. The code will translate
|
||||
// this to https and set the proper next options.
|
||||
var allmakers = []*resolvermaker{{
|
||||
url: "https://cloudflare-dns.com/dns-query",
|
||||
}, {
|
||||
url: "http3://cloudflare-dns.com/dns-query",
|
||||
}, {
|
||||
url: "https://dns.google/dns-query",
|
||||
}, {
|
||||
url: "http3://dns.google/dns-query",
|
||||
}, {
|
||||
url: "https://dns.quad9.net/dns-query",
|
||||
}, {
|
||||
url: "https://doh.powerdns.org/",
|
||||
}, {
|
||||
url: systemResolverURL,
|
||||
}, {
|
||||
url: "https://mozilla.cloudflare-dns.com/dns-query",
|
||||
}, {
|
||||
url: "http3://mozilla.cloudflare-dns.com/dns-query",
|
||||
}}
|
||||
|
||||
// allbyurl contains all the resolvermakers by URL
|
||||
var allbyurl map[string]*resolvermaker
|
||||
|
||||
// init fills allbyname and gives a nonzero initial score
|
||||
// to all resolvers except for the system resolver. We set
|
||||
// the system resolver score to zero, so that it's less
|
||||
// likely than other resolvers in this list.
|
||||
func init() {
|
||||
allbyurl = make(map[string]*resolvermaker)
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for _, e := range allmakers {
|
||||
allbyurl[e.url] = e
|
||||
if e.url != systemResolverURL {
|
||||
e.score = rng.Float64()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// byteCounter returns the configured byteCounter or a default
|
||||
func (r *Resolver) byteCounter() *bytecounter.Counter {
|
||||
if r.ByteCounter != nil {
|
||||
return r.ByteCounter
|
||||
}
|
||||
return bytecounter.New()
|
||||
}
|
||||
|
||||
// logger returns the configured logger or a default
|
||||
func (r *Resolver) logger() Logger {
|
||||
if r.Logger != nil {
|
||||
return r.Logger
|
||||
}
|
||||
return log.Log
|
||||
}
|
||||
|
||||
// newresolver creates a new resolver with the given config and URL. This is
|
||||
// where we expand http3 to https and set the h3 options.
|
||||
func (r *Resolver) newresolver(URL string) (childResolver, error) {
|
||||
h3 := strings.HasPrefix(URL, "http3://")
|
||||
if h3 {
|
||||
URL = strings.Replace(URL, "http3://", "https://", 1)
|
||||
}
|
||||
return r.clientmaker().Make(netx.Config{
|
||||
BogonIsError: true,
|
||||
ByteCounter: r.byteCounter(),
|
||||
HTTP3Enabled: h3,
|
||||
Logger: r.logger(),
|
||||
}, URL)
|
||||
}
|
||||
|
||||
// getresolver returns a resolver with the given URL. This function caches
|
||||
// already allocated resolvers so we only allocate them once.
|
||||
func (r *Resolver) getresolver(URL string) (childResolver, error) {
|
||||
defer r.mu.Unlock()
|
||||
r.mu.Lock()
|
||||
if re, found := r.res[URL]; found == true {
|
||||
return re, nil // already created
|
||||
}
|
||||
re, err := r.newresolver(URL)
|
||||
if err != nil {
|
||||
return nil, err // config err?
|
||||
}
|
||||
if r.res == nil {
|
||||
r.res = make(map[string]childResolver)
|
||||
}
|
||||
r.res[URL] = re
|
||||
return re, nil
|
||||
}
|
||||
|
||||
// closeall closes the cached resolvers.
|
||||
func (r *Resolver) closeall() {
|
||||
defer r.mu.Unlock()
|
||||
r.mu.Lock()
|
||||
for _, re := range r.res {
|
||||
re.CloseIdleConnections()
|
||||
}
|
||||
r.res = nil
|
||||
}
|
124
internal/engine/internal/sessionresolver/resolvermaker_test.go
Normal file
124
internal/engine/internal/sessionresolver/resolvermaker_test.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
)
|
||||
|
||||
func TestDefaultByteCounter(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
bc := reso.byteCounter()
|
||||
if bc == nil {
|
||||
t.Fatal("expected non-nil byte counter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLogger(t *testing.T) {
|
||||
logger := &log.Logger{}
|
||||
reso := &Resolver{Logger: logger}
|
||||
lo := reso.logger()
|
||||
if lo != logger {
|
||||
t.Fatal("expected another logger here counter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResolverHTTPSStandard(t *testing.T) {
|
||||
bc := bytecounter.New()
|
||||
URL := "https://dns.google"
|
||||
re := &FakeResolver{}
|
||||
cmk := &fakeDNSClientMaker{reso: re}
|
||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
||||
out, err := reso.getresolver(URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != re {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
o2, err := reso.getresolver(URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != o2 {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
reso.closeall()
|
||||
if re.Closed != true {
|
||||
t.Fatal("was not closed")
|
||||
}
|
||||
if cmk.savedURL != URL {
|
||||
t.Fatal("not the URL we expected")
|
||||
}
|
||||
if cmk.savedConfig.ByteCounter != bc {
|
||||
t.Fatal("unexpected ByteCounter")
|
||||
}
|
||||
if cmk.savedConfig.BogonIsError != true {
|
||||
t.Fatal("unexpected BogonIsError")
|
||||
}
|
||||
if cmk.savedConfig.HTTP3Enabled != false {
|
||||
t.Fatal("unexpected HTTP3Enabled")
|
||||
}
|
||||
if cmk.savedConfig.Logger != log.Log {
|
||||
t.Fatal("unexpected Log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResolverHTTP3(t *testing.T) {
|
||||
bc := bytecounter.New()
|
||||
URL := "http3://dns.google"
|
||||
re := &FakeResolver{}
|
||||
cmk := &fakeDNSClientMaker{reso: re}
|
||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
||||
out, err := reso.getresolver(URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != re {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
o2, err := reso.getresolver(URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != o2 {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
reso.closeall()
|
||||
if re.Closed != true {
|
||||
t.Fatal("was not closed")
|
||||
}
|
||||
if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) {
|
||||
t.Fatal("not the URL we expected")
|
||||
}
|
||||
if cmk.savedConfig.ByteCounter != bc {
|
||||
t.Fatal("unexpected ByteCounter")
|
||||
}
|
||||
if cmk.savedConfig.BogonIsError != true {
|
||||
t.Fatal("unexpected BogonIsError")
|
||||
}
|
||||
if cmk.savedConfig.HTTP3Enabled != true {
|
||||
t.Fatal("unexpected HTTP3Enabled")
|
||||
}
|
||||
if cmk.savedConfig.Logger != log.Log {
|
||||
t.Fatal("unexpected Log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResolverInvalidURL(t *testing.T) {
|
||||
bc := bytecounter.New()
|
||||
URL := "http3://dns.google"
|
||||
errMocked := errors.New("mocked error")
|
||||
cmk := &fakeDNSClientMaker{err: errMocked}
|
||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
||||
out, err := reso.getresolver(URL)
|
||||
if !errors.Is(err, errMocked) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
|
@ -1,85 +1,134 @@
|
|||
// Package sessionresolver contains the resolver used by the session. This
|
||||
// resolver uses Powerdns DoH by default and falls back on the system
|
||||
// provided resolver if Powerdns DoH is not working.
|
||||
// resolver will try to figure out which is the best service for running
|
||||
// domain name resolutions and will consistently use it.
|
||||
//
|
||||
// Occasionally this code will also swap the best resolver with other
|
||||
// ~good resolvers to give them a chance to perform.
|
||||
//
|
||||
// The penalty/reward mechanism is strongly derivative, so the code should
|
||||
// adapt ~quickly to changing network conditions. Occasionally, we will
|
||||
// have longer resolutions when trying out other resolvers.
|
||||
//
|
||||
// At the beginning we randomize the known resolvers so that we do not
|
||||
// have any preferential ordering. The initial resolutions may be slower
|
||||
// if there are many issues with resolvers.
|
||||
//
|
||||
// The system resolver is given the lowest priority at the beginning
|
||||
// but it will of course be the most popular resolver if anything else
|
||||
// is failing us. (We will still occasionally probe for other working
|
||||
// resolvers and increase their score on success.)
|
||||
package sessionresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/multierror"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/runtimex"
|
||||
)
|
||||
|
||||
// Resolver is the session resolver.
|
||||
// Resolver is the session resolver. You should create an instance of
|
||||
// this structure and use it in session.go.
|
||||
type Resolver struct {
|
||||
Primary netx.DNSClient
|
||||
PrimaryFailure *atomicx.Int64
|
||||
PrimaryQuery *atomicx.Int64
|
||||
Fallback netx.DNSClient
|
||||
FallbackFailure *atomicx.Int64
|
||||
FallbackQuery *atomicx.Int64
|
||||
ByteCounter *bytecounter.Counter // optional
|
||||
KVStore KVStore // optional
|
||||
Logger Logger // optional
|
||||
codec codec
|
||||
dnsClientMaker dnsclientmaker
|
||||
mu sync.Mutex
|
||||
once sync.Once
|
||||
res map[string]childResolver
|
||||
}
|
||||
|
||||
// New creates a new session resolver.
|
||||
func New(config netx.Config) *Resolver {
|
||||
primary, err := netx.NewDNSClientWithOverrides(config,
|
||||
"https://cloudflare.com/dns-query", "dns.cloudflare.com", "", "")
|
||||
runtimex.PanicOnError(err, "cannot create dns over https resolver")
|
||||
fallback, err := netx.NewDNSClient(config, "system:///")
|
||||
runtimex.PanicOnError(err, "cannot create system resolver")
|
||||
return &Resolver{
|
||||
Primary: primary,
|
||||
PrimaryFailure: atomicx.NewInt64(),
|
||||
PrimaryQuery: atomicx.NewInt64(),
|
||||
Fallback: fallback,
|
||||
FallbackFailure: atomicx.NewInt64(),
|
||||
FallbackQuery: atomicx.NewInt64(),
|
||||
}
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes the idle connections, if any
|
||||
// CloseIdleConnections closes the idle connections, if any. This
|
||||
// function is guaranteed to be idempotent.
|
||||
func (r *Resolver) CloseIdleConnections() {
|
||||
r.Primary.CloseIdleConnections()
|
||||
r.Fallback.CloseIdleConnections()
|
||||
r.once.Do(r.closeall)
|
||||
}
|
||||
|
||||
// Stats returns stats about the session resolver.
|
||||
func (r *Resolver) Stats() string {
|
||||
return fmt.Sprintf("sessionresolver: failure rate: primary: %d/%d; fallback: %d/%d",
|
||||
r.PrimaryFailure.Load(), r.PrimaryQuery.Load(),
|
||||
r.FallbackFailure.Load(), r.FallbackQuery.Load())
|
||||
data, err := json.Marshal(r.readstatedefault())
|
||||
runtimex.PanicOnError(err, "json.Marshal should not fail here")
|
||||
return fmt.Sprintf("sessionresolver: %s", string(data))
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
// ErrLookupHost indicates that LookupHost failed.
|
||||
var ErrLookupHost = errors.New("sessionresolver: LookupHost failed")
|
||||
|
||||
// LookupHost implements Resolver.LookupHost. This function returns a
|
||||
// multierror.Union error on failure, so you can see individual errors
|
||||
// and get a better picture of what's been going wrong.
|
||||
func (r *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
// Algorithm similar to Firefox TRR2 mode. See:
|
||||
// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox
|
||||
// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side
|
||||
// and therefore see to use DoH more often.
|
||||
r.PrimaryQuery.Add(1)
|
||||
trr2, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
addrs, err := r.Primary.LookupHost(trr2, hostname)
|
||||
if err != nil {
|
||||
r.PrimaryFailure.Add(1)
|
||||
r.FallbackQuery.Add(1)
|
||||
addrs, err = r.Fallback.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
r.FallbackFailure.Add(1)
|
||||
state := r.readstatedefault()
|
||||
r.maybeConfusion(state, time.Now().UnixNano())
|
||||
defer r.writestate(state)
|
||||
me := multierror.New(ErrLookupHost)
|
||||
for _, e := range state {
|
||||
addrs, err := r.lookupHost(ctx, e, hostname)
|
||||
if err == nil {
|
||||
return addrs, nil
|
||||
}
|
||||
me.Add(&errwrapper{error: err, URL: e.URL})
|
||||
}
|
||||
return addrs, err
|
||||
return nil, me
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (r *Resolver) lookupHost(ctx context.Context, ri *resolverinfo, hostname string) ([]string, error) {
|
||||
const ewma = 0.9 // the last sample is very important
|
||||
re, err := r.getresolver(ri.URL)
|
||||
if err != nil {
|
||||
r.logger().Warnf("sessionresolver: getresolver: %s", err.Error())
|
||||
ri.Score = 0 // this is a hard error
|
||||
return nil, err
|
||||
}
|
||||
addrs, err := r.timeLimitedLookup(ctx, re, hostname)
|
||||
if err == nil {
|
||||
r.logger().Infof("sessionresolver: %s... %v", ri.URL, nil)
|
||||
ri.Score = ewma*1.0 + (1-ewma)*ri.Score // increase score
|
||||
return addrs, nil
|
||||
}
|
||||
r.logger().Warnf("sessionresolver: %s... %s", ri.URL, err.Error())
|
||||
ri.Score = ewma*0.0 + (1-ewma)*ri.Score // decrease score
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// maybeConfusion will rearrange the first elements of the vector
|
||||
// with low probability, so giving other resolvers a chance
|
||||
// to run and show that they are also viable. We do not fully
|
||||
// reorder the vector because that could lead to long runtimes.
|
||||
//
|
||||
// The return value is only meaningful for testing.
|
||||
func (r *Resolver) maybeConfusion(state []*resolverinfo, seed int64) int {
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
const confusion = 0.3
|
||||
if rng.Float64() >= confusion {
|
||||
return -1
|
||||
}
|
||||
switch len(state) {
|
||||
case 0, 1: // nothing to do
|
||||
return 0
|
||||
case 2:
|
||||
state[0], state[1] = state[1], state[0]
|
||||
return 2
|
||||
default:
|
||||
state[0], state[2] = state[2], state[0]
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network.
|
||||
func (r *Resolver) Network() string {
|
||||
return "sessionresolver"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
// Address implements Resolver.Address.
|
||||
func (r *Resolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -1,31 +1,249 @@
|
|||
package sessionresolver_test
|
||||
package sessionresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/sessionresolver"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/multierror"
|
||||
)
|
||||
|
||||
func TestFallbackWorks(t *testing.T) {
|
||||
reso := sessionresolver.New(netx.Config{})
|
||||
defer reso.CloseIdleConnections()
|
||||
func TestNetworkWorks(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
if reso.Network() != "sessionresolver" {
|
||||
t.Fatal("unexpected Network")
|
||||
t.Fatal("unexpected value returned by Network")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddressWorks(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
if reso.Address() != "" {
|
||||
t.Fatal("unexpected Address")
|
||||
t.Fatal("unexpected value returned by Address")
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "antani.ooni.nu")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
|
||||
func TestTypicalUsageWithFailure(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // fail immediately
|
||||
reso := &Resolver{}
|
||||
addrs, err := reso.LookupHost(ctx, "ooni.org")
|
||||
if !errors.Is(err, ErrLookupHost) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
var me *multierror.Union
|
||||
if !errors.As(err, &me) {
|
||||
t.Fatal("cannot convert error")
|
||||
}
|
||||
for _, child := range me.Children {
|
||||
// net.DNSError does not include the underlying error
|
||||
// but just a string representing the error. This
|
||||
// means that we need to go down hunting what's the
|
||||
// real error that occurred and use more verbose code.
|
||||
{
|
||||
var errWrapper *errwrapper
|
||||
if !errors.As(child, &errWrapper) {
|
||||
t.Fatal("not an instance of errwrapper")
|
||||
}
|
||||
var dnsError *net.DNSError
|
||||
if errors.As(errWrapper.error, &dnsError) {
|
||||
if !strings.HasSuffix(dnsError.Err, "operation was canceled") {
|
||||
t.Fatal("not the error we expected", dnsError.Err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
// otherwise just unwrap and check whether it's
|
||||
// a real context.Canceled error.
|
||||
if !errors.Is(child, context.Canceled) {
|
||||
t.Fatal("unexpected sub-error", child)
|
||||
}
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
if len(reso.res) < 1 {
|
||||
t.Fatal("expected to see some resolvers here")
|
||||
}
|
||||
if reso.Stats() == "" {
|
||||
t.Fatal("expected to see some string returned by stats")
|
||||
}
|
||||
reso.CloseIdleConnections()
|
||||
if len(reso.res) != 0 {
|
||||
t.Fatal("expected to see no resolvers after CloseIdleConnections")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypicalUsageWithSuccess(t *testing.T) {
|
||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||
ctx := context.Background()
|
||||
reso := &Resolver{
|
||||
dnsClientMaker: &fakeDNSClientMaker{
|
||||
reso: &FakeResolver{Data: expected},
|
||||
},
|
||||
}
|
||||
addrs, err := reso.LookupHost(ctx, "dns.google")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, addrs); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLittleLLookupHostWithInvalidURL(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
ctx := context.Background()
|
||||
ri := &resolverinfo{URL: "\t\t\t", Score: 0.99}
|
||||
addrs, err := reso.lookupHost(ctx, ri, "ooni.org")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs here")
|
||||
}
|
||||
if reso.PrimaryFailure.Load() != 1 || reso.FallbackFailure.Load() != 1 {
|
||||
t.Fatal("not the counters we expected to see here")
|
||||
if ri.Score != 0 {
|
||||
t.Fatal("unexpected ri.Score", ri.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLittleLLookupHostWithSuccess(t *testing.T) {
|
||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||
reso := &Resolver{
|
||||
dnsClientMaker: &fakeDNSClientMaker{
|
||||
reso: &FakeResolver{Data: expected},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ri := &resolverinfo{URL: "dot://dns-nonexistent.ooni.org", Score: 0.1}
|
||||
addrs, err := reso.lookupHost(ctx, ri, "dns.google")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, addrs); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if ri.Score < 0.88 || ri.Score > 0.92 {
|
||||
t.Fatal("unexpected score", ri.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLittleLLookupHostWithFailure(t *testing.T) {
|
||||
errMocked := errors.New("mocked error")
|
||||
reso := &Resolver{
|
||||
dnsClientMaker: &fakeDNSClientMaker{
|
||||
reso: &FakeResolver{Err: errMocked},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ri := &resolverinfo{URL: "dot://dns-nonexistent.ooni.org", Score: 0.95}
|
||||
addrs, err := reso.lookupHost(ctx, ri, "dns.google")
|
||||
if !errors.Is(err, errMocked) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs here")
|
||||
}
|
||||
if ri.Score < 0.094 || ri.Score > 0.096 {
|
||||
t.Fatal("unexpected score", ri.Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeConfusionNoConfusion(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
rv := reso.maybeConfusion(nil, 0)
|
||||
if rv != -1 {
|
||||
t.Fatal("unexpected return value", rv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeConfusionNoArray(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
rv := reso.maybeConfusion(nil, 11)
|
||||
if rv != 0 {
|
||||
t.Fatal("unexpected return value", rv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeConfusionSingleEntry(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
state := []*resolverinfo{{}}
|
||||
rv := reso.maybeConfusion(state, 11)
|
||||
if rv != 0 {
|
||||
t.Fatal("unexpected return value", rv)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeConfusionTwoEntries(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
state := []*resolverinfo{{
|
||||
Score: 0.8,
|
||||
URL: "https://dns.google/dns-query",
|
||||
}, {
|
||||
Score: 0.4,
|
||||
URL: "http3://dns.google/dns-query",
|
||||
}}
|
||||
rv := reso.maybeConfusion(state, 11)
|
||||
if rv != 2 {
|
||||
t.Fatal("unexpected return value", rv)
|
||||
}
|
||||
if state[0].Score != 0.4 {
|
||||
t.Fatal("unexpected state[0].Score")
|
||||
}
|
||||
if state[0].URL != "http3://dns.google/dns-query" {
|
||||
t.Fatal("unexpected state[0].URL")
|
||||
}
|
||||
if state[1].Score != 0.8 {
|
||||
t.Fatal("unexpected state[1].Score")
|
||||
}
|
||||
if state[1].URL != "https://dns.google/dns-query" {
|
||||
t.Fatal("unexpected state[1].URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeConfusionManyEntries(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
state := []*resolverinfo{{
|
||||
Score: 0.8,
|
||||
URL: "https://dns.google/dns-query",
|
||||
}, {
|
||||
Score: 0.4,
|
||||
URL: "http3://dns.google/dns-query",
|
||||
}, {
|
||||
Score: 0.1,
|
||||
URL: "system:///",
|
||||
}, {
|
||||
Score: 0.01,
|
||||
URL: "dot://dns.google",
|
||||
}}
|
||||
rv := reso.maybeConfusion(state, 11)
|
||||
if rv != 3 {
|
||||
t.Fatal("unexpected return value", rv)
|
||||
}
|
||||
if state[0].Score != 0.1 {
|
||||
t.Fatal("unexpected state[0].Score")
|
||||
}
|
||||
if state[0].URL != "system:///" {
|
||||
t.Fatal("unexpected state[0].URL")
|
||||
}
|
||||
if state[1].Score != 0.4 {
|
||||
t.Fatal("unexpected state[1].Score")
|
||||
}
|
||||
if state[1].URL != "http3://dns.google/dns-query" {
|
||||
t.Fatal("unexpected state[1].URL")
|
||||
}
|
||||
if state[2].Score != 0.8 {
|
||||
t.Fatal("unexpected state[2].Score")
|
||||
}
|
||||
if state[2].URL != "https://dns.google/dns-query" {
|
||||
t.Fatal("unexpected state[2].URL")
|
||||
}
|
||||
if state[3].Score != 0.01 {
|
||||
t.Fatal("unexpected state[3].Score")
|
||||
}
|
||||
if state[3].URL != "dot://dns.google" {
|
||||
t.Fatal("unexpected state[3].URL")
|
||||
}
|
||||
}
|
||||
|
|
93
internal/engine/internal/sessionresolver/state.go
Normal file
93
internal/engine/internal/sessionresolver/state.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// storekey is the key used by the key value store to store
|
||||
// the state required by this package.
|
||||
const storekey = "sessionresolver.state"
|
||||
|
||||
// resolverinfo contains info about a resolver.
|
||||
type resolverinfo struct {
|
||||
// URL is the URL of a resolver.
|
||||
URL string
|
||||
|
||||
// Score is the score of a resolver.
|
||||
Score float64
|
||||
}
|
||||
|
||||
// readstate reads the resolver state from disk
|
||||
func (r *Resolver) readstate() ([]*resolverinfo, error) {
|
||||
data, err := r.kvstore().Get(storekey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ri []*resolverinfo
|
||||
if err := r.getCodec().Decode(data, &ri); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ri, nil
|
||||
}
|
||||
|
||||
// errNoEntries indicates that no entry remained after we pruned
|
||||
// all the available entries in readstateandprune.
|
||||
var errNoEntries = errors.New("sessionresolver: no available entries")
|
||||
|
||||
// readstateandprune reads the state from disk and removes all the
|
||||
// entries that we don't actually support.
|
||||
func (r *Resolver) readstateandprune() ([]*resolverinfo, error) {
|
||||
ri, err := r.readstate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var out []*resolverinfo
|
||||
for _, e := range ri {
|
||||
if _, found := allbyurl[e.URL]; !found {
|
||||
continue // we don't support this specific entry
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
if len(out) <= 0 {
|
||||
return nil, errNoEntries
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// sortstate sorts the state by descending score
|
||||
func sortstate(ri []*resolverinfo) {
|
||||
sort.SliceStable(ri, func(i, j int) bool {
|
||||
return ri[i].Score >= ri[j].Score
|
||||
})
|
||||
}
|
||||
|
||||
// readstatedefault reads the state from disk and merges the state
|
||||
// so that all supported entries are represented.
|
||||
func (r *Resolver) readstatedefault() []*resolverinfo {
|
||||
ri, _ := r.readstateandprune()
|
||||
here := make(map[string]bool)
|
||||
for _, e := range ri {
|
||||
here[e.URL] = true // record what we already have
|
||||
}
|
||||
for _, e := range allmakers {
|
||||
if _, found := here[e.url]; found {
|
||||
continue // already here so no need to add
|
||||
}
|
||||
ri = append(ri, &resolverinfo{
|
||||
URL: e.url,
|
||||
Score: e.score,
|
||||
})
|
||||
}
|
||||
sortstate(ri)
|
||||
return ri
|
||||
}
|
||||
|
||||
// writestate writes the state on the kvstore.
|
||||
func (r *Resolver) writestate(ri []*resolverinfo) error {
|
||||
data, err := r.getCodec().Encode(ri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.kvstore().Set(storekey, data)
|
||||
}
|
120
internal/engine/internal/sessionresolver/state_test.go
Normal file
120
internal/engine/internal/sessionresolver/state_test.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadStateNothingInKVStore(t *testing.T) {
|
||||
reso := &Resolver{KVStore: &memkvstore{}}
|
||||
out, err := reso.readstate()
|
||||
if !errors.Is(err, errMemkvstoreNotFound) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStateDecodeError(t *testing.T) {
|
||||
errMocked := errors.New("mocked error")
|
||||
reso := &Resolver{
|
||||
KVStore: &memkvstore{},
|
||||
codec: &FakeCodec{DecodeErr: errMocked},
|
||||
}
|
||||
if err := reso.KVStore.Set(storekey, []byte(`[]`)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out, err := reso.readstate()
|
||||
if !errors.Is(err, errMocked) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStateAndPruneReadStateError(t *testing.T) {
|
||||
reso := &Resolver{KVStore: &memkvstore{}}
|
||||
out, err := reso.readstateandprune()
|
||||
if !errors.Is(err, errMemkvstoreNotFound) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStateAndPruneWithUnsupportedEntries(t *testing.T) {
|
||||
reso := &Resolver{KVStore: &memkvstore{}}
|
||||
var in []*resolverinfo
|
||||
in = append(in, &resolverinfo{})
|
||||
if err := reso.writestate(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
out, err := reso.readstateandprune()
|
||||
if !errors.Is(err, errNoEntries) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStateDefaultWithMissingEntries(t *testing.T) {
|
||||
reso := &Resolver{KVStore: &memkvstore{}}
|
||||
// let us simulate that we have just one entry here
|
||||
existingURL := "https://dns.google/dns-query"
|
||||
existingScore := 0.88
|
||||
var in []*resolverinfo
|
||||
in = append(in, &resolverinfo{
|
||||
URL: existingURL,
|
||||
Score: existingScore,
|
||||
})
|
||||
if err := reso.writestate(in); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// let us seee what we read
|
||||
out := reso.readstatedefault()
|
||||
if len(out) < 1 {
|
||||
t.Fatal("expected non-empty output")
|
||||
}
|
||||
keys := make(map[string]bool)
|
||||
var found bool
|
||||
for _, e := range out {
|
||||
keys[e.URL] = true
|
||||
if e.URL == existingURL {
|
||||
if e.Score != existingScore {
|
||||
t.Fatal("the score is not what we expected")
|
||||
}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("did not found the pre-loaded URL")
|
||||
}
|
||||
for k := range allbyurl {
|
||||
if _, found := keys[k]; !found {
|
||||
t.Fatal("missing key", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteStateCannotSerialize(t *testing.T) {
|
||||
errMocked := errors.New("mocked error")
|
||||
reso := &Resolver{
|
||||
codec: &FakeCodec{
|
||||
EncodeErr: errMocked,
|
||||
},
|
||||
}
|
||||
existingURL := "https://dns.google/dns-query"
|
||||
existingScore := 0.88
|
||||
var in []*resolverinfo
|
||||
in = append(in, &resolverinfo{
|
||||
URL: existingURL,
|
||||
Score: existingScore,
|
||||
})
|
||||
if err := reso.writestate(in); !errors.Is(err, errMocked) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
}
|
|
@ -109,7 +109,11 @@ func NewSession(config SessionConfig) (*Session, error) {
|
|||
BogonIsError: true,
|
||||
Logger: sess.logger,
|
||||
}
|
||||
sess.resolver = sessionresolver.New(httpConfig)
|
||||
sess.resolver = &sessionresolver.Resolver{
|
||||
ByteCounter: sess.byteCounter,
|
||||
KVStore: config.KVStore,
|
||||
Logger: sess.logger,
|
||||
}
|
||||
httpConfig.FullResolver = sess.resolver
|
||||
httpConfig.ProxyURL = config.ProxyURL // no need to proxy the resolver
|
||||
sess.httpDefaultTransport = netx.NewHTTPTransport(httpConfig)
|
||||
|
|
Loading…
Reference in New Issue
Block a user