feat(netxlite): implement single use {,tls} dialer (#464)
This basically adapts already existing code inside websteps to instead be into the netxlite package, where it belongs. In the process, abstract the TLSDialer but keep a reference to the previous name to avoid refactoring existing code (just for now). While there, notice that the right name is CloseIdleConnections (i.e., plural not singular) and change the name. While there, since we abstracted TLSDialer to be an interface, create suitable factories for making a TLSDialer type from a Dialer and a TLSHandshaker. See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
ef9592f75e
commit
2572376fdb
|
@ -83,7 +83,7 @@ func NewSingleTransport(conn net.Conn) http.RoundTripper {
|
||||||
func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper {
|
func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper {
|
||||||
transport := newBaseTransport()
|
transport := newBaseTransport()
|
||||||
transport.DialContext = dialer.DialContext
|
transport.DialContext = dialer.DialContext
|
||||||
transport.DialTLSContext = (&netxlite.TLSDialer{
|
transport.DialTLSContext = (&netxlite.TLSDialerLegacy{
|
||||||
Config: tlsConfig,
|
Config: tlsConfig,
|
||||||
Dialer: netxlite.NewDialerLegacyAdapter(dialer),
|
Dialer: netxlite.NewDialerLegacyAdapter(dialer),
|
||||||
TLSHandshaker: handshaker,
|
TLSHandshaker: handshaker,
|
||||||
|
|
|
@ -103,8 +103,8 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
|
||||||
// - SystemTLSHandshaker
|
// - SystemTLSHandshaker
|
||||||
//
|
//
|
||||||
// If you have others needs, manually build the chain you need.
|
// If you have others needs, manually build the chain you need.
|
||||||
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer {
|
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialerLegacy {
|
||||||
return &netxlite.TLSDialer{
|
return &netxlite.TLSDialerLegacy{
|
||||||
Config: config,
|
Config: config,
|
||||||
Dialer: netxlite.NewDialerLegacyAdapter(d),
|
Dialer: netxlite.NewDialerLegacyAdapter(d),
|
||||||
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
|
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
|
||||||
|
|
|
@ -207,7 +207,7 @@ func NewTLSDialer(config Config) TLSDialer {
|
||||||
}
|
}
|
||||||
config.TLSConfig.RootCAs = config.CertPool
|
config.TLSConfig.RootCAs = config.CertPool
|
||||||
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
|
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
|
||||||
return &netxlite.TLSDialer{
|
return &netxlite.TLSDialerLegacy{
|
||||||
Config: config.TLSConfig,
|
Config: config.TLSConfig,
|
||||||
Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer),
|
Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer),
|
||||||
TLSHandshaker: h,
|
TLSHandshaker: h,
|
||||||
|
|
|
@ -255,7 +255,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
|
||||||
|
|
||||||
func TestNewTLSDialerVanilla(t *testing.T) {
|
func TestNewTLSDialerVanilla(t *testing.T) {
|
||||||
td := netx.NewTLSDialer(netx.Config{})
|
td := netx.NewTLSDialer(netx.Config{})
|
||||||
rtd, ok := td.(*netxlite.TLSDialer)
|
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("not the TLSDialer we expected")
|
||||||
}
|
}
|
||||||
|
@ -287,7 +287,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
|
||||||
td := netx.NewTLSDialer(netx.Config{
|
td := netx.NewTLSDialer(netx.Config{
|
||||||
TLSConfig: new(tls.Config),
|
TLSConfig: new(tls.Config),
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialer)
|
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("not the TLSDialer we expected")
|
||||||
}
|
}
|
||||||
|
@ -316,7 +316,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
|
||||||
td := netx.NewTLSDialer(netx.Config{
|
td := netx.NewTLSDialer(netx.Config{
|
||||||
Logger: log.Log,
|
Logger: log.Log,
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialer)
|
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("not the TLSDialer we expected")
|
||||||
}
|
}
|
||||||
|
@ -356,7 +356,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
|
||||||
td := netx.NewTLSDialer(netx.Config{
|
td := netx.NewTLSDialer(netx.Config{
|
||||||
TLSSaver: saver,
|
TLSSaver: saver,
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialer)
|
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("not the TLSDialer we expected")
|
||||||
}
|
}
|
||||||
|
@ -396,7 +396,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
|
||||||
TLSConfig: new(tls.Config),
|
TLSConfig: new(tls.Config),
|
||||||
NoTLSVerify: true,
|
NoTLSVerify: true,
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialer)
|
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("not the TLSDialer we expected")
|
||||||
}
|
}
|
||||||
|
@ -428,7 +428,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
|
||||||
td := netx.NewTLSDialer(netx.Config{
|
td := netx.NewTLSDialer(netx.Config{
|
||||||
NoTLSVerify: true,
|
NoTLSVerify: true,
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialer)
|
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("not the TLSDialer we expected")
|
||||||
}
|
}
|
||||||
|
@ -488,7 +488,7 @@ func TestNewWithDialer(t *testing.T) {
|
||||||
|
|
||||||
func TestNewWithTLSDialer(t *testing.T) {
|
func TestNewWithTLSDialer(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
tlsDialer := &netxlite.TLSDialer{
|
tlsDialer := &netxlite.TLSDialerLegacy{
|
||||||
Config: new(tls.Config),
|
Config: new(tls.Config),
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
|
|
|
@ -13,7 +13,7 @@ func TestTLSDialerSuccess(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
dialer := &netxlite.TLSDialer{Dialer: netxlite.DefaultDialer,
|
dialer := &netxlite.TLSDialerLegacy{Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerLogger{
|
TLSHandshaker: &netxlite.TLSHandshakerLogger{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
Logger: log.Log,
|
Logger: log.Log,
|
||||||
|
|
|
@ -22,7 +22,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
|
||||||
}
|
}
|
||||||
nextprotos := []string{"h2"}
|
nextprotos := []string{"h2"}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialerLegacy{
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
Config: &tls.Config{NextProtos: nextprotos},
|
||||||
Dialer: netxlite.NewDialerLegacyAdapter(
|
Dialer: netxlite.NewDialerLegacyAdapter(
|
||||||
dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
|
dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
|
||||||
|
@ -117,7 +117,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
|
||||||
}
|
}
|
||||||
nextprotos := []string{"h2"}
|
nextprotos := []string{"h2"}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialerLegacy{
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
Config: &tls.Config{NextProtos: nextprotos},
|
||||||
Dialer: netxlite.DefaultDialer,
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
|
@ -183,7 +183,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialerLegacy{
|
||||||
Dialer: netxlite.DefaultDialer,
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
|
@ -216,7 +216,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialerLegacy{
|
||||||
Dialer: netxlite.DefaultDialer,
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
|
@ -249,7 +249,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialerLegacy{
|
||||||
Dialer: netxlite.DefaultDialer,
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||||
|
@ -282,7 +282,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
}
|
}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialer{
|
tlsdlr := &netxlite.TLSDialerLegacy{
|
||||||
Config: &tls.Config{InsecureSkipVerify: true},
|
Config: &tls.Config{InsecureSkipVerify: true},
|
||||||
Dialer: netxlite.DefaultDialer,
|
Dialer: netxlite.DefaultDialer,
|
||||||
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
|
||||||
|
|
|
@ -2,7 +2,9 @@ package netxlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -137,3 +139,38 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string)
|
||||||
func (d *dialerLogger) CloseIdleConnections() {
|
func (d *dialerLogger) CloseIdleConnections() {
|
||||||
d.Dialer.CloseIdleConnections()
|
d.Dialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrNoConnReuse indicates we cannot reuse the connection provided
|
||||||
|
// to a single use (possibly TLS) dialer.
|
||||||
|
var ErrNoConnReuse = errors.New("cannot reuse connection")
|
||||||
|
|
||||||
|
// NewSingleUseDialer returns a dialer that returns the given connection once
|
||||||
|
// and after that always fails with the ErrNoConnReuse error.
|
||||||
|
func NewSingleUseDialer(conn net.Conn) Dialer {
|
||||||
|
return &dialerSingleUse{conn: conn}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialerSingleUse is the type of Dialer returned by NewSingleDialer.
|
||||||
|
type dialerSingleUse struct {
|
||||||
|
sync.Mutex
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Dialer = &dialerSingleUse{}
|
||||||
|
|
||||||
|
// DialContext implements Dialer.DialContext.
|
||||||
|
func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
|
defer s.Unlock()
|
||||||
|
s.Lock()
|
||||||
|
if s.conn == nil {
|
||||||
|
return nil, ErrNoConnReuse
|
||||||
|
}
|
||||||
|
var conn net.Conn
|
||||||
|
conn, s.conn = s.conn, nil
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections closes idle connections.
|
||||||
|
func (s *dialerSingleUse) CloseIdleConnections() {
|
||||||
|
// nothing
|
||||||
|
}
|
||||||
|
|
|
@ -235,3 +235,24 @@ func TestNewDialerWithoutResolverChain(t *testing.T) {
|
||||||
t.Fatal("invalid type")
|
t.Fatal("invalid type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewSingleUseDialerWorksAsIntended(t *testing.T) {
|
||||||
|
conn := &mocks.Conn{}
|
||||||
|
d := NewSingleUseDialer(conn)
|
||||||
|
outconn, err := d.DialContext(context.Background(), "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if conn != outconn {
|
||||||
|
t.Fatal("invalid outconn")
|
||||||
|
}
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
outconn, err = d.DialContext(context.Background(), "", "")
|
||||||
|
if !errors.Is(err, ErrNoConnReuse) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if outconn != nil {
|
||||||
|
t.Fatal("expected nil outconn here")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -73,7 +73,7 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
|
||||||
txp := http.DefaultTransport.(*http.Transport).Clone()
|
txp := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
dialer = &httpDialerWithReadTimeout{dialer}
|
dialer = &httpDialerWithReadTimeout{dialer}
|
||||||
txp.DialContext = dialer.DialContext
|
txp.DialContext = dialer.DialContext
|
||||||
txp.DialTLSContext = (&TLSDialer{
|
txp.DialTLSContext = (&tlsDialer{
|
||||||
Config: tlsConfig,
|
Config: tlsConfig,
|
||||||
Dialer: dialer,
|
Dialer: dialer,
|
||||||
TLSHandshaker: handshaker,
|
TLSHandshaker: handshaker,
|
||||||
|
|
|
@ -62,6 +62,7 @@ type (
|
||||||
TLSHandshakerConfigurable = tlsHandshakerConfigurable
|
TLSHandshakerConfigurable = tlsHandshakerConfigurable
|
||||||
TLSHandshakerLogger = tlsHandshakerLogger
|
TLSHandshakerLogger = tlsHandshakerLogger
|
||||||
DialerSystem = dialerSystem
|
DialerSystem = dialerSystem
|
||||||
|
TLSDialerLegacy = tlsDialer
|
||||||
)
|
)
|
||||||
|
|
||||||
// ResolverLegacy performs domain name resolutions.
|
// ResolverLegacy performs domain name resolutions.
|
||||||
|
|
|
@ -216,8 +216,29 @@ func (h *tlsHandshakerLogger) Handshake(
|
||||||
return tlsconn, state, nil
|
return tlsconn, state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TLSDialer is the TLS dialer
|
// TLSDialer is a Dialer dialing TLS connections.
|
||||||
type TLSDialer struct {
|
type TLSDialer interface {
|
||||||
|
// CloseIdleConnections closes idle connections, if any.
|
||||||
|
CloseIdleConnections()
|
||||||
|
|
||||||
|
// DialTLSContext dials a TLS connection.
|
||||||
|
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTLSDialer creates a new TLS dialer using the given dialer
|
||||||
|
// and TLS handshaker to establish TLS connections.
|
||||||
|
func NewTLSDialer(dialer Dialer, handshaker TLSHandshaker) TLSDialer {
|
||||||
|
return NewTLSDialerWithConfig(dialer, handshaker, &tls.Config{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTLSDialerWithConfig is like NewTLSDialer but takes an optional config
|
||||||
|
// parameter containing your desired TLS configuration.
|
||||||
|
func NewTLSDialerWithConfig(d Dialer, h TLSHandshaker, c *tls.Config) TLSDialer {
|
||||||
|
return &tlsDialer{Config: c, Dialer: d, TLSHandshaker: h}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tlsDialer is the TLS dialer
|
||||||
|
type tlsDialer struct {
|
||||||
// Config is the OPTIONAL tls config.
|
// Config is the OPTIONAL tls config.
|
||||||
Config *tls.Config
|
Config *tls.Config
|
||||||
|
|
||||||
|
@ -228,13 +249,15 @@ type TLSDialer struct {
|
||||||
TLSHandshaker TLSHandshaker
|
TLSHandshaker TLSHandshaker
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnection closes idle connections, if any.
|
var _ TLSDialer = &tlsDialer{}
|
||||||
func (d *TLSDialer) CloseIdleConnection() {
|
|
||||||
|
// CloseIdleConnections implements TLSDialer.CloseIdleConnections.
|
||||||
|
func (d *tlsDialer) CloseIdleConnections() {
|
||||||
d.Dialer.CloseIdleConnections()
|
d.Dialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialTLSContext dials a TLS connection.
|
// DialTLSContext implements TLSDialer.DialTLSContext.
|
||||||
func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d *tlsDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
host, port, err := net.SplitHostPort(address)
|
host, port, err := net.SplitHostPort(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -258,7 +281,7 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string)
|
||||||
// We set the ServerName field if not already set.
|
// We set the ServerName field if not already set.
|
||||||
//
|
//
|
||||||
// We set the ALPN if the port is 443 or 853, if not already set.
|
// We set the ALPN if the port is 443 or 853, if not already set.
|
||||||
func (d *TLSDialer) config(host, port string) *tls.Config {
|
func (d *tlsDialer) config(host, port string) *tls.Config {
|
||||||
config := d.Config
|
config := d.Config
|
||||||
if config == nil {
|
if config == nil {
|
||||||
config = &tls.Config{}
|
config = &tls.Config{}
|
||||||
|
@ -277,3 +300,22 @@ func (d *TLSDialer) config(host, port string) *tls.Config {
|
||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewSingleUseTLSDialer is like NewSingleUseDialer but takes
|
||||||
|
// in input a TLSConn rather than a net.Conn.
|
||||||
|
func NewSingleUseTLSDialer(conn TLSConn) TLSDialer {
|
||||||
|
return &tlsDialerSingleUseAdapter{NewSingleUseDialer(conn)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tlsDialerSingleUseAdapter adapts dialerSingleUse to
|
||||||
|
// be a TLSDialer type rather than a Dialer type.
|
||||||
|
type tlsDialerSingleUseAdapter struct {
|
||||||
|
Dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TLSDialer = &tlsDialerSingleUseAdapter{}
|
||||||
|
|
||||||
|
// DialTLSContext implements TLSDialer.DialTLSContext.
|
||||||
|
func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return d.Dialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
|
@ -280,21 +280,21 @@ func TestTLSHandshakerLoggerFailure(t *testing.T) {
|
||||||
|
|
||||||
func TestTLSDialerCloseIdleConnections(t *testing.T) {
|
func TestTLSDialerCloseIdleConnections(t *testing.T) {
|
||||||
var called bool
|
var called bool
|
||||||
dialer := &TLSDialer{
|
dialer := &tlsDialer{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockCloseIdleConnections: func() {
|
MockCloseIdleConnections: func() {
|
||||||
called = true
|
called = true
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
dialer.CloseIdleConnection()
|
dialer.CloseIdleConnections()
|
||||||
if !called {
|
if !called {
|
||||||
t.Fatal("not called")
|
t.Fatal("not called")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
|
func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
|
||||||
dialer := &TLSDialer{}
|
dialer := &tlsDialer{}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
const address = "www.google.com" // missing port
|
const address = "www.google.com" // missing port
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", address)
|
conn, err := dialer.DialTLSContext(ctx, "tcp", address)
|
||||||
|
@ -309,7 +309,7 @@ func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
|
||||||
func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
|
func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel() // immediately fail
|
cancel() // immediately fail
|
||||||
dialer := TLSDialer{Dialer: defaultDialer}
|
dialer := tlsDialer{Dialer: defaultDialer}
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
|
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
|
@ -321,7 +321,7 @@ func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
|
||||||
|
|
||||||
func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
|
func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
dialer := TLSDialer{
|
dialer := tlsDialer{
|
||||||
Config: &tls.Config{},
|
Config: &tls.Config{},
|
||||||
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
||||||
|
@ -345,7 +345,7 @@ func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
|
||||||
|
|
||||||
func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
|
func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
dialer := TLSDialer{
|
dialer := tlsDialer{
|
||||||
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
|
@ -372,7 +372,7 @@ func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
|
func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
|
||||||
d := &TLSDialer{}
|
d := &tlsDialer{}
|
||||||
config := d.config("www.google.com", "443")
|
config := d.config("www.google.com", "443")
|
||||||
if config.ServerName != "www.google.com" {
|
if config.ServerName != "www.google.com" {
|
||||||
t.Fatal("invalid server name")
|
t.Fatal("invalid server name")
|
||||||
|
@ -383,7 +383,7 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
|
func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
|
||||||
d := &TLSDialer{}
|
d := &tlsDialer{}
|
||||||
config := d.config("dns.google", "853")
|
config := d.config("dns.google", "853")
|
||||||
if config.ServerName != "dns.google" {
|
if config.ServerName != "dns.google" {
|
||||||
t.Fatal("invalid server name")
|
t.Fatal("invalid server name")
|
||||||
|
@ -394,7 +394,7 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSDialerConfigWithServerName(t *testing.T) {
|
func TestTLSDialerConfigWithServerName(t *testing.T) {
|
||||||
d := &TLSDialer{
|
d := &tlsDialer{
|
||||||
Config: &tls.Config{
|
Config: &tls.Config{
|
||||||
ServerName: "example.com",
|
ServerName: "example.com",
|
||||||
},
|
},
|
||||||
|
@ -409,7 +409,7 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSDialerConfigWithALPN(t *testing.T) {
|
func TestTLSDialerConfigWithALPN(t *testing.T) {
|
||||||
d := &TLSDialer{
|
d := &tlsDialer{
|
||||||
Config: &tls.Config{
|
Config: &tls.Config{
|
||||||
NextProtos: []string{"h2"},
|
NextProtos: []string{"h2"},
|
||||||
},
|
},
|
||||||
|
@ -440,3 +440,43 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) {
|
||||||
t.Fatal("expected nil NewConn")
|
t.Fatal("expected nil NewConn")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewTLSDialerWorksAsIntended(t *testing.T) {
|
||||||
|
d := &mocks.Dialer{}
|
||||||
|
tlsh := &mocks.TLSHandshaker{}
|
||||||
|
td := NewTLSDialer(d, tlsh)
|
||||||
|
tdut, okay := td.(*tlsDialer)
|
||||||
|
if !okay {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
if tdut.Config == nil {
|
||||||
|
t.Fatal("unexpected config")
|
||||||
|
}
|
||||||
|
if tdut.Dialer != d {
|
||||||
|
t.Fatal("unexpected dialer")
|
||||||
|
}
|
||||||
|
if tdut.TLSHandshaker != tlsh {
|
||||||
|
t.Fatal("invalid handshaker")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) {
|
||||||
|
conn := &mocks.TLSConn{}
|
||||||
|
d := NewSingleUseTLSDialer(conn)
|
||||||
|
outconn, err := d.DialTLSContext(context.Background(), "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if conn != outconn {
|
||||||
|
t.Fatal("invalid outconn")
|
||||||
|
}
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
outconn, err = d.DialTLSContext(context.Background(), "", "")
|
||||||
|
if !errors.Is(err, ErrNoConnReuse) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if outconn != nil {
|
||||||
|
t.Fatal("expected nil outconn here")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user