fix(netxlite): http factory that propagates close-idle-connections (#465)

While there reorganize mocks' tls implementation to use a single file
called tls.go (and tls_test.go) just like netxlite does.

While there write tests ensuring we always add timeouts when we are
making TCP connections (be them TLS or cleartext).

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-06 16:53:28 +02:00 committed by GitHub
parent 2572376fdb
commit 6df27d919d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 235 additions and 136 deletions

View File

@ -2,7 +2,6 @@ package netxlite
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
@ -67,17 +66,37 @@ func (txp *httpTransportLogger) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
}
// NewHTTPTransport creates a new HTTP transport using Go stdlib.
func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
handshaker TLSHandshaker) HTTPTransport {
// httpTransportConnectionsCloser is an HTTPTransport that
// correctly forwards CloseIdleConnections.
type httpTransportConnectionsCloser struct {
HTTPTransport
Dialer
TLSDialer
}
// CloseIdleConnections forwards the CloseIdleConnections calls.
func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
txp.Dialer.CloseIdleConnections()
txp.TLSDialer.CloseIdleConnections()
}
// NewHTTPTransport creates a new HTTP transport using the given
// dialer and TLS handshaker to create connections.
//
// We need a TLS handshaker here, as opposed to a TLSDialer, because we
// wrap the dialer we'll use to enforce timeouts for HTTP idle
// connections (see https://github.com/ooni/probe/issues/1609 for more info).
func NewHTTPTransport(dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport {
// TODO(bassosimone): here we should copy code living inside the
// websteps prototype to use the oohttp library.
txp := http.DefaultTransport.(*http.Transport).Clone()
// This wrapping ensures that we always have a timeout when we
// are using HTTP; see https://github.com/ooni/probe/issues/1609.
dialer = &httpDialerWithReadTimeout{dialer}
txp.DialContext = dialer.DialContext
txp.DialTLSContext = (&tlsDialer{
Config: tlsConfig,
Dialer: dialer,
TLSHandshaker: handshaker,
}).DialTLSContext
tlsDialer := NewTLSDialer(dialer, tlsHandshaker)
txp.DialTLSContext = tlsDialer.DialTLSContext
// Better for Cloudflare DNS and also better because we have less
// noisy events and we can better understand what happened.
txp.MaxConnsPerHost = 1
@ -86,7 +105,13 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
// back the true headers, such as Content-Length. This change is
// functional to OONI's goal of observing the network.
txp.DisableCompression = true
return txp
txp.ForceAttemptHTTP2 = true
// Ensure we correctly forward CloseIdleConnections.
return &httpTransportConnectionsCloser{
HTTPTransport: txp,
Dialer: dialer,
TLSDialer: tlsDialer,
}
}
// httpDialerWithReadTimeout enforces a read timeout for all HTTP

View File

@ -2,7 +2,6 @@ package netxlite
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
@ -110,22 +109,19 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
}
func TestHTTPTransportWorks(t *testing.T) {
d := &dialerResolver{
Dialer: defaultDialer,
Resolver: NewResolverSystem(log.Log),
}
th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th)
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log))
txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log))
client := &http.Client{Transport: txp}
defer client.CloseIdleConnections()
resp, err := client.Get("https://www.google.com/robots.txt")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
txp.CloseIdleConnections()
}
func TestHTTPTransportWithFailingDialer(t *testing.T) {
called := &atomicx.Int64{}
expected := errors.New("mocked error")
d := &dialerResolver{
Dialer: &mocks.Dialer{
@ -133,11 +129,13 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
network, address string) (net.Conn, error) {
return nil, expected
},
MockCloseIdleConnections: func() {
called.Add(1)
},
},
Resolver: NewResolverSystem(log.Log),
}
th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th)
txp := NewHTTPTransport(d, NewTLSHandshakerStdlib(log.Log))
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com/robots.txt")
if !errors.Is(err, expected) {
@ -146,5 +144,47 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
if resp != nil {
t.Fatal("expected non-nil response here")
}
txp.CloseIdleConnections()
client.CloseIdleConnections()
if called.Load() < 1 {
t.Fatal("did not propagate CloseIdleConnections")
}
}
func TestNewHTTPTransport(t *testing.T) {
d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{}
txp := NewHTTPTransport(d, th)
txpcc, okay := txp.(*httpTransportConnectionsCloser)
if !okay {
t.Fatal("invalid type")
}
udt, okay := txpcc.Dialer.(*httpDialerWithReadTimeout)
if !okay {
t.Fatal("invalid type")
}
if udt.Dialer != d {
t.Fatal("invalid dialer")
}
if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th {
t.Fatal("invalid tls handshaker")
}
htxp, okay := txpcc.HTTPTransport.(*http.Transport)
if !okay {
t.Fatal("invalid type")
}
if !htxp.ForceAttemptHTTP2 {
t.Fatal("invalid ForceAttemptHTTP2")
}
if !htxp.DisableCompression {
t.Fatal("invalid DisableCompression")
}
if htxp.MaxConnsPerHost != 1 {
t.Fatal("invalid MaxConnPerHost")
}
if htxp.DialTLSContext == nil {
t.Fatal("invalid DialTLSContext")
}
if htxp.DialContext == nil {
t.Fatal("invalid DialContext")
}
}

