netxlite: improve docs, tests, and code quality (#493)

* netxlite: improve docs, tests, and code quality

* better documentation

* more strict testing of dialer (especially make sure we
document the quirk in https://github.com/ooni/probe/issues/1779
and we have tests to guarantee we don't screw up here)

* introduce NewErrWrapper factory for creating errors so we
have confidence we are creating them correctly

Part of https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-08 21:19:51 +02:00 committed by GitHub
parent e68adec9a5
commit 3cd88debdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 452 additions and 141 deletions

View File

@ -6,8 +6,10 @@ import (
) )
func TestPEMCerts(t *testing.T) { func TestPEMCerts(t *testing.T) {
pool := x509.NewCertPool() t.Run("we can successfully load the bundled certificates", func(t *testing.T) {
if !pool.AppendCertsFromPEM([]byte(pemcerts)) { pool := x509.NewCertPool()
t.Fatal("cannot load pemcerts") if !pool.AppendCertsFromPEM([]byte(pemcerts)) {
} t.Fatal("cannot load pemcerts")
}
})
} }

View File

@ -22,13 +22,25 @@ type Dialer interface {
// NewDialerWithResolver creates a new Dialer. The returned Dialer // NewDialerWithResolver creates a new Dialer. The returned Dialer
// has the following properties: // has the following properties:
// //
// 1. logs events using the given logger // 1. logs events using the given logger;
// //
// 2. resolves domain names using the givern resolver // 2. resolves domain names using the givern resolver;
// //
// 3. wraps errors // 3. when using a resolver, each available enpoint is tried
// sequentially. On error, the code will return what it believes
// to be the most representative error in the pack. Most often,
// such an error is the first one that occurred. Choosing the
// error to return using this logic is a QUIRK that we owe
// to the original implementation of netx. We cannot change
// this behavior until all the legacy code that relies on
// it has been migrated to more sane patterns.
// //
// 4. has a configured connect timeout // Removing this quirk from the codebase is documented as
// TODO(https://github.com/ooni/probe/issues/1779).
//
// 4. wraps errors;
//
// 5. has a configured connect timeout.
func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer { func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
return &dialerLogger{ return &dialerLogger{
Dialer: &dialerResolver{ Dialer: &dialerResolver{
@ -51,45 +63,46 @@ func NewDialerWithoutResolver(logger Logger) Dialer {
return NewDialerWithResolver(logger, &nullResolver{}) return NewDialerWithResolver(logger, &nullResolver{})
} }
// dialerSystem dials using Go stdlib. // dialerSystem uses system facilities to perform domain name
// resolution and guarantees we have a dialer timeout.
type dialerSystem struct { type dialerSystem struct {
// timeout is the OPTIONAL timeout used for testing. // timeout is the OPTIONAL timeout used for testing.
timeout time.Duration timeout time.Duration
} }
// newUnderlyingDialer creates a new underlying dialer. var _ Dialer = &dialerSystem{}
const dialerDefaultTimeout = 15 * time.Second
func (d *dialerSystem) newUnderlyingDialer() *net.Dialer { func (d *dialerSystem) newUnderlyingDialer() *net.Dialer {
t := d.timeout t := d.timeout
if t <= 0 { if t <= 0 {
t = 15 * time.Second t = dialerDefaultTimeout
} }
return &net.Dialer{Timeout: t} return &net.Dialer{Timeout: t}
} }
// DialContext implements Dialer.DialContext.
func (d *dialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *dialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.newUnderlyingDialer().DialContext(ctx, network, address) return d.newUnderlyingDialer().DialContext(ctx, network, address)
} }
// CloseIdleConnections implements Dialer.CloseIdleConnections.
func (d *dialerSystem) CloseIdleConnections() { func (d *dialerSystem) CloseIdleConnections() {
// nothing // nothing to do here
} }
// dialerResolver is a dialer that uses the configured Resolver to resolver a // dialerResolver combines dialing with domain name resolution.
// domain name to IP addresses, and the configured Dialer to connect.
type dialerResolver struct { type dialerResolver struct {
// Dialer is the underlying Dialer. Dialer
Dialer Dialer Resolver
// Resolver is the underlying Resolver.
Resolver Resolver
} }
var _ Dialer = &dialerResolver{} var _ Dialer = &dialerResolver{}
// DialContext implements Dialer.DialContext.
func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
// QUIRK: this routine and the related routines in quirks.go cannot
// be changed easily until we use events tracing to measure.
//
// Reference issue: TODO(https://github.com/ooni/probe/issues/1779).
onlyhost, onlyport, err := net.SplitHostPort(address) onlyhost, onlyport, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -98,12 +111,6 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO(bassosimone): here we should be using multierror rather
// than just calling ReduceErrors. We are not ready to do that
// yet, though. To do that, we need first to modify nettests so
// that we actually avoid dialing when measuring.
//
// See also the quirks.go file. This is clearly a QUIRK.
addrs = quirkSortIPAddrs(addrs) addrs = quirkSortIPAddrs(addrs)
var errorslist []error var errorslist []error
for _, addr := range addrs { for _, addr := range addrs {
@ -117,7 +124,7 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
return nil, quirkReduceErrors(errorslist) return nil, quirkReduceErrors(errorslist)
} }
// lookupHost performs a domain name resolution. // lookupHost ensures we correctly handle IP addresses.
func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) { func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil { if net.ParseIP(hostname) != nil {
return []string{hostname}, nil return []string{hostname}, nil
@ -125,7 +132,6 @@ func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]str
return d.Resolver.LookupHost(ctx, hostname) return d.Resolver.LookupHost(ctx, hostname)
} }
// CloseIdleConnections implements Dialer.CloseIdleConnections.
func (d *dialerResolver) CloseIdleConnections() { func (d *dialerResolver) CloseIdleConnections() {
d.Dialer.CloseIdleConnections() d.Dialer.CloseIdleConnections()
d.Resolver.CloseIdleConnections() d.Resolver.CloseIdleConnections()
@ -134,10 +140,10 @@ func (d *dialerResolver) CloseIdleConnections() {
// dialerLogger is a Dialer with logging. // dialerLogger is a Dialer with logging.
type dialerLogger struct { type dialerLogger struct {
// Dialer is the underlying dialer. // Dialer is the underlying dialer.
Dialer Dialer Dialer
// Logger is the underlying logger. // Logger is the underlying logger.
Logger Logger Logger
// operationSuffix is appended to the operation name. // operationSuffix is appended to the operation name.
// //
@ -150,7 +156,6 @@ type dialerLogger struct {
var _ Dialer = &dialerLogger{} var _ Dialer = &dialerLogger{}
// DialContext implements Dialer.DialContext
func (d *dialerLogger) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *dialerLogger) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
d.Logger.Debugf("dial%s %s/%s...", d.operationSuffix, address, network) d.Logger.Debugf("dial%s %s/%s...", d.operationSuffix, address, network)
start := time.Now() start := time.Now()
@ -166,7 +171,6 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string)
return conn, nil return conn, nil
} }
// CloseIdleConnections implements Dialer.CloseIdleConnections.
func (d *dialerLogger) CloseIdleConnections() { func (d *dialerLogger) CloseIdleConnections() {
d.Dialer.CloseIdleConnections() d.Dialer.CloseIdleConnections()
} }
@ -189,7 +193,6 @@ type dialerSingleUse struct {
var _ Dialer = &dialerSingleUse{} var _ Dialer = &dialerSingleUse{}
// DialContext implements Dialer.DialContext.
func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
defer s.Unlock() defer s.Unlock()
s.Lock() s.Lock()
@ -201,79 +204,57 @@ func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr
return conn, nil return conn, nil
} }
// CloseIdleConnections closes idle connections.
func (s *dialerSingleUse) CloseIdleConnections() { func (s *dialerSingleUse) CloseIdleConnections() {
// nothing // nothing to do
} }
// TODO(bassosimone): introduce factory for creating errors and
// write tests that ensure the factory works correctly.
// dialerErrWrapper is a dialer that performs error wrapping. The connection // dialerErrWrapper is a dialer that performs error wrapping. The connection
// returned by the DialContext function will also perform error wrapping. // returned by the DialContext function will also perform error wrapping.
type dialerErrWrapper struct { type dialerErrWrapper struct {
// Dialer is the underlying dialer.
Dialer Dialer
} }
var _ Dialer = &dialerErrWrapper{} var _ Dialer = &dialerErrWrapper{}
// DialContext implements Dialer.DialContext.
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, &errorsx.ErrWrapper{ return nil, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyGenericError(err), errorsx.ClassifyGenericError, errorsx.ConnectOperation, err)
Operation: errorsx.ConnectOperation,
WrappedErr: err,
}
} }
return &dialerErrWrapperConn{Conn: conn}, nil return &dialerErrWrapperConn{Conn: conn}, nil
} }
// dialerErrWrapperConn is a net.Conn that performs error wrapping. // dialerErrWrapperConn is a net.Conn that performs error wrapping.
type dialerErrWrapperConn struct { type dialerErrWrapperConn struct {
// Conn is the underlying connection.
net.Conn net.Conn
} }
var _ net.Conn = &dialerErrWrapperConn{} var _ net.Conn = &dialerErrWrapperConn{}
// Read implements net.Conn.Read.
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) { func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b) count, err := c.Conn.Read(b)
if err != nil { if err != nil {
return 0, &errorsx.ErrWrapper{ return 0, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyGenericError(err), errorsx.ClassifyGenericError, errorsx.ReadOperation, err)
Operation: errorsx.ReadOperation,
WrappedErr: err,
}
} }
return count, nil return count, nil
} }
// Write implements net.Conn.Write.
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) { func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
count, err := c.Conn.Write(b) count, err := c.Conn.Write(b)
if err != nil { if err != nil {
return 0, &errorsx.ErrWrapper{ return 0, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyGenericError(err), errorsx.ClassifyGenericError, errorsx.WriteOperation, err)
Operation: errorsx.WriteOperation,
WrappedErr: err,
}
} }
return count, nil return count, nil
} }
// Close implements net.Conn.Close.
func (c *dialerErrWrapperConn) Close() error { func (c *dialerErrWrapperConn) Close() error {
err := c.Conn.Close() err := c.Conn.Close()
if err != nil { if err != nil {
return &errorsx.ErrWrapper{ return errorsx.NewErrWrapper(
Failure: errorsx.ClassifyGenericError(err), errorsx.ClassifyGenericError, errorsx.CloseOperation, err)
Operation: errorsx.CloseOperation,
WrappedErr: err,
}
} }
return nil return nil
} }

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -16,8 +17,8 @@ import (
func TestNewDialer(t *testing.T) { func TestNewDialer(t *testing.T) {
t.Run("produces a chain with the expected types", func(t *testing.T) { t.Run("produces a chain with the expected types", func(t *testing.T) {
dlr := NewDialerWithoutResolver(log.Log) d := NewDialerWithoutResolver(log.Log)
logger := dlr.(*dialerLogger) logger := d.(*dialerLogger)
if logger.Logger != log.Log { if logger.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
@ -35,54 +36,76 @@ func TestNewDialer(t *testing.T) {
} }
func TestDialerSystem(t *testing.T) { func TestDialerSystem(t *testing.T) {
t.Run("has a default timeout of 15 seconds", func(t *testing.T) { t.Run("has a default timeout", func(t *testing.T) {
d := &dialerSystem{} d := &dialerSystem{}
ud := d.newUnderlyingDialer() ud := d.newUnderlyingDialer()
if ud.Timeout != 15*time.Second { if ud.Timeout != dialerDefaultTimeout {
t.Fatal("invalid default timeout") t.Fatal("unexpected default timeout")
} }
}) })
t.Run("we can change the default timeout for testing", func(t *testing.T) { t.Run("we can change the timeout for testing", func(t *testing.T) {
d := &dialerSystem{timeout: 1 * time.Second} const smaller = 1 * time.Second
d := &dialerSystem{timeout: smaller}
ud := d.newUnderlyingDialer() ud := d.newUnderlyingDialer()
if ud.Timeout != 1*time.Second { if ud.Timeout != smaller {
t.Fatal("invalid default timeout") t.Fatal("unexpected timeout")
} }
}) })
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
d := &dialerSystem{} d := &dialerSystem{}
d.CloseIdleConnections() // should not crash d.CloseIdleConnections() // to avoid missing coverage
}) })
t.Run("DialContext with canceled context", func(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
d := &dialerSystem{} t.Run("with canceled context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) d := &dialerSystem{}
cancel() // immediately! ctx, cancel := context.WithCancel(context.Background())
conn, err := d.DialContext(ctx, "tcp", "dns.google:443") cancel() // immediately!
if err == nil || err.Error() != "dial tcp: operation was canceled" { conn, err := d.DialContext(ctx, "tcp", "dns.google:443")
t.Fatal("unexpected err", err) if err == nil || err.Error() != "dial tcp: operation was canceled" {
} t.Fatal("unexpected err", err)
if conn != nil { }
t.Fatal("unexpected conn") if conn != nil {
} t.Fatal("unexpected conn")
}
})
t.Run("enforces the configured timeout", func(t *testing.T) {
const timeout = 1 * time.Millisecond
d := &dialerSystem{timeout: timeout}
ctx := context.Background()
start := time.Now()
conn, err := d.DialContext(ctx, "tcp", "dns.google:443")
stop := time.Now()
if err == nil || !strings.HasSuffix(err.Error(), "i/o timeout") {
t.Fatal(err)
}
if conn != nil {
t.Fatal("unexpected conn")
}
if stop.Sub(start) > 100*time.Millisecond {
t.Fatal("undable to enforce timeout")
}
})
}) })
} }
func TestDialerResolver(t *testing.T) { func TestDialerResolver(t *testing.T) {
t.Run("DialContext", func(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
t.Run("without a port", func(t *testing.T) { t.Run("fails without a port", func(t *testing.T) {
d := &dialerResolver{ d := &dialerResolver{
Dialer: &dialerSystem{}, Dialer: &dialerSystem{},
Resolver: &resolverSystem{}, Resolver: &resolverSystem{},
} }
conn, err := d.DialContext(context.Background(), "tcp", "ooni.nu") const missingPort = "ooni.nu"
conn, err := d.DialContext(context.Background(), "tcp", missingPort)
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
t.Fatal("not the error we expected", err) t.Fatal("unexpected err", err)
} }
if conn != nil { if conn != nil {
t.Fatal("expected a nil conn here") t.Fatal("unexpected conn")
} }
}) })
@ -111,9 +134,13 @@ func TestDialerResolver(t *testing.T) {
return nil, io.EOF return nil, io.EOF
}, },
}, },
Resolver: &nullResolver{}, Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"1.1.1.1", "8.8.8.8"}, nil
},
},
} }
conn, err := d.DialContext(context.Background(), "tcp", "1.1.1.1:853") conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
@ -123,16 +150,18 @@ func TestDialerResolver(t *testing.T) {
}) })
t.Run("handles dialing success correctly for many IP addresses", func(t *testing.T) { t.Run("handles dialing success correctly for many IP addresses", func(t *testing.T) {
expectedConn := &mocks.Conn{
MockClose: func() error {
return nil
},
}
d := &dialerResolver{ d := &dialerResolver{
Dialer: &mocks.Dialer{ Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return &mocks.Conn{ return expectedConn, nil
MockClose: func() error {
return nil
},
}, nil
}, },
}, Resolver: &mocks.Resolver{ },
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"1.1.1.1", "8.8.8.8"}, nil return []string{"1.1.1.1", "8.8.8.8"}, nil
}, },
@ -142,11 +171,166 @@ func TestDialerResolver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if conn == nil { if conn != expectedConn {
t.Fatal("expected non-nil conn") t.Fatal("unexpected conn")
} }
conn.Close() conn.Close()
}) })
t.Run("calls the underlying dialer sequentially", func(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
d := &dialerResolver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
// this implementation. When we have parallelism greater
// than one, this code will lock forever and we'll see
// a failed test and see we broke the QUIRK.
defer mu.Unlock()
mu.Lock()
return nil, io.EOF
},
},
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"1.1.1.1", "8.8.8.8"}, nil
},
},
}
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("attempts with IPv4 addresses before IPv6 addresses", func(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
var attempts []string
d := &dialerResolver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
// this implementation. When we have parallelism greater
// than one, this code will lock forever and we'll see
// a failed test and see we broke the QUIRK.
defer mu.Unlock()
attempts = append(attempts, address)
mu.Lock()
return nil, io.EOF
},
},
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"2001:4860:4860::8888", "8.8.8.8"}, nil
},
},
}
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn")
}
mu.Lock()
asExpected := (attempts[0] == "8.8.8.8:853" &&
attempts[1] == "[2001:4860:4860::8888]:853")
mu.Unlock()
if !asExpected {
t.Fatal("addresses not reordered")
}
})
t.Run("returns the first meaningful error if there is one", func(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
errorsList := []error{
errors.New("a mocked error"),
errorsx.NewErrWrapper(
errorsx.ClassifyGenericError,
errorsx.CloseOperation,
io.EOF,
),
}
var errorIdx int
d := &dialerResolver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
// this implementation. When we have parallelism greater
// than one, this code will lock forever and we'll see
// a failed test and see we broke the QUIRK.
defer mu.Unlock()
err := errorsList[errorIdx]
errorIdx++
mu.Lock()
return nil, err
},
},
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"2001:4860:4860::8888", "8.8.8.8"}, nil
},
},
}
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
if err == nil || err.Error() != errorsx.FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("though ignores the unknown failures", func(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
errorsList := []error{
errors.New("a mocked error"),
errorsx.NewErrWrapper(
errorsx.ClassifyGenericError,
errorsx.CloseOperation,
errors.New("antani"),
),
}
var errorIdx int
d := &dialerResolver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
// this implementation. When we have parallelism greater
// than one, this code will lock forever and we'll see
// a failed test and see we broke the QUIRK.
defer mu.Unlock()
err := errorsList[errorIdx]
errorIdx++
mu.Lock()
return nil, err
},
},
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"2001:4860:4860::8888", "8.8.8.8"}, nil
},
},
}
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
if err == nil || err.Error() != "a mocked error" {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
}) })
t.Run("lookupHost", func(t *testing.T) { t.Run("lookupHost", func(t *testing.T) {
@ -207,6 +391,12 @@ func TestDialerResolver(t *testing.T) {
func TestDialerLogger(t *testing.T) { func TestDialerLogger(t *testing.T) {
t.Run("DialContext", func(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
t.Run("handles success correctly", func(t *testing.T) { t.Run("handles success correctly", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
d := &dialerLogger{ d := &dialerLogger{
Dialer: &mocks.Dialer{ Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
@ -217,7 +407,7 @@ func TestDialerLogger(t *testing.T) {
}, nil }, nil
}, },
}, },
Logger: log.Log, Logger: lo,
} }
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
if err != nil { if err != nil {
@ -227,16 +417,25 @@ func TestDialerLogger(t *testing.T) {
t.Fatal("expected non-nil conn here") t.Fatal("expected non-nil conn here")
} }
conn.Close() conn.Close()
if count != 2 {
t.Fatal("not enough log calls")
}
}) })
t.Run("handles failure correctly", func(t *testing.T) { t.Run("handles failure correctly", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
d := &dialerLogger{ d := &dialerLogger{
Dialer: &mocks.Dialer{ Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF return nil, io.EOF
}, },
}, },
Logger: log.Log, Logger: lo,
} }
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
@ -245,6 +444,9 @@ func TestDialerLogger(t *testing.T) {
if conn != nil { if conn != nil {
t.Fatal("expected nil conn here") t.Fatal("expected nil conn here")
} }
if count != 2 {
t.Fatal("not enough log calls")
}
}) })
}) })
@ -290,7 +492,7 @@ func TestDialerSingleUse(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
d := &dialerSingleUse{} d := &dialerSingleUse{}
d.CloseIdleConnections() // does not crash d.CloseIdleConnections() // to have the coverage
}) })
} }
@ -472,5 +674,5 @@ func TestNewNullDialer(t *testing.T) {
if conn != nil { if conn != nil {
t.Fatal("expected nil conn") t.Fatal("expected nil conn")
} }
dialer.CloseIdleConnections() // does not crash dialer.CloseIdleConnections() // to have coverage
} }

