fix(netxlite): add error wrappers (#480)

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-07 19:56:42 +02:00 committed by GitHub
parent ee78c76085
commit 323266da83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 729 additions and 15 deletions

View File

@ -6,6 +6,8 @@ import (
"net"
"sync"
"time"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
)
// Dialer establishes network connections.
@ -22,7 +24,9 @@ func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
return &dialerLogger{
Dialer: &dialerResolver{
Dialer: &dialerLogger{
Dialer: &dialerSystem{},
Dialer: &dialerErrWrapper{
Dialer: &dialerSystem{},
},
Logger: logger,
operationSuffix: "_address",
},
@ -188,3 +192,75 @@ func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr
func (s *dialerSingleUse) CloseIdleConnections() {
// nothing
}
// TODO(bassosimone): introduce factory for creating errors and
// write tests that ensure the factory works correctly.
// dialerErrWrapper is a dialer that performs error wrapping. The connection
// returned by the DialContext function will also perform error wrapping.
type dialerErrWrapper struct {
// Dialer is the underlying dialer.
Dialer
}
var _ Dialer = &dialerErrWrapper{}
// DialContext implements Dialer.DialContext.
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.ConnectOperation,
WrappedErr: err,
}
}
return &dialerErrWrapperConn{Conn: conn}, nil
}
// dialerErrWrapperConn is a net.Conn that performs error wrapping.
type dialerErrWrapperConn struct {
// Conn is the underlying connection.
net.Conn
}
var _ net.Conn = &dialerErrWrapperConn{}
// Read implements net.Conn.Read.
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b)
if err != nil {
return 0, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.ReadOperation,
WrappedErr: err,
}
}
return count, nil
}
// Write implements net.Conn.Write.
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
count, err := c.Conn.Write(b)
if err != nil {
return 0, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.WriteOperation,
WrappedErr: err,
}
}
return count, nil
}
// Close implements net.Conn.Close.
func (c *dialerErrWrapperConn) Close() error {
err := c.Conn.Close()
if err != nil {
return &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.CloseOperation,
WrappedErr: err,
}
}
return nil
}

View File

@ -10,6 +10,7 @@ import (
"time"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
@ -231,7 +232,11 @@ func TestNewDialerWithoutResolverChain(t *testing.T) {
if dlog.Logger != log.Log {
t.Fatal("invalid logger")
}
if _, okay := dlog.Dialer.(*dialerSystem); !okay {
dew, okay := dlog.Dialer.(*dialerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
if _, okay := dew.Dialer.(*dialerSystem); !okay {
t.Fatal("invalid type")
}
}
@ -256,3 +261,172 @@ func TestNewSingleUseDialerWorksAsIntended(t *testing.T) {
}
}
}
func TestDialerErrWrapper(t *testing.T) {
t.Run("DialContext on success", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.Conn{}
d := &dialerErrWrapper{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return expectedConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
errWrapperConn := conn.(*dialerErrWrapperConn)
if errWrapperConn.Conn != expectedConn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
d := &dialerErrWrapper{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expectedErr
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
d := &dialerErrWrapper{
Dialer: &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
},
}
d.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}
func TestDialerErrWrapperConn(t *testing.T) {
t.Run("Read", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
b := make([]byte, 128)
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
},
}
count, err := conn.Read(b)
if err != nil {
t.Fatal(err)
}
if count != len(b) {
t.Fatal("unexpected count")
}
})
t.Run("on failure", func(t *testing.T) {
b := make([]byte, 128)
expectedErr := io.EOF
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, expectedErr
},
},
}
count, err := conn.Read(b)
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
})
})
t.Run("Write", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
b := make([]byte, 128)
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
},
}
count, err := conn.Write(b)
if err != nil {
t.Fatal(err)
}
if count != len(b) {
t.Fatal("unexpected count")
}
})
t.Run("on failure", func(t *testing.T) {
b := make([]byte, 128)
expectedErr := io.EOF
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return 0, expectedErr
},
},
}
count, err := conn.Write(b)
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
})
})
t.Run("Close", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockClose: func() error {
return nil
},
},
}
err := conn.Close()
if err != nil {
t.Fatal(err)
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
conn := &dialerErrWrapperConn{
Conn: &mocks.Conn{
MockClose: func() error {
return expectedErr
},
},
}
err := conn.Close()
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
})
})
}

