chore: improve testing and increase coverage (#794)
This diff improves testing and increases coverage inside the ./internal/netxlite and ./internal/tracex packages. See https://github.com/ooni/probe/issues/2121
This commit is contained in:
parent
464d03184e
commit
d5249a6cf7
|
@ -207,23 +207,6 @@ func (state *getaddrinfoState) addrinfoToString(r *C.struct_addrinfo) (string, e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// staticAddrinfoWithInvalidFamily is an helper to construct an addrinfo struct
|
|
||||||
// that we use in testing. (We cannot call CGO directly from tests.)
|
|
||||||
func staticAddrinfoWithInvalidFamily() *C.struct_addrinfo {
|
|
||||||
var value C.struct_addrinfo // zeroed by Go
|
|
||||||
value.ai_socktype = C.SOCK_STREAM // this is what the code expects
|
|
||||||
value.ai_family = 0 // but 0 is not AF_INET{,6}
|
|
||||||
return &value
|
|
||||||
}
|
|
||||||
|
|
||||||
// staticAddrinfoWithInvalidSocketType is an helper to construct an addrinfo struct
|
|
||||||
// that we use in testing. (We cannot call CGO directly from tests.)
|
|
||||||
func staticAddrinfoWithInvalidSocketType() *C.struct_addrinfo {
|
|
||||||
var value C.struct_addrinfo // zeroed by Go
|
|
||||||
value.ai_socktype = C.SOCK_DGRAM // not SOCK_STREAM
|
|
||||||
return &value
|
|
||||||
}
|
|
||||||
|
|
||||||
// getaddrinfoCopyIP copies a net.IP.
|
// getaddrinfoCopyIP copies a net.IP.
|
||||||
//
|
//
|
||||||
// This function is adapted from copyIP
|
// This function is adapted from copyIP
|
||||||
|
|
|
@ -177,29 +177,29 @@ func NewOOHTTPBaseTransport(dialer model.Dialer, tlsDialer model.TLSDialer) mode
|
||||||
|
|
||||||
// Ensure we correctly forward CloseIdleConnections.
|
// Ensure we correctly forward CloseIdleConnections.
|
||||||
return &httpTransportConnectionsCloser{
|
return &httpTransportConnectionsCloser{
|
||||||
HTTPTransport: &stdlibTransport{&oohttp.StdlibTransport{Transport: txp}},
|
HTTPTransport: &httpTransportStdlib{&oohttp.StdlibTransport{Transport: txp}},
|
||||||
Dialer: dialer,
|
Dialer: dialer,
|
||||||
TLSDialer: tlsDialer,
|
TLSDialer: tlsDialer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// stdlibTransport wraps oohttp.StdlibTransport to add .Network()
|
// stdlibTransport wraps oohttp.StdlibTransport to add .Network()
|
||||||
type stdlibTransport struct {
|
type httpTransportStdlib struct {
|
||||||
StdlibTransport *oohttp.StdlibTransport
|
StdlibTransport *oohttp.StdlibTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.HTTPTransport = &stdlibTransport{}
|
var _ model.HTTPTransport = &httpTransportStdlib{}
|
||||||
|
|
||||||
func (txp *stdlibTransport) CloseIdleConnections() {
|
func (txp *httpTransportStdlib) CloseIdleConnections() {
|
||||||
txp.StdlibTransport.CloseIdleConnections()
|
txp.StdlibTransport.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *stdlibTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return txp.StdlibTransport.RoundTrip(req)
|
return txp.StdlibTransport.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network implements HTTPTransport.Network.
|
// Network implements HTTPTransport.Network.
|
||||||
func (txp *stdlibTransport) Network() string {
|
func (txp *httpTransportStdlib) Network() string {
|
||||||
return "tcp"
|
return "tcp"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -272,7 +272,7 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||||
if tlsWithReadTimeout.TLSDialer != td {
|
if tlsWithReadTimeout.TLSDialer != td {
|
||||||
t.Fatal("invalid tls dialer")
|
t.Fatal("invalid tls dialer")
|
||||||
}
|
}
|
||||||
stdlib := connectionsCloser.HTTPTransport.(*stdlibTransport)
|
stdlib := connectionsCloser.HTTPTransport.(*httpTransportStdlib)
|
||||||
if !stdlib.StdlibTransport.ForceAttemptHTTP2 {
|
if !stdlib.StdlibTransport.ForceAttemptHTTP2 {
|
||||||
t.Fatal("invalid ForceAttemptHTTP2")
|
t.Fatal("invalid ForceAttemptHTTP2")
|
||||||
}
|
}
|
||||||
|
@ -292,83 +292,13 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPDialerWithReadTimeout(t *testing.T) {
|
func TestHTTPDialerWithReadTimeout(t *testing.T) {
|
||||||
t.Run("on success", func(t *testing.T) {
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
var (
|
t.Run("on success", func(t *testing.T) {
|
||||||
calledWithZeroTime bool
|
var (
|
||||||
calledWithNonZeroTime bool
|
calledWithZeroTime bool
|
||||||
)
|
calledWithNonZeroTime bool
|
||||||
origConn := &mocks.Conn{
|
)
|
||||||
MockSetReadDeadline: func(t time.Time) error {
|
origConn := &mocks.Conn{
|
||||||
switch t.IsZero() {
|
|
||||||
case true:
|
|
||||||
calledWithZeroTime = true
|
|
||||||
case false:
|
|
||||||
calledWithNonZeroTime = true
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
MockRead: func(b []byte) (int, error) {
|
|
||||||
return 0, io.EOF
|
|
||||||
},
|
|
||||||
}
|
|
||||||
d := &httpDialerWithReadTimeout{
|
|
||||||
Dialer: &mocks.Dialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return origConn, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
conn, err := d.DialContext(ctx, "", "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
|
|
||||||
t.Fatal("invalid conn type")
|
|
||||||
}
|
|
||||||
if conn.(*httpConnWithReadTimeout).Conn != origConn {
|
|
||||||
t.Fatal("invalid origin conn")
|
|
||||||
}
|
|
||||||
b := make([]byte, 1024)
|
|
||||||
count, err := conn.Read(b)
|
|
||||||
if !errors.Is(err, io.EOF) {
|
|
||||||
t.Fatal("invalid error")
|
|
||||||
}
|
|
||||||
if count != 0 {
|
|
||||||
t.Fatal("invalid count")
|
|
||||||
}
|
|
||||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
|
||||||
t.Fatal("not called")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("on failure", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
d := &httpDialerWithReadTimeout{
|
|
||||||
Dialer: &mocks.Dialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conn, err := d.DialContext(context.Background(), "", "")
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("expected nil conn here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
|
||||||
t.Run("on success", func(t *testing.T) {
|
|
||||||
var (
|
|
||||||
calledWithZeroTime bool
|
|
||||||
calledWithNonZeroTime bool
|
|
||||||
)
|
|
||||||
origConn := &mocks.TLSConn{
|
|
||||||
Conn: mocks.Conn{
|
|
||||||
MockSetReadDeadline: func(t time.Time) error {
|
MockSetReadDeadline: func(t time.Time) error {
|
||||||
switch t.IsZero() {
|
switch t.IsZero() {
|
||||||
case true:
|
case true:
|
||||||
|
@ -381,81 +311,155 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
||||||
MockRead: func(b []byte) (int, error) {
|
MockRead: func(b []byte) (int, error) {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
}
|
d := &httpDialerWithReadTimeout{
|
||||||
d := &httpTLSDialerWithReadTimeout{
|
Dialer: &mocks.Dialer{
|
||||||
TLSDialer: &mocks.TLSDialer{
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
return origConn, nil
|
||||||
return origConn, nil
|
},
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
}
|
ctx := context.Background()
|
||||||
ctx := context.Background()
|
conn, err := d.DialContext(ctx, "", "")
|
||||||
conn, err := d.DialTLSContext(ctx, "", "")
|
if err != nil {
|
||||||
if err != nil {
|
t.Fatal(err)
|
||||||
t.Fatal(err)
|
}
|
||||||
}
|
if _, okay := conn.(*httpConnWithReadTimeout); !okay {
|
||||||
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
|
t.Fatal("invalid conn type")
|
||||||
t.Fatal("invalid conn type")
|
}
|
||||||
}
|
if conn.(*httpConnWithReadTimeout).Conn != origConn {
|
||||||
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
|
t.Fatal("invalid origin conn")
|
||||||
t.Fatal("invalid origin conn")
|
}
|
||||||
}
|
b := make([]byte, 1024)
|
||||||
b := make([]byte, 1024)
|
count, err := conn.Read(b)
|
||||||
count, err := conn.Read(b)
|
if !errors.Is(err, io.EOF) {
|
||||||
if !errors.Is(err, io.EOF) {
|
t.Fatal("invalid error")
|
||||||
t.Fatal("invalid error")
|
}
|
||||||
}
|
if count != 0 {
|
||||||
if count != 0 {
|
t.Fatal("invalid count")
|
||||||
t.Fatal("invalid count")
|
}
|
||||||
}
|
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||||
if !calledWithZeroTime || !calledWithNonZeroTime {
|
t.Fatal("not called")
|
||||||
t.Fatal("not called")
|
}
|
||||||
}
|
})
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("on failure", func(t *testing.T) {
|
t.Run("on failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
d := &httpTLSDialerWithReadTimeout{
|
d := &httpDialerWithReadTimeout{
|
||||||
TLSDialer: &mocks.TLSDialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
}
|
conn, err := d.DialContext(context.Background(), "", "")
|
||||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
if !errors.Is(err, expected) {
|
||||||
if !errors.Is(err, expected) {
|
t.Fatal("not the error we expected")
|
||||||
t.Fatal("not the error we expected")
|
}
|
||||||
}
|
if conn != nil {
|
||||||
if conn != nil {
|
t.Fatal("expected nil conn here")
|
||||||
t.Fatal("expected nil conn here")
|
}
|
||||||
}
|
})
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("with invalid conn type", func(t *testing.T) {
|
func TestHTTPTLSDialerWithReadTimeout(t *testing.T) {
|
||||||
var called bool
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
d := &httpTLSDialerWithReadTimeout{
|
t.Run("on success", func(t *testing.T) {
|
||||||
TLSDialer: &mocks.TLSDialer{
|
var (
|
||||||
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
calledWithZeroTime bool
|
||||||
return &mocks.Conn{
|
calledWithNonZeroTime bool
|
||||||
MockClose: func() error {
|
)
|
||||||
called = true
|
origConn := &mocks.TLSConn{
|
||||||
return nil
|
Conn: mocks.Conn{
|
||||||
},
|
MockSetReadDeadline: func(t time.Time) error {
|
||||||
}, nil
|
switch t.IsZero() {
|
||||||
|
case true:
|
||||||
|
calledWithZeroTime = true
|
||||||
|
case false:
|
||||||
|
calledWithNonZeroTime = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
}
|
d := &httpTLSDialerWithReadTimeout{
|
||||||
conn, err := d.DialTLSContext(context.Background(), "", "")
|
TLSDialer: &mocks.TLSDialer{
|
||||||
if !errors.Is(err, ErrNotTLSConn) {
|
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
t.Fatal("not the error we expected")
|
return origConn, nil
|
||||||
}
|
},
|
||||||
if conn != nil {
|
},
|
||||||
t.Fatal("expected nil conn here")
|
}
|
||||||
}
|
ctx := context.Background()
|
||||||
if !called {
|
conn, err := d.DialTLSContext(ctx, "", "")
|
||||||
t.Fatal("not called")
|
if err != nil {
|
||||||
}
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay {
|
||||||
|
t.Fatal("invalid conn type")
|
||||||
|
}
|
||||||
|
if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn {
|
||||||
|
t.Fatal("invalid origin conn")
|
||||||
|
}
|
||||||
|
b := make([]byte, 1024)
|
||||||
|
count, err := conn.Read(b)
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("invalid error")
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if !calledWithZeroTime || !calledWithNonZeroTime {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
d := &httpTLSDialerWithReadTimeout{
|
||||||
|
TLSDialer: &mocks.TLSDialer{
|
||||||
|
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with invalid conn type", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
d := &httpTLSDialerWithReadTimeout{
|
||||||
|
TLSDialer: &mocks.TLSDialer{
|
||||||
|
MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return &mocks.Conn{
|
||||||
|
MockClose: func() error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := d.DialTLSContext(context.Background(), "", "")
|
||||||
|
if !errors.Is(err, ErrNotTLSConn) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn here")
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string,
|
||||||
pconn.Close() // we own it on failure
|
pconn.Close() // we own it on failure
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}, nil
|
return newQUICConnectionOwnsConn(qconn, pconn), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context,
|
func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context,
|
||||||
|
@ -183,6 +183,8 @@ type quicDialerHandshakeCompleter struct {
|
||||||
Dialer model.QUICDialer
|
Dialer model.QUICDialer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ model.QUICDialer = &quicDialerHandshakeCompleter{}
|
||||||
|
|
||||||
// DialContext implements model.QUICDialer.DialContext.
|
// DialContext implements model.QUICDialer.DialContext.
|
||||||
func (d *quicDialerHandshakeCompleter) DialContext(
|
func (d *quicDialerHandshakeCompleter) DialContext(
|
||||||
ctx context.Context, network, address string,
|
ctx context.Context, network, address string,
|
||||||
|
@ -214,6 +216,10 @@ type quicConnectionOwnsConn struct {
|
||||||
conn model.UDPLikeConn
|
conn model.UDPLikeConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newQUICConnectionOwnsConn(qconn quic.EarlyConnection, pconn model.UDPLikeConn) *quicConnectionOwnsConn {
|
||||||
|
return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}
|
||||||
|
}
|
||||||
|
|
||||||
// CloseWithError implements quic.EarlyConnection.CloseWithError.
|
// CloseWithError implements quic.EarlyConnection.CloseWithError.
|
||||||
func (qconn *quicConnectionOwnsConn) CloseWithError(
|
func (qconn *quicConnectionOwnsConn) CloseWithError(
|
||||||
code quic.ApplicationErrorCode, reason string) error {
|
code quic.ApplicationErrorCode, reason string) error {
|
||||||
|
|
|
@ -302,6 +302,31 @@ func TestQUICDialerQUICGo(t *testing.T) {
|
||||||
t.Fatal("the ServerName field must match")
|
t.Fatal("the ServerName field must match")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("returns a quicDialerOwnConn in case of success", func(t *testing.T) {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
fakeconn := &mocks.QUICEarlyConnection{}
|
||||||
|
systemdialer := quicDialerQUICGo{
|
||||||
|
QUICListener: &quicListenerStdlib{},
|
||||||
|
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlyConnection, error) {
|
||||||
|
return fakeconn, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
qconn, err := systemdialer.DialContext(
|
||||||
|
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
connOwner := qconn.(*quicConnectionOwnsConn)
|
||||||
|
if connOwner.EarlyConnection != fakeconn {
|
||||||
|
t.Fatal("invalid underlying conn")
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -406,6 +431,33 @@ func TestQUICDialerHandshakeCompleter(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQUICConnectionOwnsConn(t *testing.T) {
|
||||||
|
var (
|
||||||
|
quicClose bool
|
||||||
|
udpClose bool
|
||||||
|
)
|
||||||
|
qconn := &mocks.QUICEarlyConnection{
|
||||||
|
MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error {
|
||||||
|
quicClose = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pconn := &mocks.UDPLikeConn{
|
||||||
|
MockClose: func() error {
|
||||||
|
udpClose = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn := newQUICConnectionOwnsConn(qconn, pconn)
|
||||||
|
conn.CloseWithError(0, "")
|
||||||
|
if !quicClose {
|
||||||
|
t.Fatal("did not call qconn.CloseWithError")
|
||||||
|
}
|
||||||
|
if !udpClose {
|
||||||
|
t.Fatal("did not call pconn.Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQUICDialerResolver(t *testing.T) {
|
func TestQUICDialerResolver(t *testing.T) {
|
||||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
|
@ -518,6 +570,27 @@ func TestQUICDialerResolver(t *testing.T) {
|
||||||
t.Fatal("gotTLSConfig.ServerName has not been set")
|
t.Fatal("gotTLSConfig.ServerName has not been set")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
expectedQConn := &mocks.QUICEarlyConnection{}
|
||||||
|
dialer := &quicDialerResolver{
|
||||||
|
Resolver: NewResolverStdlib(log.Log),
|
||||||
|
Dialer: &mocks.QUICDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
|
||||||
|
return expectedQConn, nil
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
qconn, err := dialer.DialContext(
|
||||||
|
context.Background(), "udp", "8.8.4.4:443",
|
||||||
|
&tls.Config{}, &quic.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if qconn != expectedQConn {
|
||||||
|
t.Fatal("unexpected underlying qconn")
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("lookup host with address", func(t *testing.T) {
|
t.Run("lookup host with address", func(t *testing.T) {
|
||||||
|
|
|
@ -22,6 +22,8 @@ type ParallelResolver struct {
|
||||||
Txp model.DNSTransport
|
Txp model.DNSTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ model.Resolver = &ParallelResolver{}
|
||||||
|
|
||||||
// UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is
|
// UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is
|
||||||
// not wrapped and you should wrap if before using it.
|
// not wrapped and you should wrap if before using it.
|
||||||
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
|
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
|
||||||
|
|
|
@ -33,6 +33,8 @@ type SerialResolver struct {
|
||||||
Txp model.DNSTransport
|
Txp model.DNSTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ model.Resolver = &SerialResolver{}
|
||||||
|
|
||||||
// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance.
|
// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance.
|
||||||
func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver {
|
func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver {
|
||||||
return &SerialResolver{
|
return &SerialResolver{
|
||||||
|
|
36
internal/tracex/event_test.go
Normal file
36
internal/tracex/event_test.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package tracex
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestUnusedEventsNames(t *testing.T) {
|
||||||
|
// Tests that we don't break the names of events we're currently
|
||||||
|
// not getting the name of directly even if they're saved.
|
||||||
|
|
||||||
|
t.Run("EventQUICHandshakeStart", func(t *testing.T) {
|
||||||
|
ev := &EventQUICHandshakeStart{}
|
||||||
|
if ev.Name() != "quic_handshake_start" {
|
||||||
|
t.Fatal("invalid event name")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EventQUICHandshakeDone", func(t *testing.T) {
|
||||||
|
ev := &EventQUICHandshakeDone{}
|
||||||
|
if ev.Name() != "quic_handshake_done" {
|
||||||
|
t.Fatal("invalid event name")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EventTLSHandshakeStart", func(t *testing.T) {
|
||||||
|
ev := &EventTLSHandshakeStart{}
|
||||||
|
if ev.Name() != "tls_handshake_start" {
|
||||||
|
t.Fatal("invalid event name")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("EventTLSHandshakeDone", func(t *testing.T) {
|
||||||
|
ev := &EventTLSHandshakeDone{}
|
||||||
|
if ev.Name() != "tls_handshake_done" {
|
||||||
|
t.Fatal("invalid event name")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -177,6 +177,14 @@ func TestQUICDialerSaver(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWrapQUICListener(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
ql := &mocks.QUICListener{}
|
||||||
|
if saver.WrapQUICListener(ql) != ql {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQUICListenerSaver(t *testing.T) {
|
func TestQUICListenerSaver(t *testing.T) {
|
||||||
t.Run("on failure", func(t *testing.T) {
|
t.Run("on failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
|
|
|
@ -15,219 +15,370 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestWrapResolver(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
reso := &mocks.Resolver{}
|
||||||
|
if saver.WrapResolver(reso) != reso {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolverSaver(t *testing.T) {
|
func TestResolverSaver(t *testing.T) {
|
||||||
t.Run("on failure", func(t *testing.T) {
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
expected := netxlite.ErrOODNSNoSuchHost
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := netxlite.ErrOODNSNoSuchHost
|
||||||
|
saver := &Saver{}
|
||||||
|
reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected))
|
||||||
|
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("expected nil address here")
|
||||||
|
}
|
||||||
|
ev := saver.Read()
|
||||||
|
if len(ev) != 2 {
|
||||||
|
t.Fatal("expected number of events")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Hostname != "www.google.com" {
|
||||||
|
t.Fatal("unexpected Hostname")
|
||||||
|
}
|
||||||
|
if ev[0].Name() != "resolve_start" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[0].Value().Time.Before(time.Now()) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Addresses != nil {
|
||||||
|
t.Fatal("unexpected Addresses")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Duration <= 0 {
|
||||||
|
t.Fatal("unexpected Duration")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
|
||||||
|
t.Fatal("unexpected Err")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Hostname != "www.google.com" {
|
||||||
|
t.Fatal("unexpected Hostname")
|
||||||
|
}
|
||||||
|
if ev[1].Name() != "resolve_done" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||||
|
saver := &Saver{}
|
||||||
|
reso := saver.WrapResolver(newFakeResolverWithResult(expected))
|
||||||
|
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("expected nil error here")
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(addrs, expected) {
|
||||||
|
t.Fatal("not the result we expected")
|
||||||
|
}
|
||||||
|
ev := saver.Read()
|
||||||
|
if len(ev) != 2 {
|
||||||
|
t.Fatal("expected number of events")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Hostname != "www.google.com" {
|
||||||
|
t.Fatal("unexpected Hostname")
|
||||||
|
}
|
||||||
|
if ev[0].Name() != "resolve_start" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[0].Value().Time.Before(time.Now()) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(ev[1].Value().Addresses, expected) {
|
||||||
|
t.Fatal("unexpected Addresses")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Duration <= 0 {
|
||||||
|
t.Fatal("unexpected Duration")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Err.IsNotNil() {
|
||||||
|
t.Fatal("unexpected Err")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Hostname != "www.google.com" {
|
||||||
|
t.Fatal("unexpected Hostname")
|
||||||
|
}
|
||||||
|
if ev[1].Name() != "resolve_done" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network", func(t *testing.T) {
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected))
|
child := &mocks.Resolver{
|
||||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
MockNetwork: func() string {
|
||||||
if !errors.Is(err, expected) {
|
return "x"
|
||||||
t.Fatal("not the error we expected")
|
},
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
reso := saver.WrapResolver(child)
|
||||||
t.Fatal("expected nil address here")
|
if reso.Network() != "x" {
|
||||||
}
|
t.Fatal("unexpected result")
|
||||||
ev := saver.Read()
|
|
||||||
if len(ev) != 2 {
|
|
||||||
t.Fatal("expected number of events")
|
|
||||||
}
|
|
||||||
if ev[0].Value().Hostname != "www.google.com" {
|
|
||||||
t.Fatal("unexpected Hostname")
|
|
||||||
}
|
|
||||||
if ev[0].Name() != "resolve_start" {
|
|
||||||
t.Fatal("unexpected name")
|
|
||||||
}
|
|
||||||
if !ev[0].Value().Time.Before(time.Now()) {
|
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Addresses != nil {
|
|
||||||
t.Fatal("unexpected Addresses")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
|
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Hostname != "www.google.com" {
|
|
||||||
t.Fatal("unexpected Hostname")
|
|
||||||
}
|
|
||||||
if ev[1].Name() != "resolve_done" {
|
|
||||||
t.Fatal("unexpected name")
|
|
||||||
}
|
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("on success", func(t *testing.T) {
|
t.Run("Address", func(t *testing.T) {
|
||||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
reso := saver.WrapResolver(newFakeResolverWithResult(expected))
|
child := &mocks.Resolver{
|
||||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
MockAddress: func() string {
|
||||||
if err != nil {
|
return "x"
|
||||||
t.Fatal("expected nil error here")
|
},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(addrs, expected) {
|
reso := saver.WrapResolver(child)
|
||||||
t.Fatal("not the result we expected")
|
if reso.Address() != "x" {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
ev := saver.Read()
|
})
|
||||||
if len(ev) != 2 {
|
|
||||||
t.Fatal("expected number of events")
|
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked")
|
||||||
|
saver := &Saver{}
|
||||||
|
child := &mocks.Resolver{
|
||||||
|
MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if ev[0].Value().Hostname != "www.google.com" {
|
reso := saver.WrapResolver(child)
|
||||||
t.Fatal("unexpected Hostname")
|
https, err := reso.LookupHTTPS(context.Background(), "dns.google")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if ev[0].Name() != "resolve_start" {
|
if https != nil {
|
||||||
t.Fatal("unexpected name")
|
t.Fatal("expected nil")
|
||||||
}
|
}
|
||||||
if !ev[0].Value().Time.Before(time.Now()) {
|
})
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
|
t.Run("LookupNS", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked")
|
||||||
|
saver := &Saver{}
|
||||||
|
child := &mocks.Resolver{
|
||||||
|
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ev[1].Value().Addresses, expected) {
|
reso := saver.WrapResolver(child)
|
||||||
t.Fatal("unexpected Addresses")
|
ns, err := reso.LookupNS(context.Background(), "dns.google")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if ev[1].Value().Duration <= 0 {
|
if len(ns) != 0 {
|
||||||
t.Fatal("unexpected Duration")
|
t.Fatal("expected zero length array")
|
||||||
}
|
}
|
||||||
if ev[1].Value().Err.IsNotNil() {
|
})
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
saver := &Saver{}
|
||||||
|
child := &mocks.Resolver{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if ev[1].Value().Hostname != "www.google.com" {
|
reso := saver.WrapResolver(child)
|
||||||
t.Fatal("unexpected Hostname")
|
reso.CloseIdleConnections()
|
||||||
}
|
if !called {
|
||||||
if ev[1].Name() != "resolve_done" {
|
t.Fatal("not called")
|
||||||
t.Fatal("unexpected name")
|
|
||||||
}
|
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWrapDNSTransport(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
txp := &mocks.DNSTransport{}
|
||||||
|
if saver.WrapDNSTransport(txp) != txp {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSTransportSaver(t *testing.T) {
|
func TestDNSTransportSaver(t *testing.T) {
|
||||||
t.Run("on failure", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
expected := netxlite.ErrOODNSNoSuchHost
|
t.Run("on failure", func(t *testing.T) {
|
||||||
saver := &Saver{}
|
expected := netxlite.ErrOODNSNoSuchHost
|
||||||
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
|
saver := &Saver{}
|
||||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
|
||||||
return nil, expected
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
},
|
return nil, expected
|
||||||
MockNetwork: func() string {
|
},
|
||||||
return "fake"
|
MockNetwork: func() string {
|
||||||
},
|
return "fake"
|
||||||
MockAddress: func() string {
|
},
|
||||||
return ""
|
MockAddress: func() string {
|
||||||
},
|
return ""
|
||||||
|
},
|
||||||
|
})
|
||||||
|
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return rawQuery, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reply, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if reply != nil {
|
||||||
|
t.Fatal("expected nil reply here")
|
||||||
|
}
|
||||||
|
ev := saver.Read()
|
||||||
|
if len(ev) != 2 {
|
||||||
|
t.Fatal("expected number of events")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
|
||||||
|
t.Fatal("unexpected DNSQuery")
|
||||||
|
}
|
||||||
|
if ev[0].Name() != "dns_round_trip_start" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[0].Value().Time.Before(time.Now()) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
|
||||||
|
t.Fatal("unexpected DNSQuery")
|
||||||
|
}
|
||||||
|
if ev[1].Value().DNSResponse != nil {
|
||||||
|
t.Fatal("unexpected DNSReply")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Duration <= 0 {
|
||||||
|
t.Fatal("unexpected Duration")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
|
||||||
|
t.Fatal("unexpected Err")
|
||||||
|
}
|
||||||
|
if ev[1].Name() != "dns_round_trip_done" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
|
||||||
query := &mocks.DNSQuery{
|
t.Run("on success", func(t *testing.T) {
|
||||||
MockBytes: func() ([]byte, error) {
|
expected := []byte{0xef, 0xbe, 0xad, 0xde}
|
||||||
return rawQuery, nil
|
saver := &Saver{}
|
||||||
|
response := &mocks.DNSResponse{
|
||||||
|
MockBytes: func() []byte {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return response, nil
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "fake"
|
||||||
|
},
|
||||||
|
MockAddress: func() string {
|
||||||
|
return ""
|
||||||
|
},
|
||||||
|
})
|
||||||
|
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return rawQuery, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reply, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("we expected nil error here")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(reply.Bytes(), expected) {
|
||||||
|
t.Fatal("expected another reply here")
|
||||||
|
}
|
||||||
|
ev := saver.Read()
|
||||||
|
if len(ev) != 2 {
|
||||||
|
t.Fatal("expected number of events")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
|
||||||
|
t.Fatal("unexpected DNSQuery")
|
||||||
|
}
|
||||||
|
if ev[0].Name() != "dns_round_trip_start" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[0].Value().Time.Before(time.Now()) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
|
||||||
|
t.Fatal("unexpected DNSQuery")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(ev[1].Value().DNSResponse, expected) {
|
||||||
|
t.Fatal("unexpected DNSReply")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Duration <= 0 {
|
||||||
|
t.Fatal("unexpected Duration")
|
||||||
|
}
|
||||||
|
if ev[1].Value().Err.IsNotNil() {
|
||||||
|
t.Fatal("unexpected Err")
|
||||||
|
}
|
||||||
|
if ev[1].Name() != "dns_round_trip_done" {
|
||||||
|
t.Fatal("unexpected name")
|
||||||
|
}
|
||||||
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
|
t.Fatal("the saved time is wrong")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network", func(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "x"
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
reply, err := txp.RoundTrip(context.Background(), query)
|
txp := saver.WrapDNSTransport(child)
|
||||||
if !errors.Is(err, expected) {
|
if txp.Network() != "x" {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("unexpected result")
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply here")
|
|
||||||
}
|
|
||||||
ev := saver.Read()
|
|
||||||
if len(ev) != 2 {
|
|
||||||
t.Fatal("expected number of events")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
|
|
||||||
t.Fatal("unexpected DNSQuery")
|
|
||||||
}
|
|
||||||
if ev[0].Name() != "dns_round_trip_start" {
|
|
||||||
t.Fatal("unexpected name")
|
|
||||||
}
|
|
||||||
if !ev[0].Value().Time.Before(time.Now()) {
|
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
|
|
||||||
t.Fatal("unexpected DNSQuery")
|
|
||||||
}
|
|
||||||
if ev[1].Value().DNSResponse != nil {
|
|
||||||
t.Fatal("unexpected DNSReply")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError {
|
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
}
|
|
||||||
if ev[1].Name() != "dns_round_trip_done" {
|
|
||||||
t.Fatal("unexpected name")
|
|
||||||
}
|
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("on success", func(t *testing.T) {
|
t.Run("Address", func(t *testing.T) {
|
||||||
expected := []byte{0xef, 0xbe, 0xad, 0xde}
|
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
response := &mocks.DNSResponse{
|
child := &mocks.DNSTransport{
|
||||||
MockBytes: func() []byte {
|
|
||||||
return expected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
|
|
||||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
|
||||||
return response, nil
|
|
||||||
},
|
|
||||||
MockNetwork: func() string {
|
|
||||||
return "fake"
|
|
||||||
},
|
|
||||||
MockAddress: func() string {
|
MockAddress: func() string {
|
||||||
return ""
|
return "x"
|
||||||
},
|
|
||||||
})
|
|
||||||
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
|
||||||
query := &mocks.DNSQuery{
|
|
||||||
MockBytes: func() ([]byte, error) {
|
|
||||||
return rawQuery, nil
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
reply, err := txp.RoundTrip(context.Background(), query)
|
txp := saver.WrapDNSTransport(child)
|
||||||
if err != nil {
|
if txp.Address() != "x" {
|
||||||
t.Fatal("we expected nil error here")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(reply.Bytes(), expected) {
|
})
|
||||||
t.Fatal("expected another reply here")
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
saver := &Saver{}
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
}
|
}
|
||||||
ev := saver.Read()
|
txp := saver.WrapDNSTransport(child)
|
||||||
if len(ev) != 2 {
|
txp.CloseIdleConnections()
|
||||||
t.Fatal("expected number of events")
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
|
})
|
||||||
t.Fatal("unexpected DNSQuery")
|
|
||||||
|
t.Run("RequiresPadding", func(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if ev[0].Name() != "dns_round_trip_start" {
|
txp := saver.WrapDNSTransport(child)
|
||||||
t.Fatal("unexpected name")
|
if !txp.RequiresPadding() {
|
||||||
}
|
t.Fatal("unexpected result")
|
||||||
if !ev[0].Value().Time.Before(time.Now()) {
|
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
|
|
||||||
t.Fatal("unexpected DNSQuery")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(ev[1].Value().DNSResponse, expected) {
|
|
||||||
t.Fatal("unexpected DNSReply")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Err.IsNotNil() {
|
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
}
|
|
||||||
if ev[1].Name() != "dns_round_trip_done" {
|
|
||||||
t.Fatal("unexpected name")
|
|
||||||
}
|
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
|
||||||
t.Fatal("the saved time is wrong")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,14 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestWrapTLSHandshaker(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
thx := &mocks.TLSHandshaker{}
|
||||||
|
if saver.WrapTLSHandshaker(thx) != thx {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerSaver(t *testing.T) {
|
func TestTLSHandshakerSaver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("Handshake", func(t *testing.T) {
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user