View File

@ -67,3 +67,28 @@ func (e *ErrWrapper) Unwrap() error {
func (e *ErrWrapper) MarshalJSON() ([]byte, error) { func (e *ErrWrapper) MarshalJSON() ([]byte, error) {
return json.Marshal(e.Failure) return json.Marshal(e.Failure)
} }
// Classifier is the type of function that performs classification.
type Classifier func(err error) string
// NewErrWrapper creates a new ErrWrapper using the given
// classifier, operation name, and underlying error.
//
// This function panics if classifier is nil, or operation
// is the empty string or error is nil.
func NewErrWrapper(c Classifier, op string, err error) *ErrWrapper {
if c == nil {
panic("nil classifier")
}
if op == "" {
panic("empty op")
}
if err == nil {
panic("nil err")
}
return &ErrWrapper{
Failure: c(err),
Operation: op,
WrappedErr: err,
}
}

View File

@ -5,6 +5,8 @@ import (
"errors" "errors"
"io" "io"
"testing" "testing"
"github.com/ooni/probe-cli/v3/internal/atomicx"
) )
func TestErrWrapper(t *testing.T) { func TestErrWrapper(t *testing.T) {
@ -40,3 +42,63 @@ func TestErrWrapper(t *testing.T) {
} }
}) })
} }
func TestNewErrWrapper(t *testing.T) {
t.Run("panics if the classifier is nil", func(t *testing.T) {
recovered := &atomicx.Int64{}
func() {
defer func() {
if recover() != nil {
recovered.Add(1)
}
}()
NewErrWrapper(nil, CloseOperation, io.EOF)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
}
})
t.Run("panics if the operation is empty", func(t *testing.T) {
recovered := &atomicx.Int64{}
func() {
defer func() {
if recover() != nil {
recovered.Add(1)
}
}()
NewErrWrapper(ClassifyGenericError, "", io.EOF)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
}
})
t.Run("panics if the error is nil", func(t *testing.T) {
recovered := &atomicx.Int64{}
func() {
defer func() {
if recover() != nil {
recovered.Add(1)
}
}()
NewErrWrapper(ClassifyGenericError, CloseOperation, nil)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
}
})
t.Run("otherwise, works as intended", func(t *testing.T) {
ew := NewErrWrapper(ClassifyGenericError, CloseOperation, io.EOF)
if ew.Failure != FailureEOFError {
t.Fatal("unexpected failure")
}
if ew.Operation != CloseOperation {
t.Fatal("unexpected operation")
}
if ew.WrappedErr != io.EOF {
t.Fatal("unexpected WrappedErr")
}
})
}