View File

@ -0,0 +1,60 @@
package mocks
import (
"context"
"crypto/tls"
"net"
)
// TLSHandshaker is a mockable TLS handshaker.
type TLSHandshaker struct {
MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}
// Handshake calls MockHandshake.
func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error) {
return th.MockHandshake(ctx, conn, config)
}
// TLSConn allows to mock netxlite.TLSConn.
type TLSConn struct {
// Conn is the embedded mockable Conn.
Conn
// MockConnectionState allows to mock the ConnectionState method.
MockConnectionState func() tls.ConnectionState
// MockHandshakeContext allows to mock the HandshakeContext method.
MockHandshakeContext func(ctx context.Context) error
}
// ConnectionState calls MockConnectionState.
func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}
// HandshakeContext calls MockHandshakeContext.
func (c *TLSConn) HandshakeContext(ctx context.Context) error {
return c.MockHandshakeContext(ctx)
}
// TLSDialer allows to mock netxlite.TLSDialer.
type TLSDialer struct {
// MockCloseIdleConnections allows to mock the CloseIdleConnections method.
MockCloseIdleConnections func()
// MockDialTLSContext allows to mock the DialTLSContext method.
MockDialTLSContext func(ctx context.Context, network, address string) (net.Conn, error)
}
// CloseIdleConnections calls MockCloseIdleConnections.
func (d *TLSDialer) CloseIdleConnections() {
d.MockCloseIdleConnections()
}
// DialTLSContext calls MockDialTLSContext.
func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.MockDialTLSContext(ctx, network, address)
}

View File

@ -0,0 +1,89 @@
package mocks
import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
)
func TestTLSHandshakerHandshake(t *testing.T) {
expected := errors.New("mocked error")
conn := &Conn{}
ctx := context.Background()
config := &tls.Config{}
th := &TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn,
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expected
},
}
tlsConn, connState, err := th.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here")
}
if tlsConn != nil {
t.Fatal("expected nil conn here")
}
}
func TestTLSConnConnectionState(t *testing.T) {
state := tls.ConnectionState{Version: tls.VersionTLS12}
c := &TLSConn{
MockConnectionState: func() tls.ConnectionState {
return state
},
}
out := c.ConnectionState()
if !reflect.DeepEqual(out, state) {
t.Fatal("not the result we expected")
}
}
func TestTLSConnHandshakeContext(t *testing.T) {
expected := errors.New("mocked error")
c := &TLSConn{
MockHandshakeContext: func(ctx context.Context) error {
return expected
},
}
err := c.HandshakeContext(context.Background())
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
}
func TestTLSDialerCloseIdleConnections(t *testing.T) {
var called bool
td := &TLSDialer{
MockCloseIdleConnections: func() {
called = true
},
}
td.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
}
func TestTLSDialerDialTLSContext(t *testing.T) {
expected := errors.New("mocked error")
td := &TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
}
ctx := context.Background()
conn, err := td.DialTLSContext(ctx, "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}

View File

@ -1,28 +0,0 @@
package mocks
import (
"context"
"crypto/tls"
)
// TLSConn allows to mock netxlite.TLSConn.
type TLSConn struct {
// Conn is the embedded mockable Conn.
Conn
// MockConnectionState allows to mock the ConnectionState method.
MockConnectionState func() tls.ConnectionState
// MockHandshakeContext allows to mock the HandshakeContext method.
MockHandshakeContext func(ctx context.Context) error
}
// ConnectionState calls MockConnectionState.
func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}
// HandshakeContext calls MockHandshakeContext.
func (c *TLSConn) HandshakeContext(ctx context.Context) error {
return c.MockHandshakeContext(ctx)
}

View File

@ -1,35 +0,0 @@
package mocks
import (
"context"
"crypto/tls"
"errors"
"reflect"
"testing"
)
func TestTLSConnConnectionState(t *testing.T) {
state := tls.ConnectionState{Version: tls.VersionTLS12}
c := &TLSConn{
MockConnectionState: func() tls.ConnectionState {
return state
},
}
out := c.ConnectionState()
if !reflect.DeepEqual(out, state) {
t.Fatal("not the result we expected")
}
}
func TestTLSConnHandshakeContext(t *testing.T) {
expected := errors.New("mocked error")
c := &TLSConn{
MockHandshakeContext: func(ctx context.Context) error {
return expected
},
}
err := c.HandshakeContext(context.Background())
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
}

View File

@ -1,19 +0,0 @@
package mocks
import (
"context"
"crypto/tls"
"net"
)
// TLSHandshaker is a mockable TLS handshaker.
type TLSHandshaker struct {
MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}
// Handshake calls MockHandshake.
func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error) {
return th.MockHandshake(ctx, conn, config)
}

View File

@ -1,33 +0,0 @@
package mocks
import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
)
func TestTLSHandshakerHandshake(t *testing.T) {
expected := errors.New("mocked error")
conn := &Conn{}
ctx := context.Background()
config := &tls.Config{}
th := &TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn,
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expected
},
}
tlsConn, connState, err := th.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero ConnectionState here")
}
if tlsConn != nil {
t.Fatal("expected nil conn here")
}
}