feat(netxlite): implement single use {,tls} dialer (#464)

This basically adapts already existing code inside websteps to
instead be into the netxlite package, where it belongs.

In the process, abstract the TLSDialer but keep a reference to the
previous name to avoid refactoring existing code (just for now).

While there, notice that the right name is CloseIdleConnections (i.e.,
plural not singular) and change the name.

While there, since we abstracted TLSDialer to be an interface, create
suitable factories for making a TLSDialer type from a Dialer and a
TLSHandshaker.

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-06 14:12:30 +02:00 committed by GitHub
parent ef9592f75e
commit 2572376fdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 177 additions and 36 deletions

View File

@ -83,7 +83,7 @@ func NewSingleTransport(conn net.Conn) http.RoundTripper {
func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper { func NewTransportWithDialer(dialer netxlite.DialerLegacy, tlsConfig *tls.Config, handshaker netxlite.TLSHandshaker) http.RoundTripper {
transport := newBaseTransport() transport := newBaseTransport()
transport.DialContext = dialer.DialContext transport.DialContext = dialer.DialContext
transport.DialTLSContext = (&netxlite.TLSDialer{ transport.DialTLSContext = (&netxlite.TLSDialerLegacy{
Config: tlsConfig, Config: tlsConfig,
Dialer: netxlite.NewDialerLegacyAdapter(dialer), Dialer: netxlite.NewDialerLegacyAdapter(dialer),
TLSHandshaker: handshaker, TLSHandshaker: handshaker,

View File

@ -103,8 +103,8 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
// - SystemTLSHandshaker // - SystemTLSHandshaker
// //
// If you have others needs, manually build the chain you need. // If you have others needs, manually build the chain you need.
func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer { func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialerLegacy {
return &netxlite.TLSDialer{ return &netxlite.TLSDialerLegacy{
Config: config, Config: config,
Dialer: netxlite.NewDialerLegacyAdapter(d), Dialer: netxlite.NewDialerLegacyAdapter(d),
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ TLSHandshaker: tlsdialer.EmitterTLSHandshaker{

View File

@ -207,7 +207,7 @@ func NewTLSDialer(config Config) TLSDialer {
} }
config.TLSConfig.RootCAs = config.CertPool config.TLSConfig.RootCAs = config.CertPool
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
return &netxlite.TLSDialer{ return &netxlite.TLSDialerLegacy{
Config: config.TLSConfig, Config: config.TLSConfig,
Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer), Dialer: netxlite.NewDialerLegacyAdapter(config.Dialer),
TLSHandshaker: h, TLSHandshaker: h,

View File

@ -255,7 +255,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
func TestNewTLSDialerVanilla(t *testing.T) { func TestNewTLSDialerVanilla(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{}) td := netx.NewTLSDialer(netx.Config{})
rtd, ok := td.(*netxlite.TLSDialer) rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok { if !ok {
t.Fatal("not the TLSDialer we expected") t.Fatal("not the TLSDialer we expected")
} }
@ -287,7 +287,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{ td := netx.NewTLSDialer(netx.Config{
TLSConfig: new(tls.Config), TLSConfig: new(tls.Config),
}) })
rtd, ok := td.(*netxlite.TLSDialer) rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok { if !ok {
t.Fatal("not the TLSDialer we expected") t.Fatal("not the TLSDialer we expected")
} }
@ -316,7 +316,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{ td := netx.NewTLSDialer(netx.Config{
Logger: log.Log, Logger: log.Log,
}) })
rtd, ok := td.(*netxlite.TLSDialer) rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok { if !ok {
t.Fatal("not the TLSDialer we expected") t.Fatal("not the TLSDialer we expected")
} }
@ -356,7 +356,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{ td := netx.NewTLSDialer(netx.Config{
TLSSaver: saver, TLSSaver: saver,
}) })
rtd, ok := td.(*netxlite.TLSDialer) rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok { if !ok {
t.Fatal("not the TLSDialer we expected") t.Fatal("not the TLSDialer we expected")
} }
@ -396,7 +396,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
TLSConfig: new(tls.Config), TLSConfig: new(tls.Config),
NoTLSVerify: true, NoTLSVerify: true,
}) })
rtd, ok := td.(*netxlite.TLSDialer) rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok { if !ok {
t.Fatal("not the TLSDialer we expected") t.Fatal("not the TLSDialer we expected")
} }
@ -428,7 +428,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
td := netx.NewTLSDialer(netx.Config{ td := netx.NewTLSDialer(netx.Config{
NoTLSVerify: true, NoTLSVerify: true,
}) })
rtd, ok := td.(*netxlite.TLSDialer) rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok { if !ok {
t.Fatal("not the TLSDialer we expected") t.Fatal("not the TLSDialer we expected")
} }
@ -488,7 +488,7 @@ func TestNewWithDialer(t *testing.T) {
func TestNewWithTLSDialer(t *testing.T) { func TestNewWithTLSDialer(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
tlsDialer := &netxlite.TLSDialer{ tlsDialer := &netxlite.TLSDialerLegacy{
Config: new(tls.Config), Config: new(tls.Config),
Dialer: &mocks.Dialer{ Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {

View File

@ -13,7 +13,7 @@ func TestTLSDialerSuccess(t *testing.T) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
dialer := &netxlite.TLSDialer{Dialer: netxlite.DefaultDialer, dialer := &netxlite.TLSDialerLegacy{Dialer: netxlite.DefaultDialer,
TLSHandshaker: &netxlite.TLSHandshakerLogger{ TLSHandshaker: &netxlite.TLSHandshakerLogger{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Logger: log.Log, Logger: log.Log,

View File

@ -22,7 +22,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
} }
nextprotos := []string{"h2"} nextprotos := []string{"h2"}
saver := &trace.Saver{} saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{NextProtos: nextprotos}, Config: &tls.Config{NextProtos: nextprotos},
Dialer: netxlite.NewDialerLegacyAdapter( Dialer: netxlite.NewDialerLegacyAdapter(
dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
@ -117,7 +117,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
} }
nextprotos := []string{"h2"} nextprotos := []string{"h2"}
saver := &trace.Saver{} saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{NextProtos: nextprotos}, Config: &tls.Config{NextProtos: nextprotos},
Dialer: netxlite.DefaultDialer, Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
@ -183,7 +183,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
saver := &trace.Saver{} saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer, Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
@ -216,7 +216,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
saver := &trace.Saver{} saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer, Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
@ -249,7 +249,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
saver := &trace.Saver{} saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer, Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
@ -282,7 +282,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
saver := &trace.Saver{} saver := &trace.Saver{}
tlsdlr := &netxlite.TLSDialer{ tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{InsecureSkipVerify: true}, Config: &tls.Config{InsecureSkipVerify: true},
Dialer: netxlite.DefaultDialer, Dialer: netxlite.DefaultDialer,
TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SaverTLSHandshaker{

View File

@ -2,7 +2,9 @@ package netxlite
import ( import (
"context" "context"
"errors"
"net" "net"
"sync"
"time" "time"
) )
@ -137,3 +139,38 @@ func (d *dialerLogger) DialContext(ctx context.Context, network, address string)
func (d *dialerLogger) CloseIdleConnections() { func (d *dialerLogger) CloseIdleConnections() {
d.Dialer.CloseIdleConnections() d.Dialer.CloseIdleConnections()
} }
// ErrNoConnReuse indicates we cannot reuse the connection provided
// to a single use (possibly TLS) dialer.
var ErrNoConnReuse = errors.New("cannot reuse connection")
// NewSingleUseDialer returns a dialer that returns the given connection once
// and after that always fails with the ErrNoConnReuse error.
func NewSingleUseDialer(conn net.Conn) Dialer {
return &dialerSingleUse{conn: conn}
}
// dialerSingleUse is the type of Dialer returned by NewSingleDialer.
type dialerSingleUse struct {
sync.Mutex
conn net.Conn
}
var _ Dialer = &dialerSingleUse{}
// DialContext implements Dialer.DialContext.
func (s *dialerSingleUse) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
defer s.Unlock()
s.Lock()
if s.conn == nil {
return nil, ErrNoConnReuse
}
var conn net.Conn
conn, s.conn = s.conn, nil
return conn, nil
}
// CloseIdleConnections closes idle connections.
func (s *dialerSingleUse) CloseIdleConnections() {
// nothing
}

View File

@ -235,3 +235,24 @@ func TestNewDialerWithoutResolverChain(t *testing.T) {
t.Fatal("invalid type") t.Fatal("invalid type")
} }
} }
func TestNewSingleUseDialerWorksAsIntended(t *testing.T) {
conn := &mocks.Conn{}
d := NewSingleUseDialer(conn)
outconn, err := d.DialContext(context.Background(), "", "")
if err != nil {
t.Fatal(err)
}
if conn != outconn {
t.Fatal("invalid outconn")
}
for i := 0; i < 4; i++ {
outconn, err = d.DialContext(context.Background(), "", "")
if !errors.Is(err, ErrNoConnReuse) {
t.Fatal("not the error we expected", err)
}
if outconn != nil {
t.Fatal("expected nil outconn here")
}
}
}

View File

@ -73,7 +73,7 @@ func NewHTTPTransport(dialer Dialer, tlsConfig *tls.Config,
txp := http.DefaultTransport.(*http.Transport).Clone() txp := http.DefaultTransport.(*http.Transport).Clone()
dialer = &httpDialerWithReadTimeout{dialer} dialer = &httpDialerWithReadTimeout{dialer}
txp.DialContext = dialer.DialContext txp.DialContext = dialer.DialContext
txp.DialTLSContext = (&TLSDialer{ txp.DialTLSContext = (&tlsDialer{
Config: tlsConfig, Config: tlsConfig,
Dialer: dialer, Dialer: dialer,
TLSHandshaker: handshaker, TLSHandshaker: handshaker,

View File

@ -62,6 +62,7 @@ type (
TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerConfigurable = tlsHandshakerConfigurable
TLSHandshakerLogger = tlsHandshakerLogger TLSHandshakerLogger = tlsHandshakerLogger
DialerSystem = dialerSystem DialerSystem = dialerSystem
TLSDialerLegacy = tlsDialer
) )
// ResolverLegacy performs domain name resolutions. // ResolverLegacy performs domain name resolutions.

View File

@ -216,8 +216,29 @@ func (h *tlsHandshakerLogger) Handshake(
return tlsconn, state, nil return tlsconn, state, nil
} }
// TLSDialer is the TLS dialer // TLSDialer is a Dialer dialing TLS connections.
type TLSDialer struct { type TLSDialer interface {
// CloseIdleConnections closes idle connections, if any.
CloseIdleConnections()
// DialTLSContext dials a TLS connection.
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
}
// NewTLSDialer creates a new TLS dialer using the given dialer
// and TLS handshaker to establish TLS connections.
func NewTLSDialer(dialer Dialer, handshaker TLSHandshaker) TLSDialer {
return NewTLSDialerWithConfig(dialer, handshaker, &tls.Config{})
}
// NewTLSDialerWithConfig is like NewTLSDialer but takes an optional config
// parameter containing your desired TLS configuration.
func NewTLSDialerWithConfig(d Dialer, h TLSHandshaker, c *tls.Config) TLSDialer {
return &tlsDialer{Config: c, Dialer: d, TLSHandshaker: h}
}
// tlsDialer is the TLS dialer
type tlsDialer struct {
// Config is the OPTIONAL tls config. // Config is the OPTIONAL tls config.
Config *tls.Config Config *tls.Config
@ -228,13 +249,15 @@ type TLSDialer struct {
TLSHandshaker TLSHandshaker TLSHandshaker TLSHandshaker
} }
// CloseIdleConnection closes idle connections, if any. var _ TLSDialer = &tlsDialer{}
func (d *TLSDialer) CloseIdleConnection() {
// CloseIdleConnections implements TLSDialer.CloseIdleConnections.
func (d *tlsDialer) CloseIdleConnections() {
d.Dialer.CloseIdleConnections() d.Dialer.CloseIdleConnections()
} }
// DialTLSContext dials a TLS connection. // DialTLSContext implements TLSDialer.DialTLSContext.
func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *tlsDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -258,7 +281,7 @@ func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string)
// We set the ServerName field if not already set. // We set the ServerName field if not already set.
// //
// We set the ALPN if the port is 443 or 853, if not already set. // We set the ALPN if the port is 443 or 853, if not already set.
func (d *TLSDialer) config(host, port string) *tls.Config { func (d *tlsDialer) config(host, port string) *tls.Config {
config := d.Config config := d.Config
if config == nil { if config == nil {
config = &tls.Config{} config = &tls.Config{}
@ -277,3 +300,22 @@ func (d *TLSDialer) config(host, port string) *tls.Config {
} }
return config return config
} }
// NewSingleUseTLSDialer is like NewSingleUseDialer but takes
// in input a TLSConn rather than a net.Conn.
func NewSingleUseTLSDialer(conn TLSConn) TLSDialer {
return &tlsDialerSingleUseAdapter{NewSingleUseDialer(conn)}
}
// tlsDialerSingleUseAdapter adapts dialerSingleUse to
// be a TLSDialer type rather than a Dialer type.
type tlsDialerSingleUseAdapter struct {
Dialer
}
var _ TLSDialer = &tlsDialerSingleUseAdapter{}
// DialTLSContext implements TLSDialer.DialTLSContext.
func (d *tlsDialerSingleUseAdapter) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address)
}

View File

@ -280,21 +280,21 @@ func TestTLSHandshakerLoggerFailure(t *testing.T) {
func TestTLSDialerCloseIdleConnections(t *testing.T) { func TestTLSDialerCloseIdleConnections(t *testing.T) {
var called bool var called bool
dialer := &TLSDialer{ dialer := &tlsDialer{
Dialer: &mocks.Dialer{ Dialer: &mocks.Dialer{
MockCloseIdleConnections: func() { MockCloseIdleConnections: func() {
called = true called = true
}, },
}, },
} }
dialer.CloseIdleConnection() dialer.CloseIdleConnections()
if !called { if !called {
t.Fatal("not called") t.Fatal("not called")
} }
} }
func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) { func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
dialer := &TLSDialer{} dialer := &tlsDialer{}
ctx := context.Background() ctx := context.Background()
const address = "www.google.com" // missing port const address = "www.google.com" // missing port
conn, err := dialer.DialTLSContext(ctx, "tcp", address) conn, err := dialer.DialTLSContext(ctx, "tcp", address)
@ -309,7 +309,7 @@ func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) { func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() // immediately fail cancel() // immediately fail
dialer := TLSDialer{Dialer: defaultDialer} dialer := tlsDialer{Dialer: defaultDialer}
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
@ -321,7 +321,7 @@ func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) { func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
ctx := context.Background() ctx := context.Background()
dialer := TLSDialer{ dialer := tlsDialer{
Config: &tls.Config{}, Config: &tls.Config{},
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{MockWrite: func(b []byte) (int, error) { return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
@ -345,7 +345,7 @@ func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) { func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
ctx := context.Background() ctx := context.Background()
dialer := TLSDialer{ dialer := tlsDialer{
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{MockWrite: func(b []byte) (int, error) { return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
return 0, io.EOF return 0, io.EOF
@ -372,7 +372,7 @@ func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
} }
func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
d := &TLSDialer{} d := &tlsDialer{}
config := d.config("www.google.com", "443") config := d.config("www.google.com", "443")
if config.ServerName != "www.google.com" { if config.ServerName != "www.google.com" {
t.Fatal("invalid server name") t.Fatal("invalid server name")
@ -383,7 +383,7 @@ func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
} }
func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
d := &TLSDialer{} d := &tlsDialer{}
config := d.config("dns.google", "853") config := d.config("dns.google", "853")
if config.ServerName != "dns.google" { if config.ServerName != "dns.google" {
t.Fatal("invalid server name") t.Fatal("invalid server name")
@ -394,7 +394,7 @@ func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
} }
func TestTLSDialerConfigWithServerName(t *testing.T) { func TestTLSDialerConfigWithServerName(t *testing.T) {
d := &TLSDialer{ d := &tlsDialer{
Config: &tls.Config{ Config: &tls.Config{
ServerName: "example.com", ServerName: "example.com",
}, },
@ -409,7 +409,7 @@ func TestTLSDialerConfigWithServerName(t *testing.T) {
} }
func TestTLSDialerConfigWithALPN(t *testing.T) { func TestTLSDialerConfigWithALPN(t *testing.T) {
d := &TLSDialer{ d := &tlsDialer{
Config: &tls.Config{ Config: &tls.Config{
NextProtos: []string{"h2"}, NextProtos: []string{"h2"},
}, },
@ -440,3 +440,43 @@ func TestNewTLSHandshakerStdlibTypes(t *testing.T) {
t.Fatal("expected nil NewConn") t.Fatal("expected nil NewConn")
} }
} }
func TestNewTLSDialerWorksAsIntended(t *testing.T) {
d := &mocks.Dialer{}
tlsh := &mocks.TLSHandshaker{}
td := NewTLSDialer(d, tlsh)
tdut, okay := td.(*tlsDialer)
if !okay {
t.Fatal("invalid type")
}
if tdut.Config == nil {
t.Fatal("unexpected config")
}
if tdut.Dialer != d {
t.Fatal("unexpected dialer")
}
if tdut.TLSHandshaker != tlsh {
t.Fatal("invalid handshaker")
}
}
func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) {
conn := &mocks.TLSConn{}
d := NewSingleUseTLSDialer(conn)
outconn, err := d.DialTLSContext(context.Background(), "", "")
if err != nil {
t.Fatal(err)
}
if conn != outconn {
t.Fatal("invalid outconn")
}
for i := 0; i < 4; i++ {
outconn, err = d.DialTLSContext(context.Background(), "", "")
if !errors.Is(err, ErrNoConnReuse) {
t.Fatal("not the error we expected", err)
}
if outconn != nil {
t.Fatal("expected nil outconn here")
}
}
}