refactor(netxlite): let NewHTTPTransport work with single-use dialers (#467)

To make this happen, we need to take as argument a TLSDialer rather than
a TLSHandshaker. Then, we need to arrange the code so that we always
enforce a timeout for both TCP and TLS connections.

Because a TLSDialer can be constructed with a custom TLSConfig, we cover
also the case where the users wants to provide such a config.

While there, make sure we have better unit tests of the HTTP code.

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-06 19:27:59 +02:00 committed by GitHub
parent 3114d6ca0e
commit ba5bae4769
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 240 additions and 14 deletions

View File

@ -2,6 +2,7 @@ package netxlite
import ( import (
"context" "context"
"errors"
"net" "net"
"net/http" "net/http"
"time" "time"
@ -84,11 +85,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
} }
// NewHTTPTransport creates a new HTTP transport using the given // NewHTTPTransport creates a new HTTP transport using the given
// dialer and TLS handshaker to create connections. // dialer and TLS dialer to create connections.
//
// We need a TLS handshaker here, as opposed to a TLSDialer, because we
// wrap the dialer we'll use to enforce timeouts for HTTP idle
// connections (see https://github.com/ooni/probe/issues/1609 for more info).
// //
// The returned transport will use the given Logger for logging. // The returned transport will use the given Logger for logging.
// //
@ -101,7 +98,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
// The returned transport will disable transparent decompression // The returned transport will disable transparent decompression
// of compressed response bodies (and will not automatically // of compressed response bodies (and will not automatically
// ask for such compression, though you can always do that manually). // ask for such compression, though you can always do that manually).
func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport { func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
// Using oohttp to support any TLS library. // Using oohttp to support any TLS library.
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone() txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
@ -109,7 +106,7 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker)
// are using HTTP; see https://github.com/ooni/probe/issues/1609. // are using HTTP; see https://github.com/ooni/probe/issues/1609.
dialer = &httpDialerWithReadTimeout{dialer} dialer = &httpDialerWithReadTimeout{dialer}
txp.DialContext = dialer.DialContext txp.DialContext = dialer.DialContext
tlsDialer := NewTLSDialer(dialer, tlsHandshaker) tlsDialer = &httpTLSDialerWithReadTimeout{tlsDialer}
txp.DialTLSContext = tlsDialer.DialTLSContext txp.DialTLSContext = tlsDialer.DialTLSContext
// We are using a different strategy to implement proxy: we // We are using a different strategy to implement proxy: we
@ -160,15 +157,73 @@ func (d *httpDialerWithReadTimeout) DialContext(
return &httpConnWithReadTimeout{conn}, nil return &httpConnWithReadTimeout{conn}, nil
} }
// httpTLSDialerWithReadTimeout enforces a read timeout for all HTTP
// connections. See https://github.com/ooni/probe/issues/1609.
type httpTLSDialerWithReadTimeout struct {
TLSDialer
}
// ErrNotTLSConn indicates that a TLSDialer returns a net.Conn
// that does not implement the TLSConn interface. This error should
// only happen when we do something wrong setting up HTTP code.
var ErrNotTLSConn = errors.New("not a TLSConn")
// DialTLSContext implements TLSDialer's DialTLSContext.
func (d *httpTLSDialerWithReadTimeout) DialTLSContext(
ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.TLSDialer.DialTLSContext(ctx, network, address)
if err != nil {
return nil, err
}
tconn, okay := conn.(TLSConn)
if !okay {
conn.Close() // we own the conn here
return nil, ErrNotTLSConn
}
return &httpTLSConnWithReadTimeout{tconn}, nil
}
// httpConnWithReadTimeout enforces a read timeout for all HTTP // httpConnWithReadTimeout enforces a read timeout for all HTTP
// connections. See https://github.com/ooni/probe/issues/1609. // connections. See https://github.com/ooni/probe/issues/1609.
type httpConnWithReadTimeout struct { type httpConnWithReadTimeout struct {
net.Conn net.Conn
} }
// httpConnReadTimeout is the read timeout we apply to all HTTP
// conns (see https://github.com/ooni/probe/issues/1609).
//
// This timeout is meant as a fallback mechanism so that a stuck
// connection will _eventually_ fail. This is why it is set to
// a large value (300 seconds when writing this note).
//
// There should be other mechanisms to ensure that the code is
// lively: the context during the RoundTrip and iox.ReadAllContext
// when reading the body. They should kick in earlier. But we
// additionally want to avoid leaking a (parked?) connection and
// the corresponding goroutine, hence this large timeout.
//
// A future @bassosimone may understand this problem even better
// and possibly apply an even better fix to this issue. This
// will happen when we'll be able to further study the anomalies
// described in https://github.com/ooni/probe/issues/1609.
const httpConnReadTimeout = 300 * time.Second
// Read implements Conn.Read. // Read implements Conn.Read.
func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) { func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) {
c.Conn.SetReadDeadline(time.Now().Add(30 * time.Second)) c.Conn.SetReadDeadline(time.Now().Add(httpConnReadTimeout))
defer c.Conn.SetReadDeadline(time.Time{}) defer c.Conn.SetReadDeadline(time.Time{})
return c.Conn.Read(b) return c.Conn.Read(b)
} }
// httpTLSConnWithReadTimeout enforces a read timeout for all HTTP
// connections. See https://github.com/ooni/probe/issues/1609.
type httpTLSConnWithReadTimeout struct {
TLSConn
}
// Read implements Conn.Read.
func (c *httpTLSConnWithReadTimeout) Read(b []byte) (int, error) {
c.TLSConn.SetReadDeadline(time.Now().Add(httpConnReadTimeout))
defer c.TLSConn.SetReadDeadline(time.Time{})
return c.TLSConn.Read(b)
}

