fix(netxlite): add error wrappers (#480)

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-07 19:56:42 +02:00 committed by GitHub
parent ee78c76085
commit 323266da83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 729 additions and 15 deletions

View File

@ -6,6 +6,8 @@ import (
"net" "net"
"sync" "sync"
"time" "time"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
) )
// Dialer establishes network connections. // Dialer establishes network connections.
@ -22,7 +24,9 @@ func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
return &dialerLogger{ return &dialerLogger{
Dialer: &dialerResolver{ Dialer: &dialerResolver{
Dialer: &dialerLogger{ Dialer: &dialerLogger{
Dialer: &dialerSystem{}, Dialer: &dialerErrWrapper{
Dialer: &dialerSystem{},
},
Logger: logger, Logger: logger,
operationSuffix: "_address", operationSuffix: "_address",
}, },
@ -188,3 +192,75 @@ func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr
func (s *dialerSingleUse) CloseIdleConnections() { func (s *dialerSingleUse) CloseIdleConnections() {
// nothing // nothing
} }
// TODO(bassosimone): introduce factory for creating errors and
// write tests that ensure the factory works correctly.
// dialerErrWrapper is a dialer that performs error wrapping. The connection
// returned by the DialContext function will also perform error wrapping.
type dialerErrWrapper struct {
// Dialer is the underlying dialer.
Dialer
}
var _ Dialer = &dialerErrWrapper{}
// DialContext implements Dialer.DialContext.
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.ConnectOperation,
WrappedErr: err,
}
}
return &dialerErrWrapperConn{Conn: conn}, nil
}
// dialerErrWrapperConn is a net.Conn that performs error wrapping.
type dialerErrWrapperConn struct {
// Conn is the underlying connection.
net.Conn
}
var _ net.Conn = &dialerErrWrapperConn{}
// Read implements net.Conn.Read.
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b)
if err != nil {
return 0, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.ReadOperation,
WrappedErr: err,
}
}
return count, nil
}
// Write implements net.Conn.Write.
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
count, err := c.Conn.Write(b)
if err != nil {
return 0, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.WriteOperation,
WrappedErr: err,
}
}
return count, nil
}
// Close implements net.Conn.Close.
func (c *dialerErrWrapperConn) Close() error {
err := c.Conn.Close()
if err != nil {
return &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.CloseOperation,
WrappedErr: err,
}
}
return nil
}

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
@ -231,7 +232,11 @@ func TestNewDialerWithoutResolverChain(t *testing.T) {
if dlog.Logger != log.Log { if dlog.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
if _, okay := dlog.Dialer.(*dialerSystem); !okay { dew, okay := dlog.Dialer.(*dialerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
if _, okay := dew.Dialer.(*dialerSystem); !okay {
t.Fatal("invalid type") t.Fatal("invalid type")
} }
} }
@ -256,3 +261,172 @@ func TestNewSingleUseDialerWorksAsIntended(t *testing.T) {
} }
} }
} }
func TestDialerErrWrapper(t *testing.T) {
t.Run("DialContext on success", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.Conn{}
d := &dialerErrWrapper{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return expectedConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
errWrapperConn := conn.(*dialerErrWrapperConn)
if errWrapperConn.Conn != expectedConn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
d := &dialerErrWrapper{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expectedErr
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
d := &dialerErrWrapper{
Dialer: &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
},
}
d.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}
func TestDialerErrWrapperConn(t *testing.T) {
t.Run("Read", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
b := make([]byte, 128)
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
},
}
count, err := conn.Read(b)
if err != nil {
t.Fatal(err)
}
if count != len(b) {
t.Fatal("unexpected count")
}
})
t.Run("on failure", func(t *testing.T) {
b := make([]byte, 128)
expectedErr := io.EOF
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, expectedErr
},
},
}
count, err := conn.Read(b)
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
})
})
t.Run("Write", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
b := make([]byte, 128)
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
},
}
count, err := conn.Write(b)
if err != nil {
t.Fatal(err)
}
if count != len(b) {
t.Fatal("unexpected count")
}
})
t.Run("on failure", func(t *testing.T) {
b := make([]byte, 128)
expectedErr := io.EOF
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return 0, expectedErr
},
},
}
count, err := conn.Write(b)
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
})
})
t.Run("Close", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockClose: func() error {
return nil
},
},
}
err := conn.Close()
if err != nil {
t.Fatal(err)
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockClose: func() error {
return expectedErr
},
},
}
err := conn.Close()
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
})
})
}

