refactor(netxlite): improve tests for http and http3 (#487)

* refactor(netxlite): improve tests for http and http3

See https://github.com/ooni/probe/issues/1591

* Update internal/netxlite/http3.go
This commit is contained in:
Simone Basso 2021-09-08 00:59:48 +02:00 committed by GitHub
parent 6d39118b26
commit 493b72b170
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 549 additions and 346 deletions

View File

@ -3,6 +3,7 @@ package netxlite
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"io"
"net/http" "net/http"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@ -13,19 +14,25 @@ import (
// an http3.RoundTripper. This is necessary because the // an http3.RoundTripper. This is necessary because the
// http3.RoundTripper does not support DialContext. // http3.RoundTripper does not support DialContext.
type http3Dialer struct { type http3Dialer struct {
Dialer QUICDialer QUICDialer
} }
// dial is like QUICContextDialer.DialContext but without context. // dial is like QUICContextDialer.DialContext but without context.
func (d *http3Dialer) dial(network, address string, tlsConfig *tls.Config, func (d *http3Dialer) dial(network, address string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlySession, error) { quicConfig *quic.Config) (quic.EarlySession, error) {
return d.Dialer.DialContext( return d.QUICDialer.DialContext(
context.Background(), network, address, tlsConfig, quicConfig) context.Background(), network, address, tlsConfig, quicConfig)
} }
// http3RoundTripper is the abstract type of quic-go/http3.RoundTripper.
type http3RoundTripper interface {
http.RoundTripper
io.Closer
}
// http3Transport is an HTTPTransport using the http3 protocol. // http3Transport is an HTTPTransport using the http3 protocol.
type http3Transport struct { type http3Transport struct {
child *http3.RoundTripper child http3RoundTripper
dialer QUICDialer dialer QUICDialer
} }

View File

@ -1,42 +1,99 @@
package netxlite package netxlite
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors"
"net/http" "net/http"
"testing" "testing"
"github.com/apex/log" "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestHTTP3TransportWorks(t *testing.T) { func TestHTTP3Dialer(t *testing.T) {
d := &quicDialerResolver{ t.Run("Dial", func(t *testing.T) {
Dialer: &quicDialerQUICGo{ expected := errors.New("mocked error")
QUICListener: &quicListenerStdlib{}, d := &http3Dialer{
}, QUICDialer: &mocks.QUICDialer{
Resolver: NewResolverSystem(log.Log), MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
} return nil, expected
txp := NewHTTP3Transport(d, &tls.Config{}) },
client := &http.Client{Transport: txp} },
resp, err := client.Get("https://www.google.com/robots.txt") }
if err != nil { sess, err := d.dial("", "", &tls.Config{}, &quic.Config{})
t.Fatal(err) if !errors.Is(err, expected) {
} t.Fatal("unexpected err", err)
resp.Body.Close() }
txp.CloseIdleConnections() if sess != nil {
t.Fatal("unexpected resp")
}
})
} }
func TestHTTP3TransportClosesIdleConnections(t *testing.T) { func TestHTTP3TransportClosesIdleConnections(t *testing.T) {
var called bool t.Run("CloseIdleConnections", func(t *testing.T) {
d := &mocks.QUICDialer{ var (
MockCloseIdleConnections: func() { calledHTTP3 bool
called = true calledDialer bool
}, )
} txp := &http3Transport{
txp := NewHTTP3Transport(d, &tls.Config{}) child: &mocks.HTTP3RoundTripper{
client := &http.Client{Transport: txp} MockClose: func() error {
client.CloseIdleConnections() calledHTTP3 = true
if !called { return nil
t.Fatal("not called") },
} },
dialer: &mocks.QUICDialer{
MockCloseIdleConnections: func() {
calledDialer = true
},
},
}
txp.CloseIdleConnections()
if !calledHTTP3 || !calledDialer {
t.Fatal("not called")
}
})
t.Run("RoundTrip", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &http3Transport{
child: &mocks.HTTP3RoundTripper{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return nil, expected
},
},
}
resp, err := txp.RoundTrip(&http.Request{})
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("unexpected resp")
}
})
}
func TestNewHTTP3Transport(t *testing.T) {
t.Run("creates the correct type chain", func(t *testing.T) {
qd := &mocks.QUICDialer{}
config := &tls.Config{}
txp := NewHTTP3Transport(qd, config)
h3txp := txp.(*http3Transport)
if h3txp.dialer != qd {
t.Fatal("invalid dialer")
}
h3 := h3txp.child.(*http3.RoundTripper)
if h3.Dial == nil {
t.Fatal("invalid Dial")
}
if !h3.DisableCompression {
t.Fatal("invalid DisableCompression")
}
if h3.TLSClientConfig != config {
t.Fatal("invalid TLSClientConfig")
}
})
} }

