refactor(sessionresolver): replace dnsclientmaker with function (#811)
See https://github.com/ooni/probe/issues/2135
This commit is contained in:
parent
a02cc6100b
commit
bf7ea423d3
|
@ -1,32 +0,0 @@
|
||||||
package sessionresolver
|
|
||||||
|
|
||||||
//
|
|
||||||
// Code for mocking the creation of a client.
|
|
||||||
//
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// dnsclientmaker makes a new resolver.
|
|
||||||
type dnsclientmaker interface {
|
|
||||||
// Make makes a new resolver.
|
|
||||||
Make(config netx.Config, URL string) (model.Resolver, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientmaker returns a valid dnsclientmaker
|
|
||||||
func (r *Resolver) clientmaker() dnsclientmaker {
|
|
||||||
if r.dnsClientMaker != nil {
|
|
||||||
return r.dnsClientMaker
|
|
||||||
}
|
|
||||||
return &defaultDNSClientMaker{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultDNSClientMaker is the default dnsclientmaker
|
|
||||||
type defaultDNSClientMaker struct{}
|
|
||||||
|
|
||||||
// Make implements dnsclientmaker.Make.
|
|
||||||
func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
|
|
||||||
return netx.NewDNSClient(config, URL)
|
|
||||||
}
|
|
|
@ -1,53 +0,0 @@
|
||||||
package sessionresolver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type fakeDNSClientMaker struct {
|
|
||||||
reso model.Resolver
|
|
||||||
err error
|
|
||||||
savedConfig netx.Config
|
|
||||||
savedURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
|
|
||||||
c.savedConfig = config
|
|
||||||
c.savedURL = URL
|
|
||||||
return c.reso, c.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientMakerWithOverride(t *testing.T) {
|
|
||||||
m := &fakeDNSClientMaker{err: io.EOF}
|
|
||||||
reso := &Resolver{dnsClientMaker: m}
|
|
||||||
out, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
|
|
||||||
if !errors.Is(err, io.EOF) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
t.Fatal("expected nil here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientDefaultWithCancelledContext(t *testing.T) {
|
|
||||||
reso := &Resolver{}
|
|
||||||
re, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel() // fail immediately
|
|
||||||
out, err := re.LookupHost(ctx, "dns.google")
|
|
||||||
if !errors.Is(err, context.Canceled) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
t.Fatal("expected nil output")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -59,13 +59,13 @@ type Resolver struct {
|
||||||
// we will construct a default codec.
|
// we will construct a default codec.
|
||||||
jsonCodec jsonCodec
|
jsonCodec jsonCodec
|
||||||
|
|
||||||
// dnsClientMaker is the OPTIONAL dnsclientmaker to
|
|
||||||
// use. If not set, we will use the default.
|
|
||||||
dnsClientMaker dnsclientmaker
|
|
||||||
|
|
||||||
// mu provides synchronisation of internal fields.
|
// mu provides synchronisation of internal fields.
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// newChildResolverFn is the OPTIONAL function to override
|
||||||
|
// the construction of a new resolver in unit tests
|
||||||
|
newChildResolverFn func(h3 bool, URL string) (model.Resolver, error)
|
||||||
|
|
||||||
// once ensures that CloseIdleConnection is
|
// once ensures that CloseIdleConnection is
|
||||||
// run just once.
|
// run just once.
|
||||||
once sync.Once
|
once sync.Once
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/kvstore"
|
"github.com/ooni/probe-cli/v3/internal/kvstore"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
"github.com/ooni/probe-cli/v3/internal/multierror"
|
"github.com/ooni/probe-cli/v3/internal/multierror"
|
||||||
)
|
)
|
||||||
|
@ -85,12 +86,13 @@ func TestTypicalUsageWithSuccess(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
reso := &Resolver{
|
reso := &Resolver{
|
||||||
KVStore: &kvstore.Memory{},
|
KVStore: &kvstore.Memory{},
|
||||||
dnsClientMaker: &fakeDNSClientMaker{
|
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||||
reso: &mocks.Resolver{
|
reso := &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return expected, nil
|
return expected, nil
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
|
return reso, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
addrs, err := reso.LookupHost(ctx, "dns.google")
|
addrs, err := reso.LookupHost(ctx, "dns.google")
|
||||||
|
@ -121,12 +123,13 @@ func TestLittleLLookupHostWithInvalidURL(t *testing.T) {
|
||||||
func TestLittleLLookupHostWithSuccess(t *testing.T) {
|
func TestLittleLLookupHostWithSuccess(t *testing.T) {
|
||||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||||
reso := &Resolver{
|
reso := &Resolver{
|
||||||
dnsClientMaker: &fakeDNSClientMaker{
|
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||||
reso: &mocks.Resolver{
|
reso := &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return expected, nil
|
return expected, nil
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
|
return reso, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -146,12 +149,13 @@ func TestLittleLLookupHostWithSuccess(t *testing.T) {
|
||||||
func TestLittleLLookupHostWithFailure(t *testing.T) {
|
func TestLittleLLookupHostWithFailure(t *testing.T) {
|
||||||
errMocked := errors.New("mocked error")
|
errMocked := errors.New("mocked error")
|
||||||
reso := &Resolver{
|
reso := &Resolver{
|
||||||
dnsClientMaker: &fakeDNSClientMaker{
|
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||||
reso: &mocks.Resolver{
|
reso := &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return nil, errMocked
|
return nil, errMocked
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
|
return reso, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
@ -64,19 +63,25 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// byteCounter returns the configured byteCounter or a default
|
|
||||||
func (r *Resolver) byteCounter() *bytecounter.Counter {
|
|
||||||
if r.ByteCounter != nil {
|
|
||||||
return r.ByteCounter
|
|
||||||
}
|
|
||||||
return bytecounter.New()
|
|
||||||
}
|
|
||||||
|
|
||||||
// logger returns the configured logger or a default
|
// logger returns the configured logger or a default
|
||||||
func (r *Resolver) logger() model.Logger {
|
func (r *Resolver) logger() model.Logger {
|
||||||
return model.ValidLoggerOrDefault(r.Logger)
|
return model.ValidLoggerOrDefault(r.Logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newChildResolver creates a new child model.Resolver.
|
||||||
|
func (r *Resolver) newChildResolver(h3 bool, URL string) (model.Resolver, error) {
|
||||||
|
if r.newChildResolverFn != nil {
|
||||||
|
return r.newChildResolverFn(h3, URL)
|
||||||
|
}
|
||||||
|
return netx.NewDNSClient(netx.Config{
|
||||||
|
BogonIsError: true,
|
||||||
|
ByteCounter: r.ByteCounter, // nil is handled by netx
|
||||||
|
HTTP3Enabled: h3,
|
||||||
|
Logger: r.logger(),
|
||||||
|
ProxyURL: r.ProxyURL,
|
||||||
|
}, URL)
|
||||||
|
}
|
||||||
|
|
||||||
// newresolver creates a new resolver with the given config and URL. This is
|
// newresolver creates a new resolver with the given config and URL. This is
|
||||||
// where we expand http3 to https and set the h3 options.
|
// where we expand http3 to https and set the h3 options.
|
||||||
func (r *Resolver) newresolver(URL string) (model.Resolver, error) {
|
func (r *Resolver) newresolver(URL string) (model.Resolver, error) {
|
||||||
|
@ -84,13 +89,7 @@ func (r *Resolver) newresolver(URL string) (model.Resolver, error) {
|
||||||
if h3 {
|
if h3 {
|
||||||
URL = strings.Replace(URL, "http3://", "https://", 1)
|
URL = strings.Replace(URL, "http3://", "https://", 1)
|
||||||
}
|
}
|
||||||
return r.clientmaker().Make(netx.Config{
|
return r.newChildResolver(h3, URL)
|
||||||
BogonIsError: true,
|
|
||||||
ByteCounter: r.byteCounter(),
|
|
||||||
HTTP3Enabled: h3,
|
|
||||||
Logger: r.logger(),
|
|
||||||
ProxyURL: r.ProxyURL,
|
|
||||||
}, URL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getresolver returns a resolver with the given URL. This function caches
|
// getresolver returns a resolver with the given URL. This function caches
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package sessionresolver
|
package sessionresolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -10,14 +9,6 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDefaultByteCounter(t *testing.T) {
|
|
||||||
reso := &Resolver{}
|
|
||||||
bc := reso.byteCounter()
|
|
||||||
if bc == nil {
|
|
||||||
t.Fatal("expected non-nil byte counter")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultLogger(t *testing.T) {
|
func TestDefaultLogger(t *testing.T) {
|
||||||
t.Run("when using a different logger", func(t *testing.T) {
|
t.Run("when using a different logger", func(t *testing.T) {
|
||||||
logger := &mocks.Logger{}
|
logger := &mocks.Logger{}
|
||||||
|
@ -46,8 +37,18 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
|
||||||
closed = true
|
closed = true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cmk := &fakeDNSClientMaker{reso: re}
|
var (
|
||||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
savedURL string
|
||||||
|
savedH3 bool
|
||||||
|
)
|
||||||
|
reso := &Resolver{
|
||||||
|
ByteCounter: bc,
|
||||||
|
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||||
|
savedURL = URL
|
||||||
|
savedH3 = h3
|
||||||
|
return re, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
out, err := reso.getresolver(URL)
|
out, err := reso.getresolver(URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -66,20 +67,11 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
|
||||||
if closed != true {
|
if closed != true {
|
||||||
t.Fatal("was not closed")
|
t.Fatal("was not closed")
|
||||||
}
|
}
|
||||||
if cmk.savedURL != URL {
|
if savedURL != URL {
|
||||||
t.Fatal("not the URL we expected")
|
t.Fatal("not the URL we expected")
|
||||||
}
|
}
|
||||||
if cmk.savedConfig.ByteCounter != bc {
|
if savedH3 {
|
||||||
t.Fatal("unexpected ByteCounter")
|
t.Fatal("expected false")
|
||||||
}
|
|
||||||
if cmk.savedConfig.BogonIsError != true {
|
|
||||||
t.Fatal("unexpected BogonIsError")
|
|
||||||
}
|
|
||||||
if cmk.savedConfig.HTTP3Enabled != false {
|
|
||||||
t.Fatal("unexpected HTTP3Enabled")
|
|
||||||
}
|
|
||||||
if cmk.savedConfig.Logger != model.DiscardLogger {
|
|
||||||
t.Fatal("unexpected Log")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,8 +84,18 @@ func TestGetResolverHTTP3(t *testing.T) {
|
||||||
closed = true
|
closed = true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cmk := &fakeDNSClientMaker{reso: re}
|
var (
|
||||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
savedURL string
|
||||||
|
savedH3 bool
|
||||||
|
)
|
||||||
|
reso := &Resolver{
|
||||||
|
ByteCounter: bc,
|
||||||
|
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||||
|
savedURL = URL
|
||||||
|
savedH3 = h3
|
||||||
|
return re, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
out, err := reso.getresolver(URL)
|
out, err := reso.getresolver(URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -112,34 +114,10 @@ func TestGetResolverHTTP3(t *testing.T) {
|
||||||
if closed != true {
|
if closed != true {
|
||||||
t.Fatal("was not closed")
|
t.Fatal("was not closed")
|
||||||
}
|
}
|
||||||
if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) {
|
if savedURL != strings.Replace(URL, "http3://", "https://", 1) {
|
||||||
t.Fatal("not the URL we expected")
|
t.Fatal("not the URL we expected")
|
||||||
}
|
}
|
||||||
if cmk.savedConfig.ByteCounter != bc {
|
if !savedH3 {
|
||||||
t.Fatal("unexpected ByteCounter")
|
t.Fatal("expected true")
|
||||||
}
|
|
||||||
if cmk.savedConfig.BogonIsError != true {
|
|
||||||
t.Fatal("unexpected BogonIsError")
|
|
||||||
}
|
|
||||||
if cmk.savedConfig.HTTP3Enabled != true {
|
|
||||||
t.Fatal("unexpected HTTP3Enabled")
|
|
||||||
}
|
|
||||||
if cmk.savedConfig.Logger != model.DiscardLogger {
|
|
||||||
t.Fatal("unexpected Log")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetResolverInvalidURL(t *testing.T) {
|
|
||||||
bc := bytecounter.New()
|
|
||||||
URL := "http3://dns.google"
|
|
||||||
errMocked := errors.New("mocked error")
|
|
||||||
cmk := &fakeDNSClientMaker{err: errMocked}
|
|
||||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
|
||||||
out, err := reso.getresolver(URL)
|
|
||||||
if !errors.Is(err, errMocked) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
t.Fatal("not the result we expected")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user