View File

@ -9,6 +9,7 @@ import (
"sync" "sync"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx" "github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
) )
@ -21,7 +22,7 @@ type QUICListener interface {
// NewQUICListener creates a new QUICListener using the standard // NewQUICListener creates a new QUICListener using the standard
// library to create listening UDP sockets. // library to create listening UDP sockets.
func NewQUICListener() QUICListener { func NewQUICListener() QUICListener {
return &quicListenerStdlib{} return &quicListenerErrWrapper{&quicListenerStdlib{}}
} }
// quicListenerStdlib is a QUICListener using the standard library. // quicListenerStdlib is a QUICListener using the standard library.
@ -54,9 +55,10 @@ func NewQUICDialerWithResolver(listener QUICListener,
return &quicDialerLogger{ return &quicDialerLogger{
Dialer: &quicDialerResolver{ Dialer: &quicDialerResolver{
Dialer: &quicDialerLogger{ Dialer: &quicDialerLogger{
Dialer: &quicDialerQUICGo{ Dialer: &quicDialerErrWrapper{
QUICListener: listener, QUICDialer: &quicDialerQUICGo{
}, QUICListener: listener,
}},
Logger: logger, Logger: logger,
operationSuffix: "_address", operationSuffix: "_address",
}, },
@ -322,3 +324,78 @@ func (s *quicDialerSingleUse) DialContext(
func (s *quicDialerSingleUse) CloseIdleConnections() { func (s *quicDialerSingleUse) CloseIdleConnections() {
// nothing to do // nothing to do
} }
// quicListenerErrWrapper is a QUICListener that wraps errors.
type quicListenerErrWrapper struct {
// QUICListener is the underlying listener.
QUICListener
}
var _ QUICListener = &quicListenerErrWrapper{}
// Listen implements QUICListener.Listen.
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
pconn, err := qls.QUICListener.Listen(addr)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.QUICListenOperation,
WrappedErr: err,
}
}
return &quicErrWrapperUDPLikeConn{pconn}, nil
}
// quicErrWrapperUDPLikeConn is a quicx.UDPLikeConn that wraps errors.
type quicErrWrapperUDPLikeConn struct {
// UDPLikeConn is the underlying conn.
quicx.UDPLikeConn
}
var _ quicx.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
// WriteTo implements quicx.UDPLikeConn.WriteTo.
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
count, err := c.UDPLikeConn.WriteTo(p, addr)
if err != nil {
return 0, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.WriteToOperation,
WrappedErr: err,
}
}
return count, nil
}
// ReadFrom implements quicx.UDPLikeConn.ReadFrom.
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := c.UDPLikeConn.ReadFrom(b)
if err != nil {
return 0, nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.ReadFromOperation,
WrappedErr: err,
}
}
return n, addr, nil
}
// quicDialerErrWrapper is a dialer that performs quic err wrapping
type quicDialerErrWrapper struct {
QUICDialer
}
// DialContext implements ContextDialer.DialContext
func (d *quicDialerErrWrapper) DialContext(
ctx context.Context, network string, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
sess, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyQUICHandshakeError(err),
Operation: errorsx.QUICHandshakeOperation,
WrappedErr: err,
}
}
return sess, nil
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io"
"net" "net"
"strings" "strings"
"testing" "testing"
@ -11,6 +12,7 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx" "github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
) )
@ -452,7 +454,11 @@ func TestNewQUICDialerWithoutResolverChain(t *testing.T) {
if dlog.Logger != log.Log { if dlog.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
dgo, okay := dlog.Dialer.(*quicDialerQUICGo) ew, okay := dlog.Dialer.(*quicDialerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
dgo, okay := ew.QUICDialer.(*quicDialerQUICGo)
if !okay { if !okay {
t.Fatal("invalid type") t.Fatal("invalid type")
} }
@ -483,3 +489,188 @@ func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) {
} }
} }
} }
func TestQUICListenerErrWrapper(t *testing.T) {
t.Run("Listen", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.QUICUDPConn{}
ql := &quicListenerErrWrapper{
QUICListener: &mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
return expectedConn, nil
},
},
}
conn, err := ql.Listen(&net.UDPAddr{})
if err != nil {
t.Fatal(err)
}
ewconn := conn.(*quicErrWrapperUDPLikeConn)
if ewconn.UDPLikeConn != expectedConn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
ql := &quicListenerErrWrapper{
QUICListener: &mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
return nil, expectedErr
},
},
}
conn, err := ql.Listen(&net.UDPAddr{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
}
func TestQUICErrWrapperUDPLikeConn(t *testing.T) {
t.Run("ReadFrom", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedAddr := &net.UDPAddr{}
p := make([]byte, 128)
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
return len(p), expectedAddr, nil
},
},
}
count, addr, err := conn.ReadFrom(p)
if err != nil {
t.Fatal(err)
}
if count != len(p) {
t.Fatal("unexpected count")
}
if addr != expectedAddr {
t.Fatal("unexpected addr")
}
})
t.Run("on failure", func(t *testing.T) {
p := make([]byte, 128)
expectedErr := io.EOF
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, expectedErr
},
},
}
count, addr, err := conn.ReadFrom(p)
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
if addr != nil {
t.Fatal("unexpected addr")
}
})
})
t.Run("WriteTo", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
p := make([]byte, 128)
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return len(p), nil
},
},
}
count, err := conn.WriteTo(p, &net.UDPAddr{})
if err != nil {
t.Fatal(err)
}
if count != len(p) {
t.Fatal("unexpected count")
}
})
t.Run("on failure", func(t *testing.T) {
p := make([]byte, 128)
expectedErr := io.EOF
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return 0, expectedErr
},
},
}
count, err := conn.WriteTo(p, &net.UDPAddr{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
})
})
}
func TestQUICDialerErrWrapper(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
d := &quicDialerErrWrapper{
QUICDialer: &mocks.QUICDialer{
MockCloseIdleConnections: func() {
called = true
},
},
}
d.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedSess := &mocks.QUICEarlySession{}
d := &quicDialerErrWrapper{
QUICDialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return expectedSess, nil
},
},
}
ctx := context.Background()
sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{})
if err != nil {
t.Fatal(err)
}
if sess != expectedSess {
t.Fatal("unexpected sess")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
d := &quicDialerErrWrapper{
QUICDialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return nil, expectedErr
},
},
}
ctx := context.Background()
sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if sess != nil {
t.Fatal("unexpected sess")
}
})
})
}

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"time" "time"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -30,7 +31,9 @@ func NewResolverSystem(logger Logger) Resolver {
return &resolverIDNA{ return &resolverIDNA{
Resolver: &resolverLogger{ Resolver: &resolverLogger{
Resolver: &resolverShortCircuitIPAddr{ Resolver: &resolverShortCircuitIPAddr{
Resolver: &resolverSystem{}, Resolver: &resolverErrWrapper{
Resolver: &resolverSystem{},
},
}, },
Logger: logger, Logger: logger,
}, },
@ -182,3 +185,23 @@ func (r *nullResolver) Address() string {
func (r *nullResolver) CloseIdleConnections() { func (r *nullResolver) CloseIdleConnections() {
// nothing // nothing
} }
// resolverErrWrapper is a Resolver that knows about wrapping errors.
type resolverErrWrapper struct {
Resolver
}
var _ Resolver = &resolverErrWrapper{}
// LookupHost implements Resolver.LookupHost
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyResolverError(err),
Operation: errorsx.ResolveOperation,
WrappedErr: err,
}
}
return addrs, nil
}

View File

@ -3,6 +3,7 @@ package netxlite
import ( import (
"context" "context"
"errors" "errors"
"io"
"strings" "strings"
"sync" "sync"
"testing" "testing"
@ -10,6 +11,7 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
@ -196,7 +198,11 @@ func TestNewResolverTypeChain(t *testing.T) {
if !ok { if !ok {
t.Fatal("invalid resolver") t.Fatal("invalid resolver")
} }
if _, ok := scia.Resolver.(*resolverSystem); !ok { ew, ok := scia.Resolver.(*resolverErrWrapper)
if !ok {
t.Fatal("invalid resolver")
}
if _, ok := ew.Resolver.(*resolverSystem); !ok {
t.Fatal("invalid resolver") t.Fatal("invalid resolver")
} }
} }
@ -255,3 +261,88 @@ func TestNullResolverWorksAsIntended(t *testing.T) {
} }
r.CloseIdleConnections() // should not crash r.CloseIdleConnections() // should not crash
} }
func TestResolverErrWrapper(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil
},
},
}
ctx := context.Background()
addrs, err := reso.LookupHost(ctx, "")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff)
}
})
t.Run("on failure", func(t *testing.T) {
expected := io.EOF
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, expected
},
},
}
ctx := context.Background()
addrs, err := reso.LookupHost(ctx, "")
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if addrs != nil {
t.Fatal("unexpected addrs")
}
})
})
t.Run("Network", func(t *testing.T) {
expected := "foobar"
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockNetwork: func() string {
return expected
},
},
}
if reso.Network() != expected {
t.Fatal("invalid network")
}
})
t.Run("Address", func(t *testing.T) {
expected := "foobar"
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockAddress: func() string {
return expected
},
},
}
if reso.Address() != expected {
t.Fatal("invalid address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockCloseIdleConnections: func() {
called = true
},
},
}
reso.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}

