refactor(netxlite): let NewHTTPTransport work with single-use dialers (#467)
To make this happen, we need to take as argument a TLSDialer rather than a TLSHandshaker. Then, we need to arrange the code so that we always enforce a timeout for both TCP and TLS connections. Because a TLSDialer can be constructed with a custom TLSConfig, we cover also the case where the users wants to provide such a config. While there, make sure we have better unit tests of the HTTP code. See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
3114d6ca0e
commit
ba5bae4769
|
@ -2,6 +2,7 @@ package netxlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -84,11 +85,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPTransport creates a new HTTP transport using the given
|
// NewHTTPTransport creates a new HTTP transport using the given
|
||||||
// dialer and TLS handshaker to create connections.
|
// dialer and TLS dialer to create connections.
|
||||||
//
|
|
||||||
// We need a TLS handshaker here, as opposed to a TLSDialer, because we
|
|
||||||
// wrap the dialer we'll use to enforce timeouts for HTTP idle
|
|
||||||
// connections (see https://github.com/ooni/probe/issues/1609 for more info).
|
|
||||||
//
|
//
|
||||||
// The returned transport will use the given Logger for logging.
|
// The returned transport will use the given Logger for logging.
|
||||||
//
|
//
|
||||||
|
@ -101,7 +98,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
|
||||||
// The returned transport will disable transparent decompression
|
// The returned transport will disable transparent decompression
|
||||||
// of compressed response bodies (and will not automatically
|
// of compressed response bodies (and will not automatically
|
||||||
// ask for such compression, though you can always do that manually).
|
// ask for such compression, though you can always do that manually).
|
||||||
func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport {
|
func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
|
||||||
// Using oohttp to support any TLS library.
|
// Using oohttp to support any TLS library.
|
||||||
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
|
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
|
||||||
|
|
||||||
|
@ -109,7 +106,7 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker)
|
||||||
// are using HTTP; see https://github.com/ooni/probe/issues/1609.
|
// are using HTTP; see https://github.com/ooni/probe/issues/1609.
|
||||||
dialer = &httpDialerWithReadTimeout{dialer}
|
dialer = &httpDialerWithReadTimeout{dialer}
|
||||||
txp.DialContext = dialer.DialContext
|
txp.DialContext = dialer.DialContext
|
||||||
tlsDialer := NewTLSDialer(dialer, tlsHandshaker)
|
tlsDialer = &httpTLSDialerWithReadTimeout{tlsDialer}
|
||||||
txp.DialTLSContext = tlsDialer.DialTLSContext
|
txp.DialTLSContext = tlsDialer.DialTLSContext
|
||||||
|
|
||||||
// We are using a different strategy to implement proxy: we
|
// We are using a different strategy to implement proxy: we
|
||||||
|
@ -160,15 +157,73 @@ func (d *httpDialerWithReadTimeout) DialContext(
|
||||||
return &httpConnWithReadTimeout{conn}, nil
|
return &httpConnWithReadTimeout{conn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// httpTLSDialerWithReadTimeout enforces a read timeout for all HTTP
|
||||||
|
// connections. See https://github.com/ooni/probe/issues/1609.
|
||||||
|
type httpTLSDialerWithReadTimeout struct {
|
||||||
|
TLSDialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrNotTLSConn indicates that a TLSDialer returns a net.Conn
|
||||||
|
// that does not implement the TLSConn interface. This error should
|
||||||
|
// only happen when we do something wrong setting up HTTP code.
|
||||||
|
var ErrNotTLSConn = errors.New("not a TLSConn")
|
||||||
|
|
||||||
|
// DialTLSContext implements TLSDialer's DialTLSContext.
|
||||||
|
func (d *httpTLSDialerWithReadTimeout) DialTLSContext(
|
||||||
|
ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
conn, err := d.TLSDialer.DialTLSContext(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tconn, okay := conn.(TLSConn)
|
||||||
|
if !okay {
|
||||||
|
conn.Close() // we own the conn here
|
||||||
|
return nil, ErrNotTLSConn
|
||||||
|
}
|
||||||
|
return &httpTLSConnWithReadTimeout{tconn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// httpConnWithReadTimeout enforces a read timeout for all HTTP
|
// httpConnWithReadTimeout enforces a read timeout for all HTTP
|
||||||
// connections. See https://github.com/ooni/probe/issues/1609.
|
// connections. See https://github.com/ooni/probe/issues/1609.
|
||||||
type httpConnWithReadTimeout struct {
|
type httpConnWithReadTimeout struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// httpConnReadTimeout is the read timeout we apply to all HTTP
|
||||||
|
// conns (see https://github.com/ooni/probe/issues/1609).
|
||||||
|
//
|
||||||
|
// This timeout is meant as a fallback mechanism so that a stuck
|
||||||
|
// connection will _eventually_ fail. This is why it is set to
|
||||||
|
// a large value (300 seconds when writing this note).
|
||||||
|
//
|
||||||
|
// There should be other mechanisms to ensure that the code is
|
||||||
|
// lively: the context during the RoundTrip and iox.ReadAllContext
|
||||||
|
// when reading the body. They should kick in earlier. But we
|
||||||
|
// additionally want to avoid leaking a (parked?) connection and
|
||||||
|
// the corresponding goroutine, hence this large timeout.
|
||||||
|
//
|
||||||
|
// A future @bassosimone may understand this problem even better
|
||||||
|
// and possibly apply an even better fix to this issue. This
|
||||||
|
// will happen when we'll be able to further study the anomalies
|
||||||
|
// described in https://github.com/ooni/probe/issues/1609.
|
||||||
|
const httpConnReadTimeout = 300 * time.Second
|
||||||
|
|
||||||
// Read implements Conn.Read.
|
// Read implements Conn.Read.
|
||||||
func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) {
|
func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) {
|
||||||
c.Conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
c.Conn.SetReadDeadline(time.Now().Add(httpConnReadTimeout))
|
||||||
defer c.Conn.SetReadDeadline(time.Time{})
|
defer c.Conn.SetReadDeadline(time.Time{})
|
||||||
return c.Conn.Read(b)
|
return c.Conn.Read(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// httpTLSConnWithReadTimeout enforces a read timeout for all HTTP
|
||||||
|
// connections. See https://github.com/ooni/probe/issues/1609.
|
||||||
|
type httpTLSConnWithReadTimeout struct {
|
||||||
|
TLSConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements Conn.Read.
|
||||||
|
func (c *httpTLSConnWithReadTimeout) Read(b []byte) (int, error) {
|
||||||
|
c.TLSConn.SetReadDeadline(time.Now().Add(httpConnReadTimeout))
|
||||||
|
defer c.TLSConn.SetReadDeadline(time.Time{})
|
||||||
|
return c.TLSConn.Read(b)
|
||||||
|
}
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apex/log"
|
"github.com/apex/log"
|
||||||
oohttp "github.com/ooni/oohttp"
|
oohttp "github.com/ooni/oohttp"
|
||||||
|
@ -111,7 +112,8 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
|
||||||
|
|
||||||
func TestHTTPTransportWorks(t *testing.T) {
|
func TestHTTPTransportWorks(t *testing.T) {
|
||||||
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log))
|
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log))
|
||||||
txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log))
|
td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log))
|
||||||
|
txp := NewHTTPTransport(log.Log, d, td)
|
||||||
client := &http.Client{Transport: txp}
|
client := &http.Client{Transport: txp}
|
||||||
defer client.CloseIdleConnections()
|
defer client.CloseIdleConnections()
|
||||||
resp, err := client.Get("https://www.google.com/robots.txt")
|
resp, err := client.Get("https://www.google.com/robots.txt")
|
||||||
|
@ -136,7 +138,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
|
||||||
},
|
},
|
||||||
Resolver: NewResolverSystem(log.Log),
|
Resolver: NewResolverSystem(log.Log),
|
||||||
}
|
}
|
||||||
txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log))
|
td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log))
|
||||||
|
txp := NewHTTPTransport(log.Log, d, td)
|
||||||
client := &http.Client{Transport: txp}
|
client := &http.Client{Transport: txp}
|
||||||
resp, err := client.Get("https://www.google.com/robots.txt")
|
resp, err := client.Get("https://www.google.com/robots.txt")
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
|
@ -153,8 +156,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
|
||||||
|
|
||||||
func TestNewHTTPTransport(t *testing.T) {
|
func TestNewHTTPTransport(t *testing.T) {
|
||||||
d := &mocks.Dialer{}
|
d := &mocks.Dialer{}
|
||||||
th := &mocks.TLSHandshaker{}
|
td := &mocks.TLSDialer{}
|
||||||
txp := NewHTTPTransport(log.Log, d, th)
|
txp := NewHTTPTransport(log.Log, d, td)
|
||||||
logtxp, okay := txp.(*httpTransportLogger)
|
logtxp, okay := txp.(*httpTransportLogger)
|
||||||
if !okay {
|
if !okay {
|
||||||
t.Fatal("invalid type")
|
t.Fatal("invalid type")
|
||||||
|
@ -173,8 +176,12 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||||
if udt.Dialer != d {
|
if udt.Dialer != d {
|
||||||
t.Fatal("invalid dialer")
|
t.Fatal("invalid dialer")
|
||||||
}
|
}
|
||||||
if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th {
|
utdt, okay := txpcc.TLSDialer.(*httpTLSDialerWithReadTimeout)
|
||||||
t.Fatal("invalid tls handshaker")
|
if !okay {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
if utdt.TLSDialer != td {
|
||||||
|
t.Fatal("invalid tls dialer")
|
||||||
}
|
}
|
||||||
stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport)
|
stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport)
|
||||||
if !okay {
|
if !okay {
|
||||||
|
@ -196,3 +203,167 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||||
t.Fatal("invalid DialContext")
|
t.Fatal("invalid DialContext")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTPDialerWithReadTimeout(t *testing.T) {
|
||||||
|
var (
|
||||||
|
calledWithZeroTime bool
|
||||||
|
calledWithNonZeroTime bool
|
||||||
|
)
|
||||||
|
origConn := &mocks.Conn{
|
||||||
|
MockSetReadDeadline: func(t time.Time) error {
|
||||||
|
switch t.IsZero() {
|
||||||
|
case true:
|
||||||
|
calledWithZeroTime = true
|
||||||
|
case false:
|
||||||
|
calledWithNonZeroTime = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
},
|
||||||
|
}
|
||||||
|
d := &httpDialerWithReadTimeout{
|
||||||
|
Dialer: &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return origConn, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, err := d.DialContext(ctx, "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
|
||||||
|
t.Fatal("invalid conn type")
|
||||||
|
}
|
||||||
|
if conn.(*httpConnWithReadTimeout).Conn != origConn {
|
||||||
|
t.Fatal("invalid origin conn")
|
||||||
|
}
|
||||||
|
b := make([]byte, 1024)
|
||||||
|
count, err := conn.Read(b)
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("invalid error")
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
||||||
|
var (
|
||||||
|
calledWithZeroTime bool
|
||||||
|
calledWithNonZeroTime bool
|
||||||
|
)
|
||||||
|
origConn := &mocks.TLSConn{
|
||||||
|
Conn: mocks.Conn{
|
||||||
|
MockSetReadDeadline: func(t time.Time) error {
|
||||||
|
switch t.IsZero() {
|
||||||
|
case true:
|
||||||
|
calledWithZeroTime = true
|
||||||
|
case false:
|
||||||
|
calledWithNonZeroTime = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
d := &httpTLSDialerWithReadTimeout{
|
||||||
|
TLSDialer: &mocks.TLSDialer{
|
||||||
|
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return origConn, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, err := d.DialTLSContext(ctx, "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
|
||||||
|
t.Fatal("invalid conn type")
|
||||||
|
}
|
||||||
|
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
|
||||||
|
t.Fatal("invalid origin conn")
|
||||||
|
}
|
||||||
|
b := make([]byte, 1024)
|
||||||
|
count, err := conn.Read(b)
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("invalid error")
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPDialerWithReadTimeoutDialingFailure(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
d := &httpDialerWithReadTimeout{
|
||||||
|
Dialer: &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := d.DialContext(context.Background(), "", "")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn here")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTLSDialerWithReadTimeoutDialingFailure(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
d := &httpTLSDialerWithReadTimeout{
|
||||||
|
TLSDialer: &mocks.TLSDialer{
|
||||||
|
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn here")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTPTLSDialerWithInvalidConnType(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
d := &httpTLSDialerWithReadTimeout{
|
||||||
|
TLSDialer: &mocks.TLSDialer{
|
||||||
|
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return &mocks.Conn{
|
||||||
|
MockClose: func() error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||||
|
if !errors.Is(err, ErrNotTLSConn) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn here")
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user