refactor(dialer): it should close idle connections (#457)
Like we did before for the resolver, a dialer should propagate the request to close idle connections to underlying types. See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
a3a27b1ebf
commit
7a9499fee3
|
@ -22,7 +22,7 @@ type Generator interface {
|
||||||
|
|
||||||
// DefaultGenerator is the default Generator.
|
// DefaultGenerator is the default Generator.
|
||||||
type DefaultGenerator struct {
|
type DefaultGenerator struct {
|
||||||
dialer netxlite.Dialer
|
dialer netxlite.DialerLegacy
|
||||||
quicDialer netxlite.QUICContextDialer
|
quicDialer netxlite.QUICContextDialer
|
||||||
resolver netxlite.ResolverLegacy
|
resolver netxlite.ResolverLegacy
|
||||||
transport http.RoundTripper
|
transport http.RoundTripper
|
||||||
|
|
|
@ -33,12 +33,12 @@ func NewRequest(ctx context.Context, URL *url.URL, headers http.Header) *http.Re
|
||||||
|
|
||||||
// NewDialerResolver contructs a new dialer for TCP connections,
|
// NewDialerResolver contructs a new dialer for TCP connections,
|
||||||
// with default, errorwrapping and resolve functionalities
|
// with default, errorwrapping and resolve functionalities
|
||||||
func NewDialerResolver(resolver netxlite.ResolverLegacy) netxlite.Dialer {
|
func NewDialerResolver(resolver netxlite.ResolverLegacy) netxlite.DialerLegacy {
|
||||||
var d netxlite.Dialer = netxlite.DefaultDialer
|
var d netxlite.DialerLegacy = netxlite.DefaultDialer
|
||||||
d = &errorsx.ErrorWrapperDialer{Dialer: d}
|
d = &errorsx.ErrorWrapperDialer{Dialer: d}
|
||||||
d = &netxlite.DialerResolver{
|
d = &netxlite.DialerResolver{
|
||||||
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
|
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
|
||||||
Dialer: d,
|
Dialer: netxlite.NewDialerLegacyAdapter(d),
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
@ -80,12 +80,12 @@ func NewSingleTransport(conn net.Conn) http.RoundTripper {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSingleTransport creates a new HTTP transport with a custom dialer and handshaker.
|
// NewSingleTransport creates a new HTTP transport with a custom dialer and handshaker.
|
||||||
func NewTransportWithDialer(dialer netxlite.Dialer, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper {
|
func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper {
|
||||||
transport := newBaseTransport()
|
transport := newBaseTransport()
|
||||||
transport.DialContext = dialer.DialContext
|
transport.DialContext = dialer.DialContext
|
||||||
transport.DialTLSContext = (&netxlite.TLSDialer{
|
transport.DialTLSContext = (&netxlite.TLSDialer{
|
||||||
Config: tlsConfig,
|
Config: tlsConfig,
|
||||||
Dialer: dialer,
|
Dialer: netxlite.NewDialerLegacyAdapter(dialer),
|
||||||
TLSHandshaker: handshaker,
|
TLSHandshaker: handshaker,
|
||||||
}).DialTLSContext
|
}).DialTLSContext
|
||||||
return transport
|
return transport
|
||||||
|
|
|
@ -8,7 +8,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type TCPConfig struct {
|
type TCPConfig struct {
|
||||||
Dialer netxlite.Dialer
|
Dialer netxlite.DialerLegacy
|
||||||
Endpoint string
|
Endpoint string
|
||||||
Resolver netxlite.ResolverLegacy
|
Resolver netxlite.ResolverLegacy
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,7 +106,7 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
|
||||||
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
|
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
|
||||||
return &netxlite.TLSDialer{
|
return &netxlite.TLSDialer{
|
||||||
Config: config,
|
Config: config,
|
||||||
Dialer: d,
|
Dialer: netxlite.NewDialerLegacyAdapter(d),
|
||||||
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
|
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
|
||||||
TLSHandshaker: &errorsx.ErrorWrapperTLSHandshaker{
|
TLSHandshaker: &errorsx.ErrorWrapperTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
|
|
|
@ -72,7 +72,10 @@ func New(config *Config, resolver Resolver) Dialer {
|
||||||
var d Dialer = netxlite.DefaultDialer
|
var d Dialer = netxlite.DefaultDialer
|
||||||
d = &errorsx.ErrorWrapperDialer{Dialer: d}
|
d = &errorsx.ErrorWrapperDialer{Dialer: d}
|
||||||
if config.Logger != nil {
|
if config.Logger != nil {
|
||||||
d = &netxlite.DialerLogger{Dialer: d, Logger: config.Logger}
|
d = &netxlite.DialerLogger{
|
||||||
|
Dialer: netxlite.NewDialerLegacyAdapter(d),
|
||||||
|
Logger: config.Logger,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if config.DialSaver != nil {
|
if config.DialSaver != nil {
|
||||||
d = &saverDialer{Dialer: d, Saver: config.DialSaver}
|
d = &saverDialer{Dialer: d, Saver: config.DialSaver}
|
||||||
|
@ -82,7 +85,7 @@ func New(config *Config, resolver Resolver) Dialer {
|
||||||
}
|
}
|
||||||
d = &netxlite.DialerResolver{
|
d = &netxlite.DialerResolver{
|
||||||
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
|
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
|
||||||
Dialer: d,
|
Dialer: netxlite.NewDialerLegacyAdapter(d),
|
||||||
}
|
}
|
||||||
d = &proxyDialer{ProxyURL: config.ProxyURL, Dialer: d}
|
d = &proxyDialer{ProxyURL: config.ProxyURL, Dialer: d}
|
||||||
if config.ContextByteCounting {
|
if config.ContextByteCounting {
|
||||||
|
|
|
@ -36,7 +36,11 @@ func TestNewCreatesTheExpectedChain(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not a dnsDialer")
|
t.Fatal("not a dnsDialer")
|
||||||
}
|
}
|
||||||
scd, ok := dnsd.Dialer.(*saverConnDialer)
|
dad, ok := dnsd.Dialer.(*netxlite.DialerLegacyAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
scd, ok := dad.DialerLegacy.(*saverConnDialer)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not a saverConnDialer")
|
t.Fatal("not a saverConnDialer")
|
||||||
}
|
}
|
||||||
|
@ -48,12 +52,16 @@ func TestNewCreatesTheExpectedChain(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not a loggingDialer")
|
t.Fatal("not a loggingDialer")
|
||||||
}
|
}
|
||||||
ewd, ok := ld.Dialer.(*errorsx.ErrorWrapperDialer)
|
dad, ok = ld.Dialer.(*netxlite.DialerLegacyAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
ewd, ok := dad.DialerLegacy.(*errorsx.ErrorWrapperDialer)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not an errorWrappingDialer")
|
t.Fatal("not an errorWrappingDialer")
|
||||||
}
|
}
|
||||||
_, ok = ewd.Dialer.(*net.Dialer)
|
_, ok = ewd.Dialer.(*netxlite.DialerSystem)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not a net.Dialer")
|
t.Fatal("not a DialerSystem")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -209,7 +209,7 @@ func NewTLSDialer(config Config) TLSDialer {
|
||||||
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
|
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
|
||||||
return &netxlite.TLSDialer{
|
return &netxlite.TLSDialer{
|
||||||
Config: config.TLSConfig,
|
Config: config.TLSConfig,
|
||||||
Dialer: config.Dialer,
|
Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer),
|
||||||
TLSHandshaker: h,
|
TLSHandshaker: h,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
package netx_test
|
package netx_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -16,6 +18,7 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||||
"github.com/ooni/probe-cli/v3/internal/errorsx"
|
"github.com/ooni/probe-cli/v3/internal/errorsx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewResolverVanilla(t *testing.T) {
|
func TestNewResolverVanilla(t *testing.T) {
|
||||||
|
@ -487,7 +490,14 @@ func TestNewWithTLSDialer(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
tlsDialer := &netxlite.TLSDialer{
|
tlsDialer := &netxlite.TLSDialer{
|
||||||
Config: new(tls.Config),
|
Config: new(tls.Config),
|
||||||
Dialer: netx.FakeDialer{Err: expected},
|
Dialer: &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
// nothing
|
||||||
|
},
|
||||||
|
},
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
}
|
}
|
||||||
txp := netx.NewHTTPTransport(netx.Config{
|
txp := netx.NewHTTPTransport(netx.Config{
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package tlsdialer_test
|
package tlsdialer_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -14,7 +13,7 @@ func TestTLSDialerSuccess(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer),
|
dialer := &netxlite.TLSDialer{Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerLogger{
|
TLSHandshaker: &netxlite.TLSHandshakerLogger{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Logger: log.Log,
|
Logger: log.Log,
|
||||||
|
|
|
@ -24,7 +24,9 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
Config: &tls.Config{NextProtos: nextprotos},
|
||||||
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
|
Dialer: netxlite.NewDialerLegacyAdapter(
|
||||||
|
dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
|
||||||
|
),
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
|
@ -117,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
Config: &tls.Config{NextProtos: nextprotos},
|
||||||
Dialer: new(net.Dialer),
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
|
@ -182,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
|
||||||
}
|
}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Dialer: new(net.Dialer),
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
|
@ -215,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
|
||||||
}
|
}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Dialer: new(net.Dialer),
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
|
@ -248,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
|
||||||
}
|
}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Dialer: new(net.Dialer),
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
|
@ -282,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Config: &tls.Config{InsecureSkipVerify: true},
|
Config: &tls.Config{InsecureSkipVerify: true},
|
||||||
Dialer: new(net.Dialer),
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
|
|
|
@ -10,15 +10,31 @@ import (
|
||||||
type Dialer interface {
|
type Dialer interface {
|
||||||
// DialContext behaves like net.Dialer.DialContext.
|
// DialContext behaves like net.Dialer.DialContext.
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
|
||||||
|
// CloseIdleConnections closes idle connections, if any.
|
||||||
|
CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultDialer is the Dialer we use by default.
|
// underlyingDialer is the Dialer we use by default.
|
||||||
var defaultDialer = &net.Dialer{
|
var underlyingDialer = &net.Dialer{
|
||||||
Timeout: 15 * time.Second,
|
Timeout: 15 * time.Second,
|
||||||
KeepAlive: 15 * time.Second,
|
KeepAlive: 15 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Dialer = defaultDialer
|
// dialerSystem dials using Go stdlib.
|
||||||
|
type dialerSystem struct{}
|
||||||
|
|
||||||
|
// DialContext implements Dialer.DialContext.
|
||||||
|
func (d *dialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return underlyingDialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||||
|
func (d *dialerSystem) CloseIdleConnections() {
|
||||||
|
// nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultDialer Dialer = &dialerSystem{}
|
||||||
|
|
||||||
// dialerResolver is a dialer that uses the configured Resolver to resolver a
|
// dialerResolver is a dialer that uses the configured Resolver to resolver a
|
||||||
// domain name to IP addresses, and the configured Dialer to connect.
|
// domain name to IP addresses, and the configured Dialer to connect.
|
||||||
|
@ -66,6 +82,12 @@ func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]str
|
||||||
return d.Resolver.LookupHost(ctx, hostname)
|
return d.Resolver.LookupHost(ctx, hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||||
|
func (d *dialerResolver) CloseIdleConnections() {
|
||||||
|
d.Dialer.CloseIdleConnections()
|
||||||
|
d.Resolver.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
// dialerLogger is a Dialer with logging.
|
// dialerLogger is a Dialer with logging.
|
||||||
type dialerLogger struct {
|
type dialerLogger struct {
|
||||||
// Dialer is the underlying dialer.
|
// Dialer is the underlying dialer.
|
||||||
|
@ -90,3 +112,8 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string)
|
||||||
d.Logger.Debugf("dial %s/%s... ok in %s", address, network, elapsed)
|
d.Logger.Debugf("dial %s/%s... ok in %s", address, network, elapsed)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||||
|
func (d *dialerLogger) CloseIdleConnections() {
|
||||||
|
d.Dialer.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
|
@ -13,8 +13,13 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestDialerSystemCloseIdleConnections(t *testing.T) {
|
||||||
|
d := &dialerSystem{}
|
||||||
|
d.CloseIdleConnections() // should not crash
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialerResolverNoPort(t *testing.T) {
|
func TestDialerResolverNoPort(t *testing.T) {
|
||||||
dialer := &dialerResolver{Dialer: &net.Dialer{}, Resolver: DefaultResolver}
|
dialer := &dialerResolver{Dialer: defaultDialer, Resolver: DefaultResolver}
|
||||||
conn, err := dialer.DialContext(context.Background(), "tcp", "ooni.nu")
|
conn, err := dialer.DialContext(context.Background(), "tcp", "ooni.nu")
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
|
@ -25,7 +30,7 @@ func TestDialerResolverNoPort(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialerResolverLookupHostAddress(t *testing.T) {
|
func TestDialerResolverLookupHostAddress(t *testing.T) {
|
||||||
dialer := &dialerResolver{Dialer: new(net.Dialer), Resolver: &mocks.Resolver{
|
dialer := &dialerResolver{Dialer: defaultDialer, Resolver: &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return nil, errors.New("we should not call this function")
|
return nil, errors.New("we should not call this function")
|
||||||
},
|
},
|
||||||
|
@ -41,7 +46,7 @@ func TestDialerResolverLookupHostAddress(t *testing.T) {
|
||||||
|
|
||||||
func TestDialerResolverLookupHostFailure(t *testing.T) {
|
func TestDialerResolverLookupHostFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
dialer := &dialerResolver{Dialer: new(net.Dialer), Resolver: &mocks.Resolver{
|
dialer := &dialerResolver{Dialer: defaultDialer, Resolver: &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
|
@ -115,6 +120,29 @@ func TestDialerResolverDialForManyIPSuccess(t *testing.T) {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialerResolverCloseIdleConnections(t *testing.T) {
|
||||||
|
var (
|
||||||
|
calledDialer bool
|
||||||
|
calledResolver bool
|
||||||
|
)
|
||||||
|
d := &dialerResolver{
|
||||||
|
Dialer: &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
calledDialer = true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Resolver: &mocks.Resolver{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
calledResolver = true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
d.CloseIdleConnections()
|
||||||
|
if !calledDialer || !calledResolver {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialerLoggerSuccess(t *testing.T) {
|
func TestDialerLoggerSuccess(t *testing.T) {
|
||||||
d := &dialerLogger{
|
d := &dialerLogger{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
|
@ -156,9 +184,26 @@ func TestDialerLoggerFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultDialerHasTimeout(t *testing.T) {
|
func TestDialerLoggerCloseIdleConnections(t *testing.T) {
|
||||||
|
var (
|
||||||
|
calledDialer bool
|
||||||
|
)
|
||||||
|
d := &dialerLogger{
|
||||||
|
Dialer: &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
calledDialer = true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
d.CloseIdleConnections()
|
||||||
|
if !calledDialer {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnderlyingDialerHasTimeout(t *testing.T) {
|
||||||
expected := 15 * time.Second
|
expected := 15 * time.Second
|
||||||
if defaultDialer.Timeout != expected {
|
if underlyingDialer.Timeout != expected {
|
||||||
t.Fatal("unexpected timeout value")
|
t.Fatal("unexpected timeout value")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package netxlite
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/errorsx"
|
"github.com/ooni/probe-cli/v3/internal/errorsx"
|
||||||
|
@ -59,6 +60,7 @@ type (
|
||||||
ResolverIDNA = resolverIDNA
|
ResolverIDNA = resolverIDNA
|
||||||
TLSHandshakerConfigurable = tlsHandshakerConfigurable
|
TLSHandshakerConfigurable = tlsHandshakerConfigurable
|
||||||
TLSHandshakerLogger = tlsHandshakerLogger
|
TLSHandshakerLogger = tlsHandshakerLogger
|
||||||
|
DialerSystem = dialerSystem
|
||||||
)
|
)
|
||||||
|
|
||||||
// ResolverLegacy performs domain name resolutions.
|
// ResolverLegacy performs domain name resolutions.
|
||||||
|
@ -122,3 +124,41 @@ func (r *ResolverLegacyAdapter) CloseIdleConnections() {
|
||||||
ra.CloseIdleConnections()
|
ra.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialerLegacy establishes network connections.
|
||||||
|
//
|
||||||
|
// This definition is DEPRECATED. Please, use Dialer.
|
||||||
|
//
|
||||||
|
// Existing code in probe-cli can use it until we
|
||||||
|
// have finished refactoring it.
|
||||||
|
type DialerLegacy interface {
|
||||||
|
// DialContext behaves like net.Dialer.DialContext.
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDialerLegacyAdapter adapts a DialerrLegacy to
|
||||||
|
// become compatible with the Dialer definition.
|
||||||
|
func NewDialerLegacyAdapter(d DialerLegacy) Dialer {
|
||||||
|
return &DialerLegacyAdapter{d}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialerLegacyAdapter makes a DialerLegacy behave like
|
||||||
|
// it was a Dialer type. If DialerLegacy is actually also
|
||||||
|
// a Dialer, this adapter will just forward missing calls,
|
||||||
|
// otherwise it will implement a sensible default action.
|
||||||
|
type DialerLegacyAdapter struct {
|
||||||
|
DialerLegacy
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Dialer = &DialerLegacyAdapter{}
|
||||||
|
|
||||||
|
type dialerLegacyIdleConnectionsCloser interface {
|
||||||
|
CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections implements Resolver.CloseIdleConnections.
|
||||||
|
func (d *DialerLegacyAdapter) CloseIdleConnections() {
|
||||||
|
if ra, ok := d.DialerLegacy.(dialerLegacyIdleConnectionsCloser); ok {
|
||||||
|
ra.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -82,3 +82,21 @@ func TestResolverLegacyAdapterDefaults(t *testing.T) {
|
||||||
}
|
}
|
||||||
r.CloseIdleConnections() // does not crash
|
r.CloseIdleConnections() // does not crash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialerLegacyAdapterWithCompatibleType(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
r := NewDialerLegacyAdapter(&mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
})
|
||||||
|
r.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialerLegacyAdapterDefaults(t *testing.T) {
|
||||||
|
r := NewDialerLegacyAdapter(&net.Dialer{})
|
||||||
|
r.CloseIdleConnections() // does not crash
|
||||||
|
}
|
||||||
|
|
|
@ -8,9 +8,15 @@ import (
|
||||||
// Dialer is a mockable Dialer.
|
// Dialer is a mockable Dialer.
|
||||||
type Dialer struct {
|
type Dialer struct {
|
||||||
MockDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
MockDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
MockCloseIdleConnections func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext calls MockDialContext.
|
// DialContext calls MockDialContext.
|
||||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return d.MockDialContext(ctx, network, address)
|
return d.MockDialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections calls MockCloseIdleConnections.
|
||||||
|
func (d *Dialer) CloseIdleConnections() {
|
||||||
|
d.MockCloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDialerWorks(t *testing.T) {
|
func TestDialerDialContext(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
d := Dialer{
|
d := Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
|
@ -23,3 +23,16 @@ func TestDialerWorks(t *testing.T) {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialerCloseIdleConnections(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
d := &Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
d.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -294,7 +294,7 @@ func TestTLSDialerFailureSplitHostPort(t *testing.T) {
|
||||||
func TestTLSDialerFailureDialing(t *testing.T) {
|
func TestTLSDialerFailureDialing(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel() // immediately fail
|
cancel() // immediately fail
|
||||||
dialer := TLSDialer{Dialer: &net.Dialer{}}
|
dialer := TLSDialer{Dialer: defaultDialer}
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
|
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user