refactor(sessionresolver): replace dnsclientmaker with function (#811)
See https://github.com/ooni/probe/issues/2135
This commit is contained in:
parent
a02cc6100b
commit
bf7ea423d3
|
@ -1,32 +0,0 @@
|
|||
package sessionresolver
|
||||
|
||||
//
|
||||
// Code for mocking the creation of a client.
|
||||
//
|
||||
|
||||
import (
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// dnsclientmaker makes a new resolver.
|
||||
type dnsclientmaker interface {
|
||||
// Make makes a new resolver.
|
||||
Make(config netx.Config, URL string) (model.Resolver, error)
|
||||
}
|
||||
|
||||
// clientmaker returns a valid dnsclientmaker
|
||||
func (r *Resolver) clientmaker() dnsclientmaker {
|
||||
if r.dnsClientMaker != nil {
|
||||
return r.dnsClientMaker
|
||||
}
|
||||
return &defaultDNSClientMaker{}
|
||||
}
|
||||
|
||||
// defaultDNSClientMaker is the default dnsclientmaker
|
||||
type defaultDNSClientMaker struct{}
|
||||
|
||||
// Make implements dnsclientmaker.Make.
|
||||
func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
|
||||
return netx.NewDNSClient(config, URL)
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
type fakeDNSClientMaker struct {
|
||||
reso model.Resolver
|
||||
err error
|
||||
savedConfig netx.Config
|
||||
savedURL string
|
||||
}
|
||||
|
||||
func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
|
||||
c.savedConfig = config
|
||||
c.savedURL = URL
|
||||
return c.reso, c.err
|
||||
}
|
||||
|
||||
func TestClientMakerWithOverride(t *testing.T) {
|
||||
m := &fakeDNSClientMaker{err: io.EOF}
|
||||
reso := &Resolver{dnsClientMaker: m}
|
||||
out, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDefaultWithCancelledContext(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
re, err := reso.clientmaker().Make(netx.Config{}, "https://dns.google/dns-query")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // fail immediately
|
||||
out, err := re.LookupHost(ctx, "dns.google")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("expected nil output")
|
||||
}
|
||||
}
|
|
@ -59,13 +59,13 @@ type Resolver struct {
|
|||
// we will construct a default codec.
|
||||
jsonCodec jsonCodec
|
||||
|
||||
// dnsClientMaker is the OPTIONAL dnsclientmaker to
|
||||
// use. If not set, we will use the default.
|
||||
dnsClientMaker dnsclientmaker
|
||||
|
||||
// mu provides synchronisation of internal fields.
|
||||
mu sync.Mutex
|
||||
|
||||
// newChildResolverFn is the OPTIONAL function to override
|
||||
// the construction of a new resolver in unit tests
|
||||
newChildResolverFn func(h3 bool, URL string) (model.Resolver, error)
|
||||
|
||||
// once ensures that CloseIdleConnection is
|
||||
// run just once.
|
||||
once sync.Once
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/kvstore"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/multierror"
|
||||
)
|
||||
|
@ -85,12 +86,13 @@ func TestTypicalUsageWithSuccess(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
reso := &Resolver{
|
||||
KVStore: &kvstore.Memory{},
|
||||
dnsClientMaker: &fakeDNSClientMaker{
|
||||
reso: &mocks.Resolver{
|
||||
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||
reso := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return expected, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
return reso, nil
|
||||
},
|
||||
}
|
||||
addrs, err := reso.LookupHost(ctx, "dns.google")
|
||||
|
@ -121,12 +123,13 @@ func TestLittleLLookupHostWithInvalidURL(t *testing.T) {
|
|||
func TestLittleLLookupHostWithSuccess(t *testing.T) {
|
||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||
reso := &Resolver{
|
||||
dnsClientMaker: &fakeDNSClientMaker{
|
||||
reso: &mocks.Resolver{
|
||||
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||
reso := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return expected, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
return reso, nil
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
@ -146,12 +149,13 @@ func TestLittleLLookupHostWithSuccess(t *testing.T) {
|
|||
func TestLittleLLookupHostWithFailure(t *testing.T) {
|
||||
errMocked := errors.New("mocked error")
|
||||
reso := &Resolver{
|
||||
dnsClientMaker: &fakeDNSClientMaker{
|
||||
reso: &mocks.Resolver{
|
||||
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||
reso := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, errMocked
|
||||
},
|
||||
},
|
||||
}
|
||||
return reso, nil
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
@ -64,19 +63,25 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
// byteCounter returns the configured byteCounter or a default
|
||||
func (r *Resolver) byteCounter() *bytecounter.Counter {
|
||||
if r.ByteCounter != nil {
|
||||
return r.ByteCounter
|
||||
}
|
||||
return bytecounter.New()
|
||||
}
|
||||
|
||||
// logger returns the configured logger or a default
|
||||
func (r *Resolver) logger() model.Logger {
|
||||
return model.ValidLoggerOrDefault(r.Logger)
|
||||
}
|
||||
|
||||
// newChildResolver creates a new child model.Resolver.
|
||||
func (r *Resolver) newChildResolver(h3 bool, URL string) (model.Resolver, error) {
|
||||
if r.newChildResolverFn != nil {
|
||||
return r.newChildResolverFn(h3, URL)
|
||||
}
|
||||
return netx.NewDNSClient(netx.Config{
|
||||
BogonIsError: true,
|
||||
ByteCounter: r.ByteCounter, // nil is handled by netx
|
||||
HTTP3Enabled: h3,
|
||||
Logger: r.logger(),
|
||||
ProxyURL: r.ProxyURL,
|
||||
}, URL)
|
||||
}
|
||||
|
||||
// newresolver creates a new resolver with the given config and URL. This is
|
||||
// where we expand http3 to https and set the h3 options.
|
||||
func (r *Resolver) newresolver(URL string) (model.Resolver, error) {
|
||||
|
@ -84,13 +89,7 @@ func (r *Resolver) newresolver(URL string) (model.Resolver, error) {
|
|||
if h3 {
|
||||
URL = strings.Replace(URL, "http3://", "https://", 1)
|
||||
}
|
||||
return r.clientmaker().Make(netx.Config{
|
||||
BogonIsError: true,
|
||||
ByteCounter: r.byteCounter(),
|
||||
HTTP3Enabled: h3,
|
||||
Logger: r.logger(),
|
||||
ProxyURL: r.ProxyURL,
|
||||
}, URL)
|
||||
return r.newChildResolver(h3, URL)
|
||||
}
|
||||
|
||||
// getresolver returns a resolver with the given URL. This function caches
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package sessionresolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -10,14 +9,6 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestDefaultByteCounter(t *testing.T) {
|
||||
reso := &Resolver{}
|
||||
bc := reso.byteCounter()
|
||||
if bc == nil {
|
||||
t.Fatal("expected non-nil byte counter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLogger(t *testing.T) {
|
||||
t.Run("when using a different logger", func(t *testing.T) {
|
||||
logger := &mocks.Logger{}
|
||||
|
@ -46,8 +37,18 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
|
|||
closed = true
|
||||
},
|
||||
}
|
||||
cmk := &fakeDNSClientMaker{reso: re}
|
||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
||||
var (
|
||||
savedURL string
|
||||
savedH3 bool
|
||||
)
|
||||
reso := &Resolver{
|
||||
ByteCounter: bc,
|
||||
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||
savedURL = URL
|
||||
savedH3 = h3
|
||||
return re, nil
|
||||
},
|
||||
}
|
||||
out, err := reso.getresolver(URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -66,20 +67,11 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
|
|||
if closed != true {
|
||||
t.Fatal("was not closed")
|
||||
}
|
||||
if cmk.savedURL != URL {
|
||||
if savedURL != URL {
|
||||
t.Fatal("not the URL we expected")
|
||||
}
|
||||
if cmk.savedConfig.ByteCounter != bc {
|
||||
t.Fatal("unexpected ByteCounter")
|
||||
}
|
||||
if cmk.savedConfig.BogonIsError != true {
|
||||
t.Fatal("unexpected BogonIsError")
|
||||
}
|
||||
if cmk.savedConfig.HTTP3Enabled != false {
|
||||
t.Fatal("unexpected HTTP3Enabled")
|
||||
}
|
||||
if cmk.savedConfig.Logger != model.DiscardLogger {
|
||||
t.Fatal("unexpected Log")
|
||||
if savedH3 {
|
||||
t.Fatal("expected false")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,8 +84,18 @@ func TestGetResolverHTTP3(t *testing.T) {
|
|||
closed = true
|
||||
},
|
||||
}
|
||||
cmk := &fakeDNSClientMaker{reso: re}
|
||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
||||
var (
|
||||
savedURL string
|
||||
savedH3 bool
|
||||
)
|
||||
reso := &Resolver{
|
||||
ByteCounter: bc,
|
||||
newChildResolverFn: func(h3 bool, URL string) (model.Resolver, error) {
|
||||
savedURL = URL
|
||||
savedH3 = h3
|
||||
return re, nil
|
||||
},
|
||||
}
|
||||
out, err := reso.getresolver(URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -112,34 +114,10 @@ func TestGetResolverHTTP3(t *testing.T) {
|
|||
if closed != true {
|
||||
t.Fatal("was not closed")
|
||||
}
|
||||
if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) {
|
||||
if savedURL != strings.Replace(URL, "http3://", "https://", 1) {
|
||||
t.Fatal("not the URL we expected")
|
||||
}
|
||||
if cmk.savedConfig.ByteCounter != bc {
|
||||
t.Fatal("unexpected ByteCounter")
|
||||
}
|
||||
if cmk.savedConfig.BogonIsError != true {
|
||||
t.Fatal("unexpected BogonIsError")
|
||||
}
|
||||
if cmk.savedConfig.HTTP3Enabled != true {
|
||||
t.Fatal("unexpected HTTP3Enabled")
|
||||
}
|
||||
if cmk.savedConfig.Logger != model.DiscardLogger {
|
||||
t.Fatal("unexpected Log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetResolverInvalidURL(t *testing.T) {
|
||||
bc := bytecounter.New()
|
||||
URL := "http3://dns.google"
|
||||
errMocked := errors.New("mocked error")
|
||||
cmk := &fakeDNSClientMaker{err: errMocked}
|
||||
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
|
||||
out, err := reso.getresolver(URL)
|
||||
if !errors.Is(err, errMocked) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("not the result we expected")
|
||||
if !savedH3 {
|
||||
t.Fatal("expected true")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user