fix(netxlite): improve TLS auto-configuration (#409)

Auto-configure every relevant TLS field as close as possible to
where it's actually used.

As a side effect, add support for mocking the creation of a TLS
connection, which should possibly be useful for uTLS?

Work that is part of https://github.com/ooni/probe/issues/1505
This commit is contained in:
Simone Basso 2021-06-25 20:51:59 +02:00 committed by GitHub
parent f1f5ed342e
commit b07890af4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 360 additions and 89 deletions

View File

@ -108,7 +108,7 @@ func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
Dialer: d, Dialer: d,
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{ TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
}, },
}, },
} }

View File

@ -184,7 +184,7 @@ func NewTLSDialer(config Config) TLSDialer {
if config.Dialer == nil { if config.Dialer == nil {
config.Dialer = NewDialer(config) config.Dialer = NewDialer(config)
} }
var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{} var h tlsHandshaker = &netxlite.TLSHandshakerConfigurable{}
h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h} h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
if config.Logger != nil { if config.Logger != nil {
h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h} h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h}

View File

@ -234,7 +234,7 @@ func TestNewTLSDialerVanilla(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
} }
@ -263,7 +263,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
} }
@ -302,7 +302,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
} }
@ -342,7 +342,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
} }
@ -375,7 +375,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
} }
@ -410,7 +410,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
} }
@ -447,7 +447,7 @@ func TestNewWithTLSDialer(t *testing.T) {
tlsDialer := &netxlite.TLSDialer{ tlsDialer := &netxlite.TLSDialer{
Config: new(tls.Config), Config: new(tls.Config),
Dialer: netx.FakeDialer{Err: expected}, Dialer: netx.FakeDialer{Err: expected},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
} }
txp := netx.NewHTTPTransport(netx.Config{ txp := netx.NewHTTPTransport(netx.Config{
TLSDialer: tlsDialer, TLSDialer: tlsDialer,

View File

@ -16,7 +16,7 @@ func TestTLSDialerSuccess(t *testing.T) {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer), dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer),
TLSHandshaker: &netxlite.TLSHandshakerLogger{ TLSHandshaker: &netxlite.TLSHandshakerLogger{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Logger: log.Log, Logger: log.Log,
}, },
} }

View File

@ -26,7 +26,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
Config: &tls.Config{NextProtos: nextprotos}, Config: &tls.Config{NextProtos: nextprotos},
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver, Saver: saver,
}, },
} }
@ -119,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
Config: &tls.Config{NextProtos: nextprotos}, Config: &tls.Config{NextProtos: nextprotos},
Dialer: new(net.Dialer), Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver, Saver: saver,
}, },
} }
@ -184,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer), Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver, Saver: saver,
}, },
} }
@ -217,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer), Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver, Saver: saver,
}, },
} }
@ -250,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialer{
Dialer: new(net.Dialer), Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver, Saver: saver,
}, },
} }
@ -284,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
Config: &tls.Config{InsecureSkipVerify: true}, Config: &tls.Config{InsecureSkipVerify: true},
Dialer: new(net.Dialer), Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver, Saver: saver,
}, },
} }

View File

@ -16,7 +16,7 @@ import (
) )
func TestSystemTLSHandshakerEOFError(t *testing.T) { func TestSystemTLSHandshakerEOFError(t *testing.T) {
h := &netxlite.TLSHandshakerStdlib{} h := &netxlite.TLSHandshakerConfigurable{}
conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{ conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
ServerName: "x.org", ServerName: "x.org",
}) })

View File

