netxlite: code quality, improve tests, docs (#494)

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-08 22:48:10 +02:00 committed by GitHub
parent 3cd88debdc
commit 50b58672c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 318 additions and 154 deletions

View File

@ -40,7 +40,11 @@ type Dialer interface {
// //
// 4. wraps errors; // 4. wraps errors;
// //
// 5. has a configured connect timeout. // 5. has a configured connect timeout;
//
// 6. if a dialer wraps a resolver, the dialer will forward
// the CloseIdleConnection call to its resolver (which is
// instrumental to manage a DoH resolver connections properly).
func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer { func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
return &dialerLogger{ return &dialerLogger{
Dialer: &dialerResolver{ Dialer: &dialerResolver{

View File

@ -24,6 +24,9 @@
// //
// We also want to mock any underlying dependency for testing. // We also want to mock any underlying dependency for testing.
// //
// We also want to map errors to OONI failures, which are described by
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
//
// Operations // Operations
// //
// This package implements the following operations: // This package implements the following operations:

View File

@ -30,7 +30,6 @@ type httpTransportLogger struct {
var _ HTTPTransport = &httpTransportLogger{} var _ HTTPTransport = &httpTransportLogger{}
// RoundTrip implements HTTPTransport.RoundTrip.
func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) { func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) {
host := req.Host host := req.Host
if host == "" { if host == "" {
@ -64,13 +63,12 @@ func (txp *httpTransportLogger) logTrip(req *http.Request) (*http.Response, erro
return resp, nil return resp, nil
} }
// CloseIdleConnections implement HTTPTransport.CloseIdleConnections.
func (txp *httpTransportLogger) CloseIdleConnections() { func (txp *httpTransportLogger) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections() txp.HTTPTransport.CloseIdleConnections()
} }
// httpTransportConnectionsCloser is an HTTPTransport that // httpTransportConnectionsCloser is an HTTPTransport that
// correctly forwards CloseIdleConnections. // correctly forwards CloseIdleConnections calls.
type httpTransportConnectionsCloser struct { type httpTransportConnectionsCloser struct {
HTTPTransport HTTPTransport
Dialer Dialer
@ -98,6 +96,16 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
// The returned transport will disable transparent decompression // The returned transport will disable transparent decompression
// of compressed response bodies (and will not automatically // of compressed response bodies (and will not automatically
// ask for such compression, though you can always do that manually). // ask for such compression, though you can always do that manually).
//
// The returned transport will configure TCP and TLS connections
// created using its dialer and TLS dialer to always have a
// read watchdog timeout to address https://github.com/ooni/probe/issues/1609.
//
// The returned transport will always enforce 1 connection per host
// and we cannot get rid of this QUIRK requirement because it is
// necessary to perform sane measurements with tracing. We will be
// able to possibly relax this requirement after we change the
// way in which we perform measurements.
func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport { func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
// Using oohttp to support any TLS library. // Using oohttp to support any TLS library.
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone() txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
@ -115,8 +123,6 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTra
// Better for Cloudflare DNS and also better because we have less // Better for Cloudflare DNS and also better because we have less
// noisy events and we can better understand what happened. // noisy events and we can better understand what happened.
//
// UNDOCUMENTED: I am wondering whether we can relax this constraint.
txp.MaxConnsPerHost = 1 txp.MaxConnsPerHost = 1
// The following (1) reduces the number of headers that Go will // The following (1) reduces the number of headers that Go will
@ -175,7 +181,7 @@ func (d *httpTLSDialerWithReadTimeout) DialTLSContext(
if err != nil { if err != nil {
return nil, err return nil, err
} }
tconn, okay := conn.(TLSConn) tconn, okay := conn.(TLSConn) // part of the contract but let's be graceful
if !okay { if !okay {
conn.Close() // we own the conn here conn.Close() // we own the conn here
return nil, ErrNotTLSConn return nil, ErrNotTLSConn

View File

@ -33,7 +33,7 @@ func TestHTTP3Dialer(t *testing.T) {
}) })
} }
func TestHTTP3TransportClosesIdleConnections(t *testing.T) { func TestHTTP3Transport(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
var ( var (
calledHTTP3 bool calledHTTP3 bool

View File

@ -21,8 +21,17 @@ import (
func TestHTTPTransportLogger(t *testing.T) { func TestHTTPTransportLogger(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
t.Run("with failure", func(t *testing.T) { t.Run("with failure", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebug: func(message string) {
count++
},
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
txp := &httpTransportLogger{ txp := &httpTransportLogger{
Logger: log.Log, Logger: lo,
HTTPTransport: &mocks.HTTPTransport{ HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) { MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return nil, io.EOF return nil, io.EOF
@ -37,6 +46,9 @@ func TestHTTPTransportLogger(t *testing.T) {
if resp != nil { if resp != nil {
t.Fatal("expected nil response here") t.Fatal("expected nil response here")
} }
if count < 1 {
t.Fatal("no logs?!")
}
}) })
t.Run("we add the host header", func(t *testing.T) { t.Run("we add the host header", func(t *testing.T) {
@ -73,8 +85,17 @@ func TestHTTPTransportLogger(t *testing.T) {
}) })
t.Run("with success", func(t *testing.T) { t.Run("with success", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebug: func(message string) {
count++
},
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
txp := &httpTransportLogger{ txp := &httpTransportLogger{
Logger: log.Log, Logger: lo,
HTTPTransport: &mocks.HTTPTransport{ HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) { MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return &http.Response{ return &http.Response{
@ -94,6 +115,9 @@ func TestHTTPTransportLogger(t *testing.T) {
} }
iox.ReadAllContext(context.Background(), resp.Body) iox.ReadAllContext(context.Background(), resp.Body)
resp.Body.Close() resp.Body.Close()
if count < 1 {
t.Fatal("no logs?!")
}
}) })
}) })

View File

@ -12,6 +12,26 @@ import (
utls "gitlab.com/yawning/utls.git" utls "gitlab.com/yawning/utls.git"
) )
func TestResolver(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
t.Run("works as intended", func(t *testing.T) {
// TODO(bassosimone): this is actually an integration
// test but how to test this case?
r := netxlite.NewResolverSystem(log.Log)
defer r.CloseIdleConnections()
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil result here")
}
})
}
func TestHTTPTransport(t *testing.T) { func TestHTTPTransport(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")

View File

@ -13,6 +13,7 @@ var (
DefaultDialer = &dialerSystem{} DefaultDialer = &dialerSystem{}
DefaultTLSHandshaker = defaultTLSHandshaker DefaultTLSHandshaker = defaultTLSHandshaker
NewConnUTLS = newConnUTLS NewConnUTLS = newConnUTLS
DefaultResolver = &resolverSystem{}
) )
// These types export internal names to legacy ooni/probe-cli code. // These types export internal names to legacy ooni/probe-cli code.

View File

@ -50,6 +50,24 @@ type QUICDialer interface {
// NewQUICDialerWithResolver returns a QUICDialer using the given // NewQUICDialerWithResolver returns a QUICDialer using the given
// QUICListener to create listening connections and the given Resolver // QUICListener to create listening connections and the given Resolver
// to resolve domain names (if needed). // to resolve domain names (if needed).
//
// Properties of the dialer:
//
// 1. logs events using the given logger;
//
// 2. resolves domain names using the givern resolver;
//
// 3. when using a resolver, _may_ attempt multiple dials
// in parallel (happy eyeballs) and _may_ return an aggregate
// error to the caller;
//
// 4. wraps errors;
//
// 5. has a configured connect timeout;
//
// 6. if a dialer wraps a resolver, the dialer will forward
// the CloseIdleConnection call to its resolver (which is
// instrumental to manage a DoH resolver connections properly).
func NewQUICDialerWithResolver(listener QUICListener, func NewQUICDialerWithResolver(listener QUICListener,
logger Logger, resolver Resolver) QUICDialer { logger Logger, resolver Resolver) QUICDialer {
return &quicDialerLogger{ return &quicDialerLogger{
@ -210,12 +228,9 @@ func (d *quicDialerResolver) DialContext(
return nil, err return nil, err
} }
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost) tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
// TODO(bassosimone): here we should be using multierror rather // See TODO(https://github.com/ooni/probe/issues/1779) however
// than just calling ReduceErrors. We are not ready to do that // this is less of a problem for QUIC because so far we have been
// yet, though. To do that, we need first to modify nettests so // using it to perform research only (i.e., urlgetter).
// that we actually avoid dialing when measuring.
//
// See also the quirks.go file. This is clearly a QUIRK.
addrs = quirkSortIPAddrs(addrs) addrs = quirkSortIPAddrs(addrs)
var errorslist []error var errorslist []error
for _, addr := range addrs { for _, addr := range addrs {

View File

@ -17,6 +17,34 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx" "github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
) )
func TestNewQUICListener(t *testing.T) {
ql := NewQUICListener()
qew := ql.(*quicListenerErrWrapper)
_ = qew.QUICListener.(*quicListenerStdlib)
}
func TestNewQUICDialer(t *testing.T) {
ql := NewQUICListener()
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
logger := dlr.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
resolver := logger.Dialer.(*quicDialerResolver)
if _, okay := resolver.Resolver.(*nullResolver); !okay {
t.Fatal("invalid resolver type")
}
logger = resolver.Dialer.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
if base.QUICListener != ql {
t.Fatal("invalid quic listener")
}
}
func TestQUICDialerQUICGo(t *testing.T) { func TestQUICDialerQUICGo(t *testing.T) {
t.Run("DialContext", func(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
t.Run("cannot split host port", func(t *testing.T) { t.Run("cannot split host port", func(t *testing.T) {
@ -223,7 +251,6 @@ func TestQUICDialerQUICGo(t *testing.T) {
} }
func TestQUICDialerResolver(t *testing.T) { func TestQUICDialerResolver(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
var ( var (
forDialer bool forDialer bool
@ -302,7 +329,7 @@ func TestQUICDialerResolver(t *testing.T) {
} }
}) })
t.Run("with invalid port", func(t *testing.T) { t.Run("with invalid port (i.e., the zero port)", func(t *testing.T) {
// This test allows us to check for the case where every attempt // This test allows us to check for the case where every attempt
// to establish a connection leads to a failure // to establish a connection leads to a failure
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
@ -376,7 +403,6 @@ func TestQUICDialerResolver(t *testing.T) {
} }
func TestQUICLoggerDialer(t *testing.T) { func TestQUICLoggerDialer(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
var forDialer bool var forDialer bool
d := &quicDialerLogger{ d := &quicDialerLogger{
@ -394,6 +420,12 @@ func TestQUICLoggerDialer(t *testing.T) {
t.Run("DialContext", func(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) { t.Run("on success", func(t *testing.T) {
var called int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
called++
},
}
d := &quicDialerLogger{ d := &quicDialerLogger{
Dialer: &mocks.QUICDialer{ Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network string, MockDialContext: func(ctx context.Context, network string,
@ -407,7 +439,7 @@ func TestQUICLoggerDialer(t *testing.T) {
}, nil }, nil
}, },
}, },
Logger: log.Log, Logger: lo,
} }
ctx := context.Background() ctx := context.Background()
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
@ -419,9 +451,18 @@ func TestQUICLoggerDialer(t *testing.T) {
if err := sess.CloseWithError(0, ""); err != nil { if err := sess.CloseWithError(0, ""); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if called != 2 {
t.Fatal("invalid number of calls")
}
}) })
t.Run("on failure", func(t *testing.T) { t.Run("on failure", func(t *testing.T) {
var called int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
called++
},
}
expected := errors.New("mocked error") expected := errors.New("mocked error")
d := &quicDialerLogger{ d := &quicDialerLogger{
Dialer: &mocks.QUICDialer{ Dialer: &mocks.QUICDialer{
@ -431,7 +472,7 @@ func TestQUICLoggerDialer(t *testing.T) {
return nil, expected return nil, expected
}, },
}, },
Logger: log.Log, Logger: lo,
} }
ctx := context.Background() ctx := context.Background()
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
@ -443,32 +484,13 @@ func TestQUICLoggerDialer(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil session") t.Fatal("expected nil session")
} }
if called != 2 {
t.Fatal("invalid number of calls")
}
}) })
}) })
} }
func TestNewQUICDialer(t *testing.T) {
ql := NewQUICListener()
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
logger := dlr.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
resolver := logger.Dialer.(*quicDialerResolver)
if _, okay := resolver.Resolver.(*nullResolver); !okay {
t.Fatal("invalid resolver type")
}
logger = resolver.Dialer.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
if base.QUICListener != ql {
t.Fatal("invalid quic listener")
}
}
func TestNewSingleUseQUICDialer(t *testing.T) { func TestNewSingleUseQUICDialer(t *testing.T) {
sess := &mocks.QUICEarlySession{} sess := &mocks.QUICEarlySession{}
qd := NewSingleUseQUICDialer(sess) qd := NewSingleUseQUICDialer(sess)

View File

@ -54,26 +54,34 @@ func TestQuirkReduceErrors(t *testing.T) {
} }
func TestQuirkSortIPAddrs(t *testing.T) { func TestQuirkSortIPAddrs(t *testing.T) {
addrs := []string{ t.Run("with some addrs", func(t *testing.T) {
"::1", addrs := []string{
"192.168.1.2", "::1",
"2a00:1450:4002:404::2004", "192.168.1.2",
"142.250.184.36", "2a00:1450:4002:404::2004",
"2604:8800:5000:82:466:38ff:fecb:d46e", "142.250.184.36",
"198.145.29.83", "2604:8800:5000:82:466:38ff:fecb:d46e",
"95.216.163.36", "198.145.29.83",
} "95.216.163.36",
expected := []string{ }
"192.168.1.2", expected := []string{
"142.250.184.36", "192.168.1.2",
"198.145.29.83", "142.250.184.36",
"95.216.163.36", "198.145.29.83",
"::1", "95.216.163.36",
"2a00:1450:4002:404::2004", "::1",
"2604:8800:5000:82:466:38ff:fecb:d46e", "2a00:1450:4002:404::2004",
} "2604:8800:5000:82:466:38ff:fecb:d46e",
out := quirkSortIPAddrs(addrs) }
if diff := cmp.Diff(expected, out); diff != "" { out := quirkSortIPAddrs(addrs)
t.Fatal(diff) if diff := cmp.Diff(expected, out); diff != "" {
} t.Fatal(diff)
}
})
t.Run("with an empty list", func(t *testing.T) {
if quirkSortIPAddrs(nil) != nil {
t.Fatal("expected nil output")
}
})
} }

View File

@ -27,6 +27,20 @@ type Resolver interface {
// NewResolverSystem creates a new resolver using system // NewResolverSystem creates a new resolver using system
// facilities for resolving domain names (e.g., getaddrinfo). // facilities for resolving domain names (e.g., getaddrinfo).
//
// The resolver will provide the following guarantees:
//
// 1. handles IDNA;
//
// 2. performs logging;
//
// 3. short-circuits IP addresses like getaddrinfo does (i.e.,
// resolving "1.1.1.1" yields []string{"1.1.1.1"};
//
// 4. wraps errors;
//
// 5. enforces reasonable timeouts (
// see https://github.com/ooni/probe/issues/1726).
func NewResolverSystem(logger Logger) Resolver { func NewResolverSystem(logger Logger) Resolver {
return &resolverIDNA{ return &resolverIDNA{
Resolver: &resolverLogger{ Resolver: &resolverLogger{
@ -48,7 +62,6 @@ type resolverSystem struct {
var _ Resolver = &resolverSystem{} var _ Resolver = &resolverSystem{}
// LookupHost implements Resolver.LookupHost.
func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
// This code forces adding a shorter timeout to the domain name // This code forces adding a shorter timeout to the domain name
// resolutions when using the system resolver. We have seen cases // resolutions when using the system resolver. We have seen cases
@ -89,24 +102,18 @@ func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) (
return net.DefaultResolver.LookupHost return net.DefaultResolver.LookupHost
} }
// Network implements Resolver.Network.
func (r *resolverSystem) Network() string { func (r *resolverSystem) Network() string {
return "system" return "system"
} }
// Address implements Resolver.Address.
func (r *resolverSystem) Address() string { func (r *resolverSystem) Address() string {
return "" return ""
} }
// CloseIdleConnections implements Resolver.CloseIdleConnections.
func (r *resolverSystem) CloseIdleConnections() { func (r *resolverSystem) CloseIdleConnections() {
// nothing // nothing to do
} }
// DefaultResolver is the resolver we use by default.
var DefaultResolver = &resolverSystem{}
// resolverLogger is a resolver that emits events // resolverLogger is a resolver that emits events
type resolverLogger struct { type resolverLogger struct {
Resolver Resolver
@ -115,7 +122,6 @@ type resolverLogger struct {
var _ Resolver = &resolverLogger{} var _ Resolver = &resolverLogger{}
// LookupHost returns the IP addresses of a host
func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
r.Logger.Debugf("resolve %s...", hostname) r.Logger.Debugf("resolve %s...", hostname)
start := time.Now() start := time.Now()
@ -136,7 +142,6 @@ type resolverIDNA struct {
Resolver Resolver
} }
// LookupHost implements Resolver.LookupHost.
func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]string, error) {
host, err := idna.ToASCII(hostname) host, err := idna.ToASCII(hostname)
if err != nil { if err != nil {
@ -151,7 +156,6 @@ type resolverShortCircuitIPAddr struct {
Resolver Resolver
} }
// LookupHost implements Resolver.LookupHost.
func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil { if net.ParseIP(hostname) != nil {
return []string{hostname}, nil return []string{hostname}, nil
@ -166,24 +170,20 @@ var ErrNoResolver = errors.New("no configured resolver")
// domain names to IP addresses and always returns ErrNoResolver. // domain names to IP addresses and always returns ErrNoResolver.
type nullResolver struct{} type nullResolver struct{}
// LookupHost implements Resolver.LookupHost.
func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) { func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) {
return nil, ErrNoResolver return nil, ErrNoResolver
} }
// Network implements Resolver.Network.
func (r *nullResolver) Network() string { func (r *nullResolver) Network() string {
return "null" return "null"
} }
// Address implements Resolver.Address.
func (r *nullResolver) Address() string { func (r *nullResolver) Address() string {
return "" return ""
} }
// CloseIdleConnections implements Resolver.CloseIdleConnections.
func (r *nullResolver) CloseIdleConnections() { func (r *nullResolver) CloseIdleConnections() {
// nothing // nothing to do
} }
// resolverErrWrapper is a Resolver that knows about wrapping errors. // resolverErrWrapper is a Resolver that knows about wrapping errors.
@ -193,7 +193,6 @@ type resolverErrWrapper struct {
var _ Resolver = &resolverErrWrapper{} var _ Resolver = &resolverErrWrapper{}
// LookupHost implements Resolver.LookupHost
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname) addrs, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil { if err != nil {

View File

@ -15,6 +15,18 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestNewResolverSystem(t *testing.T) {
resolver := NewResolverSystem(log.Log)
idna := resolver.(*resolverIDNA)
logger := idna.Resolver.(*resolverLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
_ = errWrapper.Resolver.(*resolverSystem)
}
func TestResolverSystem(t *testing.T) { func TestResolverSystem(t *testing.T) {
t.Run("Network and Address", func(t *testing.T) { t.Run("Network and Address", func(t *testing.T) {
r := &resolverSystem{} r := &resolverSystem{}
@ -26,16 +38,9 @@ func TestResolverSystem(t *testing.T) {
} }
}) })
t.Run("works as intended", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
r := &resolverSystem{} r := &resolverSystem{}
defer r.CloseIdleConnections() r.CloseIdleConnections() // to cover it
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil result here")
}
}) })
t.Run("check default timeout", func(t *testing.T) { t.Run("check default timeout", func(t *testing.T) {
@ -45,7 +50,30 @@ func TestResolverSystem(t *testing.T) {
} }
}) })
t.Run("check default lookup host func not nil", func(t *testing.T) {
r := &resolverSystem{}
if r.lookupHost() == nil {
t.Fatal("expected non-nil func here")
}
})
t.Run("LookupHost", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("with success", func(t *testing.T) {
r := &resolverSystem{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("invalid addrs")
}
})
t.Run("with timeout and success", func(t *testing.T) { t.Run("with timeout and success", func(t *testing.T) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(1) wg.Add(1)
@ -111,9 +139,15 @@ func TestResolverSystem(t *testing.T) {
func TestResolverLogger(t *testing.T) { func TestResolverLogger(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("with success", func(t *testing.T) { t.Run("with success", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
expected := []string{"1.1.1.1"} expected := []string{"1.1.1.1"}
r := resolverLogger{ r := resolverLogger{
Logger: log.Log, Logger: lo,
Resolver: &mocks.Resolver{ Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil return expected, nil
@ -127,12 +161,21 @@ func TestResolverLogger(t *testing.T) {
if diff := cmp.Diff(expected, addrs); diff != "" { if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
if count != 2 {
t.Fatal("unexpected count")
}
}) })
t.Run("with failure", func(t *testing.T) { t.Run("with failure", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := resolverLogger{ r := resolverLogger{
Logger: log.Log, Logger: lo,
Resolver: &mocks.Resolver{ Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, expected return nil, expected
@ -146,6 +189,9 @@ func TestResolverLogger(t *testing.T) {
if addrs != nil { if addrs != nil {
t.Fatal("expected nil addr here") t.Fatal("expected nil addr here")
} }
if count != 2 {
t.Fatal("unexpected count")
}
}) })
}) })
} }
@ -193,18 +239,6 @@ func TestResolverIDNA(t *testing.T) {
}) })
} }
func TestNewResolverSystem(t *testing.T) {
resolver := NewResolverSystem(log.Log)
idna := resolver.(*resolverIDNA)
logger := idna.Resolver.(*resolverLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
_ = errWrapper.Resolver.(*resolverSystem)
}
func TestResolverShortCircuitIPAddr(t *testing.T) { func TestResolverShortCircuitIPAddr(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("with IP addr", func(t *testing.T) { t.Run("with IP addr", func(t *testing.T) {
@ -261,7 +295,7 @@ func TestNullResolver(t *testing.T) {
if r.Address() != "" { if r.Address() != "" {
t.Fatal("invalid address") t.Fatal("invalid address")
} }
r.CloseIdleConnections() // should not crash r.CloseIdleConnections() // for coverage
} }
func TestResolverErrWrapper(t *testing.T) { func TestResolverErrWrapper(t *testing.T) {

View File

@ -99,7 +99,7 @@ func ConfigureTLSVersion(config *tls.Config, version string) error {
config.MinVersion = tls.VersionTLS10 config.MinVersion = tls.VersionTLS10
config.MaxVersion = tls.VersionTLS10 config.MaxVersion = tls.VersionTLS10
case "": case "":
// nothing // nothing to do
default: default:
return ErrInvalidTLSVersion return ErrInvalidTLSVersion
} }
@ -119,7 +119,7 @@ type TLSHandshaker interface {
// the given config. This function DOES NOT take ownership of the connection // the given config. This function DOES NOT take ownership of the connection
// and it's your responsibility to close it on failure. // and it's your responsibility to close it on failure.
// //
// The returned connection will always implement the TLSConn interface // QUIRK: The returned connection will always implement the TLSConn interface
// exposed by this package. A future version of this interface will instead // exposed by this package. A future version of this interface will instead
// return directly a TLSConn and remove the ConnectionState param. // return directly a TLSConn and remove the ConnectionState param.
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
@ -128,10 +128,21 @@ type TLSHandshaker interface {
// NewTLSHandshakerStdlib creates a new TLS handshaker using the // NewTLSHandshakerStdlib creates a new TLS handshaker using the
// go standard library to create TLS connections. // go standard library to create TLS connections.
//
// The handshaker guarantees:
//
// 1. logging
//
// 2. error wrapping
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker { func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
return newTLSHandshaker(&tlsHandshakerConfigurable{}, logger)
}
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
func newTLSHandshaker(th TLSHandshaker, logger Logger) TLSHandshaker {
return &tlsHandshakerLogger{ return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerErrWrapper{ TLSHandshaker: &tlsHandshakerErrWrapper{
TLSHandshaker: &tlsHandshakerConfigurable{}, TLSHandshaker: th,
}, },
Logger: logger, Logger: logger,
} }
@ -191,11 +202,8 @@ var defaultTLSHandshaker = &tlsHandshakerConfigurable{}
// tlsHandshakerLogger is a TLSHandshaker with logging. // tlsHandshakerLogger is a TLSHandshaker with logging.
type tlsHandshakerLogger struct { type tlsHandshakerLogger struct {
// TLSHandshaker is the underlying handshaker. TLSHandshaker
TLSHandshaker TLSHandshaker Logger
// Logger is the underlying logger.
Logger Logger
} }
var _ TLSHandshaker = &tlsHandshakerLogger{} var _ TLSHandshaker = &tlsHandshakerLogger{}

View File

@ -118,10 +118,22 @@ func TestConfigureTLSVersion(t *testing.T) {
} }
} }
func TestNewTLSHandshakerStdlib(t *testing.T) {
th := NewTLSHandshakerStdlib(log.Log)
logger := th.(*tlsHandshakerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn != nil {
t.Fatal("expected nil NewConn")
}
}
func TestTLSHandshakerConfigurable(t *testing.T) { func TestTLSHandshakerConfigurable(t *testing.T) {
t.Run("Handshake", func(t *testing.T) { t.Run("Handshake", func(t *testing.T) {
t.Run("with error", func(t *testing.T) { t.Run("with error", func(t *testing.T) {
var times []time.Time var times []time.Time
h := &tlsHandshakerConfigurable{} h := &tlsHandshakerConfigurable{}
tcpConn := &mocks.Conn{ tcpConn := &mocks.Conn{
@ -230,13 +242,19 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
func TestTLSHandshakerLogger(t *testing.T) { func TestTLSHandshakerLogger(t *testing.T) {
t.Run("Handshake", func(t *testing.T) { t.Run("Handshake", func(t *testing.T) {
t.Run("on success", func(t *testing.T) { t.Run("on success", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
th := &tlsHandshakerLogger{ th := &tlsHandshakerLogger{
TLSHandshaker: &mocks.TLSHandshaker{ TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return tls.Client(conn, config), tls.ConnectionState{}, nil return tls.Client(conn, config), tls.ConnectionState{}, nil
}, },
}, },
Logger: log.Log, Logger: lo,
} }
conn := &mocks.Conn{ conn := &mocks.Conn{
MockClose: func() error { MockClose: func() error {
@ -255,9 +273,18 @@ func TestTLSHandshakerLogger(t *testing.T) {
if !reflect.ValueOf(connState).IsZero() { if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here") t.Fatal("expected zero ConnectionState here")
} }
if count != 2 {
t.Fatal("invalid count")
}
}) })
t.Run("on failure", func(t *testing.T) { t.Run("on failure", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
expected := errors.New("mocked error") expected := errors.New("mocked error")
th := &tlsHandshakerLogger{ th := &tlsHandshakerLogger{
TLSHandshaker: &mocks.TLSHandshaker{ TLSHandshaker: &mocks.TLSHandshaker{
@ -265,7 +292,7 @@ func TestTLSHandshakerLogger(t *testing.T) {
return nil, tls.ConnectionState{}, expected return nil, tls.ConnectionState{}, expected
}, },
}, },
Logger: log.Log, Logger: lo,
} }
conn := &mocks.Conn{ conn := &mocks.Conn{
MockClose: func() error { MockClose: func() error {
@ -284,10 +311,29 @@ func TestTLSHandshakerLogger(t *testing.T) {
if !reflect.ValueOf(connState).IsZero() { if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here") t.Fatal("expected zero ConnectionState here")
} }
if count != 2 {
t.Fatal("invalid count")
}
}) })
}) })
} }
func TestNewTLSDialer(t *testing.T) {
d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{}
dialer := NewTLSDialer(d, th)
tlsd := dialer.(*tlsDialer)
if tlsd.Config == nil {
t.Fatal("unexpected config")
}
if tlsd.Dialer != d {
t.Fatal("unexpected dialer")
}
if tlsd.TLSHandshaker != th {
t.Fatal("invalid handshaker")
}
}
func TestTLSDialer(t *testing.T) { func TestTLSDialer(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool var called bool
@ -439,35 +485,6 @@ func TestTLSDialer(t *testing.T) {
}) })
} }
func TestNewTLSHandshakerStdlib(t *testing.T) {
th := NewTLSHandshakerStdlib(log.Log)
logger := th.(*tlsHandshakerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn != nil {
t.Fatal("expected nil NewConn")
}
}
func TestNewTLSDialer(t *testing.T) {
d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{}
dialer := NewTLSDialer(d, th)
tlsd := dialer.(*tlsDialer)
if tlsd.Config == nil {
t.Fatal("unexpected config")
}
if tlsd.Dialer != d {
t.Fatal("unexpected dialer")
}
if tlsd.TLSHandshaker != th {
t.Fatal("invalid handshaker")
}
}
func TestNewSingleUseTLSDialer(t *testing.T) { func TestNewSingleUseTLSDialer(t *testing.T) {
conn := &mocks.TLSConn{} conn := &mocks.TLSConn{}
d := NewSingleUseTLSDialer(conn) d := NewSingleUseTLSDialer(conn)

View File

@ -11,15 +11,18 @@ import (
// NewTLSHandshakerUTLS creates a new TLS handshaker using the // NewTLSHandshakerUTLS creates a new TLS handshaker using the
// gitlab.com/yawning/utls library to create TLS conns. // gitlab.com/yawning/utls library to create TLS conns.
//
// The handshaker guarantees:
//
// 1. logging
//
// 2. error wrapping
//
// Passing a nil `id` will make this function panic.
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker { func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
return &tlsHandshakerLogger{ return newTLSHandshaker(&tlsHandshakerConfigurable{
TLSHandshaker: &tlsHandshakerErrWrapper{ NewConn: newConnUTLS(id),
TLSHandshaker: &tlsHandshakerConfigurable{ }, logger)
NewConn: newConnUTLS(id),
},
},
Logger: logger,
}
} }
// utlsConn implements TLSConn and uses a utls UConn as its underlying connection // utlsConn implements TLSConn and uses a utls UConn as its underlying connection