fix(netxlite): improve TLS auto-configuration (#409)

Auto-configure every relevant TLS field as close as possible to
where it's actually used.

As a side effect, add support for mocking the creation of a TLS
connection, which should possibly be useful for uTLS?

Work that is part of https://github.com/ooni/probe/issues/1505
This commit is contained in:
Simone Basso 2021-06-25 20:51:59 +02:00 committed by GitHub
parent f1f5ed342e
commit b07890af4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 360 additions and 89 deletions

View File

@ -108,7 +108,7 @@ func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
Dialer: d,
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
},
},
}

View File

@ -184,7 +184,7 @@ func NewTLSDialer(config Config) TLSDialer {
if config.Dialer == nil {
config.Dialer = NewDialer(config)
}
var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{}
var h tlsHandshaker = &netxlite.TLSHandshakerConfigurable{}
h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
if config.Logger != nil {
h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h}

View File

@ -234,7 +234,7 @@ func TestNewTLSDialerVanilla(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@ -263,7 +263,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@ -302,7 +302,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@ -342,7 +342,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@ -375,7 +375,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@ -410,7 +410,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@ -447,7 +447,7 @@ func TestNewWithTLSDialer(t *testing.T) {
tlsDialer := &netxlite.TLSDialer{
Config: new(tls.Config),
Dialer: netx.FakeDialer{Err: expected},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
}
txp := netx.NewHTTPTransport(netx.Config{
TLSDialer: tlsDialer,

View File

@ -16,7 +16,7 @@ func TestTLSDialerSuccess(t *testing.T) {
log.SetLevel(log.DebugLevel)
dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer),
TLSHandshaker: &netxlite.TLSHandshakerLogger{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Logger: log.Log,
},
}

View File

@ -26,7 +26,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
Config: &tls.Config{NextProtos: nextprotos},
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
}
@ -119,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
Config: &tls.Config{NextProtos: nextprotos},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
}
@ -184,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
}
@ -217,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
}
@ -250,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
}
@ -284,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
Config: &tls.Config{InsecureSkipVerify: true},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
}

View File

@ -16,7 +16,7 @@ import (
)
func TestSystemTLSHandshakerEOFError(t *testing.T) {
h := &netxlite.TLSHandshakerStdlib{}
h := &netxlite.TLSHandshakerConfigurable{}
conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
ServerName: "x.org",
})

View File

