refactor(errorsx): directly return ErrWrapper when needed

The utility of SafeErrWrapperBuilder is low. Let us instead change the
code to always create ErrWrapper when we're in this package.

While there, also note some TODO-next items.

Part of https://github.com/ooni/probe/issues/1505.
This commit is contained in:
Simone Basso 2021-07-02 14:00:46 +02:00
parent 362ece04c4
commit b9ff9136e2
3 changed files with 190 additions and 71 deletions

View File

@ -11,55 +11,67 @@ type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
// ErrorWrapperDialer is a dialer that performs error wrapping. // ErrorWrapperDialer is a dialer that performs error wrapping. The connection
// returned by the DialContext function will also perform error wrapping.
type ErrorWrapperDialer struct { type ErrorWrapperDialer struct {
// Dialer is the underlying dialer.
Dialer Dialer
} }
// DialContext implements Dialer.DialContext. // DialContext implements Dialer.DialContext.
func (d *ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
err = SafeErrWrapperBuilder{
Error: err,
Operation: ConnectOperation,
}.MaybeBuild()
if err != nil { if err != nil {
return nil, err return nil, &ErrWrapper{
Failure: toFailureString(err),
Operation: ConnectOperation,
WrappedErr: err,
}
} }
return &errorWrapperConn{Conn: conn}, nil return &errorWrapperConn{Conn: conn}, nil
} }
// errorWrapperConn is a net.Conn that performs error wrapping. // errorWrapperConn is a net.Conn that performs error wrapping.
type errorWrapperConn struct { type errorWrapperConn struct {
// Conn is the underlying connection.
net.Conn net.Conn
} }
// Read implements net.Conn.Read. // Read implements net.Conn.Read.
func (c *errorWrapperConn) Read(b []byte) (n int, err error) { func (c *errorWrapperConn) Read(b []byte) (int, error) {
n, err = c.Conn.Read(b) count, err := c.Conn.Read(b)
err = SafeErrWrapperBuilder{ if err != nil {
Error: err, return 0, &ErrWrapper{
Operation: ReadOperation, Failure: toFailureString(err),
}.MaybeBuild() Operation: ReadOperation,
return WrappedErr: err,
}
}
return count, nil
} }
// Write implements net.Conn.Write. // Write implements net.Conn.Write.
func (c *errorWrapperConn) Write(b []byte) (n int, err error) { func (c *errorWrapperConn) Write(b []byte) (int, error) {
n, err = c.Conn.Write(b) count, err := c.Conn.Write(b)
err = SafeErrWrapperBuilder{ if err != nil {
Error: err, return 0, &ErrWrapper{
Operation: WriteOperation, Failure: toFailureString(err),
}.MaybeBuild() Operation: WriteOperation,
return WrappedErr: err,
}
}
return count, nil
} }
// Close implements net.Conn.Close. // Close implements net.Conn.Close.
func (c *errorWrapperConn) Close() (err error) { func (c *errorWrapperConn) Close() error {
err = c.Conn.Close() err := c.Conn.Close()
err = SafeErrWrapperBuilder{ if err != nil {
Error: err, return &ErrWrapper{
Operation: CloseOperation, Failure: toFailureString(err),
}.MaybeBuild() Operation: CloseOperation,
return WrappedErr: err,
}
}
return nil
} }

View File

