From 50b58672c6b44bdf27c8be34435a6794e3a11e9e Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 8 Sep 2021 22:48:10 +0200 Subject: [PATCH] netxlite: code quality, improve tests, docs (#494) See https://github.com/ooni/probe/issues/1591 --- internal/netxlite/dialer.go | 6 +- internal/netxlite/doc.go | 3 + internal/netxlite/http.go | 18 ++++-- internal/netxlite/http3_test.go | 2 +- internal/netxlite/http_test.go | 28 ++++++++- internal/netxlite/integration_test.go | 20 +++++++ internal/netxlite/legacy.go | 1 + internal/netxlite/quic.go | 27 +++++++-- internal/netxlite/quic_test.go | 76 ++++++++++++++++--------- internal/netxlite/quirks_test.go | 52 ++++++++++------- internal/netxlite/resolver.go | 33 ++++++----- internal/netxlite/resolver_test.go | 82 +++++++++++++++++++-------- internal/netxlite/tls.go | 24 +++++--- internal/netxlite/tls_test.go | 81 +++++++++++++++----------- internal/netxlite/utls.go | 19 ++++--- 15 files changed, 318 insertions(+), 154 deletions(-) diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index b78894e..df61e6d 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -40,7 +40,11 @@ type Dialer interface { // // 4. wraps errors; // -// 5. has a configured connect timeout. +// 5. has a configured connect timeout; +// +// 6. if a dialer wraps a resolver, the dialer will forward +// the CloseIdleConnection call to its resolver (which is +// instrumental to manage a DoH resolver connections properly). func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer { return &dialerLogger{ Dialer: &dialerResolver{ diff --git a/internal/netxlite/doc.go b/internal/netxlite/doc.go index 59dafb3..db6feb6 100644 --- a/internal/netxlite/doc.go +++ b/internal/netxlite/doc.go @@ -24,6 +24,9 @@ // // We also want to mock any underlying dependency for testing. // +// We also want to map errors to OONI failures, which are described by +// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md. +// // Operations // // This package implements the following operations: diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index bca6f4a..450c972 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -30,7 +30,6 @@ type httpTransportLogger struct { var _ HTTPTransport = &httpTransportLogger{} -// RoundTrip implements HTTPTransport.RoundTrip. func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) { host := req.Host if host == "" { @@ -64,13 +63,12 @@ func (txp *httpTransportLogger) logTrip(req *http.Request) (*http.Response, erro return resp, nil } -// CloseIdleConnections implement HTTPTransport.CloseIdleConnections. func (txp *httpTransportLogger) CloseIdleConnections() { txp.HTTPTransport.CloseIdleConnections() } // httpTransportConnectionsCloser is an HTTPTransport that -// correctly forwards CloseIdleConnections. +// correctly forwards CloseIdleConnections calls. type httpTransportConnectionsCloser struct { HTTPTransport Dialer @@ -98,6 +96,16 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() { // The returned transport will disable transparent decompression // of compressed response bodies (and will not automatically // ask for such compression, though you can always do that manually). +// +// The returned transport will configure TCP and TLS connections +// created using its dialer and TLS dialer to always have a +// read watchdog timeout to address https://github.com/ooni/probe/issues/1609. +// +// The returned transport will always enforce 1 connection per host +// and we cannot get rid of this QUIRK requirement because it is +// necessary to perform sane measurements with tracing. We will be +// able to possibly relax this requirement after we change the +// way in which we perform measurements. func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport { // Using oohttp to support any TLS library. txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone() @@ -115,8 +123,6 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTra // Better for Cloudflare DNS and also better because we have less // noisy events and we can better understand what happened. - // - // UNDOCUMENTED: I am wondering whether we can relax this constraint. txp.MaxConnsPerHost = 1 // The following (1) reduces the number of headers that Go will @@ -175,7 +181,7 @@ func (d *httpTLSDialerWithReadTimeout) DialTLSContext( if err != nil { return nil, err } - tconn, okay := conn.(TLSConn) + tconn, okay := conn.(TLSConn) // part of the contract but let's be graceful if !okay { conn.Close() // we own the conn here return nil, ErrNotTLSConn diff --git a/internal/netxlite/http3_test.go b/internal/netxlite/http3_test.go index dfcdf2f..c4e0bd2 100644 --- a/internal/netxlite/http3_test.go +++ b/internal/netxlite/http3_test.go @@ -33,7 +33,7 @@ func TestHTTP3Dialer(t *testing.T) { }) } -func TestHTTP3TransportClosesIdleConnections(t *testing.T) { +func TestHTTP3Transport(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) { var ( calledHTTP3 bool diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 131817b..bf0fbd7 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -21,8 +21,17 @@ import ( func TestHTTPTransportLogger(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { t.Run("with failure", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebug: func(message string) { + count++ + }, + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } txp := &httpTransportLogger{ - Logger: log.Log, + Logger: lo, HTTPTransport: &mocks.HTTPTransport{ MockRoundTrip: func(req *http.Request) (*http.Response, error) { return nil, io.EOF @@ -37,6 +46,9 @@ func TestHTTPTransportLogger(t *testing.T) { if resp != nil { t.Fatal("expected nil response here") } + if count < 1 { + t.Fatal("no logs?!") + } }) t.Run("we add the host header", func(t *testing.T) { @@ -73,8 +85,17 @@ func TestHTTPTransportLogger(t *testing.T) { }) t.Run("with success", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebug: func(message string) { + count++ + }, + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } txp := &httpTransportLogger{ - Logger: log.Log, + Logger: lo, HTTPTransport: &mocks.HTTPTransport{ MockRoundTrip: func(req *http.Request) (*http.Response, error) { return &http.Response{ @@ -94,6 +115,9 @@ func TestHTTPTransportLogger(t *testing.T) { } iox.ReadAllContext(context.Background(), resp.Body) resp.Body.Close() + if count < 1 { + t.Fatal("no logs?!") + } }) }) diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 795e94f..36e6a37 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -12,6 +12,26 @@ import ( utls "gitlab.com/yawning/utls.git" ) +func TestResolver(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + t.Run("works as intended", func(t *testing.T) { + // TODO(bassosimone): this is actually an integration + // test but how to test this case? + r := netxlite.NewResolverSystem(log.Log) + defer r.CloseIdleConnections() + addrs, err := r.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil result here") + } + }) +} + func TestHTTPTransport(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 2f304b6..83e0b8b 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -13,6 +13,7 @@ var ( DefaultDialer = &dialerSystem{} DefaultTLSHandshaker = defaultTLSHandshaker NewConnUTLS = newConnUTLS + DefaultResolver = &resolverSystem{} ) // These types export internal names to legacy ooni/probe-cli code. diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index 5fe5e28..862bcf1 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -50,6 +50,24 @@ type QUICDialer interface { // NewQUICDialerWithResolver returns a QUICDialer using the given // QUICListener to create listening connections and the given Resolver // to resolve domain names (if needed). +// +// Properties of the dialer: +// +// 1. logs events using the given logger; +// +// 2. resolves domain names using the givern resolver; +// +// 3. when using a resolver, _may_ attempt multiple dials +// in parallel (happy eyeballs) and _may_ return an aggregate +// error to the caller; +// +// 4. wraps errors; +// +// 5. has a configured connect timeout; +// +// 6. if a dialer wraps a resolver, the dialer will forward +// the CloseIdleConnection call to its resolver (which is +// instrumental to manage a DoH resolver connections properly). func NewQUICDialerWithResolver(listener QUICListener, logger Logger, resolver Resolver) QUICDialer { return &quicDialerLogger{ @@ -210,12 +228,9 @@ func (d *quicDialerResolver) DialContext( return nil, err } tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost) - // TODO(bassosimone): here we should be using multierror rather - // than just calling ReduceErrors. We are not ready to do that - // yet, though. To do that, we need first to modify nettests so - // that we actually avoid dialing when measuring. - // - // See also the quirks.go file. This is clearly a QUIRK. + // See TODO(https://github.com/ooni/probe/issues/1779) however + // this is less of a problem for QUIC because so far we have been + // using it to perform research only (i.e., urlgetter). addrs = quirkSortIPAddrs(addrs) var errorslist []error for _, addr := range addrs { diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index c68fd09..5fa349e 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -17,6 +17,34 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/quicx" ) +func TestNewQUICListener(t *testing.T) { + ql := NewQUICListener() + qew := ql.(*quicListenerErrWrapper) + _ = qew.QUICListener.(*quicListenerStdlib) +} + +func TestNewQUICDialer(t *testing.T) { + ql := NewQUICListener() + dlr := NewQUICDialerWithoutResolver(ql, log.Log) + logger := dlr.(*quicDialerLogger) + if logger.Logger != log.Log { + t.Fatal("invalid logger") + } + resolver := logger.Dialer.(*quicDialerResolver) + if _, okay := resolver.Resolver.(*nullResolver); !okay { + t.Fatal("invalid resolver type") + } + logger = resolver.Dialer.(*quicDialerLogger) + if logger.Logger != log.Log { + t.Fatal("invalid logger") + } + errWrapper := logger.Dialer.(*quicDialerErrWrapper) + base := errWrapper.QUICDialer.(*quicDialerQUICGo) + if base.QUICListener != ql { + t.Fatal("invalid quic listener") + } +} + func TestQUICDialerQUICGo(t *testing.T) { t.Run("DialContext", func(t *testing.T) { t.Run("cannot split host port", func(t *testing.T) { @@ -223,7 +251,6 @@ func TestQUICDialerQUICGo(t *testing.T) { } func TestQUICDialerResolver(t *testing.T) { - t.Run("CloseIdleConnections", func(t *testing.T) { var ( forDialer bool @@ -302,7 +329,7 @@ func TestQUICDialerResolver(t *testing.T) { } }) - t.Run("with invalid port", func(t *testing.T) { + t.Run("with invalid port (i.e., the zero port)", func(t *testing.T) { // This test allows us to check for the case where every attempt // to establish a connection leads to a failure tlsConf := &tls.Config{} @@ -376,7 +403,6 @@ func TestQUICDialerResolver(t *testing.T) { } func TestQUICLoggerDialer(t *testing.T) { - t.Run("CloseIdleConnections", func(t *testing.T) { var forDialer bool d := &quicDialerLogger{ @@ -394,6 +420,12 @@ func TestQUICLoggerDialer(t *testing.T) { t.Run("DialContext", func(t *testing.T) { t.Run("on success", func(t *testing.T) { + var called int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + called++ + }, + } d := &quicDialerLogger{ Dialer: &mocks.QUICDialer{ MockDialContext: func(ctx context.Context, network string, @@ -407,7 +439,7 @@ func TestQUICLoggerDialer(t *testing.T) { }, nil }, }, - Logger: log.Log, + Logger: lo, } ctx := context.Background() tlsConfig := &tls.Config{} @@ -419,9 +451,18 @@ func TestQUICLoggerDialer(t *testing.T) { if err := sess.CloseWithError(0, ""); err != nil { t.Fatal(err) } + if called != 2 { + t.Fatal("invalid number of calls") + } }) t.Run("on failure", func(t *testing.T) { + var called int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + called++ + }, + } expected := errors.New("mocked error") d := &quicDialerLogger{ Dialer: &mocks.QUICDialer{ @@ -431,7 +472,7 @@ func TestQUICLoggerDialer(t *testing.T) { return nil, expected }, }, - Logger: log.Log, + Logger: lo, } ctx := context.Background() tlsConfig := &tls.Config{} @@ -443,32 +484,13 @@ func TestQUICLoggerDialer(t *testing.T) { if sess != nil { t.Fatal("expected nil session") } + if called != 2 { + t.Fatal("invalid number of calls") + } }) }) } -func TestNewQUICDialer(t *testing.T) { - ql := NewQUICListener() - dlr := NewQUICDialerWithoutResolver(ql, log.Log) - logger := dlr.(*quicDialerLogger) - if logger.Logger != log.Log { - t.Fatal("invalid logger") - } - resolver := logger.Dialer.(*quicDialerResolver) - if _, okay := resolver.Resolver.(*nullResolver); !okay { - t.Fatal("invalid resolver type") - } - logger = resolver.Dialer.(*quicDialerLogger) - if logger.Logger != log.Log { - t.Fatal("invalid logger") - } - errWrapper := logger.Dialer.(*quicDialerErrWrapper) - base := errWrapper.QUICDialer.(*quicDialerQUICGo) - if base.QUICListener != ql { - t.Fatal("invalid quic listener") - } -} - func TestNewSingleUseQUICDialer(t *testing.T) { sess := &mocks.QUICEarlySession{} qd := NewSingleUseQUICDialer(sess) diff --git a/internal/netxlite/quirks_test.go b/internal/netxlite/quirks_test.go index f12d20b..189c6a8 100644 --- a/internal/netxlite/quirks_test.go +++ b/internal/netxlite/quirks_test.go @@ -54,26 +54,34 @@ func TestQuirkReduceErrors(t *testing.T) { } func TestQuirkSortIPAddrs(t *testing.T) { - addrs := []string{ - "::1", - "192.168.1.2", - "2a00:1450:4002:404::2004", - "142.250.184.36", - "2604:8800:5000:82:466:38ff:fecb:d46e", - "198.145.29.83", - "95.216.163.36", - } - expected := []string{ - "192.168.1.2", - "142.250.184.36", - "198.145.29.83", - "95.216.163.36", - "::1", - "2a00:1450:4002:404::2004", - "2604:8800:5000:82:466:38ff:fecb:d46e", - } - out := quirkSortIPAddrs(addrs) - if diff := cmp.Diff(expected, out); diff != "" { - t.Fatal(diff) - } + t.Run("with some addrs", func(t *testing.T) { + addrs := []string{ + "::1", + "192.168.1.2", + "2a00:1450:4002:404::2004", + "142.250.184.36", + "2604:8800:5000:82:466:38ff:fecb:d46e", + "198.145.29.83", + "95.216.163.36", + } + expected := []string{ + "192.168.1.2", + "142.250.184.36", + "198.145.29.83", + "95.216.163.36", + "::1", + "2a00:1450:4002:404::2004", + "2604:8800:5000:82:466:38ff:fecb:d46e", + } + out := quirkSortIPAddrs(addrs) + if diff := cmp.Diff(expected, out); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("with an empty list", func(t *testing.T) { + if quirkSortIPAddrs(nil) != nil { + t.Fatal("expected nil output") + } + }) } diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolver.go index d8a3e49..3239582 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolver.go @@ -27,6 +27,20 @@ type Resolver interface { // NewResolverSystem creates a new resolver using system // facilities for resolving domain names (e.g., getaddrinfo). +// +// The resolver will provide the following guarantees: +// +// 1. handles IDNA; +// +// 2. performs logging; +// +// 3. short-circuits IP addresses like getaddrinfo does (i.e., +// resolving "1.1.1.1" yields []string{"1.1.1.1"}; +// +// 4. wraps errors; +// +// 5. enforces reasonable timeouts ( +// see https://github.com/ooni/probe/issues/1726). func NewResolverSystem(logger Logger) Resolver { return &resolverIDNA{ Resolver: &resolverLogger{ @@ -48,7 +62,6 @@ type resolverSystem struct { var _ Resolver = &resolverSystem{} -// LookupHost implements Resolver.LookupHost. func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) { // This code forces adding a shorter timeout to the domain name // resolutions when using the system resolver. We have seen cases @@ -89,24 +102,18 @@ func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) ( return net.DefaultResolver.LookupHost } -// Network implements Resolver.Network. func (r *resolverSystem) Network() string { return "system" } -// Address implements Resolver.Address. func (r *resolverSystem) Address() string { return "" } -// CloseIdleConnections implements Resolver.CloseIdleConnections. func (r *resolverSystem) CloseIdleConnections() { - // nothing + // nothing to do } -// DefaultResolver is the resolver we use by default. -var DefaultResolver = &resolverSystem{} - // resolverLogger is a resolver that emits events type resolverLogger struct { Resolver @@ -115,7 +122,6 @@ type resolverLogger struct { var _ Resolver = &resolverLogger{} -// LookupHost returns the IP addresses of a host func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) { r.Logger.Debugf("resolve %s...", hostname) start := time.Now() @@ -136,7 +142,6 @@ type resolverIDNA struct { Resolver } -// LookupHost implements Resolver.LookupHost. func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]string, error) { host, err := idna.ToASCII(hostname) if err != nil { @@ -151,7 +156,6 @@ type resolverShortCircuitIPAddr struct { Resolver } -// LookupHost implements Resolver.LookupHost. func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname string) ([]string, error) { if net.ParseIP(hostname) != nil { return []string{hostname}, nil @@ -166,24 +170,20 @@ var ErrNoResolver = errors.New("no configured resolver") // domain names to IP addresses and always returns ErrNoResolver. type nullResolver struct{} -// LookupHost implements Resolver.LookupHost. func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) { return nil, ErrNoResolver } -// Network implements Resolver.Network. func (r *nullResolver) Network() string { return "null" } -// Address implements Resolver.Address. func (r *nullResolver) Address() string { return "" } -// CloseIdleConnections implements Resolver.CloseIdleConnections. func (r *nullResolver) CloseIdleConnections() { - // nothing + // nothing to do } // resolverErrWrapper is a Resolver that knows about wrapping errors. @@ -193,7 +193,6 @@ type resolverErrWrapper struct { var _ Resolver = &resolverErrWrapper{} -// LookupHost implements Resolver.LookupHost func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) { addrs, err := r.Resolver.LookupHost(ctx, hostname) if err != nil { diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolver_test.go index 93b0771..5eaa6b4 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolver_test.go @@ -15,6 +15,18 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) +func TestNewResolverSystem(t *testing.T) { + resolver := NewResolverSystem(log.Log) + idna := resolver.(*resolverIDNA) + logger := idna.Resolver.(*resolverLogger) + if logger.Logger != log.Log { + t.Fatal("invalid logger") + } + shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) + errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) + _ = errWrapper.Resolver.(*resolverSystem) +} + func TestResolverSystem(t *testing.T) { t.Run("Network and Address", func(t *testing.T) { r := &resolverSystem{} @@ -26,16 +38,9 @@ func TestResolverSystem(t *testing.T) { } }) - t.Run("works as intended", func(t *testing.T) { + t.Run("CloseIdleConnections", func(t *testing.T) { r := &resolverSystem{} - defer r.CloseIdleConnections() - addrs, err := r.LookupHost(context.Background(), "dns.google.com") - if err != nil { - t.Fatal(err) - } - if addrs == nil { - t.Fatal("expected non-nil result here") - } + r.CloseIdleConnections() // to cover it }) t.Run("check default timeout", func(t *testing.T) { @@ -45,7 +50,30 @@ func TestResolverSystem(t *testing.T) { } }) + t.Run("check default lookup host func not nil", func(t *testing.T) { + r := &resolverSystem{} + if r.lookupHost() == nil { + t.Fatal("expected non-nil func here") + } + }) + t.Run("LookupHost", func(t *testing.T) { + t.Run("with success", func(t *testing.T) { + r := &resolverSystem{ + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"8.8.8.8"}, nil + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "example.antani") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("invalid addrs") + } + }) + t.Run("with timeout and success", func(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(1) @@ -111,9 +139,15 @@ func TestResolverSystem(t *testing.T) { func TestResolverLogger(t *testing.T) { t.Run("LookupHost", func(t *testing.T) { t.Run("with success", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } expected := []string{"1.1.1.1"} r := resolverLogger{ - Logger: log.Log, + Logger: lo, Resolver: &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return expected, nil @@ -127,12 +161,21 @@ func TestResolverLogger(t *testing.T) { if diff := cmp.Diff(expected, addrs); diff != "" { t.Fatal(diff) } + if count != 2 { + t.Fatal("unexpected count") + } }) t.Run("with failure", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } expected := errors.New("mocked error") r := resolverLogger{ - Logger: log.Log, + Logger: lo, Resolver: &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, expected @@ -146,6 +189,9 @@ func TestResolverLogger(t *testing.T) { if addrs != nil { t.Fatal("expected nil addr here") } + if count != 2 { + t.Fatal("unexpected count") + } }) }) } @@ -193,18 +239,6 @@ func TestResolverIDNA(t *testing.T) { }) } -func TestNewResolverSystem(t *testing.T) { - resolver := NewResolverSystem(log.Log) - idna := resolver.(*resolverIDNA) - logger := idna.Resolver.(*resolverLogger) - if logger.Logger != log.Log { - t.Fatal("invalid logger") - } - shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) - errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) - _ = errWrapper.Resolver.(*resolverSystem) -} - func TestResolverShortCircuitIPAddr(t *testing.T) { t.Run("LookupHost", func(t *testing.T) { t.Run("with IP addr", func(t *testing.T) { @@ -261,7 +295,7 @@ func TestNullResolver(t *testing.T) { if r.Address() != "" { t.Fatal("invalid address") } - r.CloseIdleConnections() // should not crash + r.CloseIdleConnections() // for coverage } func TestResolverErrWrapper(t *testing.T) { diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index c112945..ac0b042 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -99,7 +99,7 @@ func ConfigureTLSVersion(config *tls.Config, version string) error { config.MinVersion = tls.VersionTLS10 config.MaxVersion = tls.VersionTLS10 case "": - // nothing + // nothing to do default: return ErrInvalidTLSVersion } @@ -119,7 +119,7 @@ type TLSHandshaker interface { // the given config. This function DOES NOT take ownership of the connection // and it's your responsibility to close it on failure. // - // The returned connection will always implement the TLSConn interface + // QUIRK: The returned connection will always implement the TLSConn interface // exposed by this package. A future version of this interface will instead // return directly a TLSConn and remove the ConnectionState param. Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( @@ -128,10 +128,21 @@ type TLSHandshaker interface { // NewTLSHandshakerStdlib creates a new TLS handshaker using the // go standard library to create TLS connections. +// +// The handshaker guarantees: +// +// 1. logging +// +// 2. error wrapping func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker { + return newTLSHandshaker(&tlsHandshakerConfigurable{}, logger) +} + +// newTLSHandshaker is the common factory for creating a new TLSHandshaker +func newTLSHandshaker(th TLSHandshaker, logger Logger) TLSHandshaker { return &tlsHandshakerLogger{ TLSHandshaker: &tlsHandshakerErrWrapper{ - TLSHandshaker: &tlsHandshakerConfigurable{}, + TLSHandshaker: th, }, Logger: logger, } @@ -191,11 +202,8 @@ var defaultTLSHandshaker = &tlsHandshakerConfigurable{} // tlsHandshakerLogger is a TLSHandshaker with logging. type tlsHandshakerLogger struct { - // TLSHandshaker is the underlying handshaker. - TLSHandshaker TLSHandshaker - - // Logger is the underlying logger. - Logger Logger + TLSHandshaker + Logger } var _ TLSHandshaker = &tlsHandshakerLogger{} diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index ae97c78..cec398d 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -118,10 +118,22 @@ func TestConfigureTLSVersion(t *testing.T) { } } +func TestNewTLSHandshakerStdlib(t *testing.T) { + th := NewTLSHandshakerStdlib(log.Log) + logger := th.(*tlsHandshakerLogger) + if logger.Logger != log.Log { + t.Fatal("invalid logger") + } + errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper) + configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable) + if configurable.NewConn != nil { + t.Fatal("expected nil NewConn") + } +} + func TestTLSHandshakerConfigurable(t *testing.T) { t.Run("Handshake", func(t *testing.T) { t.Run("with error", func(t *testing.T) { - var times []time.Time h := &tlsHandshakerConfigurable{} tcpConn := &mocks.Conn{ @@ -230,13 +242,19 @@ func TestTLSHandshakerConfigurable(t *testing.T) { func TestTLSHandshakerLogger(t *testing.T) { t.Run("Handshake", func(t *testing.T) { t.Run("on success", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } th := &tlsHandshakerLogger{ TLSHandshaker: &mocks.TLSHandshaker{ MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { return tls.Client(conn, config), tls.ConnectionState{}, nil }, }, - Logger: log.Log, + Logger: lo, } conn := &mocks.Conn{ MockClose: func() error { @@ -255,9 +273,18 @@ func TestTLSHandshakerLogger(t *testing.T) { if !reflect.ValueOf(connState).IsZero() { t.Fatal("expected zero ConnectionState here") } + if count != 2 { + t.Fatal("invalid count") + } }) t.Run("on failure", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } expected := errors.New("mocked error") th := &tlsHandshakerLogger{ TLSHandshaker: &mocks.TLSHandshaker{ @@ -265,7 +292,7 @@ func TestTLSHandshakerLogger(t *testing.T) { return nil, tls.ConnectionState{}, expected }, }, - Logger: log.Log, + Logger: lo, } conn := &mocks.Conn{ MockClose: func() error { @@ -284,10 +311,29 @@ func TestTLSHandshakerLogger(t *testing.T) { if !reflect.ValueOf(connState).IsZero() { t.Fatal("expected zero ConnectionState here") } + if count != 2 { + t.Fatal("invalid count") + } }) }) } +func TestNewTLSDialer(t *testing.T) { + d := &mocks.Dialer{} + th := &mocks.TLSHandshaker{} + dialer := NewTLSDialer(d, th) + tlsd := dialer.(*tlsDialer) + if tlsd.Config == nil { + t.Fatal("unexpected config") + } + if tlsd.Dialer != d { + t.Fatal("unexpected dialer") + } + if tlsd.TLSHandshaker != th { + t.Fatal("invalid handshaker") + } +} + func TestTLSDialer(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) { var called bool @@ -439,35 +485,6 @@ func TestTLSDialer(t *testing.T) { }) } -func TestNewTLSHandshakerStdlib(t *testing.T) { - th := NewTLSHandshakerStdlib(log.Log) - logger := th.(*tlsHandshakerLogger) - if logger.Logger != log.Log { - t.Fatal("invalid logger") - } - errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper) - configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable) - if configurable.NewConn != nil { - t.Fatal("expected nil NewConn") - } -} - -func TestNewTLSDialer(t *testing.T) { - d := &mocks.Dialer{} - th := &mocks.TLSHandshaker{} - dialer := NewTLSDialer(d, th) - tlsd := dialer.(*tlsDialer) - if tlsd.Config == nil { - t.Fatal("unexpected config") - } - if tlsd.Dialer != d { - t.Fatal("unexpected dialer") - } - if tlsd.TLSHandshaker != th { - t.Fatal("invalid handshaker") - } -} - func TestNewSingleUseTLSDialer(t *testing.T) { conn := &mocks.TLSConn{} d := NewSingleUseTLSDialer(conn) diff --git a/internal/netxlite/utls.go b/internal/netxlite/utls.go index d836a1a..995465c 100644 --- a/internal/netxlite/utls.go +++ b/internal/netxlite/utls.go @@ -11,15 +11,18 @@ import ( // NewTLSHandshakerUTLS creates a new TLS handshaker using the // gitlab.com/yawning/utls library to create TLS conns. +// +// The handshaker guarantees: +// +// 1. logging +// +// 2. error wrapping +// +// Passing a nil `id` will make this function panic. func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker { - return &tlsHandshakerLogger{ - TLSHandshaker: &tlsHandshakerErrWrapper{ - TLSHandshaker: &tlsHandshakerConfigurable{ - NewConn: newConnUTLS(id), - }, - }, - Logger: logger, - } + return newTLSHandshaker(&tlsHandshakerConfigurable{ + NewConn: newConnUTLS(id), + }, logger) } // utlsConn implements TLSConn and uses a utls UConn as its underlying connection