View File

@ -18,249 +18,227 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestHTTPTransportLoggerFailure(t *testing.T) { func TestHTTPTransportLogger(t *testing.T) {
txp := &httpTransportLogger{ t.Run("RoundTrip", func(t *testing.T) {
Logger: log.Log, t.Run("with failure", func(t *testing.T) {
HTTPTransport: &mocks.HTTPTransport{ txp := &httpTransportLogger{
MockRoundTrip: func(req *http.Request) (*http.Response, error) { Logger: log.Log,
return nil, io.EOF HTTPTransport: &mocks.HTTPTransport{
}, MockRoundTrip: func(req *http.Request) (*http.Response, error) {
}, return nil, io.EOF
}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
}
func TestHTTPTransportLoggerFailureWithNoHostHeader(t *testing.T) {
foundHost := &atomicx.Int64{}
txp := &httpTransportLogger{
Logger: log.Log,
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
if req.Header.Get("Host") == "www.google.com" {
foundHost.Add(1)
}
return nil, io.EOF
},
},
}
req := &http.Request{
Header: http.Header{},
URL: &url.URL{
Scheme: "https",
Host: "www.google.com",
Path: "/",
},
}
resp, err := txp.RoundTrip(req)
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
if foundHost.Load() != 1 {
t.Fatal("host header was not added")
}
}
func TestHTTPTransportLoggerSuccess(t *testing.T) {
txp := &httpTransportLogger{
Logger: log.Log,
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{
"Server": []string{"antani/0.1.0"},
}, },
StatusCode: 200, },
}, nil }
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("we add the host header", func(t *testing.T) {
foundHost := &atomicx.Int64{}
txp := &httpTransportLogger{
Logger: log.Log,
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
if req.Header.Get("Host") == "www.google.com" {
foundHost.Add(1)
}
return nil, io.EOF
},
},
}
req := &http.Request{
Header: http.Header{},
URL: &url.URL{
Scheme: "https",
Host: "www.google.com",
Path: "/",
},
}
resp, err := txp.RoundTrip(req)
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
if foundHost.Load() != 1 {
t.Fatal("host header was not added")
}
})
t.Run("with success", func(t *testing.T) {
txp := &httpTransportLogger{
Logger: log.Log,
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{
"Server": []string{"antani/0.1.0"},
},
StatusCode: 200,
}, nil
},
},
}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatal(err)
}
iox.ReadAllContext(context.Background(), resp.Body)
resp.Body.Close()
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
calls := &atomicx.Int64{}
txp := &httpTransportLogger{
HTTPTransport: &mocks.HTTPTransport{
MockCloseIdleConnections: func() {
calls.Add(1)
},
}, },
}, Logger: log.Log,
} }
client := &http.Client{Transport: txp} txp.CloseIdleConnections()
resp, err := client.Get("https://www.google.com") if calls.Load() != 1 {
if err != nil { t.Fatal("not called")
t.Fatal(err) }
} })
iox.ReadAllContext(context.Background(), resp.Body)
resp.Body.Close()
} }
func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) { func TestHTTPTransportConnectionsCloser(t *testing.T) {
calls := &atomicx.Int64{} t.Run("CloseIdleConnections", func(t *testing.T) {
txp := &httpTransportLogger{ var (
HTTPTransport: &mocks.HTTPTransport{ calledTxp bool
MockCloseIdleConnections: func() { calledDialer bool
calls.Add(1) calledTLS bool
)
txp := &httpTransportConnectionsCloser{
HTTPTransport: &mocks.HTTPTransport{
MockCloseIdleConnections: func() {
calledTxp = true
},
}, },
}, Dialer: &mocks.Dialer{
Logger: log.Log, MockCloseIdleConnections: func() {
} calledDialer = true
txp.CloseIdleConnections() },
if calls.Load() != 1 { },
t.Fatal("not called") TLSDialer: &mocks.TLSDialer{
} MockCloseIdleConnections: func() {
} calledTLS = true
},
},
}
txp.CloseIdleConnections()
if !calledDialer || !calledTLS || !calledTxp {
t.Fatal("not called")
}
})
func TestHTTPTransportWorks(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
d := NewDialerWithResolver(log.Log, NewResolverSystem(log.Log)) expected := errors.New("mocked error")
td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) txp := &httpTransportConnectionsCloser{
txp := NewHTTPTransport(log.Log, d, td) HTTPTransport: &mocks.HTTPTransport{
client := &http.Client{Transport: txp} MockRoundTrip: func(req *http.Request) (*http.Response, error) {
defer client.CloseIdleConnections() return nil, expected
resp, err := client.Get("https://www.google.com/robots.txt") },
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
}
func TestHTTPTransportWithFailingDialer(t *testing.T) {
called := &atomicx.Int64{}
expected := errors.New("mocked error")
d := &dialerResolver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context,
network, address string) (net.Conn, error) {
return nil, expected
}, },
MockCloseIdleConnections: func() { }
called.Add(1) client := &http.Client{Transport: txp}
}, resp, err := client.Get("https://www.google.com")
}, if !errors.Is(err, expected) {
Resolver: NewResolverSystem(log.Log), t.Fatal("unexpected err", err)
} }
td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) if resp != nil {
txp := NewHTTPTransport(log.Log, d, td) t.Fatal("unexpected resp")
client := &http.Client{Transport: txp} }
resp, err := client.Get("https://www.google.com/robots.txt") })
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("expected non-nil response here")
}
client.CloseIdleConnections()
if called.Load() < 1 {
t.Fatal("did not propagate CloseIdleConnections")
}
} }
func TestNewHTTPTransport(t *testing.T) { func TestNewHTTPTransport(t *testing.T) {
d := &mocks.Dialer{} t.Run("works as intended with failing dialer", func(t *testing.T) {
td := &mocks.TLSDialer{} called := &atomicx.Int64{}
txp := NewHTTPTransport(log.Log, d, td) expected := errors.New("mocked error")
logtxp, okay := txp.(*httpTransportLogger) d := &dialerResolver{
if !okay { Dialer: &mocks.Dialer{
t.Fatal("invalid type") MockDialContext: func(ctx context.Context,
} network, address string) (net.Conn, error) {
if logtxp.Logger != log.Log { return nil, expected
t.Fatal("invalid logger") },
} MockCloseIdleConnections: func() {
txpcc, okay := logtxp.HTTPTransport.(*httpTransportConnectionsCloser) called.Add(1)
if !okay { },
t.Fatal("invalid type") },
} Resolver: NewResolverSystem(log.Log),
udt, okay := txpcc.Dialer.(*httpDialerWithReadTimeout) }
if !okay { td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log))
t.Fatal("invalid type") txp := NewHTTPTransport(log.Log, d, td)
} client := &http.Client{Transport: txp}
if udt.Dialer != d { resp, err := client.Get("https://www.google.com/robots.txt")
t.Fatal("invalid dialer") if !errors.Is(err, expected) {
} t.Fatal("not the error we expected", err)
utdt, okay := txpcc.TLSDialer.(*httpTLSDialerWithReadTimeout) }
if !okay { if resp != nil {
t.Fatal("invalid type") t.Fatal("expected non-nil response here")
} }
if utdt.TLSDialer != td { client.CloseIdleConnections()
t.Fatal("invalid tls dialer") if called.Load() < 1 {
} t.Fatal("did not propagate CloseIdleConnections")
stdwtxp, okay := txpcc.HTTPTransport.(*oohttp.StdlibTransport) }
if !okay { })
t.Fatal("invalid type")
} t.Run("creates the correct type chain", func(t *testing.T) {
if !stdwtxp.Transport.ForceAttemptHTTP2 { d := &mocks.Dialer{}
t.Fatal("invalid ForceAttemptHTTP2") td := &mocks.TLSDialer{}
} txp := NewHTTPTransport(log.Log, d, td)
if !stdwtxp.Transport.DisableCompression { logger := txp.(*httpTransportLogger)
t.Fatal("invalid DisableCompression") if logger.Logger != log.Log {
} t.Fatal("invalid logger")
if stdwtxp.Transport.MaxConnsPerHost != 1 { }
t.Fatal("invalid MaxConnPerHost") connectionsCloser := logger.HTTPTransport.(*httpTransportConnectionsCloser)
} withReadTimeout := connectionsCloser.Dialer.(*httpDialerWithReadTimeout)
if stdwtxp.Transport.DialTLSContext == nil { if withReadTimeout.Dialer != d {
t.Fatal("invalid DialTLSContext") t.Fatal("invalid dialer")
} }
if stdwtxp.Transport.DialContext == nil { tlsWithReadTimeout := connectionsCloser.TLSDialer.(*httpTLSDialerWithReadTimeout)
t.Fatal("invalid DialContext") if tlsWithReadTimeout.TLSDialer != td {
} t.Fatal("invalid tls dialer")
}
stdlib := connectionsCloser.HTTPTransport.(*oohttp.StdlibTransport)
if !stdlib.Transport.ForceAttemptHTTP2 {
t.Fatal("invalid ForceAttemptHTTP2")
}
if !stdlib.Transport.DisableCompression {
t.Fatal("invalid DisableCompression")
}
if stdlib.Transport.MaxConnsPerHost != 1 {
t.Fatal("invalid MaxConnPerHost")
}
if stdlib.Transport.DialTLSContext == nil {
t.Fatal("invalid DialTLSContext")
}
if stdlib.Transport.DialContext == nil {
t.Fatal("invalid DialContext")
}
})
} }
func TestHTTPDialerWithReadTimeout(t *testing.T) { func TestHTTPDialerWithReadTimeout(t *testing.T) {
var ( t.Run("on success", func(t *testing.T) {
calledWithZeroTime bool var (
calledWithNonZeroTime bool calledWithZeroTime bool
) calledWithNonZeroTime bool
origConn := &mocks.Conn{ )
MockSetReadDeadline: func(t time.Time) error { origConn := &mocks.Conn{
switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
}
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpConnWithReadTimeout).Conn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
}
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.TLSConn{
Conn: mocks.Conn{
MockSetReadDeadline: func(t time.Time) error { MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() { switch t.IsZero() {
case true: case true:
@ -273,97 +251,151 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
MockRead: func(b []byte) (int, error) { MockRead: func(b []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
}, },
}, }
} d := &httpDialerWithReadTimeout{
d := &httpTLSDialerWithReadTimeout{ Dialer: &mocks.Dialer{
TLSDialer: &mocks.TLSDialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { return origConn, nil
return origConn, nil },
}, },
}, }
} ctx := context.Background()
ctx := context.Background() conn, err := d.DialContext(ctx, "", "")
conn, err := d.DialTLSContext(ctx, "", "") if err != nil {
if err != nil { t.Fatal(err)
t.Fatal(err) }
} if _, okay := conn.(*httpConnWithReadTimeout); !okay {
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { t.Fatal("invalid conn type")
t.Fatal("invalid conn type") }
} if conn.(*httpConnWithReadTimeout).Conn != origConn {
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { t.Fatal("invalid origin conn")
t.Fatal("invalid origin conn") }
} b := make([]byte, 1024)
b := make([]byte, 1024) count, err := conn.Read(b)
count, err := conn.Read(b) if !errors.Is(err, io.EOF) {
if !errors.Is(err, io.EOF) { t.Fatal("invalid error")
t.Fatal("invalid error") }
} if count != 0 {
if count != 0 { t.Fatal("invalid count")
t.Fatal("invalid count") }
} if !calledWithZeroTime || !calledWithNonZeroTime {
if !calledWithZeroTime || !calledWithNonZeroTime { t.Fatal("not called")
t.Fatal("not called") }
} })
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
}
conn, err := d.DialContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
} }
func TestHTTPDialerWithReadTimeoutDialingFailure(t *testing.T) { func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
expected := errors.New("mocked error") t.Run("on success", func(t *testing.T) {
d := &httpDialerWithReadTimeout{ var (
Dialer: &mocks.Dialer{ calledWithZeroTime bool
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { calledWithNonZeroTime bool
return nil, expected )
origConn := &mocks.TLSConn{
Conn: mocks.Conn{
MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
}, },
}, }
} d := &httpTLSDialerWithReadTimeout{
conn, err := d.DialContext(context.Background(), "", "") TLSDialer: &mocks.TLSDialer{
if !errors.Is(err, expected) { MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
t.Fatal("not the error we expected") return origConn, nil
} },
if conn != nil { },
t.Fatal("expected nil conn here") }
} ctx := context.Background()
} conn, err := d.DialTLSContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
})
func TestHTTPTLSDialerWithReadTimeoutDialingFailure(t *testing.T) { t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
d := &httpTLSDialerWithReadTimeout{ d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{ TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected return nil, expected
},
}, },
}, }
} conn, err := d.DialTLSContext(context.Background(), "", "")
conn, err := d.DialTLSContext(context.Background(), "", "") if !errors.Is(err, expected) {
if !errors.Is(err, expected) { t.Fatal("not the error we expected")
t.Fatal("not the error we expected") }
} if conn != nil {
if conn != nil { t.Fatal("expected nil conn here")
t.Fatal("expected nil conn here") }
} })
}
func TestHTTPTLSDialerWithInvalidConnType(t *testing.T) { t.Run("with invalid conn type", func(t *testing.T) {
var called bool var called bool
d := &httpTLSDialerWithReadTimeout{ d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{ TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{ return &mocks.Conn{
MockClose: func() error { MockClose: func() error {
called = true called = true
return nil return nil
}, },
}, nil }, nil
},
}, },
}, }
} conn, err := d.DialTLSContext(context.Background(), "", "")
conn, err := d.DialTLSContext(context.Background(), "", "") if !errors.Is(err, ErrNotTLSConn) {
if !errors.Is(err, ErrNotTLSConn) { t.Fatal("not the error we expected")
t.Fatal("not the error we expected") }
} if conn != nil {
if conn != nil { t.Fatal("expected nil conn here")
t.Fatal("expected nil conn here") }
} if !called {
if !called { t.Fatal("not called")
t.Fatal("not called") }
} })
} }

