refactor(ndt7): use netxlite rather than netx (#768)
This diff required us to move some code around, but no major change actually happened, except better tests. While there, I also slightly refactored ndt7's implementation and removed the ProxyURL setting, which was actually unused. See https://github.com/ooni/probe/issues/2121
This commit is contained in:
parent
314c3c934d
commit
3265bc670a
52
internal/bytecounter/dialer.go
Normal file
52
internal/bytecounter/dialer.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package bytecounter
|
||||
|
||||
//
|
||||
// model.Dialer wrappers
|
||||
//
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// ContextAwareDialer is a model.Dialer that attempts to count bytes using
|
||||
// the MaybeWrapWithContextByteCounters function.
|
||||
//
|
||||
// Bug
|
||||
//
|
||||
// This implementation cannot properly account for the bytes that are sent by
|
||||
// persistent connections, because they stick to the counters set when the
|
||||
// connection was established. This typically means we miss the bytes sent and
|
||||
// received when submitting a measurement. Such bytes are specifically not
|
||||
// seen by the experiment specific byte counter.
|
||||
//
|
||||
// For this reason, this implementation may be heavily changed/removed
|
||||
// in the future (<- this message is now ~two years old, though).
|
||||
type ContextAwareDialer struct {
|
||||
Dialer model.Dialer
|
||||
}
|
||||
|
||||
// NewContextAwareDialer creates a new ContextAwareDialer.
|
||||
func NewContextAwareDialer(dialer model.Dialer) *ContextAwareDialer {
|
||||
return &ContextAwareDialer{Dialer: dialer}
|
||||
}
|
||||
|
||||
var _ model.Dialer = &ContextAwareDialer{}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d *ContextAwareDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = MaybeWrapWithContextByteCounters(ctx, conn)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||
func (d *ContextAwareDialer) CloseIdleConnections() {
|
||||
d.Dialer.CloseIdleConnections()
|
||||
}
|
101
internal/bytecounter/dialer_test.go
Normal file
101
internal/bytecounter/dialer_test.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
package bytecounter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestContextAwareDialer(t *testing.T) {
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
dialAndUseConn := func(ctx context.Context, bufsiz int) error {
|
||||
childConn := &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
}
|
||||
child := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return childConn, nil
|
||||
},
|
||||
}
|
||||
dialer := NewContextAwareDialer(child)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer := make([]byte, bufsiz)
|
||||
conn.Read(buffer)
|
||||
conn.Write(buffer)
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("normal usage", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
sess := New()
|
||||
ctx := context.Background()
|
||||
ctx = WithSessionByteCounter(ctx, sess)
|
||||
const count = 128
|
||||
if err := dialAndUseConn(ctx, count); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
exp := New()
|
||||
ctx = WithExperimentByteCounter(ctx, exp)
|
||||
if err := dialAndUseConn(ctx, count); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if exp.Received.Load() != count {
|
||||
t.Fatal("experiment should have received 128 bytes")
|
||||
}
|
||||
if sess.Received.Load() != 2*count {
|
||||
t.Fatal("session should have received 256 bytes")
|
||||
}
|
||||
if exp.Sent.Load() != count {
|
||||
t.Fatal("experiment should have sent 128 bytes")
|
||||
}
|
||||
if sess.Sent.Load() != 256 {
|
||||
t.Fatal("session should have sent 256 bytes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failure", func(t *testing.T) {
|
||||
dialer := &ContextAwareDialer{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
child := &mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
dialer := NewContextAwareDialer(child)
|
||||
dialer.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
package ndt7
|
||||
|
||||
import "time"
|
||||
|
||||
func defaultCallbackJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultCallbackPerformance(elapsed time.Duration, count int64) {
|
||||
}
|
|
@ -4,10 +4,9 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
@ -15,7 +14,6 @@ import (
|
|||
type dialManager struct {
|
||||
ndt7URL string
|
||||
logger model.Logger
|
||||
proxyURL *url.URL
|
||||
readBufferSize int
|
||||
userAgent string
|
||||
writeBufferSize int
|
||||
|
@ -32,16 +30,9 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM
|
|||
}
|
||||
|
||||
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
|
||||
var reso model.Resolver = &netxlite.ResolverSystem{}
|
||||
reso = &netxlite.ResolverLogger{
|
||||
Resolver: reso,
|
||||
Logger: mgr.logger,
|
||||
}
|
||||
dlr := dialer.New(&dialer.Config{
|
||||
ContextByteCounting: true,
|
||||
Logger: mgr.logger,
|
||||
ProxyURL: mgr.proxyURL,
|
||||
}, reso)
|
||||
reso := netxlite.NewResolverStdlib(mgr.logger)
|
||||
dlr := netxlite.NewDialerWithResolver(mgr.logger, reso)
|
||||
dlr = bytecounter.NewContextAwareDialer(dlr)
|
||||
// Implements shaping if the user builds using `-tags shaping`
|
||||
// See https://github.com/ooni/probe/issues/2112
|
||||
dlr = netxlite.NewMaybeShapingDialer(dlr)
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
type downloadManager struct {
|
||||
conn mockableConn
|
||||
conn wsConn
|
||||
maxMessageSize int64
|
||||
maxRuntime time.Duration
|
||||
measureInterval time.Duration
|
||||
|
@ -20,7 +20,7 @@ type downloadManager struct {
|
|||
}
|
||||
|
||||
func newDownloadManager(
|
||||
conn mockableConn, onPerformance callbackPerformance,
|
||||
conn wsConn, onPerformance callbackPerformance,
|
||||
onJSON callbackJSON,
|
||||
) downloadManager {
|
||||
return downloadManager{
|
||||
|
|
|
@ -12,10 +12,17 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func defaultCallbackJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultCallbackPerformance(elapsed time.Duration, count int64) {
|
||||
}
|
||||
|
||||
func TestDownloadSetReadDeadlineFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newDownloadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
ReadDeadlineErr: expected,
|
||||
},
|
||||
defaultCallbackPerformance,
|
||||
|
@ -30,7 +37,7 @@ func TestDownloadSetReadDeadlineFailure(t *testing.T) {
|
|||
func TestDownloadNextReaderFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newDownloadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
NextReaderErr: expected,
|
||||
},
|
||||
defaultCallbackPerformance,
|
||||
|
@ -45,7 +52,7 @@ func TestDownloadNextReaderFailure(t *testing.T) {
|
|||
func TestDownloadTextMessageReadAllFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newDownloadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
NextReaderMsgType: websocket.TextMessage,
|
||||
NextReaderReader: func() io.Reader {
|
||||
return &alwaysFailingReader{
|
||||
|
@ -73,7 +80,7 @@ func (r *alwaysFailingReader) Read(p []byte) (int, error) {
|
|||
func TestDownloadBinaryMessageReadAllFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newDownloadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
NextReaderMsgType: websocket.BinaryMessage,
|
||||
NextReaderReader: func() io.Reader {
|
||||
return &alwaysFailingReader{
|
||||
|
@ -92,7 +99,7 @@ func TestDownloadBinaryMessageReadAllFailure(t *testing.T) {
|
|||
|
||||
func TestDownloadOnJSONCallbackError(t *testing.T) {
|
||||
mgr := newDownloadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
NextReaderMsgType: websocket.TextMessage,
|
||||
NextReaderReader: func() io.Reader {
|
||||
return &invalidJSONReader{}
|
||||
|
@ -121,7 +128,7 @@ func TestDownloadOnJSONLoop(t *testing.T) {
|
|||
t.Skip("skip test in short mode")
|
||||
}
|
||||
mgr := newDownloadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
NextReaderMsgType: websocket.TextMessage,
|
||||
NextReaderReader: func() io.Reader {
|
||||
return &goodJSONReader{}
|
||||
|
|
|
@ -12,7 +12,7 @@ func newMessage(n int) (*websocket.PreparedMessage, error) {
|
|||
}
|
||||
|
||||
type uploadManager struct {
|
||||
conn mockableConn
|
||||
conn wsConn
|
||||
fractionForScaling int64
|
||||
maxRuntime time.Duration
|
||||
maxMessageSize int
|
||||
|
@ -24,7 +24,7 @@ type uploadManager struct {
|
|||
}
|
||||
|
||||
func newUploadManager(
|
||||
conn mockableConn, onPerformance callbackPerformance,
|
||||
conn wsConn, onPerformance callbackPerformance,
|
||||
) uploadManager {
|
||||
return uploadManager{
|
||||
conn: conn,
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
func TestUploadSetWriteDeadlineFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newUploadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
WriteDeadlineErr: expected,
|
||||
},
|
||||
defaultCallbackPerformance,
|
||||
|
@ -26,7 +26,7 @@ func TestUploadSetWriteDeadlineFailure(t *testing.T) {
|
|||
func TestUploadNewMessageFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newUploadManager(
|
||||
&mockableConnMock{},
|
||||
&mockableWSConn{},
|
||||
defaultCallbackPerformance,
|
||||
)
|
||||
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {
|
||||
|
@ -41,7 +41,7 @@ func TestUploadNewMessageFailure(t *testing.T) {
|
|||
func TestUploadWritePreparedMessageFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newUploadManager(
|
||||
&mockableConnMock{
|
||||
&mockableWSConn{
|
||||
WritePreparedMessageErr: expected,
|
||||
},
|
||||
defaultCallbackPerformance,
|
||||
|
@ -55,7 +55,7 @@ func TestUploadWritePreparedMessageFailure(t *testing.T) {
|
|||
func TestUploadWritePreparedMessageSubsequentFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
mgr := newUploadManager(
|
||||
&mockableConnMock{},
|
||||
&mockableWSConn{},
|
||||
defaultCallbackPerformance,
|
||||
)
|
||||
var already bool
|
||||
|
@ -77,7 +77,7 @@ func TestUploadLoop(t *testing.T) {
|
|||
t.Skip("skip test in short mode")
|
||||
}
|
||||
mgr := newUploadManager(
|
||||
&mockableConnMock{},
|
||||
&mockableWSConn{},
|
||||
defaultCallbackPerformance,
|
||||
)
|
||||
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {
|
||||
|
|
|
@ -7,7 +7,8 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type mockableConn interface {
|
||||
// weConn is the interface of gorilla/websocket.Conn
|
||||
type wsConn interface {
|
||||
NextReader() (int, io.Reader, error)
|
||||
SetReadDeadline(time.Time) error
|
||||
SetReadLimit(int64)
|
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type mockableConnMock struct {
|
||||
type mockableWSConn struct {
|
||||
NextReaderMsgType int
|
||||
NextReaderErr error
|
||||
NextReaderReader func() io.Reader
|
||||
|
@ -16,7 +16,7 @@ type mockableConnMock struct {
|
|||
WritePreparedMessageErr error
|
||||
}
|
||||
|
||||
func (c *mockableConnMock) NextReader() (int, io.Reader, error) {
|
||||
func (c *mockableWSConn) NextReader() (int, io.Reader, error) {
|
||||
var reader io.Reader
|
||||
if c.NextReaderReader != nil {
|
||||
reader = c.NextReaderReader()
|
||||
|
@ -24,16 +24,16 @@ func (c *mockableConnMock) NextReader() (int, io.Reader, error) {
|
|||
return c.NextReaderMsgType, reader, c.NextReaderErr
|
||||
}
|
||||
|
||||
func (c *mockableConnMock) SetReadDeadline(time.Time) error {
|
||||
func (c *mockableWSConn) SetReadDeadline(time.Time) error {
|
||||
return c.ReadDeadlineErr
|
||||
}
|
||||
|
||||
func (c *mockableConnMock) SetReadLimit(int64) {}
|
||||
func (c *mockableWSConn) SetReadLimit(int64) {}
|
||||
|
||||
func (c *mockableConnMock) SetWriteDeadline(time.Time) error {
|
||||
func (c *mockableWSConn) SetWriteDeadline(time.Time) error {
|
||||
return c.WriteDeadlineErr
|
||||
}
|
||||
|
||||
func (c *mockableConnMock) WritePreparedMessage(*websocket.PreparedMessage) error {
|
||||
func (c *mockableWSConn) WritePreparedMessage(*websocket.PreparedMessage) error {
|
||||
return c.WritePreparedMessageErr
|
||||
}
|
|
@ -1,26 +1,5 @@
|
|||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
import "github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// byteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you
|
||||
// should make sure that you insert this dialer in the dialing chain.
|
||||
type byteCounterDialer struct {
|
||||
model.Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d *byteCounterDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn = bytecounter.MaybeWrapWithContextByteCounters(ctx, conn)
|
||||
return conn, nil
|
||||
}
|
||||
type byteCounterDialer = bytecounter.ContextAwareDialer
|
||||
|
|
|
@ -1,91 +0,0 @@
|
|||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
func dorequest(ctx context.Context, url string) error {
|
||||
txp := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defer txp.CloseIdleConnections()
|
||||
dialer := &byteCounterDialer{Dialer: netxlite.DefaultDialer}
|
||||
txp.DialContext = dialer.DialContext
|
||||
client := &http.Client{Transport: txp}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := netxlite.CopyContext(ctx, io.Discard, resp.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
return resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestByteCounterNormalUsage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
sess := bytecounter.New()
|
||||
ctx := context.Background()
|
||||
ctx = bytecounter.WithSessionByteCounter(ctx, sess)
|
||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
exp := bytecounter.New()
|
||||
ctx = bytecounter.WithExperimentByteCounter(ctx, exp)
|
||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if exp.Received.Load() <= 0 {
|
||||
t.Fatal("experiment should have received some bytes")
|
||||
}
|
||||
if sess.Received.Load() <= exp.Received.Load() {
|
||||
t.Fatal("session should have received more than experiment")
|
||||
}
|
||||
if exp.Sent.Load() <= 0 {
|
||||
t.Fatal("experiment should have sent some bytes")
|
||||
}
|
||||
if sess.Sent.Load() <= exp.Sent.Load() {
|
||||
t.Fatal("session should have sent more than experiment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterNoHandlers(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
ctx := context.Background()
|
||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterConnectFailure(t *testing.T) {
|
||||
dialer := &byteCounterDialer{Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
},
|
||||
}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user