refactor(netxlite): finish grouping tests (#488)

They are now more readable. I'll do another pass and start
separating integration testing from unit testing.

I think we need to have some always on integration testing
for netxlite that runs on macOS, linux, and windows.

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-08 11:39:27 +02:00 committed by GitHub
parent 493b72b170
commit f2e3e5cc08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1140 additions and 1128 deletions

View File

@ -1,12 +1,15 @@
package netxlite_test package netxlite_test
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net"
"net/http" "net/http"
"testing" "testing"
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
utls "gitlab.com/yawning/utls.git"
) )
func TestHTTPTransport(t *testing.T) { func TestHTTPTransport(t *testing.T) {
@ -49,3 +52,21 @@ func TestHTTP3Transport(t *testing.T) {
txp.CloseIdleConnections() txp.CloseIdleConnections()
}) })
} }
func TestUTLSHandshaker(t *testing.T) {
t.Run("with chrome fingerprint", func(t *testing.T) {
h := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_Auto)
cfg := &tls.Config{ServerName: "google.com"}
conn, err := net.Dial("tcp", "google.com:443")
if err != nil {
t.Fatal("unexpected error", err)
}
conn, _, err = h.Handshake(context.Background(), conn, cfg)
if err != nil {
t.Fatal("unexpected error", err)
}
if conn == nil {
t.Fatal("nil connection")
}
})
}

View File