View File

@ -9,6 +9,7 @@ import (
"sync"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
)
@ -21,7 +22,7 @@ type QUICListener interface {
// NewQUICListener creates a new QUICListener using the standard
// library to create listening UDP sockets.
func NewQUICListener() QUICListener {
return &quicListenerStdlib{}
return &quicListenerErrWrapper{&quicListenerStdlib{}}
}
// quicListenerStdlib is a QUICListener using the standard library.
@ -54,9 +55,10 @@ func NewQUICDialerWithResolver(listener QUICListener,
return &quicDialerLogger{
Dialer: &quicDialerResolver{
Dialer: &quicDialerLogger{
Dialer: &quicDialerQUICGo{
QUICListener: listener,
},
Dialer: &quicDialerErrWrapper{
QUICDialer: &quicDialerQUICGo{
QUICListener: listener,
}},
Logger: logger,
operationSuffix: "_address",
},
@ -322,3 +324,78 @@ func (s *quicDialerSingleUse) DialContext(
func (s *quicDialerSingleUse) CloseIdleConnections() {
// nothing to do
}
// quicListenerErrWrapper is a QUICListener that wraps errors.
type quicListenerErrWrapper struct {
// QUICListener is the underlying listener.
QUICListener
}
var _ QUICListener = &quicListenerErrWrapper{}
// Listen implements QUICListener.Listen.
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
pconn, err := qls.QUICListener.Listen(addr)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.QUICListenOperation,
WrappedErr: err,
}
}
return &quicErrWrapperUDPLikeConn{pconn}, nil
}
// quicErrWrapperUDPLikeConn is a quicx.UDPLikeConn that wraps errors.
type quicErrWrapperUDPLikeConn struct {
// UDPLikeConn is the underlying conn.
quicx.UDPLikeConn
}
var _ quicx.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
// WriteTo implements quicx.UDPLikeConn.WriteTo.
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
count, err := c.UDPLikeConn.WriteTo(p, addr)
if err != nil {
return 0, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.WriteToOperation,
WrappedErr: err,
}
}
return count, nil
}
// ReadFrom implements quicx.UDPLikeConn.ReadFrom.
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := c.UDPLikeConn.ReadFrom(b)
if err != nil {
return 0, nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyGenericError(err),
Operation: errorsx.ReadFromOperation,
WrappedErr: err,
}
}
return n, addr, nil
}
// quicDialerErrWrapper is a dialer that performs quic err wrapping
type quicDialerErrWrapper struct {
QUICDialer
}
// DialContext implements ContextDialer.DialContext
func (d *quicDialerErrWrapper) DialContext(
ctx context.Context, network string, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
sess, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyQUICHandshakeError(err),
Operation: errorsx.QUICHandshakeOperation,
WrappedErr: err,
}
}
return sess, nil
}

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"strings"
"testing"
@ -11,6 +12,7 @@ import (
"github.com/apex/log"
"github.com/google/go-cmp/cmp"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
)
@ -452,7 +454,11 @@ func TestNewQUICDialerWithoutResolverChain(t *testing.T) {
if dlog.Logger != log.Log {
t.Fatal("invalid logger")
}
dgo, okay := dlog.Dialer.(*quicDialerQUICGo)
ew, okay := dlog.Dialer.(*quicDialerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
dgo, okay := ew.QUICDialer.(*quicDialerQUICGo)
if !okay {
t.Fatal("invalid type")
}
@ -483,3 +489,188 @@ func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) {
}
}
}
func TestQUICListenerErrWrapper(t *testing.T) {
t.Run("Listen", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.QUICUDPConn{}
ql := &quicListenerErrWrapper{
QUICListener: &mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
return expectedConn, nil
},
},
}
conn, err := ql.Listen(&net.UDPAddr{})
if err != nil {
t.Fatal(err)
}
ewconn := conn.(*quicErrWrapperUDPLikeConn)
if ewconn.UDPLikeConn != expectedConn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
ql := &quicListenerErrWrapper{
QUICListener: &mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
return nil, expectedErr
},
},
}
conn, err := ql.Listen(&net.UDPAddr{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
}
func TestQUICErrWrapperUDPLikeConn(t *testing.T) {
t.Run("ReadFrom", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedAddr := &net.UDPAddr{}
p := make([]byte, 128)
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
return len(p), expectedAddr, nil
},
},
}
count, addr, err := conn.ReadFrom(p)
if err != nil {
t.Fatal(err)
}
if count != len(p) {
t.Fatal("unexpected count")
}
if addr != expectedAddr {
t.Fatal("unexpected addr")
}
})
t.Run("on failure", func(t *testing.T) {
p := make([]byte, 128)
expectedErr := io.EOF
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, expectedErr
},
},
}
count, addr, err := conn.ReadFrom(p)
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
if addr != nil {
t.Fatal("unexpected addr")
}
})
})
t.Run("WriteTo", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
p := make([]byte, 128)
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return len(p), nil
},
},
}
count, err := conn.WriteTo(p, &net.UDPAddr{})
if err != nil {
t.Fatal(err)
}
if count != len(p) {
t.Fatal("unexpected count")
}
})
t.Run("on failure", func(t *testing.T) {
p := make([]byte, 128)
expectedErr := io.EOF
conn := &quicErrWrapperUDPLikeConn{
UDPLikeConn: &mocks.QUICUDPConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return 0, expectedErr
},
},
}
count, err := conn.WriteTo(p, &net.UDPAddr{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
})
})
}
func TestQUICDialerErrWrapper(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
d := &quicDialerErrWrapper{
QUICDialer: &mocks.QUICDialer{
MockCloseIdleConnections: func() {
called = true
},
},
}
d.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedSess := &mocks.QUICEarlySession{}
d := &quicDialerErrWrapper{
QUICDialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return expectedSess, nil
},
},
}
ctx := context.Background()
sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{})
if err != nil {
t.Fatal(err)
}
if sess != expectedSess {
t.Fatal("unexpected sess")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
d := &quicDialerErrWrapper{
QUICDialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return nil, expectedErr
},
},
}
ctx := context.Background()
sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if sess != nil {
t.Fatal("unexpected sess")
}
})
})
}