@ -10,7 +10,7 @@ import (
"github.com/lucas-clemente/quic-go"
)
// QUICDialerContext is a dialer for QUIC using Context.
// QUICContextDialer is a dialer for QUIC using Context.
type QUICContextDialer interface {
// DialContext establishes a new QUIC session using the given
// network and address. The tlsConfig and the quicConfig arguments
@ -39,6 +39,11 @@ func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error)
type QUICDialerQUICGo struct {
// QUICListener is the underlying QUICListener to use.
QUICListener QUICListener
// mockDialEarlyContext allows to mock quic.DialEarlyContext.
mockDialEarlyContext func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlySession, error)
}
var _ QUICContextDialer = &QUICDialerQUICGo{}
@ -46,7 +51,14 @@ var _ QUICContextDialer = &QUICDialerQUICGo{}
// errInvalidIP indicates that a string is not a valid IP.
var errInvalidIP = errors.New("netxlite: invalid IP")
// DialContext implements ContextDialer.DialContext
// DialContext implements ContextDialer.DialContext. This function will
// apply the following TLS defaults:
//
// 1. if tlsConfig.RootCAs is nil, we use the Mozilla CA that we
// bundle with this measurement library;
//
// 2. if tlsConfig.NextProtos is empty _and_ the port is 443 or 8853,
// then we configure, respectively, "h3" and "dq".
func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
address string, tlsConfig *tls.Config, quicConfig *quic.Config) (
quic.EarlySession, error) {
@ -67,7 +79,8 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
return nil, err
}
udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""}
sess, err := quic.DialEarlyContext(
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, port)
sess, err := d.dialEarlyContext(
ctx, pconn, udpAddr, address, tlsConfig, quicConfig)
if err != nil {
return nil, err
@ -75,6 +88,36 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
return &quicSessionOwnsConn{EarlySession: sess, conn: pconn}, nil
}
func (d *QUICDialerQUICGo) dialEarlyContext(ctx context.Context,
pconn net.PacketConn, remoteAddr net.Addr, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
if d.mockDialEarlyContext != nil {
return d.mockDialEarlyContext(
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
}
return quic.DialEarlyContext(
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
}
// maybeApplyTLSDefaults ensures that we're using our certificate pool, if
// needed, and that we use a suitable ALPN, if needed, for h3 and dq.
func (d *QUICDialerQUICGo) maybeApplyTLSDefaults(config *tls.Config, port int) *tls.Config {
config = config.Clone()
if config.RootCAs == nil {
config.RootCAs = defaultCertPool
}
if len(config.NextProtos) <= 0 {
switch port {
case 443:
config.NextProtos = []string{"h3"}
case 8853:
// See https://datatracker.ietf.org/doc/html/draft-ietf-dprive-dnsoquic-02#section-10
config.NextProtos = []string{"dq"}
}
}
return config
}
// quicSessionOwnsConn ensures that we close the PacketConn.
type quicSessionOwnsConn struct {
// EarlySession is the embedded early session
@ -102,7 +145,11 @@ type QUICDialerResolver struct {
Resolver Resolver
}
// DialContext implements QUICContextDialer.DialContext
// DialContext implements QUICContextDialer.DialContext. This function
// will apply the following TLS defaults:
//
// 1. if tlsConfig.ServerName is empty, we will use the hostname
// contained inside of the `address` endpoint.
func (d *QUICDialerResolver) DialContext(
ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
@ -110,15 +157,11 @@ func (d *QUICDialerResolver) DialContext(
if err != nil {
return nil, err
}
// TODO(kelmenhorst): Should this be somewhere else?
// failure if tlsCfg is nil but that should not happen
if tlsConfig.ServerName == "" {
tlsConfig.ServerName = onlyhost
}
addrs, err := d.lookupHost(ctx, onlyhost)
if err != nil {
return nil, err
}
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
// TODO(bassosimone): here we should be using multierror rather
// than just calling ReduceErrors. We are not ready to do that
// yet, though. To do that, we need first to modify nettests so
@ -136,6 +179,15 @@ func (d *QUICDialerResolver) DialContext(
return nil, reduceErrors(errorslist)
}
// maybeApplyTLSDefaults sets the SNI if it's not already configured.
func (d *QUICDialerResolver) maybeApplyTLSDefaults(config *tls.Config, host string) *tls.Config {
config = config.Clone()
if config.ServerName == "" {
config.ServerName = host
}
return config
}
// lookupHost performs a domain name resolution.
func (d *QUICDialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {

View File

@ -9,13 +9,13 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/netxmocks"
)
func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com",
}
systemdialer := QUICDialerQUICGo{
@ -34,7 +34,6 @@ func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com",
}
systemdialer := QUICDialerQUICGo{
@ -53,7 +52,6 @@ func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com",
}
systemdialer := QUICDialerQUICGo{
@ -73,7 +71,6 @@ func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
func TestQUICDialerQUICGoCannotListen(t *testing.T) {
expected := errors.New("mocked error")
tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com",
}
systemdialer := QUICDialerQUICGo{
@ -94,9 +91,8 @@ func TestQUICDialerQUICGoCannotListen(t *testing.T) {
}
}
func TestQUICDialerCannotPerformHandshake(t *testing.T) {
func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) {
tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "dns.google",
}
systemdialer := QUICDialerQUICGo{
@ -114,9 +110,8 @@ func TestQUICDialerCannotPerformHandshake(t *testing.T) {
}
}
func TestQUICDialerWorksAsIntended(t *testing.T) {
func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) {
tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "dns.google",
}
systemdialer := QUICDialerQUICGo{
@ -134,8 +129,90 @@ func TestQUICDialerWorksAsIntended(t *testing.T) {
}
}
func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{
ServerName: "dns.google",
}
systemdialer := QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{},
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlySession, error) {
gotTLSConfig = tlsConfig
return nil, expected
},
}
ctx := context.Background()
sess, err := systemdialer.DialContext(
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if sess != nil {
t.Fatal("expected nil session here")
}
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
}
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" {
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
}
if tlsConfig.ServerName != gotTLSConfig.ServerName {
t.Fatal("the ServerName field must match")
}
}
func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{
ServerName: "dns.google",
}
systemdialer := QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{},
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlySession, error) {
gotTLSConfig = tlsConfig
return nil, expected
},
}
ctx := context.Background()
sess, err := systemdialer.DialContext(
ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if sess != nil {
t.Fatal("expected nil session here")
}
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
}
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" {
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
}
if tlsConfig.ServerName != gotTLSConfig.ServerName {
t.Fatal("the ServerName field must match")
}
}
func TestQUICDialerResolverSuccess(t *testing.T) {
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
tlsConfig := &tls.Config{}
dialer := &QUICDialerResolver{
Resolver: &net.Resolver{}, Dialer: &QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{},
@ -153,7 +230,7 @@ func TestQUICDialerResolverSuccess(t *testing.T) {
}
func TestQUICDialerResolverNoPort(t *testing.T) {
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
tlsConfig := &tls.Config{}
dialer := &QUICDialerResolver{
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{}}
sess, err := dialer.DialContext(
@ -185,7 +262,7 @@ func TestQUICDialerResolverLookupHostAddress(t *testing.T) {
}
func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
tlsConfig := &tls.Config{}
expected := errors.New("mocked error")
dialer := &QUICDialerResolver{Resolver: &netxmocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
@ -206,7 +283,7 @@ func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
func TestQUICDialerResolverInvalidPort(t *testing.T) {
// This test allows us to check for the case where every attempt
// to establish a connection leads to a failure
tlsConf := &tls.Config{NextProtos: []string{"h3"}}
tlsConf := &tls.Config{}
dialer := &QUICDialerResolver{
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{},
@ -225,3 +302,32 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) {
t.Fatal("expected nil sess")
}
}
func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{}
dialer := &QUICDialerResolver{
Resolver: new(net.Resolver), Dialer: &netxmocks.QUICContextDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
gotTLSConfig = tlsConfig
return nil, expected
},
}}
sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com:443",
tlsConfig, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if sess != nil {
t.Fatal("expected nil session here")
}
if tlsConfig.ServerName != "" {
t.Fatal("should not have changed tlsConfig.ServerName")
}
if gotTLSConfig.ServerName != "www.google.com" {
t.Fatal("gotTLSConfig.ServerName has not been set")
}
}

View File

@ -0,0 +1,18 @@
package netxlite
import (
"crypto/tls"
"net"
)
// TLSConn is any tls.Conn-like structure.
type TLSConn interface {
// net.Conn is the embedded conn.
net.Conn
// ConnectionState returns the TLS connection state.
ConnectionState() tls.ConnectionState
// Handshake performs the handshake.
Handshake() error
}

View File

@ -43,8 +43,6 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string)
// We set the ServerName field if not already set.
//
// We set the ALPN if the port is 443 or 853, if not already set.
//
// We force using our root CA, unless it's already set.
func (d *TLSDialer) config(host, port string) *tls.Config {
config := d.Config
if config == nil {
@ -62,8 +60,5 @@ func (d *TLSDialer) config(host, port string) *tls.Config {
config.NextProtos = []string{"dot"}
}
}
if config.RootCAs == nil {
config.RootCAs = NewDefaultCertPool()
}
return config
}

View File

@ -3,7 +3,6 @@ package netxlite
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
@ -54,7 +53,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) {
return nil
}}, nil
}},
TLSHandshaker: &TLSHandshakerStdlib{},
TLSHandshaker: &TLSHandshakerConfigurable{},
}
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
@ -99,9 +98,6 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
if config.ServerName != "www.google.com" {
t.Fatal("invalid server name")
}
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
t.Fatal(diff)
}
@ -113,9 +109,6 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
if config.ServerName != "dns.google" {
t.Fatal("invalid server name")
}
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff)
}
@ -131,9 +124,6 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
if config.ServerName != "example.com" {
t.Fatal("invalid server name")
}
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff)
}
@ -149,29 +139,7 @@ func TestTLSDialerConfigWithALPN(t *testing.T) {
if config.ServerName != "dns.google" {
t.Fatal("invalid server name")
}
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
t.Fatal(diff)
}
}
func TestTLSDialerConfigWithRootCA(t *testing.T) {
pool := &x509.CertPool{}
d := &TLSDialer{
Config: &tls.Config{
RootCAs: pool,
},
}
config := d.config("dns.google", "853")
if config.ServerName != "dns.google" {
t.Fatal("invalid server name")
}
if config.RootCAs != pool {
t.Fatal("not the RootCAs we expected")
}
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff)
}
}