@ -10,7 +10,7 @@ import (
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
) )
// QUICDialerContext is a dialer for QUIC using Context. // QUICContextDialer is a dialer for QUIC using Context.
type QUICContextDialer interface { type QUICContextDialer interface {
// DialContext establishes a new QUIC session using the given // DialContext establishes a new QUIC session using the given
// network and address. The tlsConfig and the quicConfig arguments // network and address. The tlsConfig and the quicConfig arguments
@ -39,6 +39,11 @@ func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error)
type QUICDialerQUICGo struct { type QUICDialerQUICGo struct {
// QUICListener is the underlying QUICListener to use. // QUICListener is the underlying QUICListener to use.
QUICListener QUICListener QUICListener QUICListener
// mockDialEarlyContext allows to mock quic.DialEarlyContext.
mockDialEarlyContext func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlySession, error)
} }
var _ QUICContextDialer = &QUICDialerQUICGo{} var _ QUICContextDialer = &QUICDialerQUICGo{}
@ -46,7 +51,14 @@ var _ QUICContextDialer = &QUICDialerQUICGo{}
// errInvalidIP indicates that a string is not a valid IP. // errInvalidIP indicates that a string is not a valid IP.
var errInvalidIP = errors.New("netxlite: invalid IP") var errInvalidIP = errors.New("netxlite: invalid IP")
// DialContext implements ContextDialer.DialContext // DialContext implements ContextDialer.DialContext. This function will
// apply the following TLS defaults:
//
// 1. if tlsConfig.RootCAs is nil, we use the Mozilla CA that we
// bundle with this measurement library;
//
// 2. if tlsConfig.NextProtos is empty _and_ the port is 443 or 8853,
// then we configure, respectively, "h3" and "dq".
func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string, func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
address string, tlsConfig *tls.Config, quicConfig *quic.Config) ( address string, tlsConfig *tls.Config, quicConfig *quic.Config) (
quic.EarlySession, error) { quic.EarlySession, error) {
@ -67,7 +79,8 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
return nil, err return nil, err
} }
udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""} udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""}
sess, err := quic.DialEarlyContext( tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, port)
sess, err := d.dialEarlyContext(
ctx, pconn, udpAddr, address, tlsConfig, quicConfig) ctx, pconn, udpAddr, address, tlsConfig, quicConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,6 +88,36 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
return &quicSessionOwnsConn{EarlySession: sess, conn: pconn}, nil return &quicSessionOwnsConn{EarlySession: sess, conn: pconn}, nil
} }
func (d *QUICDialerQUICGo) dialEarlyContext(ctx context.Context,
pconn net.PacketConn, remoteAddr net.Addr, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
if d.mockDialEarlyContext != nil {
return d.mockDialEarlyContext(
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
}
return quic.DialEarlyContext(
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
}
// maybeApplyTLSDefaults ensures that we're using our certificate pool, if
// needed, and that we use a suitable ALPN, if needed, for h3 and dq.
func (d *QUICDialerQUICGo) maybeApplyTLSDefaults(config *tls.Config, port int) *tls.Config {
config = config.Clone()
if config.RootCAs == nil {
config.RootCAs = defaultCertPool
}
if len(config.NextProtos) <= 0 {
switch port {
case 443:
config.NextProtos = []string{"h3"}
case 8853:
// See https://datatracker.ietf.org/doc/html/draft-ietf-dprive-dnsoquic-02#section-10
config.NextProtos = []string{"dq"}
}
}
return config
}
// quicSessionOwnsConn ensures that we close the PacketConn. // quicSessionOwnsConn ensures that we close the PacketConn.
type quicSessionOwnsConn struct { type quicSessionOwnsConn struct {
// EarlySession is the embedded early session // EarlySession is the embedded early session
@ -102,7 +145,11 @@ type QUICDialerResolver struct {
Resolver Resolver Resolver Resolver
} }
// DialContext implements QUICContextDialer.DialContext // DialContext implements QUICContextDialer.DialContext. This function
// will apply the following TLS defaults:
//
// 1. if tlsConfig.ServerName is empty, we will use the hostname
// contained inside of the `address` endpoint.
func (d *QUICDialerResolver) DialContext( func (d *QUICDialerResolver) DialContext(
ctx context.Context, network, address string, ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
@ -110,15 +157,11 @@ func (d *QUICDialerResolver) DialContext(
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO(kelmenhorst): Should this be somewhere else?
// failure if tlsCfg is nil but that should not happen
if tlsConfig.ServerName == "" {
tlsConfig.ServerName = onlyhost
}
addrs, err := d.lookupHost(ctx, onlyhost) addrs, err := d.lookupHost(ctx, onlyhost)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
// TODO(bassosimone): here we should be using multierror rather // TODO(bassosimone): here we should be using multierror rather
// than just calling ReduceErrors. We are not ready to do that // than just calling ReduceErrors. We are not ready to do that
// yet, though. To do that, we need first to modify nettests so // yet, though. To do that, we need first to modify nettests so
@ -136,6 +179,15 @@ func (d *QUICDialerResolver) DialContext(
return nil, reduceErrors(errorslist) return nil, reduceErrors(errorslist)
} }
// maybeApplyTLSDefaults sets the SNI if it's not already configured.
func (d *QUICDialerResolver) maybeApplyTLSDefaults(config *tls.Config, host string) *tls.Config {
config = config.Clone()
if config.ServerName == "" {
config.ServerName = host
}
return config
}
// lookupHost performs a domain name resolution. // lookupHost performs a domain name resolution.
func (d *QUICDialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) { func (d *QUICDialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil { if net.ParseIP(hostname) != nil {

View File

@ -9,13 +9,13 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/netxmocks" "github.com/ooni/probe-cli/v3/internal/netxmocks"
) )
func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) { func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com", ServerName: "www.google.com",
} }
systemdialer := QUICDialerQUICGo{ systemdialer := QUICDialerQUICGo{
@ -34,7 +34,6 @@ func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
func TestQUICDialerQUICGoInvalidPort(t *testing.T) { func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com", ServerName: "www.google.com",
} }
systemdialer := QUICDialerQUICGo{ systemdialer := QUICDialerQUICGo{
@ -53,7 +52,6 @@ func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
func TestQUICDialerQUICGoInvalidIP(t *testing.T) { func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com", ServerName: "www.google.com",
} }
systemdialer := QUICDialerQUICGo{ systemdialer := QUICDialerQUICGo{
@ -73,7 +71,6 @@ func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
func TestQUICDialerQUICGoCannotListen(t *testing.T) { func TestQUICDialerQUICGoCannotListen(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "www.google.com", ServerName: "www.google.com",
} }
systemdialer := QUICDialerQUICGo{ systemdialer := QUICDialerQUICGo{
@ -94,9 +91,8 @@ func TestQUICDialerQUICGoCannotListen(t *testing.T) {
} }
} }
func TestQUICDialerCannotPerformHandshake(t *testing.T) { func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "dns.google", ServerName: "dns.google",
} }
systemdialer := QUICDialerQUICGo{ systemdialer := QUICDialerQUICGo{
@ -114,9 +110,8 @@ func TestQUICDialerCannotPerformHandshake(t *testing.T) {
} }
} }
func TestQUICDialerWorksAsIntended(t *testing.T) { func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: []string{"h3"},
ServerName: "dns.google", ServerName: "dns.google",
} }
systemdialer := QUICDialerQUICGo{ systemdialer := QUICDialerQUICGo{
@ -134,8 +129,90 @@ func TestQUICDialerWorksAsIntended(t *testing.T) {
} }
} }
func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{
ServerName: "dns.google",
}
systemdialer := QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{},
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlySession, error) {
gotTLSConfig = tlsConfig
return nil, expected
},
}
ctx := context.Background()
sess, err := systemdialer.DialContext(
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if sess != nil {
t.Fatal("expected nil session here")
}
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
}
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" {
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
}
if tlsConfig.ServerName != gotTLSConfig.ServerName {
t.Fatal("the ServerName field must match")
}
}
func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{
ServerName: "dns.google",
}
systemdialer := QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{},
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
quicConfig *quic.Config) (quic.EarlySession, error) {
gotTLSConfig = tlsConfig
return nil, expected
},
}
ctx := context.Background()
sess, err := systemdialer.DialContext(
ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if sess != nil {
t.Fatal("expected nil session here")
}
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
}
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" {
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
}
if tlsConfig.ServerName != gotTLSConfig.ServerName {
t.Fatal("the ServerName field must match")
}
}
func TestQUICDialerResolverSuccess(t *testing.T) { func TestQUICDialerResolverSuccess(t *testing.T) {
tlsConfig := &tls.Config{NextProtos: []string{"h3"}} tlsConfig := &tls.Config{}
dialer := &QUICDialerResolver{ dialer := &QUICDialerResolver{
Resolver: &net.Resolver{}, Dialer: &QUICDialerQUICGo{ Resolver: &net.Resolver{}, Dialer: &QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{}, QUICListener: &QUICListenerStdlib{},
@ -153,7 +230,7 @@ func TestQUICDialerResolverSuccess(t *testing.T) {
} }
func TestQUICDialerResolverNoPort(t *testing.T) { func TestQUICDialerResolverNoPort(t *testing.T) {
tlsConfig := &tls.Config{NextProtos: []string{"h3"}} tlsConfig := &tls.Config{}
dialer := &QUICDialerResolver{ dialer := &QUICDialerResolver{
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{}} Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{}}
sess, err := dialer.DialContext( sess, err := dialer.DialContext(
@ -185,7 +262,7 @@ func TestQUICDialerResolverLookupHostAddress(t *testing.T) {
} }
func TestQUICDialerResolverLookupHostFailure(t *testing.T) { func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
tlsConfig := &tls.Config{NextProtos: []string{"h3"}} tlsConfig := &tls.Config{}
expected := errors.New("mocked error") expected := errors.New("mocked error")
dialer := &QUICDialerResolver{Resolver: &netxmocks.Resolver{ dialer := &QUICDialerResolver{Resolver: &netxmocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
@ -206,7 +283,7 @@ func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
func TestQUICDialerResolverInvalidPort(t *testing.T) { func TestQUICDialerResolverInvalidPort(t *testing.T) {
// This test allows us to check for the case where every attempt // This test allows us to check for the case where every attempt
// to establish a connection leads to a failure // to establish a connection leads to a failure
tlsConf := &tls.Config{NextProtos: []string{"h3"}} tlsConf := &tls.Config{}
dialer := &QUICDialerResolver{ dialer := &QUICDialerResolver{
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{ Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{
QUICListener: &QUICListenerStdlib{}, QUICListener: &QUICListenerStdlib{},
@ -225,3 +302,32 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) {
t.Fatal("expected nil sess") t.Fatal("expected nil sess")
} }
} }
func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{}
dialer := &QUICDialerResolver{
Resolver: new(net.Resolver), Dialer: &netxmocks.QUICContextDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
gotTLSConfig = tlsConfig
return nil, expected
},
}}
sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com:443",
tlsConfig, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if sess != nil {
t.Fatal("expected nil session here")
}
if tlsConfig.ServerName != "" {
t.Fatal("should not have changed tlsConfig.ServerName")
}
if gotTLSConfig.ServerName != "www.google.com" {
t.Fatal("gotTLSConfig.ServerName has not been set")
}
}

View File

@ -0,0 +1,18 @@
package netxlite
import (
"crypto/tls"
"net"
)
// TLSConn is any tls.Conn-like structure.
type TLSConn interface {
// net.Conn is the embedded conn.
net.Conn
// ConnectionState returns the TLS connection state.
ConnectionState() tls.ConnectionState
// Handshake performs the handshake.
Handshake() error
}

View File

@ -43,8 +43,6 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string)
// We set the ServerName field if not already set. // We set the ServerName field if not already set.
// //
// We set the ALPN if the port is 443 or 853, if not already set. // We set the ALPN if the port is 443 or 853, if not already set.
//
// We force using our root CA, unless it's already set.
func (d *TLSDialer) config(host, port string) *tls.Config { func (d *TLSDialer) config(host, port string) *tls.Config {
config := d.Config config := d.Config
if config == nil { if config == nil {
@ -62,8 +60,5 @@ func (d *TLSDialer) config(host, port string) *tls.Config {
config.NextProtos = []string{"dot"} config.NextProtos = []string{"dot"}
} }
} }
if config.RootCAs == nil {
config.RootCAs = NewDefaultCertPool()
}
return config return config
} }

View File

@ -3,7 +3,6 @@ package netxlite
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors" "errors"
"io" "io"
"net" "net"
@ -54,7 +53,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) {
return nil return nil
}}, nil }}, nil
}}, }},
TLSHandshaker: &TLSHandshakerStdlib{}, TLSHandshaker: &TLSHandshakerConfigurable{},
} }
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
@ -99,9 +98,6 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
if config.ServerName != "www.google.com" { if config.ServerName != "www.google.com" {
t.Fatal("invalid server name") t.Fatal("invalid server name")
} }
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
@ -113,9 +109,6 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
if config.ServerName != "dns.google" { if config.ServerName != "dns.google" {
t.Fatal("invalid server name") t.Fatal("invalid server name")
} }
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
@ -131,9 +124,6 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
if config.ServerName != "example.com" { if config.ServerName != "example.com" {
t.Fatal("invalid server name") t.Fatal("invalid server name")
} }
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
@ -149,29 +139,7 @@ func TestTLSDialerConfigWithALPN(t *testing.T) {
if config.ServerName != "dns.google" { if config.ServerName != "dns.google" {
t.Fatal("invalid server name") t.Fatal("invalid server name")
} }
if config.RootCAs == nil {
t.Fatal("expected non-nil root CAs")
}
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" { if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
} }
func TestTLSDialerConfigWithRootCA(t *testing.T) {
pool := &x509.CertPool{}
d := &TLSDialer{
Config: &tls.Config{
RootCAs: pool,
},
}
config := d.config("dns.google", "853")
if config.ServerName != "dns.google" {
t.Fatal("invalid server name")
}
if config.RootCAs != pool {
t.Fatal("not the RootCAs we expected")
}
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
t.Fatal(diff)
}
}