View File

@ -6,6 +6,7 @@ import (
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"golang.org/x/net/idna"
)
@ -30,7 +31,9 @@ func NewResolverSystem(logger Logger) Resolver {
return &resolverIDNA{
Resolver: &resolverLogger{
Resolver: &resolverShortCircuitIPAddr{
Resolver: &resolverSystem{},
Resolver: &resolverErrWrapper{
Resolver: &resolverSystem{},
},
},
Logger: logger,
},
@ -182,3 +185,23 @@ func (r *nullResolver) Address() string {
func (r *nullResolver) CloseIdleConnections() {
// nothing
}
// resolverErrWrapper is a Resolver that knows about wrapping errors.
type resolverErrWrapper struct {
Resolver
}
var _ Resolver = &resolverErrWrapper{}
// LookupHost implements Resolver.LookupHost
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil {
return nil, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyResolverError(err),
Operation: errorsx.ResolveOperation,
WrappedErr: err,
}
}
return addrs, nil
}

View File

@ -3,6 +3,7 @@ package netxlite
import (
"context"
"errors"
"io"
"strings"
"sync"
"testing"
@ -10,6 +11,7 @@ import (
"github.com/apex/log"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
@ -196,7 +198,11 @@ func TestNewResolverTypeChain(t *testing.T) {
if !ok {
t.Fatal("invalid resolver")
}
if _, ok := scia.Resolver.(*resolverSystem); !ok {
ew, ok := scia.Resolver.(*resolverErrWrapper)
if !ok {
t.Fatal("invalid resolver")
}
if _, ok := ew.Resolver.(*resolverSystem); !ok {
t.Fatal("invalid resolver")
}
}
@ -255,3 +261,88 @@ func TestNullResolverWorksAsIntended(t *testing.T) {
}
r.CloseIdleConnections() // should not crash
}
func TestResolverErrWrapper(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil
},
},
}
ctx := context.Background()
addrs, err := reso.LookupHost(ctx, "")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff)
}
})
t.Run("on failure", func(t *testing.T) {
expected := io.EOF
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, expected
},
},
}
ctx := context.Background()
addrs, err := reso.LookupHost(ctx, "")
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if addrs != nil {
t.Fatal("unexpected addrs")
}
})
})
t.Run("Network", func(t *testing.T) {
expected := "foobar"
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockNetwork: func() string {
return expected
},
},
}
if reso.Network() != expected {
t.Fatal("invalid network")
}
})
t.Run("Address", func(t *testing.T) {
expected := "foobar"
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockAddress: func() string {
return expected
},
},
}
if reso.Address() != expected {
t.Fatal("invalid address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
reso := &resolverErrWrapper{
Resolver: &mocks.Resolver{
MockCloseIdleConnections: func() {
called = true
},
},
}
reso.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}

View File

@ -10,6 +10,7 @@ import (
"time"
oohttp "github.com/ooni/oohttp"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
)
var (
@ -125,8 +126,10 @@ type TLSHandshaker interface {
// go standard library to create TLS connections.
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerConfigurable{},
Logger: logger,
TLSHandshaker: &tlsHandshakerErrWrapper{
TLSHandshaker: &tlsHandshakerConfigurable{},
},
Logger: logger,
}
}
@ -319,3 +322,23 @@ var _ TLSDialer = &tlsDialerSingleUseAdapter{}
func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address)
}
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
type tlsHandshakerErrWrapper struct {
TLSHandshaker
}
// Handshake implements TLSHandshaker.Handshake
func (h *tlsHandshakerErrWrapper) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil {
return nil, tls.ConnectionState{}, &errorsx.ErrWrapper{
Failure: errorsx.ClassifyTLSHandshakeError(err),
Operation: errorsx.TLSHandshakeOperation,
WrappedErr: err,
}
}
return tlsconn, state, nil
}

