179 lines
4.8 KiB
Go
179 lines
4.8 KiB
Go
|
package archival
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/google/go-cmp/cmp"
|
||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||
|
"github.com/ooni/probe-cli/v3/internal/fakefill"
|
||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||
|
)
|
||
|
|
||
|
func TestSaverTLSHandshake(t *testing.T) {
|
||
|
// newTLSHandshaker helps with building a TLS handshaker
|
||
|
newTLSHandshaker := func(tlsConn net.Conn, state tls.ConnectionState, err error) model.TLSHandshaker {
|
||
|
return &mocks.TLSHandshaker{
|
||
|
MockHandshake: func(ctx context.Context, tcpConn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||
|
time.Sleep(1 * time.Microsecond)
|
||
|
return tlsConn, state, err
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// newTCPConn creates a suitable net.Conn
|
||
|
newTCPConn := func(address string) net.Conn {
|
||
|
return &mocks.Conn{
|
||
|
MockRemoteAddr: func() net.Addr {
|
||
|
return &mocks.Addr{
|
||
|
MockString: func() string {
|
||
|
return address
|
||
|
},
|
||
|
MockNetwork: func() string {
|
||
|
return "tcp"
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
MockClose: func() error {
|
||
|
return nil
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
t.Run("on success", func(t *testing.T) {
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
var certs [][]byte
|
||
|
ff := &fakefill.Filler{}
|
||
|
ff.Fill(&certs)
|
||
|
if len(certs) < 1 {
|
||
|
t.Fatal("did not fill certs")
|
||
|
}
|
||
|
saver := NewSaver()
|
||
|
v := &SingleQUICTLSHandshakeValidator{
|
||
|
ExpectedALPN: []string{"h2", "http/1.1"},
|
||
|
ExpectedSNI: "dns.google",
|
||
|
ExpectedSkipVerify: true,
|
||
|
//
|
||
|
ExpectedCipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||
|
ExpectedNegotiatedProtocol: "h2",
|
||
|
ExpectedPeerCerts: certs,
|
||
|
ExpectedVersion: tls.VersionTLS12,
|
||
|
//
|
||
|
ExpectedNetwork: "tcp",
|
||
|
ExpectedRemoteAddr: mockedEndpoint,
|
||
|
//
|
||
|
QUICConfig: nil, // this is not QUIC
|
||
|
ExpectedFailure: nil,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
expectedState := v.NewTLSConnectionState()
|
||
|
thx := newTLSHandshaker(newTCPConn(mockedEndpoint), expectedState, nil)
|
||
|
ctx := context.Background()
|
||
|
tcpConn := newTCPConn(mockedEndpoint)
|
||
|
conn, state, err := saver.TLSHandshake(ctx, thx, tcpConn, v.NewTLSConfig())
|
||
|
if conn == nil {
|
||
|
t.Fatal("expected non-nil conn")
|
||
|
}
|
||
|
conn.Close()
|
||
|
if diff := cmp.Diff(expectedState, state, cmpopts.IgnoreUnexported(tls.ConnectionState{})); diff != "" {
|
||
|
t.Fatal(diff)
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if err := v.Validate(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
// failureFlow is the flow we run on failure.
|
||
|
failureFlow := func(mockedError error, peerCerts [][]byte) error {
|
||
|
const mockedEndpoint = "8.8.4.4:443"
|
||
|
saver := NewSaver()
|
||
|
v := &SingleQUICTLSHandshakeValidator{
|
||
|
ExpectedALPN: []string{"h2", "http/1.1"},
|
||
|
ExpectedSNI: "dns.google",
|
||
|
ExpectedSkipVerify: true,
|
||
|
//
|
||
|
ExpectedCipherSuite: 0,
|
||
|
ExpectedNegotiatedProtocol: "",
|
||
|
ExpectedPeerCerts: peerCerts,
|
||
|
ExpectedVersion: 0,
|
||
|
//
|
||
|
ExpectedNetwork: "tcp",
|
||
|
ExpectedRemoteAddr: mockedEndpoint,
|
||
|
//
|
||
|
QUICConfig: nil, // this is not QUIC
|
||
|
ExpectedFailure: mockedError,
|
||
|
Saver: saver,
|
||
|
}
|
||
|
expectedState := v.NewTLSConnectionState()
|
||
|
thx := newTLSHandshaker(nil, expectedState, mockedError)
|
||
|
ctx := context.Background()
|
||
|
tcpConn := newTCPConn(mockedEndpoint)
|
||
|
conn, state, err := saver.TLSHandshake(ctx, thx, tcpConn, v.NewTLSConfig())
|
||
|
if conn != nil {
|
||
|
return errors.New("expected nil conn")
|
||
|
}
|
||
|
if diff := cmp.Diff(expectedState, state, cmpopts.IgnoreUnexported(tls.ConnectionState{})); diff != "" {
|
||
|
return errors.New(diff)
|
||
|
}
|
||
|
if !errors.Is(err, mockedError) {
|
||
|
return fmt.Errorf("unexpected err: %w", err)
|
||
|
}
|
||
|
return v.Validate()
|
||
|
}
|
||
|
|
||
|
t.Run("on generic failure", func(t *testing.T) {
|
||
|
mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF)
|
||
|
if err := failureFlow(mockedError, nil); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("on x509.HostnameError", func(t *testing.T) {
|
||
|
var certificate []byte
|
||
|
ff := &fakefill.Filler{}
|
||
|
ff.Fill(&certificate)
|
||
|
mockedError := x509.HostnameError{
|
||
|
Certificate: &x509.Certificate{Raw: certificate},
|
||
|
}
|
||
|
if err := failureFlow(mockedError, [][]byte{certificate}); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("on x509.UnknownAuthorityError", func(t *testing.T) {
|
||
|
var certificate []byte
|
||
|
ff := &fakefill.Filler{}
|
||
|
ff.Fill(&certificate)
|
||
|
mockedError := x509.UnknownAuthorityError{
|
||
|
Cert: &x509.Certificate{Raw: certificate},
|
||
|
}
|
||
|
if err := failureFlow(mockedError, [][]byte{certificate}); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("on x509.CertificateInvalidError", func(t *testing.T) {
|
||
|
var certificate []byte
|
||
|
ff := &fakefill.Filler{}
|
||
|
ff.Fill(&certificate)
|
||
|
mockedError := x509.CertificateInvalidError{
|
||
|
Cert: &x509.Certificate{Raw: certificate},
|
||
|
}
|
||
|
if err := failureFlow(mockedError, [][]byte{certificate}); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
}
|