464 lines
12 KiB
Go
464 lines
12 KiB
Go
|
package archival
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/google/go-cmp/cmp"
|
||
|
"github.com/lucas-clemente/quic-go"
|
||
|
"github.com/marten-seemann/qtls-go1-17" // it's annoying to depend on that
|
||
|
"github.com/ooni/probe-cli/v3/internal/fakefill"
|
||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||
|
)
|
||
|
|
||
|
func TestSaverWriteTo(t *testing.T) {
|
||
|
// newAddr creates an new net.Addr for testing.
|
||
|
newAddr := func(endpoint string) net.Addr {
|
||
|
return &mocks.Addr{
|
||
|
MockString: func() string {
|
||
|
return endpoint
|
||
|
},
|
||
|
MockNetwork: func() string {
|
||
|
return "udp"
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// newConn is a helper function for creating a new connection.
|
||
|
newConn := func(numBytes int, err error) model.UDPLikeConn {
|
||
|
return &mocks.UDPLikeConn{
|
||
|
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
|
||
|
time.Sleep(time.Microsecond)
|
||
|
return numBytes, err
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
t.Run("on success", func(t *testing.T) {
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
const mockedNumBytes = 128
|
||
|
addr := newAddr(mockedEndpoint)
|
||
|
conn := newConn(mockedNumBytes, nil)
|
||
|
saver := NewSaver()
|
||
|
v := &SingleNetworkEventValidator{
|
||
|
ExpectedCount: mockedNumBytes,
|
||
|
ExpectedErr: nil,
|
||
|
ExpectedNetwork: "udp",
|
||
|
ExpectedOp: netxlite.WriteToOperation,
|
||
|
ExpectedEpnt: mockedEndpoint,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
buf := make([]byte, 1024)
|
||
|
count, err := saver.WriteTo(conn, buf, addr)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if count != mockedNumBytes {
|
||
|
t.Fatal("invalid count")
|
||
|
}
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("on failure", func(t *testing.T) {
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF)
|
||
|
addr := newAddr(mockedEndpoint)
|
||
|
conn := newConn(0, mockedError)
|
||
|
saver := NewSaver()
|
||
|
v := &SingleNetworkEventValidator{
|
||
|
ExpectedCount: 0,
|
||
|
ExpectedErr: mockedError,
|
||
|
ExpectedNetwork: "udp",
|
||
|
ExpectedOp: netxlite.WriteToOperation,
|
||
|
ExpectedEpnt: mockedEndpoint,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
buf := make([]byte, 1024)
|
||
|
count, err := saver.WriteTo(conn, buf, addr)
|
||
|
if !errors.Is(err, mockedError) {
|
||
|
t.Fatal("unexpected err", err)
|
||
|
}
|
||
|
if count != 0 {
|
||
|
t.Fatal("invalid count")
|
||
|
}
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestSaverReadFrom(t *testing.T) {
|
||
|
// newAddr creates an new net.Addr for testing.
|
||
|
newAddr := func(endpoint string) net.Addr {
|
||
|
return &mocks.Addr{
|
||
|
MockString: func() string {
|
||
|
return endpoint
|
||
|
},
|
||
|
MockNetwork: func() string {
|
||
|
return "udp"
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// newConn is a helper function for creating a new connection.
|
||
|
newConn := func(numBytes int, addr net.Addr, err error) model.UDPLikeConn {
|
||
|
return &mocks.UDPLikeConn{
|
||
|
MockReadFrom: func(p []byte) (int, net.Addr, error) {
|
||
|
time.Sleep(time.Microsecond)
|
||
|
return numBytes, addr, err
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
t.Run("on success", func(t *testing.T) {
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
const mockedNumBytes = 128
|
||
|
expectedAddr := newAddr(mockedEndpoint)
|
||
|
conn := newConn(mockedNumBytes, expectedAddr, nil)
|
||
|
saver := NewSaver()
|
||
|
v := &SingleNetworkEventValidator{
|
||
|
ExpectedCount: mockedNumBytes,
|
||
|
ExpectedErr: nil,
|
||
|
ExpectedNetwork: "udp",
|
||
|
ExpectedOp: netxlite.ReadFromOperation,
|
||
|
ExpectedEpnt: mockedEndpoint,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
buf := make([]byte, 1024)
|
||
|
count, addr, err := saver.ReadFrom(conn, buf)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if expectedAddr.Network() != addr.Network() {
|
||
|
t.Fatal("invalid addr.Network")
|
||
|
}
|
||
|
if expectedAddr.String() != addr.String() {
|
||
|
t.Fatal("invalid addr.String")
|
||
|
}
|
||
|
if count != mockedNumBytes {
|
||
|
t.Fatal("invalid count")
|
||
|
}
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("on failure", func(t *testing.T) {
|
||
|
mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF)
|
||
|
conn := newConn(0, nil, mockedError)
|
||
|
saver := NewSaver()
|
||
|
v := &SingleNetworkEventValidator{
|
||
|
ExpectedCount: 0,
|
||
|
ExpectedErr: mockedError,
|
||
|
ExpectedNetwork: "udp",
|
||
|
ExpectedOp: netxlite.ReadFromOperation,
|
||
|
ExpectedEpnt: "",
|
||
|
Saver: saver,
|
||
|
}
|
||
|
buf := make([]byte, 1024)
|
||
|
count, addr, err := saver.ReadFrom(conn, buf)
|
||
|
if !errors.Is(err, mockedError) {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if addr != nil {
|
||
|
t.Fatal("invalid addr")
|
||
|
}
|
||
|
if count != 0 {
|
||
|
t.Fatal("invalid count")
|
||
|
}
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestSaverQUICDialContext(t *testing.T) {
|
||
|
// newQUICDialer creates a new QUICDialer for testing.
|
||
|
newQUICDialer := func(sess quic.EarlySession, err error) model.QUICDialer {
|
||
|
return &mocks.QUICDialer{
|
||
|
MockDialContext: func(
|
||
|
ctx context.Context, network, address string, tlsConfig *tls.Config,
|
||
|
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||
|
time.Sleep(time.Microsecond)
|
||
|
return sess, err
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// newQUICSession creates a new quic.EarlySession for testing.
|
||
|
newQUICSession := func(handshakeComplete context.Context, state tls.ConnectionState) quic.EarlySession {
|
||
|
return &mocks.QUICEarlySession{
|
||
|
MockHandshakeComplete: func() context.Context {
|
||
|
return handshakeComplete
|
||
|
},
|
||
|
MockConnectionState: func() quic.ConnectionState {
|
||
|
return quic.ConnectionState{
|
||
|
TLS: qtls.ConnectionStateWith0RTT{
|
||
|
ConnectionState: state,
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error {
|
||
|
return nil
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
t.Run("on success", func(t *testing.T) {
|
||
|
handshakeCtx := context.Background()
|
||
|
handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx)
|
||
|
handshakeCancel() // simulate a completed handshake
|
||
|
const expectedNetwork = "udp"
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
saver := NewSaver()
|
||
|
var peerCerts [][]byte
|
||
|
ff := &fakefill.Filler{}
|
||
|
ff.Fill(&peerCerts)
|
||
|
if len(peerCerts) < 1 {
|
||
|
t.Fatal("did not fill peerCerts")
|
||
|
}
|
||
|
v := &SingleQUICTLSHandshakeValidator{
|
||
|
ExpectedALPN: []string{"h3"},
|
||
|
ExpectedSNI: "dns.google",
|
||
|
ExpectedSkipVerify: true,
|
||
|
//
|
||
|
ExpectedCipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||
|
ExpectedNegotiatedProtocol: "h3",
|
||
|
ExpectedPeerCerts: peerCerts,
|
||
|
ExpectedVersion: tls.VersionTLS13,
|
||
|
//
|
||
|
ExpectedNetwork: "quic",
|
||
|
ExpectedRemoteAddr: mockedEndpoint,
|
||
|
//
|
||
|
QUICConfig: &quic.Config{},
|
||
|
//
|
||
|
ExpectedFailure: nil,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
sess := newQUICSession(handshakeCtx, v.NewTLSConnectionState())
|
||
|
dialer := newQUICDialer(sess, nil)
|
||
|
ctx := context.Background()
|
||
|
sess, err := saver.QUICDialContext(ctx, dialer, expectedNetwork,
|
||
|
mockedEndpoint, v.NewTLSConfig(), v.QUICConfig)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if sess == nil {
|
||
|
t.Fatal("expected nil sess")
|
||
|
}
|
||
|
sess.CloseWithError(0, "")
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("on handshake timeout", func(t *testing.T) {
|
||
|
handshakeCtx := context.Background()
|
||
|
handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx)
|
||
|
defer handshakeCancel()
|
||
|
const expectedNetwork = "udp"
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
saver := NewSaver()
|
||
|
v := &SingleQUICTLSHandshakeValidator{
|
||
|
ExpectedALPN: []string{"h3"},
|
||
|
ExpectedSNI: "dns.google",
|
||
|
ExpectedSkipVerify: true,
|
||
|
//
|
||
|
ExpectedCipherSuite: 0,
|
||
|
ExpectedNegotiatedProtocol: "",
|
||
|
ExpectedPeerCerts: nil,
|
||
|
ExpectedVersion: 0,
|
||
|
//
|
||
|
ExpectedNetwork: "quic",
|
||
|
ExpectedRemoteAddr: mockedEndpoint,
|
||
|
//
|
||
|
QUICConfig: &quic.Config{},
|
||
|
//
|
||
|
ExpectedFailure: context.DeadlineExceeded,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
sess := newQUICSession(handshakeCtx, tls.ConnectionState{})
|
||
|
dialer := newQUICDialer(sess, nil)
|
||
|
ctx := context.Background()
|
||
|
ctx, cancel := context.WithTimeout(ctx, time.Microsecond)
|
||
|
defer cancel()
|
||
|
sess, err := saver.QUICDialContext(ctx, dialer, expectedNetwork,
|
||
|
mockedEndpoint, v.NewTLSConfig(), v.QUICConfig)
|
||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||
|
t.Fatal("unexpected error")
|
||
|
}
|
||
|
if sess != nil {
|
||
|
t.Fatal("expected nil sess")
|
||
|
}
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("on other error", func(t *testing.T) {
|
||
|
mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF)
|
||
|
const expectedNetwork = "udp"
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
saver := NewSaver()
|
||
|
v := &SingleQUICTLSHandshakeValidator{
|
||
|
ExpectedALPN: []string{"h3"},
|
||
|
ExpectedSNI: "dns.google",
|
||
|
ExpectedSkipVerify: true,
|
||
|
//
|
||
|
ExpectedCipherSuite: 0,
|
||
|
ExpectedNegotiatedProtocol: "",
|
||
|
ExpectedPeerCerts: nil,
|
||
|
ExpectedVersion: 0,
|
||
|
//
|
||
|
ExpectedNetwork: "quic",
|
||
|
ExpectedRemoteAddr: mockedEndpoint,
|
||
|
//
|
||
|
QUICConfig: &quic.Config{},
|
||
|
//
|
||
|
ExpectedFailure: mockedError,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
dialer := newQUICDialer(nil, mockedError)
|
||
|
ctx := context.Background()
|
||
|
sess, err := saver.QUICDialContext(ctx, dialer, expectedNetwork,
|
||
|
mockedEndpoint, v.NewTLSConfig(), v.QUICConfig)
|
||
|
if !errors.Is(err, mockedError) {
|
||
|
t.Fatal("unexpected error")
|
||
|
}
|
||
|
if sess != nil {
|
||
|
t.Fatal("expected nil sess")
|
||
|
}
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
// TODO(bassosimone): here we're not testing the case in which
|
||
|
// the certificate is invalid for the required SNI.
|
||
|
//
|
||
|
// We need first to figure out whether this is what happens
|
||
|
// when we validate for QUIC in such cases. If that's the case
|
||
|
// indeed, then we can write the tests.
|
||
|
|
||
|
t.Run("on x509.HostnameError", func(t *testing.T) {
|
||
|
t.Skip("test not implemented")
|
||
|
})
|
||
|
|
||
|
t.Run("on x509.UnknownAuthorityError", func(t *testing.T) {
|
||
|
t.Skip("test not implemented")
|
||
|
})
|
||
|
|
||
|
t.Run("on x509.CertificateInvalidError", func(t *testing.T) {
|
||
|
t.Skip("test not implemented")
|
||
|
})
|
||
|
}
|
||
|
|
||
|
type SingleQUICTLSHandshakeValidator struct {
|
||
|
// related to the tls.Config
|
||
|
ExpectedALPN []string
|
||
|
ExpectedSNI string
|
||
|
ExpectedSkipVerify bool
|
||
|
|
||
|
// related to the tls.ConnectionState
|
||
|
ExpectedCipherSuite uint16
|
||
|
ExpectedNegotiatedProtocol string
|
||
|
ExpectedPeerCerts [][]byte
|
||
|
ExpectedVersion uint16
|
||
|
|
||
|
// related to the mocked conn (TLS) / dial params (QUIC)
|
||
|
ExpectedNetwork string
|
||
|
ExpectedRemoteAddr string
|
||
|
|
||
|
// tells us whether we're using QUIC
|
||
|
QUICConfig *quic.Config
|
||
|
|
||
|
// other fields
|
||
|
ExpectedFailure error
|
||
|
Saver *Saver
|
||
|
}
|
||
|
|
||
|
func (v *SingleQUICTLSHandshakeValidator) NewTLSConfig() *tls.Config {
|
||
|
return &tls.Config{
|
||
|
NextProtos: v.ExpectedALPN,
|
||
|
ServerName: v.ExpectedSNI,
|
||
|
InsecureSkipVerify: v.ExpectedSkipVerify,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (v *SingleQUICTLSHandshakeValidator) NewTLSConnectionState() tls.ConnectionState {
|
||
|
var state tls.ConnectionState
|
||
|
if v.ExpectedCipherSuite != 0 {
|
||
|
state.CipherSuite = v.ExpectedCipherSuite
|
||
|
}
|
||
|
if v.ExpectedNegotiatedProtocol != "" {
|
||
|
state.NegotiatedProtocol = v.ExpectedNegotiatedProtocol
|
||
|
}
|
||
|
for _, cert := range v.ExpectedPeerCerts {
|
||
|
state.PeerCertificates = append(state.PeerCertificates, &x509.Certificate{
|
||
|
Raw: cert,
|
||
|
})
|
||
|
}
|
||
|
if v.ExpectedVersion != 0 {
|
||
|
state.Version = v.ExpectedVersion
|
||
|
}
|
||
|
return state
|
||
|
}
|
||
|
|
||
|
func (v *SingleQUICTLSHandshakeValidator) Validate() error {
|
||
|
trace := v.Saver.MoveOutTrace()
|
||
|
var entries []*QUICTLSHandshakeEvent
|
||
|
if v.QUICConfig != nil {
|
||
|
entries = trace.QUICHandshake
|
||
|
} else {
|
||
|
entries = trace.TLSHandshake
|
||
|
}
|
||
|
if len(entries) != 1 {
|
||
|
return errors.New("expected to see a single entry")
|
||
|
}
|
||
|
entry := entries[0]
|
||
|
if diff := cmp.Diff(entry.ALPN, v.ExpectedALPN); diff != "" {
|
||
|
return errors.New(diff)
|
||
|
}
|
||
|
if entry.CipherSuite != netxlite.TLSCipherSuiteString(v.ExpectedCipherSuite) {
|
||
|
return errors.New("unexpected .CipherSuite")
|
||
|
}
|
||
|
if !errors.Is(entry.Failure, v.ExpectedFailure) {
|
||
|
return errors.New("unexpected .Failure")
|
||
|
}
|
||
|
if !entry.Finished.After(entry.Started) {
|
||
|
return errors.New(".Finished is not after .Started")
|
||
|
}
|
||
|
if entry.NegotiatedProto != v.ExpectedNegotiatedProtocol {
|
||
|
return errors.New("unexpected .NegotiatedProto")
|
||
|
}
|
||
|
if entry.Network != v.ExpectedNetwork {
|
||
|
return errors.New("unexpected .Network")
|
||
|
}
|
||
|
if diff := cmp.Diff(entry.PeerCerts, v.ExpectedPeerCerts); diff != "" {
|
||
|
return errors.New("unexpected .PeerCerts")
|
||
|
}
|
||
|
if entry.RemoteAddr != v.ExpectedRemoteAddr {
|
||
|
return errors.New("unexpected .RemoteAddr")
|
||
|
}
|
||
|
if entry.SNI != v.ExpectedSNI {
|
||
|
return errors.New("unexpected .ServerName")
|
||
|
}
|
||
|
if entry.SkipVerify != v.ExpectedSkipVerify {
|
||
|
return errors.New("unexpected .SkipVerify")
|
||
|
}
|
||
|
if entry.TLSVersion != netxlite.TLSVersionString(v.ExpectedVersion) {
|
||
|
return errors.New("unexpected .Version")
|
||
|
}
|
||
|
return nil
|
||
|
}
|