View File

@ -0,0 +1,51 @@
package netxlite_test
import (
"crypto/tls"
"net/http"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestHTTPTransport(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
t.Run("works as intended", func(t *testing.T) {
d := netxlite.NewDialerWithResolver(log.Log, netxlite.NewResolverSystem(log.Log))
td := netxlite.NewTLSDialer(d, netxlite.NewTLSHandshakerStdlib(log.Log))
txp := netxlite.NewHTTPTransport(log.Log, d, td)
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com/robots.txt")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
client.CloseIdleConnections()
})
}
func TestHTTP3Transport(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
t.Run("works as intended", func(t *testing.T) {
d := netxlite.NewQUICDialerWithResolver(
netxlite.NewQUICListener(),
log.Log,
netxlite.NewResolverSystem(log.Log),
)
txp := netxlite.NewHTTP3Transport(d, &tls.Config{})
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com/robots.txt")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
txp.CloseIdleConnections()
})
}

View File

@ -0,0 +1,19 @@
package mocks
import "net/http"
// HTTP3RoundTripper allows mocking http3.RoundTripper.
type HTTP3RoundTripper struct {
MockRoundTrip func(req *http.Request) (*http.Response, error)
MockClose func() error
}
// RoundTrip calls MockRoundTrip.
func (txp *HTTP3RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return txp.MockRoundTrip(req)
}
// Close calls MockClose.
func (txp *HTTP3RoundTripper) Close() error {
return txp.MockClose()
}

View File

@ -0,0 +1,37 @@
package mocks
import (
"errors"
"net/http"
"testing"
)
func TestHTTP3RoundTripper(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &HTTP3RoundTripper{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(&http.Request{})
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("unexpected resp")
}
})
t.Run("Close", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &HTTP3RoundTripper{
MockClose: func() error {
return expected
},
}
if err := txp.Close(); !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
})
}