From b9ff9136e26aeab9c36ab377fc300ac8d9061248 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Fri, 2 Jul 2021 14:00:46 +0200 Subject: [PATCH] refactor(errorsx): directly return ErrWrapper when needed The utility of SafeErrWrapperBuilder is low. Let us instead change the code to always create ErrWrapper when we're in this package. While there, also note some TODO-next items. Part of https://github.com/ooni/probe/issues/1505. --- internal/errorsx/dialer.go | 66 ++++++----- internal/errorsx/dialer_test.go | 193 ++++++++++++++++++++++++-------- internal/errorsx/errorsx.go | 2 + 3 files changed, 190 insertions(+), 71 deletions(-) diff --git a/internal/errorsx/dialer.go b/internal/errorsx/dialer.go index b439296..4a8c733 100644 --- a/internal/errorsx/dialer.go +++ b/internal/errorsx/dialer.go @@ -11,55 +11,67 @@ type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } -// ErrorWrapperDialer is a dialer that performs error wrapping. +// ErrorWrapperDialer is a dialer that performs error wrapping. The connection +// returned by the DialContext function will also perform error wrapping. type ErrorWrapperDialer struct { + // Dialer is the underlying dialer. Dialer } // DialContext implements Dialer.DialContext. func (d *ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) - err = SafeErrWrapperBuilder{ - Error: err, - Operation: ConnectOperation, - }.MaybeBuild() if err != nil { - return nil, err + return nil, &ErrWrapper{ + Failure: toFailureString(err), + Operation: ConnectOperation, + WrappedErr: err, + } } return &errorWrapperConn{Conn: conn}, nil } // errorWrapperConn is a net.Conn that performs error wrapping. type errorWrapperConn struct { + // Conn is the underlying connection. net.Conn } // Read implements net.Conn.Read. -func (c *errorWrapperConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - err = SafeErrWrapperBuilder{ - Error: err, - Operation: ReadOperation, - }.MaybeBuild() - return +func (c *errorWrapperConn) Read(b []byte) (int, error) { + count, err := c.Conn.Read(b) + if err != nil { + return 0, &ErrWrapper{ + Failure: toFailureString(err), + Operation: ReadOperation, + WrappedErr: err, + } + } + return count, nil } // Write implements net.Conn.Write. -func (c *errorWrapperConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - err = SafeErrWrapperBuilder{ - Error: err, - Operation: WriteOperation, - }.MaybeBuild() - return +func (c *errorWrapperConn) Write(b []byte) (int, error) { + count, err := c.Conn.Write(b) + if err != nil { + return 0, &ErrWrapper{ + Failure: toFailureString(err), + Operation: WriteOperation, + WrappedErr: err, + } + } + return count, nil } // Close implements net.Conn.Close. -func (c *errorWrapperConn) Close() (err error) { - err = c.Conn.Close() - err = SafeErrWrapperBuilder{ - Error: err, - Operation: CloseOperation, - }.MaybeBuild() - return +func (c *errorWrapperConn) Close() error { + err := c.Conn.Close() + if err != nil { + return &ErrWrapper{ + Failure: toFailureString(err), + Operation: CloseOperation, + WrappedErr: err, + } + } + return nil } diff --git a/internal/errorsx/dialer_test.go b/internal/errorsx/dialer_test.go index a870c47..4988e7b 100644 --- a/internal/errorsx/dialer_test.go +++ b/internal/errorsx/dialer_test.go @@ -18,66 +18,171 @@ func TestErrorWrapperDialerFailure(t *testing.T) { }, }} conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot convert to ErrWrapper") + } + if ew.Operation != ConnectOperation { + t.Fatal("unexpected operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("unexpected failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("unexpected underlying error", ew.WrappedErr) + } if conn != nil { t.Fatal("expected a nil conn here") } - errorWrapperCheckErr(t, err, ConnectOperation) -} - -func errorWrapperCheckErr(t *testing.T, err error, op string) { - if !errors.Is(err, io.EOF) { - t.Fatal("expected another error here") - } - var errWrapper *ErrWrapper - if !errors.As(err, &errWrapper) { - t.Fatal("cannot cast to ErrWrapper") - } - if errWrapper.Operation != op { - t.Fatal("unexpected Operation") - } - if errWrapper.Failure != FailureEOFError { - t.Fatal("unexpected failure") - } } func TestErrorWrapperDialerSuccess(t *testing.T) { + origConn := &net.TCPConn{} ctx := context.Background() d := &ErrorWrapperDialer{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return &netxmocks.Conn{ - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockWrite: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockClose: func() error { - return io.EOF - }, - MockLocalAddr: func() net.Addr { - return &net.TCPAddr{Port: 12345} - }, - }, nil + return origConn, nil }, }} - conn, err := d.DialContext(ctx, "tcp", "www.google.com") + conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") if err != nil { t.Fatal(err) } - if conn == nil { - t.Fatal("expected non-nil conn here") + ewConn, ok := conn.(*errorWrapperConn) + if !ok { + t.Fatal("cannot cast to errorWrapperConn") + } + if ewConn.Conn != origConn { + t.Fatal("not the connection we expected") } - count, err := conn.Read(nil) - errorWrapperCheckIOResult(t, count, err, ReadOperation) - count, err = conn.Write(nil) - errorWrapperCheckIOResult(t, count, err, WriteOperation) - err = conn.Close() - errorWrapperCheckErr(t, err, CloseOperation) } -func errorWrapperCheckIOResult(t *testing.T, count int, err error, op string) { - if count != 0 { - t.Fatal("expected nil count here") +func TestErrorWrapperConnReadFailure(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Read(buf) + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot cast error to ErrWrapper") + } + if ew.Operation != ReadOperation { + t.Fatal("invalid operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("invalid failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error", ew.WrappedErr) + } + if cnt != 0 { + t.Fatal("expected zero here", cnt) + } +} + +func TestErrorWrapperConnReadSuccess(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockRead: func(b []byte) (int, error) { + return len(b), nil + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Read(buf) + if err != nil { + t.Fatal(err) + } + if cnt != len(buf) { + t.Fatal("expected len(buf) here", cnt) + } +} + +func TestErrorWrapperConnWriteFailure(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Write(buf) + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot cast error to ErrWrapper") + } + if ew.Operation != WriteOperation { + t.Fatal("invalid operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("invalid failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error", ew.WrappedErr) + } + if cnt != 0 { + t.Fatal("expected zero here", cnt) + } +} + +func TestErrorWrapperConnWriteSuccess(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Write(buf) + if err != nil { + t.Fatal(err) + } + if cnt != len(buf) { + t.Fatal("expected len(buf) here", cnt) + } +} + +func TestErrorWrapperConnCloseFailure(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockClose: func() error { + return io.EOF + }, + }, + } + err := c.Close() + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot cast error to ErrWrapper") + } + if ew.Operation != CloseOperation { + t.Fatal("invalid operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("invalid failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error", ew.WrappedErr) + } +} + +func TestErrorWrapperConnCloseSuccess(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockClose: func() error { + return nil + }, + }, + } + err := c.Close() + if err != nil { + t.Fatal(err) } - errorWrapperCheckErr(t, err, op) } diff --git a/internal/errorsx/errorsx.go b/internal/errorsx/errorsx.go index 4a4b78f..d8dd3ad 100644 --- a/internal/errorsx/errorsx.go +++ b/internal/errorsx/errorsx.go @@ -105,6 +105,8 @@ func toFailureString(err error) string { // The list returned here matches the values used by MK unless // explicitly noted otherwise with a comment. + // TODO(bassosimone): we need to always apply this rule not only here + // when we're making the most generic conversion. var errwrapper *ErrWrapper if errors.As(err, &errwrapper) { return errwrapper.Error() // we've already wrapped it