fix(netxlite): add error wrappers (#480)
See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
ee78c76085
commit
323266da83
|
@ -6,6 +6,8 @@ import (
|
|||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
)
|
||||
|
||||
// Dialer establishes network connections.
|
||||
|
@ -22,7 +24,9 @@ func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
|
|||
return &dialerLogger{
|
||||
Dialer: &dialerResolver{
|
||||
Dialer: &dialerLogger{
|
||||
Dialer: &dialerErrWrapper{
|
||||
Dialer: &dialerSystem{},
|
||||
},
|
||||
Logger: logger,
|
||||
operationSuffix: "_address",
|
||||
},
|
||||
|
@ -188,3 +192,75 @@ func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr
|
|||
func (s *dialerSingleUse) CloseIdleConnections() {
|
||||
// nothing
|
||||
}
|
||||
|
||||
// TODO(bassosimone): introduce factory for creating errors and
|
||||
// write tests that ensure the factory works correctly.
|
||||
|
||||
// dialerErrWrapper is a dialer that performs error wrapping. The connection
|
||||
// returned by the DialContext function will also perform error wrapping.
|
||||
type dialerErrWrapper struct {
|
||||
// Dialer is the underlying dialer.
|
||||
Dialer
|
||||
}
|
||||
|
||||
var _ Dialer = &dialerErrWrapper{}
|
||||
|
||||
// DialContext implements Dialer.DialContext.
|
||||
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyGenericError(err),
|
||||
Operation: errorsx.ConnectOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return &dialerErrWrapperConn{Conn: conn}, nil
|
||||
}
|
||||
|
||||
// dialerErrWrapperConn is a net.Conn that performs error wrapping.
|
||||
type dialerErrWrapperConn struct {
|
||||
// Conn is the underlying connection.
|
||||
net.Conn
|
||||
}
|
||||
|
||||
var _ net.Conn = &dialerErrWrapperConn{}
|
||||
|
||||
// Read implements net.Conn.Read.
|
||||
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
||||
count, err := c.Conn.Read(b)
|
||||
if err != nil {
|
||||
return 0, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyGenericError(err),
|
||||
Operation: errorsx.ReadOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write.
|
||||
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
||||
count, err := c.Conn.Write(b)
|
||||
if err != nil {
|
||||
return 0, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyGenericError(err),
|
||||
Operation: errorsx.WriteOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Close implements net.Conn.Close.
|
||||
func (c *dialerErrWrapperConn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
if err != nil {
|
||||
return &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyGenericError(err),
|
||||
Operation: errorsx.CloseOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
|
@ -231,7 +232,11 @@ func TestNewDialerWithoutResolverChain(t *testing.T) {
|
|||
if dlog.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
if _, okay := dlog.Dialer.(*dialerSystem); !okay {
|
||||
dew, okay := dlog.Dialer.(*dialerErrWrapper)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
if _, okay := dew.Dialer.(*dialerSystem); !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
}
|
||||
|
@ -256,3 +261,172 @@ func TestNewSingleUseDialerWorksAsIntended(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialerErrWrapper(t *testing.T) {
|
||||
t.Run("DialContext on success", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedConn := &mocks.Conn{}
|
||||
d := &dialerErrWrapper{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return expectedConn, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialContext(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
errWrapperConn := conn.(*dialerErrWrapperConn)
|
||||
if errWrapperConn.Conn != expectedConn {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expectedErr := io.EOF
|
||||
d := &dialerErrWrapper{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialContext(ctx, "", "")
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
d := &dialerErrWrapper{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
},
|
||||
}
|
||||
d.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialerErrWrapperConn(t *testing.T) {
|
||||
t.Run("Read", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
b := make([]byte, 128)
|
||||
conn := &dialerErrWrapperConn{
|
||||
Conn: &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
count, err := conn.Read(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != len(b) {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
b := make([]byte, 128)
|
||||
expectedErr := io.EOF
|
||||
conn := &dialerErrWrapperConn{
|
||||
Conn: &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
count, err := conn.Read(b)
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Write", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
b := make([]byte, 128)
|
||||
conn := &dialerErrWrapperConn{
|
||||
Conn: &mocks.Conn{
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
count, err := conn.Write(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != len(b) {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
b := make([]byte, 128)
|
||||
expectedErr := io.EOF
|
||||
conn := &dialerErrWrapperConn{
|
||||
Conn: &mocks.Conn{
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
count, err := conn.Write(b)
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Close", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
conn := &dialerErrWrapperConn{
|
||||
Conn: &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expectedErr := io.EOF
|
||||
conn := &dialerErrWrapperConn{
|
||||
Conn: &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
return expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
err := conn.Close()
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
|
||||
)
|
||||
|
||||
|
@ -21,7 +22,7 @@ type QUICListener interface {
|
|||
// NewQUICListener creates a new QUICListener using the standard
|
||||
// library to create listening UDP sockets.
|
||||
func NewQUICListener() QUICListener {
|
||||
return &quicListenerStdlib{}
|
||||
return &quicListenerErrWrapper{&quicListenerStdlib{}}
|
||||
}
|
||||
|
||||
// quicListenerStdlib is a QUICListener using the standard library.
|
||||
|
@ -54,9 +55,10 @@ func NewQUICDialerWithResolver(listener QUICListener,
|
|||
return &quicDialerLogger{
|
||||
Dialer: &quicDialerResolver{
|
||||
Dialer: &quicDialerLogger{
|
||||
Dialer: &quicDialerQUICGo{
|
||||
Dialer: &quicDialerErrWrapper{
|
||||
QUICDialer: &quicDialerQUICGo{
|
||||
QUICListener: listener,
|
||||
},
|
||||
}},
|
||||
Logger: logger,
|
||||
operationSuffix: "_address",
|
||||
},
|
||||
|
@ -322,3 +324,78 @@ func (s *quicDialerSingleUse) DialContext(
|
|||
func (s *quicDialerSingleUse) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
// quicListenerErrWrapper is a QUICListener that wraps errors.
|
||||
type quicListenerErrWrapper struct {
|
||||
// QUICListener is the underlying listener.
|
||||
QUICListener
|
||||
}
|
||||
|
||||
var _ QUICListener = &quicListenerErrWrapper{}
|
||||
|
||||
// Listen implements QUICListener.Listen.
|
||||
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
|
||||
pconn, err := qls.QUICListener.Listen(addr)
|
||||
if err != nil {
|
||||
return nil, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyGenericError(err),
|
||||
Operation: errorsx.QUICListenOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return &quicErrWrapperUDPLikeConn{pconn}, nil
|
||||
}
|
||||
|
||||
// quicErrWrapperUDPLikeConn is a quicx.UDPLikeConn that wraps errors.
|
||||
type quicErrWrapperUDPLikeConn struct {
|
||||
// UDPLikeConn is the underlying conn.
|
||||
quicx.UDPLikeConn
|
||||
}
|
||||
|
||||
var _ quicx.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
|
||||
|
||||
// WriteTo implements quicx.UDPLikeConn.WriteTo.
|
||||
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
count, err := c.UDPLikeConn.WriteTo(p, addr)
|
||||
if err != nil {
|
||||
return 0, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyGenericError(err),
|
||||
Operation: errorsx.WriteToOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ReadFrom implements quicx.UDPLikeConn.ReadFrom.
|
||||
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, addr, err := c.UDPLikeConn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return 0, nil, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyGenericError(err),
|
||||
Operation: errorsx.ReadFromOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
|
||||
// quicDialerErrWrapper is a dialer that performs quic err wrapping
|
||||
type quicDialerErrWrapper struct {
|
||||
QUICDialer
|
||||
}
|
||||
|
||||
// DialContext implements ContextDialer.DialContext
|
||||
func (d *quicDialerErrWrapper) DialContext(
|
||||
ctx context.Context, network string, host string,
|
||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
sess, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyQUICHandshakeError(err),
|
||||
Operation: errorsx.QUICHandshakeOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -11,6 +12,7 @@ import (
|
|||
"github.com/apex/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
|
||||
)
|
||||
|
@ -452,7 +454,11 @@ func TestNewQUICDialerWithoutResolverChain(t *testing.T) {
|
|||
if dlog.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
dgo, okay := dlog.Dialer.(*quicDialerQUICGo)
|
||||
ew, okay := dlog.Dialer.(*quicDialerErrWrapper)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
dgo, okay := ew.QUICDialer.(*quicDialerQUICGo)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
|
@ -483,3 +489,188 @@ func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICListenerErrWrapper(t *testing.T) {
|
||||
t.Run("Listen", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedConn := &mocks.QUICUDPConn{}
|
||||
ql := &quicListenerErrWrapper{
|
||||
QUICListener: &mocks.QUICListener{
|
||||
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
|
||||
return expectedConn, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := ql.Listen(&net.UDPAddr{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ewconn := conn.(*quicErrWrapperUDPLikeConn)
|
||||
if ewconn.UDPLikeConn != expectedConn {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expectedErr := io.EOF
|
||||
ql := &quicListenerErrWrapper{
|
||||
QUICListener: &mocks.QUICListener{
|
||||
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
|
||||
return nil, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := ql.Listen(&net.UDPAddr{})
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestQUICErrWrapperUDPLikeConn(t *testing.T) {
|
||||
t.Run("ReadFrom", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedAddr := &net.UDPAddr{}
|
||||
p := make([]byte, 128)
|
||||
conn := &quicErrWrapperUDPLikeConn{
|
||||
UDPLikeConn: &mocks.QUICUDPConn{
|
||||
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
||||
return len(p), expectedAddr, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
count, addr, err := conn.ReadFrom(p)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != len(p) {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
if addr != expectedAddr {
|
||||
t.Fatal("unexpected addr")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
p := make([]byte, 128)
|
||||
expectedErr := io.EOF
|
||||
conn := &quicErrWrapperUDPLikeConn{
|
||||
UDPLikeConn: &mocks.QUICUDPConn{
|
||||
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
||||
return 0, nil, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
count, addr, err := conn.ReadFrom(p)
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
if addr != nil {
|
||||
t.Fatal("unexpected addr")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("WriteTo", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
p := make([]byte, 128)
|
||||
conn := &quicErrWrapperUDPLikeConn{
|
||||
UDPLikeConn: &mocks.QUICUDPConn{
|
||||
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
|
||||
return len(p), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
count, err := conn.WriteTo(p, &net.UDPAddr{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != len(p) {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
p := make([]byte, 128)
|
||||
expectedErr := io.EOF
|
||||
conn := &quicErrWrapperUDPLikeConn{
|
||||
UDPLikeConn: &mocks.QUICUDPConn{
|
||||
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
|
||||
return 0, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
count, err := conn.WriteTo(p, &net.UDPAddr{})
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestQUICDialerErrWrapper(t *testing.T) {
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
d := &quicDialerErrWrapper{
|
||||
QUICDialer: &mocks.QUICDialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
},
|
||||
}
|
||||
d.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedSess := &mocks.QUICEarlySession{}
|
||||
d := &quicDialerErrWrapper{
|
||||
QUICDialer: &mocks.QUICDialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
return expectedSess, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess != expectedSess {
|
||||
t.Fatal("unexpected sess")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expectedErr := io.EOF
|
||||
d := &quicDialerErrWrapper{
|
||||
QUICDialer: &mocks.QUICDialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
return nil, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
sess, err := d.DialContext(ctx, "", "", &tls.Config{}, &quic.Config{})
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("unexpected sess")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
|
@ -30,8 +31,10 @@ func NewResolverSystem(logger Logger) Resolver {
|
|||
return &resolverIDNA{
|
||||
Resolver: &resolverLogger{
|
||||
Resolver: &resolverShortCircuitIPAddr{
|
||||
Resolver: &resolverErrWrapper{
|
||||
Resolver: &resolverSystem{},
|
||||
},
|
||||
},
|
||||
Logger: logger,
|
||||
},
|
||||
}
|
||||
|
@ -182,3 +185,23 @@ func (r *nullResolver) Address() string {
|
|||
func (r *nullResolver) CloseIdleConnections() {
|
||||
// nothing
|
||||
}
|
||||
|
||||
// resolverErrWrapper is a Resolver that knows about wrapping errors.
|
||||
type resolverErrWrapper struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
var _ Resolver = &resolverErrWrapper{}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyResolverError(err),
|
||||
Operation: errorsx.ResolveOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package netxlite
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
|
||||
"github.com/apex/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
|
@ -196,7 +198,11 @@ func TestNewResolverTypeChain(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
if _, ok := scia.Resolver.(*resolverSystem); !ok {
|
||||
ew, ok := scia.Resolver.(*resolverErrWrapper)
|
||||
if !ok {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
if _, ok := ew.Resolver.(*resolverSystem); !ok {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
}
|
||||
|
@ -255,3 +261,88 @@ func TestNullResolverWorksAsIntended(t *testing.T) {
|
|||
}
|
||||
r.CloseIdleConnections() // should not crash
|
||||
}
|
||||
|
||||
func TestResolverErrWrapper(t *testing.T) {
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||
reso := &resolverErrWrapper{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return expected, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
addrs, err := reso.LookupHost(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, addrs); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := io.EOF
|
||||
reso := &resolverErrWrapper{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
addrs, err := reso.LookupHost(ctx, "")
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("unexpected addrs")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Network", func(t *testing.T) {
|
||||
expected := "foobar"
|
||||
reso := &resolverErrWrapper{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockNetwork: func() string {
|
||||
return expected
|
||||
},
|
||||
},
|
||||
}
|
||||
if reso.Network() != expected {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Address", func(t *testing.T) {
|
||||
expected := "foobar"
|
||||
reso := &resolverErrWrapper{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockAddress: func() string {
|
||||
return expected
|
||||
},
|
||||
},
|
||||
}
|
||||
if reso.Address() != expected {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
reso := &resolverErrWrapper{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
},
|
||||
}
|
||||
reso.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
oohttp "github.com/ooni/oohttp"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -125,7 +126,9 @@ type TLSHandshaker interface {
|
|||
// go standard library to create TLS connections.
|
||||
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
|
||||
return &tlsHandshakerLogger{
|
||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &tlsHandshakerConfigurable{},
|
||||
},
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
|
@ -319,3 +322,23 @@ var _ TLSDialer = &tlsDialerSingleUseAdapter{}
|
|||
func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
|
||||
type tlsHandshakerErrWrapper struct {
|
||||
TLSHandshaker
|
||||
}
|
||||
|
||||
// Handshake implements TLSHandshaker.Handshake
|
||||
func (h *tlsHandshakerErrWrapper) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
if err != nil {
|
||||
return nil, tls.ConnectionState{}, &errorsx.ErrWrapper{
|
||||
Failure: errorsx.ClassifyTLSHandshakeError(err),
|
||||
Operation: errorsx.TLSHandshakeOperation,
|
||||
WrappedErr: err,
|
||||
}
|
||||
}
|
||||
return tlsconn, state, nil
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
|
||||
"github.com/apex/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
|
@ -432,7 +433,11 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) {
|
|||
if thl.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
|
@ -480,3 +485,51 @@ func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerErrWrapper(t *testing.T) {
|
||||
t.Run("Handshake", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedConn := &mocks.TLSConn{}
|
||||
expectedState := tls.ConnectionState{
|
||||
Version: tls.VersionTLS12,
|
||||
}
|
||||
th := &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return expectedConn, expectedState, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if expectedState.Version != state.Version {
|
||||
t.Fatal("unexpected state")
|
||||
}
|
||||
if expectedConn != conn {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expectedErr := io.EOF
|
||||
th := &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return nil, tls.ConnectionState{}, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if err == nil || err.Error() != errorsx.FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
@ -13,9 +13,11 @@ import (
|
|||
// gitlab.com/yawning/utls library to create TLS conns.
|
||||
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
|
||||
return &tlsHandshakerLogger{
|
||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &tlsHandshakerConfigurable{
|
||||
NewConn: newConnUTLS(id),
|
||||
},
|
||||
},
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,11 @@ func TestNewTLSHandshakerUTLSTypes(t *testing.T) {
|
|||
if thl.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
thc, okay := thl.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user