fix(netxlite): improve TLS auto-configuration (#409)
Auto-configure every relevant TLS field as close as possible to where it's actually used. As a side effect, add support for mocking the creation of a TLS connection, which should possibly be useful for uTLS? Work that is part of https://github.com/ooni/probe/issues/1505
This commit is contained in:
parent
f1f5ed342e
commit
b07890af4d
|
@ -108,7 +108,7 @@ func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
|
|||
Dialer: d,
|
||||
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
|
||||
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -184,7 +184,7 @@ func NewTLSDialer(config Config) TLSDialer {
|
|||
if config.Dialer == nil {
|
||||
config.Dialer = NewDialer(config)
|
||||
}
|
||||
var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{}
|
||||
var h tlsHandshaker = &netxlite.TLSHandshakerConfigurable{}
|
||||
h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
|
||||
if config.Logger != nil {
|
||||
h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h}
|
||||
|
|
|
@ -234,7 +234,7 @@ func TestNewTLSDialerVanilla(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
@ -263,7 +263,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
@ -302,7 +302,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
@ -342,7 +342,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
@ -375,7 +375,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
@ -410,7 +410,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
@ -447,7 +447,7 @@ func TestNewWithTLSDialer(t *testing.T) {
|
|||
tlsDialer := &netxlite.TLSDialer{
|
||||
Config: new(tls.Config),
|
||||
Dialer: netx.FakeDialer{Err: expected},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
}
|
||||
txp := netx.NewHTTPTransport(netx.Config{
|
||||
TLSDialer: tlsDialer,
|
||||
|
|
|
@ -16,7 +16,7 @@ func TestTLSDialerSuccess(t *testing.T) {
|
|||
log.SetLevel(log.DebugLevel)
|
||||
dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer),
|
||||
TLSHandshaker: &netxlite.TLSHandshakerLogger{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
Logger: log.Log,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
|
|||
Config: &tls.Config{NextProtos: nextprotos},
|
||||
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
|
||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
|
@ -119,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
|
|||
Config: &tls.Config{NextProtos: nextprotos},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
|
|||
tlsdlr := &netxlite.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
|
@ -217,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
|
|||
tlsdlr := &netxlite.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
|
@ -250,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
|
|||
tlsdlr := &netxlite.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
|
|||
Config: &tls.Config{InsecureSkipVerify: true},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
func TestSystemTLSHandshakerEOFError(t *testing.T) {
|
||||
h := &netxlite.TLSHandshakerStdlib{}
|
||||
h := &netxlite.TLSHandshakerConfigurable{}
|
||||
conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
|
||||
ServerName: "x.org",
|
||||
})
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// QUICDialerContext is a dialer for QUIC using Context.
|
||||
// QUICContextDialer is a dialer for QUIC using Context.
|
||||
type QUICContextDialer interface {
|
||||
// DialContext establishes a new QUIC session using the given
|
||||
// network and address. The tlsConfig and the quicConfig arguments
|
||||
|
@ -39,6 +39,11 @@ func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error)
|
|||
type QUICDialerQUICGo struct {
|
||||
// QUICListener is the underlying QUICListener to use.
|
||||
QUICListener QUICListener
|
||||
|
||||
// mockDialEarlyContext allows to mock quic.DialEarlyContext.
|
||||
mockDialEarlyContext func(ctx context.Context, pconn net.PacketConn,
|
||||
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||
quicConfig *quic.Config) (quic.EarlySession, error)
|
||||
}
|
||||
|
||||
var _ QUICContextDialer = &QUICDialerQUICGo{}
|
||||
|
@ -46,7 +51,14 @@ var _ QUICContextDialer = &QUICDialerQUICGo{}
|
|||
// errInvalidIP indicates that a string is not a valid IP.
|
||||
var errInvalidIP = errors.New("netxlite: invalid IP")
|
||||
|
||||
// DialContext implements ContextDialer.DialContext
|
||||
// DialContext implements ContextDialer.DialContext. This function will
|
||||
// apply the following TLS defaults:
|
||||
//
|
||||
// 1. if tlsConfig.RootCAs is nil, we use the Mozilla CA that we
|
||||
// bundle with this measurement library;
|
||||
//
|
||||
// 2. if tlsConfig.NextProtos is empty _and_ the port is 443 or 8853,
|
||||
// then we configure, respectively, "h3" and "dq".
|
||||
func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
|
||||
address string, tlsConfig *tls.Config, quicConfig *quic.Config) (
|
||||
quic.EarlySession, error) {
|
||||
|
@ -67,7 +79,8 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
|
|||
return nil, err
|
||||
}
|
||||
udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""}
|
||||
sess, err := quic.DialEarlyContext(
|
||||
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, port)
|
||||
sess, err := d.dialEarlyContext(
|
||||
ctx, pconn, udpAddr, address, tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -75,6 +88,36 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
|
|||
return &quicSessionOwnsConn{EarlySession: sess, conn: pconn}, nil
|
||||
}
|
||||
|
||||
func (d *QUICDialerQUICGo) dialEarlyContext(ctx context.Context,
|
||||
pconn net.PacketConn, remoteAddr net.Addr, address string,
|
||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
if d.mockDialEarlyContext != nil {
|
||||
return d.mockDialEarlyContext(
|
||||
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
|
||||
}
|
||||
return quic.DialEarlyContext(
|
||||
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
|
||||
}
|
||||
|
||||
// maybeApplyTLSDefaults ensures that we're using our certificate pool, if
|
||||
// needed, and that we use a suitable ALPN, if needed, for h3 and dq.
|
||||
func (d *QUICDialerQUICGo) maybeApplyTLSDefaults(config *tls.Config, port int) *tls.Config {
|
||||
config = config.Clone()
|
||||
if config.RootCAs == nil {
|
||||
config.RootCAs = defaultCertPool
|
||||
}
|
||||
if len(config.NextProtos) <= 0 {
|
||||
switch port {
|
||||
case 443:
|
||||
config.NextProtos = []string{"h3"}
|
||||
case 8853:
|
||||
// See https://datatracker.ietf.org/doc/html/draft-ietf-dprive-dnsoquic-02#section-10
|
||||
config.NextProtos = []string{"dq"}
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// quicSessionOwnsConn ensures that we close the PacketConn.
|
||||
type quicSessionOwnsConn struct {
|
||||
// EarlySession is the embedded early session
|
||||
|
@ -102,7 +145,11 @@ type QUICDialerResolver struct {
|
|||
Resolver Resolver
|
||||
}
|
||||
|
||||
// DialContext implements QUICContextDialer.DialContext
|
||||
// DialContext implements QUICContextDialer.DialContext. This function
|
||||
// will apply the following TLS defaults:
|
||||
//
|
||||
// 1. if tlsConfig.ServerName is empty, we will use the hostname
|
||||
// contained inside of the `address` endpoint.
|
||||
func (d *QUICDialerResolver) DialContext(
|
||||
ctx context.Context, network, address string,
|
||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
|
@ -110,15 +157,11 @@ func (d *QUICDialerResolver) DialContext(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(kelmenhorst): Should this be somewhere else?
|
||||
// failure if tlsCfg is nil but that should not happen
|
||||
if tlsConfig.ServerName == "" {
|
||||
tlsConfig.ServerName = onlyhost
|
||||
}
|
||||
addrs, err := d.lookupHost(ctx, onlyhost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
|
||||
// TODO(bassosimone): here we should be using multierror rather
|
||||
// than just calling ReduceErrors. We are not ready to do that
|
||||
// yet, though. To do that, we need first to modify nettests so
|
||||
|
@ -136,6 +179,15 @@ func (d *QUICDialerResolver) DialContext(
|
|||
return nil, reduceErrors(errorslist)
|
||||
}
|
||||
|
||||
// maybeApplyTLSDefaults sets the SNI if it's not already configured.
|
||||
func (d *QUICDialerResolver) maybeApplyTLSDefaults(config *tls.Config, host string) *tls.Config {
|
||||
config = config.Clone()
|
||||
if config.ServerName == "" {
|
||||
config.ServerName = host
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
// lookupHost performs a domain name resolution.
|
||||
func (d *QUICDialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
|
|
|
@ -9,13 +9,13 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxmocks"
|
||||
)
|
||||
|
||||
func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{"h3"},
|
||||
ServerName: "www.google.com",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
|
@ -34,7 +34,6 @@ func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
|
|||
|
||||
func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{"h3"},
|
||||
ServerName: "www.google.com",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
|
@ -53,7 +52,6 @@ func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
|
|||
|
||||
func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{"h3"},
|
||||
ServerName: "www.google.com",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
|
@ -73,7 +71,6 @@ func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
|
|||
func TestQUICDialerQUICGoCannotListen(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{"h3"},
|
||||
ServerName: "www.google.com",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
|
@ -94,9 +91,8 @@ func TestQUICDialerQUICGoCannotListen(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerCannotPerformHandshake(t *testing.T) {
|
||||
func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) {
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{"h3"},
|
||||
ServerName: "dns.google",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
|
@ -114,9 +110,8 @@ func TestQUICDialerCannotPerformHandshake(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerWorksAsIntended(t *testing.T) {
|
||||
func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) {
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{"h3"},
|
||||
ServerName: "dns.google",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
|
@ -134,8 +129,90 @@ func TestQUICDialerWorksAsIntended(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var gotTLSConfig *tls.Config
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: "dns.google",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
QUICListener: &QUICListenerStdlib{},
|
||||
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
gotTLSConfig = tlsConfig
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
sess, err := systemdialer.DialContext(
|
||||
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil session here")
|
||||
}
|
||||
if tlsConfig.RootCAs != nil {
|
||||
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
||||
}
|
||||
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid gotTLSConfig.RootCAs")
|
||||
}
|
||||
if tlsConfig.NextProtos != nil {
|
||||
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
||||
}
|
||||
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" {
|
||||
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
||||
}
|
||||
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
||||
t.Fatal("the ServerName field must match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var gotTLSConfig *tls.Config
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: "dns.google",
|
||||
}
|
||||
systemdialer := QUICDialerQUICGo{
|
||||
QUICListener: &QUICListenerStdlib{},
|
||||
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
gotTLSConfig = tlsConfig
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
sess, err := systemdialer.DialContext(
|
||||
ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil session here")
|
||||
}
|
||||
if tlsConfig.RootCAs != nil {
|
||||
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
||||
}
|
||||
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid gotTLSConfig.RootCAs")
|
||||
}
|
||||
if tlsConfig.NextProtos != nil {
|
||||
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
||||
}
|
||||
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" {
|
||||
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
||||
}
|
||||
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
||||
t.Fatal("the ServerName field must match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerResolverSuccess(t *testing.T) {
|
||||
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
|
||||
tlsConfig := &tls.Config{}
|
||||
dialer := &QUICDialerResolver{
|
||||
Resolver: &net.Resolver{}, Dialer: &QUICDialerQUICGo{
|
||||
QUICListener: &QUICListenerStdlib{},
|
||||
|
@ -153,7 +230,7 @@ func TestQUICDialerResolverSuccess(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestQUICDialerResolverNoPort(t *testing.T) {
|
||||
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
|
||||
tlsConfig := &tls.Config{}
|
||||
dialer := &QUICDialerResolver{
|
||||
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{}}
|
||||
sess, err := dialer.DialContext(
|
||||
|
@ -185,7 +262,7 @@ func TestQUICDialerResolverLookupHostAddress(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
|
||||
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
|
||||
tlsConfig := &tls.Config{}
|
||||
expected := errors.New("mocked error")
|
||||
dialer := &QUICDialerResolver{Resolver: &netxmocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
|
@ -206,7 +283,7 @@ func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
|
|||
func TestQUICDialerResolverInvalidPort(t *testing.T) {
|
||||
// This test allows us to check for the case where every attempt
|
||||
// to establish a connection leads to a failure
|
||||
tlsConf := &tls.Config{NextProtos: []string{"h3"}}
|
||||
tlsConf := &tls.Config{}
|
||||
dialer := &QUICDialerResolver{
|
||||
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{
|
||||
QUICListener: &QUICListenerStdlib{},
|
||||
|
@ -225,3 +302,32 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) {
|
|||
t.Fatal("expected nil sess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var gotTLSConfig *tls.Config
|
||||
tlsConfig := &tls.Config{}
|
||||
dialer := &QUICDialerResolver{
|
||||
Resolver: new(net.Resolver), Dialer: &netxmocks.QUICContextDialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string,
|
||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
gotTLSConfig = tlsConfig
|
||||
return nil, expected
|
||||
},
|
||||
}}
|
||||
sess, err := dialer.DialContext(
|
||||
context.Background(), "udp", "www.google.com:443",
|
||||
tlsConfig, &quic.Config{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil session here")
|
||||
}
|
||||
if tlsConfig.ServerName != "" {
|
||||
t.Fatal("should not have changed tlsConfig.ServerName")
|
||||
}
|
||||
if gotTLSConfig.ServerName != "www.google.com" {
|
||||
t.Fatal("gotTLSConfig.ServerName has not been set")
|
||||
}
|
||||
}
|
||||
|
|
18
internal/netxlite/tlsconn.go
Normal file
18
internal/netxlite/tlsconn.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package netxlite
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
// TLSConn is any tls.Conn-like structure.
|
||||
type TLSConn interface {
|
||||
// net.Conn is the embedded conn.
|
||||
net.Conn
|
||||
|
||||
// ConnectionState returns the TLS connection state.
|
||||
ConnectionState() tls.ConnectionState
|
||||
|
||||
// Handshake performs the handshake.
|
||||
Handshake() error
|
||||
}
|
|
@ -43,8 +43,6 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string)
|
|||
// We set the ServerName field if not already set.
|
||||
//
|
||||
// We set the ALPN if the port is 443 or 853, if not already set.
|
||||
//
|
||||
// We force using our root CA, unless it's already set.
|
||||
func (d *TLSDialer) config(host, port string) *tls.Config {
|
||||
config := d.Config
|
||||
if config == nil {
|
||||
|
@ -62,8 +60,5 @@ func (d *TLSDialer) config(host, port string) *tls.Config {
|
|||
config.NextProtos = []string{"dot"}
|
||||
}
|
||||
}
|
||||
if config.RootCAs == nil {
|
||||
config.RootCAs = NewDefaultCertPool()
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package netxlite
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -54,7 +53,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) {
|
|||
return nil
|
||||
}}, nil
|
||||
}},
|
||||
TLSHandshaker: &TLSHandshakerStdlib{},
|
||||
TLSHandshaker: &TLSHandshakerConfigurable{},
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
|
@ -99,9 +98,6 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
|
|||
if config.ServerName != "www.google.com" {
|
||||
t.Fatal("invalid server name")
|
||||
}
|
||||
if config.RootCAs == nil {
|
||||
t.Fatal("expected non-nil root CAs")
|
||||
}
|
||||
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
@ -113,9 +109,6 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
|
|||
if config.ServerName != "dns.google" {
|
||||
t.Fatal("invalid server name")
|
||||
}
|
||||
if config.RootCAs == nil {
|
||||
t.Fatal("expected non-nil root CAs")
|
||||
}
|
||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
@ -131,9 +124,6 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
|
|||
if config.ServerName != "example.com" {
|
||||
t.Fatal("invalid server name")
|
||||
}
|
||||
if config.RootCAs == nil {
|
||||
t.Fatal("expected non-nil root CAs")
|
||||
}
|
||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
|
@ -149,29 +139,7 @@ func TestTLSDialerConfigWithALPN(t *testing.T) {
|
|||
if config.ServerName != "dns.google" {
|
||||
t.Fatal("invalid server name")
|
||||
}
|
||||
if config.RootCAs == nil {
|
||||
t.Fatal("expected non-nil root CAs")
|
||||
}
|
||||
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerConfigWithRootCA(t *testing.T) {
|
||||
pool := &x509.CertPool{}
|
||||
d := &TLSDialer{
|
||||
Config: &tls.Config{
|
||||
RootCAs: pool,
|
||||
},
|
||||
}
|
||||
config := d.config("dns.google", "853")
|
||||
if config.ServerName != "dns.google" {
|
||||
t.Fatal("invalid server name")
|
||||
}
|
||||
if config.RootCAs != pool {
|
||||
t.Fatal("not the RootCAs we expected")
|
||||
}
|
||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,17 +16,33 @@ type TLSHandshaker interface {
|
|||
net.Conn, tls.ConnectionState, error)
|
||||
}
|
||||
|
||||
// TLSHandshakerStdlib is the stdlib's TLS handshaker.
|
||||
type TLSHandshakerStdlib struct {
|
||||
// Timeout is the timeout imposed on the TLS handshake. If zero
|
||||
// TLSHandshakerConfigurable is a configurable TLS handshaker that
|
||||
// uses by default the standard library's TLS implementation.
|
||||
type TLSHandshakerConfigurable struct {
|
||||
// NewConn is the OPTIONAL factory for creating a new connection. If
|
||||
// this factory is not set, we'll use the stdlib.
|
||||
NewConn func(conn net.Conn, config *tls.Config) TLSConn
|
||||
|
||||
// Timeout is the OPTIONAL timeout imposed on the TLS handshake. If zero
|
||||
// or negative, we will use default timeout of 10 seconds.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
var _ TLSHandshaker = &TLSHandshakerStdlib{}
|
||||
var _ TLSHandshaker = &TLSHandshakerConfigurable{}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h *TLSHandshakerStdlib) Handshake(
|
||||
// defaultCertPool is the cert pool we use by default. We store this
|
||||
// value into a private variable to enable for unit testing.
|
||||
var defaultCertPool = NewDefaultCertPool()
|
||||
|
||||
// Handshake implements Handshaker.Handshake. This function will
|
||||
// configure the code to use the built-in Mozilla CA if the config
|
||||
// field contains a nil RootCAs field.
|
||||
//
|
||||
// Bug
|
||||
//
|
||||
// Until Go 1.17 is released, this function will not honour
|
||||
// the context. We'll however always enforce an overall timeout.
|
||||
func (h *TLSHandshakerConfigurable) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
timeout := h.Timeout
|
||||
|
@ -35,15 +51,27 @@ func (h *TLSHandshakerStdlib) Handshake(
|
|||
}
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
tlsconn := tls.Client(conn, config)
|
||||
if config.RootCAs == nil {
|
||||
config = config.Clone()
|
||||
config.RootCAs = defaultCertPool
|
||||
}
|
||||
tlsconn := h.newConn(conn, config)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
return tlsconn, tlsconn.ConnectionState(), nil
|
||||
}
|
||||
|
||||
// newConn creates a new TLSConn.
|
||||
func (h *TLSHandshakerConfigurable) newConn(conn net.Conn, config *tls.Config) TLSConn {
|
||||
if h.NewConn != nil {
|
||||
return h.NewConn(conn, config)
|
||||
}
|
||||
return tls.Client(conn, config)
|
||||
}
|
||||
|
||||
// DefaultTLSHandshaker is the default TLS handshaker.
|
||||
var DefaultTLSHandshaker = &TLSHandshakerStdlib{}
|
||||
var DefaultTLSHandshaker = &TLSHandshakerConfigurable{}
|
||||
|
||||
// TLSHandshakerLogger is a TLSHandshaker with logging.
|
||||
type TLSHandshakerLogger struct {
|
||||
|
|
|
@ -17,9 +17,9 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/netxmocks"
|
||||
)
|
||||
|
||||
func TestTLSHandshakerStdlibWithError(t *testing.T) {
|
||||
func TestTLSHandshakerConfigurableWithError(t *testing.T) {
|
||||
var times []time.Time
|
||||
h := &TLSHandshakerStdlib{}
|
||||
h := &TLSHandshakerConfigurable{}
|
||||
tcpConn := &netxmocks.Conn{
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
|
@ -50,7 +50,7 @@ func TestTLSHandshakerStdlibWithError(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerStdlibSuccess(t *testing.T) {
|
||||
func TestTLSHandshakerConfigurableSuccess(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
})
|
||||
|
@ -65,7 +65,7 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
handshaker := &TLSHandshakerStdlib{}
|
||||
handshaker := &TLSHandshakerConfigurable{}
|
||||
ctx := context.Background()
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
|
@ -83,6 +83,44 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var gotTLSConfig *tls.Config
|
||||
handshaker := &TLSHandshakerConfigurable{
|
||||
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
|
||||
gotTLSConfig = config
|
||||
return &netxmocks.TLSConn{
|
||||
MockHandshake: func() error {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
config := &tls.Config{}
|
||||
conn := &netxmocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if !reflect.ValueOf(connState).IsZero() {
|
||||
t.Fatal("expected zero connState here")
|
||||
}
|
||||
if tlsConn != nil {
|
||||
t.Fatal("expected nil tlsConn here")
|
||||
}
|
||||
if config.RootCAs != nil {
|
||||
t.Fatal("config.RootCAs should still be nil")
|
||||
}
|
||||
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerLoggerSuccess(t *testing.T) {
|
||||
th := &TLSHandshakerLogger{
|
||||
TLSHandshaker: &netxmocks.TLSHandshaker{
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
package netxmocks
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// QUICListener is a mockable netxlite.QUICListener.
|
||||
type QUICListener struct {
|
||||
|
@ -11,3 +17,15 @@ type QUICListener struct {
|
|||
func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) {
|
||||
return ql.MockListen(addr)
|
||||
}
|
||||
|
||||
// QUICContextDialer is a mockable netxlite.QUICContextDialer.
|
||||
type QUICContextDialer struct {
|
||||
MockDialContext func(ctx context.Context, network, address string,
|
||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error)
|
||||
}
|
||||
|
||||
// DialContext calls MockDialContext.
|
||||
func (qcd *QUICContextDialer) DialContext(ctx context.Context, network, address string,
|
||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
return qcd.MockDialContext(ctx, network, address, tlsConfig, quicConfig)
|
||||
}
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
package netxmocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
func TestQUICListenerListen(t *testing.T) {
|
||||
|
@ -21,3 +25,22 @@ func TestQUICListenerListen(t *testing.T) {
|
|||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICContextDialerDialContext(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
qcd := &QUICContextDialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
tlsConfig := &tls.Config{}
|
||||
quicConfig := &quic.Config{}
|
||||
sess, err := qcd.DialContext(ctx, "udp", "dns.google:443", tlsConfig, quicConfig)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil session")
|
||||
}
|
||||
}
|
||||
|
|
25
internal/netxmocks/tlsconn.go
Normal file
25
internal/netxmocks/tlsconn.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package netxmocks
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
// TLSConn allows to mock netxlite.TLSConn.
|
||||
type TLSConn struct {
|
||||
// Conn is the embedded mockable Conn.
|
||||
Conn
|
||||
|
||||
// MockConnectionState allows to mock the ConnectionState method.
|
||||
MockConnectionState func() tls.ConnectionState
|
||||
|
||||
// MockHandshake allows to mock the Handshake method.
|
||||
MockHandshake func() error
|
||||
}
|
||||
|
||||
// ConnectionState calls MockConnectionState.
|
||||
func (c *TLSConn) ConnectionState() tls.ConnectionState {
|
||||
return c.MockConnectionState()
|
||||
}
|
||||
|
||||
// Handshake calls MockHandshake.
|
||||
func (c *TLSConn) Handshake() error {
|
||||
return c.MockHandshake()
|
||||
}
|
Loading…
Reference in New Issue
Block a user