chore: improve testing and increase coverage (#794)

This diff improves testing and increases coverage inside the
./internal/netxlite and ./internal/tracex packages.

See https://github.com/ooni/probe/issues/2121
This commit is contained in:
Simone Basso 2022-06-04 14:58:48 +02:00 committed by GitHub
parent 464d03184e
commit d5249a6cf7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 622 additions and 349 deletions

View File

@ -207,23 +207,6 @@ func (state *getaddrinfoState) addrinfoToString(r *C.struct_addrinfo) (string, e
} }
} }
// staticAddrinfoWithInvalidFamily is an helper to construct an addrinfo struct
// that we use in testing. (We cannot call CGO directly from tests.)
func staticAddrinfoWithInvalidFamily() *C.struct_addrinfo {
var value C.struct_addrinfo // zeroed by Go
value.ai_socktype = C.SOCK_STREAM // this is what the code expects
value.ai_family = 0 // but 0 is not AF_INET{,6}
return &value
}
// staticAddrinfoWithInvalidSocketType is an helper to construct an addrinfo struct
// that we use in testing. (We cannot call CGO directly from tests.)
func staticAddrinfoWithInvalidSocketType() *C.struct_addrinfo {
var value C.struct_addrinfo // zeroed by Go
value.ai_socktype = C.SOCK_DGRAM // not SOCK_STREAM
return &value
}
// getaddrinfoCopyIP copies a net.IP. // getaddrinfoCopyIP copies a net.IP.
// //
// This function is adapted from copyIP // This function is adapted from copyIP

View File

@ -177,29 +177,29 @@ func NewOOHTTPBaseTransport(dialer model.Dialer, tlsDialer model.TLSDialer) mode
// Ensure we correctly forward CloseIdleConnections. // Ensure we correctly forward CloseIdleConnections.
return &httpTransportConnectionsCloser{ return &httpTransportConnectionsCloser{
HTTPTransport: &stdlibTransport{&oohttp.StdlibTransport{Transport: txp}}, HTTPTransport: &httpTransportStdlib{&oohttp.StdlibTransport{Transport: txp}},
Dialer: dialer, Dialer: dialer,
TLSDialer: tlsDialer, TLSDialer: tlsDialer,
} }
} }
// stdlibTransport wraps oohttp.StdlibTransport to add .Network() // stdlibTransport wraps oohttp.StdlibTransport to add .Network()
type stdlibTransport struct { type httpTransportStdlib struct {
StdlibTransport *oohttp.StdlibTransport StdlibTransport *oohttp.StdlibTransport
} }
var _ model.HTTPTransport = &stdlibTransport{} var _ model.HTTPTransport = &httpTransportStdlib{}
func (txp *stdlibTransport) CloseIdleConnections() { func (txp *httpTransportStdlib) CloseIdleConnections() {
txp.StdlibTransport.CloseIdleConnections() txp.StdlibTransport.CloseIdleConnections()
} }
func (txp *stdlibTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) {
return txp.StdlibTransport.RoundTrip(req) return txp.StdlibTransport.RoundTrip(req)
} }
// Network implements HTTPTransport.Network. // Network implements HTTPTransport.Network.
func (txp *stdlibTransport) Network() string { func (txp *httpTransportStdlib) Network() string {
return "tcp" return "tcp"
} }

View File

