diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index d550535..bca6f4a 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -2,6 +2,7 @@ package netxlite import ( "context" + "errors" "net" "net/http" "time" @@ -84,11 +85,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() { } // NewHTTPTransport creates a new HTTP transport using the given -// dialer and TLS handshaker to create connections. -// -// We need a TLS handshaker here, as opposed to a TLSDialer, because we -// wrap the dialer we'll use to enforce timeouts for HTTP idle -// connections (see https://github.com/ooni/probe/issues/1609 for more info). +// dialer and TLS dialer to create connections. // // The returned transport will use the given Logger for logging. // @@ -101,7 +98,7 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() { // The returned transport will disable transparent decompression // of compressed response bodies (and will not automatically // ask for such compression, though you can always do that manually). -func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker) HTTPTransport { +func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport { // Using oohttp to support any TLS library. txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone() @@ -109,7 +106,7 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsHandshaker TLSHandshaker) // are using HTTP; see https://github.com/ooni/probe/issues/1609. dialer = &httpDialerWithReadTimeout{dialer} txp.DialContext = dialer.DialContext - tlsDialer := NewTLSDialer(dialer, tlsHandshaker) + tlsDialer = &httpTLSDialerWithReadTimeout{tlsDialer} txp.DialTLSContext = tlsDialer.DialTLSContext // We are using a different strategy to implement proxy: we @@ -160,15 +157,73 @@ func (d *httpDialerWithReadTimeout) DialContext( return &httpConnWithReadTimeout{conn}, nil } +// httpTLSDialerWithReadTimeout enforces a read timeout for all HTTP +// connections. See https://github.com/ooni/probe/issues/1609. +type httpTLSDialerWithReadTimeout struct { + TLSDialer +} + +// ErrNotTLSConn indicates that a TLSDialer returns a net.Conn +// that does not implement the TLSConn interface. This error should +// only happen when we do something wrong setting up HTTP code. +var ErrNotTLSConn = errors.New("not a TLSConn") + +// DialTLSContext implements TLSDialer's DialTLSContext. +func (d *httpTLSDialerWithReadTimeout) DialTLSContext( + ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.TLSDialer.DialTLSContext(ctx, network, address) + if err != nil { + return nil, err + } + tconn, okay := conn.(TLSConn) + if !okay { + conn.Close() // we own the conn here + return nil, ErrNotTLSConn + } + return &httpTLSConnWithReadTimeout{tconn}, nil +} + // httpConnWithReadTimeout enforces a read timeout for all HTTP // connections. See https://github.com/ooni/probe/issues/1609. type httpConnWithReadTimeout struct { net.Conn } +// httpConnReadTimeout is the read timeout we apply to all HTTP +// conns (see https://github.com/ooni/probe/issues/1609). +// +// This timeout is meant as a fallback mechanism so that a stuck +// connection will _eventually_ fail. This is why it is set to +// a large value (300 seconds when writing this note). +// +// There should be other mechanisms to ensure that the code is +// lively: the context during the RoundTrip and iox.ReadAllContext +// when reading the body. They should kick in earlier. But we +// additionally want to avoid leaking a (parked?) connection and +// the corresponding goroutine, hence this large timeout. +// +// A future @bassosimone may understand this problem even better +// and possibly apply an even better fix to this issue. This +// will happen when we'll be able to further study the anomalies +// described in https://github.com/ooni/probe/issues/1609. +const httpConnReadTimeout = 300 * time.Second + // Read implements Conn.Read. func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) { - c.Conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + c.Conn.SetReadDeadline(time.Now().Add(httpConnReadTimeout)) defer c.Conn.SetReadDeadline(time.Time{}) return c.Conn.Read(b) } + +// httpTLSConnWithReadTimeout enforces a read timeout for all HTTP +// connections. See https://github.com/ooni/probe/issues/1609. +type httpTLSConnWithReadTimeout struct { + TLSConn +} + +// Read implements Conn.Read. +func (c *httpTLSConnWithReadTimeout) Read(b []byte) (int, error) { + c.TLSConn.SetReadDeadline(time.Now().Add(httpConnReadTimeout)) + defer c.TLSConn.SetReadDeadline(time.Time{}) + return c.TLSConn.Read(b) +} diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 65ef599..2e851b4 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -9,6 +9,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/apex/log" oohttp "github.com/ooni/oohttp" @@ -111,7 +112,8 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) { func TestHTTPTransportWorks(t *testing.T) { d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log)) - txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log)) + td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) + txp := NewHTTPTransport(log.Log, d, td) client := &http.Client{Transport: txp} defer client.CloseIdleConnections() resp, err := client.Get("https://www.google.com/robots.txt") @@ -136,7 +138,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) { }, Resolver: NewResolverSystem(log.Log), } - txp := NewHTTPTransport(log.Log, d, NewTLSHandshakerStdlib(log.Log)) + td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) + txp := NewHTTPTransport(log.Log, d, td) client := &http.Client{Transport: txp} resp, err := client.Get("https://www.google.com/robots.txt") if !errors.Is(err, expected) { @@ -153,8 +156,8 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) { func TestNewHTTPTransport(t *testing.T) { d := &mocks.Dialer{} - th := &mocks.TLSHandshaker{} - txp := NewHTTPTransport(log.Log, d, th) + td := &mocks.TLSDialer{} + txp := NewHTTPTransport(log.Log, d, td) logtxp, okay := txp.(*httpTransportLogger) if !okay { t.Fatal("invalid type") @@ -173,8 +176,12 @@ func TestNewHTTPTransport(t *testing.T) { if udt.Dialer != d { t.Fatal("invalid dialer") } - if txpcc.TLSDialer.(*tlsDialer).TLSHandshaker != th { - t.Fatal("invalid tls handshaker") + utdt, okay := txpcc.TLSDialer.(*httpTLSDialerWithReadTimeout) + if !okay { + t.Fatal("invalid type") + } + if utdt.TLSDialer != td { + t.Fatal("invalid tls dialer") } stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport) if !okay { @@ -196,3 +203,167 @@ func TestNewHTTPTransport(t *testing.T) { t.Fatal("invalid DialContext") } } + +func TestHTTPDialerWithReadTimeout(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.Conn{ + MockSetReadDeadline: func(t time.Time) error { + switch t.IsZero() { + case true: + calledWithZeroTime = true + case false: + calledWithNonZeroTime = true + } + return nil + }, + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + } + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, + }, + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpConnWithReadTimeout).Conn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } +} + +func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.TLSConn{ + Conn: mocks.Conn{ + MockSetReadDeadline: func(t time.Time) error { + switch t.IsZero() { + case true: + calledWithZeroTime = true + case false: + calledWithNonZeroTime = true + } + return nil + }, + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + }, + } + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, + }, + } + ctx := context.Background() + conn, err := d.DialTLSContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } +} + +func TestHTTPDialerWithReadTimeoutDialingFailure(t *testing.T) { + expected := errors.New("mocked error") + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, + }, + } + conn, err := d.DialContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } +} + +func TestHTTPTLSDialerWithReadTimeoutDialingFailure(t *testing.T) { + expected := errors.New("mocked error") + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, + }, + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } +} + +func TestHTTPTLSDialerWithInvalidConnType(t *testing.T) { + var called bool + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + called = true + return nil + }, + }, nil + }, + }, + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, ErrNotTLSConn) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + if !called { + t.Fatal("not called") + } +}