diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index effd47f..7a26a4e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,7 +63,8 @@ run `go mod tidy` to minimize such changes. - use `./internal/fsx.OpenFile` when you need to open a file -- use `./internal/iox.ReadAllContext` instead of `ioutil.ReadAll` +- use `./internal/iox.ReadAllContext` instead of `io.ReadAll` +and `./internal/iox.CopyContext` instead of `io.Copy` ## Code testing requirements diff --git a/cmd/ooniprobe/internal/config/parser_test.go b/cmd/ooniprobe/internal/config/parser_test.go index 255cc9c..a468dbf 100644 --- a/cmd/ooniprobe/internal/config/parser_test.go +++ b/cmd/ooniprobe/internal/config/parser_test.go @@ -1,12 +1,14 @@ package config import ( + "context" "crypto/sha256" "encoding/hex" - "io" "io/ioutil" "os" "testing" + + "github.com/ooni/probe-cli/v3/internal/iox" ) func getShasum(path string) (string, error) { @@ -17,7 +19,7 @@ func getShasum(path string) (string, error) { return "", err } defer f.Close() - if _, err := io.Copy(hasher, f); err != nil { + if _, err := iox.CopyContext(context.Background(), hasher, f); err != nil { return "", err } return hex.EncodeToString(hasher.Sum(nil)), nil diff --git a/internal/engine/experiment/ndt7/download.go b/internal/engine/experiment/ndt7/download.go index 447cde2..1789aa2 100644 --- a/internal/engine/experiment/ndt7/download.go +++ b/internal/engine/experiment/ndt7/download.go @@ -82,7 +82,7 @@ func (mgr downloadManager) doRun(ctx context.Context) error { } continue } - n, err := io.Copy(io.Discard, reader) + n, err := iox.CopyContext(ctx, io.Discard, reader) if err != nil { return err } diff --git a/internal/engine/experiment/ndt7/ndt7.go b/internal/engine/experiment/ndt7/ndt7.go index d5b7f12..7d0070a 100644 --- a/internal/engine/experiment/ndt7/ndt7.go +++ b/internal/engine/experiment/ndt7/ndt7.go @@ -19,7 +19,7 @@ import ( const ( testName = "ndt" - testVersion = "0.8.0" + testVersion = "0.9.0" ) // Config contains the experiment settings diff --git a/internal/engine/experiment/ndt7/ndt7_test.go b/internal/engine/experiment/ndt7/ndt7_test.go index 1e38b69..68e0bb3 100644 --- a/internal/engine/experiment/ndt7/ndt7_test.go +++ b/internal/engine/experiment/ndt7/ndt7_test.go @@ -17,7 +17,7 @@ func TestNewExperimentMeasurer(t *testing.T) { if measurer.ExperimentName() != "ndt" { t.Fatal("unexpected name") } - if measurer.ExperimentVersion() != "0.8.0" { + if measurer.ExperimentVersion() != "0.9.0" { t.Fatal("unexpected version") } } diff --git a/internal/engine/experiment/urlgetter/runner.go b/internal/engine/experiment/urlgetter/runner.go index 3d7522d..fd86945 100644 --- a/internal/engine/experiment/urlgetter/runner.go +++ b/internal/engine/experiment/urlgetter/runner.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "io/ioutil" "net/http" "net/http/cookiejar" @@ -13,6 +12,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/httpheader" "github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" + "github.com/ooni/probe-cli/v3/internal/iox" "github.com/ooni/probe-cli/v3/internal/runtimex" ) @@ -92,7 +92,7 @@ func (r Runner) httpGet(ctx context.Context, url string) error { return err } defer resp.Body.Close() - if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil { + if _, err = iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil { return err } // Implementation note: we shall check for this error once we have read the diff --git a/internal/engine/netx/dialer/bytecounter_test.go b/internal/engine/netx/dialer/bytecounter_test.go index 2e9ff65..2f03804 100644 --- a/internal/engine/netx/dialer/bytecounter_test.go +++ b/internal/engine/netx/dialer/bytecounter_test.go @@ -11,6 +11,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" + "github.com/ooni/probe-cli/v3/internal/iox" ) func dorequest(ctx context.Context, url string) error { @@ -27,7 +28,7 @@ func dorequest(ctx context.Context, url string) error { if err != nil { return err } - if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil { + if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil { return err } return resp.Body.Close() diff --git a/internal/engine/session_integration_test.go b/internal/engine/session_integration_test.go index 87674cd..fc8e051 100644 --- a/internal/engine/session_integration_test.go +++ b/internal/engine/session_integration_test.go @@ -3,7 +3,6 @@ package engine import ( "context" "errors" - "io" "io/ioutil" "net/http" "net/http/httptest" @@ -17,6 +16,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/geolocate" "github.com/ooni/probe-cli/v3/internal/engine/model" "github.com/ooni/probe-cli/v3/internal/engine/probeservices" + "github.com/ooni/probe-cli/v3/internal/iox" "github.com/ooni/probe-cli/v3/internal/version" ) @@ -31,7 +31,8 @@ func TestSessionByteCounter(t *testing.T) { t.Fatal(err) } defer resp.Body.Close() - if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil { + ctx := context.Background() + if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil { t.Fatal(err) } if s.KibiBytesSent() <= 0 || s.KibiBytesReceived() <= 0 { diff --git a/internal/fsx/example_test.go b/internal/fsx/example_test.go index 8457858..853947c 100644 --- a/internal/fsx/example_test.go +++ b/internal/fsx/example_test.go @@ -1,14 +1,15 @@ package fsx_test import ( + "context" "errors" "fmt" - "io" "log" "path/filepath" "syscall" "github.com/ooni/probe-cli/v3/internal/fsx" + "github.com/ooni/probe-cli/v3/internal/iox" ) func ExampleOpenFile_openingDir() { @@ -26,7 +27,7 @@ func ExampleOpenFile_openingFile() { if err != nil { log.Fatal("unexpected error", err) } - data, err := io.ReadAll(filep) + data, err := iox.ReadAllContext(context.Background(), filep) if err != nil { log.Fatal("unexpected error", err) } diff --git a/internal/humanize/humanizex.go b/internal/humanize/humanize.go similarity index 91% rename from internal/humanize/humanizex.go rename to internal/humanize/humanize.go index 5a5934b..0f31f7e 100644 --- a/internal/humanize/humanizex.go +++ b/internal/humanize/humanize.go @@ -7,7 +7,7 @@ import "fmt" // specially tailored for printing download speeds. func SI(value float64, unit string) string { value, prefix := reduce(value) - return fmt.Sprintf("%3.0f %s%s", value, prefix, unit) + return fmt.Sprintf("%6.2f %s%s", value, prefix, unit) } // reduce reduces value to a base value and a unit prefix. For diff --git a/internal/humanize/humanize_test.go b/internal/humanize/humanize_test.go new file mode 100644 index 0000000..20cbb74 --- /dev/null +++ b/internal/humanize/humanize_test.go @@ -0,0 +1,30 @@ +package humanize + +import "testing" + +func TestGood(t *testing.T) { + if v := SI(128, "bit/s"); v != "128.00 bit/s" { + t.Fatal("unexpected result", v) + } + if v := SI(1280, "bit/s"); v != " 1.28 kbit/s" { + t.Fatal("unexpected result", v) + } + if v := SI(12800, "bit/s"); v != " 12.80 kbit/s" { + t.Fatal("unexpected result", v) + } + if v := SI(128000, "bit/s"); v != "128.00 kbit/s" { + t.Fatal("unexpected result", v) + } + if v := SI(1280000, "bit/s"); v != " 1.28 Mbit/s" { + t.Fatal("unexpected result", v) + } + if v := SI(12800000, "bit/s"); v != " 12.80 Mbit/s" { + t.Fatal("unexpected result", v) + } + if v := SI(128000000, "bit/s"); v != "128.00 Mbit/s" { + t.Fatal("unexpected result", v) + } + if v := SI(1280000000, "bit/s"); v != " 1.28 Gbit/s" { + t.Fatal("unexpected result", v) + } +} diff --git a/internal/humanize/humanizex_test.go b/internal/humanize/humanizex_test.go deleted file mode 100644 index c7286ac..0000000 --- a/internal/humanize/humanizex_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package humanize - -import "testing" - -func TestGood(t *testing.T) { - if SI(128, "bit/s") != "128 bit/s" { - t.Fatal("unexpected result") - } - if SI(1280, "bit/s") != " 1 kbit/s" { - t.Fatal("unexpected result") - } - if SI(12800, "bit/s") != " 13 kbit/s" { - t.Fatal("unexpected result") - } - if SI(128000, "bit/s") != "128 kbit/s" { - t.Fatal("unexpected result") - } - if SI(1280000, "bit/s") != " 1 Mbit/s" { - t.Fatal("unexpected result") - } - if SI(12800000, "bit/s") != " 13 Mbit/s" { - t.Fatal("unexpected result") - } - if SI(128000000, "bit/s") != "128 Mbit/s" { - t.Fatal("unexpected result") - } - if SI(1280000000, "bit/s") != " 1 Gbit/s" { - t.Fatal("unexpected result") - } -} diff --git a/internal/iox/iox.go b/internal/iox/iox.go index f083d53..bf16cd5 100644 --- a/internal/iox/iox.go +++ b/internal/iox/iox.go @@ -6,7 +6,7 @@ import ( "io" ) -// ReadAllContext reads the whole reader r in a +// ReadAllContext is like io.ReadAll but reads r in a // background goroutine. This function will return // earlier if the context is cancelled. In which case // we will continue reading from r in the background @@ -45,3 +45,27 @@ var _ io.Reader = &MockableReader{} func (r *MockableReader) Read(b []byte) (int, error) { return r.MockRead(b) } + +// CopyContext is like io.Copy but may terminate earlier +// when the context expires. This function has the same +// caveats of ReadAllContext regarding the temporary leaking +// of the background goroutine used to do I/O. +func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { + countch, errch := make(chan int64, 1), make(chan error, 1) // buffers + go func() { + count, err := io.Copy(dst, src) + if err != nil { + errch <- err + return + } + countch <- count + }() + select { + case count := <-countch: + return count, nil + case <-ctx.Done(): + return 0, ctx.Err() + case err := <-errch: + return 0, err + } +} diff --git a/internal/iox/iox_test.go b/internal/iox/iox_test.go index ccd2d0b..0bb31df 100644 --- a/internal/iox/iox_test.go +++ b/internal/iox/iox_test.go @@ -3,6 +3,7 @@ package iox import ( "context" "errors" + "io" "strings" "testing" "time" @@ -68,3 +69,64 @@ func TestReadAllContextWithErrorAndCancelledContext(t *testing.T) { t.Fatal("not the expected number of bytes") } } + +func TestCopyContextCommonCase(t *testing.T) { + r := strings.NewReader("deadbeef") + ctx := context.Background() + out, err := CopyContext(ctx, io.Discard, r) + if err != nil { + t.Fatal(err) + } + if out != 8 { + t.Fatal("not the expected number of bytes") + } +} + +func TestCopyContextWithError(t *testing.T) { + expected := errors.New("mocked error") + r := &MockableReader{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + } + ctx := context.Background() + out, err := CopyContext(ctx, io.Discard, r) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if out != 0 { + t.Fatal("not the expected number of bytes") + } +} + +func TestCopyContextWithCancelledContext(t *testing.T) { + r := strings.NewReader("deadbeef") + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + out, err := CopyContext(ctx, io.Discard, r) + if !errors.Is(err, context.Canceled) { + t.Fatal("not the error we expected", err) + } + if out != 0 { + t.Fatal("not the expected number of bytes") + } +} + +func TestCopyContextWithErrorAndCancelledContext(t *testing.T) { + expected := errors.New("mocked error") + r := &MockableReader{ + MockRead: func(b []byte) (int, error) { + time.Sleep(time.Millisecond) + return 0, expected + }, + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + out, err := CopyContext(ctx, io.Discard, r) + if !errors.Is(err, context.Canceled) { + t.Fatal("not the error we expected", err) + } + if out != 0 { + t.Fatal("not the expected number of bytes") + } +} diff --git a/internal/ptx/ptx.go b/internal/ptx/ptx.go index 760ffd1..0094141 100644 --- a/internal/ptx/ptx.go +++ b/internal/ptx/ptx.go @@ -40,12 +40,12 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import ( "context" "fmt" - "io" "net" "strings" "sync" pt "git.torproject.org/pluggable-transports/goptlib.git" + "github.com/ooni/probe-cli/v3/internal/iox" ) // PTDialer is a generic pluggable transports dialer. @@ -104,17 +104,17 @@ func (lst *Listener) logger() Logger { // forward forwards the traffic from left to right and from right to left // and closes the done channel when it is done. This function DOES NOT // take ownership of the left, right net.Conn arguments. -func (lst *Listener) forward(left, right net.Conn, done chan struct{}) { +func (lst *Listener) forward(ctx context.Context, left, right net.Conn, done chan struct{}) { defer close(done) // signal termination wg := new(sync.WaitGroup) wg.Add(2) go func() { defer wg.Done() - io.Copy(left, right) + iox.CopyContext(ctx, left, right) }() go func() { defer wg.Done() - io.Copy(right, left) + iox.CopyContext(ctx, right, left) }() wg.Wait() } @@ -127,7 +127,7 @@ func (lst *Listener) forwardWithContext(ctx context.Context, left, right net.Con defer left.Close() defer right.Close() done := make(chan struct{}) - go lst.forward(left, right, done) + go lst.forward(ctx, left, right, done) select { case <-ctx.Done(): case <-done: @@ -150,7 +150,7 @@ func (lst *Listener) handleSocksConn(ctx context.Context, socksConn ptxSocksConn return err // used for testing } lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership - return nil // used for testing + return nil // used for testing } // ptxSocksListener is a pt.SocksListener-like structure.