chore: improve testing and increase coverage (#794)

This diff improves testing and increases coverage inside the
./internal/netxlite and ./internal/tracex packages.

See https://github.com/ooni/probe/issues/2121
This commit is contained in:
Simone Basso
2022-06-04 14:58:48 +02:00
committed by GitHub
parent 464d03184e
commit d5249a6cf7
11 changed files with 622 additions and 349 deletions
-17
View File
@@ -207,23 +207,6 @@ func (state *getaddrinfoState) addrinfoToString(r *C.struct_addrinfo) (string, e
}
}
// staticAddrinfoWithInvalidFamily is an helper to construct an addrinfo struct
// that we use in testing. (We cannot call CGO directly from tests.)
func staticAddrinfoWithInvalidFamily() *C.struct_addrinfo {
var value C.struct_addrinfo // zeroed by Go
value.ai_socktype = C.SOCK_STREAM // this is what the code expects
value.ai_family = 0 // but 0 is not AF_INET{,6}
return &value
}
// staticAddrinfoWithInvalidSocketType is an helper to construct an addrinfo struct
// that we use in testing. (We cannot call CGO directly from tests.)
func staticAddrinfoWithInvalidSocketType() *C.struct_addrinfo {
var value C.struct_addrinfo // zeroed by Go
value.ai_socktype = C.SOCK_DGRAM // not SOCK_STREAM
return &value
}
// getaddrinfoCopyIP copies a net.IP.
//
// This function is adapted from copyIP
+6 -6
View File
@@ -177,29 +177,29 @@ func NewOOHTTPBaseTransport(dialer model.Dialer, tlsDialer model.TLSDialer) mode
// Ensure we correctly forward CloseIdleConnections.
return &httpTransportConnectionsCloser{
HTTPTransport: &stdlibTransport{&oohttp.StdlibTransport{Transport: txp}},
HTTPTransport: &httpTransportStdlib{&oohttp.StdlibTransport{Transport: txp}},
Dialer: dialer,
TLSDialer: tlsDialer,
}
}
// stdlibTransport wraps oohttp.StdlibTransport to add .Network()
type stdlibTransport struct {
type httpTransportStdlib struct {
StdlibTransport *oohttp.StdlibTransport
}
var _ model.HTTPTransport = &stdlibTransport{}
var _ model.HTTPTransport = &httpTransportStdlib{}
func (txp *stdlibTransport) CloseIdleConnections() {
func (txp *httpTransportStdlib) CloseIdleConnections() {
txp.StdlibTransport.CloseIdleConnections()
}
func (txp *stdlibTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) {
return txp.StdlibTransport.RoundTrip(req)
}
// Network implements HTTPTransport.Network.
func (txp *stdlibTransport) Network() string {
func (txp *httpTransportStdlib) Network() string {
return "tcp"
}
+151 -147
View File
@@ -272,7 +272,7 @@ func TestNewHTTPTransport(t *testing.T) {
if tlsWithReadTimeout.TLSDialer != td {
t.Fatal("invalid tls dialer")
}
stdlib := connectionsCloser.HTTPTransport.(*stdlibTransport)
stdlib := connectionsCloser.HTTPTransport.(*httpTransportStdlib)
if !stdlib.StdlibTransport.ForceAttemptHTTP2 {
t.Fatal("invalid ForceAttemptHTTP2")
}
@@ -292,83 +292,13 @@ func TestNewHTTPTransport(t *testing.T) {
}
func TestHTTPDialerWithReadTimeout(t *testing.T) {
t.Run("on success", func(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.Conn{
MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
}
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpConnWithReadTimeout).Conn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
})
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
}
conn, err := d.DialContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
}
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
t.Run("on success", func(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.TLSConn{
Conn: mocks.Conn{
t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.Conn{
MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() {
case true:
@@ -381,81 +311,155 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
},
}
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
}
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
},
},
},
}
ctx := context.Background()
conn, err := d.DialTLSContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
})
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpConnWithReadTimeout).Conn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
})
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
conn, err := d.DialContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
})
}
t.Run("with invalid conn type", func(t *testing.T) {
var called bool
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
called = true
return nil
},
}, nil
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.TLSConn{
Conn: mocks.Conn{
MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, ErrNotTLSConn) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
if !called {
t.Fatal("not called")
}
}
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialTLSContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
})
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
t.Run("with invalid conn type", func(t *testing.T) {
var called bool
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
called = true
return nil
},
}, nil
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, ErrNotTLSConn) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
if !called {
t.Fatal("not called")
}
})
})
}
+7 -1
View File
@@ -140,7 +140,7 @@ func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string,
pconn.Close() // we own it on failure
return nil, err
}
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}, nil
return newQUICConnectionOwnsConn(qconn, pconn), nil
}
func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context,
@@ -183,6 +183,8 @@ type quicDialerHandshakeCompleter struct {
Dialer model.QUICDialer
}
var _ model.QUICDialer = &quicDialerHandshakeCompleter{}
// DialContext implements model.QUICDialer.DialContext.
func (d *quicDialerHandshakeCompleter) DialContext(
ctx context.Context, network, address string,
@@ -214,6 +216,10 @@ type quicConnectionOwnsConn struct {
conn model.UDPLikeConn
}
func newQUICConnectionOwnsConn(qconn quic.EarlyConnection, pconn model.UDPLikeConn) *quicConnectionOwnsConn {
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}
}
// CloseWithError implements quic.EarlyConnection.CloseWithError.
func (qconn *quicConnectionOwnsConn) CloseWithError(
code quic.ApplicationErrorCode, reason string) error {
+73
View File
@@ -302,6 +302,31 @@ func TestQUICDialerQUICGo(t *testing.T) {
t.Fatal("the ServerName field must match")
}
})
t.Run("returns a quicDialerOwnConn in case of success", func(t *testing.T) {
tlsConfig := &tls.Config{
ServerName: "dns.google",
}
fakeconn := &mocks.QUICEarlyConnection{}
systemdialer := quicDialerQUICGo{
QUICListener: &quicListenerStdlib{},
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlyConnection, error) {
return fakeconn, nil
},
}
ctx := context.Background()
qconn, err := systemdialer.DialContext(
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
if err != nil {
t.Fatal(err)
}
connOwner := qconn.(*quicConnectionOwnsConn)
if connOwner.EarlyConnection != fakeconn {
t.Fatal("invalid underlying conn")
}
})
})
}
@@ -406,6 +431,33 @@ func TestQUICDialerHandshakeCompleter(t *testing.T) {
})
}
func TestQUICConnectionOwnsConn(t *testing.T) {
var (
quicClose bool
udpClose bool
)
qconn := &mocks.QUICEarlyConnection{
MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error {
quicClose = true
return nil
},
}
pconn := &mocks.UDPLikeConn{
MockClose: func() error {
udpClose = true
return nil
},
}
conn := newQUICConnectionOwnsConn(qconn, pconn)
conn.CloseWithError(0, "")
if !quicClose {
t.Fatal("did not call qconn.CloseWithError")
}
if !udpClose {
t.Fatal("did not call pconn.Close")
}
}
func TestQUICDialerResolver(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var (
@@ -518,6 +570,27 @@ func TestQUICDialerResolver(t *testing.T) {
t.Fatal("gotTLSConfig.ServerName has not been set")
}
})
t.Run("on success", func(t *testing.T) {
expectedQConn := &mocks.QUICEarlyConnection{}
dialer := &quicDialerResolver{
Resolver: NewResolverStdlib(log.Log),
Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return expectedQConn, nil
},
}}
qconn, err := dialer.DialContext(
context.Background(), "udp", "8.8.4.4:443",
&tls.Config{}, &quic.Config{})
if err != nil {
t.Fatal(err)
}
if qconn != expectedQConn {
t.Fatal("unexpected underlying qconn")
}
})
})
t.Run("lookup host with address", func(t *testing.T) {
+2
View File
@@ -22,6 +22,8 @@ type ParallelResolver struct {
Txp model.DNSTransport
}
var _ model.Resolver = &ParallelResolver{}
// UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is
// not wrapped and you should wrap if before using it.
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
+2
View File
@@ -33,6 +33,8 @@ type SerialResolver struct {
Txp model.DNSTransport
}
var _ model.Resolver = &SerialResolver{}
// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance.
func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver {
return &SerialResolver{