fix(all): introduce and use iox.CopyContext (#380)

* fix(all): introduce and use iox.CopyContext

This PR is part of https://github.com/ooni/probe/issues/1417.

In https://github.com/ooni/probe-cli/pull/379 we introduced a context
aware wrapper for io.ReadAll (formerly ioutil.ReadAll).

Here we introduce a context aware wrapper for io.Copy.

* fix(humanize): more significant digits

* fix: rename humanize files to follow the common pattern

* fix aligment

* fix test
This commit is contained in:
Simone Basso 2021-06-15 13:44:28 +02:00 committed by GitHub
parent 0fdc9cafb5
commit 721ce95315
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 143 additions and 51 deletions

View File

@ -63,7 +63,8 @@ run `go mod tidy` to minimize such changes.
- use `./internal/fsx.OpenFile` when you need to open a file - use `./internal/fsx.OpenFile` when you need to open a file
- use `./internal/iox.ReadAllContext` instead of `ioutil.ReadAll` - use `./internal/iox.ReadAllContext` instead of `io.ReadAll`
and `./internal/iox.CopyContext` instead of `io.Copy`
## Code testing requirements ## Code testing requirements

View File

@ -1,12 +1,14 @@
package config package config
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"io"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"github.com/ooni/probe-cli/v3/internal/iox"
) )
func getShasum(path string) (string, error) { func getShasum(path string) (string, error) {
@ -17,7 +19,7 @@ func getShasum(path string) (string, error) {
return "", err return "", err
} }
defer f.Close() defer f.Close()
if _, err := io.Copy(hasher, f); err != nil { if _, err := iox.CopyContext(context.Background(), hasher, f); err != nil {
return "", err return "", err
} }
return hex.EncodeToString(hasher.Sum(nil)), nil return hex.EncodeToString(hasher.Sum(nil)), nil

View File

@ -82,7 +82,7 @@ func (mgr downloadManager) doRun(ctx context.Context) error {
} }
continue continue
} }
n, err := io.Copy(io.Discard, reader) n, err := iox.CopyContext(ctx, io.Discard, reader)
if err != nil { if err != nil {
return err return err
} }

View File

@ -19,7 +19,7 @@ import (
const ( const (
testName = "ndt" testName = "ndt"
testVersion = "0.8.0" testVersion = "0.9.0"
) )
// Config contains the experiment settings // Config contains the experiment settings

View File

@ -17,7 +17,7 @@ func TestNewExperimentMeasurer(t *testing.T) {
if measurer.ExperimentName() != "ndt" { if measurer.ExperimentName() != "ndt" {
t.Fatal("unexpected name") t.Fatal("unexpected name")
} }
if measurer.ExperimentVersion() != "0.8.0" { if measurer.ExperimentVersion() != "0.9.0" {
t.Fatal("unexpected version") t.Fatal("unexpected version")
} }
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
@ -13,6 +12,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/httpheader" "github.com/ooni/probe-cli/v3/internal/engine/httpheader"
"github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/iox"
"github.com/ooni/probe-cli/v3/internal/runtimex" "github.com/ooni/probe-cli/v3/internal/runtimex"
) )
@ -92,7 +92,7 @@ func (r Runner) httpGet(ctx context.Context, url string) error {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil { if _, err = iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
return err return err
} }
// Implementation note: we shall check for this error once we have read the // Implementation note: we shall check for this error once we have read the

View File

