diff --git a/internal/engine/experiment/websteps/factory.go b/internal/engine/experiment/websteps/factory.go index c588611..621d3c6 100644 --- a/internal/engine/experiment/websteps/factory.go +++ b/internal/engine/experiment/websteps/factory.go @@ -83,7 +83,7 @@ func NewSingleTransport(conn net.Conn) http.RoundTripper { func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper { transport := newBaseTransport() transport.DialContext = dialer.DialContext - transport.DialTLSContext = (&netxlite.TLSDialer{ + transport.DialTLSContext = (&netxlite.TLSDialerLegacy{ Config: tlsConfig, Dialer: netxlite.NewDialerLegacyAdapter(dialer), TLSHandshaker: handshaker, diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index 3dbf3ce..2be290a 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -103,8 +103,8 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { // - SystemTLSHandshaker // // If you have others needs, manually build the chain you need. -func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer { - return &netxlite.TLSDialer{ +func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialerLegacy { + return &netxlite.TLSDialerLegacy{ Config: config, Dialer: netxlite.NewDialerLegacyAdapter(d), TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 5a68f26..4644353 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -207,7 +207,7 @@ func NewTLSDialer(config Config) TLSDialer { } config.TLSConfig.RootCAs = config.CertPool config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify - return &netxlite.TLSDialer{ + return &netxlite.TLSDialerLegacy{ Config: config.TLSConfig, Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer), TLSHandshaker: h, diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 103bedc..bc07d02 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -255,7 +255,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { func TestNewTLSDialerVanilla(t *testing.T) { td := netx.NewTLSDialer(netx.Config{}) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -287,7 +287,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSConfig: new(tls.Config), }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -316,7 +316,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ Logger: log.Log, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -356,7 +356,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSSaver: saver, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -396,7 +396,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { TLSConfig: new(tls.Config), NoTLSVerify: true, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -428,7 +428,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ NoTLSVerify: true, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -488,7 +488,7 @@ func TestNewWithDialer(t *testing.T) { func TestNewWithTLSDialer(t *testing.T) { expected := errors.New("mocked error") - tlsDialer := &netxlite.TLSDialer{ + tlsDialer := &netxlite.TLSDialerLegacy{ Config: new(tls.Config), Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go index 735d18c..7d7237d 100644 --- a/internal/engine/netx/tlsdialer/integration_test.go +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -13,7 +13,7 @@ func TestTLSDialerSuccess(t *testing.T) { t.Skip("skip test in short mode") } log.SetLevel(log.DebugLevel) - dialer := &netxlite.TLSDialer{Dialer: netxlite.DefaultDialer, + dialer := &netxlite.TLSDialerLegacy{Dialer: netxlite.DefaultDialer, TLSHandshaker: &netxlite.TLSHandshakerLogger{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Logger: log.Log, diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go index 515ff59..716e141 100644 --- a/internal/engine/netx/tlsdialer/saver_test.go +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -22,7 +22,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { } nextprotos := []string{"h2"} saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: netxlite.NewDialerLegacyAdapter( dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), @@ -117,7 +117,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) { } nextprotos := []string{"h2"} saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ @@ -183,7 +183,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, @@ -216,7 +216,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, @@ -249,7 +249,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, @@ -282,7 +282,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Config: &tls.Config{InsecureSkipVerify: true}, Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index e67f284..be8796c 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -2,7 +2,9 @@ package netxlite import ( "context" + "errors" "net" + "sync" "time" ) @@ -137,3 +139,38 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string) func (d *dialerLogger) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } + +// ErrNoConnReuse indicates we cannot reuse the connection provided +// to a single use (possibly TLS) dialer. +var ErrNoConnReuse = errors.New("cannot reuse connection") + +// NewSingleUseDialer returns a dialer that returns the given connection once +// and after that always fails with the ErrNoConnReuse error. +func NewSingleUseDialer(conn net.Conn) Dialer { + return &dialerSingleUse{conn: conn} +} + +// dialerSingleUse is the type of Dialer returned by NewSingleDialer. +type dialerSingleUse struct { + sync.Mutex + conn net.Conn +} + +var _ Dialer = &dialerSingleUse{} + +// DialContext implements Dialer.DialContext. +func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + defer s.Unlock() + s.Lock() + if s.conn == nil { + return nil, ErrNoConnReuse + } + var conn net.Conn + conn, s.conn = s.conn, nil + return conn, nil +} + +// CloseIdleConnections closes idle connections. +func (s *dialerSingleUse) CloseIdleConnections() { + // nothing +} diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 91b5478..5d89793 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -235,3 +235,24 @@ func TestNewDialerWithoutResolverChain(t *testing.T) { t.Fatal("invalid type") } } + +func TestNewSingleUseDialerWorksAsIntended(t *testing.T) { + conn := &mocks.Conn{} + d := NewSingleUseDialer(conn) + outconn, err := d.DialContext(context.Background(), "", "") + if err != nil { + t.Fatal(err) + } + if conn != outconn { + t.Fatal("invalid outconn") + } + for i := 0; i < 4; i++ { + outconn, err = d.DialContext(context.Background(), "", "") + if !errors.Is(err, ErrNoConnReuse) { + t.Fatal("not the error we expected", err) + } + if outconn != nil { + t.Fatal("expected nil outconn here") + } + } +} diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index ef7eeb7..81fb7f0 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -73,7 +73,7 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config, txp := http.DefaultTransport.(*http.Transport).Clone() dialer = &httpDialerWithReadTimeout{dialer} txp.DialContext = dialer.DialContext - txp.DialTLSContext = (&TLSDialer{ + txp.DialTLSContext = (&tlsDialer{ Config: tlsConfig, Dialer: dialer, TLSHandshaker: handshaker, diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 8df12be..9b30829 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -62,6 +62,7 @@ type ( TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerLogger = tlsHandshakerLogger DialerSystem = dialerSystem + TLSDialerLegacy = tlsDialer ) // ResolverLegacy performs domain name resolutions. diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index 7be093c..2f7d82b 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -216,8 +216,29 @@ func (h *tlsHandshakerLogger) Handshake( return tlsconn, state, nil } -// TLSDialer is the TLS dialer -type TLSDialer struct { +// TLSDialer is a Dialer dialing TLS connections. +type TLSDialer interface { + // CloseIdleConnections closes idle connections, if any. + CloseIdleConnections() + + // DialTLSContext dials a TLS connection. + DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// NewTLSDialer creates a new TLS dialer using the given dialer +// and TLS handshaker to establish TLS connections. +func NewTLSDialer(dialer Dialer, handshaker TLSHandshaker) TLSDialer { + return NewTLSDialerWithConfig(dialer, handshaker, &tls.Config{}) +} + +// NewTLSDialerWithConfig is like NewTLSDialer but takes an optional config +// parameter containing your desired TLS configuration. +func NewTLSDialerWithConfig(d Dialer, h TLSHandshaker, c *tls.Config) TLSDialer { + return &tlsDialer{Config: c, Dialer: d, TLSHandshaker: h} +} + +// tlsDialer is the TLS dialer +type tlsDialer struct { // Config is the OPTIONAL tls config. Config *tls.Config @@ -228,13 +249,15 @@ type TLSDialer struct { TLSHandshaker TLSHandshaker } -// CloseIdleConnection closes idle connections, if any. -func (d *TLSDialer) CloseIdleConnection() { +var _ TLSDialer = &tlsDialer{} + +// CloseIdleConnections implements TLSDialer.CloseIdleConnections. +func (d *tlsDialer) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -// DialTLSContext dials a TLS connection. -func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { +// DialTLSContext implements TLSDialer.DialTLSContext. +func (d *tlsDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err @@ -258,7 +281,7 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) // We set the ServerName field if not already set. // // We set the ALPN if the port is 443 or 853, if not already set. -func (d *TLSDialer) config(host, port string) *tls.Config { +func (d *tlsDialer) config(host, port string) *tls.Config { config := d.Config if config == nil { config = &tls.Config{} @@ -277,3 +300,22 @@ func (d *TLSDialer) config(host, port string) *tls.Config { } return config } + +// NewSingleUseTLSDialer is like NewSingleUseDialer but takes +// in input a TLSConn rather than a net.Conn. +func NewSingleUseTLSDialer(conn TLSConn) TLSDialer { + return &tlsDialerSingleUseAdapter{NewSingleUseDialer(conn)} +} + +// tlsDialerSingleUseAdapter adapts dialerSingleUse to +// be a TLSDialer type rather than a Dialer type. +type tlsDialerSingleUseAdapter struct { + Dialer +} + +var _ TLSDialer = &tlsDialerSingleUseAdapter{} + +// DialTLSContext implements TLSDialer.DialTLSContext. +func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.Dialer.DialContext(ctx, network, address) +} diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 4c7bcd4..db5fe38 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -280,21 +280,21 @@ func TestTLSHandshakerLoggerFailure(t *testing.T) { func TestTLSDialerCloseIdleConnections(t *testing.T) { var called bool - dialer := &TLSDialer{ + dialer := &tlsDialer{ Dialer: &mocks.Dialer{ MockCloseIdleConnections: func() { called = true }, }, } - dialer.CloseIdleConnection() + dialer.CloseIdleConnections() if !called { t.Fatal("not called") } } func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { - dialer := &TLSDialer{} + dialer := &tlsDialer{} ctx := context.Background() const address = "www.google.com" // missing port conn, err := dialer.DialTLSContext(ctx, "tcp", address) @@ -309,7 +309,7 @@ func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail - dialer := TLSDialer{Dialer: defaultDialer} + dialer := tlsDialer{Dialer: defaultDialer} conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { t.Fatal("not the error we expected", err) @@ -321,7 +321,7 @@ func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { ctx := context.Background() - dialer := TLSDialer{ + dialer := tlsDialer{ Config: &tls.Config{}, Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{MockWrite: func(b []byte) (int, error) { @@ -345,7 +345,7 @@ func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { ctx := context.Background() - dialer := TLSDialer{ + dialer := tlsDialer{ Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{MockWrite: func(b []byte) (int, error) { return 0, io.EOF @@ -372,7 +372,7 @@ func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { } func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { - d := &TLSDialer{} + d := &tlsDialer{} config := d.config("www.google.com", "443") if config.ServerName != "www.google.com" { t.Fatal("invalid server name") @@ -383,7 +383,7 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { } func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { - d := &TLSDialer{} + d := &tlsDialer{} config := d.config("dns.google", "853") if config.ServerName != "dns.google" { t.Fatal("invalid server name") @@ -394,7 +394,7 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { } func TestTLSDialerConfigWithServerName(t *testing.T) { - d := &TLSDialer{ + d := &tlsDialer{ Config: &tls.Config{ ServerName: "example.com", }, @@ -409,7 +409,7 @@ func TestTLSDialerConfigWithServerName(t *testing.T) { } func TestTLSDialerConfigWithALPN(t *testing.T) { - d := &TLSDialer{ + d := &tlsDialer{ Config: &tls.Config{ NextProtos: []string{"h2"}, }, @@ -440,3 +440,43 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) { t.Fatal("expected nil NewConn") } } + +func TestNewTLSDialerWorksAsIntended(t *testing.T) { + d := &mocks.Dialer{} + tlsh := &mocks.TLSHandshaker{} + td := NewTLSDialer(d, tlsh) + tdut, okay := td.(*tlsDialer) + if !okay { + t.Fatal("invalid type") + } + if tdut.Config == nil { + t.Fatal("unexpected config") + } + if tdut.Dialer != d { + t.Fatal("unexpected dialer") + } + if tdut.TLSHandshaker != tlsh { + t.Fatal("invalid handshaker") + } +} + +func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) { + conn := &mocks.TLSConn{} + d := NewSingleUseTLSDialer(conn) + outconn, err := d.DialTLSContext(context.Background(), "", "") + if err != nil { + t.Fatal(err) + } + if conn != outconn { + t.Fatal("invalid outconn") + } + for i := 0; i < 4; i++ { + outconn, err = d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, ErrNoConnReuse) { + t.Fatal("not the error we expected", err) + } + if outconn != nil { + t.Fatal("expected nil outconn here") + } + } +}