netxlite: code quality, improve tests, docs (#494)
See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
3cd88debdc
commit
50b58672c6
|
@ -40,7 +40,11 @@ type Dialer interface {
|
|||
//
|
||||
// 4. wraps errors;
|
||||
//
|
||||
// 5. has a configured connect timeout.
|
||||
// 5. has a configured connect timeout;
|
||||
//
|
||||
// 6. if a dialer wraps a resolver, the dialer will forward
|
||||
// the CloseIdleConnection call to its resolver (which is
|
||||
// instrumental to manage a DoH resolver connections properly).
|
||||
func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
|
||||
return &dialerLogger{
|
||||
Dialer: &dialerResolver{
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
//
|
||||
// We also want to mock any underlying dependency for testing.
|
||||
//
|
||||
// We also want to map errors to OONI failures, which are described by
|
||||
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
|
||||
//
|
||||
// Operations
|
||||
//
|
||||
// This package implements the following operations:
|
||||
|
|
|
@ -30,7 +30,6 @@ type httpTransportLogger struct {
|
|||
|
||||
var _ HTTPTransport = &httpTransportLogger{}
|
||||
|
||||
// RoundTrip implements HTTPTransport.RoundTrip.
|
||||
func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
|
@ -64,13 +63,12 @@ func (txp *httpTransportLogger) logTrip(req *http.Request) (*http.Response, erro
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
// CloseIdleConnections implement HTTPTransport.CloseIdleConnections.
|
||||
func (txp *httpTransportLogger) CloseIdleConnections() {
|
||||
txp.HTTPTransport.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// httpTransportConnectionsCloser is an HTTPTransport that
|
||||
// correctly forwards CloseIdleConnections.
|
||||
// correctly forwards CloseIdleConnections calls.
|
||||
type httpTransportConnectionsCloser struct {
|
||||
HTTPTransport
|
||||
Dialer
|
||||
|
@ -98,6 +96,16 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() {
|
|||
// The returned transport will disable transparent decompression
|
||||
// of compressed response bodies (and will not automatically
|
||||
// ask for such compression, though you can always do that manually).
|
||||
//
|
||||
// The returned transport will configure TCP and TLS connections
|
||||
// created using its dialer and TLS dialer to always have a
|
||||
// read watchdog timeout to address https://github.com/ooni/probe/issues/1609.
|
||||
//
|
||||
// The returned transport will always enforce 1 connection per host
|
||||
// and we cannot get rid of this QUIRK requirement because it is
|
||||
// necessary to perform sane measurements with tracing. We will be
|
||||
// able to possibly relax this requirement after we change the
|
||||
// way in which we perform measurements.
|
||||
func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTransport {
|
||||
// Using oohttp to support any TLS library.
|
||||
txp := oohttp.DefaultTransport.(*oohttp.Transport).Clone()
|
||||
|
@ -115,8 +123,6 @@ func NewHTTPTransport(logger Logger, dialer Dialer, tlsDialer TLSDialer) HTTPTra
|
|||
|
||||
// Better for Cloudflare DNS and also better because we have less
|
||||
// noisy events and we can better understand what happened.
|
||||
//
|
||||
// UNDOCUMENTED: I am wondering whether we can relax this constraint.
|
||||
txp.MaxConnsPerHost = 1
|
||||
|
||||
// The following (1) reduces the number of headers that Go will
|
||||
|
@ -175,7 +181,7 @@ func (d *httpTLSDialerWithReadTimeout) DialTLSContext(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tconn, okay := conn.(TLSConn)
|
||||
tconn, okay := conn.(TLSConn) // part of the contract but let's be graceful
|
||||
if !okay {
|
||||
conn.Close() // we own the conn here
|
||||
return nil, ErrNotTLSConn
|
||||
|
|
|
@ -33,7 +33,7 @@ func TestHTTP3Dialer(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestHTTP3TransportClosesIdleConnections(t *testing.T) {
|
||||
func TestHTTP3Transport(t *testing.T) {
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var (
|
||||
calledHTTP3 bool
|
||||
|
|
|
@ -21,8 +21,17 @@ import (
|
|||
func TestHTTPTransportLogger(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("with failure", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebug: func(message string) {
|
||||
count++
|
||||
},
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
txp := &httpTransportLogger{
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
HTTPTransport: &mocks.HTTPTransport{
|
||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, io.EOF
|
||||
|
@ -37,6 +46,9 @@ func TestHTTPTransportLogger(t *testing.T) {
|
|||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
if count < 1 {
|
||||
t.Fatal("no logs?!")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("we add the host header", func(t *testing.T) {
|
||||
|
@ -73,8 +85,17 @@ func TestHTTPTransportLogger(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("with success", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebug: func(message string) {
|
||||
count++
|
||||
},
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
txp := &httpTransportLogger{
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
HTTPTransport: &mocks.HTTPTransport{
|
||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
|
@ -94,6 +115,9 @@ func TestHTTPTransportLogger(t *testing.T) {
|
|||
}
|
||||
iox.ReadAllContext(context.Background(), resp.Body)
|
||||
resp.Body.Close()
|
||||
if count < 1 {
|
||||
t.Fatal("no logs?!")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -12,6 +12,26 @@ import (
|
|||
utls "gitlab.com/yawning/utls.git"
|
||||
)
|
||||
|
||||
func TestResolver(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
|
||||
t.Run("works as intended", func(t *testing.T) {
|
||||
// TODO(bassosimone): this is actually an integration
|
||||
// test but how to test this case?
|
||||
r := netxlite.NewResolverSystem(log.Log)
|
||||
defer r.CloseIdleConnections()
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil result here")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPTransport(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
|
|
|
@ -13,6 +13,7 @@ var (
|
|||
DefaultDialer = &dialerSystem{}
|
||||
DefaultTLSHandshaker = defaultTLSHandshaker
|
||||
NewConnUTLS = newConnUTLS
|
||||
DefaultResolver = &resolverSystem{}
|
||||
)
|
||||
|
||||
// These types export internal names to legacy ooni/probe-cli code.
|
||||
|
|
|
@ -50,6 +50,24 @@ type QUICDialer interface {
|
|||
// NewQUICDialerWithResolver returns a QUICDialer using the given
|
||||
// QUICListener to create listening connections and the given Resolver
|
||||
// to resolve domain names (if needed).
|
||||
//
|
||||
// Properties of the dialer:
|
||||
//
|
||||
// 1. logs events using the given logger;
|
||||
//
|
||||
// 2. resolves domain names using the givern resolver;
|
||||
//
|
||||
// 3. when using a resolver, _may_ attempt multiple dials
|
||||
// in parallel (happy eyeballs) and _may_ return an aggregate
|
||||
// error to the caller;
|
||||
//
|
||||
// 4. wraps errors;
|
||||
//
|
||||
// 5. has a configured connect timeout;
|
||||
//
|
||||
// 6. if a dialer wraps a resolver, the dialer will forward
|
||||
// the CloseIdleConnection call to its resolver (which is
|
||||
// instrumental to manage a DoH resolver connections properly).
|
||||
func NewQUICDialerWithResolver(listener QUICListener,
|
||||
logger Logger, resolver Resolver) QUICDialer {
|
||||
return &quicDialerLogger{
|
||||
|
@ -210,12 +228,9 @@ func (d *quicDialerResolver) DialContext(
|
|||
return nil, err
|
||||
}
|
||||
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, onlyhost)
|
||||
// TODO(bassosimone): here we should be using multierror rather
|
||||
// than just calling ReduceErrors. We are not ready to do that
|
||||
// yet, though. To do that, we need first to modify nettests so
|
||||
// that we actually avoid dialing when measuring.
|
||||
//
|
||||
// See also the quirks.go file. This is clearly a QUIRK.
|
||||
// See TODO(https://github.com/ooni/probe/issues/1779) however
|
||||
// this is less of a problem for QUIC because so far we have been
|
||||
// using it to perform research only (i.e., urlgetter).
|
||||
addrs = quirkSortIPAddrs(addrs)
|
||||
var errorslist []error
|
||||
for _, addr := range addrs {
|
||||
|
|
|
@ -17,6 +17,34 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
|
||||
)
|
||||
|
||||
func TestNewQUICListener(t *testing.T) {
|
||||
ql := NewQUICListener()
|
||||
qew := ql.(*quicListenerErrWrapper)
|
||||
_ = qew.QUICListener.(*quicListenerStdlib)
|
||||
}
|
||||
|
||||
func TestNewQUICDialer(t *testing.T) {
|
||||
ql := NewQUICListener()
|
||||
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
|
||||
logger := dlr.(*quicDialerLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
resolver := logger.Dialer.(*quicDialerResolver)
|
||||
if _, okay := resolver.Resolver.(*nullResolver); !okay {
|
||||
t.Fatal("invalid resolver type")
|
||||
}
|
||||
logger = resolver.Dialer.(*quicDialerLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
|
||||
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
|
||||
if base.QUICListener != ql {
|
||||
t.Fatal("invalid quic listener")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQUICDialerQUICGo(t *testing.T) {
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
t.Run("cannot split host port", func(t *testing.T) {
|
||||
|
@ -223,7 +251,6 @@ func TestQUICDialerQUICGo(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestQUICDialerResolver(t *testing.T) {
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var (
|
||||
forDialer bool
|
||||
|
@ -302,7 +329,7 @@ func TestQUICDialerResolver(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("with invalid port", func(t *testing.T) {
|
||||
t.Run("with invalid port (i.e., the zero port)", func(t *testing.T) {
|
||||
// This test allows us to check for the case where every attempt
|
||||
// to establish a connection leads to a failure
|
||||
tlsConf := &tls.Config{}
|
||||
|
@ -376,7 +403,6 @@ func TestQUICDialerResolver(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestQUICLoggerDialer(t *testing.T) {
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var forDialer bool
|
||||
d := &quicDialerLogger{
|
||||
|
@ -394,6 +420,12 @@ func TestQUICLoggerDialer(t *testing.T) {
|
|||
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
var called int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
called++
|
||||
},
|
||||
}
|
||||
d := &quicDialerLogger{
|
||||
Dialer: &mocks.QUICDialer{
|
||||
MockDialContext: func(ctx context.Context, network string,
|
||||
|
@ -407,7 +439,7 @@ func TestQUICLoggerDialer(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
}
|
||||
ctx := context.Background()
|
||||
tlsConfig := &tls.Config{}
|
||||
|
@ -419,9 +451,18 @@ func TestQUICLoggerDialer(t *testing.T) {
|
|||
if err := sess.CloseWithError(0, ""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if called != 2 {
|
||||
t.Fatal("invalid number of calls")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
var called int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
called++
|
||||
},
|
||||
}
|
||||
expected := errors.New("mocked error")
|
||||
d := &quicDialerLogger{
|
||||
Dialer: &mocks.QUICDialer{
|
||||
|
@ -431,7 +472,7 @@ func TestQUICLoggerDialer(t *testing.T) {
|
|||
return nil, expected
|
||||
},
|
||||
},
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
}
|
||||
ctx := context.Background()
|
||||
tlsConfig := &tls.Config{}
|
||||
|
@ -443,32 +484,13 @@ func TestQUICLoggerDialer(t *testing.T) {
|
|||
if sess != nil {
|
||||
t.Fatal("expected nil session")
|
||||
}
|
||||
if called != 2 {
|
||||
t.Fatal("invalid number of calls")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewQUICDialer(t *testing.T) {
|
||||
ql := NewQUICListener()
|
||||
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
|
||||
logger := dlr.(*quicDialerLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
resolver := logger.Dialer.(*quicDialerResolver)
|
||||
if _, okay := resolver.Resolver.(*nullResolver); !okay {
|
||||
t.Fatal("invalid resolver type")
|
||||
}
|
||||
logger = resolver.Dialer.(*quicDialerLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
|
||||
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
|
||||
if base.QUICListener != ql {
|
||||
t.Fatal("invalid quic listener")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSingleUseQUICDialer(t *testing.T) {
|
||||
sess := &mocks.QUICEarlySession{}
|
||||
qd := NewSingleUseQUICDialer(sess)
|
||||
|
|
|
@ -54,26 +54,34 @@ func TestQuirkReduceErrors(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestQuirkSortIPAddrs(t *testing.T) {
|
||||
addrs := []string{
|
||||
"::1",
|
||||
"192.168.1.2",
|
||||
"2a00:1450:4002:404::2004",
|
||||
"142.250.184.36",
|
||||
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
||||
"198.145.29.83",
|
||||
"95.216.163.36",
|
||||
}
|
||||
expected := []string{
|
||||
"192.168.1.2",
|
||||
"142.250.184.36",
|
||||
"198.145.29.83",
|
||||
"95.216.163.36",
|
||||
"::1",
|
||||
"2a00:1450:4002:404::2004",
|
||||
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
||||
}
|
||||
out := quirkSortIPAddrs(addrs)
|
||||
if diff := cmp.Diff(expected, out); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
t.Run("with some addrs", func(t *testing.T) {
|
||||
addrs := []string{
|
||||
"::1",
|
||||
"192.168.1.2",
|
||||
"2a00:1450:4002:404::2004",
|
||||
"142.250.184.36",
|
||||
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
||||
"198.145.29.83",
|
||||
"95.216.163.36",
|
||||
}
|
||||
expected := []string{
|
||||
"192.168.1.2",
|
||||
"142.250.184.36",
|
||||
"198.145.29.83",
|
||||
"95.216.163.36",
|
||||
"::1",
|
||||
"2a00:1450:4002:404::2004",
|
||||
"2604:8800:5000:82:466:38ff:fecb:d46e",
|
||||
}
|
||||
out := quirkSortIPAddrs(addrs)
|
||||
if diff := cmp.Diff(expected, out); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with an empty list", func(t *testing.T) {
|
||||
if quirkSortIPAddrs(nil) != nil {
|
||||
t.Fatal("expected nil output")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -27,6 +27,20 @@ type Resolver interface {
|
|||
|
||||
// NewResolverSystem creates a new resolver using system
|
||||
// facilities for resolving domain names (e.g., getaddrinfo).
|
||||
//
|
||||
// The resolver will provide the following guarantees:
|
||||
//
|
||||
// 1. handles IDNA;
|
||||
//
|
||||
// 2. performs logging;
|
||||
//
|
||||
// 3. short-circuits IP addresses like getaddrinfo does (i.e.,
|
||||
// resolving "1.1.1.1" yields []string{"1.1.1.1"};
|
||||
//
|
||||
// 4. wraps errors;
|
||||
//
|
||||
// 5. enforces reasonable timeouts (
|
||||
// see https://github.com/ooni/probe/issues/1726).
|
||||
func NewResolverSystem(logger Logger) Resolver {
|
||||
return &resolverIDNA{
|
||||
Resolver: &resolverLogger{
|
||||
|
@ -48,7 +62,6 @@ type resolverSystem struct {
|
|||
|
||||
var _ Resolver = &resolverSystem{}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r *resolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
// This code forces adding a shorter timeout to the domain name
|
||||
// resolutions when using the system resolver. We have seen cases
|
||||
|
@ -89,24 +102,18 @@ func (r *resolverSystem) lookupHost() func(ctx context.Context, domain string) (
|
|||
return net.DefaultResolver.LookupHost
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network.
|
||||
func (r *resolverSystem) Network() string {
|
||||
return "system"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address.
|
||||
func (r *resolverSystem) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Resolver.CloseIdleConnections.
|
||||
func (r *resolverSystem) CloseIdleConnections() {
|
||||
// nothing
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
// DefaultResolver is the resolver we use by default.
|
||||
var DefaultResolver = &resolverSystem{}
|
||||
|
||||
// resolverLogger is a resolver that emits events
|
||||
type resolverLogger struct {
|
||||
Resolver
|
||||
|
@ -115,7 +122,6 @@ type resolverLogger struct {
|
|||
|
||||
var _ Resolver = &resolverLogger{}
|
||||
|
||||
// LookupHost returns the IP addresses of a host
|
||||
func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
r.Logger.Debugf("resolve %s...", hostname)
|
||||
start := time.Now()
|
||||
|
@ -136,7 +142,6 @@ type resolverIDNA struct {
|
|||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
host, err := idna.ToASCII(hostname)
|
||||
if err != nil {
|
||||
|
@ -151,7 +156,6 @@ type resolverShortCircuitIPAddr struct {
|
|||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
|
@ -166,24 +170,20 @@ var ErrNoResolver = errors.New("no configured resolver")
|
|||
// domain names to IP addresses and always returns ErrNoResolver.
|
||||
type nullResolver struct{}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) {
|
||||
return nil, ErrNoResolver
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network.
|
||||
func (r *nullResolver) Network() string {
|
||||
return "null"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address.
|
||||
func (r *nullResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements Resolver.CloseIdleConnections.
|
||||
func (r *nullResolver) CloseIdleConnections() {
|
||||
// nothing
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
// resolverErrWrapper is a Resolver that knows about wrapping errors.
|
||||
|
@ -193,7 +193,6 @@ type resolverErrWrapper struct {
|
|||
|
||||
var _ Resolver = &resolverErrWrapper{}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
|
|
|
@ -15,6 +15,18 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestNewResolverSystem(t *testing.T) {
|
||||
resolver := NewResolverSystem(log.Log)
|
||||
idna := resolver.(*resolverIDNA)
|
||||
logger := idna.Resolver.(*resolverLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||
_ = errWrapper.Resolver.(*resolverSystem)
|
||||
}
|
||||
|
||||
func TestResolverSystem(t *testing.T) {
|
||||
t.Run("Network and Address", func(t *testing.T) {
|
||||
r := &resolverSystem{}
|
||||
|
@ -26,16 +38,9 @@ func TestResolverSystem(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("works as intended", func(t *testing.T) {
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
r := &resolverSystem{}
|
||||
defer r.CloseIdleConnections()
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil result here")
|
||||
}
|
||||
r.CloseIdleConnections() // to cover it
|
||||
})
|
||||
|
||||
t.Run("check default timeout", func(t *testing.T) {
|
||||
|
@ -45,7 +50,30 @@ func TestResolverSystem(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("check default lookup host func not nil", func(t *testing.T) {
|
||||
r := &resolverSystem{}
|
||||
if r.lookupHost() == nil {
|
||||
t.Fatal("expected non-nil func here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("with success", func(t *testing.T) {
|
||||
r := &resolverSystem{
|
||||
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return []string{"8.8.8.8"}, nil
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
addrs, err := r.LookupHost(ctx, "example.antani")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("invalid addrs")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with timeout and success", func(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
@ -111,9 +139,15 @@ func TestResolverSystem(t *testing.T) {
|
|||
func TestResolverLogger(t *testing.T) {
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("with success", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
expected := []string{"1.1.1.1"}
|
||||
r := resolverLogger{
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return expected, nil
|
||||
|
@ -127,12 +161,21 @@ func TestResolverLogger(t *testing.T) {
|
|||
if diff := cmp.Diff(expected, addrs); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with failure", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
expected := errors.New("mocked error")
|
||||
r := resolverLogger{
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return nil, expected
|
||||
|
@ -146,6 +189,9 @@ func TestResolverLogger(t *testing.T) {
|
|||
if addrs != nil {
|
||||
t.Fatal("expected nil addr here")
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -193,18 +239,6 @@ func TestResolverIDNA(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNewResolverSystem(t *testing.T) {
|
||||
resolver := NewResolverSystem(log.Log)
|
||||
idna := resolver.(*resolverIDNA)
|
||||
logger := idna.Resolver.(*resolverLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||
_ = errWrapper.Resolver.(*resolverSystem)
|
||||
}
|
||||
|
||||
func TestResolverShortCircuitIPAddr(t *testing.T) {
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("with IP addr", func(t *testing.T) {
|
||||
|
@ -261,7 +295,7 @@ func TestNullResolver(t *testing.T) {
|
|||
if r.Address() != "" {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
r.CloseIdleConnections() // should not crash
|
||||
r.CloseIdleConnections() // for coverage
|
||||
}
|
||||
|
||||
func TestResolverErrWrapper(t *testing.T) {
|
||||
|
|
|
@ -99,7 +99,7 @@ func ConfigureTLSVersion(config *tls.Config, version string) error {
|
|||
config.MinVersion = tls.VersionTLS10
|
||||
config.MaxVersion = tls.VersionTLS10
|
||||
case "":
|
||||
// nothing
|
||||
// nothing to do
|
||||
default:
|
||||
return ErrInvalidTLSVersion
|
||||
}
|
||||
|
@ -119,7 +119,7 @@ type TLSHandshaker interface {
|
|||
// the given config. This function DOES NOT take ownership of the connection
|
||||
// and it's your responsibility to close it on failure.
|
||||
//
|
||||
// The returned connection will always implement the TLSConn interface
|
||||
// QUIRK: The returned connection will always implement the TLSConn interface
|
||||
// exposed by this package. A future version of this interface will instead
|
||||
// return directly a TLSConn and remove the ConnectionState param.
|
||||
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
|
||||
|
@ -128,10 +128,21 @@ type TLSHandshaker interface {
|
|||
|
||||
// NewTLSHandshakerStdlib creates a new TLS handshaker using the
|
||||
// go standard library to create TLS connections.
|
||||
//
|
||||
// The handshaker guarantees:
|
||||
//
|
||||
// 1. logging
|
||||
//
|
||||
// 2. error wrapping
|
||||
func NewTLSHandshakerStdlib(logger Logger) TLSHandshaker {
|
||||
return newTLSHandshaker(&tlsHandshakerConfigurable{}, logger)
|
||||
}
|
||||
|
||||
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
|
||||
func newTLSHandshaker(th TLSHandshaker, logger Logger) TLSHandshaker {
|
||||
return &tlsHandshakerLogger{
|
||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &tlsHandshakerConfigurable{},
|
||||
TLSHandshaker: th,
|
||||
},
|
||||
Logger: logger,
|
||||
}
|
||||
|
@ -191,11 +202,8 @@ var defaultTLSHandshaker = &tlsHandshakerConfigurable{}
|
|||
|
||||
// tlsHandshakerLogger is a TLSHandshaker with logging.
|
||||
type tlsHandshakerLogger struct {
|
||||
// TLSHandshaker is the underlying handshaker.
|
||||
TLSHandshaker TLSHandshaker
|
||||
|
||||
// Logger is the underlying logger.
|
||||
Logger Logger
|
||||
TLSHandshaker
|
||||
Logger
|
||||
}
|
||||
|
||||
var _ TLSHandshaker = &tlsHandshakerLogger{}
|
||||
|
|
|
@ -118,10 +118,22 @@ func TestConfigureTLSVersion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||
th := NewTLSHandshakerStdlib(log.Log)
|
||||
logger := th.(*tlsHandshakerLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if configurable.NewConn != nil {
|
||||
t.Fatal("expected nil NewConn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
t.Run("Handshake", func(t *testing.T) {
|
||||
t.Run("with error", func(t *testing.T) {
|
||||
|
||||
var times []time.Time
|
||||
h := &tlsHandshakerConfigurable{}
|
||||
tcpConn := &mocks.Conn{
|
||||
|
@ -230,13 +242,19 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
|||
func TestTLSHandshakerLogger(t *testing.T) {
|
||||
t.Run("Handshake", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
th := &tlsHandshakerLogger{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return tls.Client(conn, config), tls.ConnectionState{}, nil
|
||||
},
|
||||
},
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
}
|
||||
conn := &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
|
@ -255,9 +273,18 @@ func TestTLSHandshakerLogger(t *testing.T) {
|
|||
if !reflect.ValueOf(connState).IsZero() {
|
||||
t.Fatal("expected zero ConnectionState here")
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
expected := errors.New("mocked error")
|
||||
th := &tlsHandshakerLogger{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
|
@ -265,7 +292,7 @@ func TestTLSHandshakerLogger(t *testing.T) {
|
|||
return nil, tls.ConnectionState{}, expected
|
||||
},
|
||||
},
|
||||
Logger: log.Log,
|
||||
Logger: lo,
|
||||
}
|
||||
conn := &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
|
@ -284,10 +311,29 @@ func TestTLSHandshakerLogger(t *testing.T) {
|
|||
if !reflect.ValueOf(connState).IsZero() {
|
||||
t.Fatal("expected zero ConnectionState here")
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewTLSDialer(t *testing.T) {
|
||||
d := &mocks.Dialer{}
|
||||
th := &mocks.TLSHandshaker{}
|
||||
dialer := NewTLSDialer(d, th)
|
||||
tlsd := dialer.(*tlsDialer)
|
||||
if tlsd.Config == nil {
|
||||
t.Fatal("unexpected config")
|
||||
}
|
||||
if tlsd.Dialer != d {
|
||||
t.Fatal("unexpected dialer")
|
||||
}
|
||||
if tlsd.TLSHandshaker != th {
|
||||
t.Fatal("invalid handshaker")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialer(t *testing.T) {
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
|
@ -439,35 +485,6 @@ func TestTLSDialer(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||
th := NewTLSHandshakerStdlib(log.Log)
|
||||
logger := th.(*tlsHandshakerLogger)
|
||||
if logger.Logger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if configurable.NewConn != nil {
|
||||
t.Fatal("expected nil NewConn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSDialer(t *testing.T) {
|
||||
d := &mocks.Dialer{}
|
||||
th := &mocks.TLSHandshaker{}
|
||||
dialer := NewTLSDialer(d, th)
|
||||
tlsd := dialer.(*tlsDialer)
|
||||
if tlsd.Config == nil {
|
||||
t.Fatal("unexpected config")
|
||||
}
|
||||
if tlsd.Dialer != d {
|
||||
t.Fatal("unexpected dialer")
|
||||
}
|
||||
if tlsd.TLSHandshaker != th {
|
||||
t.Fatal("invalid handshaker")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSingleUseTLSDialer(t *testing.T) {
|
||||
conn := &mocks.TLSConn{}
|
||||
d := NewSingleUseTLSDialer(conn)
|
||||
|
|
|
@ -11,15 +11,18 @@ import (
|
|||
|
||||
// NewTLSHandshakerUTLS creates a new TLS handshaker using the
|
||||
// gitlab.com/yawning/utls library to create TLS conns.
|
||||
//
|
||||
// The handshaker guarantees:
|
||||
//
|
||||
// 1. logging
|
||||
//
|
||||
// 2. error wrapping
|
||||
//
|
||||
// Passing a nil `id` will make this function panic.
|
||||
func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
|
||||
return &tlsHandshakerLogger{
|
||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &tlsHandshakerConfigurable{
|
||||
NewConn: newConnUTLS(id),
|
||||
},
|
||||
},
|
||||
Logger: logger,
|
||||
}
|
||||
return newTLSHandshaker(&tlsHandshakerConfigurable{
|
||||
NewConn: newConnUTLS(id),
|
||||
}, logger)
|
||||
}
|
||||
|
||||
// utlsConn implements TLSConn and uses a utls UConn as its underlying connection
|
||||
|
|
Loading…
Reference in New Issue
Block a user