diff --git a/internal/netxlite/getaddrinfo_cgo.go b/internal/netxlite/getaddrinfo_cgo.go index 0e7ace5..f061d0d 100644 --- a/internal/netxlite/getaddrinfo_cgo.go +++ b/internal/netxlite/getaddrinfo_cgo.go @@ -207,23 +207,6 @@ func (state *getaddrinfoState) addrinfoToString(r *C.struct_addrinfo) (string, e } } -// staticAddrinfoWithInvalidFamily is an helper to construct an addrinfo struct -// that we use in testing. (We cannot call CGO directly from tests.) -func staticAddrinfoWithInvalidFamily() *C.struct_addrinfo { - var value C.struct_addrinfo // zeroed by Go - value.ai_socktype = C.SOCK_STREAM // this is what the code expects - value.ai_family = 0 // but 0 is not AF_INET{,6} - return &value -} - -// staticAddrinfoWithInvalidSocketType is an helper to construct an addrinfo struct -// that we use in testing. (We cannot call CGO directly from tests.) -func staticAddrinfoWithInvalidSocketType() *C.struct_addrinfo { - var value C.struct_addrinfo // zeroed by Go - value.ai_socktype = C.SOCK_DGRAM // not SOCK_STREAM - return &value -} - // getaddrinfoCopyIP copies a net.IP. // // This function is adapted from copyIP diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index 437b7a9..1599cb4 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -177,29 +177,29 @@ func NewOOHTTPBaseTransport(dialer model.Dialer, tlsDialer model.TLSDialer) mode // Ensure we correctly forward CloseIdleConnections. return &httpTransportConnectionsCloser{ - HTTPTransport: &stdlibTransport{&oohttp.StdlibTransport{Transport: txp}}, + HTTPTransport: &httpTransportStdlib{&oohttp.StdlibTransport{Transport: txp}}, Dialer: dialer, TLSDialer: tlsDialer, } } // stdlibTransport wraps oohttp.StdlibTransport to add .Network() -type stdlibTransport struct { +type httpTransportStdlib struct { StdlibTransport *oohttp.StdlibTransport } -var _ model.HTTPTransport = &stdlibTransport{} +var _ model.HTTPTransport = &httpTransportStdlib{} -func (txp *stdlibTransport) CloseIdleConnections() { +func (txp *httpTransportStdlib) CloseIdleConnections() { txp.StdlibTransport.CloseIdleConnections() } -func (txp *stdlibTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) { return txp.StdlibTransport.RoundTrip(req) } // Network implements HTTPTransport.Network. -func (txp *stdlibTransport) Network() string { +func (txp *httpTransportStdlib) Network() string { return "tcp" } diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 56fe4b0..d4946cf 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -272,7 +272,7 @@ func TestNewHTTPTransport(t *testing.T) { if tlsWithReadTimeout.TLSDialer != td { t.Fatal("invalid tls dialer") } - stdlib := connectionsCloser.HTTPTransport.(*stdlibTransport) + stdlib := connectionsCloser.HTTPTransport.(*httpTransportStdlib) if !stdlib.StdlibTransport.ForceAttemptHTTP2 { t.Fatal("invalid ForceAttemptHTTP2") } @@ -292,83 +292,13 @@ func TestNewHTTPTransport(t *testing.T) { } func TestHTTPDialerWithReadTimeout(t *testing.T) { - t.Run("on success", func(t *testing.T) { - var ( - calledWithZeroTime bool - calledWithNonZeroTime bool - ) - origConn := &mocks.Conn{ - MockSetReadDeadline: func(t time.Time) error { - switch t.IsZero() { - case true: - calledWithZeroTime = true - case false: - calledWithNonZeroTime = true - } - return nil - }, - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - } - d := &httpDialerWithReadTimeout{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return origConn, nil - }, - }, - } - ctx := context.Background() - conn, err := d.DialContext(ctx, "", "") - if err != nil { - t.Fatal(err) - } - if _, okay := conn.(*httpConnWithReadTimeout); !okay { - t.Fatal("invalid conn type") - } - if conn.(*httpConnWithReadTimeout).Conn != origConn { - t.Fatal("invalid origin conn") - } - b := make([]byte, 1024) - count, err := conn.Read(b) - if !errors.Is(err, io.EOF) { - t.Fatal("invalid error") - } - if count != 0 { - t.Fatal("invalid count") - } - if !calledWithZeroTime || !calledWithNonZeroTime { - t.Fatal("not called") - } - }) - - t.Run("on failure", func(t *testing.T) { - expected := errors.New("mocked error") - d := &httpDialerWithReadTimeout{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, expected - }, - }, - } - conn, err := d.DialContext(context.Background(), "", "") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - }) -} - -func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { - t.Run("on success", func(t *testing.T) { - var ( - calledWithZeroTime bool - calledWithNonZeroTime bool - ) - origConn := &mocks.TLSConn{ - Conn: mocks.Conn{ + t.Run("DialContext", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.Conn{ MockSetReadDeadline: func(t time.Time) error { switch t.IsZero() { case true: @@ -381,81 +311,155 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { MockRead: func(b []byte) (int, error) { return 0, io.EOF }, - }, - } - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return origConn, nil + } + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, }, - }, - } - ctx := context.Background() - conn, err := d.DialTLSContext(ctx, "", "") - if err != nil { - t.Fatal(err) - } - if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { - t.Fatal("invalid conn type") - } - if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { - t.Fatal("invalid origin conn") - } - b := make([]byte, 1024) - count, err := conn.Read(b) - if !errors.Is(err, io.EOF) { - t.Fatal("invalid error") - } - if count != 0 { - t.Fatal("invalid count") - } - if !calledWithZeroTime || !calledWithNonZeroTime { - t.Fatal("not called") - } - }) + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpConnWithReadTimeout).Conn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } + }) - t.Run("on failure", func(t *testing.T) { - expected := errors.New("mocked error") - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, expected + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, }, - }, - } - conn, err := d.DialTLSContext(context.Background(), "", "") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } + } + conn, err := d.DialContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) }) +} - t.Run("with invalid conn type", func(t *testing.T) { - var called bool - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockClose: func() error { - called = true - return nil - }, - }, nil +func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.TLSConn{ + Conn: mocks.Conn{ + MockSetReadDeadline: func(t time.Time) error { + switch t.IsZero() { + case true: + calledWithZeroTime = true + case false: + calledWithNonZeroTime = true + } + return nil + }, + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, }, - }, - } - conn, err := d.DialTLSContext(context.Background(), "", "") - if !errors.Is(err, ErrNotTLSConn) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - if !called { - t.Fatal("not called") - } + } + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, + }, + } + ctx := context.Background() + conn, err := d.DialTLSContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, + }, + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + + t.Run("with invalid conn type", func(t *testing.T) { + var called bool + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + called = true + return nil + }, + }, nil + }, + }, + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, ErrNotTLSConn) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + if !called { + t.Fatal("not called") + } + }) }) } diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index 0969a40..f989771 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -140,7 +140,7 @@ func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string, pconn.Close() // we own it on failure return nil, err } - return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}, nil + return newQUICConnectionOwnsConn(qconn, pconn), nil } func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context, @@ -183,6 +183,8 @@ type quicDialerHandshakeCompleter struct { Dialer model.QUICDialer } +var _ model.QUICDialer = &quicDialerHandshakeCompleter{} + // DialContext implements model.QUICDialer.DialContext. func (d *quicDialerHandshakeCompleter) DialContext( ctx context.Context, network, address string, @@ -214,6 +216,10 @@ type quicConnectionOwnsConn struct { conn model.UDPLikeConn } +func newQUICConnectionOwnsConn(qconn quic.EarlyConnection, pconn model.UDPLikeConn) *quicConnectionOwnsConn { + return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn} +} + // CloseWithError implements quic.EarlyConnection.CloseWithError. func (qconn *quicConnectionOwnsConn) CloseWithError( code quic.ApplicationErrorCode, reason string) error { diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index aca1860..60c3a77 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -302,6 +302,31 @@ func TestQUICDialerQUICGo(t *testing.T) { t.Fatal("the ServerName field must match") } }) + + t.Run("returns a quicDialerOwnConn in case of success", func(t *testing.T) { + tlsConfig := &tls.Config{ + ServerName: "dns.google", + } + fakeconn := &mocks.QUICEarlyConnection{} + systemdialer := quicDialerQUICGo{ + QUICListener: &quicListenerStdlib{}, + mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, + remoteAddr net.Addr, host string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlyConnection, error) { + return fakeconn, nil + }, + } + ctx := context.Background() + qconn, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + connOwner := qconn.(*quicConnectionOwnsConn) + if connOwner.EarlyConnection != fakeconn { + t.Fatal("invalid underlying conn") + } + }) }) } @@ -406,6 +431,33 @@ func TestQUICDialerHandshakeCompleter(t *testing.T) { }) } +func TestQUICConnectionOwnsConn(t *testing.T) { + var ( + quicClose bool + udpClose bool + ) + qconn := &mocks.QUICEarlyConnection{ + MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error { + quicClose = true + return nil + }, + } + pconn := &mocks.UDPLikeConn{ + MockClose: func() error { + udpClose = true + return nil + }, + } + conn := newQUICConnectionOwnsConn(qconn, pconn) + conn.CloseWithError(0, "") + if !quicClose { + t.Fatal("did not call qconn.CloseWithError") + } + if !udpClose { + t.Fatal("did not call pconn.Close") + } +} + func TestQUICDialerResolver(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) { var ( @@ -518,6 +570,27 @@ func TestQUICDialerResolver(t *testing.T) { t.Fatal("gotTLSConfig.ServerName has not been set") } }) + + t.Run("on success", func(t *testing.T) { + expectedQConn := &mocks.QUICEarlyConnection{} + dialer := &quicDialerResolver{ + Resolver: NewResolverStdlib(log.Log), + Dialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return expectedQConn, nil + }, + }} + qconn, err := dialer.DialContext( + context.Background(), "udp", "8.8.4.4:443", + &tls.Config{}, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + if qconn != expectedQConn { + t.Fatal("unexpected underlying qconn") + } + }) }) t.Run("lookup host with address", func(t *testing.T) { diff --git a/internal/netxlite/resolverparallel.go b/internal/netxlite/resolverparallel.go index 8c22bb5..c210290 100644 --- a/internal/netxlite/resolverparallel.go +++ b/internal/netxlite/resolverparallel.go @@ -22,6 +22,8 @@ type ParallelResolver struct { Txp model.DNSTransport } +var _ model.Resolver = &ParallelResolver{} + // UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is // not wrapped and you should wrap if before using it. func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver { diff --git a/internal/netxlite/resolverserial.go b/internal/netxlite/resolverserial.go index 9ce54ec..4159779 100644 --- a/internal/netxlite/resolverserial.go +++ b/internal/netxlite/resolverserial.go @@ -33,6 +33,8 @@ type SerialResolver struct { Txp model.DNSTransport } +var _ model.Resolver = &SerialResolver{} + // NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance. func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver { return &SerialResolver{ diff --git a/internal/tracex/event_test.go b/internal/tracex/event_test.go new file mode 100644 index 0000000..57abbc2 --- /dev/null +++ b/internal/tracex/event_test.go @@ -0,0 +1,36 @@ +package tracex + +import "testing" + +func TestUnusedEventsNames(t *testing.T) { + // Tests that we don't break the names of events we're currently + // not getting the name of directly even if they're saved. + + t.Run("EventQUICHandshakeStart", func(t *testing.T) { + ev := &EventQUICHandshakeStart{} + if ev.Name() != "quic_handshake_start" { + t.Fatal("invalid event name") + } + }) + + t.Run("EventQUICHandshakeDone", func(t *testing.T) { + ev := &EventQUICHandshakeDone{} + if ev.Name() != "quic_handshake_done" { + t.Fatal("invalid event name") + } + }) + + t.Run("EventTLSHandshakeStart", func(t *testing.T) { + ev := &EventTLSHandshakeStart{} + if ev.Name() != "tls_handshake_start" { + t.Fatal("invalid event name") + } + }) + + t.Run("EventTLSHandshakeDone", func(t *testing.T) { + ev := &EventTLSHandshakeDone{} + if ev.Name() != "tls_handshake_done" { + t.Fatal("invalid event name") + } + }) +} diff --git a/internal/tracex/quic_test.go b/internal/tracex/quic_test.go index 7a4aac9..4de86c8 100644 --- a/internal/tracex/quic_test.go +++ b/internal/tracex/quic_test.go @@ -177,6 +177,14 @@ func TestQUICDialerSaver(t *testing.T) { }) } +func TestWrapQUICListener(t *testing.T) { + var saver *Saver + ql := &mocks.QUICListener{} + if saver.WrapQUICListener(ql) != ql { + t.Fatal("unexpected result") + } +} + func TestQUICListenerSaver(t *testing.T) { t.Run("on failure", func(t *testing.T) { expected := errors.New("mocked error") diff --git a/internal/tracex/resolver_test.go b/internal/tracex/resolver_test.go index d71785b..58b6295 100644 --- a/internal/tracex/resolver_test.go +++ b/internal/tracex/resolver_test.go @@ -15,219 +15,370 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) +func TestWrapResolver(t *testing.T) { + var saver *Saver + reso := &mocks.Resolver{} + if saver.WrapResolver(reso) != reso { + t.Fatal("unexpected result") + } +} + func TestResolverSaver(t *testing.T) { - t.Run("on failure", func(t *testing.T) { - expected := netxlite.ErrOODNSNoSuchHost + t.Run("LookupHost", func(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := netxlite.ErrOODNSNoSuchHost + saver := &Saver{} + reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected)) + addrs, err := reso.LookupHost(context.Background(), "www.google.com") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if ev[0].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[0].Name() != "resolve_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if ev[1].Value().Addresses != nil { + t.Fatal("unexpected Addresses") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { + t.Fatal("unexpected Err") + } + if ev[1].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[1].Name() != "resolve_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } + }) + + t.Run("on success", func(t *testing.T) { + expected := []string{"8.8.8.8", "8.8.4.4"} + saver := &Saver{} + reso := saver.WrapResolver(newFakeResolverWithResult(expected)) + addrs, err := reso.LookupHost(context.Background(), "www.google.com") + if err != nil { + t.Fatal("expected nil error here") + } + if !reflect.DeepEqual(addrs, expected) { + t.Fatal("not the result we expected") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if ev[0].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[0].Name() != "resolve_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { + t.Fatal("unexpected Addresses") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err.IsNotNil() { + t.Fatal("unexpected Err") + } + if ev[1].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[1].Name() != "resolve_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } + }) + }) + + t.Run("Network", func(t *testing.T) { saver := &Saver{} - reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected)) - addrs, err := reso.LookupHost(context.Background(), "www.google.com") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") + child := &mocks.Resolver{ + MockNetwork: func() string { + return "x" + }, } - if addrs != nil { - t.Fatal("expected nil address here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if ev[0].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[0].Name() != "resolve_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if ev[1].Value().Addresses != nil { - t.Fatal("unexpected Addresses") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { - t.Fatal("unexpected Err") - } - if ev[1].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[1].Name() != "resolve_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") + reso := saver.WrapResolver(child) + if reso.Network() != "x" { + t.Fatal("unexpected result") } }) - t.Run("on success", func(t *testing.T) { - expected := []string{"8.8.8.8", "8.8.4.4"} + t.Run("Address", func(t *testing.T) { saver := &Saver{} - reso := saver.WrapResolver(newFakeResolverWithResult(expected)) - addrs, err := reso.LookupHost(context.Background(), "www.google.com") - if err != nil { - t.Fatal("expected nil error here") + child := &mocks.Resolver{ + MockAddress: func() string { + return "x" + }, } - if !reflect.DeepEqual(addrs, expected) { - t.Fatal("not the result we expected") + reso := saver.WrapResolver(child) + if reso.Address() != "x" { + t.Fatal("unexpected result") } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") + }) + + t.Run("LookupHTTPS", func(t *testing.T) { + expected := errors.New("mocked") + saver := &Saver{} + child := &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + return nil, expected + }, } - if ev[0].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") + reso := saver.WrapResolver(child) + https, err := reso.LookupHTTPS(context.Background(), "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) } - if ev[0].Name() != "resolve_start" { - t.Fatal("unexpected name") + if https != nil { + t.Fatal("expected nil") } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") + }) + + t.Run("LookupNS", func(t *testing.T) { + expected := errors.New("mocked") + saver := &Saver{} + child := &mocks.Resolver{ + MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) { + return nil, expected + }, } - if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { - t.Fatal("unexpected Addresses") + reso := saver.WrapResolver(child) + ns, err := reso.LookupNS(context.Background(), "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") + if len(ns) != 0 { + t.Fatal("expected zero length array") } - if ev[1].Value().Err.IsNotNil() { - t.Fatal("unexpected Err") + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + saver := &Saver{} + child := &mocks.Resolver{ + MockCloseIdleConnections: func() { + called = true + }, } - if ev[1].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[1].Name() != "resolve_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") + reso := saver.WrapResolver(child) + reso.CloseIdleConnections() + if !called { + t.Fatal("not called") } }) } +func TestWrapDNSTransport(t *testing.T) { + var saver *Saver + txp := &mocks.DNSTransport{} + if saver.WrapDNSTransport(txp) != txp { + t.Fatal("unexpected result") + } +} + func TestDNSTransportSaver(t *testing.T) { - t.Run("on failure", func(t *testing.T) { - expected := netxlite.ErrOODNSNoSuchHost - saver := &Saver{} - txp := saver.WrapDNSTransport(&mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return nil, expected - }, - MockNetwork: func() string { - return "fake" - }, - MockAddress: func() string { - return "" - }, + t.Run("RoundTrip", func(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := netxlite.ErrOODNSNoSuchHost + saver := &Saver{} + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expected + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } + reply, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[0].Name() != "dns_round_trip_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[1].Value().DNSResponse != nil { + t.Fatal("unexpected DNSReply") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { + t.Fatal("unexpected Err") + } + if ev[1].Name() != "dns_round_trip_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } }) - rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return rawQuery, nil + + t.Run("on success", func(t *testing.T) { + expected := []byte{0xef, 0xbe, 0xad, 0xde} + saver := &Saver{} + response := &mocks.DNSResponse{ + MockBytes: func() []byte { + return expected + }, + } + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return response, nil + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } + reply, err := txp.RoundTrip(context.Background(), query) + if err != nil { + t.Fatal("we expected nil error here") + } + if !bytes.Equal(reply.Bytes(), expected) { + t.Fatal("expected another reply here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[0].Name() != "dns_round_trip_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if !bytes.Equal(ev[1].Value().DNSResponse, expected) { + t.Fatal("unexpected DNSReply") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err.IsNotNil() { + t.Fatal("unexpected Err") + } + if ev[1].Name() != "dns_round_trip_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } + }) + }) + + t.Run("Network", func(t *testing.T) { + saver := &Saver{} + child := &mocks.DNSTransport{ + MockNetwork: func() string { + return "x" }, } - reply, err := txp.RoundTrip(context.Background(), query) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[0].Name() != "dns_round_trip_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[1].Value().DNSResponse != nil { - t.Fatal("unexpected DNSReply") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "dns_round_trip_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") + txp := saver.WrapDNSTransport(child) + if txp.Network() != "x" { + t.Fatal("unexpected result") } }) - t.Run("on success", func(t *testing.T) { - expected := []byte{0xef, 0xbe, 0xad, 0xde} + t.Run("Address", func(t *testing.T) { saver := &Saver{} - response := &mocks.DNSResponse{ - MockBytes: func() []byte { - return expected - }, - } - txp := saver.WrapDNSTransport(&mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return response, nil - }, - MockNetwork: func() string { - return "fake" - }, + child := &mocks.DNSTransport{ MockAddress: func() string { - return "" - }, - }) - rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return rawQuery, nil + return "x" }, } - reply, err := txp.RoundTrip(context.Background(), query) - if err != nil { - t.Fatal("we expected nil error here") + txp := saver.WrapDNSTransport(child) + if txp.Address() != "x" { + t.Fatal("unexpected result") } - if !bytes.Equal(reply.Bytes(), expected) { - t.Fatal("expected another reply here") + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + saver := &Saver{} + child := &mocks.DNSTransport{ + MockCloseIdleConnections: func() { + called = true + }, } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") + txp := saver.WrapDNSTransport(child) + txp.CloseIdleConnections() + if !called { + t.Fatal("not called") } - if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") + }) + + t.Run("RequiresPadding", func(t *testing.T) { + saver := &Saver{} + child := &mocks.DNSTransport{ + MockRequiresPadding: func() bool { + return true + }, } - if ev[0].Name() != "dns_round_trip_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if !bytes.Equal(ev[1].Value().DNSResponse, expected) { - t.Fatal("unexpected DNSReply") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err.IsNotNil() { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "dns_round_trip_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") + txp := saver.WrapDNSTransport(child) + if !txp.RequiresPadding() { + t.Fatal("unexpected result") } }) } diff --git a/internal/tracex/tls_test.go b/internal/tracex/tls_test.go index 76b36d4..765e49d 100644 --- a/internal/tracex/tls_test.go +++ b/internal/tracex/tls_test.go @@ -12,6 +12,14 @@ import ( "github.com/ooni/probe-cli/v3/internal/model/mocks" ) +func TestWrapTLSHandshaker(t *testing.T) { + var saver *Saver + thx := &mocks.TLSHandshaker{} + if saver.WrapTLSHandshaker(thx) != thx { + t.Fatal("unexpected result") + } +} + func TestTLSHandshakerSaver(t *testing.T) { t.Run("Handshake", func(t *testing.T) {