View File

@ -16,17 +16,33 @@ type TLSHandshaker interface {
net.Conn, tls.ConnectionState, error) net.Conn, tls.ConnectionState, error)
} }
// TLSHandshakerStdlib is the stdlib's TLS handshaker. // TLSHandshakerConfigurable is a configurable TLS handshaker that
type TLSHandshakerStdlib struct { // uses by default the standard library's TLS implementation.
// Timeout is the timeout imposed on the TLS handshake. If zero type TLSHandshakerConfigurable struct {
// NewConn is the OPTIONAL factory for creating a new connection. If
// this factory is not set, we'll use the stdlib.
NewConn func(conn net.Conn, config *tls.Config) TLSConn
// Timeout is the OPTIONAL timeout imposed on the TLS handshake. If zero
// or negative, we will use default timeout of 10 seconds. // or negative, we will use default timeout of 10 seconds.
Timeout time.Duration Timeout time.Duration
} }
var _ TLSHandshaker = &TLSHandshakerStdlib{} var _ TLSHandshaker = &TLSHandshakerConfigurable{}
// Handshake implements Handshaker.Handshake // defaultCertPool is the cert pool we use by default. We store this
func (h *TLSHandshakerStdlib) Handshake( // value into a private variable to enable for unit testing.
var defaultCertPool = NewDefaultCertPool()
// Handshake implements Handshaker.Handshake. This function will
// configure the code to use the built-in Mozilla CA if the config
// field contains a nil RootCAs field.
//
// Bug
//
// Until Go 1.17 is released, this function will not honour
// the context. We'll however always enforce an overall timeout.
func (h *TLSHandshakerConfigurable) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config, ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) { ) (net.Conn, tls.ConnectionState, error) {
timeout := h.Timeout timeout := h.Timeout
@ -35,15 +51,27 @@ func (h *TLSHandshakerStdlib) Handshake(
} }
defer conn.SetDeadline(time.Time{}) defer conn.SetDeadline(time.Time{})
conn.SetDeadline(time.Now().Add(timeout)) conn.SetDeadline(time.Now().Add(timeout))
tlsconn := tls.Client(conn, config) if config.RootCAs == nil {
config = config.Clone()
config.RootCAs = defaultCertPool
}
tlsconn := h.newConn(conn, config)
if err := tlsconn.Handshake(); err != nil { if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err return nil, tls.ConnectionState{}, err
} }
return tlsconn, tlsconn.ConnectionState(), nil return tlsconn, tlsconn.ConnectionState(), nil
} }
// newConn creates a new TLSConn.
func (h *TLSHandshakerConfigurable) newConn(conn net.Conn, config *tls.Config) TLSConn {
if h.NewConn != nil {
return h.NewConn(conn, config)
}
return tls.Client(conn, config)
}
// DefaultTLSHandshaker is the default TLS handshaker. // DefaultTLSHandshaker is the default TLS handshaker.
var DefaultTLSHandshaker = &TLSHandshakerStdlib{} var DefaultTLSHandshaker = &TLSHandshakerConfigurable{}
// TLSHandshakerLogger is a TLSHandshaker with logging. // TLSHandshakerLogger is a TLSHandshaker with logging.
type TLSHandshakerLogger struct { type TLSHandshakerLogger struct {

View File

@ -17,9 +17,9 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxmocks" "github.com/ooni/probe-cli/v3/internal/netxmocks"
) )
func TestTLSHandshakerStdlibWithError(t *testing.T) { func TestTLSHandshakerConfigurableWithError(t *testing.T) {
var times []time.Time var times []time.Time
h := &TLSHandshakerStdlib{} h := &TLSHandshakerConfigurable{}
tcpConn := &netxmocks.Conn{ tcpConn := &netxmocks.Conn{
MockWrite: func(b []byte) (int, error) { MockWrite: func(b []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
@ -50,7 +50,7 @@ func TestTLSHandshakerStdlibWithError(t *testing.T) {
} }
} }
func TestTLSHandshakerStdlibSuccess(t *testing.T) { func TestTLSHandshakerConfigurableSuccess(t *testing.T) {
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200) rw.WriteHeader(200)
}) })
@ -65,7 +65,7 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
handshaker := &TLSHandshakerStdlib{} handshaker := &TLSHandshakerConfigurable{}
ctx := context.Background() ctx := context.Background()
config := &tls.Config{ config := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@ -83,6 +83,44 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
} }
} }
func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
expected := errors.New("mocked error")
var gotTLSConfig *tls.Config
handshaker := &TLSHandshakerConfigurable{
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
gotTLSConfig = config
return &netxmocks.TLSConn{
MockHandshake: func() error {
return expected
},
}
},
}
ctx := context.Background()
config := &tls.Config{}
conn := &netxmocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
}
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero connState here")
}
if tlsConn != nil {
t.Fatal("expected nil tlsConn here")
}
if config.RootCAs != nil {
t.Fatal("config.RootCAs should still be nil")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
}
}
func TestTLSHandshakerLoggerSuccess(t *testing.T) { func TestTLSHandshakerLoggerSuccess(t *testing.T) {
th := &TLSHandshakerLogger{ th := &TLSHandshakerLogger{
TLSHandshaker: &netxmocks.TLSHandshaker{ TLSHandshaker: &netxmocks.TLSHandshaker{

View File

@ -1,6 +1,12 @@
package netxmocks package netxmocks
import "net" import (
"context"
"crypto/tls"
"net"
"github.com/lucas-clemente/quic-go"
)
// QUICListener is a mockable netxlite.QUICListener. // QUICListener is a mockable netxlite.QUICListener.
type QUICListener struct { type QUICListener struct {
@ -11,3 +17,15 @@ type QUICListener struct {
func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) { func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) {
return ql.MockListen(addr) return ql.MockListen(addr)
} }
// QUICContextDialer is a mockable netxlite.QUICContextDialer.
type QUICContextDialer struct {
MockDialContext func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error)
}
// DialContext calls MockDialContext.
func (qcd *QUICContextDialer) DialContext(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return qcd.MockDialContext(ctx, network, address, tlsConfig, quicConfig)
}

View File

@ -1,9 +1,13 @@
package netxmocks package netxmocks
import ( import (
"context"
"crypto/tls"
"errors" "errors"
"net" "net"
"testing" "testing"
"github.com/lucas-clemente/quic-go"
) )
func TestQUICListenerListen(t *testing.T) { func TestQUICListenerListen(t *testing.T) {
@ -21,3 +25,22 @@ func TestQUICListenerListen(t *testing.T) {
t.Fatal("expected nil conn here") t.Fatal("expected nil conn here")
} }
} }
func TestQUICContextDialerDialContext(t *testing.T) {
expected := errors.New("mocked error")
qcd := &QUICContextDialer{
MockDialContext: func(ctx context.Context, network string, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
return nil, expected
},
}
ctx := context.Background()
tlsConfig := &tls.Config{}
quicConfig := &quic.Config{}
sess, err := qcd.DialContext(ctx, "udp", "dns.google:443", tlsConfig, quicConfig)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if sess != nil {
t.Fatal("expected nil session")
}
}

View File

@ -0,0 +1,25 @@
package netxmocks
import "crypto/tls"
// TLSConn allows to mock netxlite.TLSConn.
type TLSConn struct {
// Conn is the embedded mockable Conn.
Conn
// MockConnectionState allows to mock the ConnectionState method.
MockConnectionState func() tls.ConnectionState
// MockHandshake allows to mock the Handshake method.
MockHandshake func() error
}
// ConnectionState calls MockConnectionState.
func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}
// Handshake calls MockHandshake.
func (c *TLSConn) Handshake() error {
return c.MockHandshake()
}