@ -7,7 +7,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestResolverLegacyAdapterWithCompatibleType(t *testing.T) { func TestResolverLegacyAdapter(t *testing.T) {
t.Run("with compatible type", func(t *testing.T) {
var called bool var called bool
r := NewResolverLegacyAdapter(&mocks.Resolver{ r := NewResolverLegacyAdapter(&mocks.Resolver{
MockNetwork: func() string { MockNetwork: func() string {
@ -30,9 +31,9 @@ func TestResolverLegacyAdapterWithCompatibleType(t *testing.T) {
if !called { if !called {
t.Fatal("not called") t.Fatal("not called")
} }
} })
func TestResolverLegacyAdapterDefaults(t *testing.T) { t.Run("with incompatible type", func(t *testing.T) {
r := NewResolverLegacyAdapter(&net.Resolver{}) r := NewResolverLegacyAdapter(&net.Resolver{})
if r.Network() != "adapter" { if r.Network() != "adapter" {
t.Fatal("invalid Network") t.Fatal("invalid Network")
@ -41,9 +42,11 @@ func TestResolverLegacyAdapterDefaults(t *testing.T) {
t.Fatal("invalid Address") t.Fatal("invalid Address")
} }
r.CloseIdleConnections() // does not crash r.CloseIdleConnections() // does not crash
})
} }
func TestDialerLegacyAdapterWithCompatibleType(t *testing.T) { func TestDialerLegacyAdapter(t *testing.T) {
t.Run("with compatible type", func(t *testing.T) {
var called bool var called bool
r := NewDialerLegacyAdapter(&mocks.Dialer{ r := NewDialerLegacyAdapter(&mocks.Dialer{
MockCloseIdleConnections: func() { MockCloseIdleConnections: func() {
@ -54,14 +57,16 @@ func TestDialerLegacyAdapterWithCompatibleType(t *testing.T) {
if !called { if !called {
t.Fatal("not called") t.Fatal("not called")
} }
} })
func TestDialerLegacyAdapterDefaults(t *testing.T) { t.Run("with incompatible type", func(t *testing.T) {
r := NewDialerLegacyAdapter(&net.Dialer{}) r := NewDialerLegacyAdapter(&net.Dialer{})
r.CloseIdleConnections() // does not crash r.CloseIdleConnections() // does not crash
})
} }
func TestQUICContextDialerAdapterWithCompatibleType(t *testing.T) { func TestQUICContextDialerAdapter(t *testing.T) {
t.Run("with compatible type", func(t *testing.T) {
var called bool var called bool
d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICDialer{ d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICDialer{
MockCloseIdleConnections: func() { MockCloseIdleConnections: func() {
@ -72,9 +77,10 @@ func TestQUICContextDialerAdapterWithCompatibleType(t *testing.T) {
if !called { if !called {
t.Fatal("not called") t.Fatal("not called")
} }
} })
func TestQUICContextDialerAdapterDefaults(t *testing.T) { t.Run("with incompatible type", func(t *testing.T) {
d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICContextDialer{}) d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICContextDialer{})
d.CloseIdleConnections() // does not crash d.CloseIdleConnections() // does not crash
})
} }

View File

@ -17,7 +17,9 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx" "github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
) )
func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) { func TestQUICDialerQUICGo(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("cannot split host port", func(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: "www.google.com", ServerName: "www.google.com",
} }
@ -34,9 +36,9 @@ func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil sess here") t.Fatal("expected nil sess here")
} }
} })
func TestQUICDialerQUICGoInvalidPort(t *testing.T) { t.Run("with invalid port", func(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: "www.google.com", ServerName: "www.google.com",
} }
@ -52,9 +54,9 @@ func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil sess here") t.Fatal("expected nil sess here")
} }
} })
func TestQUICDialerQUICGoInvalidIP(t *testing.T) { t.Run("with invalid IP", func(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: "www.google.com", ServerName: "www.google.com",
} }
@ -70,9 +72,9 @@ func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil sess here") t.Fatal("expected nil sess here")
} }
} })
func TestQUICDialerQUICGoCannotListen(t *testing.T) { t.Run("with listen error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: "www.google.com", ServerName: "www.google.com",
@ -93,9 +95,9 @@ func TestQUICDialerQUICGoCannotListen(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil sess here") t.Fatal("expected nil sess here")
} }
} })
func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) { t.Run("with handshake failure", func(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: "dns.google", ServerName: "dns.google",
} }
@ -112,9 +114,12 @@ func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) {
if sess != nil { if sess != nil {
log.Fatal("expected nil session here") log.Fatal("expected nil session here")
} }
} })
func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) { t.Run("works as intended", func(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: "dns.google", ServerName: "dns.google",
} }
@ -131,9 +136,9 @@ func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) {
if err := sess.CloseWithError(0, ""); err != nil { if err := sess.CloseWithError(0, ""); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} })
func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) { t.Run("TLS defaults for web", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
var gotTLSConfig *tls.Config var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@ -172,9 +177,9 @@ func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) {
if tlsConfig.ServerName != gotTLSConfig.ServerName { if tlsConfig.ServerName != gotTLSConfig.ServerName {
t.Fatal("the ServerName field must match") t.Fatal("the ServerName field must match")
} }
} })
func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) { t.Run("TLS defaults for DoQ", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
var gotTLSConfig *tls.Config var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@ -213,9 +218,13 @@ func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
if tlsConfig.ServerName != gotTLSConfig.ServerName { if tlsConfig.ServerName != gotTLSConfig.ServerName {
t.Fatal("the ServerName field must match") t.Fatal("the ServerName field must match")
} }
})
})
} }
func TestQUICDialerResolverCloseIdleConnections(t *testing.T) { func TestQUICDialerResolver(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var ( var (
forDialer bool forDialer bool
forResolver bool forResolver bool
@ -236,9 +245,10 @@ func TestQUICDialerResolverCloseIdleConnections(t *testing.T) {
if !forDialer || !forResolver { if !forDialer || !forResolver {
t.Fatal("not called") t.Fatal("not called")
} }
} })
func TestQUICDialerResolverSuccess(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: NewResolverSystem(log.Log), Resolver: NewResolverSystem(log.Log),
@ -255,9 +265,9 @@ func TestQUICDialerResolverSuccess(t *testing.T) {
if err := sess.CloseWithError(0, ""); err != nil { if err := sess.CloseWithError(0, ""); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} })
func TestQUICDialerResolverNoPort(t *testing.T) { t.Run("with missing port", func(t *testing.T) {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: NewResolverSystem(log.Log), Resolver: NewResolverSystem(log.Log),
@ -271,26 +281,9 @@ func TestQUICDialerResolverNoPort(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected a nil sess here") t.Fatal("expected a nil sess here")
} }
} })
func TestQUICDialerResolverLookupHostAddress(t *testing.T) { t.Run("with lookup host failure", func(t *testing.T) {
dialer := &quicDialerResolver{Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
// We should not arrive here and call this function but if we do then
// there is going to be an error that fails this test.
return nil, errors.New("mocked error")
},
}}
addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
t.Fatal("not the result we expected")
}
}
func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
expected := errors.New("mocked error") expected := errors.New("mocked error")
dialer := &quicDialerResolver{Resolver: &mocks.Resolver{ dialer := &quicDialerResolver{Resolver: &mocks.Resolver{
@ -307,9 +300,9 @@ func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil sess") t.Fatal("expected nil sess")
} }
} })
func TestQUICDialerResolverInvalidPort(t *testing.T) { t.Run("with invalid port", func(t *testing.T) {
// This test allows us to check for the case where every attempt // This test allows us to check for the case where every attempt
// to establish a connection leads to a failure // to establish a connection leads to a failure
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
@ -331,9 +324,9 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil sess") t.Fatal("expected nil sess")
} }
} })
func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) { t.Run("we apply TLS defaults", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
var gotTLSConfig *tls.Config var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
@ -361,9 +354,30 @@ func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
if gotTLSConfig.ServerName != "www.google.com" { if gotTLSConfig.ServerName != "www.google.com" {
t.Fatal("gotTLSConfig.ServerName has not been set") t.Fatal("gotTLSConfig.ServerName has not been set")
} }
})
})
t.Run("lookup host with address", func(t *testing.T) {
dialer := &quicDialerResolver{Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
// We should not arrive here and call this function but if we do then
// there is going to be an error that fails this test.
return nil, errors.New("mocked error")
},
}}
addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
t.Fatal("not the result we expected")
}
})
} }
func TestQUICDialerLoggerCloseIdleConnections(t *testing.T) { func TestQUICLoggerDialer(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var forDialer bool var forDialer bool
d := &quicDialerLogger{ d := &quicDialerLogger{
Dialer: &mocks.QUICDialer{ Dialer: &mocks.QUICDialer{
@ -376,9 +390,10 @@ func TestQUICDialerLoggerCloseIdleConnections(t *testing.T) {
if !forDialer { if !forDialer {
t.Fatal("not called") t.Fatal("not called")
} }
} })
func TestQUICDialerLoggerSuccess(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
d := &quicDialerLogger{ d := &quicDialerLogger{
Dialer: &mocks.QUICDialer{ Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network string, MockDialContext: func(ctx context.Context, network string,
@ -404,9 +419,9 @@ func TestQUICDialerLoggerSuccess(t *testing.T) {
if err := sess.CloseWithError(0, ""); err != nil { if err := sess.CloseWithError(0, ""); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} })
func TestQUICDialerLoggerFailure(t *testing.T) { t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
d := &quicDialerLogger{ d := &quicDialerLogger{
Dialer: &mocks.QUICDialer{ Dialer: &mocks.QUICDialer{
@ -428,46 +443,33 @@ func TestQUICDialerLoggerFailure(t *testing.T) {
if sess != nil { if sess != nil {
t.Fatal("expected nil session") t.Fatal("expected nil session")
} }
})
})
} }
func TestNewQUICDialerWithoutResolverChain(t *testing.T) { func TestNewQUICDialer(t *testing.T) {
ql := NewQUICListener() ql := NewQUICListener()
dlr := NewQUICDialerWithoutResolver(ql, log.Log) dlr := NewQUICDialerWithoutResolver(ql, log.Log)
dlog, okay := dlr.(*quicDialerLogger) logger := dlr.(*quicDialerLogger)
if !okay { if logger.Logger != log.Log {
t.Fatal("invalid type")
}
if dlog.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
dr, okay := dlog.Dialer.(*quicDialerResolver) resolver := logger.Dialer.(*quicDialerResolver)
if !okay { if _, okay := resolver.Resolver.(*nullResolver); !okay {
t.Fatal("invalid type")
}
if _, okay := dr.Resolver.(*nullResolver); !okay {
t.Fatal("invalid resolver type") t.Fatal("invalid resolver type")
} }
dlog, okay = dr.Dialer.(*quicDialerLogger) logger = resolver.Dialer.(*quicDialerLogger)
if !okay { if logger.Logger != log.Log {
t.Fatal("invalid type")
}
if dlog.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
ew, okay := dlog.Dialer.(*quicDialerErrWrapper) errWrapper := logger.Dialer.(*quicDialerErrWrapper)
if !okay { base := errWrapper.QUICDialer.(*quicDialerQUICGo)
t.Fatal("invalid type") if base.QUICListener != ql {
}
dgo, okay := ew.QUICDialer.(*quicDialerQUICGo)
if !okay {
t.Fatal("invalid type")
}
if dgo.QUICListener != ql {
t.Fatal("invalid quic listener") t.Fatal("invalid quic listener")
} }
} }
func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) { func TestNewSingleUseQUICDialer(t *testing.T) {
sess := &mocks.QUICEarlySession{} sess := &mocks.QUICEarlySession{}
qd := NewSingleUseQUICDialer(sess) qd := NewSingleUseQUICDialer(sess)
outsess, err := qd.DialContext( outsess, err := qd.DialContext(

View File

@ -15,6 +15,7 @@ func TestQuirkReduceErrors(t *testing.T) {
t.Fatal("wrong result") t.Fatal("wrong result")
} }
}) })
t.Run("single error", func(t *testing.T) { t.Run("single error", func(t *testing.T) {
err := errors.New("mocked error") err := errors.New("mocked error")
result := quirkReduceErrors([]error{err}) result := quirkReduceErrors([]error{err})
@ -22,6 +23,7 @@ func TestQuirkReduceErrors(t *testing.T) {
t.Fatal("wrong result") t.Fatal("wrong result")
} }
}) })
t.Run("multiple errors", func(t *testing.T) { t.Run("multiple errors", func(t *testing.T) {
err1 := errors.New("mocked error #1") err1 := errors.New("mocked error #1")
err2 := errors.New("mocked error #2") err2 := errors.New("mocked error #2")
@ -30,6 +32,7 @@ func TestQuirkReduceErrors(t *testing.T) {
t.Fatal("wrong result") t.Fatal("wrong result")
} }
}) })
t.Run("multiple errors with meaningful ones", func(t *testing.T) { t.Run("multiple errors with meaningful ones", func(t *testing.T) {
err1 := errors.New("mocked error #1") err1 := errors.New("mocked error #1")
err2 := &errorsx.ErrWrapper{ err2 := &errorsx.ErrWrapper{

View File

@ -15,7 +15,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestResolverSystemNetworkAddress(t *testing.T) { func TestResolverSystem(t *testing.T) {
t.Run("Network and Address", func(t *testing.T) {
r := &resolverSystem{} r := &resolverSystem{}
if r.Network() != "system" { if r.Network() != "system" {
t.Fatal("invalid Network") t.Fatal("invalid Network")
@ -23,9 +24,9 @@ func TestResolverSystemNetworkAddress(t *testing.T) {
if r.Address() != "" { if r.Address() != "" {
t.Fatal("invalid Address") t.Fatal("invalid Address")
} }
} })
func TestResolverSystemWorksAsIntended(t *testing.T) { t.Run("works as intended", func(t *testing.T) {
r := &resolverSystem{} r := &resolverSystem{}
defer r.CloseIdleConnections() defer r.CloseIdleConnections()
addrs, err := r.LookupHost(context.Background(), "dns.google.com") addrs, err := r.LookupHost(context.Background(), "dns.google.com")
@ -35,16 +36,17 @@ func TestResolverSystemWorksAsIntended(t *testing.T) {
if addrs == nil { if addrs == nil {
t.Fatal("expected non-nil result here") t.Fatal("expected non-nil result here")
} }
} })
func TestResolverSystemDefaultTimeout(t *testing.T) { t.Run("check default timeout", func(t *testing.T) {
r := &resolverSystem{} r := &resolverSystem{}
if r.timeout() != 15*time.Second { if r.timeout() != 15*time.Second {
t.Fatal("unexpected default timeout") t.Fatal("unexpected default timeout")
} }
} })
func TestResolverSystemWithTimeoutAndSuccess(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("with timeout and success", func(t *testing.T) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(1) wg.Add(1)
r := &resolverSystem{ r := &resolverSystem{
@ -64,9 +66,9 @@ func TestResolverSystemWithTimeoutAndSuccess(t *testing.T) {
t.Fatal("invalid addrs") t.Fatal("invalid addrs")
} }
wg.Wait() wg.Wait()
} })
func TestResolverSystemWithTimeoutAndFailure(t *testing.T) { t.Run("with timeout and failure", func(t *testing.T) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(1) wg.Add(1)
r := &resolverSystem{ r := &resolverSystem{
@ -86,9 +88,9 @@ func TestResolverSystemWithTimeoutAndFailure(t *testing.T) {
t.Fatal("invalid addrs") t.Fatal("invalid addrs")
} }
wg.Wait() wg.Wait()
} })
func TestResolverSystemWithNXDOMAIN(t *testing.T) { t.Run("with NXDOMAIN", func(t *testing.T) {
r := &resolverSystem{ r := &resolverSystem{
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) { testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, errors.New("no such host") return nil, errors.New("no such host")
@ -102,9 +104,13 @@ func TestResolverSystemWithNXDOMAIN(t *testing.T) {
if addrs != nil { if addrs != nil {
t.Fatal("invalid addrs") t.Fatal("invalid addrs")
} }
})
})
} }
func TestResolverLoggerWithSuccess(t *testing.T) { func TestResolverLogger(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("with success", func(t *testing.T) {
expected := []string{"1.1.1.1"} expected := []string{"1.1.1.1"}
r := resolverLogger{ r := resolverLogger{
Logger: log.Log, Logger: log.Log,
@ -121,9 +127,9 @@ func TestResolverLoggerWithSuccess(t *testing.T) {
if diff := cmp.Diff(expected, addrs); diff != "" { if diff := cmp.Diff(expected, addrs); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
} })
func TestResolverLoggerWithFailure(t *testing.T) { t.Run("with failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := resolverLogger{ r := resolverLogger{
Logger: log.Log, Logger: log.Log,
@ -140,9 +146,13 @@ func TestResolverLoggerWithFailure(t *testing.T) {
if addrs != nil { if addrs != nil {
t.Fatal("expected nil addr here") t.Fatal("expected nil addr here")
} }
})
})
} }
func TestResolverIDNAWorksAsIntended(t *testing.T) { func TestResolverIDNA(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("with valid IDNA in input", func(t *testing.T) {
expectedIPs := []string{"77.88.55.66"} expectedIPs := []string{"77.88.55.66"}
r := &resolverIDNA{ r := &resolverIDNA{
Resolver: &mocks.Resolver{ Resolver: &mocks.Resolver{
@ -162,9 +172,9 @@ func TestResolverIDNAWorksAsIntended(t *testing.T) {
if diff := cmp.Diff(expectedIPs, addrs); diff != "" { if diff := cmp.Diff(expectedIPs, addrs); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
} })
func TestResolverIDNAWithInvalidPunycode(t *testing.T) { t.Run("with invalid punycode", func(t *testing.T) {
r := &resolverIDNA{Resolver: &mocks.Resolver{ r := &resolverIDNA{Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, errors.New("should not happen") return nil, errors.New("should not happen")
@ -179,35 +189,25 @@ func TestResolverIDNAWithInvalidPunycode(t *testing.T) {
if addrs != nil { if addrs != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
})
})
} }
func TestNewResolverTypeChain(t *testing.T) { func TestNewResolverSystem(t *testing.T) {
r := NewResolverSystem(log.Log) resolver := NewResolverSystem(log.Log)
ridna, ok := r.(*resolverIDNA) idna := resolver.(*resolverIDNA)
if !ok { logger := idna.Resolver.(*resolverLogger)
t.Fatal("invalid resolver") if logger.Logger != log.Log {
}
rl, ok := ridna.Resolver.(*resolverLogger)
if !ok {
t.Fatal("invalid resolver")
}
if rl.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
scia, ok := rl.Resolver.(*resolverShortCircuitIPAddr) shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
if !ok { errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
t.Fatal("invalid resolver") _ = errWrapper.Resolver.(*resolverSystem)
}
ew, ok := scia.Resolver.(*resolverErrWrapper)
if !ok {
t.Fatal("invalid resolver")
}
if _, ok := ew.Resolver.(*resolverSystem); !ok {
t.Fatal("invalid resolver")
}
} }
func TestResolverShortCircuitIPAddrWithIPAddr(t *testing.T) { func TestResolverShortCircuitIPAddr(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("with IP addr", func(t *testing.T) {
r := &resolverShortCircuitIPAddr{ r := &resolverShortCircuitIPAddr{
Resolver: &mocks.Resolver{ Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
@ -223,9 +223,9 @@ func TestResolverShortCircuitIPAddrWithIPAddr(t *testing.T) {
if len(addrs) != 1 || addrs[0] != "8.8.8.8" { if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("invalid result") t.Fatal("invalid result")
} }
} })
func TestResolverShortCircuitIPAddrWithDomain(t *testing.T) { t.Run("with domain", func(t *testing.T) {
r := &resolverShortCircuitIPAddr{ r := &resolverShortCircuitIPAddr{
Resolver: &mocks.Resolver{ Resolver: &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
@ -241,9 +241,11 @@ func TestResolverShortCircuitIPAddrWithDomain(t *testing.T) {
if addrs != nil { if addrs != nil {
t.Fatal("invalid result") t.Fatal("invalid result")
} }
})
})
} }
func TestNullResolverWorksAsIntended(t *testing.T) { func TestNullResolver(t *testing.T) {
r := &nullResolver{} r := &nullResolver{}
ctx := context.Background() ctx := context.Background()
addrs, err := r.LookupHost(ctx, "dns.google") addrs, err := r.LookupHost(ctx, "dns.google")

View File

@ -118,7 +118,10 @@ func TestConfigureTLSVersion(t *testing.T) {
} }
} }
func TestTLSHandshakerConfigurableWithError(t *testing.T) { func TestTLSHandshakerConfigurable(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("with error", func(t *testing.T) {
var times []time.Time var times []time.Time
h := &tlsHandshakerConfigurable{} h := &tlsHandshakerConfigurable{}
tcpConn := &mocks.Conn{ tcpConn := &mocks.Conn{
@ -149,9 +152,9 @@ func TestTLSHandshakerConfigurableWithError(t *testing.T) {
if !times[1].IsZero() { if !times[1].IsZero() {
t.Fatal("did not clear timeout on exit") t.Fatal("did not clear timeout on exit")
} }
} })
func TestTLSHandshakerConfigurableSuccess(t *testing.T) { t.Run("with success", func(t *testing.T) {
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200) rw.WriteHeader(200)
}) })
@ -182,9 +185,9 @@ func TestTLSHandshakerConfigurableSuccess(t *testing.T) {
if connState.Version != tls.VersionTLS13 { if connState.Version != tls.VersionTLS13 {
t.Fatal("unexpected TLS version") t.Fatal("unexpected TLS version")
} }
} })
func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) { t.Run("sets default root CA", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
var gotTLSConfig *tls.Config var gotTLSConfig *tls.Config
handshaker := &tlsHandshakerConfigurable{ handshaker := &tlsHandshakerConfigurable{
@ -220,9 +223,13 @@ func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
if gotTLSConfig.RootCAs != defaultCertPool { if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("gotTLSConfig.RootCAs has not been correctly set") t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
} }
})
})
} }
func TestTLSHandshakerLoggerSuccess(t *testing.T) { func TestTLSHandshakerLogger(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
th := &tlsHandshakerLogger{ th := &tlsHandshakerLogger{
TLSHandshaker: &mocks.TLSHandshaker{ TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
@ -248,9 +255,9 @@ func TestTLSHandshakerLoggerSuccess(t *testing.T) {
if !reflect.ValueOf(connState).IsZero() { if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here") t.Fatal("expected zero ConnectionState here")
} }
} })
func TestTLSHandshakerLoggerFailure(t *testing.T) { t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
th := &tlsHandshakerLogger{ th := &tlsHandshakerLogger{
TLSHandshaker: &mocks.TLSHandshaker{ TLSHandshaker: &mocks.TLSHandshaker{
@ -277,9 +284,12 @@ func TestTLSHandshakerLoggerFailure(t *testing.T) {
if !reflect.ValueOf(connState).IsZero() { if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here") t.Fatal("expected zero ConnectionState here")
} }
})
})
} }
func TestTLSDialerCloseIdleConnections(t *testing.T) { func TestTLSDialer(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool var called bool
dialer := &tlsDialer{ dialer := &tlsDialer{
Dialer: &mocks.Dialer{ Dialer: &mocks.Dialer{
@ -292,9 +302,10 @@ func TestTLSDialerCloseIdleConnections(t *testing.T) {
if !called { if !called {
t.Fatal("not called") t.Fatal("not called")
} }
} })
func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { t.Run("DialTLSContext", func(t *testing.T) {
t.Run("failure to split host and port", func(t *testing.T) {
dialer := &tlsDialer{} dialer := &tlsDialer{}
ctx := context.Background() ctx := context.Background()
const address = "www.google.com" // missing port const address = "www.google.com" // missing port
@ -305,9 +316,9 @@ func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
if conn != nil { if conn != nil {
t.Fatal("connection is not nil") t.Fatal("connection is not nil")
} }
} })
func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { t.Run("failure dialing", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() // immediately fail cancel() // immediately fail
dialer := tlsDialer{Dialer: &dialerSystem{}} dialer := tlsDialer{Dialer: &dialerSystem{}}
@ -318,9 +329,9 @@ func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
if conn != nil { if conn != nil {
t.Fatal("connection is not nil") t.Fatal("connection is not nil")
} }
} })
func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { t.Run("failure handshaking", func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
dialer := tlsDialer{ dialer := tlsDialer{
Config: &tls.Config{}, Config: &tls.Config{},
@ -342,9 +353,9 @@ func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
if conn != nil { if conn != nil {
t.Fatal("connection is not nil") t.Fatal("connection is not nil")
} }
} })
func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { t.Run("success handshaking", func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
dialer := tlsDialer{ dialer := tlsDialer{
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
@ -370,9 +381,11 @@ func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
t.Fatal("connection is nil") t.Fatal("connection is nil")
} }
conn.Close() conn.Close()
} })
})
func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { t.Run("config", func(t *testing.T) {
t.Run("from empty config for web", func(t *testing.T) {
d := &tlsDialer{} d := &tlsDialer{}
config := d.config("www.google.com", "443") config := d.config("www.google.com", "443")
if config.ServerName != "www.google.com" { if config.ServerName != "www.google.com" {
@ -381,9 +394,9 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
} })
func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { t.Run("from empty config for dot", func(t *testing.T) {
d := &tlsDialer{} d := &tlsDialer{}
config := d.config("dns.google", "853") config := d.config("dns.google", "853")
if config.ServerName != "dns.google" { if config.ServerName != "dns.google" {
@ -392,9 +405,9 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
} })
func TestTLSDialerConfigWithServerName(t *testing.T) { t.Run("with server name", func(t *testing.T) {
d := &tlsDialer{ d := &tlsDialer{
Config: &tls.Config{ Config: &tls.Config{
ServerName: "example.com", ServerName: "example.com",
@ -407,9 +420,9 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
} })
func TestTLSDialerConfigWithALPN(t *testing.T) { t.Run("with alpn", func(t *testing.T) {
d := &tlsDialer{ d := &tlsDialer{
Config: &tls.Config{ Config: &tls.Config{
NextProtos: []string{"h2"}, NextProtos: []string{"h2"},
@ -422,50 +435,40 @@ func TestTLSDialerConfigWithALPN(t *testing.T) {
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
})
})
} }
func TestNewTLSHandshakerStdlibTypes(t *testing.T) { func TestNewTLSHandshakerStdlib(t *testing.T) {
th := NewTLSHandshakerStdlib(log.Log) th := NewTLSHandshakerStdlib(log.Log)
thl, okay := th.(*tlsHandshakerLogger) logger := th.(*tlsHandshakerLogger)
if !okay { if logger.Logger != log.Log {
t.Fatal("invalid type")
}
if thl.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper) errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
if !okay { configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
t.Fatal("invalid type") if configurable.NewConn != nil {
}
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
if !okay {
t.Fatal("invalid type")
}
if thc.NewConn != nil {
t.Fatal("expected nil NewConn") t.Fatal("expected nil NewConn")
} }
} }
func TestNewTLSDialerWorksAsIntended(t *testing.T) { func TestNewTLSDialer(t *testing.T) {
d := &mocks.Dialer{} d := &mocks.Dialer{}
tlsh := &mocks.TLSHandshaker{} th := &mocks.TLSHandshaker{}
td := NewTLSDialer(d, tlsh) dialer := NewTLSDialer(d, th)
tdut, okay := td.(*tlsDialer) tlsd := dialer.(*tlsDialer)
if !okay { if tlsd.Config == nil {
t.Fatal("invalid type")
}
if tdut.Config == nil {
t.Fatal("unexpected config") t.Fatal("unexpected config")
} }
if tdut.Dialer != d { if tlsd.Dialer != d {
t.Fatal("unexpected dialer") t.Fatal("unexpected dialer")
} }
if tdut.TLSHandshaker != tlsh { if tlsd.TLSHandshaker != th {
t.Fatal("invalid handshaker") t.Fatal("invalid handshaker")
} }
} }
func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) { func TestNewSingleUseTLSDialer(t *testing.T) {
conn := &mocks.TLSConn{} conn := &mocks.TLSConn{}
d := NewSingleUseTLSDialer(conn) d := NewSingleUseTLSDialer(conn)
outconn, err := d.DialTLSContext(context.Background(), "", "") outconn, err := d.DialTLSContext(context.Background(), "", "")

View File

@ -2,9 +2,7 @@ package netxlite
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"net"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -13,47 +11,22 @@ import (
utls "gitlab.com/yawning/utls.git" utls "gitlab.com/yawning/utls.git"
) )
func TestUTLSHandshakerChrome(t *testing.T) { func TestNewTLSHandshakerUTLS(t *testing.T) {
h := &tlsHandshakerConfigurable{
NewConn: newConnUTLS(&utls.HelloChrome_Auto),
}
cfg := &tls.Config{ServerName: "google.com"}
conn, err := net.Dial("tcp", "google.com:443")
if err != nil {
t.Fatal("unexpected error", err)
}
conn, _, err = h.Handshake(context.Background(), conn, cfg)
if err != nil {
t.Fatal("unexpected error", err)
}
if conn == nil {
t.Fatal("nil connection")
}
}
func TestNewTLSHandshakerUTLSTypes(t *testing.T) {
th := NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_83) th := NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_83)
thl, okay := th.(*tlsHandshakerLogger) logger := th.(*tlsHandshakerLogger)
if !okay { if logger.Logger != log.Log {
t.Fatal("invalid type")
}
if thl.Logger != log.Log {
t.Fatal("invalid logger") t.Fatal("invalid logger")
} }
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper) errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
if !okay { configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
t.Fatal("invalid type") if configurable.NewConn == nil {
}
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
if !okay {
t.Fatal("invalid type")
}
if thc.NewConn == nil {
t.Fatal("expected non-nil NewConn") t.Fatal("expected non-nil NewConn")
} }
} }
func TestUTLSConnHandshakeNotInterruptedSuccess(t *testing.T) { func TestUTLSConn(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("not interrupted with success", func(t *testing.T) {
ctx := context.Background() ctx := context.Background()
conn := &utlsConn{ conn := &utlsConn{
testableHandshake: func() error { testableHandshake: func() error {
@ -64,9 +37,9 @@ func TestUTLSConnHandshakeNotInterruptedSuccess(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
} })
func TestUTLSConnHandshakeNotInterruptedFailure(t *testing.T) { t.Run("not interrupted with failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
ctx := context.Background() ctx := context.Background()
conn := &utlsConn{ conn := &utlsConn{
@ -78,9 +51,9 @@ func TestUTLSConnHandshakeNotInterruptedFailure(t *testing.T) {
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
} })
func TestUTLSConnHandshakeInterrupted(t *testing.T) { t.Run("interrupted", func(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
sigch := make(chan interface{}) sigch := make(chan interface{})
@ -99,9 +72,9 @@ func TestUTLSConnHandshakeInterrupted(t *testing.T) {
} }
close(sigch) close(sigch)
wg.Wait() wg.Wait()
} })
func TestUTLSConnHandshakePanic(t *testing.T) { t.Run("with panic", func(t *testing.T) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
ctx := context.Background() ctx := context.Background()
@ -116,4 +89,6 @@ func TestUTLSConnHandshakePanic(t *testing.T) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
wg.Wait() wg.Wait()
})
})
} }