fix(all): introduce and use iox.CopyContext (#380)
* fix(all): introduce and use iox.CopyContext This PR is part of https://github.com/ooni/probe/issues/1417. In https://github.com/ooni/probe-cli/pull/379 we introduced a context aware wrapper for io.ReadAll (formerly ioutil.ReadAll). Here we introduce a context aware wrapper for io.Copy. * fix(humanize): more significant digits * fix: rename humanize files to follow the common pattern * fix aligment * fix test
This commit is contained in:
parent
0fdc9cafb5
commit
721ce95315
|
@ -63,7 +63,8 @@ run `go mod tidy` to minimize such changes.
|
||||||
|
|
||||||
- use `./internal/fsx.OpenFile` when you need to open a file
|
- use `./internal/fsx.OpenFile` when you need to open a file
|
||||||
|
|
||||||
- use `./internal/iox.ReadAllContext` instead of `ioutil.ReadAll`
|
- use `./internal/iox.ReadAllContext` instead of `io.ReadAll`
|
||||||
|
and `./internal/iox.CopyContext` instead of `io.Copy`
|
||||||
|
|
||||||
## Code testing requirements
|
## Code testing requirements
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getShasum(path string) (string, error) {
|
func getShasum(path string) (string, error) {
|
||||||
|
@ -17,7 +19,7 @@ func getShasum(path string) (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
if _, err := io.Copy(hasher, f); err != nil {
|
if _, err := iox.CopyContext(context.Background(), hasher, f); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return hex.EncodeToString(hasher.Sum(nil)), nil
|
return hex.EncodeToString(hasher.Sum(nil)), nil
|
||||||
|
|
|
@ -82,7 +82,7 @@ func (mgr downloadManager) doRun(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
n, err := io.Copy(io.Discard, reader)
|
n, err := iox.CopyContext(ctx, io.Discard, reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testName = "ndt"
|
testName = "ndt"
|
||||||
testVersion = "0.8.0"
|
testVersion = "0.9.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config contains the experiment settings
|
// Config contains the experiment settings
|
||||||
|
|
|
@ -17,7 +17,7 @@ func TestNewExperimentMeasurer(t *testing.T) {
|
||||||
if measurer.ExperimentName() != "ndt" {
|
if measurer.ExperimentName() != "ndt" {
|
||||||
t.Fatal("unexpected name")
|
t.Fatal("unexpected name")
|
||||||
}
|
}
|
||||||
if measurer.ExperimentVersion() != "0.8.0" {
|
if measurer.ExperimentVersion() != "0.9.0" {
|
||||||
t.Fatal("unexpected version")
|
t.Fatal("unexpected version")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
|
@ -13,6 +12,7 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ func (r Runner) httpGet(ctx context.Context, url string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if _, err = io.Copy(ioutil.Discard, resp.Body); err != nil {
|
if _, err = iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Implementation note: we shall check for this error once we have read the
|
// Implementation note: we shall check for this error once we have read the
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||||
)
|
)
|
||||||
|
|
||||||
func dorequest(ctx context.Context, url string) error {
|
func dorequest(ctx context.Context, url string) error {
|
||||||
|
@ -27,7 +28,7 @@ func dorequest(ctx context.Context, url string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return resp.Body.Close()
|
return resp.Body.Close()
|
||||||
|
|
|
@ -3,7 +3,6 @@ package engine
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -17,6 +16,7 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
|
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/model"
|
"github.com/ooni/probe-cli/v3/internal/engine/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/probeservices"
|
"github.com/ooni/probe-cli/v3/internal/engine/probeservices"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||||
"github.com/ooni/probe-cli/v3/internal/version"
|
"github.com/ooni/probe-cli/v3/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,7 +31,8 @@ func TestSessionByteCounter(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
ctx := context.Background()
|
||||||
|
if _, err := iox.CopyContext(ctx, ioutil.Discard, resp.Body); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if s.KibiBytesSent() <= 0 || s.KibiBytesReceived() <= 0 {
|
if s.KibiBytesSent() <= 0 || s.KibiBytesReceived() <= 0 {
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
package fsx_test
|
package fsx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/fsx"
|
"github.com/ooni/probe-cli/v3/internal/fsx"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ExampleOpenFile_openingDir() {
|
func ExampleOpenFile_openingDir() {
|
||||||
|
@ -26,7 +27,7 @@ func ExampleOpenFile_openingFile() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("unexpected error", err)
|
log.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(filep)
|
data, err := iox.ReadAllContext(context.Background(), filep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("unexpected error", err)
|
log.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import "fmt"
|
||||||
// specially tailored for printing download speeds.
|
// specially tailored for printing download speeds.
|
||||||
func SI(value float64, unit string) string {
|
func SI(value float64, unit string) string {
|
||||||
value, prefix := reduce(value)
|
value, prefix := reduce(value)
|
||||||
return fmt.Sprintf("%3.0f %s%s", value, prefix, unit)
|
return fmt.Sprintf("%6.2f %s%s", value, prefix, unit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce reduces value to a base value and a unit prefix. For
|
// reduce reduces value to a base value and a unit prefix. For
|
30
internal/humanize/humanize_test.go
Normal file
30
internal/humanize/humanize_test.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package humanize
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGood(t *testing.T) {
|
||||||
|
if v := SI(128, "bit/s"); v != "128.00 bit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
if v := SI(1280, "bit/s"); v != " 1.28 kbit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
if v := SI(12800, "bit/s"); v != " 12.80 kbit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
if v := SI(128000, "bit/s"); v != "128.00 kbit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
if v := SI(1280000, "bit/s"); v != " 1.28 Mbit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
if v := SI(12800000, "bit/s"); v != " 12.80 Mbit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
if v := SI(128000000, "bit/s"); v != "128.00 Mbit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
if v := SI(1280000000, "bit/s"); v != " 1.28 Gbit/s" {
|
||||||
|
t.Fatal("unexpected result", v)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,30 +0,0 @@
|
||||||
package humanize
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestGood(t *testing.T) {
|
|
||||||
if SI(128, "bit/s") != "128 bit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
if SI(1280, "bit/s") != " 1 kbit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
if SI(12800, "bit/s") != " 13 kbit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
if SI(128000, "bit/s") != "128 kbit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
if SI(1280000, "bit/s") != " 1 Mbit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
if SI(12800000, "bit/s") != " 13 Mbit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
if SI(128000000, "bit/s") != "128 Mbit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
if SI(1280000000, "bit/s") != " 1 Gbit/s" {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadAllContext reads the whole reader r in a
|
// ReadAllContext is like io.ReadAll but reads r in a
|
||||||
// background goroutine. This function will return
|
// background goroutine. This function will return
|
||||||
// earlier if the context is cancelled. In which case
|
// earlier if the context is cancelled. In which case
|
||||||
// we will continue reading from r in the background
|
// we will continue reading from r in the background
|
||||||
|
@ -45,3 +45,27 @@ var _ io.Reader = &MockableReader{}
|
||||||
func (r *MockableReader) Read(b []byte) (int, error) {
|
func (r *MockableReader) Read(b []byte) (int, error) {
|
||||||
return r.MockRead(b)
|
return r.MockRead(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CopyContext is like io.Copy but may terminate earlier
|
||||||
|
// when the context expires. This function has the same
|
||||||
|
// caveats of ReadAllContext regarding the temporary leaking
|
||||||
|
// of the background goroutine used to do I/O.
|
||||||
|
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||||
|
countch, errch := make(chan int64, 1), make(chan error, 1) // buffers
|
||||||
|
go func() {
|
||||||
|
count, err := io.Copy(dst, src)
|
||||||
|
if err != nil {
|
||||||
|
errch <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
countch <- count
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case count := <-countch:
|
||||||
|
return count, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ctx.Err()
|
||||||
|
case err := <-errch:
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package iox
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -68,3 +69,64 @@ func TestReadAllContextWithErrorAndCancelledContext(t *testing.T) {
|
||||||
t.Fatal("not the expected number of bytes")
|
t.Fatal("not the expected number of bytes")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCopyContextCommonCase(t *testing.T) {
|
||||||
|
r := strings.NewReader("deadbeef")
|
||||||
|
ctx := context.Background()
|
||||||
|
out, err := CopyContext(ctx, io.Discard, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if out != 8 {
|
||||||
|
t.Fatal("not the expected number of bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyContextWithError(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &MockableReader{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return 0, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
out, err := CopyContext(ctx, io.Discard, r)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if out != 0 {
|
||||||
|
t.Fatal("not the expected number of bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyContextWithCancelledContext(t *testing.T) {
|
||||||
|
r := strings.NewReader("deadbeef")
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // fail immediately
|
||||||
|
out, err := CopyContext(ctx, io.Discard, r)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if out != 0 {
|
||||||
|
t.Fatal("not the expected number of bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyContextWithErrorAndCancelledContext(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &MockableReader{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
return 0, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // fail immediately
|
||||||
|
out, err := CopyContext(ctx, io.Discard, r)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if out != 0 {
|
||||||
|
t.Fatal("not the expected number of bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -40,12 +40,12 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
pt "git.torproject.org/pluggable-transports/goptlib.git"
|
pt "git.torproject.org/pluggable-transports/goptlib.git"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/iox"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PTDialer is a generic pluggable transports dialer.
|
// PTDialer is a generic pluggable transports dialer.
|
||||||
|
@ -104,17 +104,17 @@ func (lst *Listener) logger() Logger {
|
||||||
// forward forwards the traffic from left to right and from right to left
|
// forward forwards the traffic from left to right and from right to left
|
||||||
// and closes the done channel when it is done. This function DOES NOT
|
// and closes the done channel when it is done. This function DOES NOT
|
||||||
// take ownership of the left, right net.Conn arguments.
|
// take ownership of the left, right net.Conn arguments.
|
||||||
func (lst *Listener) forward(left, right net.Conn, done chan struct{}) {
|
func (lst *Listener) forward(ctx context.Context, left, right net.Conn, done chan struct{}) {
|
||||||
defer close(done) // signal termination
|
defer close(done) // signal termination
|
||||||
wg := new(sync.WaitGroup)
|
wg := new(sync.WaitGroup)
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
io.Copy(left, right)
|
iox.CopyContext(ctx, left, right)
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
io.Copy(right, left)
|
iox.CopyContext(ctx, right, left)
|
||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ func (lst *Listener) forwardWithContext(ctx context.Context, left, right net.Con
|
||||||
defer left.Close()
|
defer left.Close()
|
||||||
defer right.Close()
|
defer right.Close()
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go lst.forward(left, right, done)
|
go lst.forward(ctx, left, right, done)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
case <-done:
|
case <-done:
|
||||||
|
@ -150,7 +150,7 @@ func (lst *Listener) handleSocksConn(ctx context.Context, socksConn ptxSocksConn
|
||||||
return err // used for testing
|
return err // used for testing
|
||||||
}
|
}
|
||||||
lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership
|
lst.forwardWithContext(ctx, socksConn, ptConn) // transfer ownership
|
||||||
return nil // used for testing
|
return nil // used for testing
|
||||||
}
|
}
|
||||||
|
|
||||||
// ptxSocksListener is a pt.SocksListener-like structure.
|
// ptxSocksListener is a pt.SocksListener-like structure.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user