View File

@ -16,6 +16,7 @@ import (
"github.com/apex/log"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
@ -432,7 +433,11 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) {
if thl.Logger != log.Log {
t.Fatal("invalid logger")
}
thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable)
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
if !okay {
t.Fatal("invalid type")
}
@ -480,3 +485,51 @@ func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) {
}
}
}
func TestTLSHandshakerErrWrapper(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.TLSConn{}
expectedState := tls.ConnectionState{
Version: tls.VersionTLS12,
}
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return expectedConn, expectedState, nil
},
},
}
ctx := context.Background()
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err != nil {
t.Fatal(err)
}
if expectedState.Version != state.Version {
t.Fatal("unexpected state")
}
if expectedConn != conn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expectedErr
},
},
}
ctx := context.Background()
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("unexpected conn")
}
})
})
}

View File

@ -13,8 +13,10 @@ import (
// gitlab.com/yawning/utls library to create TLS conns.
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerConfigurable{
NewConn: newConnUTLS(id),
TLSHandshaker: &tlsHandshakerErrWrapper{
TLSHandshaker: &tlsHandshakerConfigurable{
NewConn: newConnUTLS(id),
},
},
Logger: logger,
}

View File

@ -40,7 +40,11 @@ func TestNewTLSHandshakerUTLSTypes(t *testing.T) {
if thl.Logger != log.Log {
t.Fatal("invalid logger")
}
thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable)
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
if !okay {
t.Fatal("invalid type")
}
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
if !okay {
t.Fatal("invalid type")
}