diff --git a/internal/bytecounter/conn.go b/internal/bytecounter/conn.go new file mode 100644 index 0000000..2201686 --- /dev/null +++ b/internal/bytecounter/conn.go @@ -0,0 +1,26 @@ +package bytecounter + +import "net" + +// Conn wraps a network connection and counts bytes. +type Conn struct { + // net.Conn is the underlying net.Conn. + net.Conn + + // Counter is the byte counter. + Counter *Counter +} + +// Read implements net.Conn.Read. +func (c *Conn) Read(p []byte) (int, error) { + count, err := c.Conn.Read(p) + c.Counter.CountBytesReceived(count) + return count, err +} + +// Write implements net.Conn.Write. +func (c *Conn) Write(p []byte) (int, error) { + count, err := c.Conn.Write(p) + c.Counter.CountBytesSent(count) + return count, err +} diff --git a/internal/bytecounter/conn_test.go b/internal/bytecounter/conn_test.go new file mode 100644 index 0000000..4b93050 --- /dev/null +++ b/internal/bytecounter/conn_test.go @@ -0,0 +1,66 @@ +package bytecounter + +import ( + "errors" + "testing" + + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" +) + +func TestConnWorksOnSuccess(t *testing.T) { + counter := New() + underlying := &mockablex.Conn{ + MockRead: func(b []byte) (int, error) { + return 10, nil + }, + MockWrite: func(b []byte) (int, error) { + return 4, nil + }, + } + conn := &Conn{ + Conn: underlying, + Counter: counter, + } + if _, err := conn.Read(make([]byte, 128)); err != nil { + t.Fatal(err) + } + if _, err := conn.Write(make([]byte, 1024)); err != nil { + t.Fatal(err) + } + if counter.BytesReceived() != 10 { + t.Fatal("unexpected number of bytes received") + } + if counter.BytesSent() != 4 { + t.Fatal("unexpected number of bytes sent") + } +} + +func TestConnWorksOnFailure(t *testing.T) { + readError := errors.New("read error") + writeError := errors.New("write error") + counter := New() + underlying := &mockablex.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, readError + }, + MockWrite: func(b []byte) (int, error) { + return 0, writeError + }, + } + conn := &Conn{ + Conn: underlying, + Counter: counter, + } + if _, err := conn.Read(make([]byte, 128)); !errors.Is(err, readError) { + t.Fatal("not the error we expected", err) + } + if _, err := conn.Write(make([]byte, 1024)); !errors.Is(err, writeError) { + t.Fatal("not the error we expected", err) + } + if counter.BytesReceived() != 0 { + t.Fatal("unexpected number of bytes received") + } + if counter.BytesSent() != 0 { + t.Fatal("unexpected number of bytes sent") + } +} diff --git a/internal/engine/netx/dialer/bytecounter.go b/internal/engine/netx/dialer/bytecounter.go index 1a8ffca..41476f1 100644 --- a/internal/engine/netx/dialer/bytecounter.go +++ b/internal/engine/netx/dialer/bytecounter.go @@ -20,12 +20,13 @@ func (d *byteCounterDialer) DialContext( if err != nil { return nil, err } - exp := contextExperimentByteCounter(ctx) - sess := contextSessionByteCounter(ctx) - if exp == nil && sess == nil { - return conn, nil // no point in wrapping + if exp := contextExperimentByteCounter(ctx); exp != nil { + conn = &bytecounter.Conn{Conn: conn, Counter: exp} } - return &byteCounterConnWrapper{Conn: conn, exp: exp, sess: sess}, nil + if sess := contextSessionByteCounter(ctx); sess != nil { + conn = &bytecounter.Conn{Conn: conn, Counter: sess} + } + return conn, nil } type byteCounterSessionKey struct{} @@ -53,31 +54,3 @@ func contextExperimentByteCounter(ctx context.Context) *bytecounter.Counter { func WithExperimentByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context { return context.WithValue(ctx, byteCounterExperimentKey{}, counter) } - -type byteCounterConnWrapper struct { - net.Conn - exp *bytecounter.Counter - sess *bytecounter.Counter -} - -func (c *byteCounterConnWrapper) Read(p []byte) (int, error) { - count, err := c.Conn.Read(p) - if c.exp != nil { - c.exp.CountBytesReceived(count) - } - if c.sess != nil { - c.sess.CountBytesReceived(count) - } - return count, err -} - -func (c *byteCounterConnWrapper) Write(p []byte) (int, error) { - count, err := c.Conn.Write(p) - if c.exp != nil { - c.exp.CountBytesSent(count) - } - if c.sess != nil { - c.sess.CountBytesSent(count) - } - return count, err -} diff --git a/internal/engine/netx/dialer/bytecounter_test.go b/internal/engine/netx/dialer/bytecounter_test.go index 948d001..c9a64ec 100644 --- a/internal/engine/netx/dialer/bytecounter_test.go +++ b/internal/engine/netx/dialer/bytecounter_test.go @@ -48,9 +48,15 @@ func TestByteCounterNormalUsage(t *testing.T) { if err := dorequest(ctx, "http://facebook.com"); err != nil { t.Fatal(err) } + if exp.Received.Load() <= 0 { + t.Fatal("experiment should have received some bytes") + } if sess.Received.Load() <= exp.Received.Load() { t.Fatal("session should have received more than experiment") } + if exp.Sent.Load() <= 0 { + t.Fatal("experiment should have sent some bytes") + } if sess.Sent.Load() <= exp.Sent.Load() { t.Fatal("session should have sent more than experiment") }