@ -11,6 +11,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex"
"github.com/ooni/probe-cli/v3/internal/iox"
) )
func dorequest(ctx context.Context, url string) error { func dorequest(ctx context.Context, url string) error {
@ -27,7 +28,7 @@ func dorequest(ctx context.Context, url string) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil { if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
return err return err
} }
return resp.Body.Close() return resp.Body.Close()

View File

@ -3,7 +3,6 @@ package engine
import ( import (
"context" "context"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -17,6 +16,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/geolocate" "github.com/ooni/probe-cli/v3/internal/engine/geolocate"
"github.com/ooni/probe-cli/v3/internal/engine/model" "github.com/ooni/probe-cli/v3/internal/engine/model"
"github.com/ooni/probe-cli/v3/internal/engine/probeservices" "github.com/ooni/probe-cli/v3/internal/engine/probeservices"
"github.com/ooni/probe-cli/v3/internal/iox"
"github.com/ooni/probe-cli/v3/internal/version" "github.com/ooni/probe-cli/v3/internal/version"
) )
@ -31,7 +31,8 @@ func TestSessionByteCounter(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil { ctx := context.Background()
if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if s.KibiBytesSent() <= 0 || s.KibiBytesReceived() <= 0 { if s.KibiBytesSent() <= 0 || s.KibiBytesReceived() <= 0 {

View File

@ -1,14 +1,15 @@
package fsx_test package fsx_test
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"path/filepath" "path/filepath"
"syscall" "syscall"
"github.com/ooni/probe-cli/v3/internal/fsx" "github.com/ooni/probe-cli/v3/internal/fsx"
"github.com/ooni/probe-cli/v3/internal/iox"
) )
func ExampleOpenFile_openingDir() { func ExampleOpenFile_openingDir() {
@ -26,7 +27,7 @@ func ExampleOpenFile_openingFile() {
if err != nil { if err != nil {
log.Fatal("unexpected error", err) log.Fatal("unexpected error", err)
} }
data, err := io.ReadAll(filep) data, err := iox.ReadAllContext(context.Background(), filep)
if err != nil { if err != nil {
log.Fatal("unexpected error", err) log.Fatal("unexpected error", err)
} }

View File

@ -7,7 +7,7 @@ import "fmt"
// specially tailored for printing download speeds. // specially tailored for printing download speeds.
func SI(value float64, unit string) string { func SI(value float64, unit string) string {
value, prefix := reduce(value) value, prefix := reduce(value)
return fmt.Sprintf("%3.0f %s%s", value, prefix, unit) return fmt.Sprintf("%6.2f %s%s", value, prefix, unit)
} }
// reduce reduces value to a base value and a unit prefix. For // reduce reduces value to a base value and a unit prefix. For

View File

@ -0,0 +1,30 @@
package humanize
import "testing"
func TestGood(t *testing.T) {
if v := SI(128, "bit/s"); v != "128.00 bit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(1280, "bit/s"); v != " 1.28 kbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(12800, "bit/s"); v != " 12.80 kbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(128000, "bit/s"); v != "128.00 kbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(1280000, "bit/s"); v != " 1.28 Mbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(12800000, "bit/s"); v != " 12.80 Mbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(128000000, "bit/s"); v != "128.00 Mbit/s" {
t.Fatal("unexpected result", v)
}
if v := SI(1280000000, "bit/s"); v != " 1.28 Gbit/s" {
t.Fatal("unexpected result", v)
}
}

View File

@ -1,30 +0,0 @@
package humanize
import "testing"
func TestGood(t *testing.T) {
if SI(128, "bit/s") != "128 bit/s" {
t.Fatal("unexpected result")
}
if SI(1280, "bit/s") != " 1 kbit/s" {
t.Fatal("unexpected result")
}
if SI(12800, "bit/s") != " 13 kbit/s" {
t.Fatal("unexpected result")
}
if SI(128000, "bit/s") != "128 kbit/s" {
t.Fatal("unexpected result")
}
if SI(1280000, "bit/s") != " 1 Mbit/s" {
t.Fatal("unexpected result")
}
if SI(12800000, "bit/s") != " 13 Mbit/s" {
t.Fatal("unexpected result")
}
if SI(128000000, "bit/s") != "128 Mbit/s" {
t.Fatal("unexpected result")
}
if SI(1280000000, "bit/s") != " 1 Gbit/s" {
t.Fatal("unexpected result")
}
}

View File

@ -6,7 +6,7 @@ import (
"io" "io"
) )
// ReadAllContext reads the whole reader r in a // ReadAllContext is like io.ReadAll but reads r in a
// background goroutine. This function will return // background goroutine. This function will return
// earlier if the context is cancelled. In which case // earlier if the context is cancelled. In which case
// we will continue reading from r in the background // we will continue reading from r in the background
@ -45,3 +45,27 @@ var _ io.Reader = &MockableReader{}
func (r *MockableReader) Read(b []byte) (int, error) { func (r *MockableReader) Read(b []byte) (int, error) {
return r.MockRead(b) return r.MockRead(b)
} }
// CopyContext is like io.Copy but may terminate earlier
// when the context expires. This function has the same
// caveats of ReadAllContext regarding the temporary leaking
// of the background goroutine used to do I/O.
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
countch, errch := make(chan int64, 1), make(chan error, 1) // buffers
go func() {
count, err := io.Copy(dst, src)
if err != nil {
errch <- err
return
}
countch <- count
}()
select {
case count := <-countch:
return count, nil
case <-ctx.Done():
return 0, ctx.Err()
case err := <-errch:
return 0, err
}
}

View File

@ -3,6 +3,7 @@ package iox
import ( import (
"context" "context"
"errors" "errors"
"io"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -68,3 +69,64 @@ func TestReadAllContextWithErrorAndCancelledContext(t *testing.T) {
t.Fatal("not the expected number of bytes") t.Fatal("not the expected number of bytes")
} }
} }
func TestCopyContextCommonCase(t *testing.T) {
r := strings.NewReader("deadbeef")
ctx := context.Background()
out, err := CopyContext(ctx, io.Discard, r)
if err != nil {
t.Fatal(err)
}
if out != 8 {
t.Fatal("not the expected number of bytes")
}
}
func TestCopyContextWithError(t *testing.T) {
expected := errors.New("mocked error")
r := &MockableReader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}
ctx := context.Background()
out, err := CopyContext(ctx, io.Discard, r)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if out != 0 {
t.Fatal("not the expected number of bytes")
}
}
func TestCopyContextWithCancelledContext(t *testing.T) {
r := strings.NewReader("deadbeef")
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
out, err := CopyContext(ctx, io.Discard, r)
if !errors.Is(err, context.Canceled) {
t.Fatal("not the error we expected", err)
}
if out != 0 {
t.Fatal("not the expected number of bytes")
}
}
func TestCopyContextWithErrorAndCancelledContext(t *testing.T) {
expected := errors.New("mocked error")
r := &MockableReader{
MockRead: func(b []byte) (int, error) {
time.Sleep(time.Millisecond)
return 0, expected
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
out, err := CopyContext(ctx, io.Discard, r)
if !errors.Is(err, context.Canceled) {
t.Fatal("not the error we expected", err)
}
if out != 0 {
t.Fatal("not the expected number of bytes")
}
}

View File

@ -40,12 +40,12 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"strings" "strings"
"sync" "sync"
pt "git.torproject.org/pluggable-transports/goptlib.git" pt "git.torproject.org/pluggable-transports/goptlib.git"
"github.com/ooni/probe-cli/v3/internal/iox"
) )
// PTDialer is a generic pluggable transports dialer. // PTDialer is a generic pluggable transports dialer.
@ -104,17 +104,17 @@ func (lst *Listener) logger() Logger {
// forward forwards the traffic from left to right and from right to left // forward forwards the traffic from left to right and from right to left
// and closes the done channel when it is done. This function DOES NOT // and closes the done channel when it is done. This function DOES NOT
// take ownership of the left, right net.Conn arguments. // take ownership of the left, right net.Conn arguments.
func (lst *Listener) forward(left, right net.Conn, done chan struct{}) { func (lst *Listener) forward(ctx context.Context, left, right net.Conn, done chan struct{}) {
defer close(done) // signal termination defer close(done) // signal termination
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
io.Copy(left, right) iox.CopyContext(ctx, left, right)
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
io.Copy(right, left) iox.CopyContext(ctx, right, left)
}() }()
wg.Wait() wg.Wait()
} }
@ -127,7 +127,7 @@ func (lst *Listener) forwardWithContext(ctx context.Context, left, right net.Con
defer left.Close() defer left.Close()
defer right.Close() defer right.Close()
done := make(chan struct{}) done := make(chan struct{})
go lst.forward(left, right, done) go lst.forward(ctx, left, right, done)
select { select {
case <-ctx.Done(): case <-ctx.Done():
case <-done: case <-done:
@ -150,7 +150,7 @@ func (lst *Listener) handleSocksConn(ctx context.Context, socksConn ptxSocksConn
return err // used for testing return err // used for testing
} }
lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership
return nil // used for testing return nil // used for testing
} }
// ptxSocksListener is a pt.SocksListener-like structure. // ptxSocksListener is a pt.SocksListener-like structure.