fix(netxlite): improve TLS auto-configuration (#409)
Auto-configure every relevant TLS field as close as possible to where it's actually used. As a side effect, add support for mocking the creation of a TLS connection, which should possibly be useful for uTLS? Work that is part of https://github.com/ooni/probe/issues/1505
This commit is contained in:
parent
f1f5ed342e
commit
b07890af4d
|
@ -108,7 +108,7 @@ func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
|
||||||
Dialer: d,
|
Dialer: d,
|
||||||
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
|
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
|
||||||
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
|
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -184,7 +184,7 @@ func NewTLSDialer(config Config) TLSDialer {
|
||||||
if config.Dialer == nil {
|
if config.Dialer == nil {
|
||||||
config.Dialer = NewDialer(config)
|
config.Dialer = NewDialer(config)
|
||||||
}
|
}
|
||||||
var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{}
|
var h tlsHandshaker = &netxlite.TLSHandshakerConfigurable{}
|
||||||
h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
|
h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
|
||||||
if config.Logger != nil {
|
if config.Logger != nil {
|
||||||
h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h}
|
h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h}
|
||||||
|
|
|
@ -234,7 +234,7 @@ func TestNewTLSDialerVanilla(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -263,7 +263,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -302,7 +302,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -342,7 +342,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -375,7 +375,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -410,7 +410,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
|
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -447,7 +447,7 @@ func TestNewWithTLSDialer(t *testing.T) {
|
||||||
tlsDialer := &netxlite.TLSDialer{
|
tlsDialer := &netxlite.TLSDialer{
|
||||||
Config: new(tls.Config),
|
Config: new(tls.Config),
|
||||||
Dialer: netx.FakeDialer{Err: expected},
|
Dialer: netx.FakeDialer{Err: expected},
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
}
|
}
|
||||||
txp := netx.NewHTTPTransport(netx.Config{
|
txp := netx.NewHTTPTransport(netx.Config{
|
||||||
TLSDialer: tlsDialer,
|
TLSDialer: tlsDialer,
|
||||||
|
|
|
@ -16,7 +16,7 @@ func TestTLSDialerSuccess(t *testing.T) {
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer),
|
dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer),
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerLogger{
|
TLSHandshaker: &netxlite.TLSHandshakerLogger{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Logger: log.Log,
|
Logger: log.Log,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
Config: &tls.Config{NextProtos: nextprotos},
|
||||||
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
|
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
Config: &tls.Config{NextProtos: nextprotos},
|
||||||
Dialer: new(net.Dialer),
|
Dialer: new(net.Dialer),
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -184,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Dialer: new(net.Dialer),
|
Dialer: new(net.Dialer),
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -217,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Dialer: new(net.Dialer),
|
Dialer: new(net.Dialer),
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -250,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialer{
|
||||||
Dialer: new(net.Dialer),
|
Dialer: new(net.Dialer),
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -284,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
|
||||||
Config: &tls.Config{InsecureSkipVerify: true},
|
Config: &tls.Config{InsecureSkipVerify: true},
|
||||||
Dialer: new(net.Dialer),
|
Dialer: new(net.Dialer),
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSystemTLSHandshakerEOFError(t *testing.T) {
|
func TestSystemTLSHandshakerEOFError(t *testing.T) {
|
||||||
h := &netxlite.TLSHandshakerStdlib{}
|
h := &netxlite.TLSHandshakerConfigurable{}
|
||||||
conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
|
conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
|
||||||
ServerName: "x.org",
|
ServerName: "x.org",
|
||||||
})
|
})
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QUICDialerContext is a dialer for QUIC using Context.
|
// QUICContextDialer is a dialer for QUIC using Context.
|
||||||
type QUICContextDialer interface {
|
type QUICContextDialer interface {
|
||||||
// DialContext establishes a new QUIC session using the given
|
// DialContext establishes a new QUIC session using the given
|
||||||
// network and address. The tlsConfig and the quicConfig arguments
|
// network and address. The tlsConfig and the quicConfig arguments
|
||||||
|
@ -39,6 +39,11 @@ func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error)
|
||||||
type QUICDialerQUICGo struct {
|
type QUICDialerQUICGo struct {
|
||||||
// QUICListener is the underlying QUICListener to use.
|
// QUICListener is the underlying QUICListener to use.
|
||||||
QUICListener QUICListener
|
QUICListener QUICListener
|
||||||
|
|
||||||
|
// mockDialEarlyContext allows to mock quic.DialEarlyContext.
|
||||||
|
mockDialEarlyContext func(ctx context.Context, pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlySession, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ QUICContextDialer = &QUICDialerQUICGo{}
|
var _ QUICContextDialer = &QUICDialerQUICGo{}
|
||||||
|
@ -46,7 +51,14 @@ var _ QUICContextDialer = &QUICDialerQUICGo{}
|
||||||
// errInvalidIP indicates that a string is not a valid IP.
|
// errInvalidIP indicates that a string is not a valid IP.
|
||||||
var errInvalidIP = errors.New("netxlite: invalid IP")
|
var errInvalidIP = errors.New("netxlite: invalid IP")
|
||||||
|
|
||||||
// DialContext implements ContextDialer.DialContext
|
// DialContext implements ContextDialer.DialContext. This function will
|
||||||
|
// apply the following TLS defaults:
|
||||||
|
//
|
||||||
|
// 1. if tlsConfig.RootCAs is nil, we use the Mozilla CA that we
|
||||||
|
// bundle with this measurement library;
|
||||||
|
//
|
||||||
|
// 2. if tlsConfig.NextProtos is empty _and_ the port is 443 or 8853,
|
||||||
|
// then we configure, respectively, "h3" and "dq".
|
||||||
func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
|
func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
|
||||||
address string, tlsConfig *tls.Config, quicConfig *quic.Config) (
|
address string, tlsConfig *tls.Config, quicConfig *quic.Config) (
|
||||||
quic.EarlySession, error) {
|
quic.EarlySession, error) {
|
||||||
|
@ -67,7 +79,8 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""}
|
udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""}
|
||||||
sess, err := quic.DialEarlyContext(
|
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, port)
|
||||||
|
sess, err := d.dialEarlyContext(
|
||||||
ctx, pconn, udpAddr, address, tlsConfig, quicConfig)
|
ctx, pconn, udpAddr, address, tlsConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -75,6 +88,36 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string,
|
||||||
return &quicSessionOwnsConn{EarlySession: sess, conn: pconn}, nil
|
return &quicSessionOwnsConn{EarlySession: sess, conn: pconn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *QUICDialerQUICGo) dialEarlyContext(ctx context.Context,
|
||||||
|
pconn net.PacketConn, remoteAddr net.Addr, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
if d.mockDialEarlyContext != nil {
|
||||||
|
return d.mockDialEarlyContext(
|
||||||
|
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
|
||||||
|
}
|
||||||
|
return quic.DialEarlyContext(
|
||||||
|
ctx, pconn, remoteAddr, address, tlsConfig, quicConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeApplyTLSDefaults ensures that we're using our certificate pool, if
|
||||||
|
// needed, and that we use a suitable ALPN, if needed, for h3 and dq.
|
||||||
|
func (d *QUICDialerQUICGo) maybeApplyTLSDefaults(config *tls.Config, port int) *tls.Config {
|
||||||
|
config = config.Clone()
|
||||||
|
if config.RootCAs == nil {
|
||||||
|
config.RootCAs = defaultCertPool
|
||||||
|
}
|
||||||
|
if len(config.NextProtos) <= 0 {
|
||||||
|
switch port {
|
||||||
|
case 443:
|
||||||
|
config.NextProtos = []string{"h3"}
|
||||||
|
case 8853:
|
||||||
|
// See https://datatracker.ietf.org/doc/html/draft-ietf-dprive-dnsoquic-02#section-10
|
||||||
|
config.NextProtos = []string{"dq"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
// quicSessionOwnsConn ensures that we close the PacketConn.
|
// quicSessionOwnsConn ensures that we close the PacketConn.
|
||||||
type quicSessionOwnsConn struct {
|
type quicSessionOwnsConn struct {
|
||||||
// EarlySession is the embedded early session
|
// EarlySession is the embedded early session
|
||||||
|
@ -102,7 +145,11 @@ type QUICDialerResolver struct {
|
||||||
Resolver Resolver
|
Resolver Resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext implements QUICContextDialer.DialContext
|
// DialContext implements QUICContextDialer.DialContext. This function
|
||||||
|
// will apply the following TLS defaults:
|
||||||
|
//
|
||||||
|
// 1. if tlsConfig.ServerName is empty, we will use the hostname
|
||||||
|
// contained inside of the `address` endpoint.
|
||||||
func (d *QUICDialerResolver) DialContext(
|
func (d *QUICDialerResolver) DialContext(
|
||||||
ctx context.Context, network, address string,
|
ctx context.Context, network, address string,
|
||||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
@ -110,15 +157,11 @@ func (d *QUICDialerResolver) DialContext(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// TODO(kelmenhorst): Should this be somewhere else?
|
|
||||||
// failure if tlsCfg is nil but that should not happen
|
|
||||||
if tlsConfig.ServerName == "" {
|
|
||||||
tlsConfig.ServerName = onlyhost
|
|
||||||
}
|
|
||||||
addrs, err := d.lookupHost(ctx, onlyhost)
|
addrs, err := d.lookupHost(ctx, onlyhost)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
|
||||||
// TODO(bassosimone): here we should be using multierror rather
|
// TODO(bassosimone): here we should be using multierror rather
|
||||||
// than just calling ReduceErrors. We are not ready to do that
|
// than just calling ReduceErrors. We are not ready to do that
|
||||||
// yet, though. To do that, we need first to modify nettests so
|
// yet, though. To do that, we need first to modify nettests so
|
||||||
|
@ -136,6 +179,15 @@ func (d *QUICDialerResolver) DialContext(
|
||||||
return nil, reduceErrors(errorslist)
|
return nil, reduceErrors(errorslist)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maybeApplyTLSDefaults sets the SNI if it's not already configured.
|
||||||
|
func (d *QUICDialerResolver) maybeApplyTLSDefaults(config *tls.Config, host string) *tls.Config {
|
||||||
|
config = config.Clone()
|
||||||
|
if config.ServerName == "" {
|
||||||
|
config.ServerName = host
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
// lookupHost performs a domain name resolution.
|
// lookupHost performs a domain name resolution.
|
||||||
func (d *QUICDialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (d *QUICDialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
if net.ParseIP(hostname) != nil {
|
if net.ParseIP(hostname) != nil {
|
||||||
|
|
|
@ -9,13 +9,13 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxmocks"
|
"github.com/ooni/probe-cli/v3/internal/netxmocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
|
func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
NextProtos: []string{"h3"},
|
|
||||||
ServerName: "www.google.com",
|
ServerName: "www.google.com",
|
||||||
}
|
}
|
||||||
systemdialer := QUICDialerQUICGo{
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
@ -34,7 +34,6 @@ func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
|
||||||
|
|
||||||
func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
|
func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
NextProtos: []string{"h3"},
|
|
||||||
ServerName: "www.google.com",
|
ServerName: "www.google.com",
|
||||||
}
|
}
|
||||||
systemdialer := QUICDialerQUICGo{
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
@ -53,7 +52,6 @@ func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
|
||||||
|
|
||||||
func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
|
func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
NextProtos: []string{"h3"},
|
|
||||||
ServerName: "www.google.com",
|
ServerName: "www.google.com",
|
||||||
}
|
}
|
||||||
systemdialer := QUICDialerQUICGo{
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
@ -73,7 +71,6 @@ func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
|
||||||
func TestQUICDialerQUICGoCannotListen(t *testing.T) {
|
func TestQUICDialerQUICGoCannotListen(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
NextProtos: []string{"h3"},
|
|
||||||
ServerName: "www.google.com",
|
ServerName: "www.google.com",
|
||||||
}
|
}
|
||||||
systemdialer := QUICDialerQUICGo{
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
@ -94,9 +91,8 @@ func TestQUICDialerQUICGoCannotListen(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICDialerCannotPerformHandshake(t *testing.T) {
|
func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
NextProtos: []string{"h3"},
|
|
||||||
ServerName: "dns.google",
|
ServerName: "dns.google",
|
||||||
}
|
}
|
||||||
systemdialer := QUICDialerQUICGo{
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
@ -114,9 +110,8 @@ func TestQUICDialerCannotPerformHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICDialerWorksAsIntended(t *testing.T) {
|
func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
NextProtos: []string{"h3"},
|
|
||||||
ServerName: "dns.google",
|
ServerName: "dns.google",
|
||||||
}
|
}
|
||||||
systemdialer := QUICDialerQUICGo{
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
@ -134,8 +129,90 @@ func TestQUICDialerWorksAsIntended(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
QUICListener: &QUICListenerStdlib{},
|
||||||
|
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
gotTLSConfig = tlsConfig
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
sess, err := systemdialer.DialContext(
|
||||||
|
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session here")
|
||||||
|
}
|
||||||
|
if tlsConfig.RootCAs != nil {
|
||||||
|
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||||
|
t.Fatal("invalid gotTLSConfig.RootCAs")
|
||||||
|
}
|
||||||
|
if tlsConfig.NextProtos != nil {
|
||||||
|
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" {
|
||||||
|
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
||||||
|
}
|
||||||
|
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
||||||
|
t.Fatal("the ServerName field must match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
systemdialer := QUICDialerQUICGo{
|
||||||
|
QUICListener: &QUICListenerStdlib{},
|
||||||
|
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
gotTLSConfig = tlsConfig
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
sess, err := systemdialer.DialContext(
|
||||||
|
ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session here")
|
||||||
|
}
|
||||||
|
if tlsConfig.RootCAs != nil {
|
||||||
|
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||||
|
t.Fatal("invalid gotTLSConfig.RootCAs")
|
||||||
|
}
|
||||||
|
if tlsConfig.NextProtos != nil {
|
||||||
|
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" {
|
||||||
|
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
||||||
|
}
|
||||||
|
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
||||||
|
t.Fatal("the ServerName field must match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQUICDialerResolverSuccess(t *testing.T) {
|
func TestQUICDialerResolverSuccess(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
|
tlsConfig := &tls.Config{}
|
||||||
dialer := &QUICDialerResolver{
|
dialer := &QUICDialerResolver{
|
||||||
Resolver: &net.Resolver{}, Dialer: &QUICDialerQUICGo{
|
Resolver: &net.Resolver{}, Dialer: &QUICDialerQUICGo{
|
||||||
QUICListener: &QUICListenerStdlib{},
|
QUICListener: &QUICListenerStdlib{},
|
||||||
|
@ -153,7 +230,7 @@ func TestQUICDialerResolverSuccess(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICDialerResolverNoPort(t *testing.T) {
|
func TestQUICDialerResolverNoPort(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
|
tlsConfig := &tls.Config{}
|
||||||
dialer := &QUICDialerResolver{
|
dialer := &QUICDialerResolver{
|
||||||
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{}}
|
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{}}
|
||||||
sess, err := dialer.DialContext(
|
sess, err := dialer.DialContext(
|
||||||
|
@ -185,7 +262,7 @@ func TestQUICDialerResolverLookupHostAddress(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
|
func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{NextProtos: []string{"h3"}}
|
tlsConfig := &tls.Config{}
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
dialer := &QUICDialerResolver{Resolver: &netxmocks.Resolver{
|
dialer := &QUICDialerResolver{Resolver: &netxmocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
@ -206,7 +283,7 @@ func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
|
||||||
func TestQUICDialerResolverInvalidPort(t *testing.T) {
|
func TestQUICDialerResolverInvalidPort(t *testing.T) {
|
||||||
// This test allows us to check for the case where every attempt
|
// This test allows us to check for the case where every attempt
|
||||||
// to establish a connection leads to a failure
|
// to establish a connection leads to a failure
|
||||||
tlsConf := &tls.Config{NextProtos: []string{"h3"}}
|
tlsConf := &tls.Config{}
|
||||||
dialer := &QUICDialerResolver{
|
dialer := &QUICDialerResolver{
|
||||||
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{
|
Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{
|
||||||
QUICListener: &QUICListenerStdlib{},
|
QUICListener: &QUICListenerStdlib{},
|
||||||
|
@ -225,3 +302,32 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) {
|
||||||
t.Fatal("expected nil sess")
|
t.Fatal("expected nil sess")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
dialer := &QUICDialerResolver{
|
||||||
|
Resolver: new(net.Resolver), Dialer: &netxmocks.QUICContextDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
gotTLSConfig = tlsConfig
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
sess, err := dialer.DialContext(
|
||||||
|
context.Background(), "udp", "www.google.com:443",
|
||||||
|
tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session here")
|
||||||
|
}
|
||||||
|
if tlsConfig.ServerName != "" {
|
||||||
|
t.Fatal("should not have changed tlsConfig.ServerName")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.ServerName != "www.google.com" {
|
||||||
|
t.Fatal("gotTLSConfig.ServerName has not been set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
18
internal/netxlite/tlsconn.go
Normal file
18
internal/netxlite/tlsconn.go
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
package netxlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TLSConn is any tls.Conn-like structure.
|
||||||
|
type TLSConn interface {
|
||||||
|
// net.Conn is the embedded conn.
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
// ConnectionState returns the TLS connection state.
|
||||||
|
ConnectionState() tls.ConnectionState
|
||||||
|
|
||||||
|
// Handshake performs the handshake.
|
||||||
|
Handshake() error
|
||||||
|
}
|
|
@ -43,8 +43,6 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string)
|
||||||
// We set the ServerName field if not already set.
|
// We set the ServerName field if not already set.
|
||||||
//
|
//
|
||||||
// We set the ALPN if the port is 443 or 853, if not already set.
|
// We set the ALPN if the port is 443 or 853, if not already set.
|
||||||
//
|
|
||||||
// We force using our root CA, unless it's already set.
|
|
||||||
func (d *TLSDialer) config(host, port string) *tls.Config {
|
func (d *TLSDialer) config(host, port string) *tls.Config {
|
||||||
config := d.Config
|
config := d.Config
|
||||||
if config == nil {
|
if config == nil {
|
||||||
|
@ -62,8 +60,5 @@ func (d *TLSDialer) config(host, port string) *tls.Config {
|
||||||
config.NextProtos = []string{"dot"}
|
config.NextProtos = []string{"dot"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if config.RootCAs == nil {
|
|
||||||
config.RootCAs = NewDefaultCertPool()
|
|
||||||
}
|
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package netxlite
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
@ -54,7 +53,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
}}, nil
|
}}, nil
|
||||||
}},
|
}},
|
||||||
TLSHandshaker: &TLSHandshakerStdlib{},
|
TLSHandshaker: &TLSHandshakerConfigurable{},
|
||||||
}
|
}
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
|
@ -99,9 +98,6 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
|
||||||
if config.ServerName != "www.google.com" {
|
if config.ServerName != "www.google.com" {
|
||||||
t.Fatal("invalid server name")
|
t.Fatal("invalid server name")
|
||||||
}
|
}
|
||||||
if config.RootCAs == nil {
|
|
||||||
t.Fatal("expected non-nil root CAs")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
|
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
|
||||||
t.Fatal(diff)
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
|
@ -113,9 +109,6 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
|
||||||
if config.ServerName != "dns.google" {
|
if config.ServerName != "dns.google" {
|
||||||
t.Fatal("invalid server name")
|
t.Fatal("invalid server name")
|
||||||
}
|
}
|
||||||
if config.RootCAs == nil {
|
|
||||||
t.Fatal("expected non-nil root CAs")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
||||||
t.Fatal(diff)
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
|
@ -131,9 +124,6 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
|
||||||
if config.ServerName != "example.com" {
|
if config.ServerName != "example.com" {
|
||||||
t.Fatal("invalid server name")
|
t.Fatal("invalid server name")
|
||||||
}
|
}
|
||||||
if config.RootCAs == nil {
|
|
||||||
t.Fatal("expected non-nil root CAs")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
||||||
t.Fatal(diff)
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
|
@ -149,29 +139,7 @@ func TestTLSDialerConfigWithALPN(t *testing.T) {
|
||||||
if config.ServerName != "dns.google" {
|
if config.ServerName != "dns.google" {
|
||||||
t.Fatal("invalid server name")
|
t.Fatal("invalid server name")
|
||||||
}
|
}
|
||||||
if config.RootCAs == nil {
|
|
||||||
t.Fatal("expected non-nil root CAs")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
|
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
|
||||||
t.Fatal(diff)
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSDialerConfigWithRootCA(t *testing.T) {
|
|
||||||
pool := &x509.CertPool{}
|
|
||||||
d := &TLSDialer{
|
|
||||||
Config: &tls.Config{
|
|
||||||
RootCAs: pool,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config := d.config("dns.google", "853")
|
|
||||||
if config.ServerName != "dns.google" {
|
|
||||||
t.Fatal("invalid server name")
|
|
||||||
}
|
|
||||||
if config.RootCAs != pool {
|
|
||||||
t.Fatal("not the RootCAs we expected")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -16,17 +16,33 @@ type TLSHandshaker interface {
|
||||||
net.Conn, tls.ConnectionState, error)
|
net.Conn, tls.ConnectionState, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSHandshakerStdlib is the stdlib's TLS handshaker.
|
// TLSHandshakerConfigurable is a configurable TLS handshaker that
|
||||||
type TLSHandshakerStdlib struct {
|
// uses by default the standard library's TLS implementation.
|
||||||
// Timeout is the timeout imposed on the TLS handshake. If zero
|
type TLSHandshakerConfigurable struct {
|
||||||
|
// NewConn is the OPTIONAL factory for creating a new connection. If
|
||||||
|
// this factory is not set, we'll use the stdlib.
|
||||||
|
NewConn func(conn net.Conn, config *tls.Config) TLSConn
|
||||||
|
|
||||||
|
// Timeout is the OPTIONAL timeout imposed on the TLS handshake. If zero
|
||||||
// or negative, we will use default timeout of 10 seconds.
|
// or negative, we will use default timeout of 10 seconds.
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TLSHandshaker = &TLSHandshakerStdlib{}
|
var _ TLSHandshaker = &TLSHandshakerConfigurable{}
|
||||||
|
|
||||||
// Handshake implements Handshaker.Handshake
|
// defaultCertPool is the cert pool we use by default. We store this
|
||||||
func (h *TLSHandshakerStdlib) Handshake(
|
// value into a private variable to enable for unit testing.
|
||||||
|
var defaultCertPool = NewDefaultCertPool()
|
||||||
|
|
||||||
|
// Handshake implements Handshaker.Handshake. This function will
|
||||||
|
// configure the code to use the built-in Mozilla CA if the config
|
||||||
|
// field contains a nil RootCAs field.
|
||||||
|
//
|
||||||
|
// Bug
|
||||||
|
//
|
||||||
|
// Until Go 1.17 is released, this function will not honour
|
||||||
|
// the context. We'll however always enforce an overall timeout.
|
||||||
|
func (h *TLSHandshakerConfigurable) Handshake(
|
||||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||||
) (net.Conn, tls.ConnectionState, error) {
|
) (net.Conn, tls.ConnectionState, error) {
|
||||||
timeout := h.Timeout
|
timeout := h.Timeout
|
||||||
|
@ -35,15 +51,27 @@ func (h *TLSHandshakerStdlib) Handshake(
|
||||||
}
|
}
|
||||||
defer conn.SetDeadline(time.Time{})
|
defer conn.SetDeadline(time.Time{})
|
||||||
conn.SetDeadline(time.Now().Add(timeout))
|
conn.SetDeadline(time.Now().Add(timeout))
|
||||||
tlsconn := tls.Client(conn, config)
|
if config.RootCAs == nil {
|
||||||
|
config = config.Clone()
|
||||||
|
config.RootCAs = defaultCertPool
|
||||||
|
}
|
||||||
|
tlsconn := h.newConn(conn, config)
|
||||||
if err := tlsconn.Handshake(); err != nil {
|
if err := tlsconn.Handshake(); err != nil {
|
||||||
return nil, tls.ConnectionState{}, err
|
return nil, tls.ConnectionState{}, err
|
||||||
}
|
}
|
||||||
return tlsconn, tlsconn.ConnectionState(), nil
|
return tlsconn, tlsconn.ConnectionState(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newConn creates a new TLSConn.
|
||||||
|
func (h *TLSHandshakerConfigurable) newConn(conn net.Conn, config *tls.Config) TLSConn {
|
||||||
|
if h.NewConn != nil {
|
||||||
|
return h.NewConn(conn, config)
|
||||||
|
}
|
||||||
|
return tls.Client(conn, config)
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultTLSHandshaker is the default TLS handshaker.
|
// DefaultTLSHandshaker is the default TLS handshaker.
|
||||||
var DefaultTLSHandshaker = &TLSHandshakerStdlib{}
|
var DefaultTLSHandshaker = &TLSHandshakerConfigurable{}
|
||||||
|
|
||||||
// TLSHandshakerLogger is a TLSHandshaker with logging.
|
// TLSHandshakerLogger is a TLSHandshaker with logging.
|
||||||
type TLSHandshakerLogger struct {
|
type TLSHandshakerLogger struct {
|
||||||
|
|
|
@ -17,9 +17,9 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxmocks"
|
"github.com/ooni/probe-cli/v3/internal/netxmocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTLSHandshakerStdlibWithError(t *testing.T) {
|
func TestTLSHandshakerConfigurableWithError(t *testing.T) {
|
||||||
var times []time.Time
|
var times []time.Time
|
||||||
h := &TLSHandshakerStdlib{}
|
h := &TLSHandshakerConfigurable{}
|
||||||
tcpConn := &netxmocks.Conn{
|
tcpConn := &netxmocks.Conn{
|
||||||
MockWrite: func(b []byte) (int, error) {
|
MockWrite: func(b []byte) (int, error) {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
|
@ -50,7 +50,7 @@ func TestTLSHandshakerStdlibWithError(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerStdlibSuccess(t *testing.T) {
|
func TestTLSHandshakerConfigurableSuccess(t *testing.T) {
|
||||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
rw.WriteHeader(200)
|
rw.WriteHeader(200)
|
||||||
})
|
})
|
||||||
|
@ -65,7 +65,7 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
handshaker := &TLSHandshakerStdlib{}
|
handshaker := &TLSHandshakerConfigurable{}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
|
@ -83,6 +83,44 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
handshaker := &TLSHandshakerConfigurable{
|
||||||
|
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
|
||||||
|
gotTLSConfig = config
|
||||||
|
return &netxmocks.TLSConn{
|
||||||
|
MockHandshake: func() error {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
config := &tls.Config{}
|
||||||
|
conn := &netxmocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(connState).IsZero() {
|
||||||
|
t.Fatal("expected zero connState here")
|
||||||
|
}
|
||||||
|
if tlsConn != nil {
|
||||||
|
t.Fatal("expected nil tlsConn here")
|
||||||
|
}
|
||||||
|
if config.RootCAs != nil {
|
||||||
|
t.Fatal("config.RootCAs should still be nil")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||||
|
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerLoggerSuccess(t *testing.T) {
|
func TestTLSHandshakerLoggerSuccess(t *testing.T) {
|
||||||
th := &TLSHandshakerLogger{
|
th := &TLSHandshakerLogger{
|
||||||
TLSHandshaker: &netxmocks.TLSHandshaker{
|
TLSHandshaker: &netxmocks.TLSHandshaker{
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
package netxmocks
|
package netxmocks
|
||||||
|
|
||||||
import "net"
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
)
|
||||||
|
|
||||||
// QUICListener is a mockable netxlite.QUICListener.
|
// QUICListener is a mockable netxlite.QUICListener.
|
||||||
type QUICListener struct {
|
type QUICListener struct {
|
||||||
|
@ -11,3 +17,15 @@ type QUICListener struct {
|
||||||
func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) {
|
func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) {
|
||||||
return ql.MockListen(addr)
|
return ql.MockListen(addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QUICContextDialer is a mockable netxlite.QUICContextDialer.
|
||||||
|
type QUICContextDialer struct {
|
||||||
|
MockDialContext func(ctx context.Context, network, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext calls MockDialContext.
|
||||||
|
func (qcd *QUICContextDialer) DialContext(ctx context.Context, network, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
return qcd.MockDialContext(ctx, network, address, tlsConfig, quicConfig)
|
||||||
|
}
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
package netxmocks
|
package netxmocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQUICListenerListen(t *testing.T) {
|
func TestQUICListenerListen(t *testing.T) {
|
||||||
|
@ -21,3 +25,22 @@ func TestQUICListenerListen(t *testing.T) {
|
||||||
t.Fatal("expected nil conn here")
|
t.Fatal("expected nil conn here")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQUICContextDialerDialContext(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
qcd := &QUICContextDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network string, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
quicConfig := &quic.Config{}
|
||||||
|
sess, err := qcd.DialContext(ctx, "udp", "dns.google:443", tlsConfig, quicConfig)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
25
internal/netxmocks/tlsconn.go
Normal file
25
internal/netxmocks/tlsconn.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package netxmocks
|
||||||
|
|
||||||
|
import "crypto/tls"
|
||||||
|
|
||||||
|
// TLSConn allows to mock netxlite.TLSConn.
|
||||||
|
type TLSConn struct {
|
||||||
|
// Conn is the embedded mockable Conn.
|
||||||
|
Conn
|
||||||
|
|
||||||
|
// MockConnectionState allows to mock the ConnectionState method.
|
||||||
|
MockConnectionState func() tls.ConnectionState
|
||||||
|
|
||||||
|
// MockHandshake allows to mock the Handshake method.
|
||||||
|
MockHandshake func() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionState calls MockConnectionState.
|
||||||
|
func (c *TLSConn) ConnectionState() tls.ConnectionState {
|
||||||
|
return c.MockConnectionState()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handshake calls MockHandshake.
|
||||||
|
func (c *TLSConn) Handshake() error {
|
||||||
|
return c.MockHandshake()
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user