From 2572376fdb74c4d8693fa45223d3d4bc43e0e472 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Mon, 6 Sep 2021 14:12:30 +0200 Subject: [PATCH] feat(netxlite): implement single use {,tls} dialer (#464) This basically adapts already existing code inside websteps to instead be into the netxlite package, where it belongs. In the process, abstract the TLSDialer but keep a reference to the previous name to avoid refactoring existing code (just for now). While there, notice that the right name is CloseIdleConnections (i.e., plural not singular) and change the name. While there, since we abstracted TLSDialer to be an interface, create suitable factories for making a TLSDialer type from a Dialer and a TLSHandshaker. See https://github.com/ooni/probe/issues/1591 --- .../engine/experiment/websteps/factory.go | 2 +- internal/engine/legacy/netx/dialer.go | 4 +- internal/engine/netx/netx.go | 2 +- internal/engine/netx/netx_test.go | 14 ++--- .../engine/netx/tlsdialer/integration_test.go | 2 +- internal/engine/netx/tlsdialer/saver_test.go | 12 ++-- internal/netxlite/dialer.go | 37 ++++++++++++ internal/netxlite/dialer_test.go | 21 +++++++ internal/netxlite/http.go | 2 +- internal/netxlite/legacy.go | 1 + internal/netxlite/tls.go | 56 ++++++++++++++--- internal/netxlite/tls_test.go | 60 +++++++++++++++---- 12 files changed, 177 insertions(+), 36 deletions(-) diff --git a/internal/engine/experiment/websteps/factory.go b/internal/engine/experiment/websteps/factory.go index c588611..621d3c6 100644 --- a/internal/engine/experiment/websteps/factory.go +++ b/internal/engine/experiment/websteps/factory.go @@ -83,7 +83,7 @@ func NewSingleTransport(conn net.Conn) http.RoundTripper { func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper { transport := newBaseTransport() transport.DialContext = dialer.DialContext - transport.DialTLSContext = (&netxlite.TLSDialer{ + transport.DialTLSContext = (&netxlite.TLSDialerLegacy{ Config: tlsConfig, Dialer: netxlite.NewDialerLegacyAdapter(dialer), TLSHandshaker: handshaker, diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index 3dbf3ce..2be290a 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -103,8 +103,8 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { // - SystemTLSHandshaker // // If you have others needs, manually build the chain you need. -func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer { - return &netxlite.TLSDialer{ +func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialerLegacy { + return &netxlite.TLSDialerLegacy{ Config: config, Dialer: netxlite.NewDialerLegacyAdapter(d), TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 5a68f26..4644353 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -207,7 +207,7 @@ func NewTLSDialer(config Config) TLSDialer { } config.TLSConfig.RootCAs = config.CertPool config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify - return &netxlite.TLSDialer{ + return &netxlite.TLSDialerLegacy{ Config: config.TLSConfig, Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer), TLSHandshaker: h, diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 103bedc..bc07d02 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -255,7 +255,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { func TestNewTLSDialerVanilla(t *testing.T) { td := netx.NewTLSDialer(netx.Config{}) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -287,7 +287,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSConfig: new(tls.Config), }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -316,7 +316,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ Logger: log.Log, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -356,7 +356,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSSaver: saver, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -396,7 +396,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { TLSConfig: new(tls.Config), NoTLSVerify: true, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -428,7 +428,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ NoTLSVerify: true, }) - rtd, ok := td.(*netxlite.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialerLegacy) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -488,7 +488,7 @@ func TestNewWithDialer(t *testing.T) { func TestNewWithTLSDialer(t *testing.T) { expected := errors.New("mocked error") - tlsDialer := &netxlite.TLSDialer{ + tlsDialer := &netxlite.TLSDialerLegacy{ Config: new(tls.Config), Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go index 735d18c..7d7237d 100644 --- a/internal/engine/netx/tlsdialer/integration_test.go +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -13,7 +13,7 @@ func TestTLSDialerSuccess(t *testing.T) { t.Skip("skip test in short mode") } log.SetLevel(log.DebugLevel) - dialer := &netxlite.TLSDialer{Dialer: netxlite.DefaultDialer, + dialer := &netxlite.TLSDialerLegacy{Dialer: netxlite.DefaultDialer, TLSHandshaker: &netxlite.TLSHandshakerLogger{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Logger: log.Log, diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go index 515ff59..716e141 100644 --- a/internal/engine/netx/tlsdialer/saver_test.go +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -22,7 +22,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { } nextprotos := []string{"h2"} saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: netxlite.NewDialerLegacyAdapter( dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), @@ -117,7 +117,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) { } nextprotos := []string{"h2"} saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ @@ -183,7 +183,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, @@ -216,7 +216,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, @@ -249,7 +249,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, @@ -282,7 +282,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := &netxlite.TLSDialer{ + tlsdlr := &netxlite.TLSDialerLegacy{ Config: &tls.Config{InsecureSkipVerify: true}, Dialer: netxlite.DefaultDialer, TLSHandshaker: tlsdialer.SaverTLSHandshaker{ diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index e67f284..be8796c 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -2,7 +2,9 @@ package netxlite import ( "context" + "errors" "net" + "sync" "time" ) @@ -137,3 +139,38 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string) func (d *dialerLogger) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } + +// ErrNoConnReuse indicates we cannot reuse the connection provided +// to a single use (possibly TLS) dialer. +var ErrNoConnReuse = errors.New("cannot reuse connection") + +// NewSingleUseDialer returns a dialer that returns the given connection once +// and after that always fails with the ErrNoConnReuse error. +func NewSingleUseDialer(conn net.Conn) Dialer { + return &dialerSingleUse{conn: conn} +} + +// dialerSingleUse is the type of Dialer returned by NewSingleDialer. +type dialerSingleUse struct { + sync.Mutex + conn net.Conn +} + +var _ Dialer = &dialerSingleUse{} + +// DialContext implements Dialer.DialContext. +func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + defer s.Unlock() + s.Lock() + if s.conn == nil { + return nil, ErrNoConnReuse + } + var conn net.Conn + conn, s.conn = s.conn, nil + return conn, nil +} + +// CloseIdleConnections closes idle connections. +func (s *dialerSingleUse) CloseIdleConnections() { + // nothing +} diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 91b5478..5d89793 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -235,3 +235,24 @@ func TestNewDialerWithoutResolverChain(t *testing.T) { t.Fatal("invalid type") } } + +func TestNewSingleUseDialerWorksAsIntended(t *testing.T) { + conn := &mocks.Conn{} + d := NewSingleUseDialer(conn) + outconn, err := d.DialContext(context.Background(), "", "") + if err != nil { + t.Fatal(err) + } + if conn != outconn { + t.Fatal("invalid outconn") + } + for i := 0; i < 4; i++ { + outconn, err = d.DialContext(context.Background(), "", "") + if !errors.Is(err, ErrNoConnReuse) { + t.Fatal("not the error we expected", err) + } + if outconn != nil { + t.Fatal("expected nil outconn here") + } + } +} diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index ef7eeb7..81fb7f0 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -73,7 +73,7 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config, txp := http.DefaultTransport.(*http.Transport).Clone() dialer = &httpDialerWithReadTimeout{dialer} txp.DialContext = dialer.DialContext - txp.DialTLSContext = (&TLSDialer{ + txp.DialTLSContext = (&tlsDialer{ Config: tlsConfig, Dialer: dialer, TLSHandshaker: handshaker, diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 8df12be..9b30829 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -62,6 +62,7 @@ type ( TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerLogger = tlsHandshakerLogger DialerSystem = dialerSystem + TLSDialerLegacy = tlsDialer ) // ResolverLegacy performs domain name resolutions. diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index 7be093c..2f7d82b 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -216,8 +216,29 @@ func (h *tlsHandshakerLogger) Handshake( return tlsconn, state, nil } -// TLSDialer is the TLS dialer -type TLSDialer struct { +// TLSDialer is a Dialer dialing TLS connections. +type TLSDialer interface { + // CloseIdleConnections closes idle connections, if any. + CloseIdleConnections() + + // DialTLSContext dials a TLS connection. + DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// NewTLSDialer creates a new TLS dialer using the given dialer +// and TLS handshaker to establish TLS connections. +func NewTLSDialer(dialer Dialer, handshaker TLSHandshaker) TLSDialer { + return NewTLSDialerWithConfig(dialer, handshaker, &tls.Config{}) +} + +// NewTLSDialerWithConfig is like NewTLSDialer but takes an optional config +// parameter containing your desired TLS configuration. +func NewTLSDialerWithConfig(d Dialer, h TLSHandshaker, c *tls.Config) TLSDialer { + return &tlsDialer{Config: c, Dialer: d, TLSHandshaker: h} +} + +// tlsDialer is the TLS dialer +type tlsDialer struct { // Config is the OPTIONAL tls config. Config *tls.Config @@ -228,13 +249,15 @@ type TLSDialer struct { TLSHandshaker TLSHandshaker } -// CloseIdleConnection closes idle connections, if any. -func (d *TLSDialer) CloseIdleConnection() { +var _ TLSDialer = &tlsDialer{} + +// CloseIdleConnections implements TLSDialer.CloseIdleConnections. +func (d *tlsDialer) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -// DialTLSContext dials a TLS connection. -func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { +// DialTLSContext implements TLSDialer.DialTLSContext. +func (d *tlsDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err @@ -258,7 +281,7 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) // We set the ServerName field if not already set. // // We set the ALPN if the port is 443 or 853, if not already set. -func (d *TLSDialer) config(host, port string) *tls.Config { +func (d *tlsDialer) config(host, port string) *tls.Config { config := d.Config if config == nil { config = &tls.Config{} @@ -277,3 +300,22 @@ func (d *TLSDialer) config(host, port string) *tls.Config { } return config } + +// NewSingleUseTLSDialer is like NewSingleUseDialer but takes +// in input a TLSConn rather than a net.Conn. +func NewSingleUseTLSDialer(conn TLSConn) TLSDialer { + return &tlsDialerSingleUseAdapter{NewSingleUseDialer(conn)} +} + +// tlsDialerSingleUseAdapter adapts dialerSingleUse to +// be a TLSDialer type rather than a Dialer type. +type tlsDialerSingleUseAdapter struct { + Dialer +} + +var _ TLSDialer = &tlsDialerSingleUseAdapter{} + +// DialTLSContext implements TLSDialer.DialTLSContext. +func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.Dialer.DialContext(ctx, network, address) +} diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 4c7bcd4..db5fe38 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -280,21 +280,21 @@ func TestTLSHandshakerLoggerFailure(t *testing.T) { func TestTLSDialerCloseIdleConnections(t *testing.T) { var called bool - dialer := &TLSDialer{ + dialer := &tlsDialer{ Dialer: &mocks.Dialer{ MockCloseIdleConnections: func() { called = true }, }, } - dialer.CloseIdleConnection() + dialer.CloseIdleConnections() if !called { t.Fatal("not called") } } func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { - dialer := &TLSDialer{} + dialer := &tlsDialer{} ctx := context.Background() const address = "www.google.com" // missing port conn, err := dialer.DialTLSContext(ctx, "tcp", address) @@ -309,7 +309,7 @@ func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail - dialer := TLSDialer{Dialer: defaultDialer} + dialer := tlsDialer{Dialer: defaultDialer} conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { t.Fatal("not the error we expected", err) @@ -321,7 +321,7 @@ func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { ctx := context.Background() - dialer := TLSDialer{ + dialer := tlsDialer{ Config: &tls.Config{}, Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{MockWrite: func(b []byte) (int, error) { @@ -345,7 +345,7 @@ func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { ctx := context.Background() - dialer := TLSDialer{ + dialer := tlsDialer{ Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{MockWrite: func(b []byte) (int, error) { return 0, io.EOF @@ -372,7 +372,7 @@ func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { } func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { - d := &TLSDialer{} + d := &tlsDialer{} config := d.config("www.google.com", "443") if config.ServerName != "www.google.com" { t.Fatal("invalid server name") @@ -383,7 +383,7 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { } func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { - d := &TLSDialer{} + d := &tlsDialer{} config := d.config("dns.google", "853") if config.ServerName != "dns.google" { t.Fatal("invalid server name") @@ -394,7 +394,7 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { } func TestTLSDialerConfigWithServerName(t *testing.T) { - d := &TLSDialer{ + d := &tlsDialer{ Config: &tls.Config{ ServerName: "example.com", }, @@ -409,7 +409,7 @@ func TestTLSDialerConfigWithServerName(t *testing.T) { } func TestTLSDialerConfigWithALPN(t *testing.T) { - d := &TLSDialer{ + d := &tlsDialer{ Config: &tls.Config{ NextProtos: []string{"h2"}, }, @@ -440,3 +440,43 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) { t.Fatal("expected nil NewConn") } } + +func TestNewTLSDialerWorksAsIntended(t *testing.T) { + d := &mocks.Dialer{} + tlsh := &mocks.TLSHandshaker{} + td := NewTLSDialer(d, tlsh) + tdut, okay := td.(*tlsDialer) + if !okay { + t.Fatal("invalid type") + } + if tdut.Config == nil { + t.Fatal("unexpected config") + } + if tdut.Dialer != d { + t.Fatal("unexpected dialer") + } + if tdut.TLSHandshaker != tlsh { + t.Fatal("invalid handshaker") + } +} + +func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) { + conn := &mocks.TLSConn{} + d := NewSingleUseTLSDialer(conn) + outconn, err := d.DialTLSContext(context.Background(), "", "") + if err != nil { + t.Fatal(err) + } + if conn != outconn { + t.Fatal("invalid outconn") + } + for i := 0; i < 4; i++ { + outconn, err = d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, ErrNoConnReuse) { + t.Fatal("not the error we expected", err) + } + if outconn != nil { + t.Fatal("expected nil outconn here") + } + } +}