View File

@ -10,6 +10,7 @@ import (
"time" "time"
oohttp "github.com/ooni/oohttp" oohttp "github.com/ooni/oohttp"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
) )
var ( var (
@ -125,8 +126,10 @@ type TLSHandshaker interface {
// go standard library to create TLS connections. // go standard library to create TLS connections.
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker { func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
return &tlsHandshakerLogger{ return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerConfigurable{}, TLSHandshaker: &tlsHandshakerErrWrapper{
Logger: logger, TLSHandshaker: &tlsHandshakerConfigurable{},
},
Logger: logger,
} }
} }
@ -319,3 +322,23 @@ var _ TLSDialer = &tlsDialerSingleUseAdapter{}
func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address) return d.Dialer.DialContext(ctx, network, address)
} }
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
type tlsHandshakerErrWrapper struct {
TLSHandshaker
}
// Handshake implements TLSHandshaker.Handshake
func (h *tlsHandshakerErrWrapper) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil {
return nil, tls.ConnectionState{}, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyTLSHandshakeError(err),
Operation: errorsx.TLSHandshakeOperation,
WrappedErr: err,
}
}
return tlsconn, state, nil
}

View File

@ -16,6 +16,7 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
@ -432,7 +433,11 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) {
if thl.Logger != log.Log { if thl.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable) ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
if !okay { if !okay {
t.Fatal("invalid type") t.Fatal("invalid type")
} }
@ -480,3 +485,51 @@ func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) {
} }
} }
} }
func TestTLSHandshakerErrWrapper(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.TLSConn{}
expectedState := tls.ConnectionState{
Version: tls.VersionTLS12,
}
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return expectedConn, expectedState, nil
},
},
}
ctx := context.Background()
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err != nil {
t.Fatal(err)
}
if expectedState.Version != state.Version {
t.Fatal("unexpected state")
}
if expectedConn != conn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expectedErr
},
},
}
ctx := context.Background()
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("unexpected conn")
}
})
})
}

View File

@ -13,8 +13,10 @@ import (
// gitlab.com/yawning/utls library to create TLS conns. // gitlab.com/yawning/utls library to create TLS conns.
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker { func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
return &tlsHandshakerLogger{ return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerConfigurable{ TLSHandshaker: &tlsHandshakerErrWrapper{
NewConn: newConnUTLS(id), TLSHandshaker: &tlsHandshakerConfigurable{
NewConn: newConnUTLS(id),
},
}, },
Logger: logger, Logger: logger,
} }

View File

@ -40,7 +40,11 @@ func TestNewTLSHandshakerUTLSTypes(t *testing.T) {
if thl.Logger != log.Log { if thl.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable) ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
if !okay { if !okay {
t.Fatal("invalid type") t.Fatal("invalid type")
} }