@ -272,7 +272,7 @@ func TestNewHTTPTransport(t *testing.T) {
if tlsWithReadTimeout.TLSDialer != td { if tlsWithReadTimeout.TLSDialer != td {
t.Fatal("invalid tls dialer") t.Fatal("invalid tls dialer")
} }
stdlib := connectionsCloser.HTTPTransport.(*stdlibTransport) stdlib := connectionsCloser.HTTPTransport.(*httpTransportStdlib)
if !stdlib.StdlibTransport.ForceAttemptHTTP2 { if !stdlib.StdlibTransport.ForceAttemptHTTP2 {
t.Fatal("invalid ForceAttemptHTTP2") t.Fatal("invalid ForceAttemptHTTP2")
} }
@ -292,83 +292,13 @@ func TestNewHTTPTransport(t *testing.T) {
} }
func TestHTTPDialerWithReadTimeout(t *testing.T) { func TestHTTPDialerWithReadTimeout(t *testing.T) {
t.Run("on success", func(t *testing.T) { t.Run("DialContext", func(t *testing.T) {
var ( t.Run("on success", func(t *testing.T) {
calledWithZeroTime bool var (
calledWithNonZeroTime bool calledWithZeroTime bool
) calledWithNonZeroTime bool
origConn := &mocks.Conn{ )
MockSetReadDeadline: func(t time.Time) error { origConn := &mocks.Conn{
switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
}
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return origConn, nil
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "", "")
if err != nil {
t.Fatal(err)
}
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpConnWithReadTimeout).Conn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
})
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &httpDialerWithReadTimeout{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
}
conn, err := d.DialContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
}
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
t.Run("on success", func(t *testing.T) {
var (
calledWithZeroTime bool
calledWithNonZeroTime bool
)
origConn := &mocks.TLSConn{
Conn: mocks.Conn{
MockSetReadDeadline: func(t time.Time) error { MockSetReadDeadline: func(t time.Time) error {
switch t.IsZero() { switch t.IsZero() {
case true: case true:
@ -381,81 +311,155 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
MockRead: func(b []byte) (int, error) { MockRead: func(b []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
}, },
}, }
} d := &httpDialerWithReadTimeout{
d := &httpTLSDialerWithReadTimeout{ Dialer: &mocks.Dialer{
TLSDialer: &mocks.TLSDialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { return origConn, nil
return origConn, nil },
}, },
}, }
} ctx := context.Background()
ctx := context.Background() conn, err := d.DialContext(ctx, "", "")
conn, err := d.DialTLSContext(ctx, "", "") if err != nil {
if err != nil { t.Fatal(err)
t.Fatal(err) }
} if _, okay := conn.(*httpConnWithReadTimeout); !okay {
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { t.Fatal("invalid conn type")
t.Fatal("invalid conn type") }
} if conn.(*httpConnWithReadTimeout).Conn != origConn {
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { t.Fatal("invalid origin conn")
t.Fatal("invalid origin conn") }
} b := make([]byte, 1024)
b := make([]byte, 1024) count, err := conn.Read(b)
count, err := conn.Read(b) if !errors.Is(err, io.EOF) {
if !errors.Is(err, io.EOF) { t.Fatal("invalid error")
t.Fatal("invalid error") }
} if count != 0 {
if count != 0 { t.Fatal("invalid count")
t.Fatal("invalid count") }
} if !calledWithZeroTime || !calledWithNonZeroTime {
if !calledWithZeroTime || !calledWithNonZeroTime { t.Fatal("not called")
t.Fatal("not called") }
} })
})
t.Run("on failure", func(t *testing.T) { t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
d := &httpTLSDialerWithReadTimeout{ d := &httpDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{ Dialer: &mocks.Dialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected return nil, expected
},
}, },
}, }
} conn, err := d.DialContext(context.Background(), "", "")
conn, err := d.DialTLSContext(context.Background(), "", "") if !errors.Is(err, expected) {
if !errors.Is(err, expected) { t.Fatal("not the error we expected")
t.Fatal("not the error we expected") }
} if conn != nil {
if conn != nil { t.Fatal("expected nil conn here")
t.Fatal("expected nil conn here") }
} })
}) })
}
t.Run("with invalid conn type", func(t *testing.T) { func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
var called bool t.Run("DialContext", func(t *testing.T) {
d := &httpTLSDialerWithReadTimeout{ t.Run("on success", func(t *testing.T) {
TLSDialer: &mocks.TLSDialer{ var (
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { calledWithZeroTime bool
return &mocks.Conn{ calledWithNonZeroTime bool
MockClose: func() error { )
called = true origConn := &mocks.TLSConn{
return nil Conn: mocks.Conn{
}, MockSetReadDeadline: func(t time.Time) error {
}, nil switch t.IsZero() {
case true:
calledWithZeroTime = true
case false:
calledWithNonZeroTime = true
}
return nil
},
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
}, },
}, }
} d := &httpTLSDialerWithReadTimeout{
conn, err := d.DialTLSContext(context.Background(), "", "") TLSDialer: &mocks.TLSDialer{
if !errors.Is(err, ErrNotTLSConn) { MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
t.Fatal("not the error we expected") return origConn, nil
} },
if conn != nil { },
t.Fatal("expected nil conn here") }
} ctx := context.Background()
if !called { conn, err := d.DialTLSContext(ctx, "", "")
t.Fatal("not called") if err != nil {
} t.Fatal(err)
}
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
t.Fatal("invalid conn type")
}
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
t.Fatal("invalid origin conn")
}
b := make([]byte, 1024)
count, err := conn.Read(b)
if !errors.Is(err, io.EOF) {
t.Fatal("invalid error")
}
if count != 0 {
t.Fatal("invalid count")
}
if !calledWithZeroTime || !calledWithNonZeroTime {
t.Fatal("not called")
}
})
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, expected
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
t.Run("with invalid conn type", func(t *testing.T) {
var called bool
d := &httpTLSDialerWithReadTimeout{
TLSDialer: &mocks.TLSDialer{
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
called = true
return nil
},
}, nil
},
},
}
conn, err := d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, ErrNotTLSConn) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
if !called {
t.Fatal("not called")
}
})
}) })
} }

