chore: improve testing and increase coverage (#794)
This diff improves testing and increases coverage inside the ./internal/netxlite and ./internal/tracex packages. See https://github.com/ooni/probe/issues/2121
This commit is contained in:
@@ -207,23 +207,6 @@ func (state *getaddrinfoState) addrinfoToString(r *C.struct_addrinfo) (string, e
|
||||
}
|
||||
}
|
||||
|
||||
// staticAddrinfoWithInvalidFamily is an helper to construct an addrinfo struct
|
||||
// that we use in testing. (We cannot call CGO directly from tests.)
|
||||
func staticAddrinfoWithInvalidFamily() *C.struct_addrinfo {
|
||||
var value C.struct_addrinfo // zeroed by Go
|
||||
value.ai_socktype = C.SOCK_STREAM // this is what the code expects
|
||||
value.ai_family = 0 // but 0 is not AF_INET{,6}
|
||||
return &value
|
||||
}
|
||||
|
||||
// staticAddrinfoWithInvalidSocketType is an helper to construct an addrinfo struct
|
||||
// that we use in testing. (We cannot call CGO directly from tests.)
|
||||
func staticAddrinfoWithInvalidSocketType() *C.struct_addrinfo {
|
||||
var value C.struct_addrinfo // zeroed by Go
|
||||
value.ai_socktype = C.SOCK_DGRAM // not SOCK_STREAM
|
||||
return &value
|
||||
}
|
||||
|
||||
// getaddrinfoCopyIP copies a net.IP.
|
||||
//
|
||||
// This function is adapted from copyIP
|
||||
|
||||
@@ -177,29 +177,29 @@ func NewOOHTTPBaseTransport(dialer model.Dialer, tlsDialer model.TLSDialer) mode
|
||||
|
||||
// Ensure we correctly forward CloseIdleConnections.
|
||||
return &httpTransportConnectionsCloser{
|
||||
HTTPTransport: &stdlibTransport{&oohttp.StdlibTransport{Transport: txp}},
|
||||
HTTPTransport: &httpTransportStdlib{&oohttp.StdlibTransport{Transport: txp}},
|
||||
Dialer: dialer,
|
||||
TLSDialer: tlsDialer,
|
||||
}
|
||||
}
|
||||
|
||||
// stdlibTransport wraps oohttp.StdlibTransport to add .Network()
|
||||
type stdlibTransport struct {
|
||||
type httpTransportStdlib struct {
|
||||
StdlibTransport *oohttp.StdlibTransport
|
||||
}
|
||||
|
||||
var _ model.HTTPTransport = &stdlibTransport{}
|
||||
var _ model.HTTPTransport = &httpTransportStdlib{}
|
||||
|
||||
func (txp *stdlibTransport) CloseIdleConnections() {
|
||||
func (txp *httpTransportStdlib) CloseIdleConnections() {
|
||||
txp.StdlibTransport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func (txp *stdlibTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return txp.StdlibTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// Network implements HTTPTransport.Network.
|
||||
func (txp *stdlibTransport) Network() string {
|
||||
func (txp *httpTransportStdlib) Network() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
|
||||
+151
-147
@@ -272,7 +272,7 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||
if tlsWithReadTimeout.TLSDialer != td {
|
||||
t.Fatal("invalid tls dialer")
|
||||
}
|
||||
stdlib := connectionsCloser.HTTPTransport.(*stdlibTransport)
|
||||
stdlib := connectionsCloser.HTTPTransport.(*httpTransportStdlib)
|
||||
if !stdlib.StdlibTransport.ForceAttemptHTTP2 {
|
||||
t.Fatal("invalid ForceAttemptHTTP2")
|
||||
}
|
||||
@@ -292,83 +292,13 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHTTPDialerWithReadTimeout(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
var (
|
||||
calledWithZeroTime bool
|
||||
calledWithNonZeroTime bool
|
||||
)
|
||||
origConn := &mocks.Conn{
|
||||
MockSetReadDeadline: func(t time.Time) error {
|
||||
switch t.IsZero() {
|
||||
case true:
|
||||
calledWithZeroTime = true
|
||||
case false:
|
||||
calledWithNonZeroTime = true
|
||||
}
|
||||
return nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
}
|
||||
d := &httpDialerWithReadTimeout{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return origConn, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialContext(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
|
||||
t.Fatal("invalid conn type")
|
||||
}
|
||||
if conn.(*httpConnWithReadTimeout).Conn != origConn {
|
||||
t.Fatal("invalid origin conn")
|
||||
}
|
||||
b := make([]byte, 1024)
|
||||
count, err := conn.Read(b)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("invalid error")
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := &httpDialerWithReadTimeout{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "", "")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
var (
|
||||
calledWithZeroTime bool
|
||||
calledWithNonZeroTime bool
|
||||
)
|
||||
origConn := &mocks.TLSConn{
|
||||
Conn: mocks.Conn{
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
var (
|
||||
calledWithZeroTime bool
|
||||
calledWithNonZeroTime bool
|
||||
)
|
||||
origConn := &mocks.Conn{
|
||||
MockSetReadDeadline: func(t time.Time) error {
|
||||
switch t.IsZero() {
|
||||
case true:
|
||||
@@ -381,81 +311,155 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
},
|
||||
}
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return origConn, nil
|
||||
}
|
||||
d := &httpDialerWithReadTimeout{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return origConn, nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialTLSContext(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
|
||||
t.Fatal("invalid conn type")
|
||||
}
|
||||
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
|
||||
t.Fatal("invalid origin conn")
|
||||
}
|
||||
b := make([]byte, 1024)
|
||||
count, err := conn.Read(b)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("invalid error")
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialContext(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
|
||||
t.Fatal("invalid conn type")
|
||||
}
|
||||
if conn.(*httpConnWithReadTimeout).Conn != origConn {
|
||||
t.Fatal("invalid origin conn")
|
||||
}
|
||||
b := make([]byte, 1024)
|
||||
count, err := conn.Read(b)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("invalid error")
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, expected
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := &httpDialerWithReadTimeout{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "", "")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("with invalid conn type", func(t *testing.T) {
|
||||
var called bool
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
var (
|
||||
calledWithZeroTime bool
|
||||
calledWithNonZeroTime bool
|
||||
)
|
||||
origConn := &mocks.TLSConn{
|
||||
Conn: mocks.Conn{
|
||||
MockSetReadDeadline: func(t time.Time) error {
|
||||
switch t.IsZero() {
|
||||
case true:
|
||||
calledWithZeroTime = true
|
||||
case false:
|
||||
calledWithNonZeroTime = true
|
||||
}
|
||||
return nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||
if !errors.Is(err, ErrNotTLSConn) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
}
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return origConn, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, err := d.DialTLSContext(ctx, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
|
||||
t.Fatal("invalid conn type")
|
||||
}
|
||||
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
|
||||
t.Fatal("invalid origin conn")
|
||||
}
|
||||
b := make([]byte, 1024)
|
||||
count, err := conn.Read(b)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("invalid error")
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with invalid conn type", func(t *testing.T) {
|
||||
var called bool
|
||||
d := &httpTLSDialerWithReadTimeout{
|
||||
TLSDialer: &mocks.TLSDialer{
|
||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||
if !errors.Is(err, ErrNotTLSConn) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string,
|
||||
pconn.Close() // we own it on failure
|
||||
return nil, err
|
||||
}
|
||||
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}, nil
|
||||
return newQUICConnectionOwnsConn(qconn, pconn), nil
|
||||
}
|
||||
|
||||
func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context,
|
||||
@@ -183,6 +183,8 @@ type quicDialerHandshakeCompleter struct {
|
||||
Dialer model.QUICDialer
|
||||
}
|
||||
|
||||
var _ model.QUICDialer = &quicDialerHandshakeCompleter{}
|
||||
|
||||
// DialContext implements model.QUICDialer.DialContext.
|
||||
func (d *quicDialerHandshakeCompleter) DialContext(
|
||||
ctx context.Context, network, address string,
|
||||
@@ -214,6 +216,10 @@ type quicConnectionOwnsConn struct {
|
||||
conn model.UDPLikeConn
|
||||
}
|
||||
|
||||
func newQUICConnectionOwnsConn(qconn quic.EarlyConnection, pconn model.UDPLikeConn) *quicConnectionOwnsConn {
|
||||
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}
|
||||
}
|
||||
|
||||
// CloseWithError implements quic.EarlyConnection.CloseWithError.
|
||||
func (qconn *quicConnectionOwnsConn) CloseWithError(
|
||||
code quic.ApplicationErrorCode, reason string) error {
|
||||
|
||||
@@ -302,6 +302,31 @@ func TestQUICDialerQUICGo(t *testing.T) {
|
||||
t.Fatal("the ServerName field must match")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns a quicDialerOwnConn in case of success", func(t *testing.T) {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: "dns.google",
|
||||
}
|
||||
fakeconn := &mocks.QUICEarlyConnection{}
|
||||
systemdialer := quicDialerQUICGo{
|
||||
QUICListener: &quicListenerStdlib{},
|
||||
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||
quicConfig *quic.Config) (quic.EarlyConnection, error) {
|
||||
return fakeconn, nil
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
qconn, err := systemdialer.DialContext(
|
||||
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
connOwner := qconn.(*quicConnectionOwnsConn)
|
||||
if connOwner.EarlyConnection != fakeconn {
|
||||
t.Fatal("invalid underlying conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -406,6 +431,33 @@ func TestQUICDialerHandshakeCompleter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestQUICConnectionOwnsConn(t *testing.T) {
|
||||
var (
|
||||
quicClose bool
|
||||
udpClose bool
|
||||
)
|
||||
qconn := &mocks.QUICEarlyConnection{
|
||||
MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error {
|
||||
quicClose = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
pconn := &mocks.UDPLikeConn{
|
||||
MockClose: func() error {
|
||||
udpClose = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
conn := newQUICConnectionOwnsConn(qconn, pconn)
|
||||
conn.CloseWithError(0, "")
|
||||
if !quicClose {
|
||||
t.Fatal("did not call qconn.CloseWithError")
|
||||
}
|
||||
if !udpClose {
|
||||
t.Fatal("did not call pconn.Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerResolver(t *testing.T) {
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var (
|
||||
@@ -518,6 +570,27 @@ func TestQUICDialerResolver(t *testing.T) {
|
||||
t.Fatal("gotTLSConfig.ServerName has not been set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedQConn := &mocks.QUICEarlyConnection{}
|
||||
dialer := &quicDialerResolver{
|
||||
Resolver: NewResolverStdlib(log.Log),
|
||||
Dialer: &mocks.QUICDialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string,
|
||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
|
||||
return expectedQConn, nil
|
||||
},
|
||||
}}
|
||||
qconn, err := dialer.DialContext(
|
||||
context.Background(), "udp", "8.8.4.4:443",
|
||||
&tls.Config{}, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if qconn != expectedQConn {
|
||||
t.Fatal("unexpected underlying qconn")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("lookup host with address", func(t *testing.T) {
|
||||
|
||||
@@ -22,6 +22,8 @@ type ParallelResolver struct {
|
||||
Txp model.DNSTransport
|
||||
}
|
||||
|
||||
var _ model.Resolver = &ParallelResolver{}
|
||||
|
||||
// UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is
|
||||
// not wrapped and you should wrap if before using it.
|
||||
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
|
||||
|
||||
@@ -33,6 +33,8 @@ type SerialResolver struct {
|
||||
Txp model.DNSTransport
|
||||
}
|
||||
|
||||
var _ model.Resolver = &SerialResolver{}
|
||||
|
||||
// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance.
|
||||
func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver {
|
||||
return &SerialResolver{
|
||||
|
||||
Reference in New Issue
Block a user