netxlite: code quality, improve tests, docs (#494)
See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
3cd88debdc
commit
50b58672c6
|
@ -40,7 +40,11 @@ type Dialer interface {
|
||||||
//
|
//
|
||||||
// 4. wraps errors;
|
// 4. wraps errors;
|
||||||
//
|
//
|
||||||
// 5. has a configured connect timeout.
|
// 5. has a configured connect timeout;
|
||||||
|
//
|
||||||
|
// 6. if a dialer wraps a resolver, the dialer will forward
|
||||||
|
// the CloseIdleConnection call to its resolver (which is
|
||||||
|
// instrumental to manage a DoH resolver connections properly).
|
||||||
func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
|
func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
|
||||||
return &dialerLogger{
|
return &dialerLogger{
|
||||||
Dialer: &dialerResolver{
|
Dialer: &dialerResolver{
|
||||||
|
|
|
@ -24,6 +24,9 @@
|
||||||
//
|
//
|
||||||
// We also want to mock any underlying dependency for testing.
|
// We also want to mock any underlying dependency for testing.
|
||||||
//
|
//
|
||||||
|
// We also want to map errors to OONI failures, which are described by
|
||||||
|
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
|
||||||
|
//
|
||||||
// Operations
|
// Operations
|
||||||
//
|
//
|
||||||
// This package implements the following operations:
|
// This package implements the following operations:
|
||||||
|
|
|
@ -30,7 +30,6 @@ type httpTransportLogger struct {
|
||||||
|
|
||||||
var _ HTTPTransport = &httpTransportLogger{}
|
var _ HTTPTransport = &httpTransportLogger{}
|
||||||
|
|
||||||
// RoundTrip implements HTTPTransport.RoundTrip.
|
|
||||||
func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
host := req.Host
|
host := req.Host
|
||||||
if host == "" {
|
if host == "" {
|
||||||
|
@ -64,13 +63,12 @@ func (txp *httpTransportLogger) logTrip(req *http.Request) (*http.Response, erro
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections implement HTTPTransport.CloseIdleConnections.
|
|
||||||
func (txp *httpTransportLogger) CloseIdleConnections() {
|
func (txp *httpTransportLogger) CloseIdleConnections() {
|
||||||
txp.HTTPTransport.CloseIdleConnections()
|
txp.HTTPTransport.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
// httpTransportConnectionsCloser is an HTTPTransport that
|
// httpTransportConnectionsCloser is an HTTPTransport that
|
||||||
// correctly forwards CloseIdleConnections.
|
// correctly forwards CloseIdleConnections calls.
|
||||||
type httpTransportConnectionsCloser struct {
|
type httpTransportConnectionsCloser struct {
|
||||||
HTTPTransport
|
HTTPTransport
|
||||||
Dialer
|
Dialer
|
||||||
|
@ -98,6 +96,16 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
|
||||||
// The returned transport will disable transparent decompression
|
// The returned transport will disable transparent decompression
|
||||||
// of compressed response bodies (and will not automatically
|
// of compressed response bodies (and will not automatically
|
||||||
// ask for such compression, though you can always do that manually).
|
// ask for such compression, though you can always do that manually).
|
||||||
|
//
|
||||||
|
// The returned transport will configure TCP and TLS connections
|
||||||
|
// created using its dialer and TLS dialer to always have a
|
||||||
|
// read watchdog timeout to address https://github.com/ooni/probe/issues/1609.
|
||||||
|
//
|
||||||
|
// The returned transport will always enforce 1 connection per host
|
||||||
|
// and we cannot get rid of this QUIRK requirement because it is
|
||||||
|
// necessary to perform sane measurements with tracing. We will be
|
||||||
|
// able to possibly relax this requirement after we change the
|
||||||
|
// way in which we perform measurements.
|
||||||
func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
|
func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
|
||||||
// Using oohttp to support any TLS library.
|
// Using oohttp to support any TLS library.
|
||||||
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
|
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
|
||||||
|
@ -115,8 +123,6 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTra
|
||||||
|
|
||||||
// Better for Cloudflare DNS and also better because we have less
|
// Better for Cloudflare DNS and also better because we have less
|
||||||
// noisy events and we can better understand what happened.
|
// noisy events and we can better understand what happened.
|
||||||
//
|
|
||||||
// UNDOCUMENTED: I am wondering whether we can relax this constraint.
|
|
||||||
txp.MaxConnsPerHost = 1
|
txp.MaxConnsPerHost = 1
|
||||||
|
|
||||||
// The following (1) reduces the number of headers that Go will
|
// The following (1) reduces the number of headers that Go will
|
||||||
|
@ -175,7 +181,7 @@ func (d *httpTLSDialerWithReadTimeout) DialTLSContext(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tconn, okay := conn.(TLSConn)
|
tconn, okay := conn.(TLSConn) // part of the contract but let's be graceful
|
||||||
if !okay {
|
if !okay {
|
||||||
conn.Close() // we own the conn here
|
conn.Close() // we own the conn here
|
||||||
return nil, ErrNotTLSConn
|
return nil, ErrNotTLSConn
|
||||||
|
|
|
@ -33,7 +33,7 @@ func TestHTTP3Dialer(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTP3TransportClosesIdleConnections(t *testing.T) {
|
func TestHTTP3Transport(t *testing.T) {
|
||||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
calledHTTP3 bool
|
calledHTTP3 bool
|
||||||
|
|
|
@ -21,8 +21,17 @@ import (
|
||||||
func TestHTTPTransportLogger(t *testing.T) {
|
func TestHTTPTransportLogger(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
t.Run("with failure", func(t *testing.T) {
|
t.Run("with failure", func(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebug: func(message string) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
}
|
||||||
txp := &httpTransportLogger{
|
txp := &httpTransportLogger{
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
HTTPTransport: &mocks.HTTPTransport{
|
HTTPTransport: &mocks.HTTPTransport{
|
||||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
|
@ -37,6 +46,9 @@ func TestHTTPTransportLogger(t *testing.T) {
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected nil response here")
|
t.Fatal("expected nil response here")
|
||||||
}
|
}
|
||||||
|
if count < 1 {
|
||||||
|
t.Fatal("no logs?!")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("we add the host header", func(t *testing.T) {
|
t.Run("we add the host header", func(t *testing.T) {
|
||||||
|
@ -73,8 +85,17 @@ func TestHTTPTransportLogger(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with success", func(t *testing.T) {
|
t.Run("with success", func(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebug: func(message string) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
}
|
||||||
txp := &httpTransportLogger{
|
txp := &httpTransportLogger{
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
HTTPTransport: &mocks.HTTPTransport{
|
HTTPTransport: &mocks.HTTPTransport{
|
||||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
|
@ -94,6 +115,9 @@ func TestHTTPTransportLogger(t *testing.T) {
|
||||||
}
|
}
|
||||||
iox.ReadAllContext(context.Background(), resp.Body)
|
iox.ReadAllContext(context.Background(), resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
|
if count < 1 {
|
||||||
|
t.Fatal("no logs?!")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,26 @@ import (
|
||||||
utls "gitlab.com/yawning/utls.git"
|
utls "gitlab.com/yawning/utls.git"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestResolver(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skip test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("works as intended", func(t *testing.T) {
|
||||||
|
// TODO(bassosimone): this is actually an integration
|
||||||
|
// test but how to test this case?
|
||||||
|
r := netxlite.NewResolverSystem(log.Log)
|
||||||
|
defer r.CloseIdleConnections()
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if addrs == nil {
|
||||||
|
t.Fatal("expected non-nil result here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestHTTPTransport(t *testing.T) {
|
func TestHTTPTransport(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skip test in short mode")
|
t.Skip("skip test in short mode")
|
||||||
|
|
|
@ -13,6 +13,7 @@ var (
|
||||||
DefaultDialer = &dialerSystem{}
|
DefaultDialer = &dialerSystem{}
|
||||||
DefaultTLSHandshaker = defaultTLSHandshaker
|
DefaultTLSHandshaker = defaultTLSHandshaker
|
||||||
NewConnUTLS = newConnUTLS
|
NewConnUTLS = newConnUTLS
|
||||||
|
DefaultResolver = &resolverSystem{}
|
||||||
)
|
)
|
||||||
|
|
||||||
// These types export internal names to legacy ooni/probe-cli code.
|
// These types export internal names to legacy ooni/probe-cli code.
|
||||||
|
|
|
@ -50,6 +50,24 @@ type QUICDialer interface {
|
||||||
// NewQUICDialerWithResolver returns a QUICDialer using the given
|
// NewQUICDialerWithResolver returns a QUICDialer using the given
|
||||||
// QUICListener to create listening connections and the given Resolver
|
// QUICListener to create listening connections and the given Resolver
|
||||||
// to resolve domain names (if needed).
|
// to resolve domain names (if needed).
|
||||||
|
//
|
||||||
|
// Properties of the dialer:
|
||||||
|
//
|
||||||
|
// 1. logs events using the given logger;
|
||||||
|
//
|
||||||
|
// 2. resolves domain names using the givern resolver;
|
||||||
|
//
|
||||||
|
// 3. when using a resolver, _may_ attempt multiple dials
|
||||||
|
// in parallel (happy eyeballs) and _may_ return an aggregate
|
||||||
|
// error to the caller;
|
||||||
|
//
|
||||||
|
// 4. wraps errors;
|
||||||
|
//
|
||||||
|
// 5. has a configured connect timeout;
|
||||||
|
//
|
||||||
|
// 6. if a dialer wraps a resolver, the dialer will forward
|
||||||
|
// the CloseIdleConnection call to its resolver (which is
|
||||||
|
// instrumental to manage a DoH resolver connections properly).
|
||||||
func NewQUICDialerWithResolver(listener QUICListener,
|
func NewQUICDialerWithResolver(listener QUICListener,
|
||||||
logger Logger, resolver Resolver) QUICDialer {
|
logger Logger, resolver Resolver) QUICDialer {
|
||||||
return &quicDialerLogger{
|
return &quicDialerLogger{
|
||||||
|
@ -210,12 +228,9 @@ func (d *quicDialerResolver) DialContext(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
|
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
|
||||||
// TODO(bassosimone): here we should be using multierror rather
|
// See TODO(https://github.com/ooni/probe/issues/1779) however
|
||||||
// than just calling ReduceErrors. We are not ready to do that
|
// this is less of a problem for QUIC because so far we have been
|
||||||
// yet, though. To do that, we need first to modify nettests so
|
// using it to perform research only (i.e., urlgetter).
|
||||||
// that we actually avoid dialing when measuring.
|
|
||||||
//
|
|
||||||
// See also the quirks.go file. This is clearly a QUIRK.
|
|
||||||
addrs = quirkSortIPAddrs(addrs)
|
addrs = quirkSortIPAddrs(addrs)
|
||||||
var errorslist []error
|
var errorslist []error
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
|
|
|
@ -17,6 +17,34 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestNewQUICListener(t *testing.T) {
|
||||||
|
ql := NewQUICListener()
|
||||||
|
qew := ql.(*quicListenerErrWrapper)
|
||||||
|
_ = qew.QUICListener.(*quicListenerStdlib)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewQUICDialer(t *testing.T) {
|
||||||
|
ql := NewQUICListener()
|
||||||
|
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
|
||||||
|
logger := dlr.(*quicDialerLogger)
|
||||||
|
if logger.Logger != log.Log {
|
||||||
|
t.Fatal("invalid logger")
|
||||||
|
}
|
||||||
|
resolver := logger.Dialer.(*quicDialerResolver)
|
||||||
|
if _, okay := resolver.Resolver.(*nullResolver); !okay {
|
||||||
|
t.Fatal("invalid resolver type")
|
||||||
|
}
|
||||||
|
logger = resolver.Dialer.(*quicDialerLogger)
|
||||||
|
if logger.Logger != log.Log {
|
||||||
|
t.Fatal("invalid logger")
|
||||||
|
}
|
||||||
|
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
|
||||||
|
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
|
||||||
|
if base.QUICListener != ql {
|
||||||
|
t.Fatal("invalid quic listener")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQUICDialerQUICGo(t *testing.T) {
|
func TestQUICDialerQUICGo(t *testing.T) {
|
||||||
t.Run("DialContext", func(t *testing.T) {
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
t.Run("cannot split host port", func(t *testing.T) {
|
t.Run("cannot split host port", func(t *testing.T) {
|
||||||
|
@ -223,7 +251,6 @@ func TestQUICDialerQUICGo(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICDialerResolver(t *testing.T) {
|
func TestQUICDialerResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
forDialer bool
|
forDialer bool
|
||||||
|
@ -302,7 +329,7 @@ func TestQUICDialerResolver(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with invalid port", func(t *testing.T) {
|
t.Run("with invalid port (i.e., the zero port)", func(t *testing.T) {
|
||||||
// This test allows us to check for the case where every attempt
|
// This test allows us to check for the case where every attempt
|
||||||
// to establish a connection leads to a failure
|
// to establish a connection leads to a failure
|
||||||
tlsConf := &tls.Config{}
|
tlsConf := &tls.Config{}
|
||||||
|
@ -376,7 +403,6 @@ func TestQUICDialerResolver(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICLoggerDialer(t *testing.T) {
|
func TestQUICLoggerDialer(t *testing.T) {
|
||||||
|
|
||||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
var forDialer bool
|
var forDialer bool
|
||||||
d := &quicDialerLogger{
|
d := &quicDialerLogger{
|
||||||
|
@ -394,6 +420,12 @@ func TestQUICLoggerDialer(t *testing.T) {
|
||||||
|
|
||||||
t.Run("DialContext", func(t *testing.T) {
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
t.Run("on success", func(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
var called int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
called++
|
||||||
|
},
|
||||||
|
}
|
||||||
d := &quicDialerLogger{
|
d := &quicDialerLogger{
|
||||||
Dialer: &mocks.QUICDialer{
|
Dialer: &mocks.QUICDialer{
|
||||||
MockDialContext: func(ctx context.Context, network string,
|
MockDialContext: func(ctx context.Context, network string,
|
||||||
|
@ -407,7 +439,7 @@ func TestQUICLoggerDialer(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
tlsConfig := &tls.Config{}
|
tlsConfig := &tls.Config{}
|
||||||
|
@ -419,9 +451,18 @@ func TestQUICLoggerDialer(t *testing.T) {
|
||||||
if err := sess.CloseWithError(0, ""); err != nil {
|
if err := sess.CloseWithError(0, ""); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if called != 2 {
|
||||||
|
t.Fatal("invalid number of calls")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("on failure", func(t *testing.T) {
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
var called int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
called++
|
||||||
|
},
|
||||||
|
}
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
d := &quicDialerLogger{
|
d := &quicDialerLogger{
|
||||||
Dialer: &mocks.QUICDialer{
|
Dialer: &mocks.QUICDialer{
|
||||||
|
@ -431,7 +472,7 @@ func TestQUICLoggerDialer(t *testing.T) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
tlsConfig := &tls.Config{}
|
tlsConfig := &tls.Config{}
|
||||||
|
@ -443,32 +484,13 @@ func TestQUICLoggerDialer(t *testing.T) {
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
t.Fatal("expected nil session")
|
t.Fatal("expected nil session")
|
||||||
}
|
}
|
||||||
|
if called != 2 {
|
||||||
|
t.Fatal("invalid number of calls")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewQUICDialer(t *testing.T) {
|
|
||||||
ql := NewQUICListener()
|
|
||||||
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
|
|
||||||
logger := dlr.(*quicDialerLogger)
|
|
||||||
if logger.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
|
||||||
}
|
|
||||||
resolver := logger.Dialer.(*quicDialerResolver)
|
|
||||||
if _, okay := resolver.Resolver.(*nullResolver); !okay {
|
|
||||||
t.Fatal("invalid resolver type")
|
|
||||||
}
|
|
||||||
logger = resolver.Dialer.(*quicDialerLogger)
|
|
||||||
if logger.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
|
||||||
}
|
|
||||||
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
|
|
||||||
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
|
|
||||||
if base.QUICListener != ql {
|
|
||||||
t.Fatal("invalid quic listener")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewSingleUseQUICDialer(t *testing.T) {
|
func TestNewSingleUseQUICDialer(t *testing.T) {
|
||||||
sess := &mocks.QUICEarlySession{}
|
sess := &mocks.QUICEarlySession{}
|
||||||
qd := NewSingleUseQUICDialer(sess)
|
qd := NewSingleUseQUICDialer(sess)
|
||||||
|
|
|
@ -54,26 +54,34 @@ func TestQuirkReduceErrors(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQuirkSortIPAddrs(t *testing.T) {
|
func TestQuirkSortIPAddrs(t *testing.T) {
|
||||||
addrs := []string{
|
t.Run("with some addrs", func(t *testing.T) {
|
||||||
"::1",
|
addrs := []string{
|
||||||
"192.168.1.2",
|
"::1",
|
||||||
"2a00:1450:4002:404::2004",
|
"192.168.1.2",
|
||||||
"142.250.184.36",
|
"2a00:1450:4002:404::2004",
|
||||||
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
"142.250.184.36",
|
||||||
"198.145.29.83",
|
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
||||||
"95.216.163.36",
|
"198.145.29.83",
|
||||||
}
|
"95.216.163.36",
|
||||||
expected := []string{
|
}
|
||||||
"192.168.1.2",
|
expected := []string{
|
||||||
"142.250.184.36",
|
"192.168.1.2",
|
||||||
"198.145.29.83",
|
"142.250.184.36",
|
||||||
"95.216.163.36",
|
"198.145.29.83",
|
||||||
"::1",
|
"95.216.163.36",
|
||||||
"2a00:1450:4002:404::2004",
|
"::1",
|
||||||
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
"2a00:1450:4002:404::2004",
|
||||||
}
|
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
||||||
out := quirkSortIPAddrs(addrs)
|
}
|
||||||
if diff := cmp.Diff(expected, out); diff != "" {
|
out := quirkSortIPAddrs(addrs)
|
||||||
t.Fatal(diff)
|
if diff := cmp.Diff(expected, out); diff != "" {
|
||||||
}
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with an empty list", func(t *testing.T) {
|
||||||
|
if quirkSortIPAddrs(nil) != nil {
|
||||||
|
t.Fatal("expected nil output")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,20 @@ type Resolver interface {
|
||||||
|
|
||||||
// NewResolverSystem creates a new resolver using system
|
// NewResolverSystem creates a new resolver using system
|
||||||
// facilities for resolving domain names (e.g., getaddrinfo).
|
// facilities for resolving domain names (e.g., getaddrinfo).
|
||||||
|
//
|
||||||
|
// The resolver will provide the following guarantees:
|
||||||
|
//
|
||||||
|
// 1. handles IDNA;
|
||||||
|
//
|
||||||
|
// 2. performs logging;
|
||||||
|
//
|
||||||
|
// 3. short-circuits IP addresses like getaddrinfo does (i.e.,
|
||||||
|
// resolving "1.1.1.1" yields []string{"1.1.1.1"};
|
||||||
|
//
|
||||||
|
// 4. wraps errors;
|
||||||
|
//
|
||||||
|
// 5. enforces reasonable timeouts (
|
||||||
|
// see https://github.com/ooni/probe/issues/1726).
|
||||||
func NewResolverSystem(logger Logger) Resolver {
|
func NewResolverSystem(logger Logger) Resolver {
|
||||||
return &resolverIDNA{
|
return &resolverIDNA{
|
||||||
Resolver: &resolverLogger{
|
Resolver: &resolverLogger{
|
||||||
|
@ -48,7 +62,6 @@ type resolverSystem struct {
|
||||||
|
|
||||||
var _ Resolver = &resolverSystem{}
|
var _ Resolver = &resolverSystem{}
|
||||||
|
|
||||||
// LookupHost implements Resolver.LookupHost.
|
|
||||||
func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
// This code forces adding a shorter timeout to the domain name
|
// This code forces adding a shorter timeout to the domain name
|
||||||
// resolutions when using the system resolver. We have seen cases
|
// resolutions when using the system resolver. We have seen cases
|
||||||
|
@ -89,24 +102,18 @@ func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) (
|
||||||
return net.DefaultResolver.LookupHost
|
return net.DefaultResolver.LookupHost
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network implements Resolver.Network.
|
|
||||||
func (r *resolverSystem) Network() string {
|
func (r *resolverSystem) Network() string {
|
||||||
return "system"
|
return "system"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address implements Resolver.Address.
|
|
||||||
func (r *resolverSystem) Address() string {
|
func (r *resolverSystem) Address() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections implements Resolver.CloseIdleConnections.
|
|
||||||
func (r *resolverSystem) CloseIdleConnections() {
|
func (r *resolverSystem) CloseIdleConnections() {
|
||||||
// nothing
|
// nothing to do
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultResolver is the resolver we use by default.
|
|
||||||
var DefaultResolver = &resolverSystem{}
|
|
||||||
|
|
||||||
// resolverLogger is a resolver that emits events
|
// resolverLogger is a resolver that emits events
|
||||||
type resolverLogger struct {
|
type resolverLogger struct {
|
||||||
Resolver
|
Resolver
|
||||||
|
@ -115,7 +122,6 @@ type resolverLogger struct {
|
||||||
|
|
||||||
var _ Resolver = &resolverLogger{}
|
var _ Resolver = &resolverLogger{}
|
||||||
|
|
||||||
// LookupHost returns the IP addresses of a host
|
|
||||||
func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
r.Logger.Debugf("resolve %s...", hostname)
|
r.Logger.Debugf("resolve %s...", hostname)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
@ -136,7 +142,6 @@ type resolverIDNA struct {
|
||||||
Resolver
|
Resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHost implements Resolver.LookupHost.
|
|
||||||
func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
host, err := idna.ToASCII(hostname)
|
host, err := idna.ToASCII(hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -151,7 +156,6 @@ type resolverShortCircuitIPAddr struct {
|
||||||
Resolver
|
Resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHost implements Resolver.LookupHost.
|
|
||||||
func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
if net.ParseIP(hostname) != nil {
|
if net.ParseIP(hostname) != nil {
|
||||||
return []string{hostname}, nil
|
return []string{hostname}, nil
|
||||||
|
@ -166,24 +170,20 @@ var ErrNoResolver = errors.New("no configured resolver")
|
||||||
// domain names to IP addresses and always returns ErrNoResolver.
|
// domain names to IP addresses and always returns ErrNoResolver.
|
||||||
type nullResolver struct{}
|
type nullResolver struct{}
|
||||||
|
|
||||||
// LookupHost implements Resolver.LookupHost.
|
|
||||||
func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) {
|
func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) {
|
||||||
return nil, ErrNoResolver
|
return nil, ErrNoResolver
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network implements Resolver.Network.
|
|
||||||
func (r *nullResolver) Network() string {
|
func (r *nullResolver) Network() string {
|
||||||
return "null"
|
return "null"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address implements Resolver.Address.
|
|
||||||
func (r *nullResolver) Address() string {
|
func (r *nullResolver) Address() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections implements Resolver.CloseIdleConnections.
|
|
||||||
func (r *nullResolver) CloseIdleConnections() {
|
func (r *nullResolver) CloseIdleConnections() {
|
||||||
// nothing
|
// nothing to do
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolverErrWrapper is a Resolver that knows about wrapping errors.
|
// resolverErrWrapper is a Resolver that knows about wrapping errors.
|
||||||
|
@ -193,7 +193,6 @@ type resolverErrWrapper struct {
|
||||||
|
|
||||||
var _ Resolver = &resolverErrWrapper{}
|
var _ Resolver = &resolverErrWrapper{}
|
||||||
|
|
||||||
// LookupHost implements Resolver.LookupHost
|
|
||||||
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -15,6 +15,18 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestNewResolverSystem(t *testing.T) {
|
||||||
|
resolver := NewResolverSystem(log.Log)
|
||||||
|
idna := resolver.(*resolverIDNA)
|
||||||
|
logger := idna.Resolver.(*resolverLogger)
|
||||||
|
if logger.Logger != log.Log {
|
||||||
|
t.Fatal("invalid logger")
|
||||||
|
}
|
||||||
|
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||||
|
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||||
|
_ = errWrapper.Resolver.(*resolverSystem)
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolverSystem(t *testing.T) {
|
func TestResolverSystem(t *testing.T) {
|
||||||
t.Run("Network and Address", func(t *testing.T) {
|
t.Run("Network and Address", func(t *testing.T) {
|
||||||
r := &resolverSystem{}
|
r := &resolverSystem{}
|
||||||
|
@ -26,16 +38,9 @@ func TestResolverSystem(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("works as intended", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
r := &resolverSystem{}
|
r := &resolverSystem{}
|
||||||
defer r.CloseIdleConnections()
|
r.CloseIdleConnections() // to cover it
|
||||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if addrs == nil {
|
|
||||||
t.Fatal("expected non-nil result here")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check default timeout", func(t *testing.T) {
|
t.Run("check default timeout", func(t *testing.T) {
|
||||||
|
@ -45,7 +50,30 @@ func TestResolverSystem(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("check default lookup host func not nil", func(t *testing.T) {
|
||||||
|
r := &resolverSystem{}
|
||||||
|
if r.lookupHost() == nil {
|
||||||
|
t.Fatal("expected non-nil func here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("LookupHost", func(t *testing.T) {
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
|
t.Run("with success", func(t *testing.T) {
|
||||||
|
r := &resolverSystem{
|
||||||
|
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
return []string{"8.8.8.8"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
addrs, err := r.LookupHost(ctx, "example.antani")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||||
|
t.Fatal("invalid addrs")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("with timeout and success", func(t *testing.T) {
|
t.Run("with timeout and success", func(t *testing.T) {
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
@ -111,9 +139,15 @@ func TestResolverSystem(t *testing.T) {
|
||||||
func TestResolverLogger(t *testing.T) {
|
func TestResolverLogger(t *testing.T) {
|
||||||
t.Run("LookupHost", func(t *testing.T) {
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
t.Run("with success", func(t *testing.T) {
|
t.Run("with success", func(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
}
|
||||||
expected := []string{"1.1.1.1"}
|
expected := []string{"1.1.1.1"}
|
||||||
r := resolverLogger{
|
r := resolverLogger{
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
Resolver: &mocks.Resolver{
|
Resolver: &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return expected, nil
|
return expected, nil
|
||||||
|
@ -127,12 +161,21 @@ func TestResolverLogger(t *testing.T) {
|
||||||
if diff := cmp.Diff(expected, addrs); diff != "" {
|
if diff := cmp.Diff(expected, addrs); diff != "" {
|
||||||
t.Fatal(diff)
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Fatal("unexpected count")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with failure", func(t *testing.T) {
|
t.Run("with failure", func(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
}
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := resolverLogger{
|
r := resolverLogger{
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
Resolver: &mocks.Resolver{
|
Resolver: &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
|
@ -146,6 +189,9 @@ func TestResolverLogger(t *testing.T) {
|
||||||
if addrs != nil {
|
if addrs != nil {
|
||||||
t.Fatal("expected nil addr here")
|
t.Fatal("expected nil addr here")
|
||||||
}
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Fatal("unexpected count")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -193,18 +239,6 @@ func TestResolverIDNA(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverSystem(t *testing.T) {
|
|
||||||
resolver := NewResolverSystem(log.Log)
|
|
||||||
idna := resolver.(*resolverIDNA)
|
|
||||||
logger := idna.Resolver.(*resolverLogger)
|
|
||||||
if logger.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
|
||||||
}
|
|
||||||
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
|
||||||
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
|
||||||
_ = errWrapper.Resolver.(*resolverSystem)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverShortCircuitIPAddr(t *testing.T) {
|
func TestResolverShortCircuitIPAddr(t *testing.T) {
|
||||||
t.Run("LookupHost", func(t *testing.T) {
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
t.Run("with IP addr", func(t *testing.T) {
|
t.Run("with IP addr", func(t *testing.T) {
|
||||||
|
@ -261,7 +295,7 @@ func TestNullResolver(t *testing.T) {
|
||||||
if r.Address() != "" {
|
if r.Address() != "" {
|
||||||
t.Fatal("invalid address")
|
t.Fatal("invalid address")
|
||||||
}
|
}
|
||||||
r.CloseIdleConnections() // should not crash
|
r.CloseIdleConnections() // for coverage
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolverErrWrapper(t *testing.T) {
|
func TestResolverErrWrapper(t *testing.T) {
|
||||||
|
|
|
@ -99,7 +99,7 @@ func ConfigureTLSVersion(config *tls.Config, version string) error {
|
||||||
config.MinVersion = tls.VersionTLS10
|
config.MinVersion = tls.VersionTLS10
|
||||||
config.MaxVersion = tls.VersionTLS10
|
config.MaxVersion = tls.VersionTLS10
|
||||||
case "":
|
case "":
|
||||||
// nothing
|
// nothing to do
|
||||||
default:
|
default:
|
||||||
return ErrInvalidTLSVersion
|
return ErrInvalidTLSVersion
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ type TLSHandshaker interface {
|
||||||
// the given config. This function DOES NOT take ownership of the connection
|
// the given config. This function DOES NOT take ownership of the connection
|
||||||
// and it's your responsibility to close it on failure.
|
// and it's your responsibility to close it on failure.
|
||||||
//
|
//
|
||||||
// The returned connection will always implement the TLSConn interface
|
// QUIRK: The returned connection will always implement the TLSConn interface
|
||||||
// exposed by this package. A future version of this interface will instead
|
// exposed by this package. A future version of this interface will instead
|
||||||
// return directly a TLSConn and remove the ConnectionState param.
|
// return directly a TLSConn and remove the ConnectionState param.
|
||||||
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
|
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
|
||||||
|
@ -128,10 +128,21 @@ type TLSHandshaker interface {
|
||||||
|
|
||||||
// NewTLSHandshakerStdlib creates a new TLS handshaker using the
|
// NewTLSHandshakerStdlib creates a new TLS handshaker using the
|
||||||
// go standard library to create TLS connections.
|
// go standard library to create TLS connections.
|
||||||
|
//
|
||||||
|
// The handshaker guarantees:
|
||||||
|
//
|
||||||
|
// 1. logging
|
||||||
|
//
|
||||||
|
// 2. error wrapping
|
||||||
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
|
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
|
||||||
|
return newTLSHandshaker(&tlsHandshakerConfigurable{}, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
|
||||||
|
func newTLSHandshaker(th TLSHandshaker, logger Logger) TLSHandshaker {
|
||||||
return &tlsHandshakerLogger{
|
return &tlsHandshakerLogger{
|
||||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
TLSHandshaker: &tlsHandshakerErrWrapper{
|
||||||
TLSHandshaker: &tlsHandshakerConfigurable{},
|
TLSHandshaker: th,
|
||||||
},
|
},
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
}
|
}
|
||||||
|
@ -191,11 +202,8 @@ var defaultTLSHandshaker = &tlsHandshakerConfigurable{}
|
||||||
|
|
||||||
// tlsHandshakerLogger is a TLSHandshaker with logging.
|
// tlsHandshakerLogger is a TLSHandshaker with logging.
|
||||||
type tlsHandshakerLogger struct {
|
type tlsHandshakerLogger struct {
|
||||||
// TLSHandshaker is the underlying handshaker.
|
TLSHandshaker
|
||||||
TLSHandshaker TLSHandshaker
|
Logger
|
||||||
|
|
||||||
// Logger is the underlying logger.
|
|
||||||
Logger Logger
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TLSHandshaker = &tlsHandshakerLogger{}
|
var _ TLSHandshaker = &tlsHandshakerLogger{}
|
||||||
|
|
|
@ -118,10 +118,22 @@ func TestConfigureTLSVersion(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||||
|
th := NewTLSHandshakerStdlib(log.Log)
|
||||||
|
logger := th.(*tlsHandshakerLogger)
|
||||||
|
if logger.Logger != log.Log {
|
||||||
|
t.Fatal("invalid logger")
|
||||||
|
}
|
||||||
|
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||||
|
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||||
|
if configurable.NewConn != nil {
|
||||||
|
t.Fatal("expected nil NewConn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerConfigurable(t *testing.T) {
|
func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
t.Run("Handshake", func(t *testing.T) {
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
t.Run("with error", func(t *testing.T) {
|
t.Run("with error", func(t *testing.T) {
|
||||||
|
|
||||||
var times []time.Time
|
var times []time.Time
|
||||||
h := &tlsHandshakerConfigurable{}
|
h := &tlsHandshakerConfigurable{}
|
||||||
tcpConn := &mocks.Conn{
|
tcpConn := &mocks.Conn{
|
||||||
|
@ -230,13 +242,19 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
func TestTLSHandshakerLogger(t *testing.T) {
|
func TestTLSHandshakerLogger(t *testing.T) {
|
||||||
t.Run("Handshake", func(t *testing.T) {
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
t.Run("on success", func(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
}
|
||||||
th := &tlsHandshakerLogger{
|
th := &tlsHandshakerLogger{
|
||||||
TLSHandshaker: &mocks.TLSHandshaker{
|
TLSHandshaker: &mocks.TLSHandshaker{
|
||||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
return tls.Client(conn, config), tls.ConnectionState{}, nil
|
return tls.Client(conn, config), tls.ConnectionState{}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
}
|
}
|
||||||
conn := &mocks.Conn{
|
conn := &mocks.Conn{
|
||||||
MockClose: func() error {
|
MockClose: func() error {
|
||||||
|
@ -255,9 +273,18 @@ func TestTLSHandshakerLogger(t *testing.T) {
|
||||||
if !reflect.ValueOf(connState).IsZero() {
|
if !reflect.ValueOf(connState).IsZero() {
|
||||||
t.Fatal("expected zero ConnectionState here")
|
t.Fatal("expected zero ConnectionState here")
|
||||||
}
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("on failure", func(t *testing.T) {
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
lo := &mocks.Logger{
|
||||||
|
MockDebugf: func(format string, v ...interface{}) {
|
||||||
|
count++
|
||||||
|
},
|
||||||
|
}
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
th := &tlsHandshakerLogger{
|
th := &tlsHandshakerLogger{
|
||||||
TLSHandshaker: &mocks.TLSHandshaker{
|
TLSHandshaker: &mocks.TLSHandshaker{
|
||||||
|
@ -265,7 +292,7 @@ func TestTLSHandshakerLogger(t *testing.T) {
|
||||||
return nil, tls.ConnectionState{}, expected
|
return nil, tls.ConnectionState{}, expected
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Logger: log.Log,
|
Logger: lo,
|
||||||
}
|
}
|
||||||
conn := &mocks.Conn{
|
conn := &mocks.Conn{
|
||||||
MockClose: func() error {
|
MockClose: func() error {
|
||||||
|
@ -284,10 +311,29 @@ func TestTLSHandshakerLogger(t *testing.T) {
|
||||||
if !reflect.ValueOf(connState).IsZero() {
|
if !reflect.ValueOf(connState).IsZero() {
|
||||||
t.Fatal("expected zero ConnectionState here")
|
t.Fatal("expected zero ConnectionState here")
|
||||||
}
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewTLSDialer(t *testing.T) {
|
||||||
|
d := &mocks.Dialer{}
|
||||||
|
th := &mocks.TLSHandshaker{}
|
||||||
|
dialer := NewTLSDialer(d, th)
|
||||||
|
tlsd := dialer.(*tlsDialer)
|
||||||
|
if tlsd.Config == nil {
|
||||||
|
t.Fatal("unexpected config")
|
||||||
|
}
|
||||||
|
if tlsd.Dialer != d {
|
||||||
|
t.Fatal("unexpected dialer")
|
||||||
|
}
|
||||||
|
if tlsd.TLSHandshaker != th {
|
||||||
|
t.Fatal("invalid handshaker")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTLSDialer(t *testing.T) {
|
func TestTLSDialer(t *testing.T) {
|
||||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
var called bool
|
var called bool
|
||||||
|
@ -439,35 +485,6 @@ func TestTLSDialer(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTLSHandshakerStdlib(t *testing.T) {
|
|
||||||
th := NewTLSHandshakerStdlib(log.Log)
|
|
||||||
logger := th.(*tlsHandshakerLogger)
|
|
||||||
if logger.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
|
||||||
}
|
|
||||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
|
||||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
|
||||||
if configurable.NewConn != nil {
|
|
||||||
t.Fatal("expected nil NewConn")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSDialer(t *testing.T) {
|
|
||||||
d := &mocks.Dialer{}
|
|
||||||
th := &mocks.TLSHandshaker{}
|
|
||||||
dialer := NewTLSDialer(d, th)
|
|
||||||
tlsd := dialer.(*tlsDialer)
|
|
||||||
if tlsd.Config == nil {
|
|
||||||
t.Fatal("unexpected config")
|
|
||||||
}
|
|
||||||
if tlsd.Dialer != d {
|
|
||||||
t.Fatal("unexpected dialer")
|
|
||||||
}
|
|
||||||
if tlsd.TLSHandshaker != th {
|
|
||||||
t.Fatal("invalid handshaker")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewSingleUseTLSDialer(t *testing.T) {
|
func TestNewSingleUseTLSDialer(t *testing.T) {
|
||||||
conn := &mocks.TLSConn{}
|
conn := &mocks.TLSConn{}
|
||||||
d := NewSingleUseTLSDialer(conn)
|
d := NewSingleUseTLSDialer(conn)
|
||||||
|
|
|
@ -11,15 +11,18 @@ import (
|
||||||
|
|
||||||
// NewTLSHandshakerUTLS creates a new TLS handshaker using the
|
// NewTLSHandshakerUTLS creates a new TLS handshaker using the
|
||||||
// gitlab.com/yawning/utls library to create TLS conns.
|
// gitlab.com/yawning/utls library to create TLS conns.
|
||||||
|
//
|
||||||
|
// The handshaker guarantees:
|
||||||
|
//
|
||||||
|
// 1. logging
|
||||||
|
//
|
||||||
|
// 2. error wrapping
|
||||||
|
//
|
||||||
|
// Passing a nil `id` will make this function panic.
|
||||||
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
|
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
|
||||||
return &tlsHandshakerLogger{
|
return newTLSHandshaker(&tlsHandshakerConfigurable{
|
||||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
NewConn: newConnUTLS(id),
|
||||||
TLSHandshaker: &tlsHandshakerConfigurable{
|
}, logger)
|
||||||
NewConn: newConnUTLS(id),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Logger: logger,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// utlsConn implements TLSConn and uses a utls UConn as its underlying connection
|
// utlsConn implements TLSConn and uses a utls UConn as its underlying connection
|
||||||
|
|
Loading…
Reference in New Issue
Block a user