diff --git a/internal/errorsx/dialer.go b/internal/errorsx/dialer.go index b439296..4a8c733 100644 --- a/internal/errorsx/dialer.go +++ b/internal/errorsx/dialer.go @@ -11,55 +11,67 @@ type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } -// ErrorWrapperDialer is a dialer that performs error wrapping. +// ErrorWrapperDialer is a dialer that performs error wrapping. The connection +// returned by the DialContext function will also perform error wrapping. type ErrorWrapperDialer struct { + // Dialer is the underlying dialer. Dialer } // DialContext implements Dialer.DialContext. func (d *ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) - err = SafeErrWrapperBuilder{ - Error: err, - Operation: ConnectOperation, - }.MaybeBuild() if err != nil { - return nil, err + return nil, &ErrWrapper{ + Failure: toFailureString(err), + Operation: ConnectOperation, + WrappedErr: err, + } } return &errorWrapperConn{Conn: conn}, nil } // errorWrapperConn is a net.Conn that performs error wrapping. type errorWrapperConn struct { + // Conn is the underlying connection. net.Conn } // Read implements net.Conn.Read. -func (c *errorWrapperConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - err = SafeErrWrapperBuilder{ - Error: err, - Operation: ReadOperation, - }.MaybeBuild() - return +func (c *errorWrapperConn) Read(b []byte) (int, error) { + count, err := c.Conn.Read(b) + if err != nil { + return 0, &ErrWrapper{ + Failure: toFailureString(err), + Operation: ReadOperation, + WrappedErr: err, + } + } + return count, nil } // Write implements net.Conn.Write. -func (c *errorWrapperConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - err = SafeErrWrapperBuilder{ - Error: err, - Operation: WriteOperation, - }.MaybeBuild() - return +func (c *errorWrapperConn) Write(b []byte) (int, error) { + count, err := c.Conn.Write(b) + if err != nil { + return 0, &ErrWrapper{ + Failure: toFailureString(err), + Operation: WriteOperation, + WrappedErr: err, + } + } + return count, nil } // Close implements net.Conn.Close. -func (c *errorWrapperConn) Close() (err error) { - err = c.Conn.Close() - err = SafeErrWrapperBuilder{ - Error: err, - Operation: CloseOperation, - }.MaybeBuild() - return +func (c *errorWrapperConn) Close() error { + err := c.Conn.Close() + if err != nil { + return &ErrWrapper{ + Failure: toFailureString(err), + Operation: CloseOperation, + WrappedErr: err, + } + } + return nil } diff --git a/internal/errorsx/dialer_test.go b/internal/errorsx/dialer_test.go index a870c47..4988e7b 100644 --- a/internal/errorsx/dialer_test.go +++ b/internal/errorsx/dialer_test.go @@ -18,66 +18,171 @@ func TestErrorWrapperDialerFailure(t *testing.T) { }, }} conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot convert to ErrWrapper") + } + if ew.Operation != ConnectOperation { + t.Fatal("unexpected operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("unexpected failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("unexpected underlying error", ew.WrappedErr) + } if conn != nil { t.Fatal("expected a nil conn here") } - errorWrapperCheckErr(t, err, ConnectOperation) -} - -func errorWrapperCheckErr(t *testing.T, err error, op string) { - if !errors.Is(err, io.EOF) { - t.Fatal("expected another error here") - } - var errWrapper *ErrWrapper - if !errors.As(err, &errWrapper) { - t.Fatal("cannot cast to ErrWrapper") - } - if errWrapper.Operation != op { - t.Fatal("unexpected Operation") - } - if errWrapper.Failure != FailureEOFError { - t.Fatal("unexpected failure") - } } func TestErrorWrapperDialerSuccess(t *testing.T) { + origConn := &net.TCPConn{} ctx := context.Background() d := &ErrorWrapperDialer{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return &netxmocks.Conn{ - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockWrite: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockClose: func() error { - return io.EOF - }, - MockLocalAddr: func() net.Addr { - return &net.TCPAddr{Port: 12345} - }, - }, nil + return origConn, nil }, }} - conn, err := d.DialContext(ctx, "tcp", "www.google.com") + conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") if err != nil { t.Fatal(err) } - if conn == nil { - t.Fatal("expected non-nil conn here") + ewConn, ok := conn.(*errorWrapperConn) + if !ok { + t.Fatal("cannot cast to errorWrapperConn") + } + if ewConn.Conn != origConn { + t.Fatal("not the connection we expected") } - count, err := conn.Read(nil) - errorWrapperCheckIOResult(t, count, err, ReadOperation) - count, err = conn.Write(nil) - errorWrapperCheckIOResult(t, count, err, WriteOperation) - err = conn.Close() - errorWrapperCheckErr(t, err, CloseOperation) } -func errorWrapperCheckIOResult(t *testing.T, count int, err error, op string) { - if count != 0 { - t.Fatal("expected nil count here") +func TestErrorWrapperConnReadFailure(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Read(buf) + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot cast error to ErrWrapper") + } + if ew.Operation != ReadOperation { + t.Fatal("invalid operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("invalid failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error", ew.WrappedErr) + } + if cnt != 0 { + t.Fatal("expected zero here", cnt) + } +} + +func TestErrorWrapperConnReadSuccess(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockRead: func(b []byte) (int, error) { + return len(b), nil + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Read(buf) + if err != nil { + t.Fatal(err) + } + if cnt != len(buf) { + t.Fatal("expected len(buf) here", cnt) + } +} + +func TestErrorWrapperConnWriteFailure(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Write(buf) + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot cast error to ErrWrapper") + } + if ew.Operation != WriteOperation { + t.Fatal("invalid operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("invalid failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error", ew.WrappedErr) + } + if cnt != 0 { + t.Fatal("expected zero here", cnt) + } +} + +func TestErrorWrapperConnWriteSuccess(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + }, + } + buf := make([]byte, 1024) + cnt, err := c.Write(buf) + if err != nil { + t.Fatal(err) + } + if cnt != len(buf) { + t.Fatal("expected len(buf) here", cnt) + } +} + +func TestErrorWrapperConnCloseFailure(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockClose: func() error { + return io.EOF + }, + }, + } + err := c.Close() + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("cannot cast error to ErrWrapper") + } + if ew.Operation != CloseOperation { + t.Fatal("invalid operation", ew.Operation) + } + if ew.Failure != FailureEOFError { + t.Fatal("invalid failure", ew.Failure) + } + if !errors.Is(ew.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error", ew.WrappedErr) + } +} + +func TestErrorWrapperConnCloseSuccess(t *testing.T) { + c := &errorWrapperConn{ + Conn: &netxmocks.Conn{ + MockClose: func() error { + return nil + }, + }, + } + err := c.Close() + if err != nil { + t.Fatal(err) } - errorWrapperCheckErr(t, err, op) } diff --git a/internal/errorsx/errorsx.go b/internal/errorsx/errorsx.go index 4a4b78f..d8dd3ad 100644 --- a/internal/errorsx/errorsx.go +++ b/internal/errorsx/errorsx.go @@ -105,6 +105,8 @@ func toFailureString(err error) string { // The list returned here matches the values used by MK unless // explicitly noted otherwise with a comment. + // TODO(bassosimone): we need to always apply this rule not only here + // when we're making the most generic conversion. var errwrapper *ErrWrapper if errors.As(err, &errwrapper) { return errwrapper.Error() // we've already wrapped it