diff --git a/internal/cmd/oohelperd/internal/websteps/generate.go b/internal/cmd/oohelperd/internal/websteps/generate.go index 3a7fa51..dee55c9 100644 --- a/internal/cmd/oohelperd/internal/websteps/generate.go +++ b/internal/cmd/oohelperd/internal/websteps/generate.go @@ -22,7 +22,7 @@ type Generator interface { // DefaultGenerator is the default Generator. type DefaultGenerator struct { - dialer netxlite.Dialer + dialer netxlite.DialerLegacy quicDialer netxlite.QUICContextDialer resolver netxlite.ResolverLegacy transport http.RoundTripper diff --git a/internal/engine/experiment/websteps/factory.go b/internal/engine/experiment/websteps/factory.go index 0b7f317..c588611 100644 --- a/internal/engine/experiment/websteps/factory.go +++ b/internal/engine/experiment/websteps/factory.go @@ -33,12 +33,12 @@ func NewRequest(ctx context.Context, URL *url.URL, headers http.Header) *http.Re // NewDialerResolver contructs a new dialer for TCP connections, // with default, errorwrapping and resolve functionalities -func NewDialerResolver(resolver netxlite.ResolverLegacy) netxlite.Dialer { - var d netxlite.Dialer = netxlite.DefaultDialer +func NewDialerResolver(resolver netxlite.ResolverLegacy) netxlite.DialerLegacy { + var d netxlite.DialerLegacy = netxlite.DefaultDialer d = &errorsx.ErrorWrapperDialer{Dialer: d} d = &netxlite.DialerResolver{ Resolver: netxlite.NewResolverLegacyAdapter(resolver), - Dialer: d, + Dialer: netxlite.NewDialerLegacyAdapter(d), } return d } @@ -80,12 +80,12 @@ func NewSingleTransport(conn net.Conn) http.RoundTripper { } // NewSingleTransport creates a new HTTP transport with a custom dialer and handshaker. -func NewTransportWithDialer(dialer netxlite.Dialer, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper { +func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper { transport := newBaseTransport() transport.DialContext = dialer.DialContext transport.DialTLSContext = (&netxlite.TLSDialer{ Config: tlsConfig, - Dialer: dialer, + Dialer: netxlite.NewDialerLegacyAdapter(dialer), TLSHandshaker: handshaker, }).DialTLSContext return transport diff --git a/internal/engine/experiment/websteps/tcp.go b/internal/engine/experiment/websteps/tcp.go index 67bea47..a8b34c7 100644 --- a/internal/engine/experiment/websteps/tcp.go +++ b/internal/engine/experiment/websteps/tcp.go @@ -8,7 +8,7 @@ import ( ) type TCPConfig struct { - Dialer netxlite.Dialer + Dialer netxlite.DialerLegacy Endpoint string Resolver netxlite.ResolverLegacy } diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index 48276de..3dbf3ce 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -106,7 +106,7 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer { return &netxlite.TLSDialer{ Config: config, - Dialer: d, + Dialer: netxlite.NewDialerLegacyAdapter(d), TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ TLSHandshaker: &errorsx.ErrorWrapperTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, diff --git a/internal/engine/netx/dialer/dialer.go b/internal/engine/netx/dialer/dialer.go index 31a196a..276bf60 100644 --- a/internal/engine/netx/dialer/dialer.go +++ b/internal/engine/netx/dialer/dialer.go @@ -72,7 +72,10 @@ func New(config *Config, resolver Resolver) Dialer { var d Dialer = netxlite.DefaultDialer d = &errorsx.ErrorWrapperDialer{Dialer: d} if config.Logger != nil { - d = &netxlite.DialerLogger{Dialer: d, Logger: config.Logger} + d = &netxlite.DialerLogger{ + Dialer: netxlite.NewDialerLegacyAdapter(d), + Logger: config.Logger, + } } if config.DialSaver != nil { d = &saverDialer{Dialer: d, Saver: config.DialSaver} @@ -82,7 +85,7 @@ func New(config *Config, resolver Resolver) Dialer { } d = &netxlite.DialerResolver{ Resolver: netxlite.NewResolverLegacyAdapter(resolver), - Dialer: d, + Dialer: netxlite.NewDialerLegacyAdapter(d), } d = &proxyDialer{ProxyURL: config.ProxyURL, Dialer: d} if config.ContextByteCounting { diff --git a/internal/engine/netx/dialer/dialer_test.go b/internal/engine/netx/dialer/dialer_test.go index e038bc9..9535592 100644 --- a/internal/engine/netx/dialer/dialer_test.go +++ b/internal/engine/netx/dialer/dialer_test.go @@ -36,7 +36,11 @@ func TestNewCreatesTheExpectedChain(t *testing.T) { if !ok { t.Fatal("not a dnsDialer") } - scd, ok := dnsd.Dialer.(*saverConnDialer) + dad, ok := dnsd.Dialer.(*netxlite.DialerLegacyAdapter) + if !ok { + t.Fatal("invalid type") + } + scd, ok := dad.DialerLegacy.(*saverConnDialer) if !ok { t.Fatal("not a saverConnDialer") } @@ -48,12 +52,16 @@ func TestNewCreatesTheExpectedChain(t *testing.T) { if !ok { t.Fatal("not a loggingDialer") } - ewd, ok := ld.Dialer.(*errorsx.ErrorWrapperDialer) + dad, ok = ld.Dialer.(*netxlite.DialerLegacyAdapter) + if !ok { + t.Fatal("invalid type") + } + ewd, ok := dad.DialerLegacy.(*errorsx.ErrorWrapperDialer) if !ok { t.Fatal("not an errorWrappingDialer") } - _, ok = ewd.Dialer.(*net.Dialer) + _, ok = ewd.Dialer.(*netxlite.DialerSystem) if !ok { - t.Fatal("not a net.Dialer") + t.Fatal("not a DialerSystem") } } diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 354b087..5a68f26 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -209,7 +209,7 @@ func NewTLSDialer(config Config) TLSDialer { config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify return &netxlite.TLSDialer{ Config: config.TLSConfig, - Dialer: config.Dialer, + Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer), TLSHandshaker: h, } } diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 357be66..103bedc 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -1,8 +1,10 @@ package netx_test import ( + "context" "crypto/tls" "errors" + "net" "net/http" "strings" "testing" @@ -16,6 +18,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" "github.com/ooni/probe-cli/v3/internal/errorsx" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) func TestNewResolverVanilla(t *testing.T) { @@ -486,8 +489,15 @@ func TestNewWithDialer(t *testing.T) { func TestNewWithTLSDialer(t *testing.T) { expected := errors.New("mocked error") tlsDialer := &netxlite.TLSDialer{ - Config: new(tls.Config), - Dialer: netx.FakeDialer{Err: expected}, + Config: new(tls.Config), + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }, + MockCloseIdleConnections: func() { + // nothing + }, + }, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, } txp := netx.NewHTTPTransport(netx.Config{ diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go index d9795c3..735d18c 100644 --- a/internal/engine/netx/tlsdialer/integration_test.go +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -1,7 +1,6 @@ package tlsdialer_test import ( - "net" "net/http" "testing" @@ -14,7 +13,7 @@ func TestTLSDialerSuccess(t *testing.T) { t.Skip("skip test in short mode") } log.SetLevel(log.DebugLevel) - dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer), + dialer := &netxlite.TLSDialer{Dialer: netxlite.DefaultDialer, TLSHandshaker: &netxlite.TLSHandshakerLogger{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Logger: log.Log, diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go index 83d153c..515ff59 100644 --- a/internal/engine/netx/tlsdialer/saver_test.go +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -24,7 +24,9 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { saver := &trace.Saver{} tlsdlr := &netxlite.TLSDialer{ Config: &tls.Config{NextProtos: nextprotos}, - Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), + Dialer: netxlite.NewDialerLegacyAdapter( + dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), + ), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, @@ -117,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) { saver := &trace.Saver{} tlsdlr := &netxlite.TLSDialer{ Config: &tls.Config{NextProtos: nextprotos}, - Dialer: new(net.Dialer), + Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, @@ -182,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) { } saver := &trace.Saver{} tlsdlr := &netxlite.TLSDialer{ - Dialer: new(net.Dialer), + Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, @@ -215,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { } saver := &trace.Saver{} tlsdlr := &netxlite.TLSDialer{ - Dialer: new(net.Dialer), + Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, @@ -248,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) { } saver := &trace.Saver{} tlsdlr := &netxlite.TLSDialer{ - Dialer: new(net.Dialer), + Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, @@ -282,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { saver := &trace.Saver{} tlsdlr := &netxlite.TLSDialer{ Config: &tls.Config{InsecureSkipVerify: true}, - Dialer: new(net.Dialer), + Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index 224b983..3548cb1 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -10,15 +10,31 @@ import ( type Dialer interface { // DialContext behaves like net.Dialer.DialContext. DialContext(ctx context.Context, network, address string) (net.Conn, error) + + // CloseIdleConnections closes idle connections, if any. + CloseIdleConnections() } -// defaultDialer is the Dialer we use by default. -var defaultDialer = &net.Dialer{ +// underlyingDialer is the Dialer we use by default. +var underlyingDialer = &net.Dialer{ Timeout: 15 * time.Second, KeepAlive: 15 * time.Second, } -var _ Dialer = defaultDialer +// dialerSystem dials using Go stdlib. +type dialerSystem struct{} + +// DialContext implements Dialer.DialContext. +func (d *dialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return underlyingDialer.DialContext(ctx, network, address) +} + +// CloseIdleConnections implements Dialer.CloseIdleConnections. +func (d *dialerSystem) CloseIdleConnections() { + // nothing +} + +var defaultDialer Dialer = &dialerSystem{} // dialerResolver is a dialer that uses the configured Resolver to resolver a // domain name to IP addresses, and the configured Dialer to connect. @@ -66,6 +82,12 @@ func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]str return d.Resolver.LookupHost(ctx, hostname) } +// CloseIdleConnections implements Dialer.CloseIdleConnections. +func (d *dialerResolver) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() + d.Resolver.CloseIdleConnections() +} + // dialerLogger is a Dialer with logging. type dialerLogger struct { // Dialer is the underlying dialer. @@ -90,3 +112,8 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string) d.Logger.Debugf("dial %s/%s... ok in %s", address, network, elapsed) return conn, nil } + +// CloseIdleConnections implements Dialer.CloseIdleConnections. +func (d *dialerLogger) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() +} diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index dd6b042..c91d69c 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -13,8 +13,13 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) +func TestDialerSystemCloseIdleConnections(t *testing.T) { + d := &dialerSystem{} + d.CloseIdleConnections() // should not crash +} + func TestDialerResolverNoPort(t *testing.T) { - dialer := &dialerResolver{Dialer: &net.Dialer{}, Resolver: DefaultResolver} + dialer := &dialerResolver{Dialer: defaultDialer, Resolver: DefaultResolver} conn, err := dialer.DialContext(context.Background(), "tcp", "ooni.nu") if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { t.Fatal("not the error we expected", err) @@ -25,7 +30,7 @@ func TestDialerResolverNoPort(t *testing.T) { } func TestDialerResolverLookupHostAddress(t *testing.T) { - dialer := &dialerResolver{Dialer: new(net.Dialer), Resolver: &mocks.Resolver{ + dialer := &dialerResolver{Dialer: defaultDialer, Resolver: &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, errors.New("we should not call this function") }, @@ -41,7 +46,7 @@ func TestDialerResolverLookupHostAddress(t *testing.T) { func TestDialerResolverLookupHostFailure(t *testing.T) { expected := errors.New("mocked error") - dialer := &dialerResolver{Dialer: new(net.Dialer), Resolver: &mocks.Resolver{ + dialer := &dialerResolver{Dialer: defaultDialer, Resolver: &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, expected }, @@ -115,6 +120,29 @@ func TestDialerResolverDialForManyIPSuccess(t *testing.T) { conn.Close() } +func TestDialerResolverCloseIdleConnections(t *testing.T) { + var ( + calledDialer bool + calledResolver bool + ) + d := &dialerResolver{ + Dialer: &mocks.Dialer{ + MockCloseIdleConnections: func() { + calledDialer = true + }, + }, + Resolver: &mocks.Resolver{ + MockCloseIdleConnections: func() { + calledResolver = true + }, + }, + } + d.CloseIdleConnections() + if !calledDialer || !calledResolver { + t.Fatal("not called") + } +} + func TestDialerLoggerSuccess(t *testing.T) { d := &dialerLogger{ Dialer: &mocks.Dialer{ @@ -156,9 +184,26 @@ func TestDialerLoggerFailure(t *testing.T) { } } -func TestDefaultDialerHasTimeout(t *testing.T) { +func TestDialerLoggerCloseIdleConnections(t *testing.T) { + var ( + calledDialer bool + ) + d := &dialerLogger{ + Dialer: &mocks.Dialer{ + MockCloseIdleConnections: func() { + calledDialer = true + }, + }, + } + d.CloseIdleConnections() + if !calledDialer { + t.Fatal("not called") + } +} + +func TestUnderlyingDialerHasTimeout(t *testing.T) { expected := 15 * time.Second - if defaultDialer.Timeout != expected { + if underlyingDialer.Timeout != expected { t.Fatal("unexpected timeout value") } } diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 6e6aad1..ba4b941 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -3,6 +3,7 @@ package netxlite import ( "context" "errors" + "net" "strings" "github.com/ooni/probe-cli/v3/internal/errorsx" @@ -59,6 +60,7 @@ type ( ResolverIDNA = resolverIDNA TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerLogger = tlsHandshakerLogger + DialerSystem = dialerSystem ) // ResolverLegacy performs domain name resolutions. @@ -122,3 +124,41 @@ func (r *ResolverLegacyAdapter) CloseIdleConnections() { ra.CloseIdleConnections() } } + +// DialerLegacy establishes network connections. +// +// This definition is DEPRECATED. Please, use Dialer. +// +// Existing code in probe-cli can use it until we +// have finished refactoring it. +type DialerLegacy interface { + // DialContext behaves like net.Dialer.DialContext. + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// NewDialerLegacyAdapter adapts a DialerrLegacy to +// become compatible with the Dialer definition. +func NewDialerLegacyAdapter(d DialerLegacy) Dialer { + return &DialerLegacyAdapter{d} +} + +// DialerLegacyAdapter makes a DialerLegacy behave like +// it was a Dialer type. If DialerLegacy is actually also +// a Dialer, this adapter will just forward missing calls, +// otherwise it will implement a sensible default action. +type DialerLegacyAdapter struct { + DialerLegacy +} + +var _ Dialer = &DialerLegacyAdapter{} + +type dialerLegacyIdleConnectionsCloser interface { + CloseIdleConnections() +} + +// CloseIdleConnections implements Resolver.CloseIdleConnections. +func (d *DialerLegacyAdapter) CloseIdleConnections() { + if ra, ok := d.DialerLegacy.(dialerLegacyIdleConnectionsCloser); ok { + ra.CloseIdleConnections() + } +} diff --git a/internal/netxlite/legacy_test.go b/internal/netxlite/legacy_test.go index 8d44cce..9cc54e6 100644 --- a/internal/netxlite/legacy_test.go +++ b/internal/netxlite/legacy_test.go @@ -82,3 +82,21 @@ func TestResolverLegacyAdapterDefaults(t *testing.T) { } r.CloseIdleConnections() // does not crash } + +func TestDialerLegacyAdapterWithCompatibleType(t *testing.T) { + var called bool + r := NewDialerLegacyAdapter(&mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + }) + r.CloseIdleConnections() + if !called { + t.Fatal("not called") + } +} + +func TestDialerLegacyAdapterDefaults(t *testing.T) { + r := NewDialerLegacyAdapter(&net.Dialer{}) + r.CloseIdleConnections() // does not crash +} diff --git a/internal/netxlite/mocks/dialer.go b/internal/netxlite/mocks/dialer.go index f37affc..f896cca 100644 --- a/internal/netxlite/mocks/dialer.go +++ b/internal/netxlite/mocks/dialer.go @@ -7,10 +7,16 @@ import ( // Dialer is a mockable Dialer. type Dialer struct { - MockDialContext func(ctx context.Context, network, address string) (net.Conn, error) + MockDialContext func(ctx context.Context, network, address string) (net.Conn, error) + MockCloseIdleConnections func() } // DialContext calls MockDialContext. func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.MockDialContext(ctx, network, address) } + +// CloseIdleConnections calls MockCloseIdleConnections. +func (d *Dialer) CloseIdleConnections() { + d.MockCloseIdleConnections() +} diff --git a/internal/netxlite/mocks/dialer_test.go b/internal/netxlite/mocks/dialer_test.go index a114d14..6166e94 100644 --- a/internal/netxlite/mocks/dialer_test.go +++ b/internal/netxlite/mocks/dialer_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -func TestDialerWorks(t *testing.T) { +func TestDialerDialContext(t *testing.T) { expected := errors.New("mocked error") d := Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { @@ -23,3 +23,16 @@ func TestDialerWorks(t *testing.T) { t.Fatal("expected nil conn") } } + +func TestDialerCloseIdleConnections(t *testing.T) { + var called bool + d := &Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + d.CloseIdleConnections() + if !called { + t.Fatal("not called") + } +} diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 4866bac..c9b39a3 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -294,7 +294,7 @@ func TestTLSDialerFailureSplitHostPort(t *testing.T) { func TestTLSDialerFailureDialing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail - dialer := TLSDialer{Dialer: &net.Dialer{}} + dialer := TLSDialer{Dialer: defaultDialer} conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { t.Fatal("not the error we expected", err)