refactor(netx): move construction logic outside package (#798)
For testability, replace most if-based construction logic with calls to well-tested factories living in other packages. While there, acknowledge that a bunch of types could now be private and make them private, modifying the code to call the public factories allowing to construct said types instead. Part of https://github.com/ooni/probe/issues/2121
This commit is contained in:
parent
2d3d5d9cdc
commit
6b85dfce88
|
@ -6,8 +6,8 @@ package bytecounter
|
||||||
|
|
||||||
import "net"
|
import "net"
|
||||||
|
|
||||||
// Conn wraps a network connection and counts bytes.
|
// wrappedConn wraps a network connection and counts bytes.
|
||||||
type Conn struct {
|
type wrappedConn struct {
|
||||||
// net.Conn is the underlying net.Conn.
|
// net.Conn is the underlying net.Conn.
|
||||||
net.Conn
|
net.Conn
|
||||||
|
|
||||||
|
@ -16,28 +16,28 @@ type Conn struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read implements net.Conn.Read.
|
// Read implements net.Conn.Read.
|
||||||
func (c *Conn) Read(p []byte) (int, error) {
|
func (c *wrappedConn) Read(p []byte) (int, error) {
|
||||||
count, err := c.Conn.Read(p)
|
count, err := c.Conn.Read(p)
|
||||||
c.Counter.CountBytesReceived(count)
|
c.Counter.CountBytesReceived(count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write implements net.Conn.Write.
|
// Write implements net.Conn.Write.
|
||||||
func (c *Conn) Write(p []byte) (int, error) {
|
func (c *wrappedConn) Write(p []byte) (int, error) {
|
||||||
count, err := c.Conn.Write(p)
|
count, err := c.Conn.Write(p)
|
||||||
c.Counter.CountBytesSent(count)
|
c.Counter.CountBytesSent(count)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap returns a new conn that uses the given counter.
|
// WrapConn returns a new conn that uses the given counter.
|
||||||
func Wrap(conn net.Conn, counter *Counter) net.Conn {
|
func WrapConn(conn net.Conn, counter *Counter) net.Conn {
|
||||||
return &Conn{Conn: conn, Counter: counter}
|
return &wrappedConn{Conn: conn, Counter: counter}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaybeWrap is like wrap if counter is not nil, otherwise it's a no-op.
|
// MaybeWrapConn is like wrap if counter is not nil, otherwise it's a no-op.
|
||||||
func MaybeWrap(conn net.Conn, counter *Counter) net.Conn {
|
func MaybeWrapConn(conn net.Conn, counter *Counter) net.Conn {
|
||||||
if counter == nil {
|
if counter == nil {
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
return Wrap(conn, counter)
|
return WrapConn(conn, counter)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConnWorksOnSuccess(t *testing.T) {
|
func TestWrappedConnWorksOnSuccess(t *testing.T) {
|
||||||
counter := New()
|
counter := New()
|
||||||
underlying := &mocks.Conn{
|
underlying := &mocks.Conn{
|
||||||
MockRead: func(b []byte) (int, error) {
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
@ -17,7 +17,7 @@ func TestConnWorksOnSuccess(t *testing.T) {
|
||||||
return 4, nil
|
return 4, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn := &Conn{
|
conn := &wrappedConn{
|
||||||
Conn: underlying,
|
Conn: underlying,
|
||||||
Counter: counter,
|
Counter: counter,
|
||||||
}
|
}
|
||||||
|
@ -35,7 +35,7 @@ func TestConnWorksOnSuccess(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConnWorksOnFailure(t *testing.T) {
|
func TestWrappedConnWorksOnFailure(t *testing.T) {
|
||||||
readError := errors.New("read error")
|
readError := errors.New("read error")
|
||||||
writeError := errors.New("write error")
|
writeError := errors.New("write error")
|
||||||
counter := New()
|
counter := New()
|
||||||
|
@ -47,7 +47,7 @@ func TestConnWorksOnFailure(t *testing.T) {
|
||||||
return 0, writeError
|
return 0, writeError
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn := &Conn{
|
conn := &wrappedConn{
|
||||||
Conn: underlying,
|
Conn: underlying,
|
||||||
Counter: counter,
|
Counter: counter,
|
||||||
}
|
}
|
||||||
|
@ -65,20 +65,20 @@ func TestConnWorksOnFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWrap(t *testing.T) {
|
func TestWrapConn(t *testing.T) {
|
||||||
conn := &mocks.Conn{}
|
conn := &mocks.Conn{}
|
||||||
counter := New()
|
counter := New()
|
||||||
nconn := Wrap(conn, counter)
|
nconn := WrapConn(conn, counter)
|
||||||
_, good := nconn.(*Conn)
|
_, good := nconn.(*wrappedConn)
|
||||||
if !good {
|
if !good {
|
||||||
t.Fatal("did not wrap")
|
t.Fatal("did not wrap")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMaybeWrap(t *testing.T) {
|
func TestMaybeWrapConn(t *testing.T) {
|
||||||
t.Run("with nil counter", func(t *testing.T) {
|
t.Run("with nil counter", func(t *testing.T) {
|
||||||
conn := &mocks.Conn{}
|
conn := &mocks.Conn{}
|
||||||
nconn := MaybeWrap(conn, nil)
|
nconn := MaybeWrapConn(conn, nil)
|
||||||
_, good := nconn.(*mocks.Conn)
|
_, good := nconn.(*mocks.Conn)
|
||||||
if !good {
|
if !good {
|
||||||
t.Fatal("did not wrap")
|
t.Fatal("did not wrap")
|
||||||
|
@ -88,8 +88,8 @@ func TestMaybeWrap(t *testing.T) {
|
||||||
t.Run("with legit counter", func(t *testing.T) {
|
t.Run("with legit counter", func(t *testing.T) {
|
||||||
conn := &mocks.Conn{}
|
conn := &mocks.Conn{}
|
||||||
counter := New()
|
counter := New()
|
||||||
nconn := MaybeWrap(conn, counter)
|
nconn := MaybeWrapConn(conn, counter)
|
||||||
_, good := nconn.(*Conn)
|
_, good := nconn.(*wrappedConn)
|
||||||
if !good {
|
if !good {
|
||||||
t.Fatal("did not wrap")
|
t.Fatal("did not wrap")
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ func WithExperimentByteCounter(ctx context.Context, counter *Counter) context.Co
|
||||||
// MaybeWrapWithContextByteCounters wraps a conn with the byte counters
|
// MaybeWrapWithContextByteCounters wraps a conn with the byte counters
|
||||||
// that have previosuly been configured into a context.
|
// that have previosuly been configured into a context.
|
||||||
func MaybeWrapWithContextByteCounters(ctx context.Context, conn net.Conn) net.Conn {
|
func MaybeWrapWithContextByteCounters(ctx context.Context, conn net.Conn) net.Conn {
|
||||||
conn = MaybeWrap(conn, ContextExperimentByteCounter(ctx))
|
conn = MaybeWrapConn(conn, ContextExperimentByteCounter(ctx))
|
||||||
conn = MaybeWrap(conn, ContextSessionByteCounter(ctx))
|
conn = MaybeWrapConn(conn, ContextSessionByteCounter(ctx))
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@ package bytecounter
|
||||||
|
|
||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestGood(t *testing.T) {
|
func TestCounter(t *testing.T) {
|
||||||
counter := New()
|
counter := New()
|
||||||
counter.CountBytesReceived(16384)
|
counter.CountBytesReceived(16384)
|
||||||
counter.CountKibiBytesReceived(10)
|
counter.CountKibiBytesReceived(10)
|
||||||
|
|
|
@ -11,8 +11,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ContextAwareDialer is a model.Dialer that attempts to count bytes using
|
// MaybeWrapWithContextAwareDialer wraps the given dialer with a ContextAwareDialer
|
||||||
// the MaybeWrapWithContextByteCounters function.
|
// if the enabled argument is true and otherwise just returns the given dialer.
|
||||||
//
|
//
|
||||||
// Bug
|
// Bug
|
||||||
//
|
//
|
||||||
|
@ -24,19 +24,29 @@ import (
|
||||||
//
|
//
|
||||||
// For this reason, this implementation may be heavily changed/removed
|
// For this reason, this implementation may be heavily changed/removed
|
||||||
// in the future (<- this message is now ~two years old, though).
|
// in the future (<- this message is now ~two years old, though).
|
||||||
type ContextAwareDialer struct {
|
func MaybeWrapWithContextAwareDialer(enabled bool, dialer model.Dialer) model.Dialer {
|
||||||
|
if !enabled {
|
||||||
|
return dialer
|
||||||
|
}
|
||||||
|
return WrapWithContextAwareDialer(dialer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextAwareDialer is a model.Dialer that attempts to count bytes using
|
||||||
|
// the MaybeWrapWithContextByteCounters function.
|
||||||
|
type contextAwareDialer struct {
|
||||||
Dialer model.Dialer
|
Dialer model.Dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewContextAwareDialer creates a new ContextAwareDialer.
|
// WrapWithContextAwareDialer creates a new ContextAwareDialer. See the docs
|
||||||
func NewContextAwareDialer(dialer model.Dialer) *ContextAwareDialer {
|
// of MaybeWrapWithContextAwareDialer for a list of caveats.
|
||||||
return &ContextAwareDialer{Dialer: dialer}
|
func WrapWithContextAwareDialer(dialer model.Dialer) *contextAwareDialer {
|
||||||
|
return &contextAwareDialer{Dialer: dialer}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Dialer = &ContextAwareDialer{}
|
var _ model.Dialer = &contextAwareDialer{}
|
||||||
|
|
||||||
// DialContext implements Dialer.DialContext
|
// DialContext implements Dialer.DialContext
|
||||||
func (d *ContextAwareDialer) DialContext(
|
func (d *contextAwareDialer) DialContext(
|
||||||
ctx context.Context, network, address string) (net.Conn, error) {
|
ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -47,6 +57,6 @@ func (d *ContextAwareDialer) DialContext(
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||||
func (d *ContextAwareDialer) CloseIdleConnections() {
|
func (d *contextAwareDialer) CloseIdleConnections() {
|
||||||
d.Dialer.CloseIdleConnections()
|
d.Dialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,25 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMaybeWrapWithContextAwareDialer(t *testing.T) {
|
||||||
|
t.Run("when enabled is true", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Dialer{}
|
||||||
|
dialer := MaybeWrapWithContextAwareDialer(true, underlying)
|
||||||
|
realDialer := dialer.(*contextAwareDialer)
|
||||||
|
if realDialer.Dialer != underlying {
|
||||||
|
t.Fatal("did not wrap correctly")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when enabled is false", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Dialer{}
|
||||||
|
dialer := MaybeWrapWithContextAwareDialer(false, underlying)
|
||||||
|
if dialer != underlying {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestContextAwareDialer(t *testing.T) {
|
func TestContextAwareDialer(t *testing.T) {
|
||||||
t.Run("DialContext", func(t *testing.T) {
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
dialAndUseConn := func(ctx context.Context, bufsiz int) error {
|
dialAndUseConn := func(ctx context.Context, bufsiz int) error {
|
||||||
|
@ -26,7 +45,7 @@ func TestContextAwareDialer(t *testing.T) {
|
||||||
return childConn, nil
|
return childConn, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dialer := NewContextAwareDialer(child)
|
dialer := WrapWithContextAwareDialer(child)
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443")
|
conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -68,7 +87,7 @@ func TestContextAwareDialer(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("failure", func(t *testing.T) {
|
t.Run("failure", func(t *testing.T) {
|
||||||
dialer := &ContextAwareDialer{
|
dialer := &contextAwareDialer{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
|
@ -92,7 +111,7 @@ func TestContextAwareDialer(t *testing.T) {
|
||||||
called = true
|
called = true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dialer := NewContextAwareDialer(child)
|
dialer := WrapWithContextAwareDialer(child)
|
||||||
dialer.CloseIdleConnections()
|
dialer.CloseIdleConnections()
|
||||||
if !called {
|
if !called {
|
||||||
t.Fatal("not called")
|
t.Fatal("not called")
|
||||||
|
|
|
@ -7,29 +7,39 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPTransport is a model.HTTPTransport that counts bytes.
|
// MaybeWrapHTTPTransport takes in input an HTTPTransport and either wraps it
|
||||||
type HTTPTransport struct {
|
// to perform byte counting, if this counter is not nil, or just returns to the
|
||||||
|
// caller the original transport, when the counter is nil.
|
||||||
|
func (c *Counter) MaybeWrapHTTPTransport(txp model.HTTPTransport) model.HTTPTransport {
|
||||||
|
if c != nil {
|
||||||
|
txp = WrapHTTPTransport(txp, c)
|
||||||
|
}
|
||||||
|
return txp
|
||||||
|
}
|
||||||
|
|
||||||
|
// httpTransport is a model.HTTPTransport that counts bytes.
|
||||||
|
type httpTransport struct {
|
||||||
HTTPTransport model.HTTPTransport
|
HTTPTransport model.HTTPTransport
|
||||||
Counter *Counter
|
Counter *Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPTransport creates a new byte-counting-aware HTTP transport.
|
// WrapHTTPTransport creates a new byte-counting-aware HTTP transport.
|
||||||
func NewHTTPTransport(txp model.HTTPTransport, counter *Counter) model.HTTPTransport {
|
func WrapHTTPTransport(txp model.HTTPTransport, counter *Counter) model.HTTPTransport {
|
||||||
return &HTTPTransport{
|
return &httpTransport{
|
||||||
HTTPTransport: txp,
|
HTTPTransport: txp,
|
||||||
Counter: counter,
|
Counter: counter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.HTTPTransport = &HTTPTransport{}
|
var _ model.HTTPTransport = &httpTransport{}
|
||||||
|
|
||||||
// CloseIdleConnections implements model.HTTPTransport.CloseIdleConnections.
|
// CloseIdleConnections implements model.HTTPTransport.CloseIdleConnections.
|
||||||
func (txp *HTTPTransport) CloseIdleConnections() {
|
func (txp *httpTransport) CloseIdleConnections() {
|
||||||
txp.HTTPTransport.CloseIdleConnections()
|
txp.HTTPTransport.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements model.HTTPTRansport.RoundTrip
|
// RoundTrip implements model.HTTPTRansport.RoundTrip
|
||||||
func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (txp *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
if req.Body != nil {
|
if req.Body != nil {
|
||||||
req.Body = &httpBodyWrapper{
|
req.Body = &httpBodyWrapper{
|
||||||
account: txp.Counter.CountBytesSent,
|
account: txp.Counter.CountBytesSent,
|
||||||
|
@ -50,11 +60,11 @@ func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network implements model.HTTPTransport.Network.
|
// Network implements model.HTTPTransport.Network.
|
||||||
func (txp *HTTPTransport) Network() string {
|
func (txp *httpTransport) Network() string {
|
||||||
return txp.HTTPTransport.Network()
|
return txp.HTTPTransport.Network()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *HTTPTransport) estimateRequestMetadata(req *http.Request) {
|
func (txp *httpTransport) estimateRequestMetadata(req *http.Request) {
|
||||||
txp.Counter.CountBytesSent(len(req.Method))
|
txp.Counter.CountBytesSent(len(req.Method))
|
||||||
txp.Counter.CountBytesSent(len(req.URL.String()))
|
txp.Counter.CountBytesSent(len(req.URL.String()))
|
||||||
for key, values := range req.Header {
|
for key, values := range req.Header {
|
||||||
|
@ -68,7 +78,7 @@ func (txp *HTTPTransport) estimateRequestMetadata(req *http.Request) {
|
||||||
txp.Counter.CountBytesSent(len("\r\n"))
|
txp.Counter.CountBytesSent(len("\r\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *HTTPTransport) estimateResponseMetadata(resp *http.Response) {
|
func (txp *httpTransport) estimateResponseMetadata(resp *http.Response) {
|
||||||
txp.Counter.CountBytesReceived(len(resp.Status))
|
txp.Counter.CountBytesReceived(len(resp.Status))
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
|
|
|
@ -12,11 +12,32 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMaybeWrapHTTPTransport(t *testing.T) {
|
||||||
|
t.Run("when counter is not nil", func(t *testing.T) {
|
||||||
|
underlying := &mocks.HTTPTransport{}
|
||||||
|
counter := &Counter{}
|
||||||
|
txp := counter.MaybeWrapHTTPTransport(underlying)
|
||||||
|
realTxp := txp.(*httpTransport)
|
||||||
|
if realTxp.HTTPTransport != underlying {
|
||||||
|
t.Fatal("did not wrap correctly")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when counter is nil", func(t *testing.T) {
|
||||||
|
underlying := &mocks.HTTPTransport{}
|
||||||
|
var counter *Counter
|
||||||
|
txp := counter.MaybeWrapHTTPTransport(underlying)
|
||||||
|
if txp != underlying {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestHTTPTransport(t *testing.T) {
|
func TestHTTPTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
t.Run("failure", func(t *testing.T) {
|
t.Run("failure", func(t *testing.T) {
|
||||||
counter := New()
|
counter := New()
|
||||||
txp := &HTTPTransport{
|
txp := &httpTransport{
|
||||||
Counter: counter,
|
Counter: counter,
|
||||||
HTTPTransport: &mocks.HTTPTransport{
|
HTTPTransport: &mocks.HTTPTransport{
|
||||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -47,7 +68,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("success", func(t *testing.T) {
|
t.Run("success", func(t *testing.T) {
|
||||||
counter := New()
|
counter := New()
|
||||||
txp := &HTTPTransport{
|
txp := &httpTransport{
|
||||||
Counter: counter,
|
Counter: counter,
|
||||||
HTTPTransport: &mocks.HTTPTransport{
|
HTTPTransport: &mocks.HTTPTransport{
|
||||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -91,7 +112,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("success with EOF", func(t *testing.T) {
|
t.Run("success with EOF", func(t *testing.T) {
|
||||||
counter := New()
|
counter := New()
|
||||||
txp := &HTTPTransport{
|
txp := &httpTransport{
|
||||||
Counter: counter,
|
Counter: counter,
|
||||||
HTTPTransport: &mocks.HTTPTransport{
|
HTTPTransport: &mocks.HTTPTransport{
|
||||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -139,7 +160,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
counter := New()
|
counter := New()
|
||||||
txp := NewHTTPTransport(child, counter)
|
txp := WrapHTTPTransport(child, counter)
|
||||||
txp.CloseIdleConnections()
|
txp.CloseIdleConnections()
|
||||||
if !called {
|
if !called {
|
||||||
t.Fatal("not called")
|
t.Fatal("not called")
|
||||||
|
@ -154,7 +175,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
counter := New()
|
counter := New()
|
||||||
txp := NewHTTPTransport(child, counter)
|
txp := WrapHTTPTransport(child, counter)
|
||||||
if network := txp.Network(); network != expected {
|
if network := txp.Network(); network != expected {
|
||||||
t.Fatal("unexpected network", network)
|
t.Fatal("unexpected network", network)
|
||||||
}
|
}
|
||||||
|
|
|
@ -285,10 +285,10 @@ func (e *Experiment) OpenReportContext(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
// use custom client to have proper byte accounting
|
// use custom client to have proper byte accounting
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Transport: &bytecounter.HTTPTransport{
|
Transport: bytecounter.WrapHTTPTransport(
|
||||||
HTTPTransport: e.session.httpDefaultTransport, // proxy is OK
|
e.session.httpDefaultTransport, // proxy is OK
|
||||||
Counter: e.byteCounter,
|
e.byteCounter,
|
||||||
},
|
),
|
||||||
}
|
}
|
||||||
client, err := e.session.NewProbeServicesClient(ctx)
|
client, err := e.session.NewProbeServicesClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -32,7 +32,7 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM
|
||||||
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
|
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
|
||||||
reso := netxlite.NewResolverStdlib(mgr.logger)
|
reso := netxlite.NewResolverStdlib(mgr.logger)
|
||||||
dlr := netxlite.NewDialerWithResolver(mgr.logger, reso)
|
dlr := netxlite.NewDialerWithResolver(mgr.logger, reso)
|
||||||
dlr = bytecounter.NewContextAwareDialer(dlr)
|
dlr = bytecounter.WrapWithContextAwareDialer(dlr)
|
||||||
// Implements shaping if the user builds using `-tags shaping`
|
// Implements shaping if the user builds using `-tags shaping`
|
||||||
// See https://github.com/ooni/probe/issues/2112
|
// See https://github.com/ooni/probe/issues/2112
|
||||||
dlr = netxlite.NewMaybeShapingDialer(dlr)
|
dlr = netxlite.NewMaybeShapingDialer(dlr)
|
||||||
|
|
|
@ -2,44 +2,90 @@ package netx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CacheResolver is a resolver that caches successful replies.
|
// MaybeWrapWithCachingResolver wraps the provided resolver with a resolver
|
||||||
type CacheResolver struct {
|
// that remembers the result of previous successful resolutions, if the enabled
|
||||||
ReadOnly bool
|
// argument is true. Otherwise, we return the unmodified provided resolver.
|
||||||
model.Resolver
|
//
|
||||||
mu sync.Mutex
|
// Bug: the returned resolver only applies caching to LookupHost and any other
|
||||||
cache map[string][]string
|
// lookup operation returns ErrNoDNSTransport to the caller.
|
||||||
|
func MaybeWrapWithCachingResolver(enabled bool, reso model.Resolver) model.Resolver {
|
||||||
|
if enabled {
|
||||||
|
reso = &cacheResolver{
|
||||||
|
cache: map[string][]string{},
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
readOnly: false,
|
||||||
|
resolver: reso,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reso
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHost implements Resolver.LookupHost
|
// MaybeWrapWithStaticDNSCache wraps the provided resolver with a resolver that
|
||||||
func (r *CacheResolver) LookupHost(
|
// checks the given cache before issuing queries to the underlying DNS resolver.
|
||||||
|
//
|
||||||
|
// Bug: the returned resolver only applies caching to LookupHost and any other
|
||||||
|
// lookup operation returns ErrNoDNSTransport to the caller.
|
||||||
|
func MaybeWrapWithStaticDNSCache(cache map[string][]string, reso model.Resolver) model.Resolver {
|
||||||
|
if len(cache) > 0 {
|
||||||
|
reso = &cacheResolver{
|
||||||
|
cache: cache,
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
readOnly: true,
|
||||||
|
resolver: reso,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reso
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheResolver implements CachingResolver and StaticDNSCache.
|
||||||
|
type cacheResolver struct {
|
||||||
|
// cache is the underlying DNS cache.
|
||||||
|
cache map[string][]string
|
||||||
|
|
||||||
|
// mu provides mutual exclusion.
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// readOnly means that we won't cache the result of successful resolutions.
|
||||||
|
readOnly bool
|
||||||
|
|
||||||
|
// resolver is the underlying resolver.
|
||||||
|
resolver model.Resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.Resolver = &cacheResolver{}
|
||||||
|
|
||||||
|
// LookupHost implements model.Resolver.LookupHost
|
||||||
|
func (r *cacheResolver) LookupHost(
|
||||||
ctx context.Context, hostname string) ([]string, error) {
|
ctx context.Context, hostname string) ([]string, error) {
|
||||||
if entry := r.Get(hostname); entry != nil {
|
if entry := r.get(hostname); entry != nil {
|
||||||
return entry, nil
|
return entry, nil
|
||||||
}
|
}
|
||||||
entry, err := r.Resolver.LookupHost(ctx, hostname)
|
entry, err := r.resolver.LookupHost(ctx, hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !r.ReadOnly {
|
if !r.readOnly {
|
||||||
r.Set(hostname, entry)
|
r.set(hostname, entry)
|
||||||
}
|
}
|
||||||
return entry, nil
|
return entry, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets the currently configured entry for domain, or nil
|
// get gets the currently configured entry for domain, or nil
|
||||||
func (r *CacheResolver) Get(domain string) []string {
|
func (r *cacheResolver) get(domain string) []string {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
return r.cache[domain]
|
return r.cache[domain]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set allows to pre-populate the cache
|
// set sets a valid inside the cache iff readOnly is false.
|
||||||
func (r *CacheResolver) Set(domain string, addresses []string) {
|
func (r *cacheResolver) set(domain string, addresses []string) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
if r.cache == nil {
|
if r.cache == nil {
|
||||||
r.cache = make(map[string][]string)
|
r.cache = make(map[string][]string)
|
||||||
|
@ -47,3 +93,28 @@ func (r *CacheResolver) Set(domain string, addresses []string) {
|
||||||
r.cache[domain] = addresses
|
r.cache[domain] = addresses
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Address implements model.Resolver.Address.
|
||||||
|
func (r *cacheResolver) Address() string {
|
||||||
|
return r.resolver.Address()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Network implements model.Resolver.Network.
|
||||||
|
func (r *cacheResolver) Network() string {
|
||||||
|
return r.resolver.Network()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections implements model.Resolver.CloseIdleConnections.
|
||||||
|
func (r *cacheResolver) CloseIdleConnections() {
|
||||||
|
r.resolver.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupHTTPS implements model.Resolver.LookupHTTPS.
|
||||||
|
func (r *cacheResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||||
|
return nil, netxlite.ErrNoDNSTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupNS implements model.Resolver.LookupNS.
|
||||||
|
func (r *cacheResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||||
|
return nil, netxlite.ErrNoDNSTransport
|
||||||
|
}
|
||||||
|
|
|
@ -5,17 +5,76 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCacheResolverFailure(t *testing.T) {
|
func TestMaybeWrapWithCachingResolver(t *testing.T) {
|
||||||
|
t.Run("with enable equal to true", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Resolver{}
|
||||||
|
reso := MaybeWrapWithCachingResolver(true, underlying)
|
||||||
|
cachereso := reso.(*cacheResolver)
|
||||||
|
if cachereso.resolver != underlying {
|
||||||
|
t.Fatal("did not wrap correctly")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with enable equal to false", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Resolver{}
|
||||||
|
reso := MaybeWrapWithCachingResolver(false, underlying)
|
||||||
|
if reso != underlying {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaybeWrapWithStaticDNSCache(t *testing.T) {
|
||||||
|
t.Run("when the cache is not empty", func(t *testing.T) {
|
||||||
|
cachedDomain := "dns.google"
|
||||||
|
expectedEntry := []string{"8.8.8.8", "8.8.4.4"}
|
||||||
|
underlyingCache := make(map[string][]string)
|
||||||
|
underlyingCache[cachedDomain] = expectedEntry
|
||||||
|
underlyingReso := &mocks.Resolver{}
|
||||||
|
reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso)
|
||||||
|
cachereso := reso.(*cacheResolver)
|
||||||
|
if diff := cmp.Diff(cachereso.cache, underlyingCache); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
if cachereso.resolver != underlyingReso {
|
||||||
|
t.Fatal("unexpected underlying resolver")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when the cache is empty", func(t *testing.T) {
|
||||||
|
underlyingCache := make(map[string][]string)
|
||||||
|
underlyingReso := &mocks.Resolver{}
|
||||||
|
reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso)
|
||||||
|
if reso != underlyingReso {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when the cache is nil", func(t *testing.T) {
|
||||||
|
var underlyingCache map[string][]string
|
||||||
|
underlyingReso := &mocks.Resolver{}
|
||||||
|
reso := MaybeWrapWithStaticDNSCache(underlyingCache, underlyingReso)
|
||||||
|
if reso != underlyingReso {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheResolver(t *testing.T) {
|
||||||
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
|
t.Run("cache miss and failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &mocks.Resolver{
|
r := &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cache := &CacheResolver{Resolver: r}
|
cache := &cacheResolver{resolver: r}
|
||||||
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
|
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -23,20 +82,20 @@ func TestCacheResolverFailure(t *testing.T) {
|
||||||
if addrs != nil {
|
if addrs != nil {
|
||||||
t.Fatal("expected nil addrs here")
|
t.Fatal("expected nil addrs here")
|
||||||
}
|
}
|
||||||
if cache.Get("www.google.com") != nil {
|
if cache.get("www.google.com") != nil {
|
||||||
t.Fatal("expected empty cache here")
|
t.Fatal("expected empty cache here")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestCacheResolverHitSuccess(t *testing.T) {
|
t.Run("cache hit", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &mocks.Resolver{
|
r := &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cache := &CacheResolver{Resolver: r}
|
cache := &cacheResolver{resolver: r}
|
||||||
cache.Set("dns.google.com", []string{"8.8.8.8"})
|
cache.set("dns.google.com", []string{"8.8.8.8"})
|
||||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -44,15 +103,15 @@ func TestCacheResolverHitSuccess(t *testing.T) {
|
||||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||||
t.Fatal("not the result we expected")
|
t.Fatal("not the result we expected")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestCacheResolverMissSuccess(t *testing.T) {
|
t.Run("cache miss and success with readwrite cache", func(t *testing.T) {
|
||||||
r := &mocks.Resolver{
|
r := &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return []string{"8.8.8.8"}, nil
|
return []string{"8.8.8.8"}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cache := &CacheResolver{Resolver: r}
|
cache := &cacheResolver{resolver: r}
|
||||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -60,18 +119,18 @@ func TestCacheResolverMissSuccess(t *testing.T) {
|
||||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||||
t.Fatal("not the result we expected")
|
t.Fatal("not the result we expected")
|
||||||
}
|
}
|
||||||
if cache.Get("dns.google.com")[0] != "8.8.8.8" {
|
if cache.get("dns.google.com")[0] != "8.8.8.8" {
|
||||||
t.Fatal("expected full cache here")
|
t.Fatal("expected full cache here")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestCacheResolverReadonlySuccess(t *testing.T) {
|
t.Run("cache miss and success with readonly cache", func(t *testing.T) {
|
||||||
r := &mocks.Resolver{
|
r := &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return []string{"8.8.8.8"}, nil
|
return []string{"8.8.8.8"}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cache := &CacheResolver{Resolver: r, ReadOnly: true}
|
cache := &cacheResolver{resolver: r, readOnly: true}
|
||||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -79,7 +138,69 @@ func TestCacheResolverReadonlySuccess(t *testing.T) {
|
||||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||||
t.Fatal("not the result we expected")
|
t.Fatal("not the result we expected")
|
||||||
}
|
}
|
||||||
if cache.Get("dns.google.com") != nil {
|
if cache.get("dns.google.com") != nil {
|
||||||
t.Fatal("expected empty cache here")
|
t.Fatal("expected empty cache here")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Address", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Resolver{
|
||||||
|
MockAddress: func() string {
|
||||||
|
return "x"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reso := &cacheResolver{resolver: underlying}
|
||||||
|
if reso.Address() != "x" {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Resolver{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "x"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reso := &cacheResolver{resolver: underlying}
|
||||||
|
if reso.Network() != "x" {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
underlying := &mocks.Resolver{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reso := &cacheResolver{resolver: underlying}
|
||||||
|
reso.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||||
|
reso := &cacheResolver{}
|
||||||
|
https, err := reso.LookupHTTPS(context.Background(), "dns.google")
|
||||||
|
if !errors.Is(err, netxlite.ErrNoDNSTransport) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if https != nil {
|
||||||
|
t.Fatal("expected nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LookupNS", func(t *testing.T) {
|
||||||
|
reso := &cacheResolver{}
|
||||||
|
ns, err := reso.LookupNS(context.Background(), "dns.google")
|
||||||
|
if !errors.Is(err, netxlite.ErrNoDNSTransport) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if len(ns) != 0 {
|
||||||
|
t.Fatal("expected zero length slice")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,19 +67,9 @@ func NewResolver(config Config) model.Resolver {
|
||||||
model.ValidLoggerOrDefault(config.Logger),
|
model.ValidLoggerOrDefault(config.Logger),
|
||||||
config.BaseResolver,
|
config.BaseResolver,
|
||||||
)
|
)
|
||||||
if config.CacheResolutions {
|
r = MaybeWrapWithCachingResolver(config.CacheResolutions, r)
|
||||||
r = &CacheResolver{Resolver: r}
|
r = MaybeWrapWithStaticDNSCache(config.DNSCache, r)
|
||||||
}
|
r = netxlite.MaybeWrapWithBogonResolver(config.BogonIsError, r)
|
||||||
if config.DNSCache != nil {
|
|
||||||
cache := &CacheResolver{Resolver: r, ReadOnly: true}
|
|
||||||
for key, values := range config.DNSCache {
|
|
||||||
cache.Set(key, values)
|
|
||||||
}
|
|
||||||
r = cache
|
|
||||||
}
|
|
||||||
if config.BogonIsError {
|
|
||||||
r = &netxlite.BogonResolver{Resolver: r}
|
|
||||||
}
|
|
||||||
return config.Saver.WrapResolver(r) // WAI when config.Saver==nil
|
return config.Saver.WrapResolver(r) // WAI when config.Saver==nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,9 +84,7 @@ func NewDialer(config Config) model.Dialer {
|
||||||
config.ReadWriteSaver.NewReadWriteObserver(),
|
config.ReadWriteSaver.NewReadWriteObserver(),
|
||||||
)
|
)
|
||||||
d = netxlite.NewMaybeProxyDialer(d, config.ProxyURL)
|
d = netxlite.NewMaybeProxyDialer(d, config.ProxyURL)
|
||||||
if config.ContextByteCounting {
|
d = bytecounter.MaybeWrapWithContextAwareDialer(config.ContextByteCounting, d)
|
||||||
d = &bytecounter.ContextAwareDialer{Dialer: d}
|
|
||||||
}
|
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,15 +131,12 @@ func NewHTTPTransport(config Config) model.HTTPTransport {
|
||||||
TLSDialer: config.TLSDialer,
|
TLSDialer: config.TLSDialer,
|
||||||
TLSConfig: config.TLSConfig,
|
TLSConfig: config.TLSConfig,
|
||||||
})
|
})
|
||||||
if config.ByteCounter != nil {
|
// TODO(bassosimone): I am not super convinced by this code because it
|
||||||
txp = &bytecounter.HTTPTransport{
|
// seems we're currently counting bytes twice in some cases. I think we
|
||||||
Counter: config.ByteCounter, HTTPTransport: txp}
|
// should review how we're counting bytes and using netx currently.
|
||||||
}
|
txp = config.ByteCounter.MaybeWrapHTTPTransport(txp) // WAI with ByteCounter == nil
|
||||||
if config.Saver != nil {
|
const defaultSnapshotSize = 0 // means: use the default snapsize
|
||||||
txp = &tracex.HTTPTransportSaver{
|
return config.Saver.MaybeWrapHTTPTransport(txp, defaultSnapshotSize) // WAI with Saver == nil
|
||||||
HTTPTransport: txp, Saver: config.Saver}
|
|
||||||
}
|
|
||||||
return txp
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// httpTransportInfo contains the constructing function as well as the transport name
|
// httpTransportInfo contains the constructing function as well as the transport name
|
||||||
|
|
|
@ -100,21 +100,6 @@ func TestNewWithDialer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewWithByteCounter(t *testing.T) {
|
|
||||||
counter := bytecounter.New()
|
|
||||||
txp := NewHTTPTransport(Config{
|
|
||||||
ByteCounter: counter,
|
|
||||||
})
|
|
||||||
bctxp, ok := txp.(*bytecounter.HTTPTransport)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the transport we expected")
|
|
||||||
}
|
|
||||||
if bctxp.Counter != counter {
|
|
||||||
t.Fatal("not the byte counter we expected")
|
|
||||||
}
|
|
||||||
// We are going to trust the underlying transport returned by netxlite
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewWithSaver(t *testing.T) {
|
func TestNewWithSaver(t *testing.T) {
|
||||||
saver := new(tracex.Saver)
|
saver := new(tracex.Saver)
|
||||||
txp := NewHTTPTransport(Config{
|
txp := NewHTTPTransport(Config{
|
||||||
|
|
|
@ -202,7 +202,7 @@ func NewSession(ctx context.Context, config SessionConfig) (*Session, error) {
|
||||||
handshaker := netxlite.NewTLSHandshakerStdlib(sess.logger)
|
handshaker := netxlite.NewTLSHandshakerStdlib(sess.logger)
|
||||||
tlsDialer := netxlite.NewTLSDialer(dialer, handshaker)
|
tlsDialer := netxlite.NewTLSDialer(dialer, handshaker)
|
||||||
txp := netxlite.NewHTTPTransport(sess.logger, dialer, tlsDialer)
|
txp := netxlite.NewHTTPTransport(sess.logger, dialer, tlsDialer)
|
||||||
txp = bytecounter.NewHTTPTransport(txp, sess.byteCounter)
|
txp = bytecounter.WrapHTTPTransport(txp, sess.byteCounter)
|
||||||
sess.httpDefaultTransport = txp
|
sess.httpDefaultTransport = txp
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,16 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MaybeWrapWithBogonResolver wraps the given resolver with a BogonResolver
|
||||||
|
// iff the provided boolean flag is true. Otherwise, this factory just returns
|
||||||
|
// the provided resolver to the caller without any wrapping.
|
||||||
|
func MaybeWrapWithBogonResolver(enabled bool, reso model.Resolver) model.Resolver {
|
||||||
|
if enabled {
|
||||||
|
reso = &BogonResolver{Resolver: reso}
|
||||||
|
}
|
||||||
|
return reso
|
||||||
|
}
|
||||||
|
|
||||||
// BogonResolver is a bogon aware resolver. When a bogon is encountered in
|
// BogonResolver is a bogon aware resolver. When a bogon is encountered in
|
||||||
// a reply, this resolver will return ErrDNSBogon.
|
// a reply, this resolver will return ErrDNSBogon.
|
||||||
//
|
//
|
||||||
|
|
|
@ -9,6 +9,25 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMaybeWrapWithBogonResolver(t *testing.T) {
|
||||||
|
t.Run("with enabled equal to true", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Resolver{}
|
||||||
|
reso := MaybeWrapWithBogonResolver(true, underlying)
|
||||||
|
bogoreso := reso.(*BogonResolver)
|
||||||
|
if bogoreso.Resolver != underlying {
|
||||||
|
t.Fatal("did not wrap")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with enabled equal to false", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Resolver{}
|
||||||
|
reso := MaybeWrapWithBogonResolver(false, underlying)
|
||||||
|
if reso != underlying {
|
||||||
|
t.Fatal("expected unmodified resolver")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestBogonResolver(t *testing.T) {
|
func TestBogonResolver(t *testing.T) {
|
||||||
t.Run("LookupHost", func(t *testing.T) {
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
t.Run("with failure", func(t *testing.T) {
|
t.Run("with failure", func(t *testing.T) {
|
||||||
|
|
|
@ -163,8 +163,8 @@ func (lst *Listener) handleSocksConn(ctx context.Context, socksConn ptxSocksConn
|
||||||
// We _must_ wrap the ptConn. Wrapping the socks conn leads us to
|
// We _must_ wrap the ptConn. Wrapping the socks conn leads us to
|
||||||
// count the sent bytes as received and the received bytes as sent:
|
// count the sent bytes as received and the received bytes as sent:
|
||||||
// bytes flow in the opposite direction there for the socks conn.
|
// bytes flow in the opposite direction there for the socks conn.
|
||||||
ptConn = bytecounter.MaybeWrap(ptConn, lst.SessionByteCounter)
|
ptConn = bytecounter.MaybeWrapConn(ptConn, lst.SessionByteCounter)
|
||||||
ptConn = bytecounter.MaybeWrap(ptConn, lst.ExperimentByteCounter)
|
ptConn = bytecounter.MaybeWrapConn(ptConn, lst.ExperimentByteCounter)
|
||||||
lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership
|
lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership
|
||||||
return nil // used for testing
|
return nil // used for testing
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,20 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MaybeWrapHTTPTransport wraps the HTTPTransport to save events if this Saver
|
||||||
|
// is not nil and otherwise just returns the given HTTPTransport. The snapshotSize
|
||||||
|
// argument is the maximum response body snapshot size to save per response.
|
||||||
|
func (s *Saver) MaybeWrapHTTPTransport(txp model.HTTPTransport, snapshotSize int64) model.HTTPTransport {
|
||||||
|
if s != nil {
|
||||||
|
txp = &HTTPTransportSaver{
|
||||||
|
HTTPTransport: txp,
|
||||||
|
Saver: s,
|
||||||
|
SnapshotSize: snapshotSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return txp
|
||||||
|
}
|
||||||
|
|
||||||
// httpCloneRequestHeaders returns a clone of the headers where we have
|
// httpCloneRequestHeaders returns a clone of the headers where we have
|
||||||
// also set the host header, which normally is not set by
|
// also set the host header, which normally is not set by
|
||||||
// golang until it serializes the request itself.
|
// golang until it serializes the request itself.
|
||||||
|
|
|
@ -15,6 +15,32 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestMaybeWrapHTTPTransport(t *testing.T) {
|
||||||
|
const snapshotSize = 1024
|
||||||
|
|
||||||
|
t.Run("with non-nil saver", func(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
underlying := &mocks.HTTPTransport{}
|
||||||
|
txp := saver.MaybeWrapHTTPTransport(underlying, snapshotSize)
|
||||||
|
realTxp := txp.(*HTTPTransportSaver)
|
||||||
|
if realTxp.HTTPTransport != underlying {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
if realTxp.SnapshotSize != snapshotSize {
|
||||||
|
t.Fatal("did not set snapshotSize correctly")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with nil saver", func(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
underlying := &mocks.HTTPTransport{}
|
||||||
|
txp := saver.MaybeWrapHTTPTransport(underlying, snapshotSize)
|
||||||
|
if txp != underlying {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestHTTPTransportSaver(t *testing.T) {
|
func TestHTTPTransportSaver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user