@ -18,66 +18,171 @@ func TestErrorWrapperDialerFailure(t *testing.T) {
}, },
}} }}
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
var ew *ErrWrapper
if !errors.As(err, &ew) {
t.Fatal("cannot convert to ErrWrapper")
}
if ew.Operation != ConnectOperation {
t.Fatal("unexpected operation", ew.Operation)
}
if ew.Failure != FailureEOFError {
t.Fatal("unexpected failure", ew.Failure)
}
if !errors.Is(ew.WrappedErr, io.EOF) {
t.Fatal("unexpected underlying error", ew.WrappedErr)
}
if conn != nil { if conn != nil {
t.Fatal("expected a nil conn here") t.Fatal("expected a nil conn here")
} }
errorWrapperCheckErr(t, err, ConnectOperation)
}
func errorWrapperCheckErr(t *testing.T, err error, op string) {
if !errors.Is(err, io.EOF) {
t.Fatal("expected another error here")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("cannot cast to ErrWrapper")
}
if errWrapper.Operation != op {
t.Fatal("unexpected Operation")
}
if errWrapper.Failure != FailureEOFError {
t.Fatal("unexpected failure")
}
} }
func TestErrorWrapperDialerSuccess(t *testing.T) { func TestErrorWrapperDialerSuccess(t *testing.T) {
origConn := &net.TCPConn{}
ctx := context.Background() ctx := context.Background()
d := &ErrorWrapperDialer{Dialer: &netxmocks.Dialer{ d := &ErrorWrapperDialer{Dialer: &netxmocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return &netxmocks.Conn{ return origConn, nil
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
MockClose: func() error {
return io.EOF
},
MockLocalAddr: func() net.Addr {
return &net.TCPAddr{Port: 12345}
},
}, nil
}, },
}} }}
conn, err := d.DialContext(ctx, "tcp", "www.google.com") conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if conn == nil { ewConn, ok := conn.(*errorWrapperConn)
t.Fatal("expected non-nil conn here") if !ok {
t.Fatal("cannot cast to errorWrapperConn")
}
if ewConn.Conn != origConn {
t.Fatal("not the connection we expected")
} }
count, err := conn.Read(nil)
errorWrapperCheckIOResult(t, count, err, ReadOperation)
count, err = conn.Write(nil)
errorWrapperCheckIOResult(t, count, err, WriteOperation)
err = conn.Close()
errorWrapperCheckErr(t, err, CloseOperation)
} }
func errorWrapperCheckIOResult(t *testing.T, count int, err error, op string) { func TestErrorWrapperConnReadFailure(t *testing.T) {
if count != 0 { c := &errorWrapperConn{
t.Fatal("expected nil count here") Conn: &netxmocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
},
}
buf := make([]byte, 1024)
cnt, err := c.Read(buf)
var ew *ErrWrapper
if !errors.As(err, &ew) {
t.Fatal("cannot cast error to ErrWrapper")
}
if ew.Operation != ReadOperation {
t.Fatal("invalid operation", ew.Operation)
}
if ew.Failure != FailureEOFError {
t.Fatal("invalid failure", ew.Failure)
}
if !errors.Is(ew.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error", ew.WrappedErr)
}
if cnt != 0 {
t.Fatal("expected zero here", cnt)
}
}
func TestErrorWrapperConnReadSuccess(t *testing.T) {
c := &errorWrapperConn{
Conn: &netxmocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
},
}
buf := make([]byte, 1024)
cnt, err := c.Read(buf)
if err != nil {
t.Fatal(err)
}
if cnt != len(buf) {
t.Fatal("expected len(buf) here", cnt)
}
}
func TestErrorWrapperConnWriteFailure(t *testing.T) {
c := &errorWrapperConn{
Conn: &netxmocks.Conn{
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
},
}
buf := make([]byte, 1024)
cnt, err := c.Write(buf)
var ew *ErrWrapper
if !errors.As(err, &ew) {
t.Fatal("cannot cast error to ErrWrapper")
}
if ew.Operation != WriteOperation {
t.Fatal("invalid operation", ew.Operation)
}
if ew.Failure != FailureEOFError {
t.Fatal("invalid failure", ew.Failure)
}
if !errors.Is(ew.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error", ew.WrappedErr)
}
if cnt != 0 {
t.Fatal("expected zero here", cnt)
}
}
func TestErrorWrapperConnWriteSuccess(t *testing.T) {
c := &errorWrapperConn{
Conn: &netxmocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
},
}
buf := make([]byte, 1024)
cnt, err := c.Write(buf)
if err != nil {
t.Fatal(err)
}
if cnt != len(buf) {
t.Fatal("expected len(buf) here", cnt)
}
}
func TestErrorWrapperConnCloseFailure(t *testing.T) {
c := &errorWrapperConn{
Conn: &netxmocks.Conn{
MockClose: func() error {
return io.EOF
},
},
}
err := c.Close()
var ew *ErrWrapper
if !errors.As(err, &ew) {
t.Fatal("cannot cast error to ErrWrapper")
}
if ew.Operation != CloseOperation {
t.Fatal("invalid operation", ew.Operation)
}
if ew.Failure != FailureEOFError {
t.Fatal("invalid failure", ew.Failure)
}
if !errors.Is(ew.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error", ew.WrappedErr)
}
}
func TestErrorWrapperConnCloseSuccess(t *testing.T) {
c := &errorWrapperConn{
Conn: &netxmocks.Conn{
MockClose: func() error {
return nil
},
},
}
err := c.Close()
if err != nil {
t.Fatal(err)
} }
errorWrapperCheckErr(t, err, op)
} }

View File

@ -105,6 +105,8 @@ func toFailureString(err error) string {
// The list returned here matches the values used by MK unless // The list returned here matches the values used by MK unless
// explicitly noted otherwise with a comment. // explicitly noted otherwise with a comment.
// TODO(bassosimone): we need to always apply this rule not only here
// when we're making the most generic conversion.
var errwrapper *ErrWrapper var errwrapper *ErrWrapper
if errors.As(err, &errwrapper) { if errors.As(err, &errwrapper) {
return errwrapper.Error() // we've already wrapped it return errwrapper.Error() // we've already wrapped it