ooni-probe-cli/internal/netxlite/tlshandshaker.go
Simone Basso b07890af4d
fix(netxlite): improve TLS auto-configuration (#409)
Auto-configure every relevant TLS field as close as possible to
where it's actually used.

As a side effect, add support for mocking the creation of a TLS
connection, which should possibly be useful for uTLS?

Work that is part of https://github.com/ooni/probe/issues/1505
2021-06-25 20:51:59 +02:00

109 lines
3.4 KiB
Go

package netxlite
import (
"context"
"crypto/tls"
"net"
"time"
)
// TLSHandshaker is the generic TLS handshaker.
type TLSHandshaker interface {
// Handshake creates a new TLS connection from the given connection and
// the given config. This function DOES NOT take ownership of the connection
// and it's your responsibility to close it on failure.
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}
// TLSHandshakerConfigurable is a configurable TLS handshaker that
// uses by default the standard library's TLS implementation.
type TLSHandshakerConfigurable struct {
// NewConn is the OPTIONAL factory for creating a new connection. If
// this factory is not set, we'll use the stdlib.
NewConn func(conn net.Conn, config *tls.Config) TLSConn
// Timeout is the OPTIONAL timeout imposed on the TLS handshake. If zero
// or negative, we will use default timeout of 10 seconds.
Timeout time.Duration
}
var _ TLSHandshaker = &TLSHandshakerConfigurable{}
// defaultCertPool is the cert pool we use by default. We store this
// value into a private variable to enable for unit testing.
var defaultCertPool = NewDefaultCertPool()
// Handshake implements Handshaker.Handshake. This function will
// configure the code to use the built-in Mozilla CA if the config
// field contains a nil RootCAs field.
//
// Bug
//
// Until Go 1.17 is released, this function will not honour
// the context. We'll however always enforce an overall timeout.
func (h *TLSHandshakerConfigurable) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
timeout := h.Timeout
if timeout <= 0 {
timeout = 10 * time.Second
}
defer conn.SetDeadline(time.Time{})
conn.SetDeadline(time.Now().Add(timeout))
if config.RootCAs == nil {
config = config.Clone()
config.RootCAs = defaultCertPool
}
tlsconn := h.newConn(conn, config)
if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
}
// newConn creates a new TLSConn.
func (h *TLSHandshakerConfigurable) newConn(conn net.Conn, config *tls.Config) TLSConn {
if h.NewConn != nil {
return h.NewConn(conn, config)
}
return tls.Client(conn, config)
}
// DefaultTLSHandshaker is the default TLS handshaker.
var DefaultTLSHandshaker = &TLSHandshakerConfigurable{}
// TLSHandshakerLogger is a TLSHandshaker with logging.
type TLSHandshakerLogger struct {
// TLSHandshaker is the underlying handshaker.
TLSHandshaker TLSHandshaker
// Logger is the underlying logger.
Logger Logger
}
// Handshake implements Handshaker.Handshake
func (h *TLSHandshakerLogger) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
h.Logger.Debugf(
"tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos)
start := time.Now()
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
elapsed := time.Since(start)
if err != nil {
h.Logger.Debugf(
"tls {sni=%s next=%+v}... %s in %s", config.ServerName,
config.NextProtos, err, elapsed)
return nil, tls.ConnectionState{}, err
}
h.Logger.Debugf(
"tls {sni=%s next=%+v}... ok in %s {next=%s cipher=%s v=%s}",
config.ServerName, config.NextProtos, elapsed, state.NegotiatedProtocol,
TLSCipherSuiteString(state.CipherSuite),
TLSVersionString(state.Version))
return tlsconn, state, nil
}
var _ TLSHandshaker = &TLSHandshakerLogger{}