refactor(netxlite): let NewHTTPTransport work with single-use dialers (#467)
To make this happen, we need to take as argument a TLSDialer rather than a TLSHandshaker. Then, we need to arrange the code so that we always enforce a timeout for both TCP and TLS connections. Because a TLSDialer can be constructed with a custom TLSConfig, we cover also the case where the users wants to provide such a config. While there, make sure we have better unit tests of the HTTP code. See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
3114d6ca0e
commit
ba5bae4769
|
@ -2,6 +2,7 @@ package netxlite
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -84,11 +85,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
|
|||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTP transport using the given
|
||||
// dialer and TLS handshaker to create connections.
|
||||
//
|
||||
// We need a TLS handshaker here, as opposed to a TLSDialer, because we
|
||||
// wrap the dialer we'll use to enforce timeouts for HTTP idle
|
||||
// connections (see https://github.com/ooni/probe/issues/1609 for more info).
|
||||
// dialer and TLS dialer to create connections.
|
||||
//
|
||||
// The returned transport will use the given Logger for logging.
|
||||
//
|
||||
|
@ -101,7 +98,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
|
|||
// The returned transport will disable transparent decompression
|
||||
// of compressed response bodies (and will not automatically
|
||||
// ask for such compression, though you can always do that manually).
|
||||
func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport {
|
||||
func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
|
||||
// Using oohttp to support any TLS library.
|
||||
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
|
||||
|
||||
|
@ -109,7 +106,7 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker)
|
|||
// are using HTTP; see https://github.com/ooni/probe/issues/1609.
|
||||
dialer = &httpDialerWithReadTimeout{dialer}
|
||||
txp.DialContext = dialer.DialContext
|
||||
tlsDialer := NewTLSDialer(dialer, tlsHandshaker)
|
||||
tlsDialer = &httpTLSDialerWithReadTimeout{tlsDialer}
|
||||
txp.DialTLSContext = tlsDialer.DialTLSContext
|
||||
|
||||
// We are using a different strategy to implement proxy: we
|
||||
|
@ -160,15 +157,73 @@ func (d *httpDialerWithReadTimeout) DialContext(
|
|||
return &httpConnWithReadTimeout{conn}, nil
|
||||
}
|
||||
|
||||
// httpTLSDialerWithReadTimeout enforces a read timeout for all HTTP
|
||||
// connections. See https://github.com/ooni/probe/issues/1609.
|
||||
type httpTLSDialerWithReadTimeout struct {
|
||||
TLSDialer
|
||||
}
|
||||
|
||||
// ErrNotTLSConn indicates that a TLSDialer returns a net.Conn
|
||||
// that does not implement the TLSConn interface. This error should
|
||||
// only happen when we do something wrong setting up HTTP code.
|
||||
var ErrNotTLSConn = errors.New("not a TLSConn")
|
||||
|
||||
// DialTLSContext implements TLSDialer's DialTLSContext.
|
||||
func (d *httpTLSDialerWithReadTimeout) DialTLSContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.TLSDialer.DialTLSContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tconn, okay := conn.(TLSConn)
|
||||
if !okay {
|
||||
conn.Close() // we own the conn here
|
||||
return nil, ErrNotTLSConn
|
||||
}
|
||||
return &httpTLSConnWithReadTimeout{tconn}, nil
|
||||
}
|
||||
|
||||
// httpConnWithReadTimeout enforces a read timeout for all HTTP
|
||||
// connections. See https://github.com/ooni/probe/issues/1609.
|
||||
type httpConnWithReadTimeout struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
// httpConnReadTimeout is the read timeout we apply to all HTTP
|
||||
// conns (see https://github.com/ooni/probe/issues/1609).
|
||||
//
|
||||
// This timeout is meant as a fallback mechanism so that a stuck
|
||||
// connection will _eventually_ fail. This is why it is set to
|
||||
// a large value (300 seconds when writing this note).
|
||||
//
|
||||
// There should be other mechanisms to ensure that the code is
|
||||
// lively: the context during the RoundTrip and iox.ReadAllContext
|
||||
// when reading the body. They should kick in earlier. But we
|
||||
// additionally want to avoid leaking a (parked?) connection and
|
||||
// the corresponding goroutine, hence this large timeout.
|
||||
//
|
||||
// A future @bassosimone may understand this problem even better
|
||||
// and possibly apply an even better fix to this issue. This
|
||||
// will happen when we'll be able to further study the anomalies
|
||||
// described in https://github.com/ooni/probe/issues/1609.
|
||||
const httpConnReadTimeout = 300 * time.Second
|
||||
|
||||
// Read implements Conn.Read.
|
||||
func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
c.Conn.SetReadDeadline(time.Now().Add(httpConnReadTimeout))
|
||||
defer c.Conn.SetReadDeadline(time.Time{})
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
// httpTLSConnWithReadTimeout enforces a read timeout for all HTTP
|
||||
// connections. See https://github.com/ooni/probe/issues/1609.
|
||||
type httpTLSConnWithReadTimeout struct {
|
||||
TLSConn
|
||||
}
|
||||
|
||||
// Read implements Conn.Read.
|
||||
func (c *httpTLSConnWithReadTimeout) Read(b []byte) (int, error) {
|
||||
c.TLSConn.SetReadDeadline(time.Now().Add(httpConnReadTimeout))
|
||||
defer c.TLSConn.SetReadDeadline(time.Time{})
|
||||
return c.TLSConn.Read(b)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apex/log"
|
||||
oohttp "github.com/ooni/oohttp"
|
||||
|
@ -111,7 +112,8 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
|
|||
|
||||
func TestHTTPTransportWorks(t *testing.T) {
|
||||
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log))
|
||||
txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log))
|
||||
td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log))
|
||||
txp := NewHTTPTransport(log.Log, d, td)
|
||||
client := &http.Client{Transport: txp}
|
||||
defer client.CloseIdleConnections()
|
||||
resp, err := client.Get("https://www.google.com/robots.txt")
|
||||
|
@ -136,7 +138,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
|
|||
},
|
||||
Resolver: NewResolverSystem(log.Log),
|
||||
}
|
||||
txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log))
|
||||
td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log))
|
||||
txp := NewHTTPTransport(log.Log, d, td)
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com/robots.txt")
|
||||
if !errors.Is(err, expected) {
|
||||
|
@ -153,8 +156,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
|
|||
|
||||
func TestNewHTTPTransport(t *testing.T) {
|
||||
d := &mocks.Dialer{}
|
||||
th := &mocks.TLSHandshaker{}
|
||||
txp := NewHTTPTransport(log.Log, d, th)
|
||||
td := &mocks.TLSDialer{}
|
||||
txp := NewHTTPTransport(log.Log, d, td)
|
||||
logtxp, okay := txp.(*httpTransportLogger)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
|
@ -173,8 +176,12 @@ func TestNewHTTPTransport(t *testing.T) {
|
|||
if udt.Dialer != d {
|
||||
t.Fatal("invalid dialer")
|
||||
}
|
||||
if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th {
|
||||
t.Fatal("invalid tls handshaker")
|
||||
utdt, okay := txpcc.TLSDialer.(*httpTLSDialerWithReadTimeout)
|
||||
if !okay {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
if utdt.TLSDialer != td {
|
||||
t.Fatal("invalid tls dialer")
|
||||
}
|
||||
stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport)
|
||||
if !okay {
|
||||
|
@ -196,3 +203,167 @@ func TestNewHTTPTransport(t *testing.T) {
|
|||
t.Fatal("invalid DialContext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPDialerWithReadTimeout(t *testing.T) {
|
||||
var (
|
||||
calledWithZeroTime bool
|
||||
calledWithNonZeroTime bool
|
||||
)
|
||||
origConn := &mocks.Conn{
|
||||
MockSetReadDeadline: func(t time.Time) error {
|
||||
switch t.IsZero() {
|
||||
case true:
|
||||
calledWithZeroTime = true
|
||||
case false:
|
||||
calledWithNonZeroTime = true
|
||||
}
|
||||
return nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
}
|
||||
d := &httpDialerWithReadTimeout{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return origConn, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialContext(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
|
||||
t.Fatal("invalid conn type")
|
||||
}
|
||||
if conn.(*httpConnWithReadTimeout).Conn != origConn {
|
||||
t.Fatal("invalid origin conn")
|
||||
}
|
||||
b := make([]byte, 1024)
|
||||
count, err := conn.Read(b)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("invalid error")
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
||||
var (
|
||||
calledWithZeroTime bool
|
||||
calledWithNonZeroTime bool
|
||||
)
|
||||
origConn := &mocks.TLSConn{
|
||||
Conn: mocks.Conn{
|
||||
MockSetReadDeadline: func(t time.Time) error {
|
||||
switch t.IsZero() {
|
||||
case true:
|
||||
calledWithZeroTime = true
|
||||
case false:
|
||||
calledWithNonZeroTime = true
|
||||
}
|
||||
return nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
},
|
||||
}
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return origConn, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialTLSContext(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
|
||||
t.Fatal("invalid conn type")
|
||||
}
|
||||
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
|
||||
t.Fatal("invalid origin conn")
|
||||
}
|
||||
b := make([]byte, 1024)
|
||||
count, err := conn.Read(b)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("invalid error")
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPDialerWithReadTimeoutDialingFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := &httpDialerWithReadTimeout{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "", "")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTLSDialerWithReadTimeoutDialingFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTLSDialerWithInvalidConnType(t *testing.T) {
|
||||
var called bool
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||
if !errors.Is(err, ErrNotTLSConn) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user