View File

@ -0,0 +1,18 @@
package mocks
// Logger allows mocking a logger.
type Logger struct {
MockDebug func(message string)
MockDebugf func(format string, v ...interface{})
}
// Debug calls MockDebug.
func (lo *Logger) Debug(message string) {
lo.MockDebug(message)
}
// Debugf calls MockDebugf.
func (lo *Logger) Debugf(format string, v ...interface{}) {
lo.MockDebugf(format, v...)
}

View File

@ -0,0 +1,31 @@
package mocks
import "testing"
func TestLogger(t *testing.T) {
t.Run("Debug", func(t *testing.T) {
var called bool
lo := &Logger{
MockDebug: func(message string) {
called = true
},
}
lo.Debug("antani")
if !called {
t.Fatal("not called")
}
})
t.Run("Debugf", func(t *testing.T) {
var called bool
lo := &Logger{
MockDebugf: func(message string, v ...interface{}) {
called = true
},
}
lo.Debugf("antani", 1, 2, 3, 4)
if !called {
t.Fatal("not called")
}
})
}

View File

@ -337,11 +337,8 @@ var _ QUICListener = &quicListenerErrWrapper{}
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (quicx.UDPLikeConn, error) { func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
pconn, err := qls.QUICListener.Listen(addr) pconn, err := qls.QUICListener.Listen(addr)
if err != nil { if err != nil {
return nil, &errorsx.ErrWrapper{ return nil, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyGenericError(err), errorsx.ClassifyGenericError, errorsx.QUICListenOperation, err)
Operation: errorsx.QUICListenOperation,
WrappedErr: err,
}
} }
return &quicErrWrapperUDPLikeConn{pconn}, nil return &quicErrWrapperUDPLikeConn{pconn}, nil
} }
@ -358,11 +355,8 @@ var _ quicx.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) { func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
count, err := c.UDPLikeConn.WriteTo(p, addr) count, err := c.UDPLikeConn.WriteTo(p, addr)
if err != nil { if err != nil {
return 0, &errorsx.ErrWrapper{ return 0, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyGenericError(err), errorsx.ClassifyGenericError, errorsx.WriteToOperation, err)
Operation: errorsx.WriteToOperation,
WrappedErr: err,
}
} }
return count, nil return count, nil
} }
@ -371,11 +365,8 @@ func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := c.UDPLikeConn.ReadFrom(b) n, addr, err := c.UDPLikeConn.ReadFrom(b)
if err != nil { if err != nil {
return 0, nil, &errorsx.ErrWrapper{ return 0, nil, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyGenericError(err), errorsx.ClassifyGenericError, errorsx.ReadFromOperation, err)
Operation: errorsx.ReadFromOperation,
WrappedErr: err,
}
} }
return n, addr, nil return n, addr, nil
} }
@ -391,11 +382,8 @@ func (d *quicDialerErrWrapper) DialContext(
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
sess, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg) sess, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
if err != nil { if err != nil {
return nil, &errorsx.ErrWrapper{ return nil, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyQUICHandshakeError(err), errorsx.ClassifyQUICHandshakeError, errorsx.QUICHandshakeOperation, err)
Operation: errorsx.QUICHandshakeOperation,
WrappedErr: err,
}
} }
return sess, nil return sess, nil
} }