View File

@ -9,6 +9,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/apex/log" "github.com/apex/log"
oohttp "github.com/ooni/oohttp" oohttp "github.com/ooni/oohttp"
@ -111,7 +112,8 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
func TestHTTPTransportWorks(t *testing.T) { func TestHTTPTransportWorks(t *testing.T) {
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log)) d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log))
txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log)) td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log))
txp := NewHTTPTransport(log.Log, d, td)
client := &http.Client{Transport: txp} client := &http.Client{Transport: txp}
defer client.CloseIdleConnections() defer client.CloseIdleConnections()
resp, err := client.Get("https://www.google.com/robots.txt") resp, err := client.Get("https://www.google.com/robots.txt")
@ -136,7 +138,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
}, },
Resolver: NewResolverSystem(log.Log), Resolver: NewResolverSystem(log.Log),
} }
txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log)) td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log))
txp := NewHTTPTransport(log.Log, d, td)
client := &http.Client{Transport: txp} client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com/robots.txt") resp, err := client.Get("https://www.google.com/robots.txt")
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
@ -153,8 +156,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
func TestNewHTTPTransport(t *testing.T) { func TestNewHTTPTransport(t *testing.T) {
d := &mocks.Dialer{} d := &mocks.Dialer{}
th := &mocks.TLSHandshaker{} td := &mocks.TLSDialer{}
txp := NewHTTPTransport(log.Log, d, th) txp := NewHTTPTransport(log.Log, d, td)
logtxp, okay := txp.(*httpTransportLogger) logtxp, okay := txp.(*httpTransportLogger)
if !okay { if !okay {
t.Fatal("invalid type") t.Fatal("invalid type")
@ -173,8 +176,12 @@ func TestNewHTTPTransport(t *testing.T) {
if udt.Dialer != d { if udt.Dialer != d {
t.Fatal("invalid dialer") t.Fatal("invalid dialer")
} }
if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th { utdt, okay := txpcc.TLSDialer.(*httpTLSDialerWithReadTimeout)
t.Fatal("invalid tls handshaker") if !okay {
t.Fatal("invalid type")
}
if utdt.TLSDialer != td {
t.Fatal("invalid tls dialer")
} }
stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport) stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport)
if !okay { if !okay {
@ -196,3 +203,167 @@ func TestNewHTTPTransport(t *testing.T) {
t.Fatal("invalid DialContext") t.Fatal("invalid DialContext")
} }
} }
func TestHTTPDialerWithReadTimeout(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.Conn{
MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
}
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpConnWithReadTimeout).Conn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
}
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.TLSConn{
Conn: mocks.Conn{
MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
},
}
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialTLSContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
}
func TestHTTPDialerWithReadTimeoutDialingFailure(t *testing.T) {
expected := errors.New("mocked error")
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
}
conn, err := d.DialContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
func TestHTTPTLSDialerWithReadTimeoutDialingFailure(t *testing.T) {
expected := errors.New("mocked error")
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
func TestHTTPTLSDialerWithInvalidConnType(t *testing.T) {
var called bool
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
called = true
return nil
},
}, nil
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, ErrNotTLSConn) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
if !called {
t.Fatal("not called")
}
}