diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index 73e14cf..7be093c 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -228,6 +228,11 @@ type TLSDialer struct { TLSHandshaker TLSHandshaker } +// CloseIdleConnection closes idle connections, if any. +func (d *TLSDialer) CloseIdleConnection() { + d.Dialer.CloseIdleConnections() +} + // DialTLSContext dials a TLS connection. func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 50b5276..4c7bcd4 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -278,7 +278,22 @@ func TestTLSHandshakerLoggerFailure(t *testing.T) { } } -func TestTLSDialerFailureSplitHostPort(t *testing.T) { +func TestTLSDialerCloseIdleConnections(t *testing.T) { + var called bool + dialer := &TLSDialer{ + Dialer: &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + }, + } + dialer.CloseIdleConnection() + if !called { + t.Fatal("not called") + } +} + +func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { dialer := &TLSDialer{} ctx := context.Background() const address = "www.google.com" // missing port @@ -291,7 +306,7 @@ func TestTLSDialerFailureSplitHostPort(t *testing.T) { } } -func TestTLSDialerFailureDialing(t *testing.T) { +func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail dialer := TLSDialer{Dialer: defaultDialer} @@ -304,7 +319,7 @@ func TestTLSDialerFailureDialing(t *testing.T) { } } -func TestTLSDialerFailureHandshaking(t *testing.T) { +func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { ctx := context.Background() dialer := TLSDialer{ Config: &tls.Config{}, @@ -328,7 +343,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) { } } -func TestTLSDialerSuccessHandshaking(t *testing.T) { +func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { ctx := context.Background() dialer := TLSDialer{ Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {