netxlite: code quality, improve tests, docs (#494)

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-08 22:48:10 +02:00 committed by GitHub
parent 3cd88debdc
commit 50b58672c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 318 additions and 154 deletions

View File

@ -40,7 +40,11 @@ type Dialer interface {
//
// 4. wraps errors;
//
// 5. has a configured connect timeout.
// 5. has a configured connect timeout;
//
// 6. if a dialer wraps a resolver, the dialer will forward
// the CloseIdleConnection call to its resolver (which is
// instrumental to manage a DoH resolver connections properly).
func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
return &dialerLogger{
Dialer: &dialerResolver{

View File

@ -24,6 +24,9 @@
//
// We also want to mock any underlying dependency for testing.
//
// We also want to map errors to OONI failures, which are described by
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
//
// Operations
//
// This package implements the following operations:

View File

@ -30,7 +30,6 @@ type httpTransportLogger struct {
var _ HTTPTransport = &httpTransportLogger{}
// RoundTrip implements HTTPTransport.RoundTrip.
func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) {
host := req.Host
if host == "" {
@ -64,13 +63,12 @@ func (txp *httpTransportLogger) logTrip(req *http.Request) (*http.Response, erro
return resp, nil
}
// CloseIdleConnections implement HTTPTransport.CloseIdleConnections.
func (txp *httpTransportLogger) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
}
// httpTransportConnectionsCloser is an HTTPTransport that
// correctly forwards CloseIdleConnections.
// correctly forwards CloseIdleConnections calls.
type httpTransportConnectionsCloser struct {
HTTPTransport
Dialer
@ -98,6 +96,16 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
// The returned transport will disable transparent decompression
// of compressed response bodies (and will not automatically
// ask for such compression, though you can always do that manually).
//
// The returned transport will configure TCP and TLS connections
// created using its dialer and TLS dialer to always have a
// read watchdog timeout to address https://github.com/ooni/probe/issues/1609.
//
// The returned transport will always enforce 1 connection per host
// and we cannot get rid of this QUIRK requirement because it is
// necessary to perform sane measurements with tracing. We will be
// able to possibly relax this requirement after we change the
// way in which we perform measurements.
func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
// Using oohttp to support any TLS library.
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
@ -115,8 +123,6 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTra
// Better for Cloudflare DNS and also better because we have less
// noisy events and we can better understand what happened.
//
// UNDOCUMENTED: I am wondering whether we can relax this constraint.
txp.MaxConnsPerHost = 1
// The following (1) reduces the number of headers that Go will
@ -175,7 +181,7 @@ func (d *httpTLSDialerWithReadTimeout) DialTLSContext(
if err != nil {
return nil, err
}
tconn, okay := conn.(TLSConn)
tconn, okay := conn.(TLSConn) // part of the contract but let's be graceful
if !okay {
conn.Close() // we own the conn here
return nil, ErrNotTLSConn

View File

@ -33,7 +33,7 @@ func TestHTTP3Dialer(t *testing.T) {
})
}
func TestHTTP3TransportClosesIdleConnections(t *testing.T) {
func TestHTTP3Transport(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var (
calledHTTP3 bool

View File

@ -21,8 +21,17 @@ import (
func TestHTTPTransportLogger(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("with failure", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebug: func(message string) {
count++
},
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
txp := &httpTransportLogger{
Logger: log.Log,
Logger: lo,
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return nil, io.EOF
@ -37,6 +46,9 @@ func TestHTTPTransportLogger(t *testing.T) {
if resp != nil {
t.Fatal("expected nil response here")
}
if count < 1 {
t.Fatal("no logs?!")
}
})
t.Run("we add the host header", func(t *testing.T) {
@ -73,8 +85,17 @@ func TestHTTPTransportLogger(t *testing.T) {
})
t.Run("with success", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebug: func(message string) {
count++
},
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
txp := &httpTransportLogger{
Logger: log.Log,
Logger: lo,
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return &http.Response{
@ -94,6 +115,9 @@ func TestHTTPTransportLogger(t *testing.T) {
}
iox.ReadAllContext(context.Background(), resp.Body)
resp.Body.Close()
if count < 1 {
t.Fatal("no logs?!")
}
})
})

View File

@ -12,6 +12,26 @@ import (
utls "gitlab.com/yawning/utls.git"
)
func TestResolver(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
t.Run("works as intended", func(t *testing.T) {
// TODO(bassosimone): this is actually an integration
// test but how to test this case?
r := netxlite.NewResolverSystem(log.Log)
defer r.CloseIdleConnections()
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil result here")
}
})
}
func TestHTTPTransport(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")

View File

@ -13,6 +13,7 @@ var (
DefaultDialer = &dialerSystem{}
DefaultTLSHandshaker = defaultTLSHandshaker
NewConnUTLS = newConnUTLS
DefaultResolver = &resolverSystem{}
)
// These types export internal names to legacy ooni/probe-cli code.

View File

@ -50,6 +50,24 @@ type QUICDialer interface {
// NewQUICDialerWithResolver returns a QUICDialer using the given
// QUICListener to create listening connections and the given Resolver
// to resolve domain names (if needed).
//
// Properties of the dialer:
//
// 1. logs events using the given logger;
//
// 2. resolves domain names using the givern resolver;
//
// 3. when using a resolver, _may_ attempt multiple dials
// in parallel (happy eyeballs) and _may_ return an aggregate
// error to the caller;
//
// 4. wraps errors;
//
// 5. has a configured connect timeout;
//
// 6. if a dialer wraps a resolver, the dialer will forward
// the CloseIdleConnection call to its resolver (which is
// instrumental to manage a DoH resolver connections properly).
func NewQUICDialerWithResolver(listener QUICListener,
logger Logger, resolver Resolver) QUICDialer {
return &quicDialerLogger{
@ -210,12 +228,9 @@ func (d *quicDialerResolver) DialContext(
return nil, err
}
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
// TODO(bassosimone): here we should be using multierror rather
// than just calling ReduceErrors. We are not ready to do that
// yet, though. To do that, we need first to modify nettests so
// that we actually avoid dialing when measuring.
//
// See also the quirks.go file. This is clearly a QUIRK.
// See TODO(https://github.com/ooni/probe/issues/1779) however
// this is less of a problem for QUIC because so far we have been
// using it to perform research only (i.e., urlgetter).
addrs = quirkSortIPAddrs(addrs)
var errorslist []error
for _, addr := range addrs {

View File

@ -17,6 +17,34 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
)
func TestNewQUICListener(t *testing.T) {
ql := NewQUICListener()
qew := ql.(*quicListenerErrWrapper)
_ = qew.QUICListener.(*quicListenerStdlib)
}
func TestNewQUICDialer(t *testing.T) {
ql := NewQUICListener()
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
logger := dlr.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
resolver := logger.Dialer.(*quicDialerResolver)
if _, okay := resolver.Resolver.(*nullResolver); !okay {
t.Fatal("invalid resolver type")
}
logger = resolver.Dialer.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
if base.QUICListener != ql {
t.Fatal("invalid quic listener")
}
}
func TestQUICDialerQUICGo(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("cannot split host port", func(t *testing.T) {
@ -223,7 +251,6 @@ func TestQUICDialerQUICGo(t *testing.T) {
}
func TestQUICDialerResolver(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var (
forDialer bool
@ -302,7 +329,7 @@ func TestQUICDialerResolver(t *testing.T) {
}
})
t.Run("with invalid port", func(t *testing.T) {
t.Run("with invalid port (i.e., the zero port)", func(t *testing.T) {
// This test allows us to check for the case where every attempt
// to establish a connection leads to a failure
tlsConf := &tls.Config{}
@ -376,7 +403,6 @@ func TestQUICDialerResolver(t *testing.T) {
}
func TestQUICLoggerDialer(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var forDialer bool
d := &quicDialerLogger{
@ -394,6 +420,12 @@ func TestQUICLoggerDialer(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
var called int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
called++
},
}
d := &quicDialerLogger{
Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network string,
@ -407,7 +439,7 @@ func TestQUICLoggerDialer(t *testing.T) {
}, nil
},
},
Logger: log.Log,
Logger: lo,
}
ctx := context.Background()
tlsConfig := &tls.Config{}
@ -419,9 +451,18 @@ func TestQUICLoggerDialer(t *testing.T) {
if err := sess.CloseWithError(0, ""); err != nil {
t.Fatal(err)
}
if called != 2 {
t.Fatal("invalid number of calls")
}
})
t.Run("on failure", func(t *testing.T) {
var called int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
called++
},
}
expected := errors.New("mocked error")
d := &quicDialerLogger{
Dialer: &mocks.QUICDialer{
@ -431,7 +472,7 @@ func TestQUICLoggerDialer(t *testing.T) {
return nil, expected
},
},
Logger: log.Log,
Logger: lo,
}
ctx := context.Background()
tlsConfig := &tls.Config{}
@ -443,32 +484,13 @@ func TestQUICLoggerDialer(t *testing.T) {
if sess != nil {
t.Fatal("expected nil session")
}
if called != 2 {
t.Fatal("invalid number of calls")
}
})
})
}
func TestNewQUICDialer(t *testing.T) {
ql := NewQUICListener()
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
logger := dlr.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
resolver := logger.Dialer.(*quicDialerResolver)
if _, okay := resolver.Resolver.(*nullResolver); !okay {
t.Fatal("invalid resolver type")
}
logger = resolver.Dialer.(*quicDialerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
if base.QUICListener != ql {
t.Fatal("invalid quic listener")
}
}
func TestNewSingleUseQUICDialer(t *testing.T) {
sess := &mocks.QUICEarlySession{}
qd := NewSingleUseQUICDialer(sess)

View File

@ -54,26 +54,34 @@ func TestQuirkReduceErrors(t *testing.T) {
}
func TestQuirkSortIPAddrs(t *testing.T) {
addrs := []string{
"::1",
"192.168.1.2",
"2a00:1450:4002:404::2004",
"142.250.184.36",
"2604:8800:5000:82:466:38ff:fecb:d46e",
"198.145.29.83",
"95.216.163.36",
}
expected := []string{
"192.168.1.2",
"142.250.184.36",
"198.145.29.83",
"95.216.163.36",
"::1",
"2a00:1450:4002:404::2004",
"2604:8800:5000:82:466:38ff:fecb:d46e",
}
out := quirkSortIPAddrs(addrs)
if diff := cmp.Diff(expected, out); diff != "" {
t.Fatal(diff)
}
t.Run("with some addrs", func(t *testing.T) {
addrs := []string{
"::1",
"192.168.1.2",
"2a00:1450:4002:404::2004",
"142.250.184.36",
"2604:8800:5000:82:466:38ff:fecb:d46e",
"198.145.29.83",
"95.216.163.36",
}
expected := []string{
"192.168.1.2",
"142.250.184.36",
"198.145.29.83",
"95.216.163.36",
"::1",
"2a00:1450:4002:404::2004",
"2604:8800:5000:82:466:38ff:fecb:d46e",
}
out := quirkSortIPAddrs(addrs)
if diff := cmp.Diff(expected, out); diff != "" {
t.Fatal(diff)
}
})
t.Run("with an empty list", func(t *testing.T) {
if quirkSortIPAddrs(nil) != nil {
t.Fatal("expected nil output")
}
})
}

View File

@ -27,6 +27,20 @@ type Resolver interface {
// NewResolverSystem creates a new resolver using system
// facilities for resolving domain names (e.g., getaddrinfo).
//
// The resolver will provide the following guarantees:
//
// 1. handles IDNA;
//
// 2. performs logging;
//
// 3. short-circuits IP addresses like getaddrinfo does (i.e.,
// resolving "1.1.1.1" yields []string{"1.1.1.1"};
//
// 4. wraps errors;
//
// 5. enforces reasonable timeouts (
// see https://github.com/ooni/probe/issues/1726).
func NewResolverSystem(logger Logger) Resolver {
return &resolverIDNA{
Resolver: &resolverLogger{
@ -48,7 +62,6 @@ type resolverSystem struct {
var _ Resolver = &resolverSystem{}
// LookupHost implements Resolver.LookupHost.
func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
// This code forces adding a shorter timeout to the domain name
// resolutions when using the system resolver. We have seen cases
@ -89,24 +102,18 @@ func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) (
return net.DefaultResolver.LookupHost
}
// Network implements Resolver.Network.
func (r *resolverSystem) Network() string {
return "system"
}
// Address implements Resolver.Address.
func (r *resolverSystem) Address() string {
return ""
}
// CloseIdleConnections implements Resolver.CloseIdleConnections.
func (r *resolverSystem) CloseIdleConnections() {
// nothing
// nothing to do
}
// DefaultResolver is the resolver we use by default.
var DefaultResolver = &resolverSystem{}
// resolverLogger is a resolver that emits events
type resolverLogger struct {
Resolver
@ -115,7 +122,6 @@ type resolverLogger struct {
var _ Resolver = &resolverLogger{}
// LookupHost returns the IP addresses of a host
func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
r.Logger.Debugf("resolve %s...", hostname)
start := time.Now()
@ -136,7 +142,6 @@ type resolverIDNA struct {
Resolver
}
// LookupHost implements Resolver.LookupHost.
func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]string, error) {
host, err := idna.ToASCII(hostname)
if err != nil {
@ -151,7 +156,6 @@ type resolverShortCircuitIPAddr struct {
Resolver
}
// LookupHost implements Resolver.LookupHost.
func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {
return []string{hostname}, nil
@ -166,24 +170,20 @@ var ErrNoResolver = errors.New("no configured resolver")
// domain names to IP addresses and always returns ErrNoResolver.
type nullResolver struct{}
// LookupHost implements Resolver.LookupHost.
func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) {
return nil, ErrNoResolver
}
// Network implements Resolver.Network.
func (r *nullResolver) Network() string {
return "null"
}
// Address implements Resolver.Address.
func (r *nullResolver) Address() string {
return ""
}
// CloseIdleConnections implements Resolver.CloseIdleConnections.
func (r *nullResolver) CloseIdleConnections() {
// nothing
// nothing to do
}
// resolverErrWrapper is a Resolver that knows about wrapping errors.
@ -193,7 +193,6 @@ type resolverErrWrapper struct {
var _ Resolver = &resolverErrWrapper{}
// LookupHost implements Resolver.LookupHost
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil {

View File

@ -15,6 +15,18 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
func TestNewResolverSystem(t *testing.T) {
resolver := NewResolverSystem(log.Log)
idna := resolver.(*resolverIDNA)
logger := idna.Resolver.(*resolverLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
_ = errWrapper.Resolver.(*resolverSystem)
}
func TestResolverSystem(t *testing.T) {
t.Run("Network and Address", func(t *testing.T) {
r := &resolverSystem{}
@ -26,16 +38,9 @@ func TestResolverSystem(t *testing.T) {
}
})
t.Run("works as intended", func(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
r := &resolverSystem{}
defer r.CloseIdleConnections()
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil result here")
}
r.CloseIdleConnections() // to cover it
})
t.Run("check default timeout", func(t *testing.T) {
@ -45,7 +50,30 @@ func TestResolverSystem(t *testing.T) {
}
})
t.Run("check default lookup host func not nil", func(t *testing.T) {
r := &resolverSystem{}
if r.lookupHost() == nil {
t.Fatal("expected non-nil func here")
}
})
t.Run("LookupHost", func(t *testing.T) {
t.Run("with success", func(t *testing.T) {
r := &resolverSystem{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"8.8.8.8"}, nil
},
}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "example.antani")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("invalid addrs")
}
})
t.Run("with timeout and success", func(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(1)
@ -111,9 +139,15 @@ func TestResolverSystem(t *testing.T) {
func TestResolverLogger(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("with success", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
expected := []string{"1.1.1.1"}
r := resolverLogger{
Logger: log.Log,
Logger: lo,
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil
@ -127,12 +161,21 @@ func TestResolverLogger(t *testing.T) {
if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff)
}
if count != 2 {
t.Fatal("unexpected count")
}
})
t.Run("with failure", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
expected := errors.New("mocked error")
r := resolverLogger{
Logger: log.Log,
Logger: lo,
Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, expected
@ -146,6 +189,9 @@ func TestResolverLogger(t *testing.T) {
if addrs != nil {
t.Fatal("expected nil addr here")
}
if count != 2 {
t.Fatal("unexpected count")
}
})
})
}
@ -193,18 +239,6 @@ func TestResolverIDNA(t *testing.T) {
})
}
func TestNewResolverSystem(t *testing.T) {
resolver := NewResolverSystem(log.Log)
idna := resolver.(*resolverIDNA)
logger := idna.Resolver.(*resolverLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
_ = errWrapper.Resolver.(*resolverSystem)
}
func TestResolverShortCircuitIPAddr(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("with IP addr", func(t *testing.T) {
@ -261,7 +295,7 @@ func TestNullResolver(t *testing.T) {
if r.Address() != "" {
t.Fatal("invalid address")
}
r.CloseIdleConnections() // should not crash
r.CloseIdleConnections() // for coverage
}
func TestResolverErrWrapper(t *testing.T) {

View File

@ -99,7 +99,7 @@ func ConfigureTLSVersion(config *tls.Config, version string) error {
config.MinVersion = tls.VersionTLS10
config.MaxVersion = tls.VersionTLS10
case "":
// nothing
// nothing to do
default:
return ErrInvalidTLSVersion
}
@ -119,7 +119,7 @@ type TLSHandshaker interface {
// the given config. This function DOES NOT take ownership of the connection
// and it's your responsibility to close it on failure.
//
// The returned connection will always implement the TLSConn interface
// QUIRK: The returned connection will always implement the TLSConn interface
// exposed by this package. A future version of this interface will instead
// return directly a TLSConn and remove the ConnectionState param.
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
@ -128,10 +128,21 @@ type TLSHandshaker interface {
// NewTLSHandshakerStdlib creates a new TLS handshaker using the
// go standard library to create TLS connections.
//
// The handshaker guarantees:
//
// 1. logging
//
// 2. error wrapping
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
return newTLSHandshaker(&tlsHandshakerConfigurable{}, logger)
}
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
func newTLSHandshaker(th TLSHandshaker, logger Logger) TLSHandshaker {
return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerErrWrapper{
TLSHandshaker: &tlsHandshakerConfigurable{},
TLSHandshaker: th,
},
Logger: logger,
}
@ -191,11 +202,8 @@ var defaultTLSHandshaker = &tlsHandshakerConfigurable{}
// tlsHandshakerLogger is a TLSHandshaker with logging.
type tlsHandshakerLogger struct {
// TLSHandshaker is the underlying handshaker.
TLSHandshaker TLSHandshaker
// Logger is the underlying logger.
Logger Logger
TLSHandshaker
Logger
}
var _ TLSHandshaker = &tlsHandshakerLogger{}

View File

@ -118,10 +118,22 @@ func TestConfigureTLSVersion(t *testing.T) {
}
}
func TestNewTLSHandshakerStdlib(t *testing.T) {
th := NewTLSHandshakerStdlib(log.Log)
logger := th.(*tlsHandshakerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn != nil {
t.Fatal("expected nil NewConn")
}
}
func TestTLSHandshakerConfigurable(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("with error", func(t *testing.T) {
var times []time.Time
h := &tlsHandshakerConfigurable{}
tcpConn := &mocks.Conn{
@ -230,13 +242,19 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
func TestTLSHandshakerLogger(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
th := &tlsHandshakerLogger{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return tls.Client(conn, config), tls.ConnectionState{}, nil
},
},
Logger: log.Log,
Logger: lo,
}
conn := &mocks.Conn{
MockClose: func() error {
@ -255,9 +273,18 @@ func TestTLSHandshakerLogger(t *testing.T) {
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here")
}
if count != 2 {
t.Fatal("invalid count")
}
})
t.Run("on failure", func(t *testing.T) {
var count int
lo := &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
count++
},
}
expected := errors.New("mocked error")
th := &tlsHandshakerLogger{
TLSHandshaker: &mocks.TLSHandshaker{
@ -265,7 +292,7 @@ func TestTLSHandshakerLogger(t *testing.T) {
return nil, tls.ConnectionState{}, expected
},
},
Logger: log.Log,
Logger: lo,
}
conn := &mocks.Conn{
MockClose: func() error {
@ -284,10 +311,29 @@ func TestTLSHandshakerLogger(t *testing.T) {
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here")
}
if count != 2 {
t.Fatal("invalid count")
}
})
})
}
func TestNewTLSDialer(t *testing.T) {
d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{}
dialer := NewTLSDialer(d, th)
tlsd := dialer.(*tlsDialer)
if tlsd.Config == nil {
t.Fatal("unexpected config")
}
if tlsd.Dialer != d {
t.Fatal("unexpected dialer")
}
if tlsd.TLSHandshaker != th {
t.Fatal("invalid handshaker")
}
}
func TestTLSDialer(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
@ -439,35 +485,6 @@ func TestTLSDialer(t *testing.T) {
})
}
func TestNewTLSHandshakerStdlib(t *testing.T) {
th := NewTLSHandshakerStdlib(log.Log)
logger := th.(*tlsHandshakerLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn != nil {
t.Fatal("expected nil NewConn")
}
}
func TestNewTLSDialer(t *testing.T) {
d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{}
dialer := NewTLSDialer(d, th)
tlsd := dialer.(*tlsDialer)
if tlsd.Config == nil {
t.Fatal("unexpected config")
}
if tlsd.Dialer != d {
t.Fatal("unexpected dialer")
}
if tlsd.TLSHandshaker != th {
t.Fatal("invalid handshaker")
}
}
func TestNewSingleUseTLSDialer(t *testing.T) {
conn := &mocks.TLSConn{}
d := NewSingleUseTLSDialer(conn)

View File

@ -11,15 +11,18 @@ import (
// NewTLSHandshakerUTLS creates a new TLS handshaker using the
// gitlab.com/yawning/utls library to create TLS conns.
//
// The handshaker guarantees:
//
// 1. logging
//
// 2. error wrapping
//
// Passing a nil `id` will make this function panic.
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerErrWrapper{
TLSHandshaker: &tlsHandshakerConfigurable{
NewConn: newConnUTLS(id),
},
},
Logger: logger,
}
return newTLSHandshaker(&tlsHandshakerConfigurable{
NewConn: newConnUTLS(id),
}, logger)
}
// utlsConn implements TLSConn and uses a utls UConn as its underlying connection