refactor(ndt7): use netxlite rather than netx (#768)
This diff required us to move some code around, but no major change actually happened, except better tests. While there, I also slightly refactored ndt7's implementation and removed the ProxyURL setting, which was actually unused. See https://github.com/ooni/probe/issues/2121
This commit is contained in:
parent
314c3c934d
commit
3265bc670a
52
internal/bytecounter/dialer.go
Normal file
52
internal/bytecounter/dialer.go
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
package bytecounter
|
||||||
|
|
||||||
|
//
|
||||||
|
// model.Dialer wrappers
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContextAwareDialer is a model.Dialer that attempts to count bytes using
|
||||||
|
// the MaybeWrapWithContextByteCounters function.
|
||||||
|
//
|
||||||
|
// Bug
|
||||||
|
//
|
||||||
|
// This implementation cannot properly account for the bytes that are sent by
|
||||||
|
// persistent connections, because they stick to the counters set when the
|
||||||
|
// connection was established. This typically means we miss the bytes sent and
|
||||||
|
// received when submitting a measurement. Such bytes are specifically not
|
||||||
|
// seen by the experiment specific byte counter.
|
||||||
|
//
|
||||||
|
// For this reason, this implementation may be heavily changed/removed
|
||||||
|
// in the future (<- this message is now ~two years old, though).
|
||||||
|
type ContextAwareDialer struct {
|
||||||
|
Dialer model.Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContextAwareDialer creates a new ContextAwareDialer.
|
||||||
|
func NewContextAwareDialer(dialer model.Dialer) *ContextAwareDialer {
|
||||||
|
return &ContextAwareDialer{Dialer: dialer}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.Dialer = &ContextAwareDialer{}
|
||||||
|
|
||||||
|
// DialContext implements Dialer.DialContext
|
||||||
|
func (d *ContextAwareDialer) DialContext(
|
||||||
|
ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn = MaybeWrapWithContextByteCounters(ctx, conn)
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||||
|
func (d *ContextAwareDialer) CloseIdleConnections() {
|
||||||
|
d.Dialer.CloseIdleConnections()
|
||||||
|
}
|
101
internal/bytecounter/dialer_test.go
Normal file
101
internal/bytecounter/dialer_test.go
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
package bytecounter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextAwareDialer(t *testing.T) {
|
||||||
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
|
dialAndUseConn := func(ctx context.Context, bufsiz int) error {
|
||||||
|
childConn := &mocks.Conn{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
child := &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return childConn, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := NewContextAwareDialer(child)
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", "10.0.0.1:443")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
buffer := make([]byte, bufsiz)
|
||||||
|
conn.Read(buffer)
|
||||||
|
conn.Write(buffer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("normal usage", func(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skip test in short mode")
|
||||||
|
}
|
||||||
|
sess := New()
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = WithSessionByteCounter(ctx, sess)
|
||||||
|
const count = 128
|
||||||
|
if err := dialAndUseConn(ctx, count); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
exp := New()
|
||||||
|
ctx = WithExperimentByteCounter(ctx, exp)
|
||||||
|
if err := dialAndUseConn(ctx, count); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if exp.Received.Load() != count {
|
||||||
|
t.Fatal("experiment should have received 128 bytes")
|
||||||
|
}
|
||||||
|
if sess.Received.Load() != 2*count {
|
||||||
|
t.Fatal("session should have received 256 bytes")
|
||||||
|
}
|
||||||
|
if exp.Sent.Load() != count {
|
||||||
|
t.Fatal("experiment should have sent 128 bytes")
|
||||||
|
}
|
||||||
|
if sess.Sent.Load() != 256 {
|
||||||
|
t.Fatal("session should have sent 256 bytes")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("failure", func(t *testing.T) {
|
||||||
|
dialer := &ContextAwareDialer{
|
||||||
|
Dialer: &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
|
return nil, io.EOF
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
child := &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := NewContextAwareDialer(child)
|
||||||
|
dialer.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,10 +0,0 @@
|
||||||
package ndt7
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
func defaultCallbackJSON(data []byte) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultCallbackPerformance(elapsed time.Duration, count int64) {
|
|
||||||
}
|
|
|
@ -4,10 +4,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
@ -15,7 +14,6 @@ import (
|
||||||
type dialManager struct {
|
type dialManager struct {
|
||||||
ndt7URL string
|
ndt7URL string
|
||||||
logger model.Logger
|
logger model.Logger
|
||||||
proxyURL *url.URL
|
|
||||||
readBufferSize int
|
readBufferSize int
|
||||||
userAgent string
|
userAgent string
|
||||||
writeBufferSize int
|
writeBufferSize int
|
||||||
|
@ -32,16 +30,9 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
|
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
|
||||||
var reso model.Resolver = &netxlite.ResolverSystem{}
|
reso := netxlite.NewResolverStdlib(mgr.logger)
|
||||||
reso = &netxlite.ResolverLogger{
|
dlr := netxlite.NewDialerWithResolver(mgr.logger, reso)
|
||||||
Resolver: reso,
|
dlr = bytecounter.NewContextAwareDialer(dlr)
|
||||||
Logger: mgr.logger,
|
|
||||||
}
|
|
||||||
dlr := dialer.New(&dialer.Config{
|
|
||||||
ContextByteCounting: true,
|
|
||||||
Logger: mgr.logger,
|
|
||||||
ProxyURL: mgr.proxyURL,
|
|
||||||
}, reso)
|
|
||||||
// Implements shaping if the user builds using `-tags shaping`
|
// Implements shaping if the user builds using `-tags shaping`
|
||||||
// See https://github.com/ooni/probe/issues/2112
|
// See https://github.com/ooni/probe/issues/2112
|
||||||
dlr = netxlite.NewMaybeShapingDialer(dlr)
|
dlr = netxlite.NewMaybeShapingDialer(dlr)
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type downloadManager struct {
|
type downloadManager struct {
|
||||||
conn mockableConn
|
conn wsConn
|
||||||
maxMessageSize int64
|
maxMessageSize int64
|
||||||
maxRuntime time.Duration
|
maxRuntime time.Duration
|
||||||
measureInterval time.Duration
|
measureInterval time.Duration
|
||||||
|
@ -20,7 +20,7 @@ type downloadManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDownloadManager(
|
func newDownloadManager(
|
||||||
conn mockableConn, onPerformance callbackPerformance,
|
conn wsConn, onPerformance callbackPerformance,
|
||||||
onJSON callbackJSON,
|
onJSON callbackJSON,
|
||||||
) downloadManager {
|
) downloadManager {
|
||||||
return downloadManager{
|
return downloadManager{
|
||||||
|
|
|
@ -12,10 +12,17 @@ import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func defaultCallbackJSON(data []byte) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultCallbackPerformance(elapsed time.Duration, count int64) {
|
||||||
|
}
|
||||||
|
|
||||||
func TestDownloadSetReadDeadlineFailure(t *testing.T) {
|
func TestDownloadSetReadDeadlineFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newDownloadManager(
|
mgr := newDownloadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
ReadDeadlineErr: expected,
|
ReadDeadlineErr: expected,
|
||||||
},
|
},
|
||||||
defaultCallbackPerformance,
|
defaultCallbackPerformance,
|
||||||
|
@ -30,7 +37,7 @@ func TestDownloadSetReadDeadlineFailure(t *testing.T) {
|
||||||
func TestDownloadNextReaderFailure(t *testing.T) {
|
func TestDownloadNextReaderFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newDownloadManager(
|
mgr := newDownloadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
NextReaderErr: expected,
|
NextReaderErr: expected,
|
||||||
},
|
},
|
||||||
defaultCallbackPerformance,
|
defaultCallbackPerformance,
|
||||||
|
@ -45,7 +52,7 @@ func TestDownloadNextReaderFailure(t *testing.T) {
|
||||||
func TestDownloadTextMessageReadAllFailure(t *testing.T) {
|
func TestDownloadTextMessageReadAllFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newDownloadManager(
|
mgr := newDownloadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
NextReaderMsgType: websocket.TextMessage,
|
NextReaderMsgType: websocket.TextMessage,
|
||||||
NextReaderReader: func() io.Reader {
|
NextReaderReader: func() io.Reader {
|
||||||
return &alwaysFailingReader{
|
return &alwaysFailingReader{
|
||||||
|
@ -73,7 +80,7 @@ func (r *alwaysFailingReader) Read(p []byte) (int, error) {
|
||||||
func TestDownloadBinaryMessageReadAllFailure(t *testing.T) {
|
func TestDownloadBinaryMessageReadAllFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newDownloadManager(
|
mgr := newDownloadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
NextReaderMsgType: websocket.BinaryMessage,
|
NextReaderMsgType: websocket.BinaryMessage,
|
||||||
NextReaderReader: func() io.Reader {
|
NextReaderReader: func() io.Reader {
|
||||||
return &alwaysFailingReader{
|
return &alwaysFailingReader{
|
||||||
|
@ -92,7 +99,7 @@ func TestDownloadBinaryMessageReadAllFailure(t *testing.T) {
|
||||||
|
|
||||||
func TestDownloadOnJSONCallbackError(t *testing.T) {
|
func TestDownloadOnJSONCallbackError(t *testing.T) {
|
||||||
mgr := newDownloadManager(
|
mgr := newDownloadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
NextReaderMsgType: websocket.TextMessage,
|
NextReaderMsgType: websocket.TextMessage,
|
||||||
NextReaderReader: func() io.Reader {
|
NextReaderReader: func() io.Reader {
|
||||||
return &invalidJSONReader{}
|
return &invalidJSONReader{}
|
||||||
|
@ -121,7 +128,7 @@ func TestDownloadOnJSONLoop(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
mgr := newDownloadManager(
|
mgr := newDownloadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
NextReaderMsgType: websocket.TextMessage,
|
NextReaderMsgType: websocket.TextMessage,
|
||||||
NextReaderReader: func() io.Reader {
|
NextReaderReader: func() io.Reader {
|
||||||
return &goodJSONReader{}
|
return &goodJSONReader{}
|
||||||
|
|
|
@ -12,7 +12,7 @@ func newMessage(n int) (*websocket.PreparedMessage, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type uploadManager struct {
|
type uploadManager struct {
|
||||||
conn mockableConn
|
conn wsConn
|
||||||
fractionForScaling int64
|
fractionForScaling int64
|
||||||
maxRuntime time.Duration
|
maxRuntime time.Duration
|
||||||
maxMessageSize int
|
maxMessageSize int
|
||||||
|
@ -24,7 +24,7 @@ type uploadManager struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUploadManager(
|
func newUploadManager(
|
||||||
conn mockableConn, onPerformance callbackPerformance,
|
conn wsConn, onPerformance callbackPerformance,
|
||||||
) uploadManager {
|
) uploadManager {
|
||||||
return uploadManager{
|
return uploadManager{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
func TestUploadSetWriteDeadlineFailure(t *testing.T) {
|
func TestUploadSetWriteDeadlineFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newUploadManager(
|
mgr := newUploadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
WriteDeadlineErr: expected,
|
WriteDeadlineErr: expected,
|
||||||
},
|
},
|
||||||
defaultCallbackPerformance,
|
defaultCallbackPerformance,
|
||||||
|
@ -26,7 +26,7 @@ func TestUploadSetWriteDeadlineFailure(t *testing.T) {
|
||||||
func TestUploadNewMessageFailure(t *testing.T) {
|
func TestUploadNewMessageFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newUploadManager(
|
mgr := newUploadManager(
|
||||||
&mockableConnMock{},
|
&mockableWSConn{},
|
||||||
defaultCallbackPerformance,
|
defaultCallbackPerformance,
|
||||||
)
|
)
|
||||||
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {
|
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {
|
||||||
|
@ -41,7 +41,7 @@ func TestUploadNewMessageFailure(t *testing.T) {
|
||||||
func TestUploadWritePreparedMessageFailure(t *testing.T) {
|
func TestUploadWritePreparedMessageFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newUploadManager(
|
mgr := newUploadManager(
|
||||||
&mockableConnMock{
|
&mockableWSConn{
|
||||||
WritePreparedMessageErr: expected,
|
WritePreparedMessageErr: expected,
|
||||||
},
|
},
|
||||||
defaultCallbackPerformance,
|
defaultCallbackPerformance,
|
||||||
|
@ -55,7 +55,7 @@ func TestUploadWritePreparedMessageFailure(t *testing.T) {
|
||||||
func TestUploadWritePreparedMessageSubsequentFailure(t *testing.T) {
|
func TestUploadWritePreparedMessageSubsequentFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
mgr := newUploadManager(
|
mgr := newUploadManager(
|
||||||
&mockableConnMock{},
|
&mockableWSConn{},
|
||||||
defaultCallbackPerformance,
|
defaultCallbackPerformance,
|
||||||
)
|
)
|
||||||
var already bool
|
var already bool
|
||||||
|
@ -77,7 +77,7 @@ func TestUploadLoop(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
mgr := newUploadManager(
|
mgr := newUploadManager(
|
||||||
&mockableConnMock{},
|
&mockableWSConn{},
|
||||||
defaultCallbackPerformance,
|
defaultCallbackPerformance,
|
||||||
)
|
)
|
||||||
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {
|
mgr.newMessage = func(int) (*websocket.PreparedMessage, error) {
|
||||||
|
|
|
@ -7,7 +7,8 @@ import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockableConn interface {
|
// weConn is the interface of gorilla/websocket.Conn
|
||||||
|
type wsConn interface {
|
||||||
NextReader() (int, io.Reader, error)
|
NextReader() (int, io.Reader, error)
|
||||||
SetReadDeadline(time.Time) error
|
SetReadDeadline(time.Time) error
|
||||||
SetReadLimit(int64)
|
SetReadLimit(int64)
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockableConnMock struct {
|
type mockableWSConn struct {
|
||||||
NextReaderMsgType int
|
NextReaderMsgType int
|
||||||
NextReaderErr error
|
NextReaderErr error
|
||||||
NextReaderReader func() io.Reader
|
NextReaderReader func() io.Reader
|
||||||
|
@ -16,7 +16,7 @@ type mockableConnMock struct {
|
||||||
WritePreparedMessageErr error
|
WritePreparedMessageErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockableConnMock) NextReader() (int, io.Reader, error) {
|
func (c *mockableWSConn) NextReader() (int, io.Reader, error) {
|
||||||
var reader io.Reader
|
var reader io.Reader
|
||||||
if c.NextReaderReader != nil {
|
if c.NextReaderReader != nil {
|
||||||
reader = c.NextReaderReader()
|
reader = c.NextReaderReader()
|
||||||
|
@ -24,16 +24,16 @@ func (c *mockableConnMock) NextReader() (int, io.Reader, error) {
|
||||||
return c.NextReaderMsgType, reader, c.NextReaderErr
|
return c.NextReaderMsgType, reader, c.NextReaderErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockableConnMock) SetReadDeadline(time.Time) error {
|
func (c *mockableWSConn) SetReadDeadline(time.Time) error {
|
||||||
return c.ReadDeadlineErr
|
return c.ReadDeadlineErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockableConnMock) SetReadLimit(int64) {}
|
func (c *mockableWSConn) SetReadLimit(int64) {}
|
||||||
|
|
||||||
func (c *mockableConnMock) SetWriteDeadline(time.Time) error {
|
func (c *mockableWSConn) SetWriteDeadline(time.Time) error {
|
||||||
return c.WriteDeadlineErr
|
return c.WriteDeadlineErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockableConnMock) WritePreparedMessage(*websocket.PreparedMessage) error {
|
func (c *mockableWSConn) WritePreparedMessage(*websocket.PreparedMessage) error {
|
||||||
return c.WritePreparedMessageErr
|
return c.WritePreparedMessageErr
|
||||||
}
|
}
|
|
@ -1,26 +1,5 @@
|
||||||
package dialer
|
package dialer
|
||||||
|
|
||||||
import (
|
import "github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
type byteCounterDialer = bytecounter.ContextAwareDialer
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// byteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you
|
|
||||||
// should make sure that you insert this dialer in the dialing chain.
|
|
||||||
type byteCounterDialer struct {
|
|
||||||
model.Dialer
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext implements Dialer.DialContext
|
|
||||||
func (d *byteCounterDialer) DialContext(
|
|
||||||
ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn = bytecounter.MaybeWrapWithContextByteCounters(ctx, conn)
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,91 +0,0 @@
|
||||||
package dialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
func dorequest(ctx context.Context, url string) error {
|
|
||||||
txp := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
defer txp.CloseIdleConnections()
|
|
||||||
dialer := &byteCounterDialer{Dialer: netxlite.DefaultDialer}
|
|
||||||
txp.DialContext = dialer.DialContext
|
|
||||||
client := &http.Client{Transport: txp}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := netxlite.CopyContext(ctx, io.Discard, resp.Body); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return resp.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestByteCounterNormalUsage(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skip test in short mode")
|
|
||||||
}
|
|
||||||
sess := bytecounter.New()
|
|
||||||
ctx := context.Background()
|
|
||||||
ctx = bytecounter.WithSessionByteCounter(ctx, sess)
|
|
||||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
exp := bytecounter.New()
|
|
||||||
ctx = bytecounter.WithExperimentByteCounter(ctx, exp)
|
|
||||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if exp.Received.Load() <= 0 {
|
|
||||||
t.Fatal("experiment should have received some bytes")
|
|
||||||
}
|
|
||||||
if sess.Received.Load() <= exp.Received.Load() {
|
|
||||||
t.Fatal("session should have received more than experiment")
|
|
||||||
}
|
|
||||||
if exp.Sent.Load() <= 0 {
|
|
||||||
t.Fatal("experiment should have sent some bytes")
|
|
||||||
}
|
|
||||||
if sess.Sent.Load() <= exp.Sent.Load() {
|
|
||||||
t.Fatal("session should have sent more than experiment")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestByteCounterNoHandlers(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skip test in short mode")
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestByteCounterConnectFailure(t *testing.T) {
|
|
||||||
dialer := &byteCounterDialer{Dialer: &mocks.Dialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
|
||||||
return nil, io.EOF
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
|
|
||||||
if !errors.Is(err, io.EOF) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("expected nil conn here")
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user