refactor(netx): move construction logic outside package (#798)
For testability, replace most if-based construction logic with calls to well-tested factories living in other packages. While there, acknowledge that a bunch of types could now be private and make them private, modifying the code to call the public factories allowing to construct said types instead. Part of https://github.com/ooni/probe/issues/2121
This commit is contained in:
@@ -6,8 +6,8 @@ package bytecounter
|
||||
|
||||
import "net"
|
||||
|
||||
// Conn wraps a network connection and counts bytes.
|
||||
type Conn struct {
|
||||
// wrappedConn wraps a network connection and counts bytes.
|
||||
type wrappedConn struct {
|
||||
// net.Conn is the underlying net.Conn.
|
||||
net.Conn
|
||||
|
||||
@@ -16,28 +16,28 @@ type Conn struct {
|
||||
}
|
||||
|
||||
// Read implements net.Conn.Read.
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
func (c *wrappedConn) Read(p []byte) (int, error) {
|
||||
count, err := c.Conn.Read(p)
|
||||
c.Counter.CountBytesReceived(count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write.
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
func (c *wrappedConn) Write(p []byte) (int, error) {
|
||||
count, err := c.Conn.Write(p)
|
||||
c.Counter.CountBytesSent(count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Wrap returns a new conn that uses the given counter.
|
||||
func Wrap(conn net.Conn, counter *Counter) net.Conn {
|
||||
return &Conn{Conn: conn, Counter: counter}
|
||||
// WrapConn returns a new conn that uses the given counter.
|
||||
func WrapConn(conn net.Conn, counter *Counter) net.Conn {
|
||||
return &wrappedConn{Conn: conn, Counter: counter}
|
||||
}
|
||||
|
||||
// MaybeWrap is like wrap if counter is not nil, otherwise it's a no-op.
|
||||
func MaybeWrap(conn net.Conn, counter *Counter) net.Conn {
|
||||
// MaybeWrapConn is like wrap if counter is not nil, otherwise it's a no-op.
|
||||
func MaybeWrapConn(conn net.Conn, counter *Counter) net.Conn {
|
||||
if counter == nil {
|
||||
return conn
|
||||
}
|
||||
return Wrap(conn, counter)
|
||||
return WrapConn(conn, counter)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestConnWorksOnSuccess(t *testing.T) {
|
||||
func TestWrappedConnWorksOnSuccess(t *testing.T) {
|
||||
counter := New()
|
||||
underlying := &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
@@ -17,7 +17,7 @@ func TestConnWorksOnSuccess(t *testing.T) {
|
||||
return 4, nil
|
||||
},
|
||||
}
|
||||
conn := &Conn{
|
||||
conn := &wrappedConn{
|
||||
Conn: underlying,
|
||||
Counter: counter,
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func TestConnWorksOnSuccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnWorksOnFailure(t *testing.T) {
|
||||
func TestWrappedConnWorksOnFailure(t *testing.T) {
|
||||
readError := errors.New("read error")
|
||||
writeError := errors.New("write error")
|
||||
counter := New()
|
||||
@@ -47,7 +47,7 @@ func TestConnWorksOnFailure(t *testing.T) {
|
||||
return 0, writeError
|
||||
},
|
||||
}
|
||||
conn := &Conn{
|
||||
conn := &wrappedConn{
|
||||
Conn: underlying,
|
||||
Counter: counter,
|
||||
}
|
||||
@@ -65,20 +65,20 @@ func TestConnWorksOnFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrap(t *testing.T) {
|
||||
func TestWrapConn(t *testing.T) {
|
||||
conn := &mocks.Conn{}
|
||||
counter := New()
|
||||
nconn := Wrap(conn, counter)
|
||||
_, good := nconn.(*Conn)
|
||||
nconn := WrapConn(conn, counter)
|
||||
_, good := nconn.(*wrappedConn)
|
||||
if !good {
|
||||
t.Fatal("did not wrap")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaybeWrap(t *testing.T) {
|
||||
func TestMaybeWrapConn(t *testing.T) {
|
||||
t.Run("with nil counter", func(t *testing.T) {
|
||||
conn := &mocks.Conn{}
|
||||
nconn := MaybeWrap(conn, nil)
|
||||
nconn := MaybeWrapConn(conn, nil)
|
||||
_, good := nconn.(*mocks.Conn)
|
||||
if !good {
|
||||
t.Fatal("did not wrap")
|
||||
@@ -88,8 +88,8 @@ func TestMaybeWrap(t *testing.T) {
|
||||
t.Run("with legit counter", func(t *testing.T) {
|
||||
conn := &mocks.Conn{}
|
||||
counter := New()
|
||||
nconn := MaybeWrap(conn, counter)
|
||||
_, good := nconn.(*Conn)
|
||||
nconn := MaybeWrapConn(conn, counter)
|
||||
_, good := nconn.(*wrappedConn)
|
||||
if !good {
|
||||
t.Fatal("did not wrap")
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func WithExperimentByteCounter(ctx context.Context, counter *Counter) context.Co
|
||||
// MaybeWrapWithContextByteCounters wraps a conn with the byte counters
|
||||
// that have previosuly been configured into a context.
|
||||
func MaybeWrapWithContextByteCounters(ctx context.Context, conn net.Conn) net.Conn {
|
||||
conn = MaybeWrap(conn, ContextExperimentByteCounter(ctx))
|
||||
conn = MaybeWrap(conn, ContextSessionByteCounter(ctx))
|
||||
conn = MaybeWrapConn(conn, ContextExperimentByteCounter(ctx))
|
||||
conn = MaybeWrapConn(conn, ContextSessionByteCounter(ctx))
|
||||
return conn
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package bytecounter
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
func TestCounter(t *testing.T) {
|
||||
counter := New()
|
||||
counter.CountBytesReceived(16384)
|
||||
counter.CountKibiBytesReceived(10)
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// ContextAwareDialer is a model.Dialer that attempts to count bytes using
|
||||
// the MaybeWrapWithContextByteCounters function.
|
||||
// MaybeWrapWithContextAwareDialer wraps the given dialer with a ContextAwareDialer
|
||||
// if the enabled argument is true and otherwise just returns the given dialer.
|
||||
//
|
||||
// Bug
|
||||
//
|
||||
@@ -24,19 +24,29 @@ import (
|
||||
//
|
||||
// For this reason, this implementation may be heavily changed/removed
|
||||
// in the future (<- this message is now ~two years old, though).
|
||||
type ContextAwareDialer struct {
|
||||
func MaybeWrapWithContextAwareDialer(enabled bool, dialer model.Dialer) model.Dialer {
|
||||
if !enabled {
|
||||
return dialer
|
||||
}
|
||||
return WrapWithContextAwareDialer(dialer)
|
||||
}
|
||||
|
||||
// contextAwareDialer is a model.Dialer that attempts to count bytes using
|
||||
// the MaybeWrapWithContextByteCounters function.
|
||||
type contextAwareDialer struct {
|
||||
Dialer model.Dialer
|
||||
}
|
||||
|
||||
// NewContextAwareDialer creates a new ContextAwareDialer.
|
||||
func NewContextAwareDialer(dialer model.Dialer) *ContextAwareDialer {
|
||||
return &ContextAwareDialer{Dialer: dialer}
|
||||
// WrapWithContextAwareDialer creates a new ContextAwareDialer. See the docs
|
||||
// of MaybeWrapWithContextAwareDialer for a list of caveats.
|
||||
func WrapWithContextAwareDialer(dialer model.Dialer) *contextAwareDialer {
|
||||
return &contextAwareDialer{Dialer: dialer}
|
||||
}
|
||||
|
||||
var _ model.Dialer = &ContextAwareDialer{}
|
||||
var _ model.Dialer = &contextAwareDialer{}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d *ContextAwareDialer) DialContext(
|
||||
func (d *contextAwareDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
@@ -47,6 +57,6 @@ func (d *ContextAwareDialer) DialContext(
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||
func (d *ContextAwareDialer) CloseIdleConnections() {
|
||||
func (d *contextAwareDialer) CloseIdleConnections() {
|
||||
d.Dialer.CloseIdleConnections()
|
||||
}
|
||||
|
||||
@@ -10,6 +10,25 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestMaybeWrapWithContextAwareDialer(t *testing.T) {
|
||||
t.Run("when enabled is true", func(t *testing.T) {
|
||||
underlying := &mocks.Dialer{}
|
||||
dialer := MaybeWrapWithContextAwareDialer(true, underlying)
|
||||
realDialer := dialer.(*contextAwareDialer)
|
||||
if realDialer.Dialer != underlying {
|
||||
t.Fatal("did not wrap correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when enabled is false", func(t *testing.T) {
|
||||
underlying := &mocks.Dialer{}
|
||||
dialer := MaybeWrapWithContextAwareDialer(false, underlying)
|
||||
if dialer != underlying {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextAwareDialer(t *testing.T) {
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
dialAndUseConn := func(ctx context.Context, bufsiz int) error {
|
||||
@@ -26,7 +45,7 @@ func TestContextAwareDialer(t *testing.T) {
|
||||
return childConn, nil
|
||||
},
|
||||
}
|
||||
dialer := NewContextAwareDialer(child)
|
||||
dialer := WrapWithContextAwareDialer(child)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -68,7 +87,7 @@ func TestContextAwareDialer(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("failure", func(t *testing.T) {
|
||||
dialer := &ContextAwareDialer{
|
||||
dialer := &contextAwareDialer{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
@@ -92,7 +111,7 @@ func TestContextAwareDialer(t *testing.T) {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
dialer := NewContextAwareDialer(child)
|
||||
dialer := WrapWithContextAwareDialer(child)
|
||||
dialer.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
|
||||
@@ -7,29 +7,39 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// HTTPTransport is a model.HTTPTransport that counts bytes.
|
||||
type HTTPTransport struct {
|
||||
// MaybeWrapHTTPTransport takes in input an HTTPTransport and either wraps it
|
||||
// to perform byte counting, if this counter is not nil, or just returns to the
|
||||
// caller the original transport, when the counter is nil.
|
||||
func (c *Counter) MaybeWrapHTTPTransport(txp model.HTTPTransport) model.HTTPTransport {
|
||||
if c != nil {
|
||||
txp = WrapHTTPTransport(txp, c)
|
||||
}
|
||||
return txp
|
||||
}
|
||||
|
||||
// httpTransport is a model.HTTPTransport that counts bytes.
|
||||
type httpTransport struct {
|
||||
HTTPTransport model.HTTPTransport
|
||||
Counter *Counter
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new byte-counting-aware HTTP transport.
|
||||
func NewHTTPTransport(txp model.HTTPTransport, counter *Counter) model.HTTPTransport {
|
||||
return &HTTPTransport{
|
||||
// WrapHTTPTransport creates a new byte-counting-aware HTTP transport.
|
||||
func WrapHTTPTransport(txp model.HTTPTransport, counter *Counter) model.HTTPTransport {
|
||||
return &httpTransport{
|
||||
HTTPTransport: txp,
|
||||
Counter: counter,
|
||||
}
|
||||
}
|
||||
|
||||
var _ model.HTTPTransport = &HTTPTransport{}
|
||||
var _ model.HTTPTransport = &httpTransport{}
|
||||
|
||||
// CloseIdleConnections implements model.HTTPTransport.CloseIdleConnections.
|
||||
func (txp *HTTPTransport) CloseIdleConnections() {
|
||||
func (txp *httpTransport) CloseIdleConnections() {
|
||||
txp.HTTPTransport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// RoundTrip implements model.HTTPTRansport.RoundTrip
|
||||
func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (txp *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil {
|
||||
req.Body = &httpBodyWrapper{
|
||||
account: txp.Counter.CountBytesSent,
|
||||
@@ -50,11 +60,11 @@ func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
// Network implements model.HTTPTransport.Network.
|
||||
func (txp *HTTPTransport) Network() string {
|
||||
func (txp *httpTransport) Network() string {
|
||||
return txp.HTTPTransport.Network()
|
||||
}
|
||||
|
||||
func (txp *HTTPTransport) estimateRequestMetadata(req *http.Request) {
|
||||
func (txp *httpTransport) estimateRequestMetadata(req *http.Request) {
|
||||
txp.Counter.CountBytesSent(len(req.Method))
|
||||
txp.Counter.CountBytesSent(len(req.URL.String()))
|
||||
for key, values := range req.Header {
|
||||
@@ -68,7 +78,7 @@ func (txp *HTTPTransport) estimateRequestMetadata(req *http.Request) {
|
||||
txp.Counter.CountBytesSent(len("\r\n"))
|
||||
}
|
||||
|
||||
func (txp *HTTPTransport) estimateResponseMetadata(resp *http.Response) {
|
||||
func (txp *httpTransport) estimateResponseMetadata(resp *http.Response) {
|
||||
txp.Counter.CountBytesReceived(len(resp.Status))
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
|
||||
@@ -12,11 +12,32 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
func TestMaybeWrapHTTPTransport(t *testing.T) {
|
||||
t.Run("when counter is not nil", func(t *testing.T) {
|
||||
underlying := &mocks.HTTPTransport{}
|
||||
counter := &Counter{}
|
||||
txp := counter.MaybeWrapHTTPTransport(underlying)
|
||||
realTxp := txp.(*httpTransport)
|
||||
if realTxp.HTTPTransport != underlying {
|
||||
t.Fatal("did not wrap correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when counter is nil", func(t *testing.T) {
|
||||
underlying := &mocks.HTTPTransport{}
|
||||
var counter *Counter
|
||||
txp := counter.MaybeWrapHTTPTransport(underlying)
|
||||
if txp != underlying {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPTransport(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("failure", func(t *testing.T) {
|
||||
counter := New()
|
||||
txp := &HTTPTransport{
|
||||
txp := &httpTransport{
|
||||
Counter: counter,
|
||||
HTTPTransport: &mocks.HTTPTransport{
|
||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
@@ -47,7 +68,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
counter := New()
|
||||
txp := &HTTPTransport{
|
||||
txp := &httpTransport{
|
||||
Counter: counter,
|
||||
HTTPTransport: &mocks.HTTPTransport{
|
||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
@@ -91,7 +112,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||
|
||||
t.Run("success with EOF", func(t *testing.T) {
|
||||
counter := New()
|
||||
txp := &HTTPTransport{
|
||||
txp := &httpTransport{
|
||||
Counter: counter,
|
||||
HTTPTransport: &mocks.HTTPTransport{
|
||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
@@ -139,7 +160,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||
},
|
||||
}
|
||||
counter := New()
|
||||
txp := NewHTTPTransport(child, counter)
|
||||
txp := WrapHTTPTransport(child, counter)
|
||||
txp.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
@@ -154,7 +175,7 @@ func TestHTTPTransport(t *testing.T) {
|
||||
},
|
||||
}
|
||||
counter := New()
|
||||
txp := NewHTTPTransport(child, counter)
|
||||
txp := WrapHTTPTransport(child, counter)
|
||||
if network := txp.Network(); network != expected {
|
||||
t.Fatal("unexpected network", network)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user