refactor(dialer): it should close idle connections (#457)
Like we did before for the resolver, a dialer should propagate the request to close idle connections to underlying types. See https://github.com/ooni/probe/issues/1591
This commit is contained in:
@@ -10,15 +10,31 @@ import (
|
||||
type Dialer interface {
|
||||
// DialContext behaves like net.Dialer.DialContext.
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// CloseIdleConnections closes idle connections, if any.
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// defaultDialer is the Dialer we use by default.
|
||||
var defaultDialer = &net.Dialer{
|
||||
// underlyingDialer is the Dialer we use by default.
|
||||
var underlyingDialer = &net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 15 * time.Second,
|
||||
}
|
||||
|
||||
var _ Dialer = defaultDialer
|
||||
// dialerSystem dials using Go stdlib.
|
||||
type dialerSystem struct{}
|
||||
|
||||
// DialContext implements Dialer.DialContext.
|
||||
func (d *dialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return underlyingDialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||
func (d *dialerSystem) CloseIdleConnections() {
|
||||
// nothing
|
||||
}
|
||||
|
||||
var defaultDialer Dialer = &dialerSystem{}
|
||||
|
||||
// dialerResolver is a dialer that uses the configured Resolver to resolver a
|
||||
// domain name to IP addresses, and the configured Dialer to connect.
|
||||
@@ -66,6 +82,12 @@ func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]str
|
||||
return d.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||
func (d *dialerResolver) CloseIdleConnections() {
|
||||
d.Dialer.CloseIdleConnections()
|
||||
d.Resolver.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// dialerLogger is a Dialer with logging.
|
||||
type dialerLogger struct {
|
||||
// Dialer is the underlying dialer.
|
||||
@@ -90,3 +112,8 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string)
|
||||
d.Logger.Debugf("dial %s/%s... ok in %s", address, network, elapsed)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Dialer.CloseIdleConnections.
|
||||
func (d *dialerLogger) CloseIdleConnections() {
|
||||
d.Dialer.CloseIdleConnections()
|
||||
}
|
||||
|
||||
@@ -13,8 +13,13 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDialerSystemCloseIdleConnections(t *testing.T) {
|
||||
d := &dialerSystem{}
|
||||
d.CloseIdleConnections() // should not crash
|
||||
}
|
||||
|
||||
func TestDialerResolverNoPort(t *testing.T) {
|
||||
dialer := &dialerResolver{Dialer: &net.Dialer{}, Resolver: DefaultResolver}
|
||||
dialer := &dialerResolver{Dialer: defaultDialer, Resolver: DefaultResolver}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "ooni.nu")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||
t.Fatal("not the error we expected", err)
|
||||
@@ -25,7 +30,7 @@ func TestDialerResolverNoPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDialerResolverLookupHostAddress(t *testing.T) {
|
||||
dialer := &dialerResolver{Dialer: new(net.Dialer), Resolver: &mocks.Resolver{
|
||||
dialer := &dialerResolver{Dialer: defaultDialer, Resolver: &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, errors.New("we should not call this function")
|
||||
},
|
||||
@@ -41,7 +46,7 @@ func TestDialerResolverLookupHostAddress(t *testing.T) {
|
||||
|
||||
func TestDialerResolverLookupHostFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
dialer := &dialerResolver{Dialer: new(net.Dialer), Resolver: &mocks.Resolver{
|
||||
dialer := &dialerResolver{Dialer: defaultDialer, Resolver: &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
@@ -115,6 +120,29 @@ func TestDialerResolverDialForManyIPSuccess(t *testing.T) {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialerResolverCloseIdleConnections(t *testing.T) {
|
||||
var (
|
||||
calledDialer bool
|
||||
calledResolver bool
|
||||
)
|
||||
d := &dialerResolver{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
calledDialer = true
|
||||
},
|
||||
},
|
||||
Resolver: &mocks.Resolver{
|
||||
MockCloseIdleConnections: func() {
|
||||
calledResolver = true
|
||||
},
|
||||
},
|
||||
}
|
||||
d.CloseIdleConnections()
|
||||
if !calledDialer || !calledResolver {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialerLoggerSuccess(t *testing.T) {
|
||||
d := &dialerLogger{
|
||||
Dialer: &mocks.Dialer{
|
||||
@@ -156,9 +184,26 @@ func TestDialerLoggerFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultDialerHasTimeout(t *testing.T) {
|
||||
func TestDialerLoggerCloseIdleConnections(t *testing.T) {
|
||||
var (
|
||||
calledDialer bool
|
||||
)
|
||||
d := &dialerLogger{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
calledDialer = true
|
||||
},
|
||||
},
|
||||
}
|
||||
d.CloseIdleConnections()
|
||||
if !calledDialer {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnderlyingDialerHasTimeout(t *testing.T) {
|
||||
expected := 15 * time.Second
|
||||
if defaultDialer.Timeout != expected {
|
||||
if underlyingDialer.Timeout != expected {
|
||||
t.Fatal("unexpected timeout value")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package netxlite
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/errorsx"
|
||||
@@ -59,6 +60,7 @@ type (
|
||||
ResolverIDNA = resolverIDNA
|
||||
TLSHandshakerConfigurable = tlsHandshakerConfigurable
|
||||
TLSHandshakerLogger = tlsHandshakerLogger
|
||||
DialerSystem = dialerSystem
|
||||
)
|
||||
|
||||
// ResolverLegacy performs domain name resolutions.
|
||||
@@ -122,3 +124,41 @@ func (r *ResolverLegacyAdapter) CloseIdleConnections() {
|
||||
ra.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// DialerLegacy establishes network connections.
|
||||
//
|
||||
// This definition is DEPRECATED. Please, use Dialer.
|
||||
//
|
||||
// Existing code in probe-cli can use it until we
|
||||
// have finished refactoring it.
|
||||
type DialerLegacy interface {
|
||||
// DialContext behaves like net.Dialer.DialContext.
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// NewDialerLegacyAdapter adapts a DialerrLegacy to
|
||||
// become compatible with the Dialer definition.
|
||||
func NewDialerLegacyAdapter(d DialerLegacy) Dialer {
|
||||
return &DialerLegacyAdapter{d}
|
||||
}
|
||||
|
||||
// DialerLegacyAdapter makes a DialerLegacy behave like
|
||||
// it was a Dialer type. If DialerLegacy is actually also
|
||||
// a Dialer, this adapter will just forward missing calls,
|
||||
// otherwise it will implement a sensible default action.
|
||||
type DialerLegacyAdapter struct {
|
||||
DialerLegacy
|
||||
}
|
||||
|
||||
var _ Dialer = &DialerLegacyAdapter{}
|
||||
|
||||
type dialerLegacyIdleConnectionsCloser interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Resolver.CloseIdleConnections.
|
||||
func (d *DialerLegacyAdapter) CloseIdleConnections() {
|
||||
if ra, ok := d.DialerLegacy.(dialerLegacyIdleConnectionsCloser); ok {
|
||||
ra.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,3 +82,21 @@ func TestResolverLegacyAdapterDefaults(t *testing.T) {
|
||||
}
|
||||
r.CloseIdleConnections() // does not crash
|
||||
}
|
||||
|
||||
func TestDialerLegacyAdapterWithCompatibleType(t *testing.T) {
|
||||
var called bool
|
||||
r := NewDialerLegacyAdapter(&mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
})
|
||||
r.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialerLegacyAdapterDefaults(t *testing.T) {
|
||||
r := NewDialerLegacyAdapter(&net.Dialer{})
|
||||
r.CloseIdleConnections() // does not crash
|
||||
}
|
||||
|
||||
@@ -7,10 +7,16 @@ import (
|
||||
|
||||
// Dialer is a mockable Dialer.
|
||||
type Dialer struct {
|
||||
MockDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
MockDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
MockCloseIdleConnections func()
|
||||
}
|
||||
|
||||
// DialContext calls MockDialContext.
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.MockDialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// CloseIdleConnections calls MockCloseIdleConnections.
|
||||
func (d *Dialer) CloseIdleConnections() {
|
||||
d.MockCloseIdleConnections()
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDialerWorks(t *testing.T) {
|
||||
func TestDialerDialContext(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
@@ -23,3 +23,16 @@ func TestDialerWorks(t *testing.T) {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialerCloseIdleConnections(t *testing.T) {
|
||||
var called bool
|
||||
d := &Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
d.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -294,7 +294,7 @@ func TestTLSDialerFailureSplitHostPort(t *testing.T) {
|
||||
func TestTLSDialerFailureDialing(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // immediately fail
|
||||
dialer := TLSDialer{Dialer: &net.Dialer{}}
|
||||
dialer := TLSDialer{Dialer: defaultDialer}
|
||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
|
||||
t.Fatal("not the error we expected", err)
|
||||
|
||||
Reference in New Issue
Block a user