diff --git a/internal/bytecounter/dialer.go b/internal/bytecounter/dialer.go new file mode 100644 index 0000000..d4246e5 --- /dev/null +++ b/internal/bytecounter/dialer.go @@ -0,0 +1,52 @@ +package bytecounter + +// +// model.Dialer wrappers +// + +import ( + "context" + "net" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// ContextAwareDialer is a model.Dialer that attempts to count bytes using +// the MaybeWrapWithContextByteCounters function. +// +// Bug +// +// This implementation cannot properly account for the bytes that are sent by +// persistent connections, because they stick to the counters set when the +// connection was established. This typically means we miss the bytes sent and +// received when submitting a measurement. Such bytes are specifically not +// seen by the experiment specific byte counter. +// +// For this reason, this implementation may be heavily changed/removed +// in the future (<- this message is now ~two years old, though). +type ContextAwareDialer struct { + Dialer model.Dialer +} + +// NewContextAwareDialer creates a new ContextAwareDialer. +func NewContextAwareDialer(dialer model.Dialer) *ContextAwareDialer { + return &ContextAwareDialer{Dialer: dialer} +} + +var _ model.Dialer = &ContextAwareDialer{} + +// DialContext implements Dialer.DialContext +func (d *ContextAwareDialer) DialContext( + ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + conn = MaybeWrapWithContextByteCounters(ctx, conn) + return conn, nil +} + +// CloseIdleConnections implements Dialer.CloseIdleConnections. +func (d *ContextAwareDialer) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() +} diff --git a/internal/bytecounter/dialer_test.go b/internal/bytecounter/dialer_test.go new file mode 100644 index 0000000..35bdd94 --- /dev/null +++ b/internal/bytecounter/dialer_test.go @@ -0,0 +1,101 @@ +package bytecounter + +import ( + "context" + "errors" + "io" + "net" + "testing" + + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestContextAwareDialer(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + dialAndUseConn := func(ctx context.Context, bufsiz int) error { + childConn := &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return len(b), nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + } + child := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return childConn, nil + }, + } + dialer := NewContextAwareDialer(child) + conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443") + if err != nil { + return err + } + buffer := make([]byte, bufsiz) + conn.Read(buffer) + conn.Write(buffer) + return nil + } + + t.Run("normal usage", func(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + sess := New() + ctx := context.Background() + ctx = WithSessionByteCounter(ctx, sess) + const count = 128 + if err := dialAndUseConn(ctx, count); err != nil { + t.Fatal(err) + } + exp := New() + ctx = WithExperimentByteCounter(ctx, exp) + if err := dialAndUseConn(ctx, count); err != nil { + t.Fatal(err) + } + if exp.Received.Load() != count { + t.Fatal("experiment should have received 128 bytes") + } + if sess.Received.Load() != 2*count { + t.Fatal("session should have received 256 bytes") + } + if exp.Sent.Load() != count { + t.Fatal("experiment should have sent 128 bytes") + } + if sess.Sent.Load() != 256 { + t.Fatal("session should have sent 256 bytes") + } + }) + + t.Run("failure", func(t *testing.T) { + dialer := &ContextAwareDialer{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }, + } + conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + dialer := NewContextAwareDialer(child) + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) +} diff --git a/internal/engine/experiment/ndt7/callback_test.go b/internal/engine/experiment/ndt7/callback_test.go deleted file mode 100644 index 1a78613..0000000 --- a/internal/engine/experiment/ndt7/callback_test.go +++ /dev/null @@ -1,10 +0,0 @@ -package ndt7 - -import "time" - -func defaultCallbackJSON(data []byte) error { - return nil -} - -func defaultCallbackPerformance(elapsed time.Duration, count int64) { -} diff --git a/internal/engine/experiment/ndt7/dial.go b/internal/engine/experiment/ndt7/dial.go index 9966382..0e02ce9 100644 --- a/internal/engine/experiment/ndt7/dial.go +++ b/internal/engine/experiment/ndt7/dial.go @@ -4,10 +4,9 @@ import ( "context" "crypto/tls" "net/http" - "net/url" "github.com/gorilla/websocket" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -15,7 +14,6 @@ import ( type dialManager struct { ndt7URL string logger model.Logger - proxyURL *url.URL readBufferSize int userAgent string writeBufferSize int @@ -32,16 +30,9 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM } func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) { - var reso model.Resolver = &netxlite.ResolverSystem{} - reso = &netxlite.ResolverLogger{ - Resolver: reso, - Logger: mgr.logger, - } - dlr := dialer.New(&dialer.Config{ - ContextByteCounting: true, - Logger: mgr.logger, - ProxyURL: mgr.proxyURL, - }, reso) + reso := netxlite.NewResolverStdlib(mgr.logger) + dlr := netxlite.NewDialerWithResolver(mgr.logger, reso) + dlr = bytecounter.NewContextAwareDialer(dlr) // Implements shaping if the user builds using `-tags shaping` // See https://github.com/ooni/probe/issues/2112 dlr = netxlite.NewMaybeShapingDialer(dlr) diff --git a/internal/engine/experiment/ndt7/download.go b/internal/engine/experiment/ndt7/download.go index 0804de5..bd18adf 100644 --- a/internal/engine/experiment/ndt7/download.go +++ b/internal/engine/experiment/ndt7/download.go @@ -11,7 +11,7 @@ import ( ) type downloadManager struct { - conn mockableConn + conn wsConn maxMessageSize int64 maxRuntime time.Duration measureInterval time.Duration @@ -20,7 +20,7 @@ type downloadManager struct { } func newDownloadManager( - conn mockableConn, onPerformance callbackPerformance, + conn wsConn, onPerformance callbackPerformance, onJSON callbackJSON, ) downloadManager { return downloadManager{ diff --git a/internal/engine/experiment/ndt7/download_test.go b/internal/engine/experiment/ndt7/download_test.go index 8e87ab7..84a690e 100644 --- a/internal/engine/experiment/ndt7/download_test.go +++ b/internal/engine/experiment/ndt7/download_test.go @@ -12,10 +12,17 @@ import ( "github.com/gorilla/websocket" ) +func defaultCallbackJSON(data []byte) error { + return nil +} + +func defaultCallbackPerformance(elapsed time.Duration, count int64) { +} + func TestDownloadSetReadDeadlineFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newDownloadManager( - &mockableConnMock{ + &mockableWSConn{ ReadDeadlineErr: expected, }, defaultCallbackPerformance, @@ -30,7 +37,7 @@ func TestDownloadSetReadDeadlineFailure(t *testing.T) { func TestDownloadNextReaderFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newDownloadManager( - &mockableConnMock{ + &mockableWSConn{ NextReaderErr: expected, }, defaultCallbackPerformance, @@ -45,7 +52,7 @@ func TestDownloadNextReaderFailure(t *testing.T) { func TestDownloadTextMessageReadAllFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newDownloadManager( - &mockableConnMock{ + &mockableWSConn{ NextReaderMsgType: websocket.TextMessage, NextReaderReader: func() io.Reader { return &alwaysFailingReader{ @@ -73,7 +80,7 @@ func (r *alwaysFailingReader) Read(p []byte) (int, error) { func TestDownloadBinaryMessageReadAllFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newDownloadManager( - &mockableConnMock{ + &mockableWSConn{ NextReaderMsgType: websocket.BinaryMessage, NextReaderReader: func() io.Reader { return &alwaysFailingReader{ @@ -92,7 +99,7 @@ func TestDownloadBinaryMessageReadAllFailure(t *testing.T) { func TestDownloadOnJSONCallbackError(t *testing.T) { mgr := newDownloadManager( - &mockableConnMock{ + &mockableWSConn{ NextReaderMsgType: websocket.TextMessage, NextReaderReader: func() io.Reader { return &invalidJSONReader{} @@ -121,7 +128,7 @@ func TestDownloadOnJSONLoop(t *testing.T) { t.Skip("skip test in short mode") } mgr := newDownloadManager( - &mockableConnMock{ + &mockableWSConn{ NextReaderMsgType: websocket.TextMessage, NextReaderReader: func() io.Reader { return &goodJSONReader{} diff --git a/internal/engine/experiment/ndt7/upload.go b/internal/engine/experiment/ndt7/upload.go index f45b3fb..8717ade 100644 --- a/internal/engine/experiment/ndt7/upload.go +++ b/internal/engine/experiment/ndt7/upload.go @@ -12,7 +12,7 @@ func newMessage(n int) (*websocket.PreparedMessage, error) { } type uploadManager struct { - conn mockableConn + conn wsConn fractionForScaling int64 maxRuntime time.Duration maxMessageSize int @@ -24,7 +24,7 @@ type uploadManager struct { } func newUploadManager( - conn mockableConn, onPerformance callbackPerformance, + conn wsConn, onPerformance callbackPerformance, ) uploadManager { return uploadManager{ conn: conn, diff --git a/internal/engine/experiment/ndt7/upload_test.go b/internal/engine/experiment/ndt7/upload_test.go index b97deda..11b4998 100644 --- a/internal/engine/experiment/ndt7/upload_test.go +++ b/internal/engine/experiment/ndt7/upload_test.go @@ -12,7 +12,7 @@ import ( func TestUploadSetWriteDeadlineFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newUploadManager( - &mockableConnMock{ + &mockableWSConn{ WriteDeadlineErr: expected, }, defaultCallbackPerformance, @@ -26,7 +26,7 @@ func TestUploadSetWriteDeadlineFailure(t *testing.T) { func TestUploadNewMessageFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newUploadManager( - &mockableConnMock{}, + &mockableWSConn{}, defaultCallbackPerformance, ) mgr.newMessage = func(int) (*websocket.PreparedMessage, error) { @@ -41,7 +41,7 @@ func TestUploadNewMessageFailure(t *testing.T) { func TestUploadWritePreparedMessageFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newUploadManager( - &mockableConnMock{ + &mockableWSConn{ WritePreparedMessageErr: expected, }, defaultCallbackPerformance, @@ -55,7 +55,7 @@ func TestUploadWritePreparedMessageFailure(t *testing.T) { func TestUploadWritePreparedMessageSubsequentFailure(t *testing.T) { expected := errors.New("mocked error") mgr := newUploadManager( - &mockableConnMock{}, + &mockableWSConn{}, defaultCallbackPerformance, ) var already bool @@ -77,7 +77,7 @@ func TestUploadLoop(t *testing.T) { t.Skip("skip test in short mode") } mgr := newUploadManager( - &mockableConnMock{}, + &mockableWSConn{}, defaultCallbackPerformance, ) mgr.newMessage = func(int) (*websocket.PreparedMessage, error) { diff --git a/internal/engine/experiment/ndt7/mockable.go b/internal/engine/experiment/ndt7/wsconn.go similarity index 77% rename from internal/engine/experiment/ndt7/mockable.go rename to internal/engine/experiment/ndt7/wsconn.go index 4c0d6d7..5426a83 100644 --- a/internal/engine/experiment/ndt7/mockable.go +++ b/internal/engine/experiment/ndt7/wsconn.go @@ -7,7 +7,8 @@ import ( "github.com/gorilla/websocket" ) -type mockableConn interface { +// weConn is the interface of gorilla/websocket.Conn +type wsConn interface { NextReader() (int, io.Reader, error) SetReadDeadline(time.Time) error SetReadLimit(int64) diff --git a/internal/engine/experiment/ndt7/mockable_test.go b/internal/engine/experiment/ndt7/wsconn_test.go similarity index 58% rename from internal/engine/experiment/ndt7/mockable_test.go rename to internal/engine/experiment/ndt7/wsconn_test.go index 7e0d61c..01e7d5d 100644 --- a/internal/engine/experiment/ndt7/mockable_test.go +++ b/internal/engine/experiment/ndt7/wsconn_test.go @@ -7,7 +7,7 @@ import ( "github.com/gorilla/websocket" ) -type mockableConnMock struct { +type mockableWSConn struct { NextReaderMsgType int NextReaderErr error NextReaderReader func() io.Reader @@ -16,7 +16,7 @@ type mockableConnMock struct { WritePreparedMessageErr error } -func (c *mockableConnMock) NextReader() (int, io.Reader, error) { +func (c *mockableWSConn) NextReader() (int, io.Reader, error) { var reader io.Reader if c.NextReaderReader != nil { reader = c.NextReaderReader() @@ -24,16 +24,16 @@ func (c *mockableConnMock) NextReader() (int, io.Reader, error) { return c.NextReaderMsgType, reader, c.NextReaderErr } -func (c *mockableConnMock) SetReadDeadline(time.Time) error { +func (c *mockableWSConn) SetReadDeadline(time.Time) error { return c.ReadDeadlineErr } -func (c *mockableConnMock) SetReadLimit(int64) {} +func (c *mockableWSConn) SetReadLimit(int64) {} -func (c *mockableConnMock) SetWriteDeadline(time.Time) error { +func (c *mockableWSConn) SetWriteDeadline(time.Time) error { return c.WriteDeadlineErr } -func (c *mockableConnMock) WritePreparedMessage(*websocket.PreparedMessage) error { +func (c *mockableWSConn) WritePreparedMessage(*websocket.PreparedMessage) error { return c.WritePreparedMessageErr } diff --git a/internal/engine/netx/dialer/bytecounter.go b/internal/engine/netx/dialer/bytecounter.go index 447fb4b..08aae8b 100644 --- a/internal/engine/netx/dialer/bytecounter.go +++ b/internal/engine/netx/dialer/bytecounter.go @@ -1,26 +1,5 @@ package dialer -import ( - "context" - "net" +import "github.com/ooni/probe-cli/v3/internal/bytecounter" - "github.com/ooni/probe-cli/v3/internal/bytecounter" - "github.com/ooni/probe-cli/v3/internal/model" -) - -// byteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you -// should make sure that you insert this dialer in the dialing chain. -type byteCounterDialer struct { - model.Dialer -} - -// DialContext implements Dialer.DialContext -func (d *byteCounterDialer) DialContext( - ctx context.Context, network, address string) (net.Conn, error) { - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, err - } - conn = bytecounter.MaybeWrapWithContextByteCounters(ctx, conn) - return conn, nil -} +type byteCounterDialer = bytecounter.ContextAwareDialer diff --git a/internal/engine/netx/dialer/bytecounter_test.go b/internal/engine/netx/dialer/bytecounter_test.go deleted file mode 100644 index 48b7b76..0000000 --- a/internal/engine/netx/dialer/bytecounter_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package dialer - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "testing" - - "github.com/ooni/probe-cli/v3/internal/bytecounter" - "github.com/ooni/probe-cli/v3/internal/model/mocks" - "github.com/ooni/probe-cli/v3/internal/netxlite" -) - -func dorequest(ctx context.Context, url string) error { - txp := http.DefaultTransport.(*http.Transport).Clone() - defer txp.CloseIdleConnections() - dialer := &byteCounterDialer{Dialer: netxlite.DefaultDialer} - txp.DialContext = dialer.DialContext - client := &http.Client{Transport: txp} - req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil) - if err != nil { - return err - } - resp, err := client.Do(req) - if err != nil { - return err - } - if _, err := netxlite.CopyContext(ctx, io.Discard, resp.Body); err != nil { - return err - } - return resp.Body.Close() -} - -func TestByteCounterNormalUsage(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - sess := bytecounter.New() - ctx := context.Background() - ctx = bytecounter.WithSessionByteCounter(ctx, sess) - if err := dorequest(ctx, "http://www.google.com"); err != nil { - t.Fatal(err) - } - exp := bytecounter.New() - ctx = bytecounter.WithExperimentByteCounter(ctx, exp) - if err := dorequest(ctx, "http://facebook.com"); err != nil { - t.Fatal(err) - } - if exp.Received.Load() <= 0 { - t.Fatal("experiment should have received some bytes") - } - if sess.Received.Load() <= exp.Received.Load() { - t.Fatal("session should have received more than experiment") - } - if exp.Sent.Load() <= 0 { - t.Fatal("experiment should have sent some bytes") - } - if sess.Sent.Load() <= exp.Sent.Load() { - t.Fatal("session should have sent more than experiment") - } -} - -func TestByteCounterNoHandlers(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - ctx := context.Background() - if err := dorequest(ctx, "http://www.google.com"); err != nil { - t.Fatal(err) - } - if err := dorequest(ctx, "http://facebook.com"); err != nil { - t.Fatal(err) - } -} - -func TestByteCounterConnectFailure(t *testing.T) { - dialer := &byteCounterDialer{Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, io.EOF - }, - }} - conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80") - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } -}