package archival import ( "context" "crypto/tls" "crypto/x509" "errors" "io" "net" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/lucas-clemente/quic-go" "github.com/marten-seemann/qtls-go1-17" // it's annoying to depend on that "github.com/ooni/probe-cli/v3/internal/fakefill" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" ) func TestSaverWriteTo(t *testing.T) { // newAddr creates an new net.Addr for testing. newAddr := func(endpoint string) net.Addr { return &mocks.Addr{ MockString: func() string { return endpoint }, MockNetwork: func() string { return "udp" }, } } // newConn is a helper function for creating a new connection. newConn := func(numBytes int, err error) model.UDPLikeConn { return &mocks.UDPLikeConn{ MockWriteTo: func(p []byte, addr net.Addr) (int, error) { time.Sleep(time.Microsecond) return numBytes, err }, } } t.Run("on success", func(t *testing.T) { const mockedEndpoint = "8.8.4.4:443" const mockedNumBytes = 128 addr := newAddr(mockedEndpoint) conn := newConn(mockedNumBytes, nil) saver := NewSaver() v := &SingleNetworkEventValidator{ ExpectedCount: mockedNumBytes, ExpectedErr: nil, ExpectedNetwork: "udp", ExpectedOp: netxlite.WriteToOperation, ExpectedEpnt: mockedEndpoint, Saver: saver, } buf := make([]byte, 1024) count, err := saver.WriteTo(conn, buf, addr) if err != nil { t.Fatal(err) } if count != mockedNumBytes { t.Fatal("invalid count") } if err := v.Validate(); err != nil { t.Fatal(err) } }) t.Run("on failure", func(t *testing.T) { const mockedEndpoint = "8.8.4.4:443" mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) addr := newAddr(mockedEndpoint) conn := newConn(0, mockedError) saver := NewSaver() v := &SingleNetworkEventValidator{ ExpectedCount: 0, ExpectedErr: mockedError, ExpectedNetwork: "udp", ExpectedOp: netxlite.WriteToOperation, ExpectedEpnt: mockedEndpoint, Saver: saver, } buf := make([]byte, 1024) count, err := saver.WriteTo(conn, buf, addr) if !errors.Is(err, mockedError) { t.Fatal("unexpected err", err) } if count != 0 { t.Fatal("invalid count") } if err := v.Validate(); err != nil { t.Fatal(err) } }) } func TestSaverReadFrom(t *testing.T) { // newAddr creates an new net.Addr for testing. newAddr := func(endpoint string) net.Addr { return &mocks.Addr{ MockString: func() string { return endpoint }, MockNetwork: func() string { return "udp" }, } } // newConn is a helper function for creating a new connection. newConn := func(numBytes int, addr net.Addr, err error) model.UDPLikeConn { return &mocks.UDPLikeConn{ MockReadFrom: func(p []byte) (int, net.Addr, error) { time.Sleep(time.Microsecond) return numBytes, addr, err }, } } t.Run("on success", func(t *testing.T) { const mockedEndpoint = "8.8.4.4:443" const mockedNumBytes = 128 expectedAddr := newAddr(mockedEndpoint) conn := newConn(mockedNumBytes, expectedAddr, nil) saver := NewSaver() v := &SingleNetworkEventValidator{ ExpectedCount: mockedNumBytes, ExpectedErr: nil, ExpectedNetwork: "udp", ExpectedOp: netxlite.ReadFromOperation, ExpectedEpnt: mockedEndpoint, Saver: saver, } buf := make([]byte, 1024) count, addr, err := saver.ReadFrom(conn, buf) if err != nil { t.Fatal(err) } if expectedAddr.Network() != addr.Network() { t.Fatal("invalid addr.Network") } if expectedAddr.String() != addr.String() { t.Fatal("invalid addr.String") } if count != mockedNumBytes { t.Fatal("invalid count") } if err := v.Validate(); err != nil { t.Fatal(err) } }) t.Run("on failure", func(t *testing.T) { mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) conn := newConn(0, nil, mockedError) saver := NewSaver() v := &SingleNetworkEventValidator{ ExpectedCount: 0, ExpectedErr: mockedError, ExpectedNetwork: "udp", ExpectedOp: netxlite.ReadFromOperation, ExpectedEpnt: "", Saver: saver, } buf := make([]byte, 1024) count, addr, err := saver.ReadFrom(conn, buf) if !errors.Is(err, mockedError) { t.Fatal(err) } if addr != nil { t.Fatal("invalid addr") } if count != 0 { t.Fatal("invalid count") } if err := v.Validate(); err != nil { t.Fatal(err) } }) } func TestSaverQUICDialContext(t *testing.T) { // newQUICDialer creates a new QUICDialer for testing. newQUICDialer := func(sess quic.EarlySession, err error) model.QUICDialer { return &mocks.QUICDialer{ MockDialContext: func( ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { time.Sleep(time.Microsecond) return sess, err }, } } // newQUICSession creates a new quic.EarlySession for testing. newQUICSession := func(handshakeComplete context.Context, state tls.ConnectionState) quic.EarlySession { return &mocks.QUICEarlySession{ MockHandshakeComplete: func() context.Context { return handshakeComplete }, MockConnectionState: func() quic.ConnectionState { return quic.ConnectionState{ TLS: qtls.ConnectionStateWith0RTT{ ConnectionState: state, }, } }, MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error { return nil }, } } t.Run("on success", func(t *testing.T) { handshakeCtx := context.Background() handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx) handshakeCancel() // simulate a completed handshake const expectedNetwork = "udp" const mockedEndpoint = "8.8.4.4:443" saver := NewSaver() var peerCerts [][]byte ff := &fakefill.Filler{} ff.Fill(&peerCerts) if len(peerCerts) < 1 { t.Fatal("did not fill peerCerts") } v := &SingleQUICTLSHandshakeValidator{ ExpectedALPN: []string{"h3"}, ExpectedSNI: "dns.google", ExpectedSkipVerify: true, // ExpectedCipherSuite: tls.TLS_AES_128_GCM_SHA256, ExpectedNegotiatedProtocol: "h3", ExpectedPeerCerts: peerCerts, ExpectedVersion: tls.VersionTLS13, // ExpectedNetwork: "quic", ExpectedRemoteAddr: mockedEndpoint, // QUICConfig: &quic.Config{}, // ExpectedFailure: nil, Saver: saver, } sess := newQUICSession(handshakeCtx, v.NewTLSConnectionState()) dialer := newQUICDialer(sess, nil) ctx := context.Background() sess, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) if err != nil { t.Fatal(err) } if sess == nil { t.Fatal("expected nil sess") } sess.CloseWithError(0, "") if err := v.Validate(); err != nil { t.Fatal(err) } }) t.Run("on handshake timeout", func(t *testing.T) { handshakeCtx := context.Background() handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx) defer handshakeCancel() const expectedNetwork = "udp" const mockedEndpoint = "8.8.4.4:443" saver := NewSaver() v := &SingleQUICTLSHandshakeValidator{ ExpectedALPN: []string{"h3"}, ExpectedSNI: "dns.google", ExpectedSkipVerify: true, // ExpectedCipherSuite: 0, ExpectedNegotiatedProtocol: "", ExpectedPeerCerts: nil, ExpectedVersion: 0, // ExpectedNetwork: "quic", ExpectedRemoteAddr: mockedEndpoint, // QUICConfig: &quic.Config{}, // ExpectedFailure: context.DeadlineExceeded, Saver: saver, } sess := newQUICSession(handshakeCtx, tls.ConnectionState{}) dialer := newQUICDialer(sess, nil) ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, time.Microsecond) defer cancel() sess, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) if !errors.Is(err, context.DeadlineExceeded) { t.Fatal("unexpected error") } if sess != nil { t.Fatal("expected nil sess") } if err := v.Validate(); err != nil { t.Fatal(err) } }) t.Run("on other error", func(t *testing.T) { mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) const expectedNetwork = "udp" const mockedEndpoint = "8.8.4.4:443" saver := NewSaver() v := &SingleQUICTLSHandshakeValidator{ ExpectedALPN: []string{"h3"}, ExpectedSNI: "dns.google", ExpectedSkipVerify: true, // ExpectedCipherSuite: 0, ExpectedNegotiatedProtocol: "", ExpectedPeerCerts: nil, ExpectedVersion: 0, // ExpectedNetwork: "quic", ExpectedRemoteAddr: mockedEndpoint, // QUICConfig: &quic.Config{}, // ExpectedFailure: mockedError, Saver: saver, } dialer := newQUICDialer(nil, mockedError) ctx := context.Background() sess, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) if !errors.Is(err, mockedError) { t.Fatal("unexpected error") } if sess != nil { t.Fatal("expected nil sess") } if err := v.Validate(); err != nil { t.Fatal(err) } }) // TODO(bassosimone): here we're not testing the case in which // the certificate is invalid for the required SNI. // // We need first to figure out whether this is what happens // when we validate for QUIC in such cases. If that's the case // indeed, then we can write the tests. t.Run("on x509.HostnameError", func(t *testing.T) { t.Skip("test not implemented") }) t.Run("on x509.UnknownAuthorityError", func(t *testing.T) { t.Skip("test not implemented") }) t.Run("on x509.CertificateInvalidError", func(t *testing.T) { t.Skip("test not implemented") }) } type SingleQUICTLSHandshakeValidator struct { // related to the tls.Config ExpectedALPN []string ExpectedSNI string ExpectedSkipVerify bool // related to the tls.ConnectionState ExpectedCipherSuite uint16 ExpectedNegotiatedProtocol string ExpectedPeerCerts [][]byte ExpectedVersion uint16 // related to the mocked conn (TLS) / dial params (QUIC) ExpectedNetwork string ExpectedRemoteAddr string // tells us whether we're using QUIC QUICConfig *quic.Config // other fields ExpectedFailure error Saver *Saver } func (v *SingleQUICTLSHandshakeValidator) NewTLSConfig() *tls.Config { return &tls.Config{ NextProtos: v.ExpectedALPN, ServerName: v.ExpectedSNI, InsecureSkipVerify: v.ExpectedSkipVerify, } } func (v *SingleQUICTLSHandshakeValidator) NewTLSConnectionState() tls.ConnectionState { var state tls.ConnectionState if v.ExpectedCipherSuite != 0 { state.CipherSuite = v.ExpectedCipherSuite } if v.ExpectedNegotiatedProtocol != "" { state.NegotiatedProtocol = v.ExpectedNegotiatedProtocol } for _, cert := range v.ExpectedPeerCerts { state.PeerCertificates = append(state.PeerCertificates, &x509.Certificate{ Raw: cert, }) } if v.ExpectedVersion != 0 { state.Version = v.ExpectedVersion } return state } func (v *SingleQUICTLSHandshakeValidator) Validate() error { trace := v.Saver.MoveOutTrace() var entries []*QUICTLSHandshakeEvent if v.QUICConfig != nil { entries = trace.QUICHandshake } else { entries = trace.TLSHandshake } if len(entries) != 1 { return errors.New("expected to see a single entry") } entry := entries[0] if diff := cmp.Diff(entry.ALPN, v.ExpectedALPN); diff != "" { return errors.New(diff) } if entry.CipherSuite != netxlite.TLSCipherSuiteString(v.ExpectedCipherSuite) { return errors.New("unexpected .CipherSuite") } if !errors.Is(entry.Failure, v.ExpectedFailure) { return errors.New("unexpected .Failure") } if !entry.Finished.After(entry.Started) { return errors.New(".Finished is not after .Started") } if entry.NegotiatedProto != v.ExpectedNegotiatedProtocol { return errors.New("unexpected .NegotiatedProto") } if entry.Network != v.ExpectedNetwork { return errors.New("unexpected .Network") } if diff := cmp.Diff(entry.PeerCerts, v.ExpectedPeerCerts); diff != "" { return errors.New("unexpected .PeerCerts") } if entry.RemoteAddr != v.ExpectedRemoteAddr { return errors.New("unexpected .RemoteAddr") } if entry.SNI != v.ExpectedSNI { return errors.New("unexpected .ServerName") } if entry.SkipVerify != v.ExpectedSkipVerify { return errors.New("unexpected .SkipVerify") } if entry.TLSVersion != netxlite.TLSVersionString(v.ExpectedVersion) { return errors.New("unexpected .Version") } return nil }