diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index dd734ab..cba86d7 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -6,6 +6,8 @@ import ( "net" "sync" "time" + + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" ) // Dialer establishes network connections. @@ -22,7 +24,9 @@ func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer { return &dialerLogger{ Dialer: &dialerResolver{ Dialer: &dialerLogger{ - Dialer: &dialerSystem{}, + Dialer: &dialerErrWrapper{ + Dialer: &dialerSystem{}, + }, Logger: logger, operationSuffix: "_address", }, @@ -188,3 +192,75 @@ func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr func (s *dialerSingleUse) CloseIdleConnections() { // nothing } + +// TODO(bassosimone): introduce factory for creating errors and +// write tests that ensure the factory works correctly. + +// dialerErrWrapper is a dialer that performs error wrapping. The connection +// returned by the DialContext function will also perform error wrapping. +type dialerErrWrapper struct { + // Dialer is the underlying dialer. + Dialer +} + +var _ Dialer = &dialerErrWrapper{} + +// DialContext implements Dialer.DialContext. +func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyGenericError(err), + Operation: errorsx.ConnectOperation, + WrappedErr: err, + } + } + return &dialerErrWrapperConn{Conn: conn}, nil +} + +// dialerErrWrapperConn is a net.Conn that performs error wrapping. +type dialerErrWrapperConn struct { + // Conn is the underlying connection. + net.Conn +} + +var _ net.Conn = &dialerErrWrapperConn{} + +// Read implements net.Conn.Read. +func (c *dialerErrWrapperConn) Read(b []byte) (int, error) { + count, err := c.Conn.Read(b) + if err != nil { + return 0, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyGenericError(err), + Operation: errorsx.ReadOperation, + WrappedErr: err, + } + } + return count, nil +} + +// Write implements net.Conn.Write. +func (c *dialerErrWrapperConn) Write(b []byte) (int, error) { + count, err := c.Conn.Write(b) + if err != nil { + return 0, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyGenericError(err), + Operation: errorsx.WriteOperation, + WrappedErr: err, + } + } + return count, nil +} + +// Close implements net.Conn.Close. +func (c *dialerErrWrapperConn) Close() error { + err := c.Conn.Close() + if err != nil { + return &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyGenericError(err), + Operation: errorsx.CloseOperation, + WrappedErr: err, + } + } + return nil +} diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 5d89793..bc624a3 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) @@ -231,7 +232,11 @@ func TestNewDialerWithoutResolverChain(t *testing.T) { if dlog.Logger != log.Log { t.Fatal("invalid logger") } - if _, okay := dlog.Dialer.(*dialerSystem); !okay { + dew, okay := dlog.Dialer.(*dialerErrWrapper) + if !okay { + t.Fatal("invalid type") + } + if _, okay := dew.Dialer.(*dialerSystem); !okay { t.Fatal("invalid type") } } @@ -256,3 +261,172 @@ func TestNewSingleUseDialerWorksAsIntended(t *testing.T) { } } } + +func TestDialerErrWrapper(t *testing.T) { + t.Run("DialContext on success", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expectedConn := &mocks.Conn{} + d := &dialerErrWrapper{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return expectedConn, nil + }, + }, + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + errWrapperConn := conn.(*dialerErrWrapperConn) + if errWrapperConn.Conn != expectedConn { + t.Fatal("unexpected conn") + } + }) + + t.Run("on failure", func(t *testing.T) { + expectedErr := io.EOF + d := &dialerErrWrapper{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expectedErr + }, + }, + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "", "") + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + d := &dialerErrWrapper{ + Dialer: &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + }, + } + d.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) +} + +func TestDialerErrWrapperConn(t *testing.T) { + t.Run("Read", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + b := make([]byte, 128) + conn := &dialerErrWrapperConn{ + Conn: &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return len(b), nil + }, + }, + } + count, err := conn.Read(b) + if err != nil { + t.Fatal(err) + } + if count != len(b) { + t.Fatal("unexpected count") + } + }) + + t.Run("on failure", func(t *testing.T) { + b := make([]byte, 128) + expectedErr := io.EOF + conn := &dialerErrWrapperConn{ + Conn: &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, expectedErr + }, + }, + } + count, err := conn.Read(b) + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + }) + }) + + t.Run("Write", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + b := make([]byte, 128) + conn := &dialerErrWrapperConn{ + Conn: &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + }, + } + count, err := conn.Write(b) + if err != nil { + t.Fatal(err) + } + if count != len(b) { + t.Fatal("unexpected count") + } + }) + + t.Run("on failure", func(t *testing.T) { + b := make([]byte, 128) + expectedErr := io.EOF + conn := &dialerErrWrapperConn{ + Conn: &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, expectedErr + }, + }, + } + count, err := conn.Write(b) + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + }) + }) + + t.Run("Close", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + conn := &dialerErrWrapperConn{ + Conn: &mocks.Conn{ + MockClose: func() error { + return nil + }, + }, + } + err := conn.Close() + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on failure", func(t *testing.T) { + expectedErr := io.EOF + conn := &dialerErrWrapperConn{ + Conn: &mocks.Conn{ + MockClose: func() error { + return expectedErr + }, + }, + } + err := conn.Close() + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + }) + }) +} diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index 9802863..856130c 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/lucas-clemente/quic-go" + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" "github.com/ooni/probe-cli/v3/internal/netxlite/quicx" ) @@ -21,7 +22,7 @@ type QUICListener interface { // NewQUICListener creates a new QUICListener using the standard // library to create listening UDP sockets. func NewQUICListener() QUICListener { - return &quicListenerStdlib{} + return &quicListenerErrWrapper{&quicListenerStdlib{}} } // quicListenerStdlib is a QUICListener using the standard library. @@ -54,9 +55,10 @@ func NewQUICDialerWithResolver(listener QUICListener, return &quicDialerLogger{ Dialer: &quicDialerResolver{ Dialer: &quicDialerLogger{ - Dialer: &quicDialerQUICGo{ - QUICListener: listener, - }, + Dialer: &quicDialerErrWrapper{ + QUICDialer: &quicDialerQUICGo{ + QUICListener: listener, + }}, Logger: logger, operationSuffix: "_address", }, @@ -322,3 +324,78 @@ func (s *quicDialerSingleUse) DialContext( func (s *quicDialerSingleUse) CloseIdleConnections() { // nothing to do } + +// quicListenerErrWrapper is a QUICListener that wraps errors. +type quicListenerErrWrapper struct { + // QUICListener is the underlying listener. + QUICListener +} + +var _ QUICListener = &quicListenerErrWrapper{} + +// Listen implements QUICListener.Listen. +func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (quicx.UDPLikeConn, error) { + pconn, err := qls.QUICListener.Listen(addr) + if err != nil { + return nil, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyGenericError(err), + Operation: errorsx.QUICListenOperation, + WrappedErr: err, + } + } + return &quicErrWrapperUDPLikeConn{pconn}, nil +} + +// quicErrWrapperUDPLikeConn is a quicx.UDPLikeConn that wraps errors. +type quicErrWrapperUDPLikeConn struct { + // UDPLikeConn is the underlying conn. + quicx.UDPLikeConn +} + +var _ quicx.UDPLikeConn = &quicErrWrapperUDPLikeConn{} + +// WriteTo implements quicx.UDPLikeConn.WriteTo. +func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) { + count, err := c.UDPLikeConn.WriteTo(p, addr) + if err != nil { + return 0, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyGenericError(err), + Operation: errorsx.WriteToOperation, + WrappedErr: err, + } + } + return count, nil +} + +// ReadFrom implements quicx.UDPLikeConn.ReadFrom. +func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, addr, err := c.UDPLikeConn.ReadFrom(b) + if err != nil { + return 0, nil, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyGenericError(err), + Operation: errorsx.ReadFromOperation, + WrappedErr: err, + } + } + return n, addr, nil +} + +// quicDialerErrWrapper is a dialer that performs quic err wrapping +type quicDialerErrWrapper struct { + QUICDialer +} + +// DialContext implements ContextDialer.DialContext +func (d *quicDialerErrWrapper) DialContext( + ctx context.Context, network string, host string, + tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { + sess, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg) + if err != nil { + return nil, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyQUICHandshakeError(err), + Operation: errorsx.QUICHandshakeOperation, + WrappedErr: err, + } + } + return sess, nil +} diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index 8290a30..ec55b53 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "io" "net" "strings" "testing" @@ -11,6 +12,7 @@ import ( "github.com/apex/log" "github.com/google/go-cmp/cmp" "github.com/lucas-clemente/quic-go" + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/quicx" ) @@ -452,7 +454,11 @@ func TestNewQUICDialerWithoutResolverChain(t *testing.T) { if dlog.Logger != log.Log { t.Fatal("invalid logger") } - dgo, okay := dlog.Dialer.(*quicDialerQUICGo) + ew, okay := dlog.Dialer.(*quicDialerErrWrapper) + if !okay { + t.Fatal("invalid type") + } + dgo, okay := ew.QUICDialer.(*quicDialerQUICGo) if !okay { t.Fatal("invalid type") } @@ -483,3 +489,188 @@ func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) { } } } + +func TestQUICListenerErrWrapper(t *testing.T) { + t.Run("Listen", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expectedConn := &mocks.QUICUDPConn{} + ql := &quicListenerErrWrapper{ + QUICListener: &mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) { + return expectedConn, nil + }, + }, + } + conn, err := ql.Listen(&net.UDPAddr{}) + if err != nil { + t.Fatal(err) + } + ewconn := conn.(*quicErrWrapperUDPLikeConn) + if ewconn.UDPLikeConn != expectedConn { + t.Fatal("unexpected conn") + } + }) + + t.Run("on failure", func(t *testing.T) { + expectedErr := io.EOF + ql := &quicListenerErrWrapper{ + QUICListener: &mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) { + return nil, expectedErr + }, + }, + } + conn, err := ql.Listen(&net.UDPAddr{}) + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + }) +} + +func TestQUICErrWrapperUDPLikeConn(t *testing.T) { + t.Run("ReadFrom", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expectedAddr := &net.UDPAddr{} + p := make([]byte, 128) + conn := &quicErrWrapperUDPLikeConn{ + UDPLikeConn: &mocks.QUICUDPConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + return len(p), expectedAddr, nil + }, + }, + } + count, addr, err := conn.ReadFrom(p) + if err != nil { + t.Fatal(err) + } + if count != len(p) { + t.Fatal("unexpected count") + } + if addr != expectedAddr { + t.Fatal("unexpected addr") + } + }) + + t.Run("on failure", func(t *testing.T) { + p := make([]byte, 128) + expectedErr := io.EOF + conn := &quicErrWrapperUDPLikeConn{ + UDPLikeConn: &mocks.QUICUDPConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + return 0, nil, expectedErr + }, + }, + } + count, addr, err := conn.ReadFrom(p) + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + if addr != nil { + t.Fatal("unexpected addr") + } + }) + }) + + t.Run("WriteTo", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + p := make([]byte, 128) + conn := &quicErrWrapperUDPLikeConn{ + UDPLikeConn: &mocks.QUICUDPConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return len(p), nil + }, + }, + } + count, err := conn.WriteTo(p, &net.UDPAddr{}) + if err != nil { + t.Fatal(err) + } + if count != len(p) { + t.Fatal("unexpected count") + } + }) + + t.Run("on failure", func(t *testing.T) { + p := make([]byte, 128) + expectedErr := io.EOF + conn := &quicErrWrapperUDPLikeConn{ + UDPLikeConn: &mocks.QUICUDPConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return 0, expectedErr + }, + }, + } + count, err := conn.WriteTo(p, &net.UDPAddr{}) + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + }) + }) +} + +func TestQUICDialerErrWrapper(t *testing.T) { + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + d := &quicDialerErrWrapper{ + QUICDialer: &mocks.QUICDialer{ + MockCloseIdleConnections: func() { + called = true + }, + }, + } + d.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("DialContext", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expectedSess := &mocks.QUICEarlySession{} + d := &quicDialerErrWrapper{ + QUICDialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + return expectedSess, nil + }, + }, + } + ctx := context.Background() + sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + if sess != expectedSess { + t.Fatal("unexpected sess") + } + }) + + t.Run("on failure", func(t *testing.T) { + expectedErr := io.EOF + d := &quicDialerErrWrapper{ + QUICDialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + return nil, expectedErr + }, + }, + } + ctx := context.Background() + sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{}) + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if sess != nil { + t.Fatal("unexpected sess") + } + }) + }) +} diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolver.go index fbbabb6..785887c 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolver.go @@ -6,6 +6,7 @@ import ( "net" "time" + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" "golang.org/x/net/idna" ) @@ -30,7 +31,9 @@ func NewResolverSystem(logger Logger) Resolver { return &resolverIDNA{ Resolver: &resolverLogger{ Resolver: &resolverShortCircuitIPAddr{ - Resolver: &resolverSystem{}, + Resolver: &resolverErrWrapper{ + Resolver: &resolverSystem{}, + }, }, Logger: logger, }, @@ -182,3 +185,23 @@ func (r *nullResolver) Address() string { func (r *nullResolver) CloseIdleConnections() { // nothing } + +// resolverErrWrapper is a Resolver that knows about wrapping errors. +type resolverErrWrapper struct { + Resolver +} + +var _ Resolver = &resolverErrWrapper{} + +// LookupHost implements Resolver.LookupHost +func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) { + addrs, err := r.Resolver.LookupHost(ctx, hostname) + if err != nil { + return nil, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyResolverError(err), + Operation: errorsx.ResolveOperation, + WrappedErr: err, + } + } + return addrs, nil +} diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolver_test.go index da58322..27e7d1c 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolver_test.go @@ -3,6 +3,7 @@ package netxlite import ( "context" "errors" + "io" "strings" "sync" "testing" @@ -10,6 +11,7 @@ import ( "github.com/apex/log" "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) @@ -196,7 +198,11 @@ func TestNewResolverTypeChain(t *testing.T) { if !ok { t.Fatal("invalid resolver") } - if _, ok := scia.Resolver.(*resolverSystem); !ok { + ew, ok := scia.Resolver.(*resolverErrWrapper) + if !ok { + t.Fatal("invalid resolver") + } + if _, ok := ew.Resolver.(*resolverSystem); !ok { t.Fatal("invalid resolver") } } @@ -255,3 +261,88 @@ func TestNullResolverWorksAsIntended(t *testing.T) { } r.CloseIdleConnections() // should not crash } + +func TestResolverErrWrapper(t *testing.T) { + t.Run("LookupHost", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expected := []string{"8.8.8.8", "8.8.4.4"} + reso := &resolverErrWrapper{ + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return expected, nil + }, + }, + } + ctx := context.Background() + addrs, err := reso.LookupHost(ctx, "") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expected, addrs); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := io.EOF + reso := &resolverErrWrapper{ + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + }, + } + ctx := context.Background() + addrs, err := reso.LookupHost(ctx, "") + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if addrs != nil { + t.Fatal("unexpected addrs") + } + }) + }) + + t.Run("Network", func(t *testing.T) { + expected := "foobar" + reso := &resolverErrWrapper{ + Resolver: &mocks.Resolver{ + MockNetwork: func() string { + return expected + }, + }, + } + if reso.Network() != expected { + t.Fatal("invalid network") + } + }) + + t.Run("Address", func(t *testing.T) { + expected := "foobar" + reso := &resolverErrWrapper{ + Resolver: &mocks.Resolver{ + MockAddress: func() string { + return expected + }, + }, + } + if reso.Address() != expected { + t.Fatal("invalid address") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + reso := &resolverErrWrapper{ + Resolver: &mocks.Resolver{ + MockCloseIdleConnections: func() { + called = true + }, + }, + } + reso.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) +} diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index 2f7d82b..36fb64b 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -10,6 +10,7 @@ import ( "time" oohttp "github.com/ooni/oohttp" + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" ) var ( @@ -125,8 +126,10 @@ type TLSHandshaker interface { // go standard library to create TLS connections. func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker { return &tlsHandshakerLogger{ - TLSHandshaker: &tlsHandshakerConfigurable{}, - Logger: logger, + TLSHandshaker: &tlsHandshakerErrWrapper{ + TLSHandshaker: &tlsHandshakerConfigurable{}, + }, + Logger: logger, } } @@ -319,3 +322,23 @@ var _ TLSDialer = &tlsDialerSingleUseAdapter{} func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { return d.Dialer.DialContext(ctx, network, address) } + +// tlsHandshakerErrWrapper wraps the returned error to be an OONI error +type tlsHandshakerErrWrapper struct { + TLSHandshaker +} + +// Handshake implements TLSHandshaker.Handshake +func (h *tlsHandshakerErrWrapper) Handshake( + ctx context.Context, conn net.Conn, config *tls.Config, +) (net.Conn, tls.ConnectionState, error) { + tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) + if err != nil { + return nil, tls.ConnectionState{}, &errorsx.ErrWrapper{ + Failure: errorsx.ClassifyTLSHandshakeError(err), + Operation: errorsx.TLSHandshakeOperation, + WrappedErr: err, + } + } + return tlsconn, state, nil +} diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index db5fe38..7d40520 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -16,6 +16,7 @@ import ( "github.com/apex/log" "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) @@ -432,7 +433,11 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) { if thl.Logger != log.Log { t.Fatal("invalid logger") } - thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable) + ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper) + if !okay { + t.Fatal("invalid type") + } + thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable) if !okay { t.Fatal("invalid type") } @@ -480,3 +485,51 @@ func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) { } } } + +func TestTLSHandshakerErrWrapper(t *testing.T) { + t.Run("Handshake", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expectedConn := &mocks.TLSConn{} + expectedState := tls.ConnectionState{ + Version: tls.VersionTLS12, + } + th := &tlsHandshakerErrWrapper{ + TLSHandshaker: &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return expectedConn, expectedState, nil + }, + }, + } + ctx := context.Background() + conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) + if err != nil { + t.Fatal(err) + } + if expectedState.Version != state.Version { + t.Fatal("unexpected state") + } + if expectedConn != conn { + t.Fatal("unexpected conn") + } + }) + + t.Run("on failure", func(t *testing.T) { + expectedErr := io.EOF + th := &tlsHandshakerErrWrapper{ + TLSHandshaker: &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, expectedErr + }, + }, + } + ctx := context.Background() + conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) + if err == nil || err.Error() != errorsx.FailureEOFError { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("unexpected conn") + } + }) + }) +} diff --git a/internal/netxlite/utls.go b/internal/netxlite/utls.go index c9968b7..d836a1a 100644 --- a/internal/netxlite/utls.go +++ b/internal/netxlite/utls.go @@ -13,8 +13,10 @@ import ( // gitlab.com/yawning/utls library to create TLS conns. func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker { return &tlsHandshakerLogger{ - TLSHandshaker: &tlsHandshakerConfigurable{ - NewConn: newConnUTLS(id), + TLSHandshaker: &tlsHandshakerErrWrapper{ + TLSHandshaker: &tlsHandshakerConfigurable{ + NewConn: newConnUTLS(id), + }, }, Logger: logger, } diff --git a/internal/netxlite/utls_test.go b/internal/netxlite/utls_test.go index 12388a1..cf2e894 100644 --- a/internal/netxlite/utls_test.go +++ b/internal/netxlite/utls_test.go @@ -40,7 +40,11 @@ func TestNewTLSHandshakerUTLSTypes(t *testing.T) { if thl.Logger != log.Log { t.Fatal("invalid logger") } - thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable) + ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper) + if !okay { + t.Fatal("invalid type") + } + thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable) if !okay { t.Fatal("invalid type") }