fix(all): introduce and use iox.CopyContext (#380)
* fix(all): introduce and use iox.CopyContext This PR is part of https://github.com/ooni/probe/issues/1417. In https://github.com/ooni/probe-cli/pull/379 we introduced a context aware wrapper for io.ReadAll (formerly ioutil.ReadAll). Here we introduce a context aware wrapper for io.Copy. * fix(humanize): more significant digits * fix: rename humanize files to follow the common pattern * fix aligment * fix test
This commit is contained in:
parent
0fdc9cafb5
commit
721ce95315
|
@ -63,7 +63,8 @@ run `go mod tidy` to minimize such changes.
|
|||
|
||||
- use `./internal/fsx.OpenFile` when you need to open a file
|
||||
|
||||
- use `./internal/iox.ReadAllContext` instead of `ioutil.ReadAll`
|
||||
- use `./internal/iox.ReadAllContext` instead of `io.ReadAll`
|
||||
and `./internal/iox.CopyContext` instead of `io.Copy`
|
||||
|
||||
## Code testing requirements
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||
)
|
||||
|
||||
func getShasum(path string) (string, error) {
|
||||
|
@ -17,7 +19,7 @@ func getShasum(path string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
if _, err := io.Copy(hasher, f); err != nil {
|
||||
if _, err := iox.CopyContext(context.Background(), hasher, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||
|
|
|
@ -82,7 +82,7 @@ func (mgr downloadManager) doRun(ctx context.Context) error {
|
|||
}
|
||||
continue
|
||||
}
|
||||
n, err := io.Copy(io.Discard, reader)
|
||||
n, err := iox.CopyContext(ctx, io.Discard, reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
|
||||
const (
|
||||
testName = "ndt"
|
||||
testVersion = "0.8.0"
|
||||
testVersion = "0.9.0"
|
||||
)
|
||||
|
||||
// Config contains the experiment settings
|
||||
|
|
|
@ -17,7 +17,7 @@ func TestNewExperimentMeasurer(t *testing.T) {
|
|||
if measurer.ExperimentName() != "ndt" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if measurer.ExperimentVersion() != "0.8.0" {
|
||||
if measurer.ExperimentVersion() != "0.9.0" {
|
||||
t.Fatal("unexpected version")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
|
@ -13,6 +12,7 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
|
@ -92,7 +92,7 @@ func (r Runner) httpGet(ctx context.Context, url string) error {
|
|||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
|
||||
if _, err = iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
// Implementation note: we shall check for this error once we have read the
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex"
|
||||
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||
)
|
||||
|
||||
func dorequest(ctx context.Context, url string) error {
|
||||
|
@ -27,7 +28,7 @@ func dorequest(ctx context.Context, url string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
||||
if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
return resp.Body.Close()
|
||||
|
|
|
@ -3,7 +3,6 @@ package engine
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -17,6 +16,7 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/probeservices"
|
||||
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||
"github.com/ooni/probe-cli/v3/internal/version"
|
||||
)
|
||||
|
||||
|
@ -31,7 +31,8 @@ func TestSessionByteCounter(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
||||
ctx := context.Background()
|
||||
if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s.KibiBytesSent() <= 0 || s.KibiBytesReceived() <= 0 {
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package fsx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/fsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||
)
|
||||
|
||||
func ExampleOpenFile_openingDir() {
|
||||
|
@ -26,7 +27,7 @@ func ExampleOpenFile_openingFile() {
|
|||
if err != nil {
|
||||
log.Fatal("unexpected error", err)
|
||||
}
|
||||
data, err := io.ReadAll(filep)
|
||||
data, err := iox.ReadAllContext(context.Background(), filep)
|
||||
if err != nil {
|
||||
log.Fatal("unexpected error", err)
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import "fmt"
|
|||
// specially tailored for printing download speeds.
|
||||
func SI(value float64, unit string) string {
|
||||
value, prefix := reduce(value)
|
||||
return fmt.Sprintf("%3.0f %s%s", value, prefix, unit)
|
||||
return fmt.Sprintf("%6.2f %s%s", value, prefix, unit)
|
||||
}
|
||||
|
||||
// reduce reduces value to a base value and a unit prefix. For
|
30
internal/humanize/humanize_test.go
Normal file
30
internal/humanize/humanize_test.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package humanize
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
if v := SI(128, "bit/s"); v != "128.00 bit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
if v := SI(1280, "bit/s"); v != " 1.28 kbit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
if v := SI(12800, "bit/s"); v != " 12.80 kbit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
if v := SI(128000, "bit/s"); v != "128.00 kbit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
if v := SI(1280000, "bit/s"); v != " 1.28 Mbit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
if v := SI(12800000, "bit/s"); v != " 12.80 Mbit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
if v := SI(128000000, "bit/s"); v != "128.00 Mbit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
if v := SI(1280000000, "bit/s"); v != " 1.28 Gbit/s" {
|
||||
t.Fatal("unexpected result", v)
|
||||
}
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
package humanize
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
if SI(128, "bit/s") != "128 bit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if SI(1280, "bit/s") != " 1 kbit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if SI(12800, "bit/s") != " 13 kbit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if SI(128000, "bit/s") != "128 kbit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if SI(1280000, "bit/s") != " 1 Mbit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if SI(12800000, "bit/s") != " 13 Mbit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if SI(128000000, "bit/s") != "128 Mbit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if SI(1280000000, "bit/s") != " 1 Gbit/s" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
|
@ -6,7 +6,7 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
// ReadAllContext reads the whole reader r in a
|
||||
// ReadAllContext is like io.ReadAll but reads r in a
|
||||
// background goroutine. This function will return
|
||||
// earlier if the context is cancelled. In which case
|
||||
// we will continue reading from r in the background
|
||||
|
@ -45,3 +45,27 @@ var _ io.Reader = &MockableReader{}
|
|||
func (r *MockableReader) Read(b []byte) (int, error) {
|
||||
return r.MockRead(b)
|
||||
}
|
||||
|
||||
// CopyContext is like io.Copy but may terminate earlier
|
||||
// when the context expires. This function has the same
|
||||
// caveats of ReadAllContext regarding the temporary leaking
|
||||
// of the background goroutine used to do I/O.
|
||||
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||
countch, errch := make(chan int64, 1), make(chan error, 1) // buffers
|
||||
go func() {
|
||||
count, err := io.Copy(dst, src)
|
||||
if err != nil {
|
||||
errch <- err
|
||||
return
|
||||
}
|
||||
countch <- count
|
||||
}()
|
||||
select {
|
||||
case count := <-countch:
|
||||
return count, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
case err := <-errch:
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package iox
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -68,3 +69,64 @@ func TestReadAllContextWithErrorAndCancelledContext(t *testing.T) {
|
|||
t.Fatal("not the expected number of bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyContextCommonCase(t *testing.T) {
|
||||
r := strings.NewReader("deadbeef")
|
||||
ctx := context.Background()
|
||||
out, err := CopyContext(ctx, io.Discard, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != 8 {
|
||||
t.Fatal("not the expected number of bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyContextWithError(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &MockableReader{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, expected
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
out, err := CopyContext(ctx, io.Discard, r)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != 0 {
|
||||
t.Fatal("not the expected number of bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyContextWithCancelledContext(t *testing.T) {
|
||||
r := strings.NewReader("deadbeef")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // fail immediately
|
||||
out, err := CopyContext(ctx, io.Discard, r)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != 0 {
|
||||
t.Fatal("not the expected number of bytes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyContextWithErrorAndCancelledContext(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &MockableReader{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
time.Sleep(time.Millisecond)
|
||||
return 0, expected
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // fail immediately
|
||||
out, err := CopyContext(ctx, io.Discard, r)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if out != 0 {
|
||||
t.Fatal("not the expected number of bytes")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,12 +40,12 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
pt "git.torproject.org/pluggable-transports/goptlib.git"
|
||||
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||
)
|
||||
|
||||
// PTDialer is a generic pluggable transports dialer.
|
||||
|
@ -104,17 +104,17 @@ func (lst *Listener) logger() Logger {
|
|||
// forward forwards the traffic from left to right and from right to left
|
||||
// and closes the done channel when it is done. This function DOES NOT
|
||||
// take ownership of the left, right net.Conn arguments.
|
||||
func (lst *Listener) forward(left, right net.Conn, done chan struct{}) {
|
||||
func (lst *Listener) forward(ctx context.Context, left, right net.Conn, done chan struct{}) {
|
||||
defer close(done) // signal termination
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(left, right)
|
||||
iox.CopyContext(ctx, left, right)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(right, left)
|
||||
iox.CopyContext(ctx, right, left)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ func (lst *Listener) forwardWithContext(ctx context.Context, left, right net.Con
|
|||
defer left.Close()
|
||||
defer right.Close()
|
||||
done := make(chan struct{})
|
||||
go lst.forward(left, right, done)
|
||||
go lst.forward(ctx, left, right, done)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
|
|
Loading…
Reference in New Issue
Block a user