View File

@ -29,6 +29,8 @@ import (
// //
// This is CLEARLY a QUIRK anyway. There may code depending on how // This is CLEARLY a QUIRK anyway. There may code depending on how
// we do things here and it's tricky to remove this behavior. // we do things here and it's tricky to remove this behavior.
//
// See TODO(https://github.com/ooni/probe/issues/1779).
func quirkReduceErrors(errorslist []error) error { func quirkReduceErrors(errorslist []error) error {
if len(errorslist) == 0 { if len(errorslist) == 0 {
return nil return nil
@ -49,6 +51,8 @@ func quirkReduceErrors(errorslist []error) error {
// //
// It saddens me to have this quirk, but it is here to pair // It saddens me to have this quirk, but it is here to pair
// with quirkReduceErrors, which assumes that <facepalm>. // with quirkReduceErrors, which assumes that <facepalm>.
//
// See TODO(https://github.com/ooni/probe/issues/1779).
func quirkSortIPAddrs(addrs []string) (out []string) { func quirkSortIPAddrs(addrs []string) (out []string) {
isIPv6 := func(x string) bool { isIPv6 := func(x string) bool {
// This check for identifying IPv6 is discussed // This check for identifying IPv6 is discussed

View File

@ -35,12 +35,16 @@ func TestQuirkReduceErrors(t *testing.T) {
t.Run("multiple errors with meaningful ones", func(t *testing.T) { t.Run("multiple errors with meaningful ones", func(t *testing.T) {
err1 := errors.New("mocked error #1") err1 := errors.New("mocked error #1")
err2 := &errorsx.ErrWrapper{ err2 := errorsx.NewErrWrapper(
Failure: "unknown_failure: antani", errorsx.ClassifyGenericError,
} errorsx.CloseOperation,
err3 := &errorsx.ErrWrapper{ errors.New("antani"),
Failure: errorsx.FailureConnectionRefused, )
} err3 := errorsx.NewErrWrapper(
errorsx.ClassifyGenericError,
errorsx.CloseOperation,
errorsx.ECONNREFUSED,
)
err4 := errors.New("mocked error #3") err4 := errors.New("mocked error #3")
result := quirkReduceErrors([]error{err1, err2, err3, err4}) result := quirkReduceErrors([]error{err1, err2, err3, err4})
if result.Error() != errorsx.FailureConnectionRefused { if result.Error() != errorsx.FailureConnectionRefused {

View File

@ -197,11 +197,8 @@ var _ Resolver = &resolverErrWrapper{}
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname) addrs, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil { if err != nil {
return nil, &errorsx.ErrWrapper{ return nil, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyResolverError(err), errorsx.ClassifyResolverError, errorsx.ResolveOperation, err)
Operation: errorsx.ResolveOperation,
WrappedErr: err,
}
} }
return addrs, nil return addrs, nil
} }

View File

@ -338,11 +338,8 @@ func (h *tlsHandshakerErrWrapper) Handshake(
) (net.Conn, tls.ConnectionState, error) { ) (net.Conn, tls.ConnectionState, error) {
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil { if err != nil {
return nil, tls.ConnectionState{}, &errorsx.ErrWrapper{ return nil, tls.ConnectionState{}, errorsx.NewErrWrapper(
Failure: errorsx.ClassifyTLSHandshakeError(err), errorsx.ClassifyTLSHandshakeError, errorsx.TLSHandshakeOperation, err)
Operation: errorsx.TLSHandshakeOperation,
WrappedErr: err,
}
} }
return tlsconn, state, nil return tlsconn, state, nil
} }