diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index 81fb7f0..205bb3e 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -2,7 +2,6 @@ package netxlite import ( "context" - "crypto/tls" "net" "net/http" "time" @@ -67,17 +66,37 @@ func (txp *httpTransportLogger) CloseIdleConnections() { txp.HTTPTransport.CloseIdleConnections() } -// NewHTTPTransport creates a new HTTP transport using Go stdlib. -func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config, - handshaker TLSHandshaker) HTTPTransport { +// httpTransportConnectionsCloser is an HTTPTransport that +// correctly forwards CloseIdleConnections. +type httpTransportConnectionsCloser struct { + HTTPTransport + Dialer + TLSDialer +} + +// CloseIdleConnections forwards the CloseIdleConnections calls. +func (txp *httpTransportConnectionsCloser) CloseIdleConnections() { + txp.HTTPTransport.CloseIdleConnections() + txp.Dialer.CloseIdleConnections() + txp.TLSDialer.CloseIdleConnections() +} + +// NewHTTPTransport creates a new HTTP transport using the given +// dialer and TLS handshaker to create connections. +// +// We need a TLS handshaker here, as opposed to a TLSDialer, because we +// wrap the dialer we'll use to enforce timeouts for HTTP idle +// connections (see https://github.com/ooni/probe/issues/1609 for more info). +func NewHTTPTransport(dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport { + // TODO(bassosimone): here we should copy code living inside the + // websteps prototype to use the oohttp library. txp := http.DefaultTransport.(*http.Transport).Clone() + // This wrapping ensures that we always have a timeout when we + // are using HTTP; see https://github.com/ooni/probe/issues/1609. dialer = &httpDialerWithReadTimeout{dialer} txp.DialContext = dialer.DialContext - txp.DialTLSContext = (&tlsDialer{ - Config: tlsConfig, - Dialer: dialer, - TLSHandshaker: handshaker, - }).DialTLSContext + tlsDialer := NewTLSDialer(dialer, tlsHandshaker) + txp.DialTLSContext = tlsDialer.DialTLSContext // Better for Cloudflare DNS and also better because we have less // noisy events and we can better understand what happened. txp.MaxConnsPerHost = 1 @@ -86,7 +105,13 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config, // back the true headers, such as Content-Length. This change is // functional to OONI's goal of observing the network. txp.DisableCompression = true - return txp + txp.ForceAttemptHTTP2 = true + // Ensure we correctly forward CloseIdleConnections. + return &httpTransportConnectionsCloser{ + HTTPTransport: txp, + Dialer: dialer, + TLSDialer: tlsDialer, + } } // httpDialerWithReadTimeout enforces a read timeout for all HTTP diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index eca16c4..56e0ad9 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -2,7 +2,6 @@ package netxlite import ( "context" - "crypto/tls" "errors" "io" "net" @@ -110,22 +109,19 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) { } func TestHTTPTransportWorks(t *testing.T) { - d := &dialerResolver{ - Dialer: defaultDialer, - Resolver: NewResolverSystem(log.Log), - } - th := &tlsHandshakerConfigurable{} - txp := NewHTTPTransport(d, &tls.Config{}, th) + d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log)) + txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log)) client := &http.Client{Transport: txp} + defer client.CloseIdleConnections() resp, err := client.Get("https://www.google.com/robots.txt") if err != nil { t.Fatal(err) } resp.Body.Close() - txp.CloseIdleConnections() } func TestHTTPTransportWithFailingDialer(t *testing.T) { + called := &atomicx.Int64{} expected := errors.New("mocked error") d := &dialerResolver{ Dialer: &mocks.Dialer{ @@ -133,11 +129,13 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) { network, address string) (net.Conn, error) { return nil, expected }, + MockCloseIdleConnections: func() { + called.Add(1) + }, }, Resolver: NewResolverSystem(log.Log), } - th := &tlsHandshakerConfigurable{} - txp := NewHTTPTransport(d, &tls.Config{}, th) + txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log)) client := &http.Client{Transport: txp} resp, err := client.Get("https://www.google.com/robots.txt") if !errors.Is(err, expected) { @@ -146,5 +144,47 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) { if resp != nil { t.Fatal("expected non-nil response here") } - txp.CloseIdleConnections() + client.CloseIdleConnections() + if called.Load() < 1 { + t.Fatal("did not propagate CloseIdleConnections") + } +} + +func TestNewHTTPTransport(t *testing.T) { + d := &mocks.Dialer{} + th := &mocks.TLSHandshaker{} + txp := NewHTTPTransport(d, th) + txpcc, okay := txp.(*httpTransportConnectionsCloser) + if !okay { + t.Fatal("invalid type") + } + udt, okay := txpcc.Dialer.(*httpDialerWithReadTimeout) + if !okay { + t.Fatal("invalid type") + } + if udt.Dialer != d { + t.Fatal("invalid dialer") + } + if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th { + t.Fatal("invalid tls handshaker") + } + htxp, okay := txpcc.HTTPTransport.(*http.Transport) + if !okay { + t.Fatal("invalid type") + } + if !htxp.ForceAttemptHTTP2 { + t.Fatal("invalid ForceAttemptHTTP2") + } + if !htxp.DisableCompression { + t.Fatal("invalid DisableCompression") + } + if htxp.MaxConnsPerHost != 1 { + t.Fatal("invalid MaxConnPerHost") + } + if htxp.DialTLSContext == nil { + t.Fatal("invalid DialTLSContext") + } + if htxp.DialContext == nil { + t.Fatal("invalid DialContext") + } } diff --git a/internal/netxlite/mocks/tls.go b/internal/netxlite/mocks/tls.go new file mode 100644 index 0000000..347cd34 --- /dev/null +++ b/internal/netxlite/mocks/tls.go @@ -0,0 +1,60 @@ +package mocks + +import ( + "context" + "crypto/tls" + "net" +) + +// TLSHandshaker is a mockable TLS handshaker. +type TLSHandshaker struct { + MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) ( + net.Conn, tls.ConnectionState, error) +} + +// Handshake calls MockHandshake. +func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( + net.Conn, tls.ConnectionState, error) { + return th.MockHandshake(ctx, conn, config) +} + +// TLSConn allows to mock netxlite.TLSConn. +type TLSConn struct { + // Conn is the embedded mockable Conn. + Conn + + // MockConnectionState allows to mock the ConnectionState method. + MockConnectionState func() tls.ConnectionState + + // MockHandshakeContext allows to mock the HandshakeContext method. + MockHandshakeContext func(ctx context.Context) error +} + +// ConnectionState calls MockConnectionState. +func (c *TLSConn) ConnectionState() tls.ConnectionState { + return c.MockConnectionState() +} + +// HandshakeContext calls MockHandshakeContext. +func (c *TLSConn) HandshakeContext(ctx context.Context) error { + return c.MockHandshakeContext(ctx) +} + +// TLSDialer allows to mock netxlite.TLSDialer. +type TLSDialer struct { + // MockCloseIdleConnections allows to mock the CloseIdleConnections method. + MockCloseIdleConnections func() + + // MockDialTLSContext allows to mock the DialTLSContext method. + MockDialTLSContext func(ctx context.Context, network, address string) (net.Conn, error) +} + +// CloseIdleConnections calls MockCloseIdleConnections. +func (d *TLSDialer) CloseIdleConnections() { + d.MockCloseIdleConnections() +} + +// DialTLSContext calls MockDialTLSContext. +func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.MockDialTLSContext(ctx, network, address) +} diff --git a/internal/netxlite/mocks/tls_test.go b/internal/netxlite/mocks/tls_test.go new file mode 100644 index 0000000..3e4bac6 --- /dev/null +++ b/internal/netxlite/mocks/tls_test.go @@ -0,0 +1,89 @@ +package mocks + +import ( + "context" + "crypto/tls" + "errors" + "net" + "reflect" + "testing" +) + +func TestTLSHandshakerHandshake(t *testing.T) { + expected := errors.New("mocked error") + conn := &Conn{} + ctx := context.Background() + config := &tls.Config{} + th := &TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, + config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, expected + }, + } + tlsConn, connState, err := th.Handshake(ctx, conn, config) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero ConnectionState here") + } + if tlsConn != nil { + t.Fatal("expected nil conn here") + } +} + +func TestTLSConnConnectionState(t *testing.T) { + state := tls.ConnectionState{Version: tls.VersionTLS12} + c := &TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return state + }, + } + out := c.ConnectionState() + if !reflect.DeepEqual(out, state) { + t.Fatal("not the result we expected") + } +} + +func TestTLSConnHandshakeContext(t *testing.T) { + expected := errors.New("mocked error") + c := &TLSConn{ + MockHandshakeContext: func(ctx context.Context) error { + return expected + }, + } + err := c.HandshakeContext(context.Background()) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } +} + +func TestTLSDialerCloseIdleConnections(t *testing.T) { + var called bool + td := &TLSDialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + td.CloseIdleConnections() + if !called { + t.Fatal("not called") + } +} + +func TestTLSDialerDialTLSContext(t *testing.T) { + expected := errors.New("mocked error") + td := &TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, + } + ctx := context.Background() + conn, err := td.DialTLSContext(ctx, "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("expected nil conn here") + } +} diff --git a/internal/netxlite/mocks/tlsconn.go b/internal/netxlite/mocks/tlsconn.go deleted file mode 100644 index 8307b3a..0000000 --- a/internal/netxlite/mocks/tlsconn.go +++ /dev/null @@ -1,28 +0,0 @@ -package mocks - -import ( - "context" - "crypto/tls" -) - -// TLSConn allows to mock netxlite.TLSConn. -type TLSConn struct { - // Conn is the embedded mockable Conn. - Conn - - // MockConnectionState allows to mock the ConnectionState method. - MockConnectionState func() tls.ConnectionState - - // MockHandshakeContext allows to mock the HandshakeContext method. - MockHandshakeContext func(ctx context.Context) error -} - -// ConnectionState calls MockConnectionState. -func (c *TLSConn) ConnectionState() tls.ConnectionState { - return c.MockConnectionState() -} - -// HandshakeContext calls MockHandshakeContext. -func (c *TLSConn) HandshakeContext(ctx context.Context) error { - return c.MockHandshakeContext(ctx) -} diff --git a/internal/netxlite/mocks/tlsconn_test.go b/internal/netxlite/mocks/tlsconn_test.go deleted file mode 100644 index 7b55a81..0000000 --- a/internal/netxlite/mocks/tlsconn_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package mocks - -import ( - "context" - "crypto/tls" - "errors" - "reflect" - "testing" -) - -func TestTLSConnConnectionState(t *testing.T) { - state := tls.ConnectionState{Version: tls.VersionTLS12} - c := &TLSConn{ - MockConnectionState: func() tls.ConnectionState { - return state - }, - } - out := c.ConnectionState() - if !reflect.DeepEqual(out, state) { - t.Fatal("not the result we expected") - } -} - -func TestTLSConnHandshakeContext(t *testing.T) { - expected := errors.New("mocked error") - c := &TLSConn{ - MockHandshakeContext: func(ctx context.Context) error { - return expected - }, - } - err := c.HandshakeContext(context.Background()) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } -} diff --git a/internal/netxlite/mocks/tlshandshaker.go b/internal/netxlite/mocks/tlshandshaker.go deleted file mode 100644 index ddca993..0000000 --- a/internal/netxlite/mocks/tlshandshaker.go +++ /dev/null @@ -1,19 +0,0 @@ -package mocks - -import ( - "context" - "crypto/tls" - "net" -) - -// TLSHandshaker is a mockable TLS handshaker. -type TLSHandshaker struct { - MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) ( - net.Conn, tls.ConnectionState, error) -} - -// Handshake calls MockHandshake. -func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( - net.Conn, tls.ConnectionState, error) { - return th.MockHandshake(ctx, conn, config) -} diff --git a/internal/netxlite/mocks/tlshandshaker_test.go b/internal/netxlite/mocks/tlshandshaker_test.go deleted file mode 100644 index 8bca25a..0000000 --- a/internal/netxlite/mocks/tlshandshaker_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package mocks - -import ( - "context" - "crypto/tls" - "errors" - "net" - "reflect" - "testing" -) - -func TestTLSHandshakerHandshake(t *testing.T) { - expected := errors.New("mocked error") - conn := &Conn{} - ctx := context.Background() - config := &tls.Config{} - th := &TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, - config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, expected - }, - } - tlsConn, connState, err := th.Handshake(ctx, conn, config) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if !reflect.ValueOf(connState).IsZero() { - t.Fatal("expected zero ConnectionState here") - } - if tlsConn != nil { - t.Fatal("expected nil conn here") - } -}