diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index 63ac676..dffa231 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -108,7 +108,7 @@ func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer { Dialer: d, TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, }, }, } diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index eb25850..ba2fd77 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -184,7 +184,7 @@ func NewTLSDialer(config Config) TLSDialer { if config.Dialer == nil { config.Dialer = NewDialer(config) } - var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{} + var h tlsHandshaker = &netxlite.TLSHandshakerConfigurable{} h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h} if config.Logger != nil { h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h} diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 569d1f5..1f2d318 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -234,7 +234,7 @@ func TestNewTLSDialerVanilla(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -263,7 +263,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -302,7 +302,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -342,7 +342,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -375,7 +375,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -410,7 +410,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok { + if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -447,7 +447,7 @@ func TestNewWithTLSDialer(t *testing.T) { tlsDialer := &netxlite.TLSDialer{ Config: new(tls.Config), Dialer: netx.FakeDialer{Err: expected}, - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, } txp := netx.NewHTTPTransport(netx.Config{ TLSDialer: tlsDialer, diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go index 24bd2d7..d9795c3 100644 --- a/internal/engine/netx/tlsdialer/integration_test.go +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -16,7 +16,7 @@ func TestTLSDialerSuccess(t *testing.T) { log.SetLevel(log.DebugLevel) dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer), TLSHandshaker: &netxlite.TLSHandshakerLogger{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Logger: log.Log, }, } diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go index 9420fe8..5aeb820 100644 --- a/internal/engine/netx/tlsdialer/saver_test.go +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -26,7 +26,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { Config: &tls.Config{NextProtos: nextprotos}, Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, }, } @@ -119,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) { Config: &tls.Config{NextProtos: nextprotos}, Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, }, } @@ -184,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) { tlsdlr := &netxlite.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, }, } @@ -217,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { tlsdlr := &netxlite.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, }, } @@ -250,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) { tlsdlr := &netxlite.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, }, } @@ -284,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { Config: &tls.Config{InsecureSkipVerify: true}, Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, + TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, Saver: saver, }, } diff --git a/internal/engine/netx/tlsdialer/tls_test.go b/internal/engine/netx/tlsdialer/tls_test.go index bdbabab..64161bc 100644 --- a/internal/engine/netx/tlsdialer/tls_test.go +++ b/internal/engine/netx/tlsdialer/tls_test.go @@ -16,7 +16,7 @@ import ( ) func TestSystemTLSHandshakerEOFError(t *testing.T) { - h := &netxlite.TLSHandshakerStdlib{} + h := &netxlite.TLSHandshakerConfigurable{} conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{ ServerName: "x.org", }) diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index f1cea3f..3971842 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -10,7 +10,7 @@ import ( "github.com/lucas-clemente/quic-go" ) -// QUICDialerContext is a dialer for QUIC using Context. +// QUICContextDialer is a dialer for QUIC using Context. type QUICContextDialer interface { // DialContext establishes a new QUIC session using the given // network and address. The tlsConfig and the quicConfig arguments @@ -39,6 +39,11 @@ func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error) type QUICDialerQUICGo struct { // QUICListener is the underlying QUICListener to use. QUICListener QUICListener + + // mockDialEarlyContext allows to mock quic.DialEarlyContext. + mockDialEarlyContext func(ctx context.Context, pconn net.PacketConn, + remoteAddr net.Addr, host string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) } var _ QUICContextDialer = &QUICDialerQUICGo{} @@ -46,7 +51,14 @@ var _ QUICContextDialer = &QUICDialerQUICGo{} // errInvalidIP indicates that a string is not a valid IP. var errInvalidIP = errors.New("netxlite: invalid IP") -// DialContext implements ContextDialer.DialContext +// DialContext implements ContextDialer.DialContext. This function will +// apply the following TLS defaults: +// +// 1. if tlsConfig.RootCAs is nil, we use the Mozilla CA that we +// bundle with this measurement library; +// +// 2. if tlsConfig.NextProtos is empty _and_ the port is 443 or 8853, +// then we configure, respectively, "h3" and "dq". func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string, address string, tlsConfig *tls.Config, quicConfig *quic.Config) ( quic.EarlySession, error) { @@ -67,7 +79,8 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string, return nil, err } udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""} - sess, err := quic.DialEarlyContext( + tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, port) + sess, err := d.dialEarlyContext( ctx, pconn, udpAddr, address, tlsConfig, quicConfig) if err != nil { return nil, err @@ -75,6 +88,36 @@ func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string, return &quicSessionOwnsConn{EarlySession: sess, conn: pconn}, nil } +func (d *QUICDialerQUICGo) dialEarlyContext(ctx context.Context, + pconn net.PacketConn, remoteAddr net.Addr, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + if d.mockDialEarlyContext != nil { + return d.mockDialEarlyContext( + ctx, pconn, remoteAddr, address, tlsConfig, quicConfig) + } + return quic.DialEarlyContext( + ctx, pconn, remoteAddr, address, tlsConfig, quicConfig) +} + +// maybeApplyTLSDefaults ensures that we're using our certificate pool, if +// needed, and that we use a suitable ALPN, if needed, for h3 and dq. +func (d *QUICDialerQUICGo) maybeApplyTLSDefaults(config *tls.Config, port int) *tls.Config { + config = config.Clone() + if config.RootCAs == nil { + config.RootCAs = defaultCertPool + } + if len(config.NextProtos) <= 0 { + switch port { + case 443: + config.NextProtos = []string{"h3"} + case 8853: + // See https://datatracker.ietf.org/doc/html/draft-ietf-dprive-dnsoquic-02#section-10 + config.NextProtos = []string{"dq"} + } + } + return config +} + // quicSessionOwnsConn ensures that we close the PacketConn. type quicSessionOwnsConn struct { // EarlySession is the embedded early session @@ -102,7 +145,11 @@ type QUICDialerResolver struct { Resolver Resolver } -// DialContext implements QUICContextDialer.DialContext +// DialContext implements QUICContextDialer.DialContext. This function +// will apply the following TLS defaults: +// +// 1. if tlsConfig.ServerName is empty, we will use the hostname +// contained inside of the `address` endpoint. func (d *QUICDialerResolver) DialContext( ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { @@ -110,15 +157,11 @@ func (d *QUICDialerResolver) DialContext( if err != nil { return nil, err } - // TODO(kelmenhorst): Should this be somewhere else? - // failure if tlsCfg is nil but that should not happen - if tlsConfig.ServerName == "" { - tlsConfig.ServerName = onlyhost - } addrs, err := d.lookupHost(ctx, onlyhost) if err != nil { return nil, err } + tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost) // TODO(bassosimone): here we should be using multierror rather // than just calling ReduceErrors. We are not ready to do that // yet, though. To do that, we need first to modify nettests so @@ -136,6 +179,15 @@ func (d *QUICDialerResolver) DialContext( return nil, reduceErrors(errorslist) } +// maybeApplyTLSDefaults sets the SNI if it's not already configured. +func (d *QUICDialerResolver) maybeApplyTLSDefaults(config *tls.Config, host string) *tls.Config { + config = config.Clone() + if config.ServerName == "" { + config.ServerName = host + } + return config +} + // lookupHost performs a domain name resolution. func (d *QUICDialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) { if net.ParseIP(hostname) != nil { diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index c1553bb..fde7ce5 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -9,13 +9,13 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/netxmocks" ) func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) { tlsConfig := &tls.Config{ - NextProtos: []string{"h3"}, ServerName: "www.google.com", } systemdialer := QUICDialerQUICGo{ @@ -34,7 +34,6 @@ func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) { func TestQUICDialerQUICGoInvalidPort(t *testing.T) { tlsConfig := &tls.Config{ - NextProtos: []string{"h3"}, ServerName: "www.google.com", } systemdialer := QUICDialerQUICGo{ @@ -53,7 +52,6 @@ func TestQUICDialerQUICGoInvalidPort(t *testing.T) { func TestQUICDialerQUICGoInvalidIP(t *testing.T) { tlsConfig := &tls.Config{ - NextProtos: []string{"h3"}, ServerName: "www.google.com", } systemdialer := QUICDialerQUICGo{ @@ -73,7 +71,6 @@ func TestQUICDialerQUICGoInvalidIP(t *testing.T) { func TestQUICDialerQUICGoCannotListen(t *testing.T) { expected := errors.New("mocked error") tlsConfig := &tls.Config{ - NextProtos: []string{"h3"}, ServerName: "www.google.com", } systemdialer := QUICDialerQUICGo{ @@ -94,9 +91,8 @@ func TestQUICDialerQUICGoCannotListen(t *testing.T) { } } -func TestQUICDialerCannotPerformHandshake(t *testing.T) { +func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) { tlsConfig := &tls.Config{ - NextProtos: []string{"h3"}, ServerName: "dns.google", } systemdialer := QUICDialerQUICGo{ @@ -114,9 +110,8 @@ func TestQUICDialerCannotPerformHandshake(t *testing.T) { } } -func TestQUICDialerWorksAsIntended(t *testing.T) { +func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) { tlsConfig := &tls.Config{ - NextProtos: []string{"h3"}, ServerName: "dns.google", } systemdialer := QUICDialerQUICGo{ @@ -134,8 +129,90 @@ func TestQUICDialerWorksAsIntended(t *testing.T) { } } +func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + tlsConfig := &tls.Config{ + ServerName: "dns.google", + } + systemdialer := QUICDialerQUICGo{ + QUICListener: &QUICListenerStdlib{}, + mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, + remoteAddr net.Addr, host string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) { + gotTLSConfig = tlsConfig + return nil, expected + }, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil session here") + } + if tlsConfig.RootCAs != nil { + t.Fatal("tlsConfig.RootCAs should not have been changed") + } + if gotTLSConfig.RootCAs != defaultCertPool { + t.Fatal("invalid gotTLSConfig.RootCAs") + } + if tlsConfig.NextProtos != nil { + t.Fatal("tlsConfig.NextProtos should not have been changed") + } + if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" { + t.Fatal("invalid gotTLSConfig.NextProtos", diff) + } + if tlsConfig.ServerName != gotTLSConfig.ServerName { + t.Fatal("the ServerName field must match") + } +} + +func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + tlsConfig := &tls.Config{ + ServerName: "dns.google", + } + systemdialer := QUICDialerQUICGo{ + QUICListener: &QUICListenerStdlib{}, + mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, + remoteAddr net.Addr, host string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) { + gotTLSConfig = tlsConfig + return nil, expected + }, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil session here") + } + if tlsConfig.RootCAs != nil { + t.Fatal("tlsConfig.RootCAs should not have been changed") + } + if gotTLSConfig.RootCAs != defaultCertPool { + t.Fatal("invalid gotTLSConfig.RootCAs") + } + if tlsConfig.NextProtos != nil { + t.Fatal("tlsConfig.NextProtos should not have been changed") + } + if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" { + t.Fatal("invalid gotTLSConfig.NextProtos", diff) + } + if tlsConfig.ServerName != gotTLSConfig.ServerName { + t.Fatal("the ServerName field must match") + } +} + func TestQUICDialerResolverSuccess(t *testing.T) { - tlsConfig := &tls.Config{NextProtos: []string{"h3"}} + tlsConfig := &tls.Config{} dialer := &QUICDialerResolver{ Resolver: &net.Resolver{}, Dialer: &QUICDialerQUICGo{ QUICListener: &QUICListenerStdlib{}, @@ -153,7 +230,7 @@ func TestQUICDialerResolverSuccess(t *testing.T) { } func TestQUICDialerResolverNoPort(t *testing.T) { - tlsConfig := &tls.Config{NextProtos: []string{"h3"}} + tlsConfig := &tls.Config{} dialer := &QUICDialerResolver{ Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{}} sess, err := dialer.DialContext( @@ -185,7 +262,7 @@ func TestQUICDialerResolverLookupHostAddress(t *testing.T) { } func TestQUICDialerResolverLookupHostFailure(t *testing.T) { - tlsConfig := &tls.Config{NextProtos: []string{"h3"}} + tlsConfig := &tls.Config{} expected := errors.New("mocked error") dialer := &QUICDialerResolver{Resolver: &netxmocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { @@ -206,7 +283,7 @@ func TestQUICDialerResolverLookupHostFailure(t *testing.T) { func TestQUICDialerResolverInvalidPort(t *testing.T) { // This test allows us to check for the case where every attempt // to establish a connection leads to a failure - tlsConf := &tls.Config{NextProtos: []string{"h3"}} + tlsConf := &tls.Config{} dialer := &QUICDialerResolver{ Resolver: new(net.Resolver), Dialer: &QUICDialerQUICGo{ QUICListener: &QUICListenerStdlib{}, @@ -225,3 +302,32 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) { t.Fatal("expected nil sess") } } + +func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + tlsConfig := &tls.Config{} + dialer := &QUICDialerResolver{ + Resolver: new(net.Resolver), Dialer: &netxmocks.QUICContextDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + gotTLSConfig = tlsConfig + return nil, expected + }, + }} + sess, err := dialer.DialContext( + context.Background(), "udp", "www.google.com:443", + tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil session here") + } + if tlsConfig.ServerName != "" { + t.Fatal("should not have changed tlsConfig.ServerName") + } + if gotTLSConfig.ServerName != "www.google.com" { + t.Fatal("gotTLSConfig.ServerName has not been set") + } +} diff --git a/internal/netxlite/tlsconn.go b/internal/netxlite/tlsconn.go new file mode 100644 index 0000000..0dbb5fc --- /dev/null +++ b/internal/netxlite/tlsconn.go @@ -0,0 +1,18 @@ +package netxlite + +import ( + "crypto/tls" + "net" +) + +// TLSConn is any tls.Conn-like structure. +type TLSConn interface { + // net.Conn is the embedded conn. + net.Conn + + // ConnectionState returns the TLS connection state. + ConnectionState() tls.ConnectionState + + // Handshake performs the handshake. + Handshake() error +} diff --git a/internal/netxlite/tlsdialer.go b/internal/netxlite/tlsdialer.go index 9a959d0..cc09e3a 100644 --- a/internal/netxlite/tlsdialer.go +++ b/internal/netxlite/tlsdialer.go @@ -43,8 +43,6 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) // We set the ServerName field if not already set. // // We set the ALPN if the port is 443 or 853, if not already set. -// -// We force using our root CA, unless it's already set. func (d *TLSDialer) config(host, port string) *tls.Config { config := d.Config if config == nil { @@ -62,8 +60,5 @@ func (d *TLSDialer) config(host, port string) *tls.Config { config.NextProtos = []string{"dot"} } } - if config.RootCAs == nil { - config.RootCAs = NewDefaultCertPool() - } return config } diff --git a/internal/netxlite/tlsdialer_test.go b/internal/netxlite/tlsdialer_test.go index 6e6c6ed..5195585 100644 --- a/internal/netxlite/tlsdialer_test.go +++ b/internal/netxlite/tlsdialer_test.go @@ -3,7 +3,6 @@ package netxlite import ( "context" "crypto/tls" - "crypto/x509" "errors" "io" "net" @@ -54,7 +53,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) { return nil }}, nil }}, - TLSHandshaker: &TLSHandshakerStdlib{}, + TLSHandshaker: &TLSHandshakerConfigurable{}, } conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") if !errors.Is(err, io.EOF) { @@ -99,9 +98,6 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { if config.ServerName != "www.google.com" { t.Fatal("invalid server name") } - if config.RootCAs == nil { - t.Fatal("expected non-nil root CAs") - } if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" { t.Fatal(diff) } @@ -113,9 +109,6 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { if config.ServerName != "dns.google" { t.Fatal("invalid server name") } - if config.RootCAs == nil { - t.Fatal("expected non-nil root CAs") - } if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { t.Fatal(diff) } @@ -131,9 +124,6 @@ func TestTLSDialerConfigWithServerName(t *testing.T) { if config.ServerName != "example.com" { t.Fatal("invalid server name") } - if config.RootCAs == nil { - t.Fatal("expected non-nil root CAs") - } if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { t.Fatal(diff) } @@ -149,29 +139,7 @@ func TestTLSDialerConfigWithALPN(t *testing.T) { if config.ServerName != "dns.google" { t.Fatal("invalid server name") } - if config.RootCAs == nil { - t.Fatal("expected non-nil root CAs") - } if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" { t.Fatal(diff) } } - -func TestTLSDialerConfigWithRootCA(t *testing.T) { - pool := &x509.CertPool{} - d := &TLSDialer{ - Config: &tls.Config{ - RootCAs: pool, - }, - } - config := d.config("dns.google", "853") - if config.ServerName != "dns.google" { - t.Fatal("invalid server name") - } - if config.RootCAs != pool { - t.Fatal("not the RootCAs we expected") - } - if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { - t.Fatal(diff) - } -} diff --git a/internal/netxlite/tlshandshaker.go b/internal/netxlite/tlshandshaker.go index 7e10233..64fe5df 100644 --- a/internal/netxlite/tlshandshaker.go +++ b/internal/netxlite/tlshandshaker.go @@ -16,17 +16,33 @@ type TLSHandshaker interface { net.Conn, tls.ConnectionState, error) } -// TLSHandshakerStdlib is the stdlib's TLS handshaker. -type TLSHandshakerStdlib struct { - // Timeout is the timeout imposed on the TLS handshake. If zero +// TLSHandshakerConfigurable is a configurable TLS handshaker that +// uses by default the standard library's TLS implementation. +type TLSHandshakerConfigurable struct { + // NewConn is the OPTIONAL factory for creating a new connection. If + // this factory is not set, we'll use the stdlib. + NewConn func(conn net.Conn, config *tls.Config) TLSConn + + // Timeout is the OPTIONAL timeout imposed on the TLS handshake. If zero // or negative, we will use default timeout of 10 seconds. Timeout time.Duration } -var _ TLSHandshaker = &TLSHandshakerStdlib{} +var _ TLSHandshaker = &TLSHandshakerConfigurable{} -// Handshake implements Handshaker.Handshake -func (h *TLSHandshakerStdlib) Handshake( +// defaultCertPool is the cert pool we use by default. We store this +// value into a private variable to enable for unit testing. +var defaultCertPool = NewDefaultCertPool() + +// Handshake implements Handshaker.Handshake. This function will +// configure the code to use the built-in Mozilla CA if the config +// field contains a nil RootCAs field. +// +// Bug +// +// Until Go 1.17 is released, this function will not honour +// the context. We'll however always enforce an overall timeout. +func (h *TLSHandshakerConfigurable) Handshake( ctx context.Context, conn net.Conn, config *tls.Config, ) (net.Conn, tls.ConnectionState, error) { timeout := h.Timeout @@ -35,15 +51,27 @@ func (h *TLSHandshakerStdlib) Handshake( } defer conn.SetDeadline(time.Time{}) conn.SetDeadline(time.Now().Add(timeout)) - tlsconn := tls.Client(conn, config) + if config.RootCAs == nil { + config = config.Clone() + config.RootCAs = defaultCertPool + } + tlsconn := h.newConn(conn, config) if err := tlsconn.Handshake(); err != nil { return nil, tls.ConnectionState{}, err } return tlsconn, tlsconn.ConnectionState(), nil } +// newConn creates a new TLSConn. +func (h *TLSHandshakerConfigurable) newConn(conn net.Conn, config *tls.Config) TLSConn { + if h.NewConn != nil { + return h.NewConn(conn, config) + } + return tls.Client(conn, config) +} + // DefaultTLSHandshaker is the default TLS handshaker. -var DefaultTLSHandshaker = &TLSHandshakerStdlib{} +var DefaultTLSHandshaker = &TLSHandshakerConfigurable{} // TLSHandshakerLogger is a TLSHandshaker with logging. type TLSHandshakerLogger struct { diff --git a/internal/netxlite/tlshandshaker_test.go b/internal/netxlite/tlshandshaker_test.go index 24f1cd2..817a63b 100644 --- a/internal/netxlite/tlshandshaker_test.go +++ b/internal/netxlite/tlshandshaker_test.go @@ -17,9 +17,9 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxmocks" ) -func TestTLSHandshakerStdlibWithError(t *testing.T) { +func TestTLSHandshakerConfigurableWithError(t *testing.T) { var times []time.Time - h := &TLSHandshakerStdlib{} + h := &TLSHandshakerConfigurable{} tcpConn := &netxmocks.Conn{ MockWrite: func(b []byte) (int, error) { return 0, io.EOF @@ -50,7 +50,7 @@ func TestTLSHandshakerStdlibWithError(t *testing.T) { } } -func TestTLSHandshakerStdlibSuccess(t *testing.T) { +func TestTLSHandshakerConfigurableSuccess(t *testing.T) { handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(200) }) @@ -65,7 +65,7 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) { t.Fatal(err) } defer conn.Close() - handshaker := &TLSHandshakerStdlib{} + handshaker := &TLSHandshakerConfigurable{} ctx := context.Background() config := &tls.Config{ InsecureSkipVerify: true, @@ -83,6 +83,44 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) { } } +func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) { + expected := errors.New("mocked error") + var gotTLSConfig *tls.Config + handshaker := &TLSHandshakerConfigurable{ + NewConn: func(conn net.Conn, config *tls.Config) TLSConn { + gotTLSConfig = config + return &netxmocks.TLSConn{ + MockHandshake: func() error { + return expected + }, + } + }, + } + ctx := context.Background() + config := &tls.Config{} + conn := &netxmocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + } + tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero connState here") + } + if tlsConn != nil { + t.Fatal("expected nil tlsConn here") + } + if config.RootCAs != nil { + t.Fatal("config.RootCAs should still be nil") + } + if gotTLSConfig.RootCAs != defaultCertPool { + t.Fatal("gotTLSConfig.RootCAs has not been correctly set") + } +} + func TestTLSHandshakerLoggerSuccess(t *testing.T) { th := &TLSHandshakerLogger{ TLSHandshaker: &netxmocks.TLSHandshaker{ diff --git a/internal/netxmocks/quic.go b/internal/netxmocks/quic.go index a0238ec..765e2ef 100644 --- a/internal/netxmocks/quic.go +++ b/internal/netxmocks/quic.go @@ -1,6 +1,12 @@ package netxmocks -import "net" +import ( + "context" + "crypto/tls" + "net" + + "github.com/lucas-clemente/quic-go" +) // QUICListener is a mockable netxlite.QUICListener. type QUICListener struct { @@ -11,3 +17,15 @@ type QUICListener struct { func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) { return ql.MockListen(addr) } + +// QUICContextDialer is a mockable netxlite.QUICContextDialer. +type QUICContextDialer struct { + MockDialContext func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) +} + +// DialContext calls MockDialContext. +func (qcd *QUICContextDialer) DialContext(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + return qcd.MockDialContext(ctx, network, address, tlsConfig, quicConfig) +} diff --git a/internal/netxmocks/quic_test.go b/internal/netxmocks/quic_test.go index fbddfb7..1afc610 100644 --- a/internal/netxmocks/quic_test.go +++ b/internal/netxmocks/quic_test.go @@ -1,9 +1,13 @@ package netxmocks import ( + "context" + "crypto/tls" "errors" "net" "testing" + + "github.com/lucas-clemente/quic-go" ) func TestQUICListenerListen(t *testing.T) { @@ -21,3 +25,22 @@ func TestQUICListenerListen(t *testing.T) { t.Fatal("expected nil conn here") } } + +func TestQUICContextDialerDialContext(t *testing.T) { + expected := errors.New("mocked error") + qcd := &QUICContextDialer{ + MockDialContext: func(ctx context.Context, network string, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { + return nil, expected + }, + } + ctx := context.Background() + tlsConfig := &tls.Config{} + quicConfig := &quic.Config{} + sess, err := qcd.DialContext(ctx, "udp", "dns.google:443", tlsConfig, quicConfig) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if sess != nil { + t.Fatal("expected nil session") + } +} diff --git a/internal/netxmocks/tlsconn.go b/internal/netxmocks/tlsconn.go new file mode 100644 index 0000000..3a99af7 --- /dev/null +++ b/internal/netxmocks/tlsconn.go @@ -0,0 +1,25 @@ +package netxmocks + +import "crypto/tls" + +// TLSConn allows to mock netxlite.TLSConn. +type TLSConn struct { + // Conn is the embedded mockable Conn. + Conn + + // MockConnectionState allows to mock the ConnectionState method. + MockConnectionState func() tls.ConnectionState + + // MockHandshake allows to mock the Handshake method. + MockHandshake func() error +} + +// ConnectionState calls MockConnectionState. +func (c *TLSConn) ConnectionState() tls.ConnectionState { + return c.MockConnectionState() +} + +// Handshake calls MockHandshake. +func (c *TLSConn) Handshake() error { + return c.MockHandshake() +}