From f2e3e5cc08fd1b8a4ded0044ed3c24f1038c67d7 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 8 Sep 2021 11:39:27 +0200 Subject: [PATCH] refactor(netxlite): finish grouping tests (#488) They are now more readable. I'll do another pass and start separating integration testing from unit testing. I think we need to have some always on integration testing for netxlite that runs on macOS, linux, and windows. See https://github.com/ooni/probe/issues/1591 --- internal/netxlite/integration_test.go | 21 + internal/netxlite/legacy_test.go | 130 ++-- internal/netxlite/quic_test.go | 862 +++++++++++++------------- internal/netxlite/quirks_test.go | 3 + internal/netxlite/resolver_test.go | 436 ++++++------- internal/netxlite/tls_test.go | 651 +++++++++---------- internal/netxlite/utls_test.go | 165 +++-- 7 files changed, 1140 insertions(+), 1128 deletions(-) diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 4740145..a803423 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -1,12 +1,15 @@ package netxlite_test import ( + "context" "crypto/tls" + "net" "net/http" "testing" "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/netxlite" + utls "gitlab.com/yawning/utls.git" ) func TestHTTPTransport(t *testing.T) { @@ -49,3 +52,21 @@ func TestHTTP3Transport(t *testing.T) { txp.CloseIdleConnections() }) } + +func TestUTLSHandshaker(t *testing.T) { + t.Run("with chrome fingerprint", func(t *testing.T) { + h := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_Auto) + cfg := &tls.Config{ServerName: "google.com"} + conn, err := net.Dial("tcp", "google.com:443") + if err != nil { + t.Fatal("unexpected error", err) + } + conn, _, err = h.Handshake(context.Background(), conn, cfg) + if err != nil { + t.Fatal("unexpected error", err) + } + if conn == nil { + t.Fatal("nil connection") + } + }) +} diff --git a/internal/netxlite/legacy_test.go b/internal/netxlite/legacy_test.go index d994f40..af132cd 100644 --- a/internal/netxlite/legacy_test.go +++ b/internal/netxlite/legacy_test.go @@ -7,74 +7,80 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) -func TestResolverLegacyAdapterWithCompatibleType(t *testing.T) { - var called bool - r := NewResolverLegacyAdapter(&mocks.Resolver{ - MockNetwork: func() string { - return "network" - }, - MockAddress: func() string { - return "address" - }, - MockCloseIdleConnections: func() { - called = true - }, +func TestResolverLegacyAdapter(t *testing.T) { + t.Run("with compatible type", func(t *testing.T) { + var called bool + r := NewResolverLegacyAdapter(&mocks.Resolver{ + MockNetwork: func() string { + return "network" + }, + MockAddress: func() string { + return "address" + }, + MockCloseIdleConnections: func() { + called = true + }, + }) + if r.Network() != "network" { + t.Fatal("invalid Network") + } + if r.Address() != "address" { + t.Fatal("invalid Address") + } + r.CloseIdleConnections() + if !called { + t.Fatal("not called") + } }) - if r.Network() != "network" { - t.Fatal("invalid Network") - } - if r.Address() != "address" { - t.Fatal("invalid Address") - } - r.CloseIdleConnections() - if !called { - t.Fatal("not called") - } -} -func TestResolverLegacyAdapterDefaults(t *testing.T) { - r := NewResolverLegacyAdapter(&net.Resolver{}) - if r.Network() != "adapter" { - t.Fatal("invalid Network") - } - if r.Address() != "" { - t.Fatal("invalid Address") - } - r.CloseIdleConnections() // does not crash -} - -func TestDialerLegacyAdapterWithCompatibleType(t *testing.T) { - var called bool - r := NewDialerLegacyAdapter(&mocks.Dialer{ - MockCloseIdleConnections: func() { - called = true - }, + t.Run("with incompatible type", func(t *testing.T) { + r := NewResolverLegacyAdapter(&net.Resolver{}) + if r.Network() != "adapter" { + t.Fatal("invalid Network") + } + if r.Address() != "" { + t.Fatal("invalid Address") + } + r.CloseIdleConnections() // does not crash }) - r.CloseIdleConnections() - if !called { - t.Fatal("not called") - } } -func TestDialerLegacyAdapterDefaults(t *testing.T) { - r := NewDialerLegacyAdapter(&net.Dialer{}) - r.CloseIdleConnections() // does not crash -} - -func TestQUICContextDialerAdapterWithCompatibleType(t *testing.T) { - var called bool - d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICDialer{ - MockCloseIdleConnections: func() { - called = true - }, +func TestDialerLegacyAdapter(t *testing.T) { + t.Run("with compatible type", func(t *testing.T) { + var called bool + r := NewDialerLegacyAdapter(&mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + }) + r.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("with incompatible type", func(t *testing.T) { + r := NewDialerLegacyAdapter(&net.Dialer{}) + r.CloseIdleConnections() // does not crash }) - d.CloseIdleConnections() - if !called { - t.Fatal("not called") - } } -func TestQUICContextDialerAdapterDefaults(t *testing.T) { - d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICContextDialer{}) - d.CloseIdleConnections() // does not crash +func TestQUICContextDialerAdapter(t *testing.T) { + t.Run("with compatible type", func(t *testing.T) { + var called bool + d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICDialer{ + MockCloseIdleConnections: func() { + called = true + }, + }) + d.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("with incompatible type", func(t *testing.T) { + d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICContextDialer{}) + d.CloseIdleConnections() // does not crash + }) } diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index 59728f6..c68fd09 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -17,457 +17,459 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/quicx" ) -func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) { - tlsConfig := &tls.Config{ - ServerName: "www.google.com", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - } - defer systemdialer.CloseIdleConnections() // just to see it running - ctx := context.Background() - sess, err := systemdialer.DialContext( - ctx, "udp", "a.b.c.d", tlsConfig, &quic.Config{}) - if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil sess here") - } -} +func TestQUICDialerQUICGo(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("cannot split host port", func(t *testing.T) { + tlsConfig := &tls.Config{ + ServerName: "www.google.com", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + } + defer systemdialer.CloseIdleConnections() // just to see it running + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "a.b.c.d", tlsConfig, &quic.Config{}) + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } + }) -func TestQUICDialerQUICGoInvalidPort(t *testing.T) { - tlsConfig := &tls.Config{ - ServerName: "www.google.com", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - } - ctx := context.Background() - sess, err := systemdialer.DialContext( - ctx, "udp", "8.8.4.4:xyz", tlsConfig, &quic.Config{}) - if err == nil || !strings.HasSuffix(err.Error(), "invalid syntax") { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil sess here") - } -} + t.Run("with invalid port", func(t *testing.T) { + tlsConfig := &tls.Config{ + ServerName: "www.google.com", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.4.4:xyz", tlsConfig, &quic.Config{}) + if err == nil || !strings.HasSuffix(err.Error(), "invalid syntax") { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } + }) -func TestQUICDialerQUICGoInvalidIP(t *testing.T) { - tlsConfig := &tls.Config{ - ServerName: "www.google.com", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - } - ctx := context.Background() - sess, err := systemdialer.DialContext( - ctx, "udp", "a.b.c.d:0", tlsConfig, &quic.Config{}) - if !errors.Is(err, errInvalidIP) { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil sess here") - } -} + t.Run("with invalid IP", func(t *testing.T) { + tlsConfig := &tls.Config{ + ServerName: "www.google.com", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "a.b.c.d:0", tlsConfig, &quic.Config{}) + if !errors.Is(err, errInvalidIP) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } + }) -func TestQUICDialerQUICGoCannotListen(t *testing.T) { - expected := errors.New("mocked error") - tlsConfig := &tls.Config{ - ServerName: "www.google.com", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &mocks.QUICListener{ - MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) { - return nil, expected - }, - }, - } - ctx := context.Background() - sess, err := systemdialer.DialContext( - ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil sess here") - } -} - -func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) { - tlsConfig := &tls.Config{ - ServerName: "dns.google", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - } - ctx, cancel := context.WithCancel(context.Background()) - cancel() // fail immediately - sess, err := systemdialer.DialContext( - ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) - if !errors.Is(err, context.Canceled) { - t.Fatal("not the error we expected", err) - } - if sess != nil { - log.Fatal("expected nil session here") - } -} - -func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) { - tlsConfig := &tls.Config{ - ServerName: "dns.google", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - } - ctx := context.Background() - sess, err := systemdialer.DialContext( - ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) - if err != nil { - t.Fatal("not the error we expected", err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } -} - -func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) { - expected := errors.New("mocked error") - var gotTLSConfig *tls.Config - tlsConfig := &tls.Config{ - ServerName: "dns.google", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, - remoteAddr net.Addr, host string, tlsConfig *tls.Config, - quicConfig *quic.Config) (quic.EarlySession, error) { - gotTLSConfig = tlsConfig - return nil, expected - }, - } - ctx := context.Background() - sess, err := systemdialer.DialContext( - ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil session here") - } - if tlsConfig.RootCAs != nil { - t.Fatal("tlsConfig.RootCAs should not have been changed") - } - if gotTLSConfig.RootCAs != defaultCertPool { - t.Fatal("invalid gotTLSConfig.RootCAs") - } - if tlsConfig.NextProtos != nil { - t.Fatal("tlsConfig.NextProtos should not have been changed") - } - if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" { - t.Fatal("invalid gotTLSConfig.NextProtos", diff) - } - if tlsConfig.ServerName != gotTLSConfig.ServerName { - t.Fatal("the ServerName field must match") - } -} - -func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) { - expected := errors.New("mocked error") - var gotTLSConfig *tls.Config - tlsConfig := &tls.Config{ - ServerName: "dns.google", - } - systemdialer := quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, - remoteAddr net.Addr, host string, tlsConfig *tls.Config, - quicConfig *quic.Config) (quic.EarlySession, error) { - gotTLSConfig = tlsConfig - return nil, expected - }, - } - ctx := context.Background() - sess, err := systemdialer.DialContext( - ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{}) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil session here") - } - if tlsConfig.RootCAs != nil { - t.Fatal("tlsConfig.RootCAs should not have been changed") - } - if gotTLSConfig.RootCAs != defaultCertPool { - t.Fatal("invalid gotTLSConfig.RootCAs") - } - if tlsConfig.NextProtos != nil { - t.Fatal("tlsConfig.NextProtos should not have been changed") - } - if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" { - t.Fatal("invalid gotTLSConfig.NextProtos", diff) - } - if tlsConfig.ServerName != gotTLSConfig.ServerName { - t.Fatal("the ServerName field must match") - } -} - -func TestQUICDialerResolverCloseIdleConnections(t *testing.T) { - var ( - forDialer bool - forResolver bool - ) - d := &quicDialerResolver{ - Dialer: &mocks.QUICDialer{ - MockCloseIdleConnections: func() { - forDialer = true - }, - }, - Resolver: &mocks.Resolver{ - MockCloseIdleConnections: func() { - forResolver = true - }, - }, - } - d.CloseIdleConnections() - if !forDialer || !forResolver { - t.Fatal("not called") - } -} - -func TestQUICDialerResolverSuccess(t *testing.T) { - tlsConfig := &tls.Config{} - dialer := &quicDialerResolver{ - Resolver: NewResolverSystem(log.Log), - Dialer: &quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - }} - sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:443", - tlsConfig, &quic.Config{}) - if err != nil { - t.Fatal(err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } -} - -func TestQUICDialerResolverNoPort(t *testing.T) { - tlsConfig := &tls.Config{} - dialer := &quicDialerResolver{ - Resolver: NewResolverSystem(log.Log), - Dialer: &quicDialerQUICGo{}} - sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com", - tlsConfig, &quic.Config{}) - if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { - t.Fatal("not the error we expected") - } - if sess != nil { - t.Fatal("expected a nil sess here") - } -} - -func TestQUICDialerResolverLookupHostAddress(t *testing.T) { - dialer := &quicDialerResolver{Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - // We should not arrive here and call this function but if we do then - // there is going to be an error that fails this test. - return nil, errors.New("mocked error") - }, - }} - addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1") - if err != nil { - t.Fatal(err) - } - if len(addrs) != 1 || addrs[0] != "1.1.1.1" { - t.Fatal("not the result we expected") - } -} - -func TestQUICDialerResolverLookupHostFailure(t *testing.T) { - tlsConfig := &tls.Config{} - expected := errors.New("mocked error") - dialer := &quicDialerResolver{Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, expected - }, - }} - sess, err := dialer.DialContext( - context.Background(), "udp", "dns.google.com:853", - tlsConfig, &quic.Config{}) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if sess != nil { - t.Fatal("expected nil sess") - } -} - -func TestQUICDialerResolverInvalidPort(t *testing.T) { - // This test allows us to check for the case where every attempt - // to establish a connection leads to a failure - tlsConf := &tls.Config{} - dialer := &quicDialerResolver{ - Resolver: NewResolverSystem(log.Log), - Dialer: &quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - }} - sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:0", - tlsConf, &quic.Config{}) - if err == nil { - t.Fatal("expected an error here") - } - if !strings.HasSuffix(err.Error(), "sendto: invalid argument") && - !strings.HasSuffix(err.Error(), "sendto: can't assign requested address") { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil sess") - } -} - -func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) { - expected := errors.New("mocked error") - var gotTLSConfig *tls.Config - tlsConfig := &tls.Config{} - dialer := &quicDialerResolver{ - Resolver: NewResolverSystem(log.Log), - Dialer: &mocks.QUICDialer{ - MockDialContext: func(ctx context.Context, network, address string, - tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { - gotTLSConfig = tlsConfig - return nil, expected - }, - }} - sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:443", - tlsConfig, &quic.Config{}) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil session here") - } - if tlsConfig.ServerName != "" { - t.Fatal("should not have changed tlsConfig.ServerName") - } - if gotTLSConfig.ServerName != "www.google.com" { - t.Fatal("gotTLSConfig.ServerName has not been set") - } -} - -func TestQUICDialerLoggerCloseIdleConnections(t *testing.T) { - var forDialer bool - d := &quicDialerLogger{ - Dialer: &mocks.QUICDialer{ - MockCloseIdleConnections: func() { - forDialer = true - }, - }, - } - d.CloseIdleConnections() - if !forDialer { - t.Fatal("not called") - } -} - -func TestQUICDialerLoggerSuccess(t *testing.T) { - d := &quicDialerLogger{ - Dialer: &mocks.QUICDialer{ - MockDialContext: func(ctx context.Context, network string, - address string, tlsConfig *tls.Config, - quicConfig *quic.Config) (quic.EarlySession, error) { - return &mocks.QUICEarlySession{ - MockCloseWithError: func( - code quic.ApplicationErrorCode, reason string) error { - return nil + t.Run("with listen error", func(t *testing.T) { + expected := errors.New("mocked error") + tlsConfig := &tls.Config{ + ServerName: "www.google.com", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) { + return nil, expected }, - }, nil - }, - }, - Logger: log.Log, - } - ctx := context.Background() - tlsConfig := &tls.Config{} - quicConfig := &quic.Config{} - sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) - if err != nil { - t.Fatal(err) - } - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } + }, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } + }) + + t.Run("with handshake failure", func(t *testing.T) { + tlsConfig := &tls.Config{ + ServerName: "dns.google", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if !errors.Is(err, context.Canceled) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + log.Fatal("expected nil session here") + } + }) + + t.Run("works as intended", func(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + tlsConfig := &tls.Config{ + ServerName: "dns.google", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if err != nil { + t.Fatal("not the error we expected", err) + } + <-sess.HandshakeComplete().Done() + if err := sess.CloseWithError(0, ""); err != nil { + t.Fatal(err) + } + }) + + t.Run("TLS defaults for web", func(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + tlsConfig := &tls.Config{ + ServerName: "dns.google", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, + remoteAddr net.Addr, host string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) { + gotTLSConfig = tlsConfig + return nil, expected + }, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil session here") + } + if tlsConfig.RootCAs != nil { + t.Fatal("tlsConfig.RootCAs should not have been changed") + } + if gotTLSConfig.RootCAs != defaultCertPool { + t.Fatal("invalid gotTLSConfig.RootCAs") + } + if tlsConfig.NextProtos != nil { + t.Fatal("tlsConfig.NextProtos should not have been changed") + } + if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" { + t.Fatal("invalid gotTLSConfig.NextProtos", diff) + } + if tlsConfig.ServerName != gotTLSConfig.ServerName { + t.Fatal("the ServerName field must match") + } + }) + + t.Run("TLS defaults for DoQ", func(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + tlsConfig := &tls.Config{ + ServerName: "dns.google", + } + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, + remoteAddr net.Addr, host string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) { + gotTLSConfig = tlsConfig + return nil, expected + }, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil session here") + } + if tlsConfig.RootCAs != nil { + t.Fatal("tlsConfig.RootCAs should not have been changed") + } + if gotTLSConfig.RootCAs != defaultCertPool { + t.Fatal("invalid gotTLSConfig.RootCAs") + } + if tlsConfig.NextProtos != nil { + t.Fatal("tlsConfig.NextProtos should not have been changed") + } + if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" { + t.Fatal("invalid gotTLSConfig.NextProtos", diff) + } + if tlsConfig.ServerName != gotTLSConfig.ServerName { + t.Fatal("the ServerName field must match") + } + }) + }) } -func TestQUICDialerLoggerFailure(t *testing.T) { - expected := errors.New("mocked error") - d := &quicDialerLogger{ - Dialer: &mocks.QUICDialer{ - MockDialContext: func(ctx context.Context, network string, - address string, tlsConfig *tls.Config, - quicConfig *quic.Config) (quic.EarlySession, error) { - return nil, expected +func TestQUICDialerResolver(t *testing.T) { + + t.Run("CloseIdleConnections", func(t *testing.T) { + var ( + forDialer bool + forResolver bool + ) + d := &quicDialerResolver{ + Dialer: &mocks.QUICDialer{ + MockCloseIdleConnections: func() { + forDialer = true + }, }, - }, - Logger: log.Log, - } - ctx := context.Background() - tlsConfig := &tls.Config{} - quicConfig := &quic.Config{} - sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if sess != nil { - t.Fatal("expected nil session") - } + Resolver: &mocks.Resolver{ + MockCloseIdleConnections: func() { + forResolver = true + }, + }, + } + d.CloseIdleConnections() + if !forDialer || !forResolver { + t.Fatal("not called") + } + }) + + t.Run("DialContext", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + tlsConfig := &tls.Config{} + dialer := &quicDialerResolver{ + Resolver: NewResolverSystem(log.Log), + Dialer: &quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + }} + sess, err := dialer.DialContext( + context.Background(), "udp", "www.google.com:443", + tlsConfig, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + <-sess.HandshakeComplete().Done() + if err := sess.CloseWithError(0, ""); err != nil { + t.Fatal(err) + } + }) + + t.Run("with missing port", func(t *testing.T) { + tlsConfig := &tls.Config{} + dialer := &quicDialerResolver{ + Resolver: NewResolverSystem(log.Log), + Dialer: &quicDialerQUICGo{}} + sess, err := dialer.DialContext( + context.Background(), "udp", "www.google.com", + tlsConfig, &quic.Config{}) + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("not the error we expected") + } + if sess != nil { + t.Fatal("expected a nil sess here") + } + }) + + t.Run("with lookup host failure", func(t *testing.T) { + tlsConfig := &tls.Config{} + expected := errors.New("mocked error") + dialer := &quicDialerResolver{Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + }} + sess, err := dialer.DialContext( + context.Background(), "udp", "dns.google.com:853", + tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if sess != nil { + t.Fatal("expected nil sess") + } + }) + + t.Run("with invalid port", func(t *testing.T) { + // This test allows us to check for the case where every attempt + // to establish a connection leads to a failure + tlsConf := &tls.Config{} + dialer := &quicDialerResolver{ + Resolver: NewResolverSystem(log.Log), + Dialer: &quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + }} + sess, err := dialer.DialContext( + context.Background(), "udp", "www.google.com:0", + tlsConf, &quic.Config{}) + if err == nil { + t.Fatal("expected an error here") + } + if !strings.HasSuffix(err.Error(), "sendto: invalid argument") && + !strings.HasSuffix(err.Error(), "sendto: can't assign requested address") { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess") + } + }) + + t.Run("we apply TLS defaults", func(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + tlsConfig := &tls.Config{} + dialer := &quicDialerResolver{ + Resolver: NewResolverSystem(log.Log), + Dialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + gotTLSConfig = tlsConfig + return nil, expected + }, + }} + sess, err := dialer.DialContext( + context.Background(), "udp", "www.google.com:443", + tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil session here") + } + if tlsConfig.ServerName != "" { + t.Fatal("should not have changed tlsConfig.ServerName") + } + if gotTLSConfig.ServerName != "www.google.com" { + t.Fatal("gotTLSConfig.ServerName has not been set") + } + }) + }) + + t.Run("lookup host with address", func(t *testing.T) { + dialer := &quicDialerResolver{Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + // We should not arrive here and call this function but if we do then + // there is going to be an error that fails this test. + return nil, errors.New("mocked error") + }, + }} + addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "1.1.1.1" { + t.Fatal("not the result we expected") + } + }) } -func TestNewQUICDialerWithoutResolverChain(t *testing.T) { +func TestQUICLoggerDialer(t *testing.T) { + + t.Run("CloseIdleConnections", func(t *testing.T) { + var forDialer bool + d := &quicDialerLogger{ + Dialer: &mocks.QUICDialer{ + MockCloseIdleConnections: func() { + forDialer = true + }, + }, + } + d.CloseIdleConnections() + if !forDialer { + t.Fatal("not called") + } + }) + + t.Run("DialContext", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + d := &quicDialerLogger{ + Dialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network string, + address string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) { + return &mocks.QUICEarlySession{ + MockCloseWithError: func( + code quic.ApplicationErrorCode, reason string) error { + return nil + }, + }, nil + }, + }, + Logger: log.Log, + } + ctx := context.Background() + tlsConfig := &tls.Config{} + quicConfig := &quic.Config{} + sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) + if err != nil { + t.Fatal(err) + } + if err := sess.CloseWithError(0, ""); err != nil { + t.Fatal(err) + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &quicDialerLogger{ + Dialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network string, + address string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) { + return nil, expected + }, + }, + Logger: log.Log, + } + ctx := context.Background() + tlsConfig := &tls.Config{} + quicConfig := &quic.Config{} + sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil session") + } + }) + }) +} + +func TestNewQUICDialer(t *testing.T) { ql := NewQUICListener() dlr := NewQUICDialerWithoutResolver(ql, log.Log) - dlog, okay := dlr.(*quicDialerLogger) - if !okay { - t.Fatal("invalid type") - } - if dlog.Logger != log.Log { + logger := dlr.(*quicDialerLogger) + if logger.Logger != log.Log { t.Fatal("invalid logger") } - dr, okay := dlog.Dialer.(*quicDialerResolver) - if !okay { - t.Fatal("invalid type") - } - if _, okay := dr.Resolver.(*nullResolver); !okay { + resolver := logger.Dialer.(*quicDialerResolver) + if _, okay := resolver.Resolver.(*nullResolver); !okay { t.Fatal("invalid resolver type") } - dlog, okay = dr.Dialer.(*quicDialerLogger) - if !okay { - t.Fatal("invalid type") - } - if dlog.Logger != log.Log { + logger = resolver.Dialer.(*quicDialerLogger) + if logger.Logger != log.Log { t.Fatal("invalid logger") } - ew, okay := dlog.Dialer.(*quicDialerErrWrapper) - if !okay { - t.Fatal("invalid type") - } - dgo, okay := ew.QUICDialer.(*quicDialerQUICGo) - if !okay { - t.Fatal("invalid type") - } - if dgo.QUICListener != ql { + errWrapper := logger.Dialer.(*quicDialerErrWrapper) + base := errWrapper.QUICDialer.(*quicDialerQUICGo) + if base.QUICListener != ql { t.Fatal("invalid quic listener") } } -func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) { +func TestNewSingleUseQUICDialer(t *testing.T) { sess := &mocks.QUICEarlySession{} qd := NewSingleUseQUICDialer(sess) outsess, err := qd.DialContext( diff --git a/internal/netxlite/quirks_test.go b/internal/netxlite/quirks_test.go index 47c88fe..4f60345 100644 --- a/internal/netxlite/quirks_test.go +++ b/internal/netxlite/quirks_test.go @@ -15,6 +15,7 @@ func TestQuirkReduceErrors(t *testing.T) { t.Fatal("wrong result") } }) + t.Run("single error", func(t *testing.T) { err := errors.New("mocked error") result := quirkReduceErrors([]error{err}) @@ -22,6 +23,7 @@ func TestQuirkReduceErrors(t *testing.T) { t.Fatal("wrong result") } }) + t.Run("multiple errors", func(t *testing.T) { err1 := errors.New("mocked error #1") err2 := errors.New("mocked error #2") @@ -30,6 +32,7 @@ func TestQuirkReduceErrors(t *testing.T) { t.Fatal("wrong result") } }) + t.Run("multiple errors with meaningful ones", func(t *testing.T) { err1 := errors.New("mocked error #1") err2 := &errorsx.ErrWrapper{ diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolver_test.go index 27e7d1c..93b0771 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolver_test.go @@ -15,235 +15,237 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) -func TestResolverSystemNetworkAddress(t *testing.T) { - r := &resolverSystem{} - if r.Network() != "system" { - t.Fatal("invalid Network") - } - if r.Address() != "" { - t.Fatal("invalid Address") - } +func TestResolverSystem(t *testing.T) { + t.Run("Network and Address", func(t *testing.T) { + r := &resolverSystem{} + if r.Network() != "system" { + t.Fatal("invalid Network") + } + if r.Address() != "" { + t.Fatal("invalid Address") + } + }) + + t.Run("works as intended", func(t *testing.T) { + r := &resolverSystem{} + defer r.CloseIdleConnections() + addrs, err := r.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil result here") + } + }) + + t.Run("check default timeout", func(t *testing.T) { + r := &resolverSystem{} + if r.timeout() != 15*time.Second { + t.Fatal("unexpected default timeout") + } + }) + + t.Run("LookupHost", func(t *testing.T) { + t.Run("with timeout and success", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + r := &resolverSystem{ + testableTimeout: 1 * time.Microsecond, + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + defer wg.Done() + time.Sleep(1 * time.Millisecond) + return []string{"8.8.8.8"}, nil + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "example.antani") + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("invalid addrs") + } + wg.Wait() + }) + + t.Run("with timeout and failure", func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + r := &resolverSystem{ + testableTimeout: 1 * time.Microsecond, + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + defer wg.Done() + time.Sleep(1 * time.Millisecond) + return nil, errors.New("no such host") + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "example.antani") + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("invalid addrs") + } + wg.Wait() + }) + + t.Run("with NXDOMAIN", func(t *testing.T) { + r := &resolverSystem{ + testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, errors.New("no such host") + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "example.antani") + if err == nil || !strings.HasSuffix(err.Error(), "no such host") { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("invalid addrs") + } + }) + }) } -func TestResolverSystemWorksAsIntended(t *testing.T) { - r := &resolverSystem{} - defer r.CloseIdleConnections() - addrs, err := r.LookupHost(context.Background(), "dns.google.com") - if err != nil { - t.Fatal(err) - } - if addrs == nil { - t.Fatal("expected non-nil result here") - } +func TestResolverLogger(t *testing.T) { + t.Run("LookupHost", func(t *testing.T) { + t.Run("with success", func(t *testing.T) { + expected := []string{"1.1.1.1"} + r := resolverLogger{ + Logger: log.Log, + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return expected, nil + }, + }, + } + addrs, err := r.LookupHost(context.Background(), "dns.google") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expected, addrs); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("with failure", func(t *testing.T) { + expected := errors.New("mocked error") + r := resolverLogger{ + Logger: log.Log, + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + }, + } + addrs, err := r.LookupHost(context.Background(), "dns.google") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil addr here") + } + }) + }) } -func TestResolverSystemDefaultTimeout(t *testing.T) { - r := &resolverSystem{} - if r.timeout() != 15*time.Second { - t.Fatal("unexpected default timeout") - } +func TestResolverIDNA(t *testing.T) { + t.Run("LookupHost", func(t *testing.T) { + t.Run("with valid IDNA in input", func(t *testing.T) { + expectedIPs := []string{"77.88.55.66"} + r := &resolverIDNA{ + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + if domain != "xn--d1acpjx3f.xn--p1ai" { + return nil, errors.New("passed invalid domain") + } + return expectedIPs, nil + }, + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "яндекс.рф") + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(expectedIPs, addrs); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("with invalid punycode", func(t *testing.T) { + r := &resolverIDNA{Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, errors.New("should not happen") + }, + }} + // See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/ + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "xn--0000h") + if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected no response here") + } + }) + }) } -func TestResolverSystemWithTimeoutAndSuccess(t *testing.T) { - wg := &sync.WaitGroup{} - wg.Add(1) - r := &resolverSystem{ - testableTimeout: 1 * time.Microsecond, - testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { - defer wg.Done() - time.Sleep(1 * time.Millisecond) - return []string{"8.8.8.8"}, nil - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "example.antani") - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("invalid addrs") - } - wg.Wait() -} - -func TestResolverSystemWithTimeoutAndFailure(t *testing.T) { - wg := &sync.WaitGroup{} - wg.Add(1) - r := &resolverSystem{ - testableTimeout: 1 * time.Microsecond, - testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { - defer wg.Done() - time.Sleep(1 * time.Millisecond) - return nil, errors.New("no such host") - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "example.antani") - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("invalid addrs") - } - wg.Wait() -} - -func TestResolverSystemWithNXDOMAIN(t *testing.T) { - r := &resolverSystem{ - testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, errors.New("no such host") - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "example.antani") - if err == nil || !strings.HasSuffix(err.Error(), "no such host") { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("invalid addrs") - } -} - -func TestResolverLoggerWithSuccess(t *testing.T) { - expected := []string{"1.1.1.1"} - r := resolverLogger{ - Logger: log.Log, - Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return expected, nil - }, - }, - } - addrs, err := r.LookupHost(context.Background(), "dns.google") - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(expected, addrs); diff != "" { - t.Fatal(diff) - } -} - -func TestResolverLoggerWithFailure(t *testing.T) { - expected := errors.New("mocked error") - r := resolverLogger{ - Logger: log.Log, - Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, expected - }, - }, - } - addrs, err := r.LookupHost(context.Background(), "dns.google") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("expected nil addr here") - } -} - -func TestResolverIDNAWorksAsIntended(t *testing.T) { - expectedIPs := []string{"77.88.55.66"} - r := &resolverIDNA{ - Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - if domain != "xn--d1acpjx3f.xn--p1ai" { - return nil, errors.New("passed invalid domain") - } - return expectedIPs, nil - }, - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "яндекс.рф") - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(expectedIPs, addrs); diff != "" { - t.Fatal(diff) - } -} - -func TestResolverIDNAWithInvalidPunycode(t *testing.T) { - r := &resolverIDNA{Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, errors.New("should not happen") - }, - }} - // See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/ - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "xn--0000h") - if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected no response here") - } -} - -func TestNewResolverTypeChain(t *testing.T) { - r := NewResolverSystem(log.Log) - ridna, ok := r.(*resolverIDNA) - if !ok { - t.Fatal("invalid resolver") - } - rl, ok := ridna.Resolver.(*resolverLogger) - if !ok { - t.Fatal("invalid resolver") - } - if rl.Logger != log.Log { +func TestNewResolverSystem(t *testing.T) { + resolver := NewResolverSystem(log.Log) + idna := resolver.(*resolverIDNA) + logger := idna.Resolver.(*resolverLogger) + if logger.Logger != log.Log { t.Fatal("invalid logger") } - scia, ok := rl.Resolver.(*resolverShortCircuitIPAddr) - if !ok { - t.Fatal("invalid resolver") - } - ew, ok := scia.Resolver.(*resolverErrWrapper) - if !ok { - t.Fatal("invalid resolver") - } - if _, ok := ew.Resolver.(*resolverSystem); !ok { - t.Fatal("invalid resolver") - } + shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) + errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) + _ = errWrapper.Resolver.(*resolverSystem) } -func TestResolverShortCircuitIPAddrWithIPAddr(t *testing.T) { - r := &resolverShortCircuitIPAddr{ - Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, errors.New("mocked error") - }, - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "8.8.8.8") - if err != nil { - t.Fatal(err) - } - if len(addrs) != 1 || addrs[0] != "8.8.8.8" { - t.Fatal("invalid result") - } +func TestResolverShortCircuitIPAddr(t *testing.T) { + t.Run("LookupHost", func(t *testing.T) { + t.Run("with IP addr", func(t *testing.T) { + r := &resolverShortCircuitIPAddr{ + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, errors.New("mocked error") + }, + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "8.8.8.8") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("invalid result") + } + }) + + t.Run("with domain", func(t *testing.T) { + r := &resolverShortCircuitIPAddr{ + Resolver: &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, errors.New("mocked error") + }, + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != "mocked error" { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("invalid result") + } + }) + }) } -func TestResolverShortCircuitIPAddrWithDomain(t *testing.T) { - r := &resolverShortCircuitIPAddr{ - Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, errors.New("mocked error") - }, - }, - } - ctx := context.Background() - addrs, err := r.LookupHost(ctx, "dns.google") - if err == nil || err.Error() != "mocked error" { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("invalid result") - } -} - -func TestNullResolverWorksAsIntended(t *testing.T) { +func TestNullResolver(t *testing.T) { r := &nullResolver{} ctx := context.Background() addrs, err := r.LookupHost(ctx, "dns.google") diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 8d18d3d..9a85dbf 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -118,354 +118,357 @@ func TestConfigureTLSVersion(t *testing.T) { } } -func TestTLSHandshakerConfigurableWithError(t *testing.T) { - var times []time.Time - h := &tlsHandshakerConfigurable{} - tcpConn := &mocks.Conn{ - MockWrite: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockSetDeadline: func(t time.Time) error { - times = append(times, t) - return nil - }, - } - ctx := context.Background() - conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{ - ServerName: "x.org", - }) - if err != io.EOF { - t.Fatal("not the error that we expected") - } - if conn != nil { - t.Fatal("expected nil con here") - } - if len(times) != 2 { - t.Fatal("expected two time entries") - } - if !times[0].After(time.Now()) { - t.Fatal("timeout not in the future") - } - if !times[1].IsZero() { - t.Fatal("did not clear timeout on exit") - } -} +func TestTLSHandshakerConfigurable(t *testing.T) { + t.Run("Handshake", func(t *testing.T) { + t.Run("with error", func(t *testing.T) { -func TestTLSHandshakerConfigurableSuccess(t *testing.T) { - handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.WriteHeader(200) - }) - srvr := httptest.NewTLSServer(handler) - defer srvr.Close() - URL, err := url.Parse(srvr.URL) - if err != nil { - t.Fatal(err) - } - conn, err := net.Dial("tcp", URL.Host) - if err != nil { - t.Fatal(err) - } - defer conn.Close() - handshaker := &tlsHandshakerConfigurable{} - ctx := context.Background() - config := &tls.Config{ - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS13, - MaxVersion: tls.VersionTLS13, - ServerName: URL.Hostname(), - } - tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) - if err != nil { - t.Fatal(err) - } - defer tlsConn.Close() - if connState.Version != tls.VersionTLS13 { - t.Fatal("unexpected TLS version") - } -} - -func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) { - expected := errors.New("mocked error") - var gotTLSConfig *tls.Config - handshaker := &tlsHandshakerConfigurable{ - NewConn: func(conn net.Conn, config *tls.Config) TLSConn { - gotTLSConfig = config - return &mocks.TLSConn{ - MockHandshakeContext: func(ctx context.Context) error { - return expected + var times []time.Time + h := &tlsHandshakerConfigurable{} + tcpConn := &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockSetDeadline: func(t time.Time) error { + times = append(times, t) + return nil }, } - }, - } - ctx := context.Background() - config := &tls.Config{} - conn := &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil - }, - } - tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero connState here") - } - if tlsConn != nil { - t.Fatal("expected nil tlsConn here") - } - if config.RootCAs != nil { - t.Fatal("config.RootCAs should still be nil") - } - if gotTLSConfig.RootCAs != defaultCertPool { - t.Fatal("gotTLSConfig.RootCAs has not been correctly set") - } + ctx := context.Background() + conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{ + ServerName: "x.org", + }) + if err != io.EOF { + t.Fatal("not the error that we expected") + } + if conn != nil { + t.Fatal("expected nil con here") + } + if len(times) != 2 { + t.Fatal("expected two time entries") + } + if !times[0].After(time.Now()) { + t.Fatal("timeout not in the future") + } + if !times[1].IsZero() { + t.Fatal("did not clear timeout on exit") + } + }) + + t.Run("with success", func(t *testing.T) { + handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(200) + }) + srvr := httptest.NewTLSServer(handler) + defer srvr.Close() + URL, err := url.Parse(srvr.URL) + if err != nil { + t.Fatal(err) + } + conn, err := net.Dial("tcp", URL.Host) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + handshaker := &tlsHandshakerConfigurable{} + ctx := context.Background() + config := &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + ServerName: URL.Hostname(), + } + tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + if err != nil { + t.Fatal(err) + } + defer tlsConn.Close() + if connState.Version != tls.VersionTLS13 { + t.Fatal("unexpected TLS version") + } + }) + + t.Run("sets default root CA", func(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + handshaker := &tlsHandshakerConfigurable{ + NewConn: func(conn net.Conn, config *tls.Config) TLSConn { + gotTLSConfig = config + return &mocks.TLSConn{ + MockHandshakeContext: func(ctx context.Context) error { + return expected + }, + } + }, + } + ctx := context.Background() + config := &tls.Config{} + conn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + } + tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero connState here") + } + if tlsConn != nil { + t.Fatal("expected nil tlsConn here") + } + if config.RootCAs != nil { + t.Fatal("config.RootCAs should still be nil") + } + if gotTLSConfig.RootCAs != defaultCertPool { + t.Fatal("gotTLSConfig.RootCAs has not been correctly set") + } + }) + }) } -func TestTLSHandshakerLoggerSuccess(t *testing.T) { - th := &tlsHandshakerLogger{ - TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return tls.Client(conn, config), tls.ConnectionState{}, nil +func TestTLSHandshakerLogger(t *testing.T) { + t.Run("Handshake", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + th := &tlsHandshakerLogger{ + TLSHandshaker: &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return tls.Client(conn, config), tls.ConnectionState{}, nil + }, + }, + Logger: log.Log, + } + conn := &mocks.Conn{ + MockClose: func() error { + return nil + }, + } + config := &tls.Config{} + ctx := context.Background() + tlsConn, connState, err := th.Handshake(ctx, conn, config) + if err != nil { + t.Fatal(err) + } + if err := tlsConn.Close(); err != nil { + t.Fatal(err) + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero ConnectionState here") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + th := &tlsHandshakerLogger{ + TLSHandshaker: &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, expected + }, + }, + Logger: log.Log, + } + conn := &mocks.Conn{ + MockClose: func() error { + return nil + }, + } + config := &tls.Config{} + ctx := context.Background() + tlsConn, connState, err := th.Handshake(ctx, conn, config) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if tlsConn != nil { + t.Fatal("expected nil conn here") + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero ConnectionState here") + } + }) + }) +} + +func TestTLSDialer(t *testing.T) { + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + dialer := &tlsDialer{ + Dialer: &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, }, - }, - Logger: log.Log, - } - conn := &mocks.Conn{ - MockClose: func() error { - return nil - }, - } - config := &tls.Config{} - ctx := context.Background() - tlsConn, connState, err := th.Handshake(ctx, conn, config) - if err != nil { - t.Fatal(err) - } - if err := tlsConn.Close(); err != nil { - t.Fatal(err) - } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero ConnectionState here") - } + } + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("DialTLSContext", func(t *testing.T) { + t.Run("failure to split host and port", func(t *testing.T) { + dialer := &tlsDialer{} + ctx := context.Background() + const address = "www.google.com" // missing port + conn, err := dialer.DialTLSContext(ctx, "tcp", address) + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("connection is not nil") + } + }) + + t.Run("failure dialing", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // immediately fail + dialer := tlsDialer{Dialer: &dialerSystem{}} + conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") + if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("connection is not nil") + } + }) + + t.Run("failure handshaking", func(t *testing.T) { + ctx := context.Background() + dialer := tlsDialer{ + Config: &tls.Config{}, + Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, MockClose: func() error { + return nil + }, MockSetDeadline: func(t time.Time) error { + return nil + }}, nil + }}, + TLSHandshaker: &tlsHandshakerConfigurable{}, + } + conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("connection is not nil") + } + }) + + t.Run("success handshaking", func(t *testing.T) { + ctx := context.Background() + dialer := tlsDialer{ + Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, MockClose: func() error { + return nil + }, MockSetDeadline: func(t time.Time) error { + return nil + }}, nil + }}, + TLSHandshaker: &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return tls.Client(conn, config), tls.ConnectionState{}, nil + }, + }, + } + conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("connection is nil") + } + conn.Close() + }) + }) + + t.Run("config", func(t *testing.T) { + t.Run("from empty config for web", func(t *testing.T) { + d := &tlsDialer{} + config := d.config("www.google.com", "443") + if config.ServerName != "www.google.com" { + t.Fatal("invalid server name") + } + if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("from empty config for dot", func(t *testing.T) { + d := &tlsDialer{} + config := d.config("dns.google", "853") + if config.ServerName != "dns.google" { + t.Fatal("invalid server name") + } + if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("with server name", func(t *testing.T) { + d := &tlsDialer{ + Config: &tls.Config{ + ServerName: "example.com", + }, + } + config := d.config("dns.google", "853") + if config.ServerName != "example.com" { + t.Fatal("invalid server name") + } + if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("with alpn", func(t *testing.T) { + d := &tlsDialer{ + Config: &tls.Config{ + NextProtos: []string{"h2"}, + }, + } + config := d.config("dns.google", "853") + if config.ServerName != "dns.google" { + t.Fatal("invalid server name") + } + if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" { + t.Fatal(diff) + } + }) + }) } -func TestTLSHandshakerLoggerFailure(t *testing.T) { - expected := errors.New("mocked error") - th := &tlsHandshakerLogger{ - TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, expected - }, - }, - Logger: log.Log, - } - conn := &mocks.Conn{ - MockClose: func() error { - return nil - }, - } - config := &tls.Config{} - ctx := context.Background() - tlsConn, connState, err := th.Handshake(ctx, conn, config) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if tlsConn != nil { - t.Fatal("expected nil conn here") - } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero ConnectionState here") - } -} - -func TestTLSDialerCloseIdleConnections(t *testing.T) { - var called bool - dialer := &tlsDialer{ - Dialer: &mocks.Dialer{ - MockCloseIdleConnections: func() { - called = true - }, - }, - } - dialer.CloseIdleConnections() - if !called { - t.Fatal("not called") - } -} - -func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { - dialer := &tlsDialer{} - ctx := context.Background() - const address = "www.google.com" // missing port - conn, err := dialer.DialTLSContext(ctx, "tcp", address) - if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { - t.Fatal("not the error we expected", err) - } - if conn != nil { - t.Fatal("connection is not nil") - } -} - -func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() // immediately fail - dialer := tlsDialer{Dialer: &dialerSystem{}} - conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") - if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { - t.Fatal("not the error we expected", err) - } - if conn != nil { - t.Fatal("connection is not nil") - } -} - -func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { - ctx := context.Background() - dialer := tlsDialer{ - Config: &tls.Config{}, - Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{MockWrite: func(b []byte) (int, error) { - return 0, io.EOF - }, MockClose: func() error { - return nil - }, MockSetDeadline: func(t time.Time) error { - return nil - }}, nil - }}, - TLSHandshaker: &tlsHandshakerConfigurable{}, - } - conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected", err) - } - if conn != nil { - t.Fatal("connection is not nil") - } -} - -func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { - ctx := context.Background() - dialer := tlsDialer{ - Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{MockWrite: func(b []byte) (int, error) { - return 0, io.EOF - }, MockClose: func() error { - return nil - }, MockSetDeadline: func(t time.Time) error { - return nil - }}, nil - }}, - TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return tls.Client(conn, config), tls.ConnectionState{}, nil - }, - }, - } - conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") - if err != nil { - t.Fatal(err) - } - if conn == nil { - t.Fatal("connection is nil") - } - conn.Close() -} - -func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { - d := &tlsDialer{} - config := d.config("www.google.com", "443") - if config.ServerName != "www.google.com" { - t.Fatal("invalid server name") - } - if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" { - t.Fatal(diff) - } -} - -func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { - d := &tlsDialer{} - config := d.config("dns.google", "853") - if config.ServerName != "dns.google" { - t.Fatal("invalid server name") - } - if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { - t.Fatal(diff) - } -} - -func TestTLSDialerConfigWithServerName(t *testing.T) { - d := &tlsDialer{ - Config: &tls.Config{ - ServerName: "example.com", - }, - } - config := d.config("dns.google", "853") - if config.ServerName != "example.com" { - t.Fatal("invalid server name") - } - if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { - t.Fatal(diff) - } -} - -func TestTLSDialerConfigWithALPN(t *testing.T) { - d := &tlsDialer{ - Config: &tls.Config{ - NextProtos: []string{"h2"}, - }, - } - config := d.config("dns.google", "853") - if config.ServerName != "dns.google" { - t.Fatal("invalid server name") - } - if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" { - t.Fatal(diff) - } -} - -func TestNewTLSHandshakerStdlibTypes(t *testing.T) { +func TestNewTLSHandshakerStdlib(t *testing.T) { th := NewTLSHandshakerStdlib(log.Log) - thl, okay := th.(*tlsHandshakerLogger) - if !okay { - t.Fatal("invalid type") - } - if thl.Logger != log.Log { + logger := th.(*tlsHandshakerLogger) + if logger.Logger != log.Log { t.Fatal("invalid logger") } - ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper) - if !okay { - t.Fatal("invalid type") - } - thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable) - if !okay { - t.Fatal("invalid type") - } - if thc.NewConn != nil { + errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper) + configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable) + if configurable.NewConn != nil { t.Fatal("expected nil NewConn") } } -func TestNewTLSDialerWorksAsIntended(t *testing.T) { +func TestNewTLSDialer(t *testing.T) { d := &mocks.Dialer{} - tlsh := &mocks.TLSHandshaker{} - td := NewTLSDialer(d, tlsh) - tdut, okay := td.(*tlsDialer) - if !okay { - t.Fatal("invalid type") - } - if tdut.Config == nil { + th := &mocks.TLSHandshaker{} + dialer := NewTLSDialer(d, th) + tlsd := dialer.(*tlsDialer) + if tlsd.Config == nil { t.Fatal("unexpected config") } - if tdut.Dialer != d { + if tlsd.Dialer != d { t.Fatal("unexpected dialer") } - if tdut.TLSHandshaker != tlsh { + if tlsd.TLSHandshaker != th { t.Fatal("invalid handshaker") } } -func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) { +func TestNewSingleUseTLSDialer(t *testing.T) { conn := &mocks.TLSConn{} d := NewSingleUseTLSDialer(conn) outconn, err := d.DialTLSContext(context.Background(), "", "") diff --git a/internal/netxlite/utls_test.go b/internal/netxlite/utls_test.go index cf2e894..2b3a7ea 100644 --- a/internal/netxlite/utls_test.go +++ b/internal/netxlite/utls_test.go @@ -2,9 +2,7 @@ package netxlite import ( "context" - "crypto/tls" "errors" - "net" "sync" "testing" "time" @@ -13,107 +11,84 @@ import ( utls "gitlab.com/yawning/utls.git" ) -func TestUTLSHandshakerChrome(t *testing.T) { - h := &tlsHandshakerConfigurable{ - NewConn: newConnUTLS(&utls.HelloChrome_Auto), - } - cfg := &tls.Config{ServerName: "google.com"} - conn, err := net.Dial("tcp", "google.com:443") - if err != nil { - t.Fatal("unexpected error", err) - } - conn, _, err = h.Handshake(context.Background(), conn, cfg) - if err != nil { - t.Fatal("unexpected error", err) - } - if conn == nil { - t.Fatal("nil connection") - } -} - -func TestNewTLSHandshakerUTLSTypes(t *testing.T) { +func TestNewTLSHandshakerUTLS(t *testing.T) { th := NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_83) - thl, okay := th.(*tlsHandshakerLogger) - if !okay { - t.Fatal("invalid type") - } - if thl.Logger != log.Log { + logger := th.(*tlsHandshakerLogger) + if logger.Logger != log.Log { t.Fatal("invalid logger") } - ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper) - if !okay { - t.Fatal("invalid type") - } - thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable) - if !okay { - t.Fatal("invalid type") - } - if thc.NewConn == nil { + errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper) + configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable) + if configurable.NewConn == nil { t.Fatal("expected non-nil NewConn") } } -func TestUTLSConnHandshakeNotInterruptedSuccess(t *testing.T) { - ctx := context.Background() - conn := &utlsConn{ - testableHandshake: func() error { - return nil - }, - } - err := conn.HandshakeContext(ctx) - if err != nil { - t.Fatal(err) - } -} +func TestUTLSConn(t *testing.T) { + t.Run("Handshake", func(t *testing.T) { + t.Run("not interrupted with success", func(t *testing.T) { + ctx := context.Background() + conn := &utlsConn{ + testableHandshake: func() error { + return nil + }, + } + err := conn.HandshakeContext(ctx) + if err != nil { + t.Fatal(err) + } + }) -func TestUTLSConnHandshakeNotInterruptedFailure(t *testing.T) { - expected := errors.New("mocked error") - ctx := context.Background() - conn := &utlsConn{ - testableHandshake: func() error { - return expected - }, - } - err := conn.HandshakeContext(ctx) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } -} + t.Run("not interrupted with failure", func(t *testing.T) { + expected := errors.New("mocked error") + ctx := context.Background() + conn := &utlsConn{ + testableHandshake: func() error { + return expected + }, + } + err := conn.HandshakeContext(ctx) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + }) -func TestUTLSConnHandshakeInterrupted(t *testing.T) { - wg := sync.WaitGroup{} - wg.Add(1) - sigch := make(chan interface{}) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) - defer cancel() - conn := &utlsConn{ - testableHandshake: func() error { - defer wg.Done() - <-sigch - return nil - }, - } - err := conn.HandshakeContext(ctx) - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("not the error we expected", err) - } - close(sigch) - wg.Wait() -} + t.Run("interrupted", func(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(1) + sigch := make(chan interface{}) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + conn := &utlsConn{ + testableHandshake: func() error { + defer wg.Done() + <-sigch + return nil + }, + } + err := conn.HandshakeContext(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("not the error we expected", err) + } + close(sigch) + wg.Wait() + }) -func TestUTLSConnHandshakePanic(t *testing.T) { - wg := sync.WaitGroup{} - wg.Add(1) - ctx := context.Background() - conn := &utlsConn{ - testableHandshake: func() error { - defer wg.Done() - panic("mascetti") - }, - } - err := conn.HandshakeContext(ctx) - if !errors.Is(err, ErrUTLSHandshakePanic) { - t.Fatal("not the error we expected", err) - } - wg.Wait() + t.Run("with panic", func(t *testing.T) { + wg := sync.WaitGroup{} + wg.Add(1) + ctx := context.Background() + conn := &utlsConn{ + testableHandshake: func() error { + defer wg.Done() + panic("mascetti") + }, + } + err := conn.HandshakeContext(ctx) + if !errors.Is(err, ErrUTLSHandshakePanic) { + t.Fatal("not the error we expected", err) + } + wg.Wait() + }) + }) }