feat(netxlite): implement single use {,tls} dialer (#464)

This basically adapts already existing code inside websteps to
instead be into the netxlite package, where it belongs.

In the process, abstract the TLSDialer but keep a reference to the
previous name to avoid refactoring existing code (just for now).

While there, notice that the right name is CloseIdleConnections (i.e.,
plural not singular) and change the name.

While there, since we abstracted TLSDialer to be an interface, create
suitable factories for making a TLSDialer type from a Dialer and a
TLSHandshaker.

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-06 14:12:30 +02:00 committed by GitHub
parent ef9592f75e
commit 2572376fdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 177 additions and 36 deletions

View File

@ -83,7 +83,7 @@ func NewSingleTransport(conn net.Conn) http.RoundTripper {
func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper {
transport := newBaseTransport()
transport.DialContext = dialer.DialContext
transport.DialTLSContext = (&netxlite.TLSDialer{
transport.DialTLSContext = (&netxlite.TLSDialerLegacy{
Config: tlsConfig,
Dialer: netxlite.NewDialerLegacyAdapter(dialer),
TLSHandshaker: handshaker,

View File

@ -103,8 +103,8 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
// - SystemTLSHandshaker
//
// If you have others needs, manually build the chain you need.
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
return &netxlite.TLSDialer{
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialerLegacy {
return &netxlite.TLSDialerLegacy{
Config: config,
Dialer: netxlite.NewDialerLegacyAdapter(d),
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{

View File

@ -207,7 +207,7 @@ func NewTLSDialer(config Config) TLSDialer {
}
config.TLSConfig.RootCAs = config.CertPool
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
return &netxlite.TLSDialer{
return &netxlite.TLSDialerLegacy{
Config: config.TLSConfig,
Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer),
TLSHandshaker: h,

View File

@ -255,7 +255,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
func TestNewTLSDialerVanilla(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{})
rtd, ok := td.(*netxlite.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
@ -287,7 +287,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
TLSConfig: new(tls.Config),
})
rtd, ok := td.(*netxlite.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
@ -316,7 +316,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
Logger: log.Log,
})
rtd, ok := td.(*netxlite.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
@ -356,7 +356,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
TLSSaver: saver,
})
rtd, ok := td.(*netxlite.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
@ -396,7 +396,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
TLSConfig: new(tls.Config),
NoTLSVerify: true,
})
rtd, ok := td.(*netxlite.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
@ -428,7 +428,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{
NoTLSVerify: true,
})
rtd, ok := td.(*netxlite.TLSDialer)
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
@ -488,7 +488,7 @@ func TestNewWithDialer(t *testing.T) {
func TestNewWithTLSDialer(t *testing.T) {
expected := errors.New("mocked error")
tlsDialer := &netxlite.TLSDialer{
tlsDialer := &netxlite.TLSDialerLegacy{
Config: new(tls.Config),
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {

View File

@ -13,7 +13,7 @@ func TestTLSDialerSuccess(t *testing.T) {
t.Skip("skip test in short mode")
}
log.SetLevel(log.DebugLevel)
dialer := &netxlite.TLSDialer{Dialer: netxlite.DefaultDialer,
dialer := &netxlite.TLSDialerLegacy{Dialer: netxlite.DefaultDialer,
TLSHandshaker: &netxlite.TLSHandshakerLogger{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Logger: log.Log,

View File

@ -22,7 +22,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{
tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: netxlite.NewDialerLegacyAdapter(
dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
@ -117,7 +117,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{
tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
@ -183,7 +183,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{
tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
@ -216,7 +216,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{
tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
@ -249,7 +249,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{
tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
@ -282,7 +282,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{
tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{InsecureSkipVerify: true},
Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{

View File

@ -2,7 +2,9 @@ package netxlite
import (
"context"
"errors"
"net"
"sync"
"time"
)
@ -137,3 +139,38 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string)
func (d *dialerLogger) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
// ErrNoConnReuse indicates we cannot reuse the connection provided
// to a single use (possibly TLS) dialer.
var ErrNoConnReuse = errors.New("cannot reuse connection")
// NewSingleUseDialer returns a dialer that returns the given connection once
// and after that always fails with the ErrNoConnReuse error.
func NewSingleUseDialer(conn net.Conn) Dialer {
return &dialerSingleUse{conn: conn}
}
// dialerSingleUse is the type of Dialer returned by NewSingleDialer.
type dialerSingleUse struct {
sync.Mutex
conn net.Conn
}
var _ Dialer = &dialerSingleUse{}
// DialContext implements Dialer.DialContext.
func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
defer s.Unlock()
s.Lock()
if s.conn == nil {
return nil, ErrNoConnReuse
}
var conn net.Conn
conn, s.conn = s.conn, nil
return conn, nil
}
// CloseIdleConnections closes idle connections.
func (s *dialerSingleUse) CloseIdleConnections() {
// nothing
}

View File

@ -235,3 +235,24 @@ func TestNewDialerWithoutResolverChain(t *testing.T) {
t.Fatal("invalid type")
}
}
func TestNewSingleUseDialerWorksAsIntended(t *testing.T) {
conn := &mocks.Conn{}
d := NewSingleUseDialer(conn)
outconn, err := d.DialContext(context.Background(), "", "")
if err != nil {
t.Fatal(err)
}
if conn != outconn {
t.Fatal("invalid outconn")
}
for i := 0; i < 4; i++ {
outconn, err = d.DialContext(context.Background(), "", "")
if !errors.Is(err, ErrNoConnReuse) {
t.Fatal("not the error we expected", err)
}
if outconn != nil {
t.Fatal("expected nil outconn here")
}
}
}

View File

@ -73,7 +73,7 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
txp := http.DefaultTransport.(*http.Transport).Clone()
dialer = &httpDialerWithReadTimeout{dialer}
txp.DialContext = dialer.DialContext
txp.DialTLSContext = (&TLSDialer{
txp.DialTLSContext = (&tlsDialer{
Config: tlsConfig,
Dialer: dialer,
TLSHandshaker: handshaker,

View File

@ -62,6 +62,7 @@ type (
TLSHandshakerConfigurable = tlsHandshakerConfigurable
TLSHandshakerLogger = tlsHandshakerLogger
DialerSystem = dialerSystem
TLSDialerLegacy = tlsDialer
)
// ResolverLegacy performs domain name resolutions.

View File

@ -216,8 +216,29 @@ func (h *tlsHandshakerLogger) Handshake(
return tlsconn, state, nil
}
// TLSDialer is the TLS dialer
type TLSDialer struct {
// TLSDialer is a Dialer dialing TLS connections.
type TLSDialer interface {
// CloseIdleConnections closes idle connections, if any.
CloseIdleConnections()
// DialTLSContext dials a TLS connection.
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
}
// NewTLSDialer creates a new TLS dialer using the given dialer
// and TLS handshaker to establish TLS connections.
func NewTLSDialer(dialer Dialer, handshaker TLSHandshaker) TLSDialer {
return NewTLSDialerWithConfig(dialer, handshaker, &tls.Config{})
}
// NewTLSDialerWithConfig is like NewTLSDialer but takes an optional config
// parameter containing your desired TLS configuration.
func NewTLSDialerWithConfig(d Dialer, h TLSHandshaker, c *tls.Config) TLSDialer {
return &tlsDialer{Config: c, Dialer: d, TLSHandshaker: h}
}
// tlsDialer is the TLS dialer
type tlsDialer struct {
// Config is the OPTIONAL tls config.
Config *tls.Config
@ -228,13 +249,15 @@ type TLSDialer struct {
TLSHandshaker TLSHandshaker
}
// CloseIdleConnection closes idle connections, if any.
func (d *TLSDialer) CloseIdleConnection() {
var _ TLSDialer = &tlsDialer{}
// CloseIdleConnections implements TLSDialer.CloseIdleConnections.
func (d *tlsDialer) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
// DialTLSContext dials a TLS connection.
func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
// DialTLSContext implements TLSDialer.DialTLSContext.
func (d *tlsDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
@ -258,7 +281,7 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string)
// We set the ServerName field if not already set.
//
// We set the ALPN if the port is 443 or 853, if not already set.
func (d *TLSDialer) config(host, port string) *tls.Config {
func (d *tlsDialer) config(host, port string) *tls.Config {
config := d.Config
if config == nil {
config = &tls.Config{}
@ -277,3 +300,22 @@ func (d *TLSDialer) config(host, port string) *tls.Config {
}
return config
}
// NewSingleUseTLSDialer is like NewSingleUseDialer but takes
// in input a TLSConn rather than a net.Conn.
func NewSingleUseTLSDialer(conn TLSConn) TLSDialer {
return &tlsDialerSingleUseAdapter{NewSingleUseDialer(conn)}
}
// tlsDialerSingleUseAdapter adapts dialerSingleUse to
// be a TLSDialer type rather than a Dialer type.
type tlsDialerSingleUseAdapter struct {
Dialer
}
var _ TLSDialer = &tlsDialerSingleUseAdapter{}
// DialTLSContext implements TLSDialer.DialTLSContext.
func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address)
}

View File

@ -280,21 +280,21 @@ func TestTLSHandshakerLoggerFailure(t *testing.T) {
func TestTLSDialerCloseIdleConnections(t *testing.T) {
var called bool
dialer := &TLSDialer{
dialer := &tlsDialer{
Dialer: &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
},
}
dialer.CloseIdleConnection()
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
}
func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
dialer := &TLSDialer{}
dialer := &tlsDialer{}
ctx := context.Background()
const address = "www.google.com" // missing port
conn, err := dialer.DialTLSContext(ctx, "tcp", address)
@ -309,7 +309,7 @@ func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // immediately fail
dialer := TLSDialer{Dialer: defaultDialer}
dialer := tlsDialer{Dialer: defaultDialer}
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
t.Fatal("not the error we expected", err)
@ -321,7 +321,7 @@ func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
ctx := context.Background()
dialer := TLSDialer{
dialer := tlsDialer{
Config: &tls.Config{},
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
@ -345,7 +345,7 @@ func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
ctx := context.Background()
dialer := TLSDialer{
dialer := tlsDialer{
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
@ -372,7 +372,7 @@ func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
}
func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
d := &TLSDialer{}
d := &tlsDialer{}
config := d.config("www.google.com", "443")
if config.ServerName != "www.google.com" {
t.Fatal("invalid server name")
@ -383,7 +383,7 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
}
func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
d := &TLSDialer{}
d := &tlsDialer{}
config := d.config("dns.google", "853")
if config.ServerName != "dns.google" {
t.Fatal("invalid server name")
@ -394,7 +394,7 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
}
func TestTLSDialerConfigWithServerName(t *testing.T) {
d := &TLSDialer{
d := &tlsDialer{
Config: &tls.Config{
ServerName: "example.com",
},
@ -409,7 +409,7 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
}
func TestTLSDialerConfigWithALPN(t *testing.T) {
d := &TLSDialer{
d := &tlsDialer{
Config: &tls.Config{
NextProtos: []string{"h2"},
},
@ -440,3 +440,43 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) {
t.Fatal("expected nil NewConn")
}
}
func TestNewTLSDialerWorksAsIntended(t *testing.T) {
d := &mocks.Dialer{}
tlsh := &mocks.TLSHandshaker{}
td := NewTLSDialer(d, tlsh)
tdut, okay := td.(*tlsDialer)
if !okay {
t.Fatal("invalid type")
}
if tdut.Config == nil {
t.Fatal("unexpected config")
}
if tdut.Dialer != d {
t.Fatal("unexpected dialer")
}
if tdut.TLSHandshaker != tlsh {
t.Fatal("invalid handshaker")
}
}
func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) {
conn := &mocks.TLSConn{}
d := NewSingleUseTLSDialer(conn)
outconn, err := d.DialTLSContext(context.Background(), "", "")
if err != nil {
t.Fatal(err)
}
if conn != outconn {
t.Fatal("invalid outconn")
}
for i := 0; i < 4; i++ {
outconn, err = d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, ErrNoConnReuse) {
t.Fatal("not the error we expected", err)
}
if outconn != nil {
t.Fatal("expected nil outconn here")
}
}
}