diff --git a/internal/netxlite/http3.go b/internal/netxlite/http3.go index d502146..32e380e 100644 --- a/internal/netxlite/http3.go +++ b/internal/netxlite/http3.go @@ -3,6 +3,7 @@ package netxlite import ( "context" "crypto/tls" + "io" "net/http" "github.com/lucas-clemente/quic-go" @@ -13,19 +14,25 @@ import ( // an http3.RoundTripper. This is necessary because the // http3.RoundTripper does not support DialContext. type http3Dialer struct { - Dialer QUICDialer + QUICDialer } // dial is like QUICContextDialer.DialContext but without context. func (d *http3Dialer) dial(network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { - return d.Dialer.DialContext( + return d.QUICDialer.DialContext( context.Background(), network, address, tlsConfig, quicConfig) } +// http3RoundTripper is the abstract type of quic-go/http3.RoundTripper. +type http3RoundTripper interface { + http.RoundTripper + io.Closer +} + // http3Transport is an HTTPTransport using the http3 protocol. type http3Transport struct { - child *http3.RoundTripper + child http3RoundTripper dialer QUICDialer } diff --git a/internal/netxlite/http3_test.go b/internal/netxlite/http3_test.go index 234e4f3..ee3544c 100644 --- a/internal/netxlite/http3_test.go +++ b/internal/netxlite/http3_test.go @@ -1,42 +1,99 @@ package netxlite import ( + "context" "crypto/tls" + "errors" "net/http" "testing" - "github.com/apex/log" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/http3" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) -func TestHTTP3TransportWorks(t *testing.T) { - d := &quicDialerResolver{ - Dialer: &quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - }, - Resolver: NewResolverSystem(log.Log), - } - txp := NewHTTP3Transport(d, &tls.Config{}) - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com/robots.txt") - if err != nil { - t.Fatal(err) - } - resp.Body.Close() - txp.CloseIdleConnections() +func TestHTTP3Dialer(t *testing.T) { + t.Run("Dial", func(t *testing.T) { + expected := errors.New("mocked error") + d := &http3Dialer{ + QUICDialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + return nil, expected + }, + }, + } + sess, err := d.dial("", "", &tls.Config{}, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if sess != nil { + t.Fatal("unexpected resp") + } + }) } func TestHTTP3TransportClosesIdleConnections(t *testing.T) { - var called bool - d := &mocks.QUICDialer{ - MockCloseIdleConnections: func() { - called = true - }, - } - txp := NewHTTP3Transport(d, &tls.Config{}) - client := &http.Client{Transport: txp} - client.CloseIdleConnections() - if !called { - t.Fatal("not called") - } + t.Run("CloseIdleConnections", func(t *testing.T) { + var ( + calledHTTP3 bool + calledDialer bool + ) + txp := &http3Transport{ + child: &mocks.HTTP3RoundTripper{ + MockClose: func() error { + calledHTTP3 = true + return nil + }, + }, + dialer: &mocks.QUICDialer{ + MockCloseIdleConnections: func() { + calledDialer = true + }, + }, + } + txp.CloseIdleConnections() + if !calledHTTP3 || !calledDialer { + t.Fatal("not called") + } + }) + + t.Run("RoundTrip", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &http3Transport{ + child: &mocks.HTTP3RoundTripper{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, expected + }, + }, + } + resp, err := txp.RoundTrip(&http.Request{}) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("unexpected resp") + } + }) +} + +func TestNewHTTP3Transport(t *testing.T) { + t.Run("creates the correct type chain", func(t *testing.T) { + qd := &mocks.QUICDialer{} + config := &tls.Config{} + txp := NewHTTP3Transport(qd, config) + h3txp := txp.(*http3Transport) + if h3txp.dialer != qd { + t.Fatal("invalid dialer") + } + h3 := h3txp.child.(*http3.RoundTripper) + if h3.Dial == nil { + t.Fatal("invalid Dial") + } + if !h3.DisableCompression { + t.Fatal("invalid DisableCompression") + } + if h3.TLSClientConfig != config { + t.Fatal("invalid TLSClientConfig") + } + }) } diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 2e851b4..131817b 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -18,249 +18,227 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) -func TestHTTPTransportLoggerFailure(t *testing.T) { - txp := &httpTransportLogger{ - Logger: log.Log, - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return nil, io.EOF - }, - }, - } - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com") - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if resp != nil { - t.Fatal("expected nil response here") - } -} - -func TestHTTPTransportLoggerFailureWithNoHostHeader(t *testing.T) { - foundHost := &atomicx.Int64{} - txp := &httpTransportLogger{ - Logger: log.Log, - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - if req.Header.Get("Host") == "www.google.com" { - foundHost.Add(1) - } - return nil, io.EOF - }, - }, - } - req := &http.Request{ - Header: http.Header{}, - URL: &url.URL{ - Scheme: "https", - Host: "www.google.com", - Path: "/", - }, - } - resp, err := txp.RoundTrip(req) - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if resp != nil { - t.Fatal("expected nil response here") - } - if foundHost.Load() != 1 { - t.Fatal("host header was not added") - } -} - -func TestHTTPTransportLoggerSuccess(t *testing.T) { - txp := &httpTransportLogger{ - Logger: log.Log, - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - Body: io.NopCloser(strings.NewReader("")), - Header: http.Header{ - "Server": []string{"antani/0.1.0"}, +func TestHTTPTransportLogger(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("with failure", func(t *testing.T) { + txp := &httpTransportLogger{ + Logger: log.Log, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF }, - StatusCode: 200, - }, nil + }, + } + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response here") + } + }) + + t.Run("we add the host header", func(t *testing.T) { + foundHost := &atomicx.Int64{} + txp := &httpTransportLogger{ + Logger: log.Log, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + if req.Header.Get("Host") == "www.google.com" { + foundHost.Add(1) + } + return nil, io.EOF + }, + }, + } + req := &http.Request{ + Header: http.Header{}, + URL: &url.URL{ + Scheme: "https", + Host: "www.google.com", + Path: "/", + }, + } + resp, err := txp.RoundTrip(req) + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response here") + } + if foundHost.Load() != 1 { + t.Fatal("host header was not added") + } + }) + + t.Run("with success", func(t *testing.T) { + txp := &httpTransportLogger{ + Logger: log.Log, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Body: io.NopCloser(strings.NewReader("")), + Header: http.Header{ + "Server": []string{"antani/0.1.0"}, + }, + StatusCode: 200, + }, nil + }, + }, + } + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + iox.ReadAllContext(context.Background(), resp.Body) + resp.Body.Close() + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + calls := &atomicx.Int64{} + txp := &httpTransportLogger{ + HTTPTransport: &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + calls.Add(1) + }, }, - }, - } - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com") - if err != nil { - t.Fatal(err) - } - iox.ReadAllContext(context.Background(), resp.Body) - resp.Body.Close() + Logger: log.Log, + } + txp.CloseIdleConnections() + if calls.Load() != 1 { + t.Fatal("not called") + } + }) } -func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) { - calls := &atomicx.Int64{} - txp := &httpTransportLogger{ - HTTPTransport: &mocks.HTTPTransport{ - MockCloseIdleConnections: func() { - calls.Add(1) +func TestHTTPTransportConnectionsCloser(t *testing.T) { + t.Run("CloseIdleConnections", func(t *testing.T) { + var ( + calledTxp bool + calledDialer bool + calledTLS bool + ) + txp := &httpTransportConnectionsCloser{ + HTTPTransport: &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + calledTxp = true + }, }, - }, - Logger: log.Log, - } - txp.CloseIdleConnections() - if calls.Load() != 1 { - t.Fatal("not called") - } -} + Dialer: &mocks.Dialer{ + MockCloseIdleConnections: func() { + calledDialer = true + }, + }, + TLSDialer: &mocks.TLSDialer{ + MockCloseIdleConnections: func() { + calledTLS = true + }, + }, + } + txp.CloseIdleConnections() + if !calledDialer || !calledTLS || !calledTxp { + t.Fatal("not called") + } + }) -func TestHTTPTransportWorks(t *testing.T) { - d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log)) - td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) - txp := NewHTTPTransport(log.Log, d, td) - client := &http.Client{Transport: txp} - defer client.CloseIdleConnections() - resp, err := client.Get("https://www.google.com/robots.txt") - if err != nil { - t.Fatal(err) - } - resp.Body.Close() -} - -func TestHTTPTransportWithFailingDialer(t *testing.T) { - called := &atomicx.Int64{} - expected := errors.New("mocked error") - d := &dialerResolver{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, - network, address string) (net.Conn, error) { - return nil, expected + t.Run("RoundTrip", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &httpTransportConnectionsCloser{ + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, expected + }, }, - MockCloseIdleConnections: func() { - called.Add(1) - }, - }, - Resolver: NewResolverSystem(log.Log), - } - td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) - txp := NewHTTPTransport(log.Log, d, td) - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com/robots.txt") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected", err) - } - if resp != nil { - t.Fatal("expected non-nil response here") - } - client.CloseIdleConnections() - if called.Load() < 1 { - t.Fatal("did not propagate CloseIdleConnections") - } + } + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("unexpected resp") + } + }) } func TestNewHTTPTransport(t *testing.T) { - d := &mocks.Dialer{} - td := &mocks.TLSDialer{} - txp := NewHTTPTransport(log.Log, d, td) - logtxp, okay := txp.(*httpTransportLogger) - if !okay { - t.Fatal("invalid type") - } - if logtxp.Logger != log.Log { - t.Fatal("invalid logger") - } - txpcc, okay := logtxp.HTTPTransport.(*httpTransportConnectionsCloser) - if !okay { - t.Fatal("invalid type") - } - udt, okay := txpcc.Dialer.(*httpDialerWithReadTimeout) - if !okay { - t.Fatal("invalid type") - } - if udt.Dialer != d { - t.Fatal("invalid dialer") - } - utdt, okay := txpcc.TLSDialer.(*httpTLSDialerWithReadTimeout) - if !okay { - t.Fatal("invalid type") - } - if utdt.TLSDialer != td { - t.Fatal("invalid tls dialer") - } - stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport) - if !okay { - t.Fatal("invalid type") - } - if !stdwtxp.Transport.ForceAttemptHTTP2 { - t.Fatal("invalid ForceAttemptHTTP2") - } - if !stdwtxp.Transport.DisableCompression { - t.Fatal("invalid DisableCompression") - } - if stdwtxp.Transport.MaxConnsPerHost != 1 { - t.Fatal("invalid MaxConnPerHost") - } - if stdwtxp.Transport.DialTLSContext == nil { - t.Fatal("invalid DialTLSContext") - } - if stdwtxp.Transport.DialContext == nil { - t.Fatal("invalid DialContext") - } + t.Run("works as intended with failing dialer", func(t *testing.T) { + called := &atomicx.Int64{} + expected := errors.New("mocked error") + d := &dialerResolver{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, + network, address string) (net.Conn, error) { + return nil, expected + }, + MockCloseIdleConnections: func() { + called.Add(1) + }, + }, + Resolver: NewResolverSystem(log.Log), + } + td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) + txp := NewHTTPTransport(log.Log, d, td) + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com/robots.txt") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected non-nil response here") + } + client.CloseIdleConnections() + if called.Load() < 1 { + t.Fatal("did not propagate CloseIdleConnections") + } + }) + + t.Run("creates the correct type chain", func(t *testing.T) { + d := &mocks.Dialer{} + td := &mocks.TLSDialer{} + txp := NewHTTPTransport(log.Log, d, td) + logger := txp.(*httpTransportLogger) + if logger.Logger != log.Log { + t.Fatal("invalid logger") + } + connectionsCloser := logger.HTTPTransport.(*httpTransportConnectionsCloser) + withReadTimeout := connectionsCloser.Dialer.(*httpDialerWithReadTimeout) + if withReadTimeout.Dialer != d { + t.Fatal("invalid dialer") + } + tlsWithReadTimeout := connectionsCloser.TLSDialer.(*httpTLSDialerWithReadTimeout) + if tlsWithReadTimeout.TLSDialer != td { + t.Fatal("invalid tls dialer") + } + stdlib := connectionsCloser.HTTPTransport.(*oohttp.StdlibTransport) + if !stdlib.Transport.ForceAttemptHTTP2 { + t.Fatal("invalid ForceAttemptHTTP2") + } + if !stdlib.Transport.DisableCompression { + t.Fatal("invalid DisableCompression") + } + if stdlib.Transport.MaxConnsPerHost != 1 { + t.Fatal("invalid MaxConnPerHost") + } + if stdlib.Transport.DialTLSContext == nil { + t.Fatal("invalid DialTLSContext") + } + if stdlib.Transport.DialContext == nil { + t.Fatal("invalid DialContext") + } + }) } func TestHTTPDialerWithReadTimeout(t *testing.T) { - var ( - calledWithZeroTime bool - calledWithNonZeroTime bool - ) - origConn := &mocks.Conn{ - MockSetReadDeadline: func(t time.Time) error { - switch t.IsZero() { - case true: - calledWithZeroTime = true - case false: - calledWithNonZeroTime = true - } - return nil - }, - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - } - d := &httpDialerWithReadTimeout{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return origConn, nil - }, - }, - } - ctx := context.Background() - conn, err := d.DialContext(ctx, "", "") - if err != nil { - t.Fatal(err) - } - if _, okay := conn.(*httpConnWithReadTimeout); !okay { - t.Fatal("invalid conn type") - } - if conn.(*httpConnWithReadTimeout).Conn != origConn { - t.Fatal("invalid origin conn") - } - b := make([]byte, 1024) - count, err := conn.Read(b) - if !errors.Is(err, io.EOF) { - t.Fatal("invalid error") - } - if count != 0 { - t.Fatal("invalid count") - } - if !calledWithZeroTime || !calledWithNonZeroTime { - t.Fatal("not called") - } -} - -func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { - var ( - calledWithZeroTime bool - calledWithNonZeroTime bool - ) - origConn := &mocks.TLSConn{ - Conn: mocks.Conn{ + t.Run("on success", func(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.Conn{ MockSetReadDeadline: func(t time.Time) error { switch t.IsZero() { case true: @@ -273,97 +251,151 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { MockRead: func(b []byte) (int, error) { return 0, io.EOF }, - }, - } - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return origConn, nil + } + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, }, - }, - } - ctx := context.Background() - conn, err := d.DialTLSContext(ctx, "", "") - if err != nil { - t.Fatal(err) - } - if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { - t.Fatal("invalid conn type") - } - if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { - t.Fatal("invalid origin conn") - } - b := make([]byte, 1024) - count, err := conn.Read(b) - if !errors.Is(err, io.EOF) { - t.Fatal("invalid error") - } - if count != 0 { - t.Fatal("invalid count") - } - if !calledWithZeroTime || !calledWithNonZeroTime { - t.Fatal("not called") - } + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpConnWithReadTimeout).Conn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, + }, + } + conn, err := d.DialContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) } -func TestHTTPDialerWithReadTimeoutDialingFailure(t *testing.T) { - expected := errors.New("mocked error") - d := &httpDialerWithReadTimeout{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, expected +func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { + t.Run("on success", func(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.TLSConn{ + Conn: mocks.Conn{ + MockSetReadDeadline: func(t time.Time) error { + switch t.IsZero() { + case true: + calledWithZeroTime = true + case false: + calledWithNonZeroTime = true + } + return nil + }, + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, }, - }, - } - conn, err := d.DialContext(context.Background(), "", "") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } -} + } + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, + }, + } + ctx := context.Background() + conn, err := d.DialTLSContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } + }) -func TestHTTPTLSDialerWithReadTimeoutDialingFailure(t *testing.T) { - expected := errors.New("mocked error") - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, expected + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, }, - }, - } - conn, err := d.DialTLSContext(context.Background(), "", "") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } -} + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) -func TestHTTPTLSDialerWithInvalidConnType(t *testing.T) { - var called bool - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockClose: func() error { - called = true - return nil - }, - }, nil + t.Run("with invalid conn type", func(t *testing.T) { + var called bool + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + called = true + return nil + }, + }, nil + }, }, - }, - } - conn, err := d.DialTLSContext(context.Background(), "", "") - if !errors.Is(err, ErrNotTLSConn) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - if !called { - t.Fatal("not called") - } + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, ErrNotTLSConn) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + if !called { + t.Fatal("not called") + } + }) } diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go new file mode 100644 index 0000000..4740145 --- /dev/null +++ b/internal/netxlite/integration_test.go @@ -0,0 +1,51 @@ +package netxlite_test + +import ( + "crypto/tls" + "net/http" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +func TestHTTPTransport(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + t.Run("works as intended", func(t *testing.T) { + d := netxlite.NewDialerWithResolver(log.Log, netxlite.NewResolverSystem(log.Log)) + td := netxlite.NewTLSDialer(d, netxlite.NewTLSHandshakerStdlib(log.Log)) + txp := netxlite.NewHTTPTransport(log.Log, d, td) + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com/robots.txt") + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + client.CloseIdleConnections() + }) +} + +func TestHTTP3Transport(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + t.Run("works as intended", func(t *testing.T) { + d := netxlite.NewQUICDialerWithResolver( + netxlite.NewQUICListener(), + log.Log, + netxlite.NewResolverSystem(log.Log), + ) + txp := netxlite.NewHTTP3Transport(d, &tls.Config{}) + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com/robots.txt") + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + txp.CloseIdleConnections() + }) +} diff --git a/internal/netxlite/mocks/http3.go b/internal/netxlite/mocks/http3.go new file mode 100644 index 0000000..8682e58 --- /dev/null +++ b/internal/netxlite/mocks/http3.go @@ -0,0 +1,19 @@ +package mocks + +import "net/http" + +// HTTP3RoundTripper allows mocking http3.RoundTripper. +type HTTP3RoundTripper struct { + MockRoundTrip func(req *http.Request) (*http.Response, error) + MockClose func() error +} + +// RoundTrip calls MockRoundTrip. +func (txp *HTTP3RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return txp.MockRoundTrip(req) +} + +// Close calls MockClose. +func (txp *HTTP3RoundTripper) Close() error { + return txp.MockClose() +} diff --git a/internal/netxlite/mocks/http3_test.go b/internal/netxlite/mocks/http3_test.go new file mode 100644 index 0000000..45d85d1 --- /dev/null +++ b/internal/netxlite/mocks/http3_test.go @@ -0,0 +1,37 @@ +package mocks + +import ( + "errors" + "net/http" + "testing" +) + +func TestHTTP3RoundTripper(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &HTTP3RoundTripper{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, expected + }, + } + resp, err := txp.RoundTrip(&http.Request{}) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("unexpected resp") + } + }) + + t.Run("Close", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &HTTP3RoundTripper{ + MockClose: func() error { + return expected + }, + } + if err := txp.Close(); !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + }) +}