View File

@ -140,7 +140,7 @@ func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string,
pconn.Close() // we own it on failure pconn.Close() // we own it on failure
return nil, err return nil, err
} }
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}, nil return newQUICConnectionOwnsConn(qconn, pconn), nil
} }
func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context, func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context,
@ -183,6 +183,8 @@ type quicDialerHandshakeCompleter struct {
Dialer model.QUICDialer Dialer model.QUICDialer
} }
var _ model.QUICDialer = &quicDialerHandshakeCompleter{}
// DialContext implements model.QUICDialer.DialContext. // DialContext implements model.QUICDialer.DialContext.
func (d *quicDialerHandshakeCompleter) DialContext( func (d *quicDialerHandshakeCompleter) DialContext(
ctx context.Context, network, address string, ctx context.Context, network, address string,
@ -214,6 +216,10 @@ type quicConnectionOwnsConn struct {
conn model.UDPLikeConn conn model.UDPLikeConn
} }
func newQUICConnectionOwnsConn(qconn quic.EarlyConnection, pconn model.UDPLikeConn) *quicConnectionOwnsConn {
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}
}
// CloseWithError implements quic.EarlyConnection.CloseWithError. // CloseWithError implements quic.EarlyConnection.CloseWithError.
func (qconn *quicConnectionOwnsConn) CloseWithError( func (qconn *quicConnectionOwnsConn) CloseWithError(
code quic.ApplicationErrorCode, reason string) error { code quic.ApplicationErrorCode, reason string) error {

View File

@ -302,6 +302,31 @@ func TestQUICDialerQUICGo(t *testing.T) {
t.Fatal("the ServerName field must match") t.Fatal("the ServerName field must match")
} }
}) })
t.Run("returns a quicDialerOwnConn in case of success", func(t *testing.T) {
tlsConfig := &tls.Config{
ServerName: "dns.google",
}
fakeconn := &mocks.QUICEarlyConnection{}
systemdialer := quicDialerQUICGo{
QUICListener: &quicListenerStdlib{},
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlyConnection, error) {
return fakeconn, nil
},
}
ctx := context.Background()
qconn, err := systemdialer.DialContext(
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
if err != nil {
t.Fatal(err)
}
connOwner := qconn.(*quicConnectionOwnsConn)
if connOwner.EarlyConnection != fakeconn {
t.Fatal("invalid underlying conn")
}
})
}) })
} }
@ -406,6 +431,33 @@ func TestQUICDialerHandshakeCompleter(t *testing.T) {
}) })
} }
func TestQUICConnectionOwnsConn(t *testing.T) {
var (
quicClose bool
udpClose bool
)
qconn := &mocks.QUICEarlyConnection{
MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error {
quicClose = true
return nil
},
}
pconn := &mocks.UDPLikeConn{
MockClose: func() error {
udpClose = true
return nil
},
}
conn := newQUICConnectionOwnsConn(qconn, pconn)
conn.CloseWithError(0, "")
if !quicClose {
t.Fatal("did not call qconn.CloseWithError")
}
if !udpClose {
t.Fatal("did not call pconn.Close")
}
}
func TestQUICDialerResolver(t *testing.T) { func TestQUICDialerResolver(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
var ( var (
@ -518,6 +570,27 @@ func TestQUICDialerResolver(t *testing.T) {
t.Fatal("gotTLSConfig.ServerName has not been set") t.Fatal("gotTLSConfig.ServerName has not been set")
} }
}) })
t.Run("on success", func(t *testing.T) {
expectedQConn := &mocks.QUICEarlyConnection{}
dialer := &quicDialerResolver{
Resolver: NewResolverStdlib(log.Log),
Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return expectedQConn, nil
},
}}
qconn, err := dialer.DialContext(
context.Background(), "udp", "8.8.4.4:443",
&tls.Config{}, &quic.Config{})
if err != nil {
t.Fatal(err)
}
if qconn != expectedQConn {
t.Fatal("unexpected underlying qconn")
}
})
}) })
t.Run("lookup host with address", func(t *testing.T) { t.Run("lookup host with address", func(t *testing.T) {

View File

@ -22,6 +22,8 @@ type ParallelResolver struct {
Txp model.DNSTransport Txp model.DNSTransport
} }
var _ model.Resolver = &ParallelResolver{}
// UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is // UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is
// not wrapped and you should wrap if before using it. // not wrapped and you should wrap if before using it.
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver { func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {

View File

@ -33,6 +33,8 @@ type SerialResolver struct {
Txp model.DNSTransport Txp model.DNSTransport
} }
var _ model.Resolver = &SerialResolver{}
// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance. // NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance.
func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver { func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver {
return &SerialResolver{ return &SerialResolver{

View File

@ -0,0 +1,36 @@
package tracex
import "testing"
func TestUnusedEventsNames(t *testing.T) {
// Tests that we don't break the names of events we're currently
// not getting the name of directly even if they're saved.
t.Run("EventQUICHandshakeStart", func(t *testing.T) {
ev := &EventQUICHandshakeStart{}
if ev.Name() != "quic_handshake_start" {
t.Fatal("invalid event name")
}
})
t.Run("EventQUICHandshakeDone", func(t *testing.T) {
ev := &EventQUICHandshakeDone{}
if ev.Name() != "quic_handshake_done" {
t.Fatal("invalid event name")
}
})
t.Run("EventTLSHandshakeStart", func(t *testing.T) {
ev := &EventTLSHandshakeStart{}
if ev.Name() != "tls_handshake_start" {
t.Fatal("invalid event name")
}
})
t.Run("EventTLSHandshakeDone", func(t *testing.T) {
ev := &EventTLSHandshakeDone{}
if ev.Name() != "tls_handshake_done" {
t.Fatal("invalid event name")
}
})
}

View File

@ -177,6 +177,14 @@ func TestQUICDialerSaver(t *testing.T) {
}) })
} }
func TestWrapQUICListener(t *testing.T) {
var saver *Saver
ql := &mocks.QUICListener{}
if saver.WrapQUICListener(ql) != ql {
t.Fatal("unexpected result")
}
}
func TestQUICListenerSaver(t *testing.T) { func TestQUICListenerSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) { t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")

View File

@ -15,219 +15,370 @@ import (
"github.com/ooni/probe-cli/v3/internal/runtimex" "github.com/ooni/probe-cli/v3/internal/runtimex"
) )
func TestWrapResolver(t *testing.T) {
var saver *Saver
reso := &mocks.Resolver{}
if saver.WrapResolver(reso) != reso {
t.Fatal("unexpected result")
}
}
func TestResolverSaver(t *testing.T) { func TestResolverSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
expected := netxlite.ErrOODNSNoSuchHost t.Run("on failure", func(t *testing.T) {
expected := netxlite.ErrOODNSNoSuchHost
saver := &Saver{}
reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected))
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name() != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if ev[1].Value().Addresses != nil {
t.Fatal("unexpected Addresses")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
t.Fatal("unexpected Err")
}
if ev[1].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name() != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
t.Run("on success", func(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &Saver{}
reso := saver.WrapResolver(newFakeResolverWithResult(expected))
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if err != nil {
t.Fatal("expected nil error here")
}
if !reflect.DeepEqual(addrs, expected) {
t.Fatal("not the result we expected")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name() != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !reflect.DeepEqual(ev[1].Value().Addresses, expected) {
t.Fatal("unexpected Addresses")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err.IsNotNil() {
t.Fatal("unexpected Err")
}
if ev[1].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name() != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
})
t.Run("Network", func(t *testing.T) {
saver := &Saver{} saver := &Saver{}
reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected)) child := &mocks.Resolver{
addrs, err := reso.LookupHost(context.Background(), "www.google.com") MockNetwork: func() string {
if !errors.Is(err, expected) { return "x"
t.Fatal("not the error we expected") },
} }
if addrs != nil { reso := saver.WrapResolver(child)
t.Fatal("expected nil address here") if reso.Network() != "x" {
} t.Fatal("unexpected result")
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name() != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if ev[1].Value().Addresses != nil {
t.Fatal("unexpected Addresses")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
t.Fatal("unexpected Err")
}
if ev[1].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name() != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
} }
}) })
t.Run("on success", func(t *testing.T) { t.Run("Address", func(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &Saver{} saver := &Saver{}
reso := saver.WrapResolver(newFakeResolverWithResult(expected)) child := &mocks.Resolver{
addrs, err := reso.LookupHost(context.Background(), "www.google.com") MockAddress: func() string {
if err != nil { return "x"
t.Fatal("expected nil error here") },
} }
if !reflect.DeepEqual(addrs, expected) { reso := saver.WrapResolver(child)
t.Fatal("not the result we expected") if reso.Address() != "x" {
t.Fatal("unexpected result")
} }
ev := saver.Read() })
if len(ev) != 2 {
t.Fatal("expected number of events") t.Run("LookupHTTPS", func(t *testing.T) {
expected := errors.New("mocked")
saver := &Saver{}
child := &mocks.Resolver{
MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
return nil, expected
},
} }
if ev[0].Value().Hostname != "www.google.com" { reso := saver.WrapResolver(child)
t.Fatal("unexpected Hostname") https, err := reso.LookupHTTPS(context.Background(), "dns.google")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
} }
if ev[0].Name() != "resolve_start" { if https != nil {
t.Fatal("unexpected name") t.Fatal("expected nil")
} }
if !ev[0].Value().Time.Before(time.Now()) { })
t.Fatal("the saved time is wrong")
t.Run("LookupNS", func(t *testing.T) {
expected := errors.New("mocked")
saver := &Saver{}
child := &mocks.Resolver{
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
return nil, expected
},
} }
if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { reso := saver.WrapResolver(child)
t.Fatal("unexpected Addresses") ns, err := reso.LookupNS(context.Background(), "dns.google")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
} }
if ev[1].Value().Duration <= 0 { if len(ns) != 0 {
t.Fatal("unexpected Duration") t.Fatal("expected zero length array")
} }
if ev[1].Value().Err.IsNotNil() { })
t.Fatal("unexpected Err")
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
saver := &Saver{}
child := &mocks.Resolver{
MockCloseIdleConnections: func() {
called = true
},
} }
if ev[1].Value().Hostname != "www.google.com" { reso := saver.WrapResolver(child)
t.Fatal("unexpected Hostname") reso.CloseIdleConnections()
} if !called {
if ev[1].Name() != "resolve_done" { t.Fatal("not called")
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
} }
}) })
} }
func TestWrapDNSTransport(t *testing.T) {
var saver *Saver
txp := &mocks.DNSTransport{}
if saver.WrapDNSTransport(txp) != txp {
t.Fatal("unexpected result")
}
}
func TestDNSTransportSaver(t *testing.T) { func TestDNSTransportSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
expected := netxlite.ErrOODNSNoSuchHost t.Run("on failure", func(t *testing.T) {
saver := &Saver{} expected := netxlite.ErrOODNSNoSuchHost
txp := saver.WrapDNSTransport(&mocks.DNSTransport{ saver := &Saver{}
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { txp := saver.WrapDNSTransport(&mocks.DNSTransport{
return nil, expected MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
}, return nil, expected
MockNetwork: func() string { },
return "fake" MockNetwork: func() string {
}, return "fake"
MockAddress: func() string { },
return "" MockAddress: func() string {
}, return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].Value().DNSResponse != nil {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
}) })
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{ t.Run("on success", func(t *testing.T) {
MockBytes: func() ([]byte, error) { expected := []byte{0xef, 0xbe, 0xad, 0xde}
return rawQuery, nil saver := &Saver{}
response := &mocks.DNSResponse{
MockBytes: func() []byte {
return expected
},
}
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return response, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal("we expected nil error here")
}
if !bytes.Equal(reply.Bytes(), expected) {
t.Fatal("expected another reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].Value().DNSResponse, expected) {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err.IsNotNil() {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
})
t.Run("Network", func(t *testing.T) {
saver := &Saver{}
child := &mocks.DNSTransport{
MockNetwork: func() string {
return "x"
}, },
} }
reply, err := txp.RoundTrip(context.Background(), query) txp := saver.WrapDNSTransport(child)
if !errors.Is(err, expected) { if txp.Network() != "x" {
t.Fatal("not the error we expected") t.Fatal("unexpected result")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].Value().DNSResponse != nil {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
} }
}) })
t.Run("on success", func(t *testing.T) { t.Run("Address", func(t *testing.T) {
expected := []byte{0xef, 0xbe, 0xad, 0xde}
saver := &Saver{} saver := &Saver{}
response := &mocks.DNSResponse{ child := &mocks.DNSTransport{
MockBytes: func() []byte {
return expected
},
}
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return response, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string { MockAddress: func() string {
return "" return "x"
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
}, },
} }
reply, err := txp.RoundTrip(context.Background(), query) txp := saver.WrapDNSTransport(child)
if err != nil { if txp.Address() != "x" {
t.Fatal("we expected nil error here") t.Fatal("unexpected result")
} }
if !bytes.Equal(reply.Bytes(), expected) { })
t.Fatal("expected another reply here")
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
saver := &Saver{}
child := &mocks.DNSTransport{
MockCloseIdleConnections: func() {
called = true
},
} }
ev := saver.Read() txp := saver.WrapDNSTransport(child)
if len(ev) != 2 { txp.CloseIdleConnections()
t.Fatal("expected number of events") if !called {
t.Fatal("not called")
} }
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { })
t.Fatal("unexpected DNSQuery")
t.Run("RequiresPadding", func(t *testing.T) {
saver := &Saver{}
child := &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return true
},
} }
if ev[0].Name() != "dns_round_trip_start" { txp := saver.WrapDNSTransport(child)
t.Fatal("unexpected name") if !txp.RequiresPadding() {
} t.Fatal("unexpected result")
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].Value().DNSResponse, expected) {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err.IsNotNil() {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
} }
}) })
} }

View File

@ -12,6 +12,14 @@ import (
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestWrapTLSHandshaker(t *testing.T) {
var saver *Saver
thx := &mocks.TLSHandshaker{}
if saver.WrapTLSHandshaker(thx) != thx {
t.Fatal("unexpected result")
}
}
func TestTLSHandshakerSaver(t *testing.T) { func TestTLSHandshakerSaver(t *testing.T) {
t.Run("Handshake", func(t *testing.T) { t.Run("Handshake", func(t *testing.T) {