fix(all): introduce and use iox.CopyContext (#380)

* fix(all): introduce and use iox.CopyContext

This PR is part of https://github.com/ooni/probe/issues/1417.

In https://github.com/ooni/probe-cli/pull/379 we introduced a context
aware wrapper for io.ReadAll (formerly ioutil.ReadAll).

Here we introduce a context aware wrapper for io.Copy.

* fix(humanize): more significant digits

* fix: rename humanize files to follow the common pattern

* fix aligment

* fix test
This commit is contained in:
Simone Basso
2021-06-15 13:44:28 +02:00
committed by GitHub
parent 0fdc9cafb5
commit 721ce95315
15 changed files with 143 additions and 51 deletions
+1 -1
View File
@@ -82,7 +82,7 @@ func (mgr downloadManager) doRun(ctx context.Context) error {
}
continue
}
n, err := io.Copy(io.Discard, reader)
n, err := iox.CopyContext(ctx, io.Discard, reader)
if err != nil {
return err
}
+1 -1
View File
@@ -19,7 +19,7 @@ import (
const (
testName = "ndt"
testVersion = "0.8.0"
testVersion = "0.9.0"
)
// Config contains the experiment settings
+1 -1
View File
@@ -17,7 +17,7 @@ func TestNewExperimentMeasurer(t *testing.T) {
if measurer.ExperimentName() != "ndt" {
t.Fatal("unexpected name")
}
if measurer.ExperimentVersion() != "0.8.0" {
if measurer.ExperimentVersion() != "0.9.0" {
t.Fatal("unexpected version")
}
}
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/cookiejar"
@@ -13,6 +12,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/iox"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
@@ -92,7 +92,7 @@ func (r Runner) httpGet(ctx context.Context, url string) error {
return err
}
defer resp.Body.Close()
if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
if _, err = iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
return err
}
// Implementation note: we shall check for this error once we have read the
@@ -11,6 +11,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex"
"github.com/ooni/probe-cli/v3/internal/iox"
)
func dorequest(ctx context.Context, url string) error {
@@ -27,7 +28,7 @@ func dorequest(ctx context.Context, url string) error {
if err != nil {
return err
}
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
return err
}
return resp.Body.Close()
+3 -2
View File
@@ -3,7 +3,6 @@ package engine
import (
"context"
"errors"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
@@ -17,6 +16,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
"github.com/ooni/probe-cli/v3/internal/engine/model"
"github.com/ooni/probe-cli/v3/internal/engine/probeservices"
"github.com/ooni/probe-cli/v3/internal/iox"
"github.com/ooni/probe-cli/v3/internal/version"
)
@@ -31,7 +31,8 @@ func TestSessionByteCounter(t *testing.T) {
t.Fatal(err)
}
defer resp.Body.Close()
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
ctx := context.Background()
if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
t.Fatal(err)
}
if s.KibiBytesSent() <= 0 || s.KibiBytesReceived() <= 0 {
+3 -2
View File
@@ -1,14 +1,15 @@
package fsx_test
import (
"context"
"errors"
"fmt"
"io"
"log"
"path/filepath"
"syscall"
"github.com/ooni/probe-cli/v3/internal/fsx"
"github.com/ooni/probe-cli/v3/internal/iox"
)
func ExampleOpenFile_openingDir() {
@@ -26,7 +27,7 @@ func ExampleOpenFile_openingFile() {
if err != nil {
log.Fatal("unexpected error", err)
}
data, err := io.ReadAll(filep)
data, err := iox.ReadAllContext(context.Background(), filep)
if err != nil {
log.Fatal("unexpected error", err)
}
@@ -7,7 +7,7 @@ import "fmt"
// specially tailored for printing download speeds.
func SI(value float64, unit string) string {
value, prefix := reduce(value)
return fmt.Sprintf("%3.0f %s%s", value, prefix, unit)
return fmt.Sprintf("%6.2f %s%s", value, prefix, unit)
}
// reduce reduces value to a base value and a unit prefix. For
+30
View File
@@ -0,0 +1,30 @@
package humanize
import "testing"
func TestGood(t *testing.T) {
if v := SI(128, "bit/s"); v != "128.00 bit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(1280, "bit/s"); v != " 1.28 kbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(12800, "bit/s"); v != " 12.80 kbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(128000, "bit/s"); v != "128.00 kbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(1280000, "bit/s"); v != " 1.28 Mbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(12800000, "bit/s"); v != " 12.80 Mbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(128000000, "bit/s"); v != "128.00 Mbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(1280000000, "bit/s"); v != " 1.28 Gbit/s" {
t.Fatal("unexpected result", v)
}
}
-30
View File
@@ -1,30 +0,0 @@
package humanize
import "testing"
func TestGood(t *testing.T) {
if SI(128, "bit/s") != "128 bit/s" {
t.Fatal("unexpected result")
}
if SI(1280, "bit/s") != " 1 kbit/s" {
t.Fatal("unexpected result")
}
if SI(12800, "bit/s") != " 13 kbit/s" {
t.Fatal("unexpected result")
}
if SI(128000, "bit/s") != "128 kbit/s" {
t.Fatal("unexpected result")
}
if SI(1280000, "bit/s") != " 1 Mbit/s" {
t.Fatal("unexpected result")
}
if SI(12800000, "bit/s") != " 13 Mbit/s" {
t.Fatal("unexpected result")
}
if SI(128000000, "bit/s") != "128 Mbit/s" {
t.Fatal("unexpected result")
}
if SI(1280000000, "bit/s") != " 1 Gbit/s" {
t.Fatal("unexpected result")
}
}
+25 -1
View File
@@ -6,7 +6,7 @@ import (
"io"
)
// ReadAllContext reads the whole reader r in a
// ReadAllContext is like io.ReadAll but reads r in a
// background goroutine. This function will return
// earlier if the context is cancelled. In which case
// we will continue reading from r in the background
@@ -45,3 +45,27 @@ var _ io.Reader = &MockableReader{}
func (r *MockableReader) Read(b []byte) (int, error) {
return r.MockRead(b)
}
// CopyContext is like io.Copy but may terminate earlier
// when the context expires. This function has the same
// caveats of ReadAllContext regarding the temporary leaking
// of the background goroutine used to do I/O.
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
countch, errch := make(chan int64, 1), make(chan error, 1) // buffers
go func() {
count, err := io.Copy(dst, src)
if err != nil {
errch <- err
return
}
countch <- count
}()
select {
case count := <-countch:
return count, nil
case <-ctx.Done():
return 0, ctx.Err()
case err := <-errch:
return 0, err
}
}
+62
View File
@@ -3,6 +3,7 @@ package iox
import (
"context"
"errors"
"io"
"strings"
"testing"
"time"
@@ -68,3 +69,64 @@ func TestReadAllContextWithErrorAndCancelledContext(t *testing.T) {
t.Fatal("not the expected number of bytes")
}
}
func TestCopyContextCommonCase(t *testing.T) {
r := strings.NewReader("deadbeef")
ctx := context.Background()
out, err := CopyContext(ctx, io.Discard, r)
if err != nil {
t.Fatal(err)
}
if out != 8 {
t.Fatal("not the expected number of bytes")
}
}
func TestCopyContextWithError(t *testing.T) {
expected := errors.New("mocked error")
r := &MockableReader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}
ctx := context.Background()
out, err := CopyContext(ctx, io.Discard, r)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if out != 0 {
t.Fatal("not the expected number of bytes")
}
}
func TestCopyContextWithCancelledContext(t *testing.T) {
r := strings.NewReader("deadbeef")
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
out, err := CopyContext(ctx, io.Discard, r)
if !errors.Is(err, context.Canceled) {
t.Fatal("not the error we expected", err)
}
if out != 0 {
t.Fatal("not the expected number of bytes")
}
}
func TestCopyContextWithErrorAndCancelledContext(t *testing.T) {
expected := errors.New("mocked error")
r := &MockableReader{
MockRead: func(b []byte) (int, error) {
time.Sleep(time.Millisecond)
return 0, expected
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
out, err := CopyContext(ctx, io.Discard, r)
if !errors.Is(err, context.Canceled) {
t.Fatal("not the error we expected", err)
}
if out != 0 {
t.Fatal("not the expected number of bytes")
}
}
+6 -6
View File
@@ -40,12 +40,12 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import (
"context"
"fmt"
"io"
"net"
"strings"
"sync"
pt "git.torproject.org/pluggable-transports/goptlib.git"
"github.com/ooni/probe-cli/v3/internal/iox"
)
// PTDialer is a generic pluggable transports dialer.
@@ -104,17 +104,17 @@ func (lst *Listener) logger() Logger {
// forward forwards the traffic from left to right and from right to left
// and closes the done channel when it is done. This function DOES NOT
// take ownership of the left, right net.Conn arguments.
func (lst *Listener) forward(left, right net.Conn, done chan struct{}) {
func (lst *Listener) forward(ctx context.Context, left, right net.Conn, done chan struct{}) {
defer close(done) // signal termination
wg := new(sync.WaitGroup)
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(left, right)
iox.CopyContext(ctx, left, right)
}()
go func() {
defer wg.Done()
io.Copy(right, left)
iox.CopyContext(ctx, right, left)
}()
wg.Wait()
}
@@ -127,7 +127,7 @@ func (lst *Listener) forwardWithContext(ctx context.Context, left, right net.Con
defer left.Close()
defer right.Close()
done := make(chan struct{})
go lst.forward(left, right, done)
go lst.forward(ctx, left, right, done)
select {
case <-ctx.Done():
case <-done:
@@ -150,7 +150,7 @@ func (lst *Listener) handleSocksConn(ctx context.Context, socksConn ptxSocksConn
return err // used for testing
}
lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership
return nil // used for testing
return nil // used for testing
}
// ptxSocksListener is a pt.SocksListener-like structure.