View File

@ -16,17 +16,33 @@ type TLSHandshaker interface {
net.Conn, tls.ConnectionState, error)
}
// TLSHandshakerStdlib is the stdlib's TLS handshaker.
type TLSHandshakerStdlib struct {
// Timeout is the timeout imposed on the TLS handshake. If zero
// TLSHandshakerConfigurable is a configurable TLS handshaker that
// uses by default the standard library's TLS implementation.
type TLSHandshakerConfigurable struct {
// NewConn is the OPTIONAL factory for creating a new connection. If
// this factory is not set, we'll use the stdlib.
NewConn func(conn net.Conn, config *tls.Config) TLSConn
// Timeout is the OPTIONAL timeout imposed on the TLS handshake. If zero
// or negative, we will use default timeout of 10 seconds.
Timeout time.Duration
}
var _ TLSHandshaker = &TLSHandshakerStdlib{}
var _ TLSHandshaker = &TLSHandshakerConfigurable{}
// Handshake implements Handshaker.Handshake
func (h *TLSHandshakerStdlib) Handshake(
// defaultCertPool is the cert pool we use by default. We store this
// value into a private variable to enable for unit testing.
var defaultCertPool = NewDefaultCertPool()
// Handshake implements Handshaker.Handshake. This function will
// configure the code to use the built-in Mozilla CA if the config
// field contains a nil RootCAs field.
//
// Bug
//
// Until Go 1.17 is released, this function will not honour
// the context. We'll however always enforce an overall timeout.
func (h *TLSHandshakerConfigurable) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
timeout := h.Timeout
@ -35,15 +51,27 @@ func (h *TLSHandshakerStdlib) Handshake(
}
defer conn.SetDeadline(time.Time{})
conn.SetDeadline(time.Now().Add(timeout))
tlsconn := tls.Client(conn, config)
if config.RootCAs == nil {
config = config.Clone()
config.RootCAs = defaultCertPool
}
tlsconn := h.newConn(conn, config)
if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
}
// newConn creates a new TLSConn.
func (h *TLSHandshakerConfigurable) newConn(conn net.Conn, config *tls.Config) TLSConn {
if h.NewConn != nil {
return h.NewConn(conn, config)
}
return tls.Client(conn, config)
}
// DefaultTLSHandshaker is the default TLS handshaker.
var DefaultTLSHandshaker = &TLSHandshakerStdlib{}
var DefaultTLSHandshaker = &TLSHandshakerConfigurable{}
// TLSHandshakerLogger is a TLSHandshaker with logging.
type TLSHandshakerLogger struct {

View File

@ -17,9 +17,9 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxmocks"
)
func TestTLSHandshakerStdlibWithError(t *testing.T) {
func TestTLSHandshakerConfigurableWithError(t *testing.T) {
var times []time.Time
h := &TLSHandshakerStdlib{}
h := &TLSHandshakerConfigurable{}
tcpConn := &netxmocks.Conn{
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
@ -50,7 +50,7 @@ func TestTLSHandshakerStdlibWithError(t *testing.T) {
}
}
func TestTLSHandshakerStdlibSuccess(t *testing.T) {
func TestTLSHandshakerConfigurableSuccess(t *testing.T) {
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200)
})
@ -65,7 +65,7 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
t.Fatal(err)
}
defer conn.Close()
handshaker := &TLSHandshakerStdlib{}
handshaker := &TLSHandshakerConfigurable{}
ctx := context.Background()
config := &tls.Config{
InsecureSkipVerify: true,
@ -83,6 +83,44 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
}
}
func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
handshaker := &TLSHandshakerConfigurable{
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
gotTLSConfig = config
return &netxmocks.TLSConn{
MockHandshake: func() error {
return expected
},
}
},
}
ctx := context.Background()
config := &tls.Config{}
conn := &netxmocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
}
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero connState here")
}
if tlsConn != nil {
t.Fatal("expected nil tlsConn here")
}
if config.RootCAs != nil {
t.Fatal("config.RootCAs should still be nil")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
}
}
func TestTLSHandshakerLoggerSuccess(t *testing.T) {
th := &TLSHandshakerLogger{
TLSHandshaker: &netxmocks.TLSHandshaker{

View File

@ -1,6 +1,12 @@
package netxmocks
import "net"
import (
"context"
"crypto/tls"
"net"
"github.com/lucas-clemente/quic-go"
)
// QUICListener is a mockable netxlite.QUICListener.
type QUICListener struct {
@ -11,3 +17,15 @@ type QUICListener struct {
func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) {
return ql.MockListen(addr)
}
// QUICContextDialer is a mockable netxlite.QUICContextDialer.
type QUICContextDialer struct {
MockDialContext func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error)
}
// DialContext calls MockDialContext.
func (qcd *QUICContextDialer) DialContext(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return qcd.MockDialContext(ctx, network, address, tlsConfig, quicConfig)
}

View File

@ -1,9 +1,13 @@
package netxmocks
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"github.com/lucas-clemente/quic-go"
)
func TestQUICListenerListen(t *testing.T) {
@ -21,3 +25,22 @@ func TestQUICListenerListen(t *testing.T) {
t.Fatal("expected nil conn here")
}
}
func TestQUICContextDialerDialContext(t *testing.T) {
expected := errors.New("mocked error")
qcd := &QUICContextDialer{
MockDialContext: func(ctx context.Context, network string, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return nil, expected
},
}
ctx := context.Background()
tlsConfig := &tls.Config{}
quicConfig := &quic.Config{}
sess, err := qcd.DialContext(ctx, "udp", "dns.google:443", tlsConfig, quicConfig)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if sess != nil {
t.Fatal("expected nil session")
}
}

View File

@ -0,0 +1,25 @@
package netxmocks
import "crypto/tls"
// TLSConn allows to mock netxlite.TLSConn.
type TLSConn struct {
// Conn is the embedded mockable Conn.
Conn
// MockConnectionState allows to mock the ConnectionState method.
MockConnectionState func() tls.ConnectionState
// MockHandshake allows to mock the Handshake method.
MockHandshake func() error
}
// ConnectionState calls MockConnectionState.
func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}
// Handshake calls MockHandshake.
func (c *TLSConn) Handshake() error {
return c.MockHandshake()
}