diff --git a/internal/bytecounter/conn.go b/internal/bytecounter/conn.go index 3f37dff..65533fa 100644 --- a/internal/bytecounter/conn.go +++ b/internal/bytecounter/conn.go @@ -6,8 +6,8 @@ package bytecounter import "net" -// Conn wraps a network connection and counts bytes. -type Conn struct { +// wrappedConn wraps a network connection and counts bytes. +type wrappedConn struct { // net.Conn is the underlying net.Conn. net.Conn @@ -16,28 +16,28 @@ type Conn struct { } // Read implements net.Conn.Read. -func (c *Conn) Read(p []byte) (int, error) { +func (c *wrappedConn) Read(p []byte) (int, error) { count, err := c.Conn.Read(p) c.Counter.CountBytesReceived(count) return count, err } // Write implements net.Conn.Write. -func (c *Conn) Write(p []byte) (int, error) { +func (c *wrappedConn) Write(p []byte) (int, error) { count, err := c.Conn.Write(p) c.Counter.CountBytesSent(count) return count, err } -// Wrap returns a new conn that uses the given counter. -func Wrap(conn net.Conn, counter *Counter) net.Conn { - return &Conn{Conn: conn, Counter: counter} +// WrapConn returns a new conn that uses the given counter. +func WrapConn(conn net.Conn, counter *Counter) net.Conn { + return &wrappedConn{Conn: conn, Counter: counter} } -// MaybeWrap is like wrap if counter is not nil, otherwise it's a no-op. -func MaybeWrap(conn net.Conn, counter *Counter) net.Conn { +// MaybeWrapConn is like wrap if counter is not nil, otherwise it's a no-op. +func MaybeWrapConn(conn net.Conn, counter *Counter) net.Conn { if counter == nil { return conn } - return Wrap(conn, counter) + return WrapConn(conn, counter) } diff --git a/internal/bytecounter/conn_test.go b/internal/bytecounter/conn_test.go index 30b1b9a..774e35e 100644 --- a/internal/bytecounter/conn_test.go +++ b/internal/bytecounter/conn_test.go @@ -7,7 +7,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/model/mocks" ) -func TestConnWorksOnSuccess(t *testing.T) { +func TestWrappedConnWorksOnSuccess(t *testing.T) { counter := New() underlying := &mocks.Conn{ MockRead: func(b []byte) (int, error) { @@ -17,7 +17,7 @@ func TestConnWorksOnSuccess(t *testing.T) { return 4, nil }, } - conn := &Conn{ + conn := &wrappedConn{ Conn: underlying, Counter: counter, } @@ -35,7 +35,7 @@ func TestConnWorksOnSuccess(t *testing.T) { } } -func TestConnWorksOnFailure(t *testing.T) { +func TestWrappedConnWorksOnFailure(t *testing.T) { readError := errors.New("read error") writeError := errors.New("write error") counter := New() @@ -47,7 +47,7 @@ func TestConnWorksOnFailure(t *testing.T) { return 0, writeError }, } - conn := &Conn{ + conn := &wrappedConn{ Conn: underlying, Counter: counter, } @@ -65,20 +65,20 @@ func TestConnWorksOnFailure(t *testing.T) { } } -func TestWrap(t *testing.T) { +func TestWrapConn(t *testing.T) { conn := &mocks.Conn{} counter := New() - nconn := Wrap(conn, counter) - _, good := nconn.(*Conn) + nconn := WrapConn(conn, counter) + _, good := nconn.(*wrappedConn) if !good { t.Fatal("did not wrap") } } -func TestMaybeWrap(t *testing.T) { +func TestMaybeWrapConn(t *testing.T) { t.Run("with nil counter", func(t *testing.T) { conn := &mocks.Conn{} - nconn := MaybeWrap(conn, nil) + nconn := MaybeWrapConn(conn, nil) _, good := nconn.(*mocks.Conn) if !good { t.Fatal("did not wrap") @@ -88,8 +88,8 @@ func TestMaybeWrap(t *testing.T) { t.Run("with legit counter", func(t *testing.T) { conn := &mocks.Conn{} counter := New() - nconn := MaybeWrap(conn, counter) - _, good := nconn.(*Conn) + nconn := MaybeWrapConn(conn, counter) + _, good := nconn.(*wrappedConn) if !good { t.Fatal("did not wrap") } diff --git a/internal/bytecounter/context.go b/internal/bytecounter/context.go index 64b9d2a..3817250 100644 --- a/internal/bytecounter/context.go +++ b/internal/bytecounter/context.go @@ -38,7 +38,7 @@ func WithExperimentByteCounter(ctx context.Context, counter *Counter) context.Co // MaybeWrapWithContextByteCounters wraps a conn with the byte counters // that have previosuly been configured into a context. func MaybeWrapWithContextByteCounters(ctx context.Context, conn net.Conn) net.Conn { - conn = MaybeWrap(conn, ContextExperimentByteCounter(ctx)) - conn = MaybeWrap(conn, ContextSessionByteCounter(ctx)) + conn = MaybeWrapConn(conn, ContextExperimentByteCounter(ctx)) + conn = MaybeWrapConn(conn, ContextSessionByteCounter(ctx)) return conn } diff --git a/internal/bytecounter/counter_test.go b/internal/bytecounter/counter_test.go index dd0d9d9..b60494c 100644 --- a/internal/bytecounter/counter_test.go +++ b/internal/bytecounter/counter_test.go @@ -2,7 +2,7 @@ package bytecounter import "testing" -func TestGood(t *testing.T) { +func TestCounter(t *testing.T) { counter := New() counter.CountBytesReceived(16384) counter.CountKibiBytesReceived(10) diff --git a/internal/bytecounter/dialer.go b/internal/bytecounter/dialer.go index d4246e5..7826549 100644 --- a/internal/bytecounter/dialer.go +++ b/internal/bytecounter/dialer.go @@ -11,8 +11,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// ContextAwareDialer is a model.Dialer that attempts to count bytes using -// the MaybeWrapWithContextByteCounters function. +// MaybeWrapWithContextAwareDialer wraps the given dialer with a ContextAwareDialer +// if the enabled argument is true and otherwise just returns the given dialer. // // Bug // @@ -24,19 +24,29 @@ import ( // // For this reason, this implementation may be heavily changed/removed // in the future (<- this message is now ~two years old, though). -type ContextAwareDialer struct { +func MaybeWrapWithContextAwareDialer(enabled bool, dialer model.Dialer) model.Dialer { + if !enabled { + return dialer + } + return WrapWithContextAwareDialer(dialer) +} + +// contextAwareDialer is a model.Dialer that attempts to count bytes using +// the MaybeWrapWithContextByteCounters function. +type contextAwareDialer struct { Dialer model.Dialer } -// NewContextAwareDialer creates a new ContextAwareDialer. -func NewContextAwareDialer(dialer model.Dialer) *ContextAwareDialer { - return &ContextAwareDialer{Dialer: dialer} +// WrapWithContextAwareDialer creates a new ContextAwareDialer. See the docs +// of MaybeWrapWithContextAwareDialer for a list of caveats. +func WrapWithContextAwareDialer(dialer model.Dialer) *contextAwareDialer { + return &contextAwareDialer{Dialer: dialer} } -var _ model.Dialer = &ContextAwareDialer{} +var _ model.Dialer = &contextAwareDialer{} // DialContext implements Dialer.DialContext -func (d *ContextAwareDialer) DialContext( +func (d *contextAwareDialer) DialContext( ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { @@ -47,6 +57,6 @@ func (d *ContextAwareDialer) DialContext( } // CloseIdleConnections implements Dialer.CloseIdleConnections. -func (d *ContextAwareDialer) CloseIdleConnections() { +func (d *contextAwareDialer) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } diff --git a/internal/bytecounter/dialer_test.go b/internal/bytecounter/dialer_test.go index 35bdd94..953c66b 100644 --- a/internal/bytecounter/dialer_test.go +++ b/internal/bytecounter/dialer_test.go @@ -10,6 +10,25 @@ import ( "github.com/ooni/probe-cli/v3/internal/model/mocks" ) +func TestMaybeWrapWithContextAwareDialer(t *testing.T) { + t.Run("when enabled is true", func(t *testing.T) { + underlying := &mocks.Dialer{} + dialer := MaybeWrapWithContextAwareDialer(true, underlying) + realDialer := dialer.(*contextAwareDialer) + if realDialer.Dialer != underlying { + t.Fatal("did not wrap correctly") + } + }) + + t.Run("when enabled is false", func(t *testing.T) { + underlying := &mocks.Dialer{} + dialer := MaybeWrapWithContextAwareDialer(false, underlying) + if dialer != underlying { + t.Fatal("unexpected result") + } + }) +} + func TestContextAwareDialer(t *testing.T) { t.Run("DialContext", func(t *testing.T) { dialAndUseConn := func(ctx context.Context, bufsiz int) error { @@ -26,7 +45,7 @@ func TestContextAwareDialer(t *testing.T) { return childConn, nil }, } - dialer := NewContextAwareDialer(child) + dialer := WrapWithContextAwareDialer(child) conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443") if err != nil { return err @@ -68,7 +87,7 @@ func TestContextAwareDialer(t *testing.T) { }) t.Run("failure", func(t *testing.T) { - dialer := &ContextAwareDialer{ + dialer := &contextAwareDialer{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF @@ -92,7 +111,7 @@ func TestContextAwareDialer(t *testing.T) { called = true }, } - dialer := NewContextAwareDialer(child) + dialer := WrapWithContextAwareDialer(child) dialer.CloseIdleConnections() if !called { t.Fatal("not called") diff --git a/internal/bytecounter/http.go b/internal/bytecounter/http.go index 0032ae0..f892623 100644 --- a/internal/bytecounter/http.go +++ b/internal/bytecounter/http.go @@ -7,29 +7,39 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// HTTPTransport is a model.HTTPTransport that counts bytes. -type HTTPTransport struct { +// MaybeWrapHTTPTransport takes in input an HTTPTransport and either wraps it +// to perform byte counting, if this counter is not nil, or just returns to the +// caller the original transport, when the counter is nil. +func (c *Counter) MaybeWrapHTTPTransport(txp model.HTTPTransport) model.HTTPTransport { + if c != nil { + txp = WrapHTTPTransport(txp, c) + } + return txp +} + +// httpTransport is a model.HTTPTransport that counts bytes. +type httpTransport struct { HTTPTransport model.HTTPTransport Counter *Counter } -// NewHTTPTransport creates a new byte-counting-aware HTTP transport. -func NewHTTPTransport(txp model.HTTPTransport, counter *Counter) model.HTTPTransport { - return &HTTPTransport{ +// WrapHTTPTransport creates a new byte-counting-aware HTTP transport. +func WrapHTTPTransport(txp model.HTTPTransport, counter *Counter) model.HTTPTransport { + return &httpTransport{ HTTPTransport: txp, Counter: counter, } } -var _ model.HTTPTransport = &HTTPTransport{} +var _ model.HTTPTransport = &httpTransport{} // CloseIdleConnections implements model.HTTPTransport.CloseIdleConnections. -func (txp *HTTPTransport) CloseIdleConnections() { +func (txp *httpTransport) CloseIdleConnections() { txp.HTTPTransport.CloseIdleConnections() } // RoundTrip implements model.HTTPTRansport.RoundTrip -func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (txp *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) { if req.Body != nil { req.Body = &httpBodyWrapper{ account: txp.Counter.CountBytesSent, @@ -50,11 +60,11 @@ func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { } // Network implements model.HTTPTransport.Network. -func (txp *HTTPTransport) Network() string { +func (txp *httpTransport) Network() string { return txp.HTTPTransport.Network() } -func (txp *HTTPTransport) estimateRequestMetadata(req *http.Request) { +func (txp *httpTransport) estimateRequestMetadata(req *http.Request) { txp.Counter.CountBytesSent(len(req.Method)) txp.Counter.CountBytesSent(len(req.URL.String())) for key, values := range req.Header { @@ -68,7 +78,7 @@ func (txp *HTTPTransport) estimateRequestMetadata(req *http.Request) { txp.Counter.CountBytesSent(len("\r\n")) } -func (txp *HTTPTransport) estimateResponseMetadata(resp *http.Response) { +func (txp *httpTransport) estimateResponseMetadata(resp *http.Response) { txp.Counter.CountBytesReceived(len(resp.Status)) for key, values := range resp.Header { for _, value := range values { diff --git a/internal/bytecounter/http_test.go b/internal/bytecounter/http_test.go index d83eaa0..6616c01 100644 --- a/internal/bytecounter/http_test.go +++ b/internal/bytecounter/http_test.go @@ -12,11 +12,32 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) +func TestMaybeWrapHTTPTransport(t *testing.T) { + t.Run("when counter is not nil", func(t *testing.T) { + underlying := &mocks.HTTPTransport{} + counter := &Counter{} + txp := counter.MaybeWrapHTTPTransport(underlying) + realTxp := txp.(*httpTransport) + if realTxp.HTTPTransport != underlying { + t.Fatal("did not wrap correctly") + } + }) + + t.Run("when counter is nil", func(t *testing.T) { + underlying := &mocks.HTTPTransport{} + var counter *Counter + txp := counter.MaybeWrapHTTPTransport(underlying) + if txp != underlying { + t.Fatal("unexpected result") + } + }) +} + func TestHTTPTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { t.Run("failure", func(t *testing.T) { counter := New() - txp := &HTTPTransport{ + txp := &httpTransport{ Counter: counter, HTTPTransport: &mocks.HTTPTransport{ MockRoundTrip: func(req *http.Request) (*http.Response, error) { @@ -47,7 +68,7 @@ func TestHTTPTransport(t *testing.T) { t.Run("success", func(t *testing.T) { counter := New() - txp := &HTTPTransport{ + txp := &httpTransport{ Counter: counter, HTTPTransport: &mocks.HTTPTransport{ MockRoundTrip: func(req *http.Request) (*http.Response, error) { @@ -91,7 +112,7 @@ func TestHTTPTransport(t *testing.T) { t.Run("success with EOF", func(t *testing.T) { counter := New() - txp := &HTTPTransport{ + txp := &httpTransport{ Counter: counter, HTTPTransport: &mocks.HTTPTransport{ MockRoundTrip: func(req *http.Request) (*http.Response, error) { @@ -139,7 +160,7 @@ func TestHTTPTransport(t *testing.T) { }, } counter := New() - txp := NewHTTPTransport(child, counter) + txp := WrapHTTPTransport(child, counter) txp.CloseIdleConnections() if !called { t.Fatal("not called") @@ -154,7 +175,7 @@ func TestHTTPTransport(t *testing.T) { }, } counter := New() - txp := NewHTTPTransport(child, counter) + txp := WrapHTTPTransport(child, counter) if network := txp.Network(); network != expected { t.Fatal("unexpected network", network) } diff --git a/internal/engine/experiment.go b/internal/engine/experiment.go index 59562e7..e271075 100644 --- a/internal/engine/experiment.go +++ b/internal/engine/experiment.go @@ -285,10 +285,10 @@ func (e *Experiment) OpenReportContext(ctx context.Context) error { } // use custom client to have proper byte accounting httpClient := &http.Client{ - Transport: &bytecounter.HTTPTransport{ - HTTPTransport: e.session.httpDefaultTransport, // proxy is OK - Counter: e.byteCounter, - }, + Transport: bytecounter.WrapHTTPTransport( + e.session.httpDefaultTransport, // proxy is OK + e.byteCounter, + ), } client, err := e.session.NewProbeServicesClient(ctx) if err != nil { diff --git a/internal/engine/experiment/ndt7/dial.go b/internal/engine/experiment/ndt7/dial.go index 0e02ce9..e81e941 100644 --- a/internal/engine/experiment/ndt7/dial.go +++ b/internal/engine/experiment/ndt7/dial.go @@ -32,7 +32,7 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) { reso := netxlite.NewResolverStdlib(mgr.logger) dlr := netxlite.NewDialerWithResolver(mgr.logger, reso) - dlr = bytecounter.NewContextAwareDialer(dlr) + dlr = bytecounter.WrapWithContextAwareDialer(dlr) // Implements shaping if the user builds using `-tags shaping` // See https://github.com/ooni/probe/issues/2112 dlr = netxlite.NewMaybeShapingDialer(dlr) diff --git a/internal/engine/netx/cacheresolver.go b/internal/engine/netx/cacheresolver.go index b00a6b8..0979c4e 100644 --- a/internal/engine/netx/cacheresolver.go +++ b/internal/engine/netx/cacheresolver.go @@ -2,44 +2,90 @@ package netx import ( "context" + "net" "sync" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// CacheResolver is a resolver that caches successful replies. -type CacheResolver struct { - ReadOnly bool - model.Resolver - mu sync.Mutex - cache map[string][]string +// MaybeWrapWithCachingResolver wraps the provided resolver with a resolver +// that remembers the result of previous successful resolutions, if the enabled +// argument is true. Otherwise, we return the unmodified provided resolver. +// +// Bug: the returned resolver only applies caching to LookupHost and any other +// lookup operation returns ErrNoDNSTransport to the caller. +func MaybeWrapWithCachingResolver(enabled bool, reso model.Resolver) model.Resolver { + if enabled { + reso = &cacheResolver{ + cache: map[string][]string{}, + mu: sync.Mutex{}, + readOnly: false, + resolver: reso, + } + } + return reso } -// LookupHost implements Resolver.LookupHost -func (r *CacheResolver) LookupHost( +// MaybeWrapWithStaticDNSCache wraps the provided resolver with a resolver that +// checks the given cache before issuing queries to the underlying DNS resolver. +// +// Bug: the returned resolver only applies caching to LookupHost and any other +// lookup operation returns ErrNoDNSTransport to the caller. +func MaybeWrapWithStaticDNSCache(cache map[string][]string, reso model.Resolver) model.Resolver { + if len(cache) > 0 { + reso = &cacheResolver{ + cache: cache, + mu: sync.Mutex{}, + readOnly: true, + resolver: reso, + } + } + return reso +} + +// cacheResolver implements CachingResolver and StaticDNSCache. +type cacheResolver struct { + // cache is the underlying DNS cache. + cache map[string][]string + + // mu provides mutual exclusion. + mu sync.Mutex + + // readOnly means that we won't cache the result of successful resolutions. + readOnly bool + + // resolver is the underlying resolver. + resolver model.Resolver +} + +var _ model.Resolver = &cacheResolver{} + +// LookupHost implements model.Resolver.LookupHost +func (r *cacheResolver) LookupHost( ctx context.Context, hostname string) ([]string, error) { - if entry := r.Get(hostname); entry != nil { + if entry := r.get(hostname); entry != nil { return entry, nil } - entry, err := r.Resolver.LookupHost(ctx, hostname) + entry, err := r.resolver.LookupHost(ctx, hostname) if err != nil { return nil, err } - if !r.ReadOnly { - r.Set(hostname, entry) + if !r.readOnly { + r.set(hostname, entry) } return entry, nil } -// Get gets the currently configured entry for domain, or nil -func (r *CacheResolver) Get(domain string) []string { +// get gets the currently configured entry for domain, or nil +func (r *cacheResolver) get(domain string) []string { r.mu.Lock() defer r.mu.Unlock() return r.cache[domain] } -// Set allows to pre-populate the cache -func (r *CacheResolver) Set(domain string, addresses []string) { +// set sets a valid inside the cache iff readOnly is false. +func (r *cacheResolver) set(domain string, addresses []string) { r.mu.Lock() if r.cache == nil { r.cache = make(map[string][]string) @@ -47,3 +93,28 @@ func (r *CacheResolver) Set(domain string, addresses []string) { r.cache[domain] = addresses r.mu.Unlock() } + +// Address implements model.Resolver.Address. +func (r *cacheResolver) Address() string { + return r.resolver.Address() +} + +// Network implements model.Resolver.Network. +func (r *cacheResolver) Network() string { + return r.resolver.Network() +} + +// CloseIdleConnections implements model.Resolver.CloseIdleConnections. +func (r *cacheResolver) CloseIdleConnections() { + r.resolver.CloseIdleConnections() +} + +// LookupHTTPS implements model.Resolver.LookupHTTPS. +func (r *cacheResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + return nil, netxlite.ErrNoDNSTransport +} + +// LookupNS implements model.Resolver.LookupNS. +func (r *cacheResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { + return nil, netxlite.ErrNoDNSTransport +} diff --git a/internal/engine/netx/cacheresolver_test.go b/internal/engine/netx/cacheresolver_test.go index d0ab603..4bda96c 100644 --- a/internal/engine/netx/cacheresolver_test.go +++ b/internal/engine/netx/cacheresolver_test.go @@ -5,81 +5,202 @@ import ( "errors" "testing" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestCacheResolverFailure(t *testing.T) { - expected := errors.New("mocked error") - r := &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, expected - }, - } - cache := &CacheResolver{Resolver: r} - addrs, err := cache.LookupHost(context.Background(), "www.google.com") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil addrs here") - } - if cache.Get("www.google.com") != nil { - t.Fatal("expected empty cache here") - } +func TestMaybeWrapWithCachingResolver(t *testing.T) { + t.Run("with enable equal to true", func(t *testing.T) { + underlying := &mocks.Resolver{} + reso := MaybeWrapWithCachingResolver(true, underlying) + cachereso := reso.(*cacheResolver) + if cachereso.resolver != underlying { + t.Fatal("did not wrap correctly") + } + }) + + t.Run("with enable equal to false", func(t *testing.T) { + underlying := &mocks.Resolver{} + reso := MaybeWrapWithCachingResolver(false, underlying) + if reso != underlying { + t.Fatal("unexpected result") + } + }) } -func TestCacheResolverHitSuccess(t *testing.T) { - expected := errors.New("mocked error") - r := &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, expected - }, - } - cache := &CacheResolver{Resolver: r} - cache.Set("dns.google.com", []string{"8.8.8.8"}) - addrs, err := cache.LookupHost(context.Background(), "dns.google.com") - if err != nil { - t.Fatal(err) - } - if len(addrs) != 1 || addrs[0] != "8.8.8.8" { - t.Fatal("not the result we expected") - } +func TestMaybeWrapWithStaticDNSCache(t *testing.T) { + t.Run("when the cache is not empty", func(t *testing.T) { + cachedDomain := "dns.google" + expectedEntry := []string{"8.8.8.8", "8.8.4.4"} + underlyingCache := make(map[string][]string) + underlyingCache[cachedDomain] = expectedEntry + underlyingReso := &mocks.Resolver{} + reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso) + cachereso := reso.(*cacheResolver) + if diff := cmp.Diff(cachereso.cache, underlyingCache); diff != "" { + t.Fatal(diff) + } + if cachereso.resolver != underlyingReso { + t.Fatal("unexpected underlying resolver") + } + }) + + t.Run("when the cache is empty", func(t *testing.T) { + underlyingCache := make(map[string][]string) + underlyingReso := &mocks.Resolver{} + reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso) + if reso != underlyingReso { + t.Fatal("unexpected result") + } + }) + + t.Run("when the cache is nil", func(t *testing.T) { + var underlyingCache map[string][]string + underlyingReso := &mocks.Resolver{} + reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso) + if reso != underlyingReso { + t.Fatal("unexpected result") + } + }) } -func TestCacheResolverMissSuccess(t *testing.T) { - r := &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"8.8.8.8"}, nil - }, - } - cache := &CacheResolver{Resolver: r} - addrs, err := cache.LookupHost(context.Background(), "dns.google.com") - if err != nil { - t.Fatal(err) - } - if len(addrs) != 1 || addrs[0] != "8.8.8.8" { - t.Fatal("not the result we expected") - } - if cache.Get("dns.google.com")[0] != "8.8.8.8" { - t.Fatal("expected full cache here") - } -} +func TestCacheResolver(t *testing.T) { + t.Run("LookupHost", func(t *testing.T) { + t.Run("cache miss and failure", func(t *testing.T) { + expected := errors.New("mocked error") + r := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + } + cache := &cacheResolver{resolver: r} + addrs, err := cache.LookupHost(context.Background(), "www.google.com") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil addrs here") + } + if cache.get("www.google.com") != nil { + t.Fatal("expected empty cache here") + } + }) -func TestCacheResolverReadonlySuccess(t *testing.T) { - r := &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"8.8.8.8"}, nil - }, - } - cache := &CacheResolver{Resolver: r, ReadOnly: true} - addrs, err := cache.LookupHost(context.Background(), "dns.google.com") - if err != nil { - t.Fatal(err) - } - if len(addrs) != 1 || addrs[0] != "8.8.8.8" { - t.Fatal("not the result we expected") - } - if cache.Get("dns.google.com") != nil { - t.Fatal("expected empty cache here") - } + t.Run("cache hit", func(t *testing.T) { + expected := errors.New("mocked error") + r := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + } + cache := &cacheResolver{resolver: r} + cache.set("dns.google.com", []string{"8.8.8.8"}) + addrs, err := cache.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("not the result we expected") + } + }) + + t.Run("cache miss and success with readwrite cache", func(t *testing.T) { + r := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"8.8.8.8"}, nil + }, + } + cache := &cacheResolver{resolver: r} + addrs, err := cache.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("not the result we expected") + } + if cache.get("dns.google.com")[0] != "8.8.8.8" { + t.Fatal("expected full cache here") + } + }) + + t.Run("cache miss and success with readonly cache", func(t *testing.T) { + r := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"8.8.8.8"}, nil + }, + } + cache := &cacheResolver{resolver: r, readOnly: true} + addrs, err := cache.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("not the result we expected") + } + if cache.get("dns.google.com") != nil { + t.Fatal("expected empty cache here") + } + }) + + t.Run("Address", func(t *testing.T) { + underlying := &mocks.Resolver{ + MockAddress: func() string { + return "x" + }, + } + reso := &cacheResolver{resolver: underlying} + if reso.Address() != "x" { + t.Fatal("unexpected result") + } + }) + + t.Run("Network", func(t *testing.T) { + underlying := &mocks.Resolver{ + MockNetwork: func() string { + return "x" + }, + } + reso := &cacheResolver{resolver: underlying} + if reso.Network() != "x" { + t.Fatal("unexpected result") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + underlying := &mocks.Resolver{ + MockCloseIdleConnections: func() { + called = true + }, + } + reso := &cacheResolver{resolver: underlying} + reso.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("LookupHTTPS", func(t *testing.T) { + reso := &cacheResolver{} + https, err := reso.LookupHTTPS(context.Background(), "dns.google") + if !errors.Is(err, netxlite.ErrNoDNSTransport) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("expected nil") + } + }) + + t.Run("LookupNS", func(t *testing.T) { + reso := &cacheResolver{} + ns, err := reso.LookupNS(context.Background(), "dns.google") + if !errors.Is(err, netxlite.ErrNoDNSTransport) { + t.Fatal("unexpected err", err) + } + if len(ns) != 0 { + t.Fatal("expected zero length slice") + } + }) + }) } diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index c12f1bd..7301cc6 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -67,19 +67,9 @@ func NewResolver(config Config) model.Resolver { model.ValidLoggerOrDefault(config.Logger), config.BaseResolver, ) - if config.CacheResolutions { - r = &CacheResolver{Resolver: r} - } - if config.DNSCache != nil { - cache := &CacheResolver{Resolver: r, ReadOnly: true} - for key, values := range config.DNSCache { - cache.Set(key, values) - } - r = cache - } - if config.BogonIsError { - r = &netxlite.BogonResolver{Resolver: r} - } + r = MaybeWrapWithCachingResolver(config.CacheResolutions, r) + r = MaybeWrapWithStaticDNSCache(config.DNSCache, r) + r = netxlite.MaybeWrapWithBogonResolver(config.BogonIsError, r) return config.Saver.WrapResolver(r) // WAI when config.Saver==nil } @@ -94,9 +84,7 @@ func NewDialer(config Config) model.Dialer { config.ReadWriteSaver.NewReadWriteObserver(), ) d = netxlite.NewMaybeProxyDialer(d, config.ProxyURL) - if config.ContextByteCounting { - d = &bytecounter.ContextAwareDialer{Dialer: d} - } + d = bytecounter.MaybeWrapWithContextAwareDialer(config.ContextByteCounting, d) return d } @@ -143,15 +131,12 @@ func NewHTTPTransport(config Config) model.HTTPTransport { TLSDialer: config.TLSDialer, TLSConfig: config.TLSConfig, }) - if config.ByteCounter != nil { - txp = &bytecounter.HTTPTransport{ - Counter: config.ByteCounter, HTTPTransport: txp} - } - if config.Saver != nil { - txp = &tracex.HTTPTransportSaver{ - HTTPTransport: txp, Saver: config.Saver} - } - return txp + // TODO(bassosimone): I am not super convinced by this code because it + // seems we're currently counting bytes twice in some cases. I think we + // should review how we're counting bytes and using netx currently. + txp = config.ByteCounter.MaybeWrapHTTPTransport(txp) // WAI with ByteCounter == nil + const defaultSnapshotSize = 0 // means: use the default snapsize + return config.Saver.MaybeWrapHTTPTransport(txp, defaultSnapshotSize) // WAI with Saver == nil } // httpTransportInfo contains the constructing function as well as the transport name diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 7a812a2..078edbd 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -100,21 +100,6 @@ func TestNewWithDialer(t *testing.T) { } } -func TestNewWithByteCounter(t *testing.T) { - counter := bytecounter.New() - txp := NewHTTPTransport(Config{ - ByteCounter: counter, - }) - bctxp, ok := txp.(*bytecounter.HTTPTransport) - if !ok { - t.Fatal("not the transport we expected") - } - if bctxp.Counter != counter { - t.Fatal("not the byte counter we expected") - } - // We are going to trust the underlying transport returned by netxlite -} - func TestNewWithSaver(t *testing.T) { saver := new(tracex.Saver) txp := NewHTTPTransport(Config{ diff --git a/internal/engine/session.go b/internal/engine/session.go index 9a4cc95..a6ab2c9 100644 --- a/internal/engine/session.go +++ b/internal/engine/session.go @@ -202,7 +202,7 @@ func NewSession(ctx context.Context, config SessionConfig) (*Session, error) { handshaker := netxlite.NewTLSHandshakerStdlib(sess.logger) tlsDialer := netxlite.NewTLSDialer(dialer, handshaker) txp := netxlite.NewHTTPTransport(sess.logger, dialer, tlsDialer) - txp = bytecounter.NewHTTPTransport(txp, sess.byteCounter) + txp = bytecounter.WrapHTTPTransport(txp, sess.byteCounter) sess.httpDefaultTransport = txp return sess, nil } diff --git a/internal/netxlite/bogon.go b/internal/netxlite/bogon.go index a905382..2e23e70 100644 --- a/internal/netxlite/bogon.go +++ b/internal/netxlite/bogon.go @@ -14,6 +14,16 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) +// MaybeWrapWithBogonResolver wraps the given resolver with a BogonResolver +// iff the provided boolean flag is true. Otherwise, this factory just returns +// the provided resolver to the caller without any wrapping. +func MaybeWrapWithBogonResolver(enabled bool, reso model.Resolver) model.Resolver { + if enabled { + reso = &BogonResolver{Resolver: reso} + } + return reso +} + // BogonResolver is a bogon aware resolver. When a bogon is encountered in // a reply, this resolver will return ErrDNSBogon. // diff --git a/internal/netxlite/bogon_test.go b/internal/netxlite/bogon_test.go index 139d8bc..8884afd 100644 --- a/internal/netxlite/bogon_test.go +++ b/internal/netxlite/bogon_test.go @@ -9,6 +9,25 @@ import ( "github.com/ooni/probe-cli/v3/internal/model/mocks" ) +func TestMaybeWrapWithBogonResolver(t *testing.T) { + t.Run("with enabled equal to true", func(t *testing.T) { + underlying := &mocks.Resolver{} + reso := MaybeWrapWithBogonResolver(true, underlying) + bogoreso := reso.(*BogonResolver) + if bogoreso.Resolver != underlying { + t.Fatal("did not wrap") + } + }) + + t.Run("with enabled equal to false", func(t *testing.T) { + underlying := &mocks.Resolver{} + reso := MaybeWrapWithBogonResolver(false, underlying) + if reso != underlying { + t.Fatal("expected unmodified resolver") + } + }) +} + func TestBogonResolver(t *testing.T) { t.Run("LookupHost", func(t *testing.T) { t.Run("with failure", func(t *testing.T) { diff --git a/internal/ptx/ptx.go b/internal/ptx/ptx.go index 51daa41..09b9113 100644 --- a/internal/ptx/ptx.go +++ b/internal/ptx/ptx.go @@ -163,8 +163,8 @@ func (lst *Listener) handleSocksConn(ctx context.Context, socksConn ptxSocksConn // We _must_ wrap the ptConn. Wrapping the socks conn leads us to // count the sent bytes as received and the received bytes as sent: // bytes flow in the opposite direction there for the socks conn. - ptConn = bytecounter.MaybeWrap(ptConn, lst.SessionByteCounter) - ptConn = bytecounter.MaybeWrap(ptConn, lst.ExperimentByteCounter) + ptConn = bytecounter.MaybeWrapConn(ptConn, lst.SessionByteCounter) + ptConn = bytecounter.MaybeWrapConn(ptConn, lst.ExperimentByteCounter) lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership return nil // used for testing } diff --git a/internal/tracex/http.go b/internal/tracex/http.go index 58d1205..6f4d62b 100644 --- a/internal/tracex/http.go +++ b/internal/tracex/http.go @@ -14,6 +14,20 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) +// MaybeWrapHTTPTransport wraps the HTTPTransport to save events if this Saver +// is not nil and otherwise just returns the given HTTPTransport. The snapshotSize +// argument is the maximum response body snapshot size to save per response. +func (s *Saver) MaybeWrapHTTPTransport(txp model.HTTPTransport, snapshotSize int64) model.HTTPTransport { + if s != nil { + txp = &HTTPTransportSaver{ + HTTPTransport: txp, + Saver: s, + SnapshotSize: snapshotSize, + } + } + return txp +} + // httpCloneRequestHeaders returns a clone of the headers where we have // also set the host header, which normally is not set by // golang until it serializes the request itself. diff --git a/internal/tracex/http_test.go b/internal/tracex/http_test.go index 87b5f02..26b8d8f 100644 --- a/internal/tracex/http_test.go +++ b/internal/tracex/http_test.go @@ -15,6 +15,32 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" ) +func TestMaybeWrapHTTPTransport(t *testing.T) { + const snapshotSize = 1024 + + t.Run("with non-nil saver", func(t *testing.T) { + saver := &Saver{} + underlying := &mocks.HTTPTransport{} + txp := saver.MaybeWrapHTTPTransport(underlying, snapshotSize) + realTxp := txp.(*HTTPTransportSaver) + if realTxp.HTTPTransport != underlying { + t.Fatal("unexpected result") + } + if realTxp.SnapshotSize != snapshotSize { + t.Fatal("did not set snapshotSize correctly") + } + }) + + t.Run("with nil saver", func(t *testing.T) { + var saver *Saver + underlying := &mocks.HTTPTransport{} + txp := saver.MaybeWrapHTTPTransport(underlying, snapshotSize) + if txp != underlying { + t.Fatal("unexpected result") + } + }) +} + func TestHTTPTransportSaver(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {