refactor(ndt7): use netxlite rather than netx (#768)

This diff required us to move some code around, but no major
change actually happened, except better tests.

While there, I also slightly refactored ndt7's implementation and
removed the ProxyURL setting, which was actually unused.

See https://github.com/ooni/probe/issues/2121
This commit is contained in:
Simone Basso 2022-05-30 23:14:07 +02:00 committed by GitHub
parent 314c3c934d
commit 3265bc670a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 189 additions and 159 deletions

View File

@ -0,0 +1,52 @@
package bytecounter
//
// model.Dialer wrappers
//
import (
"context"
"net"
"github.com/ooni/probe-cli/v3/internal/model"
)
// ContextAwareDialer is a model.Dialer that attempts to count bytes using
// the MaybeWrapWithContextByteCounters function.
//
// Bug
//
// This implementation cannot properly account for the bytes that are sent by
// persistent connections, because they stick to the counters set when the
// connection was established. This typically means we miss the bytes sent and
// received when submitting a measurement. Such bytes are specifically not
// seen by the experiment specific byte counter.
//
// For this reason, this implementation may be heavily changed/removed
// in the future (<- this message is now ~two years old, though).
type ContextAwareDialer struct {
Dialer model.Dialer
}
// NewContextAwareDialer creates a new ContextAwareDialer.
func NewContextAwareDialer(dialer model.Dialer) *ContextAwareDialer {
return &ContextAwareDialer{Dialer: dialer}
}
var _ model.Dialer = &ContextAwareDialer{}
// DialContext implements Dialer.DialContext
func (d *ContextAwareDialer) DialContext(
ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
conn = MaybeWrapWithContextByteCounters(ctx, conn)
return conn, nil
}
// CloseIdleConnections implements Dialer.CloseIdleConnections.
func (d *ContextAwareDialer) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}

View File

@ -0,0 +1,101 @@
package bytecounter
import (
"context"
"errors"
"io"
"net"
"testing"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestContextAwareDialer(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
dialAndUseConn := func(ctx context.Context, bufsiz int) error {
childConn := &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
}
child := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return childConn, nil
},
}
dialer := NewContextAwareDialer(child)
conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443")
if err != nil {
return err
}
buffer := make([]byte, bufsiz)
conn.Read(buffer)
conn.Write(buffer)
return nil
}
t.Run("normal usage", func(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
sess := New()
ctx := context.Background()
ctx = WithSessionByteCounter(ctx, sess)
const count = 128
if err := dialAndUseConn(ctx, count); err != nil {
t.Fatal(err)
}
exp := New()
ctx = WithExperimentByteCounter(ctx, exp)
if err := dialAndUseConn(ctx, count); err != nil {
t.Fatal(err)
}
if exp.Received.Load() != count {
t.Fatal("experiment should have received 128 bytes")
}
if sess.Received.Load() != 2*count {
t.Fatal("session should have received 256 bytes")
}
if exp.Sent.Load() != count {
t.Fatal("experiment should have sent 128 bytes")
}
if sess.Sent.Load() != 256 {
t.Fatal("session should have sent 256 bytes")
}
})
t.Run("failure", func(t *testing.T) {
dialer := &ContextAwareDialer{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
},
},
}
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := NewContextAwareDialer(child)
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}

View File

@ -1,10 +0,0 @@
package ndt7
import "time"
func defaultCallbackJSON(data []byte) error {
return nil
}
func defaultCallbackPerformance(elapsed time.Duration, count int64) {
}

View File

@ -4,10 +4,9 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"net/url"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
@ -15,7 +14,6 @@ import (
type dialManager struct { type dialManager struct {
ndt7URL string ndt7URL string
logger model.Logger logger model.Logger
proxyURL *url.URL
readBufferSize int readBufferSize int
userAgent string userAgent string
writeBufferSize int writeBufferSize int
@ -32,16 +30,9 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM
} }
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) { func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
var reso model.Resolver = &netxlite.ResolverSystem{} reso := netxlite.NewResolverStdlib(mgr.logger)
reso = &netxlite.ResolverLogger{ dlr := netxlite.NewDialerWithResolver(mgr.logger, reso)
Resolver: reso, dlr = bytecounter.NewContextAwareDialer(dlr)
Logger: mgr.logger,
}
dlr := dialer.New(&dialer.Config{
ContextByteCounting: true,
Logger: mgr.logger,
ProxyURL: mgr.proxyURL,
}, reso)
// Implements shaping if the user builds using `-tags shaping` // Implements shaping if the user builds using `-tags shaping`
// See https://github.com/ooni/probe/issues/2112 // See https://github.com/ooni/probe/issues/2112
dlr = netxlite.NewMaybeShapingDialer(dlr) dlr = netxlite.NewMaybeShapingDialer(dlr)

View File

@ -11,7 +11,7 @@ import (
) )
type downloadManager struct { type downloadManager struct {
conn mockableConn conn wsConn
maxMessageSize int64 maxMessageSize int64
maxRuntime time.Duration maxRuntime time.Duration
measureInterval time.Duration measureInterval time.Duration
@ -20,7 +20,7 @@ type downloadManager struct {
} }
func newDownloadManager( func newDownloadManager(
conn mockableConn, onPerformance callbackPerformance, conn wsConn, onPerformance callbackPerformance,
onJSON callbackJSON, onJSON callbackJSON,
) downloadManager { ) downloadManager {
return downloadManager{ return downloadManager{

View File

@ -12,10 +12,17 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
func defaultCallbackJSON(data []byte) error {
return nil
}
func defaultCallbackPerformance(elapsed time.Duration, count int64) {
}
func TestDownloadSetReadDeadlineFailure(t *testing.T) { func TestDownloadSetReadDeadlineFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newDownloadManager( mgr := newDownloadManager(
&mockableConnMock{ &mockableWSConn{
ReadDeadlineErr: expected, ReadDeadlineErr: expected,
}, },
defaultCallbackPerformance, defaultCallbackPerformance,
@ -30,7 +37,7 @@ func TestDownloadSetReadDeadlineFailure(t *testing.T) {
func TestDownloadNextReaderFailure(t *testing.T) { func TestDownloadNextReaderFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newDownloadManager( mgr := newDownloadManager(
&mockableConnMock{ &mockableWSConn{
NextReaderErr: expected, NextReaderErr: expected,
}, },
defaultCallbackPerformance, defaultCallbackPerformance,
@ -45,7 +52,7 @@ func TestDownloadNextReaderFailure(t *testing.T) {
func TestDownloadTextMessageReadAllFailure(t *testing.T) { func TestDownloadTextMessageReadAllFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newDownloadManager( mgr := newDownloadManager(
&mockableConnMock{ &mockableWSConn{
NextReaderMsgType: websocket.TextMessage, NextReaderMsgType: websocket.TextMessage,
NextReaderReader: func() io.Reader { NextReaderReader: func() io.Reader {
return &alwaysFailingReader{ return &alwaysFailingReader{
@ -73,7 +80,7 @@ func (r *alwaysFailingReader) Read(p []byte) (int, error) {
func TestDownloadBinaryMessageReadAllFailure(t *testing.T) { func TestDownloadBinaryMessageReadAllFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newDownloadManager( mgr := newDownloadManager(
&mockableConnMock{ &mockableWSConn{
NextReaderMsgType: websocket.BinaryMessage, NextReaderMsgType: websocket.BinaryMessage,
NextReaderReader: func() io.Reader { NextReaderReader: func() io.Reader {
return &alwaysFailingReader{ return &alwaysFailingReader{
@ -92,7 +99,7 @@ func TestDownloadBinaryMessageReadAllFailure(t *testing.T) {
func TestDownloadOnJSONCallbackError(t *testing.T) { func TestDownloadOnJSONCallbackError(t *testing.T) {
mgr := newDownloadManager( mgr := newDownloadManager(
&mockableConnMock{ &mockableWSConn{
NextReaderMsgType: websocket.TextMessage, NextReaderMsgType: websocket.TextMessage,
NextReaderReader: func() io.Reader { NextReaderReader: func() io.Reader {
return &invalidJSONReader{} return &invalidJSONReader{}
@ -121,7 +128,7 @@ func TestDownloadOnJSONLoop(t *testing.T) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
mgr := newDownloadManager( mgr := newDownloadManager(
&mockableConnMock{ &mockableWSConn{
NextReaderMsgType: websocket.TextMessage, NextReaderMsgType: websocket.TextMessage,
NextReaderReader: func() io.Reader { NextReaderReader: func() io.Reader {
return &goodJSONReader{} return &goodJSONReader{}

View File

@ -12,7 +12,7 @@ func newMessage(n int) (*websocket.PreparedMessage, error) {
} }
type uploadManager struct { type uploadManager struct {
conn mockableConn conn wsConn
fractionForScaling int64 fractionForScaling int64
maxRuntime time.Duration maxRuntime time.Duration
maxMessageSize int maxMessageSize int
@ -24,7 +24,7 @@ type uploadManager struct {
} }
func newUploadManager( func newUploadManager(
conn mockableConn, onPerformance callbackPerformance, conn wsConn, onPerformance callbackPerformance,
) uploadManager { ) uploadManager {
return uploadManager{ return uploadManager{
conn: conn, conn: conn,

View File

@ -12,7 +12,7 @@ import (
func TestUploadSetWriteDeadlineFailure(t *testing.T) { func TestUploadSetWriteDeadlineFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newUploadManager( mgr := newUploadManager(
&mockableConnMock{ &mockableWSConn{
WriteDeadlineErr: expected, WriteDeadlineErr: expected,
}, },
defaultCallbackPerformance, defaultCallbackPerformance,
@ -26,7 +26,7 @@ func TestUploadSetWriteDeadlineFailure(t *testing.T) {
func TestUploadNewMessageFailure(t *testing.T) { func TestUploadNewMessageFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newUploadManager( mgr := newUploadManager(
&mockableConnMock{}, &mockableWSConn{},
defaultCallbackPerformance, defaultCallbackPerformance,
) )
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) { mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {
@ -41,7 +41,7 @@ func TestUploadNewMessageFailure(t *testing.T) {
func TestUploadWritePreparedMessageFailure(t *testing.T) { func TestUploadWritePreparedMessageFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newUploadManager( mgr := newUploadManager(
&mockableConnMock{ &mockableWSConn{
WritePreparedMessageErr: expected, WritePreparedMessageErr: expected,
}, },
defaultCallbackPerformance, defaultCallbackPerformance,
@ -55,7 +55,7 @@ func TestUploadWritePreparedMessageFailure(t *testing.T) {
func TestUploadWritePreparedMessageSubsequentFailure(t *testing.T) { func TestUploadWritePreparedMessageSubsequentFailure(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
mgr := newUploadManager( mgr := newUploadManager(
&mockableConnMock{}, &mockableWSConn{},
defaultCallbackPerformance, defaultCallbackPerformance,
) )
var already bool var already bool
@ -77,7 +77,7 @@ func TestUploadLoop(t *testing.T) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
mgr := newUploadManager( mgr := newUploadManager(
&mockableConnMock{}, &mockableWSConn{},
defaultCallbackPerformance, defaultCallbackPerformance,
) )
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) { mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {

View File

@ -7,7 +7,8 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type mockableConn interface { // weConn is the interface of gorilla/websocket.Conn
type wsConn interface {
NextReader() (int, io.Reader, error) NextReader() (int, io.Reader, error)
SetReadDeadline(time.Time) error SetReadDeadline(time.Time) error
SetReadLimit(int64) SetReadLimit(int64)

View File

@ -7,7 +7,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type mockableConnMock struct { type mockableWSConn struct {
NextReaderMsgType int NextReaderMsgType int
NextReaderErr error NextReaderErr error
NextReaderReader func() io.Reader NextReaderReader func() io.Reader
@ -16,7 +16,7 @@ type mockableConnMock struct {
WritePreparedMessageErr error WritePreparedMessageErr error
} }
func (c *mockableConnMock) NextReader() (int, io.Reader, error) { func (c *mockableWSConn) NextReader() (int, io.Reader, error) {
var reader io.Reader var reader io.Reader
if c.NextReaderReader != nil { if c.NextReaderReader != nil {
reader = c.NextReaderReader() reader = c.NextReaderReader()
@ -24,16 +24,16 @@ func (c *mockableConnMock) NextReader() (int, io.Reader, error) {
return c.NextReaderMsgType, reader, c.NextReaderErr return c.NextReaderMsgType, reader, c.NextReaderErr
} }
func (c *mockableConnMock) SetReadDeadline(time.Time) error { func (c *mockableWSConn) SetReadDeadline(time.Time) error {
return c.ReadDeadlineErr return c.ReadDeadlineErr
} }
func (c *mockableConnMock) SetReadLimit(int64) {} func (c *mockableWSConn) SetReadLimit(int64) {}
func (c *mockableConnMock) SetWriteDeadline(time.Time) error { func (c *mockableWSConn) SetWriteDeadline(time.Time) error {
return c.WriteDeadlineErr return c.WriteDeadlineErr
} }
func (c *mockableConnMock) WritePreparedMessage(*websocket.PreparedMessage) error { func (c *mockableWSConn) WritePreparedMessage(*websocket.PreparedMessage) error {
return c.WritePreparedMessageErr return c.WritePreparedMessageErr
} }

View File

@ -1,26 +1,5 @@
package dialer package dialer
import ( import "github.com/ooni/probe-cli/v3/internal/bytecounter"
"context"
"net"
"github.com/ooni/probe-cli/v3/internal/bytecounter" type byteCounterDialer = bytecounter.ContextAwareDialer
"github.com/ooni/probe-cli/v3/internal/model"
)
// byteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you
// should make sure that you insert this dialer in the dialing chain.
type byteCounterDialer struct {
model.Dialer
}
// DialContext implements Dialer.DialContext
func (d *byteCounterDialer) DialContext(
ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
conn = bytecounter.MaybeWrapWithContextByteCounters(ctx, conn)
return conn, nil
}

View File

@ -1,91 +0,0 @@
package dialer
import (
"context"
"errors"
"io"
"net"
"net/http"
"testing"
"github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func dorequest(ctx context.Context, url string) error {
txp := http.DefaultTransport.(*http.Transport).Clone()
defer txp.CloseIdleConnections()
dialer := &byteCounterDialer{Dialer: netxlite.DefaultDialer}
txp.DialContext = dialer.DialContext
client := &http.Client{Transport: txp}
req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
if _, err := netxlite.CopyContext(ctx, io.Discard, resp.Body); err != nil {
return err
}
return resp.Body.Close()
}
func TestByteCounterNormalUsage(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
sess := bytecounter.New()
ctx := context.Background()
ctx = bytecounter.WithSessionByteCounter(ctx, sess)
if err := dorequest(ctx, "http://www.google.com"); err != nil {
t.Fatal(err)
}
exp := bytecounter.New()
ctx = bytecounter.WithExperimentByteCounter(ctx, exp)
if err := dorequest(ctx, "http://facebook.com"); err != nil {
t.Fatal(err)
}
if exp.Received.Load() <= 0 {
t.Fatal("experiment should have received some bytes")
}
if sess.Received.Load() <= exp.Received.Load() {
t.Fatal("session should have received more than experiment")
}
if exp.Sent.Load() <= 0 {
t.Fatal("experiment should have sent some bytes")
}
if sess.Sent.Load() <= exp.Sent.Load() {
t.Fatal("session should have sent more than experiment")
}
}
func TestByteCounterNoHandlers(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
ctx := context.Background()
if err := dorequest(ctx, "http://www.google.com"); err != nil {
t.Fatal(err)
}
if err := dorequest(ctx, "http://facebook.com"); err != nil {
t.Fatal(err)
}
}
func TestByteCounterConnectFailure(t *testing.T) {
dialer := &byteCounterDialer{Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
},
}}
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}