refactor(netx): move construction logic outside package (#798)
For testability, replace most if-based construction logic with calls to well-tested factories living in other packages. While there, acknowledge that a bunch of types could now be private and make them private, modifying the code to call the public factories allowing to construct said types instead. Part of https://github.com/ooni/probe/issues/2121
This commit is contained in:
@@ -285,10 +285,10 @@ func (e *Experiment) OpenReportContext(ctx context.Context) error {
|
||||
}
|
||||
// use custom client to have proper byte accounting
|
||||
httpClient := &http.Client{
|
||||
Transport: &bytecounter.HTTPTransport{
|
||||
HTTPTransport: e.session.httpDefaultTransport, // proxy is OK
|
||||
Counter: e.byteCounter,
|
||||
},
|
||||
Transport: bytecounter.WrapHTTPTransport(
|
||||
e.session.httpDefaultTransport, // proxy is OK
|
||||
e.byteCounter,
|
||||
),
|
||||
}
|
||||
client, err := e.session.NewProbeServicesClient(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,7 +32,7 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM
|
||||
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
|
||||
reso := netxlite.NewResolverStdlib(mgr.logger)
|
||||
dlr := netxlite.NewDialerWithResolver(mgr.logger, reso)
|
||||
dlr = bytecounter.NewContextAwareDialer(dlr)
|
||||
dlr = bytecounter.WrapWithContextAwareDialer(dlr)
|
||||
// Implements shaping if the user builds using `-tags shaping`
|
||||
// See https://github.com/ooni/probe/issues/2112
|
||||
dlr = netxlite.NewMaybeShapingDialer(dlr)
|
||||
|
||||
@@ -2,44 +2,90 @@ package netx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
// CacheResolver is a resolver that caches successful replies.
|
||||
type CacheResolver struct {
|
||||
ReadOnly bool
|
||||
model.Resolver
|
||||
mu sync.Mutex
|
||||
cache map[string][]string
|
||||
// MaybeWrapWithCachingResolver wraps the provided resolver with a resolver
|
||||
// that remembers the result of previous successful resolutions, if the enabled
|
||||
// argument is true. Otherwise, we return the unmodified provided resolver.
|
||||
//
|
||||
// Bug: the returned resolver only applies caching to LookupHost and any other
|
||||
// lookup operation returns ErrNoDNSTransport to the caller.
|
||||
func MaybeWrapWithCachingResolver(enabled bool, reso model.Resolver) model.Resolver {
|
||||
if enabled {
|
||||
reso = &cacheResolver{
|
||||
cache: map[string][]string{},
|
||||
mu: sync.Mutex{},
|
||||
readOnly: false,
|
||||
resolver: reso,
|
||||
}
|
||||
}
|
||||
return reso
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r *CacheResolver) LookupHost(
|
||||
// MaybeWrapWithStaticDNSCache wraps the provided resolver with a resolver that
|
||||
// checks the given cache before issuing queries to the underlying DNS resolver.
|
||||
//
|
||||
// Bug: the returned resolver only applies caching to LookupHost and any other
|
||||
// lookup operation returns ErrNoDNSTransport to the caller.
|
||||
func MaybeWrapWithStaticDNSCache(cache map[string][]string, reso model.Resolver) model.Resolver {
|
||||
if len(cache) > 0 {
|
||||
reso = &cacheResolver{
|
||||
cache: cache,
|
||||
mu: sync.Mutex{},
|
||||
readOnly: true,
|
||||
resolver: reso,
|
||||
}
|
||||
}
|
||||
return reso
|
||||
}
|
||||
|
||||
// cacheResolver implements CachingResolver and StaticDNSCache.
|
||||
type cacheResolver struct {
|
||||
// cache is the underlying DNS cache.
|
||||
cache map[string][]string
|
||||
|
||||
// mu provides mutual exclusion.
|
||||
mu sync.Mutex
|
||||
|
||||
// readOnly means that we won't cache the result of successful resolutions.
|
||||
readOnly bool
|
||||
|
||||
// resolver is the underlying resolver.
|
||||
resolver model.Resolver
|
||||
}
|
||||
|
||||
var _ model.Resolver = &cacheResolver{}
|
||||
|
||||
// LookupHost implements model.Resolver.LookupHost
|
||||
func (r *cacheResolver) LookupHost(
|
||||
ctx context.Context, hostname string) ([]string, error) {
|
||||
if entry := r.Get(hostname); entry != nil {
|
||||
if entry := r.get(hostname); entry != nil {
|
||||
return entry, nil
|
||||
}
|
||||
entry, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
entry, err := r.resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !r.ReadOnly {
|
||||
r.Set(hostname, entry)
|
||||
if !r.readOnly {
|
||||
r.set(hostname, entry)
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// Get gets the currently configured entry for domain, or nil
|
||||
func (r *CacheResolver) Get(domain string) []string {
|
||||
// get gets the currently configured entry for domain, or nil
|
||||
func (r *cacheResolver) get(domain string) []string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.cache[domain]
|
||||
}
|
||||
|
||||
// Set allows to pre-populate the cache
|
||||
func (r *CacheResolver) Set(domain string, addresses []string) {
|
||||
// set sets a valid inside the cache iff readOnly is false.
|
||||
func (r *cacheResolver) set(domain string, addresses []string) {
|
||||
r.mu.Lock()
|
||||
if r.cache == nil {
|
||||
r.cache = make(map[string][]string)
|
||||
@@ -47,3 +93,28 @@ func (r *CacheResolver) Set(domain string, addresses []string) {
|
||||
r.cache[domain] = addresses
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// Address implements model.Resolver.Address.
|
||||
func (r *cacheResolver) Address() string {
|
||||
return r.resolver.Address()
|
||||
}
|
||||
|
||||
// Network implements model.Resolver.Network.
|
||||
func (r *cacheResolver) Network() string {
|
||||
return r.resolver.Network()
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements model.Resolver.CloseIdleConnections.
|
||||
func (r *cacheResolver) CloseIdleConnections() {
|
||||
r.resolver.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// LookupHTTPS implements model.Resolver.LookupHTTPS.
|
||||
func (r *cacheResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||
return nil, netxlite.ErrNoDNSTransport
|
||||
}
|
||||
|
||||
// LookupNS implements model.Resolver.LookupNS.
|
||||
func (r *cacheResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, netxlite.ErrNoDNSTransport
|
||||
}
|
||||
|
||||
@@ -5,81 +5,202 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
func TestCacheResolverFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
cache := &CacheResolver{Resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs here")
|
||||
}
|
||||
if cache.Get("www.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
func TestMaybeWrapWithCachingResolver(t *testing.T) {
|
||||
t.Run("with enable equal to true", func(t *testing.T) {
|
||||
underlying := &mocks.Resolver{}
|
||||
reso := MaybeWrapWithCachingResolver(true, underlying)
|
||||
cachereso := reso.(*cacheResolver)
|
||||
if cachereso.resolver != underlying {
|
||||
t.Fatal("did not wrap correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with enable equal to false", func(t *testing.T) {
|
||||
underlying := &mocks.Resolver{}
|
||||
reso := MaybeWrapWithCachingResolver(false, underlying)
|
||||
if reso != underlying {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheResolverHitSuccess(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
cache := &CacheResolver{Resolver: r}
|
||||
cache.Set("dns.google.com", []string{"8.8.8.8"})
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
func TestMaybeWrapWithStaticDNSCache(t *testing.T) {
|
||||
t.Run("when the cache is not empty", func(t *testing.T) {
|
||||
cachedDomain := "dns.google"
|
||||
expectedEntry := []string{"8.8.8.8", "8.8.4.4"}
|
||||
underlyingCache := make(map[string][]string)
|
||||
underlyingCache[cachedDomain] = expectedEntry
|
||||
underlyingReso := &mocks.Resolver{}
|
||||
reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso)
|
||||
cachereso := reso.(*cacheResolver)
|
||||
if diff := cmp.Diff(cachereso.cache, underlyingCache); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if cachereso.resolver != underlyingReso {
|
||||
t.Fatal("unexpected underlying resolver")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when the cache is empty", func(t *testing.T) {
|
||||
underlyingCache := make(map[string][]string)
|
||||
underlyingReso := &mocks.Resolver{}
|
||||
reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso)
|
||||
if reso != underlyingReso {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when the cache is nil", func(t *testing.T) {
|
||||
var underlyingCache map[string][]string
|
||||
underlyingReso := &mocks.Resolver{}
|
||||
reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso)
|
||||
if reso != underlyingReso {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheResolverMissSuccess(t *testing.T) {
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return []string{"8.8.8.8"}, nil
|
||||
},
|
||||
}
|
||||
cache := &CacheResolver{Resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.Get("dns.google.com")[0] != "8.8.8.8" {
|
||||
t.Fatal("expected full cache here")
|
||||
}
|
||||
}
|
||||
func TestCacheResolver(t *testing.T) {
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("cache miss and failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
cache := &cacheResolver{resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs here")
|
||||
}
|
||||
if cache.get("www.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
})
|
||||
|
||||
func TestCacheResolverReadonlySuccess(t *testing.T) {
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return []string{"8.8.8.8"}, nil
|
||||
},
|
||||
}
|
||||
cache := &CacheResolver{Resolver: r, ReadOnly: true}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.Get("dns.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
t.Run("cache hit", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
cache := &cacheResolver{resolver: r}
|
||||
cache.set("dns.google.com", []string{"8.8.8.8"})
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cache miss and success with readwrite cache", func(t *testing.T) {
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return []string{"8.8.8.8"}, nil
|
||||
},
|
||||
}
|
||||
cache := &cacheResolver{resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.get("dns.google.com")[0] != "8.8.8.8" {
|
||||
t.Fatal("expected full cache here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cache miss and success with readonly cache", func(t *testing.T) {
|
||||
r := &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return []string{"8.8.8.8"}, nil
|
||||
},
|
||||
}
|
||||
cache := &cacheResolver{resolver: r, readOnly: true}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.get("dns.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Address", func(t *testing.T) {
|
||||
underlying := &mocks.Resolver{
|
||||
MockAddress: func() string {
|
||||
return "x"
|
||||
},
|
||||
}
|
||||
reso := &cacheResolver{resolver: underlying}
|
||||
if reso.Address() != "x" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network", func(t *testing.T) {
|
||||
underlying := &mocks.Resolver{
|
||||
MockNetwork: func() string {
|
||||
return "x"
|
||||
},
|
||||
}
|
||||
reso := &cacheResolver{resolver: underlying}
|
||||
if reso.Network() != "x" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
underlying := &mocks.Resolver{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
reso := &cacheResolver{resolver: underlying}
|
||||
reso.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||
reso := &cacheResolver{}
|
||||
https, err := reso.LookupHTTPS(context.Background(), "dns.google")
|
||||
if !errors.Is(err, netxlite.ErrNoDNSTransport) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("expected nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
reso := &cacheResolver{}
|
||||
ns, err := reso.LookupNS(context.Background(), "dns.google")
|
||||
if !errors.Is(err, netxlite.ErrNoDNSTransport) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(ns) != 0 {
|
||||
t.Fatal("expected zero length slice")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -67,19 +67,9 @@ func NewResolver(config Config) model.Resolver {
|
||||
model.ValidLoggerOrDefault(config.Logger),
|
||||
config.BaseResolver,
|
||||
)
|
||||
if config.CacheResolutions {
|
||||
r = &CacheResolver{Resolver: r}
|
||||
}
|
||||
if config.DNSCache != nil {
|
||||
cache := &CacheResolver{Resolver: r, ReadOnly: true}
|
||||
for key, values := range config.DNSCache {
|
||||
cache.Set(key, values)
|
||||
}
|
||||
r = cache
|
||||
}
|
||||
if config.BogonIsError {
|
||||
r = &netxlite.BogonResolver{Resolver: r}
|
||||
}
|
||||
r = MaybeWrapWithCachingResolver(config.CacheResolutions, r)
|
||||
r = MaybeWrapWithStaticDNSCache(config.DNSCache, r)
|
||||
r = netxlite.MaybeWrapWithBogonResolver(config.BogonIsError, r)
|
||||
return config.Saver.WrapResolver(r) // WAI when config.Saver==nil
|
||||
}
|
||||
|
||||
@@ -94,9 +84,7 @@ func NewDialer(config Config) model.Dialer {
|
||||
config.ReadWriteSaver.NewReadWriteObserver(),
|
||||
)
|
||||
d = netxlite.NewMaybeProxyDialer(d, config.ProxyURL)
|
||||
if config.ContextByteCounting {
|
||||
d = &bytecounter.ContextAwareDialer{Dialer: d}
|
||||
}
|
||||
d = bytecounter.MaybeWrapWithContextAwareDialer(config.ContextByteCounting, d)
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -143,15 +131,12 @@ func NewHTTPTransport(config Config) model.HTTPTransport {
|
||||
TLSDialer: config.TLSDialer,
|
||||
TLSConfig: config.TLSConfig,
|
||||
})
|
||||
if config.ByteCounter != nil {
|
||||
txp = &bytecounter.HTTPTransport{
|
||||
Counter: config.ByteCounter, HTTPTransport: txp}
|
||||
}
|
||||
if config.Saver != nil {
|
||||
txp = &tracex.HTTPTransportSaver{
|
||||
HTTPTransport: txp, Saver: config.Saver}
|
||||
}
|
||||
return txp
|
||||
// TODO(bassosimone): I am not super convinced by this code because it
|
||||
// seems we're currently counting bytes twice in some cases. I think we
|
||||
// should review how we're counting bytes and using netx currently.
|
||||
txp = config.ByteCounter.MaybeWrapHTTPTransport(txp) // WAI with ByteCounter == nil
|
||||
const defaultSnapshotSize = 0 // means: use the default snapsize
|
||||
return config.Saver.MaybeWrapHTTPTransport(txp, defaultSnapshotSize) // WAI with Saver == nil
|
||||
}
|
||||
|
||||
// httpTransportInfo contains the constructing function as well as the transport name
|
||||
|
||||
@@ -100,21 +100,6 @@ func TestNewWithDialer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithByteCounter(t *testing.T) {
|
||||
counter := bytecounter.New()
|
||||
txp := NewHTTPTransport(Config{
|
||||
ByteCounter: counter,
|
||||
})
|
||||
bctxp, ok := txp.(*bytecounter.HTTPTransport)
|
||||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
if bctxp.Counter != counter {
|
||||
t.Fatal("not the byte counter we expected")
|
||||
}
|
||||
// We are going to trust the underlying transport returned by netxlite
|
||||
}
|
||||
|
||||
func TestNewWithSaver(t *testing.T) {
|
||||
saver := new(tracex.Saver)
|
||||
txp := NewHTTPTransport(Config{
|
||||
|
||||
@@ -202,7 +202,7 @@ func NewSession(ctx context.Context, config SessionConfig) (*Session, error) {
|
||||
handshaker := netxlite.NewTLSHandshakerStdlib(sess.logger)
|
||||
tlsDialer := netxlite.NewTLSDialer(dialer, handshaker)
|
||||
txp := netxlite.NewHTTPTransport(sess.logger, dialer, tlsDialer)
|
||||
txp = bytecounter.NewHTTPTransport(txp, sess.byteCounter)
|
||||
txp = bytecounter.WrapHTTPTransport(txp, sess.byteCounter)
|
||||
sess.httpDefaultTransport = txp
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user