diff --git a/internal/cmd/oohelperd/internal/internal_test.go b/internal/cmd/oohelperd/internal/internal_test.go index 04d71e8..114b1b1 100644 --- a/internal/cmd/oohelperd/internal/internal_test.go +++ b/internal/cmd/oohelperd/internal/internal_test.go @@ -56,7 +56,7 @@ func TestWorkingAsIntended(t *testing.T) { Client: http.DefaultClient, Dialer: new(net.Dialer), MaxAcceptableBody: 1 << 24, - Resolver: netxlite.ResolverSystem{}, + Resolver: &netxlite.ResolverSystem{}, } srv := httptest.NewServer(handler) defer srv.Close() diff --git a/internal/engine/experiment/ndt7/dial.go b/internal/engine/experiment/ndt7/dial.go index 8a89b1a..c627513 100644 --- a/internal/engine/experiment/ndt7/dial.go +++ b/internal/engine/experiment/ndt7/dial.go @@ -34,8 +34,8 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM } func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) { - var reso resolver.Resolver = netxlite.ResolverSystem{} - reso = netxlite.ResolverLogger{Resolver: reso, Logger: mgr.logger} + var reso resolver.Resolver = &netxlite.ResolverSystem{} + reso = &netxlite.ResolverLogger{Resolver: reso, Logger: mgr.logger} dlr := dialer.New(&dialer.Config{ ContextByteCounting: true, Logger: mgr.logger, diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index 4b8b1ec..841a1d8 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -14,6 +14,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) // Dialer performs measurements while dialing. @@ -107,9 +108,7 @@ func newTLSDialer(d dialer.Dialer, config *tls.Config) tlsdialer.TLSDialer { Dialer: d, TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{ - TLSHandshaker: tlsdialer.TimeoutTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, - }, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, }, }, } diff --git a/internal/engine/legacy/netx/resolver.go b/internal/engine/legacy/netx/resolver.go index 7b5c574..cb02a5b 100644 --- a/internal/engine/legacy/netx/resolver.go +++ b/internal/engine/legacy/netx/resolver.go @@ -159,7 +159,7 @@ func resolverWrapTransport(txp resolver.RoundTripper) resolver.EmitterResolver { } func newResolverSystem() resolver.EmitterResolver { - return resolverWrapResolver(netxlite.ResolverSystem{}) + return resolverWrapResolver(&netxlite.ResolverSystem{}) } func newResolverUDP(dialer resolver.Dialer, address string) resolver.EmitterResolver { diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 13a766e..2955575 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -115,7 +115,7 @@ var defaultCertPool *x509.CertPool = tlsx.NewDefaultCertPool() // NewResolver creates a new resolver from the specified config func NewResolver(config Config) Resolver { if config.BaseResolver == nil { - config.BaseResolver = netxlite.ResolverSystem{} + config.BaseResolver = &netxlite.ResolverSystem{} } var r Resolver = config.BaseResolver r = resolver.AddressResolver{Resolver: r} @@ -134,7 +134,7 @@ func NewResolver(config Config) Resolver { } r = resolver.ErrorWrapperResolver{Resolver: r} if config.Logger != nil { - r = netxlite.ResolverLogger{Logger: config.Logger, Resolver: r} + r = &netxlite.ResolverLogger{Logger: config.Logger, Resolver: r} } if config.ResolveSaver != nil { r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver} @@ -176,8 +176,7 @@ func NewTLSDialer(config Config) TLSDialer { if config.Dialer == nil { config.Dialer = NewDialer(config) } - var h tlsHandshaker = tlsdialer.SystemTLSHandshaker{} - h = tlsdialer.TimeoutTLSHandshaker{TLSHandshaker: h} + var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{} h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h} if config.Logger != nil { h = tlsdialer.LoggingTLSHandshaker{Logger: config.Logger, TLSHandshaker: h} @@ -318,7 +317,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, } switch resolverURL.Scheme { case "system": - c.Resolver = netxlite.ResolverSystem{} + c.Resolver = &netxlite.ResolverSystem{} return c, nil case "https": config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index dbca660..c192f4e 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -32,7 +32,7 @@ func TestNewResolverVanilla(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystem) if !ok { t.Fatal("not the resolver we expected") } @@ -82,7 +82,7 @@ func TestNewResolverWithBogonFilter(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystem) if !ok { t.Fatal("not the resolver we expected") } @@ -96,7 +96,7 @@ func TestNewResolverWithLogging(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - lr, ok := ir.Resolver.(netxlite.ResolverLogger) + lr, ok := ir.Resolver.(*netxlite.ResolverLogger) if !ok { t.Fatal("not the resolver we expected") } @@ -111,7 +111,7 @@ func TestNewResolverWithLogging(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystem) if !ok { t.Fatal("not the resolver we expected") } @@ -141,7 +141,7 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystem) if !ok { t.Fatal("not the resolver we expected") } @@ -170,7 +170,7 @@ func TestNewResolverWithReadWriteCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystem) if !ok { t.Fatal("not the resolver we expected") } @@ -204,7 +204,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ar.Resolver.(netxlite.ResolverSystem) + _, ok = ar.Resolver.(*netxlite.ResolverSystem) if !ok { t.Fatal("not the resolver we expected") } @@ -235,11 +235,7 @@ func TestNewTLSDialerVanilla(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -268,11 +264,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -311,11 +303,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -355,11 +343,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -392,11 +376,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -431,11 +411,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -472,7 +448,7 @@ func TestNewWithTLSDialer(t *testing.T) { tlsDialer := tlsdialer.TLSDialer{ Config: new(tls.Config), Dialer: netx.FakeDialer{Err: expected}, - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, } txp := netx.NewHTTPTransport(netx.Config{ TLSDialer: tlsDialer, @@ -598,7 +574,7 @@ func TestNewDNSClientSystemResolver(t *testing.T) { if err != nil { t.Fatal(err) } - if _, ok := dnsclient.Resolver.(netxlite.ResolverSystem); !ok { + if _, ok := dnsclient.Resolver.(*netxlite.ResolverSystem); !ok { t.Fatal("not the resolver we expected") } dnsclient.CloseIdleConnections() @@ -610,7 +586,7 @@ func TestNewDNSClientEmpty(t *testing.T) { if err != nil { t.Fatal(err) } - if _, ok := dnsclient.Resolver.(netxlite.ResolverSystem); !ok { + if _, ok := dnsclient.Resolver.(*netxlite.ResolverSystem); !ok { t.Fatal("not the resolver we expected") } dnsclient.CloseIdleConnections() diff --git a/internal/engine/netx/resolver/chain_test.go b/internal/engine/netx/resolver/chain_test.go index 3a15045..bb9c8c2 100644 --- a/internal/engine/netx/resolver/chain_test.go +++ b/internal/engine/netx/resolver/chain_test.go @@ -11,7 +11,7 @@ import ( func TestChainLookupHost(t *testing.T) { r := resolver.ChainResolver{ Primary: resolver.NewFakeResolverThatFails(), - Secondary: netxlite.ResolverSystem{}, + Secondary: &netxlite.ResolverSystem{}, } if r.Address() != "" { t.Fatal("invalid address") diff --git a/internal/engine/netx/resolver/integration_test.go b/internal/engine/netx/resolver/integration_test.go index 39cebb9..96038df 100644 --- a/internal/engine/netx/resolver/integration_test.go +++ b/internal/engine/netx/resolver/integration_test.go @@ -19,7 +19,7 @@ func testresolverquick(t *testing.T, reso resolver.Resolver) { if testing.Short() { t.Skip("skip test in short mode") } - reso = netxlite.ResolverLogger{Logger: log.Log, Resolver: reso} + reso = &netxlite.ResolverLogger{Logger: log.Log, Resolver: reso} addrs, err := reso.LookupHost(context.Background(), "dns.google.com") if err != nil { t.Fatal(err) @@ -45,7 +45,7 @@ func testresolverquickidna(t *testing.T, reso resolver.Resolver) { t.Skip("skip test in short mode") } reso = resolver.IDNAResolver{ - netxlite.ResolverLogger{Logger: log.Log, Resolver: reso}, + &netxlite.ResolverLogger{Logger: log.Log, Resolver: reso}, } addrs, err := reso.LookupHost(context.Background(), "яндекс.рф") if err != nil { @@ -57,7 +57,7 @@ func testresolverquickidna(t *testing.T, reso resolver.Resolver) { } func TestNewResolverSystem(t *testing.T) { - reso := netxlite.ResolverSystem{} + reso := &netxlite.ResolverSystem{} testresolverquick(t, reso) testresolverquickidna(t, reso) } diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go index 694d204..2f6e619 100644 --- a/internal/engine/netx/tlsdialer/integration_test.go +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -8,6 +8,7 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) func TestTLSDialerSuccess(t *testing.T) { @@ -17,7 +18,7 @@ func TestTLSDialerSuccess(t *testing.T) { log.SetLevel(log.DebugLevel) dialer := tlsdialer.TLSDialer{Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.LoggingTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Logger: log.Log, }, } diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go index 8ad79c4..37b01b7 100644 --- a/internal/engine/netx/tlsdialer/saver_test.go +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -12,6 +12,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { @@ -25,7 +26,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { Config: &tls.Config{NextProtos: nextprotos}, Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Saver: saver, }, } @@ -118,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) { Config: &tls.Config{NextProtos: nextprotos}, Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Saver: saver, }, } @@ -183,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) { tlsdlr := tlsdialer.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Saver: saver, }, } @@ -216,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { tlsdlr := tlsdialer.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Saver: saver, }, } @@ -249,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) { tlsdlr := tlsdialer.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Saver: saver, }, } @@ -283,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { Config: &tls.Config{InsecureSkipVerify: true}, Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Saver: saver, }, } diff --git a/internal/engine/netx/tlsdialer/tls.go b/internal/engine/netx/tlsdialer/tls.go index 1fc05cf..4c58e01 100644 --- a/internal/engine/netx/tlsdialer/tls.go +++ b/internal/engine/netx/tlsdialer/tls.go @@ -22,42 +22,6 @@ type TLSHandshaker interface { net.Conn, tls.ConnectionState, error) } -// SystemTLSHandshaker is the system TLS handshaker. -type SystemTLSHandshaker struct{} - -// Handshake implements Handshaker.Handshake -func (h SystemTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - tlsconn := tls.Client(conn, config) - if err := tlsconn.Handshake(); err != nil { - return nil, tls.ConnectionState{}, err - } - return tlsconn, tlsconn.ConnectionState(), nil -} - -// TimeoutTLSHandshaker is a TLSHandshaker with timeout -type TimeoutTLSHandshaker struct { - TLSHandshaker - HandshakeTimeout time.Duration // default: 10 second -} - -// Handshake implements Handshaker.Handshake -func (h TimeoutTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - timeout := 10 * time.Second - if h.HandshakeTimeout != 0 { - timeout = h.HandshakeTimeout - } - if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { - return nil, tls.ConnectionState{}, err - } - tlsconn, connstate, err := h.TLSHandshaker.Handshake(ctx, conn, config) - conn.SetDeadline(time.Time{}) - return tlsconn, connstate, err -} - // ErrorWrapperTLSHandshaker wraps the returned error to be an OONI error type ErrorWrapperTLSHandshaker struct { TLSHandshaker diff --git a/internal/engine/netx/tlsdialer/tls_test.go b/internal/engine/netx/tlsdialer/tls_test.go index d88449a..3b3cd8b 100644 --- a/internal/engine/netx/tlsdialer/tls_test.go +++ b/internal/engine/netx/tlsdialer/tls_test.go @@ -13,10 +13,11 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) func TestSystemTLSHandshakerEOFError(t *testing.T) { - h := tlsdialer.SystemTLSHandshaker{} + h := &netxlite.TLSHandshakerStdlib{} conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{ ServerName: "x.org", }) @@ -28,63 +29,6 @@ func TestSystemTLSHandshakerEOFError(t *testing.T) { } } -func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) { - h := tlsdialer.TimeoutTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, - HandshakeTimeout: 200 * time.Millisecond, - } - expected := errors.New("mocked error") - conn, _, err := h.Handshake( - context.Background(), &tlsdialer.FakeConn{SetDeadlineError: expected}, - new(tls.Config)) - if !errors.Is(err, expected) { - t.Fatal("not the error that we expected") - } - if conn != nil { - t.Fatal("expected nil con here") - } -} - -func TestTimeoutTLSHandshakerEOFError(t *testing.T) { - h := tlsdialer.TimeoutTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, - HandshakeTimeout: 200 * time.Millisecond, - } - conn, _, err := h.Handshake( - context.Background(), tlsdialer.EOFConn{}, &tls.Config{ServerName: "x.org"}) - if !errors.Is(err, io.EOF) { - t.Fatal("not the error that we expected") - } - if conn != nil { - t.Fatal("expected nil con here") - } -} - -func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) { - h := tlsdialer.TimeoutTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, - HandshakeTimeout: 200 * time.Millisecond, - } - underlying := &SetDeadlineConn{} - conn, _, err := h.Handshake( - context.Background(), underlying, &tls.Config{ServerName: "x.org"}) - if !errors.Is(err, io.EOF) { - t.Fatal("not the error that we expected") - } - if conn != nil { - t.Fatal("expected nil con here") - } - if len(underlying.deadlines) != 2 { - t.Fatal("SetDeadline not called twice") - } - if underlying.deadlines[0].Before(time.Now()) { - t.Fatal("the first SetDeadline call was incorrect") - } - if !underlying.deadlines[1].IsZero() { - t.Fatal("the second SetDeadline call was incorrect") - } -} - type SetDeadlineConn struct { tlsdialer.EOFConn deadlines []time.Time @@ -179,7 +123,7 @@ func TestTLSDialerFailureDialing(t *testing.T) { } func TestTLSDialerFailureHandshaking(t *testing.T) { - rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}} + rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}} dialer := tlsdialer.TLSDialer{ Dialer: tlsdialer.EOFConnDialer{}, TLSHandshaker: rec, @@ -198,7 +142,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) { } func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) { - rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}} + rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}} dialer := tlsdialer.TLSDialer{ Config: &tls.Config{ ServerName: "x.org", @@ -235,7 +179,7 @@ func TestDialTLSContextGood(t *testing.T) { dialer := tlsdialer.TLSDialer{ Config: &tls.Config{ServerName: "google.com"}, Dialer: new(net.Dialer), - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, } conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443") if err != nil { @@ -252,9 +196,8 @@ func TestDialTLSContextTimeout(t *testing.T) { Config: &tls.Config{ServerName: "google.com"}, Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{ - TLSHandshaker: tlsdialer.TimeoutTLSHandshaker{ - TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, - HandshakeTimeout: 10 * time.Microsecond, + TLSHandshaker: &netxlite.TLSHandshakerStdlib{ + Timeout: 10 * time.Microsecond, }, }, } diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolver.go index 6f4b1f8..37db6e4 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolver.go @@ -15,25 +15,25 @@ type Resolver interface { // ResolverSystem is the system resolver. type ResolverSystem struct{} -var _ Resolver = ResolverSystem{} +var _ Resolver = &ResolverSystem{} // LookupHost implements Resolver.LookupHost. -func (r ResolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) { +func (r *ResolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) { return net.DefaultResolver.LookupHost(ctx, hostname) } // Network implements Resolver.Network. -func (r ResolverSystem) Network() string { +func (r *ResolverSystem) Network() string { return "system" } // Address implements Resolver.Address. -func (r ResolverSystem) Address() string { +func (r *ResolverSystem) Address() string { return "" } // DefaultResolver is the resolver we use by default. -var DefaultResolver = ResolverSystem{} +var DefaultResolver = &ResolverSystem{} // ResolverLogger is a resolver that emits events type ResolverLogger struct { @@ -41,10 +41,10 @@ type ResolverLogger struct { Logger Logger } -var _ Resolver = ResolverLogger{} +var _ Resolver = &ResolverLogger{} // LookupHost returns the IP addresses of a host -func (r ResolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) { +func (r *ResolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) { r.Logger.Debugf("resolve %s...", hostname) start := time.Now() addrs, err := r.Resolver.LookupHost(ctx, hostname) @@ -62,7 +62,7 @@ type resolverNetworker interface { } // Network implements Resolver.Network. -func (r ResolverLogger) Network() string { +func (r *ResolverLogger) Network() string { if rn, ok := r.Resolver.(resolverNetworker); ok { return rn.Network() } @@ -74,7 +74,7 @@ type resolverAddresser interface { } // Address implements Resolver.Address. -func (r ResolverLogger) Address() string { +func (r *ResolverLogger) Address() string { if ra, ok := r.Resolver.(resolverAddresser); ok { return ra.Address() } diff --git a/internal/netxlite/tlshandshaker.go b/internal/netxlite/tlshandshaker.go new file mode 100644 index 0000000..391ea53 --- /dev/null +++ b/internal/netxlite/tlshandshaker.go @@ -0,0 +1,46 @@ +package netxlite + +import ( + "context" + "crypto/tls" + "net" + "time" +) + +// TLSHandshaker is the generic TLS handshaker. +type TLSHandshaker interface { + // Handshake creates a new TLS connection from the given connection and + // the given config. This function DOES NOT take ownership of the connection + // and it's your responsibility to close it on failure. + Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( + net.Conn, tls.ConnectionState, error) +} + +// TLSHandshakerStdlib is the stdlib's TLS handshaker. +type TLSHandshakerStdlib struct { + // Timeout is the timeout imposed on the TLS handshake. If zero + // or negative, we will use default timeout of 10 seconds. + Timeout time.Duration +} + +var _ TLSHandshaker = &TLSHandshakerStdlib{} + +// Handshake implements Handshaker.Handshake +func (h *TLSHandshakerStdlib) Handshake( + ctx context.Context, conn net.Conn, config *tls.Config, +) (net.Conn, tls.ConnectionState, error) { + timeout := h.Timeout + if timeout <= 0 { + timeout = 10 * time.Second + } + defer conn.SetDeadline(time.Time{}) + conn.SetDeadline(time.Now().Add(timeout)) + tlsconn := tls.Client(conn, config) + if err := tlsconn.Handshake(); err != nil { + return nil, tls.ConnectionState{}, err + } + return tlsconn, tlsconn.ConnectionState(), nil +} + +// DefaultTLSHandshaker is the default TLS handshaker. +var DefaultTLSHandshaker = &TLSHandshakerStdlib{} diff --git a/internal/netxlite/tlshandshaker_test.go b/internal/netxlite/tlshandshaker_test.go new file mode 100644 index 0000000..5d9cd7a --- /dev/null +++ b/internal/netxlite/tlshandshaker_test.go @@ -0,0 +1,81 @@ +package netxlite + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/netxmocks" +) + +func TestTLSHandshakerStdlibWithError(t *testing.T) { + var times []time.Time + h := &TLSHandshakerStdlib{} + tcpConn := &netxmocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockSetDeadline: func(t time.Time) error { + times = append(times, t) + return nil + }, + } + ctx := context.Background() + conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{ + ServerName: "x.org", + }) + if err != io.EOF { + t.Fatal("not the error that we expected") + } + if conn != nil { + t.Fatal("expected nil con here") + } + if len(times) != 2 { + t.Fatal("expected two time entries") + } + if !times[0].After(time.Now()) { + t.Fatal("timeout not in the future") + } + if !times[1].IsZero() { + t.Fatal("did not clear timeout on exit") + } +} + +func TestTLSHandshakerStdlibSuccess(t *testing.T) { + handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(200) + }) + srvr := httptest.NewTLSServer(handler) + defer srvr.Close() + URL, err := url.Parse(srvr.URL) + if err != nil { + t.Fatal(err) + } + conn, err := net.Dial("tcp", URL.Host) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + handshaker := &TLSHandshakerStdlib{} + ctx := context.Background() + config := &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + ServerName: URL.Hostname(), + } + tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + if err != nil { + t.Fatal(err) + } + defer tlsConn.Close() + if connState.Version != tls.VersionTLS13 { + t.Fatal("unexpected TLS version") + } +}