From 69fd0c5119e3b0e0a9eaadece0e7d1a8b79f1a0f Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Tue, 31 May 2022 20:02:11 +0200 Subject: [PATCH] refactor(netxlite): allow easy dialer chain customization (#770) This diff modifies the construction of a dialer to allow one to insert custom dialer wrappers into the dialers chain. The point of the chain in which we allow custom wrappers is the optimal one for connect, read, and write measurements. This new design is better than the previous netx design since we don't need to construct the whole chain manually now. The work in this diff is part of the effort to make engine/netx just a tiny wrapper around netxlite. See https://github.com/ooni/probe/issues/2121. --- internal/engine/netx/dialer/dialer.go | 30 +++-- internal/engine/netx/dialer/dialer_test.go | 28 +--- internal/netxlite/dialer.go | 147 +++++++++++++++------ internal/netxlite/dialer_test.go | 49 ++++--- internal/netxlite/legacy.go | 3 +- internal/netxlite/quic.go | 2 +- internal/netxlite/quic_test.go | 2 +- internal/netxlite/resolver.go | 16 +-- internal/netxlite/resolver_test.go | 6 +- internal/netxlite/tls_test.go | 2 +- 10 files changed, 174 insertions(+), 111 deletions(-) diff --git a/internal/engine/netx/dialer/dialer.go b/internal/engine/netx/dialer/dialer.go index 17e093c..f134c56 100644 --- a/internal/engine/netx/dialer/dialer.go +++ b/internal/engine/netx/dialer/dialer.go @@ -50,23 +50,25 @@ type Config struct { // New creates a new Dialer from the specified config and resolver. func New(config *Config, resolver model.Resolver) model.Dialer { - var d model.Dialer = &netxlite.ErrorWrapperDialer{Dialer: netxlite.DefaultDialer} + var logger model.DebugLogger = model.DiscardLogger if config.Logger != nil { - d = &netxlite.DialerLogger{ - Dialer: d, - DebugLogger: config.Logger, - } + logger = config.Logger } - if config.DialSaver != nil { - d = &saverDialer{Dialer: d, Saver: config.DialSaver} - } - if config.ReadWriteSaver != nil { - d = &saverConnDialer{Dialer: d, Saver: config.ReadWriteSaver} - } - d = &netxlite.DialerResolver{ - Resolver: resolver, - Dialer: d, + modifiers := []netxlite.DialerWrapper{ + func(dialer model.Dialer) model.Dialer { + if config.DialSaver != nil { + dialer = &saverDialer{Dialer: dialer, Saver: config.DialSaver} + } + return dialer + }, + func(dialer model.Dialer) model.Dialer { + if config.ReadWriteSaver != nil { + dialer = &saverConnDialer{Dialer: dialer, Saver: config.ReadWriteSaver} + } + return dialer + }, } + d := netxlite.NewDialerWithResolver(logger, resolver, modifiers...) d = &netxlite.MaybeProxyDialer{ProxyURL: config.ProxyURL, Dialer: d} if config.ContextByteCounting { d = &bytecounter.ContextAwareDialer{Dialer: d} diff --git a/internal/engine/netx/dialer/dialer_test.go b/internal/engine/netx/dialer/dialer_test.go index 41bd5b5..2ecf349 100644 --- a/internal/engine/netx/dialer/dialer_test.go +++ b/internal/engine/netx/dialer/dialer_test.go @@ -24,34 +24,12 @@ func TestNewCreatesTheExpectedChain(t *testing.T) { if !ok { t.Fatal("not a byteCounterDialer") } - pd, ok := bcd.Dialer.(*netxlite.MaybeProxyDialer) + _, ok = bcd.Dialer.(*netxlite.MaybeProxyDialer) if !ok { t.Fatal("not a proxyDialer") } - dnsd, ok := pd.Dialer.(*netxlite.DialerResolver) - if !ok { - t.Fatal("not a dnsDialer") - } - scd, ok := dnsd.Dialer.(*saverConnDialer) - if !ok { - t.Fatal("not a saverConnDialer") - } - sd, ok := scd.Dialer.(*saverDialer) - if !ok { - t.Fatal("not a saverDialer") - } - ld, ok := sd.Dialer.(*netxlite.DialerLogger) - if !ok { - t.Fatal("not a loggingDialer") - } - ewd, ok := ld.Dialer.(*netxlite.ErrorWrapperDialer) - if !ok { - t.Fatal("not an errorWrappingDialer") - } - _, ok = ewd.Dialer.(*netxlite.DialerSystem) - if !ok { - t.Fatal("not a DialerSystem") - } + // We can safely stop here: the rest is tested by + // the internal/netxlite package } func TestDialerNewSuccess(t *testing.T) { diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index 43cade5..a0e3c5e 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -14,47 +14,112 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// NewDialerWithResolver calls WrapDialer for the stdlib dialer. -func NewDialerWithResolver(logger model.DebugLogger, resolver model.Resolver) model.Dialer { - return WrapDialer(logger, resolver, &dialerSystem{}) +// DialerWrapper is a function that allows you to customize the kind of Dialer returned +// by WrapDialer, NewDialerWithResolver, and NewDialerWithoutResolver. +type DialerWrapper func(dialer model.Dialer) model.Dialer + +// NewDialerWithResolver is equivalent to calling WrapDialer with +// the dialer argument being equal to &DialerSystem{}. +func NewDialerWithResolver(dl model.DebugLogger, r model.Resolver, w ...DialerWrapper) model.Dialer { + return WrapDialer(dl, r, &DialerSystem{}, w...) } -// WrapDialer creates a new Dialer that wraps the given -// Dialer. The returned Dialer has the following properties: +// WrapDialer wraps an existing Dialer to add extra functionality +// such as separting DNS lookup and connecting, error wrapping, logging, etc. // -// 1. logs events using the given logger; +// When possible use NewDialerWithResolver or NewDialerWithoutResolver +// instead of using this rather low-level function. // -// 2. resolves domain names using the givern resolver; +// Arguments // -// 3. when the resolver is not a "null" resolver, -// each available enpoint is tried -// sequentially. On error, the code will return what it believes -// to be the most representative error in the pack. Most often, -// the first error that occurred. Choosing the -// error to return using this logic is a QUIRK that we owe -// to the original implementation of netx. We cannot change -// this behavior until we refactor legacy code using it. +// 1. logger is used to emit debug messages (MUST NOT be nil); // -// Removing this quirk from the codebase is documented as -// TODO(https://github.com/ooni/probe/issues/1779). +// 2. resolver is the resolver to use when dialing for endpoint +// addresses containing domain names (MUST NOT be nil); // -// 4. wraps errors; +// 3. baseDialer is the dialer to wrap (MUST NOT be nil); // -// 5. has a configured connect timeout; +// 4. wrappers is a list of zero or more functions allowing you to +// modify the behavior of the returned dialer (see below). // -// 6. if a dialer wraps a resolver, the dialer will forward -// the CloseIdleConnection call to its resolver (which is -// instrumental to manage a DoH resolver connections properly). +// Return value // -// In general, do not use WrapDialer directly but try to use -// more high-level factories, e.g., NewDialerWithResolver. -func WrapDialer(logger model.DebugLogger, resolver model.Resolver, dialer model.Dialer) model.Dialer { +// The returned dialer is an opaque type consisting of the composition of +// several simple dialers. The following pseudo code illustrates the general +// behavior of the returned composed dialer: +// +// addrs, err := dnslookup() +// if err != nil { +// return nil, err +// } +// errors := []error{} +// for _, a := range addrs { +// conn, err := tcpconnect(a) +// if err != nil { +// errors = append(errors, err) +// continue +// } +// return conn, nil +// } +// return nil, errors[0] +// +// +// The following table describes the structure of the returned dialer: +// +// +-------+-----------------+------------------------------------------+ +// | Index | Name | Description | +// +-------+-----------------+------------------------------------------+ +// | 0 | base | the baseDialer argument | +// +-------+-----------------+------------------------------------------+ +// | 1 | errWrapper | wraps Go errors to be consistent with | +// | | | OONI df-007-errors spec | +// +-------+-----------------+------------------------------------------+ +// | 2 | ??? | if there are wrappers, result of calling | +// | | | the first one on the errWrapper dialer | +// +-------+-----------------+------------------------------------------+ +// | ... | ... | ... | +// +-------+-----------------+------------------------------------------+ +// | N | ??? | if there are wrappers, result of calling | +// | | | the last one on the N-1 dialer | +// +-------+-----------------+------------------------------------------+ +// | N+1 | logger (inner) | logs TCP connect operations | +// +-------+-----------------+------------------------------------------+ +// | N+2 | resolver | DNS lookup and try connect each IP in | +// | | | sequence until one of them succeeds | +// +-------+-----------------+------------------------------------------+ +// | N+3 | logger (outer) | logs the overall dial operation | +// +-------+-----------------+------------------------------------------+ +// +// The list of wrappers allows to insert modified dialers in the correct +// place for observing and saving I/O events (connect, read, etc.). +// +// Remarks +// +// When the resolver is &NullResolver{} any attempt to perform DNS resolutions +// in the dialer at index N+2 will fail with ErrNoResolver. +// +// Otherwise, the dialer at index N+2 will try each resolver IP address +// sequentially. In case of failure, such a resolver will return the first +// error that occurred. This implementation strategy is a QUIRK that is +// documented at TODO(https://github.com/ooni/probe/issues/1779). +// +// If the baseDialer is &DialerSystem{}, there will be a fixed TCP connect +// timeout for each connect operation. Because there may be multiple IP +// addresses per dial, the overall timeout would be a multiple of the timeout +// of a single connect operation. You may want to use the context to reduce +// the overall time spent trying all addresses and timing out. +func WrapDialer(logger model.DebugLogger, resolver model.Resolver, + baseDialer model.Dialer, wrappers ...DialerWrapper) (outDialer model.Dialer) { + outDialer = &dialerErrWrapper{ + Dialer: baseDialer, + } + for _, wrapper := range wrappers { + outDialer = wrapper(outDialer) // extend with user-supplied constructors + } return &dialerLogger{ Dialer: &dialerResolver{ Dialer: &dialerLogger{ - Dialer: &dialerErrWrapper{ - Dialer: dialer, - }, + Dialer: outDialer, DebugLogger: logger, operationSuffix: "_address", }, @@ -64,25 +129,25 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver, dialer model. } } -// NewDialerWithoutResolver calls NewDialerWithResolver with a "null" resolver. -// -// The returned dialer fails with ErrNoResolver if passed a domain name. -func NewDialerWithoutResolver(logger model.DebugLogger) model.Dialer { - return NewDialerWithResolver(logger, &nullResolver{}) +// NewDialerWithoutResolver is equivalent to calling NewDialerWithResolver +// with the resolver argument being &NullResolver{}. +func NewDialerWithoutResolver(dl model.DebugLogger, w ...DialerWrapper) model.Dialer { + return NewDialerWithResolver(dl, &NullResolver{}, w...) } -// dialerSystem uses system facilities to perform domain name -// resolution and guarantees we have a dialer timeout. -type dialerSystem struct { - // timeout is the OPTIONAL timeout used for testing. +// DialerSystem is a model.Dialer that users TProxy.NewSimplerDialer +// to construct the new SimpleDialer used for dialing. This dialer has +// a fixed timeout for each connect operation equal to 15 seconds. +type DialerSystem struct { + // timeout is the OPTIONAL timeout (for testing). timeout time.Duration } -var _ model.Dialer = &dialerSystem{} +var _ model.Dialer = &DialerSystem{} const dialerDefaultTimeout = 15 * time.Second -func (d *dialerSystem) newUnderlyingDialer() model.SimpleDialer { +func (d *DialerSystem) newUnderlyingDialer() model.SimpleDialer { t := d.timeout if t <= 0 { t = dialerDefaultTimeout @@ -90,11 +155,11 @@ func (d *dialerSystem) newUnderlyingDialer() model.SimpleDialer { return TProxy.NewSimpleDialer(t) } -func (d *dialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.newUnderlyingDialer().DialContext(ctx, network, address) } -func (d *dialerSystem) CloseIdleConnections() { +func (d *DialerSystem) CloseIdleConnections() { // nothing to do here } diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index f421a1c..795a09f 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -11,32 +11,51 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" ) +type extensionDialerFirst struct { + model.Dialer +} + +type extensionDialerSecond struct { + model.Dialer +} + func TestNewDialer(t *testing.T) { t.Run("produces a chain with the expected types", func(t *testing.T) { - d := NewDialerWithoutResolver(log.Log) + modifiers := []DialerWrapper{ + func(dialer model.Dialer) model.Dialer { + return &extensionDialerFirst{dialer} + }, + func(dialer model.Dialer) model.Dialer { + return &extensionDialerSecond{dialer} + }, + } + d := NewDialerWithoutResolver(log.Log, modifiers...) logger := d.(*dialerLogger) if logger.DebugLogger != log.Log { t.Fatal("invalid logger") } reso := logger.Dialer.(*dialerResolver) - if _, okay := reso.Resolver.(*nullResolver); !okay { + if _, okay := reso.Resolver.(*NullResolver); !okay { t.Fatal("invalid Resolver type") } logger = reso.Dialer.(*dialerLogger) if logger.DebugLogger != log.Log { t.Fatal("invalid logger") } - errWrapper := logger.Dialer.(*dialerErrWrapper) - _ = errWrapper.Dialer.(*dialerSystem) + ext2 := logger.Dialer.(*extensionDialerSecond) + ext1 := ext2.Dialer.(*extensionDialerFirst) + errWrapper := ext1.Dialer.(*dialerErrWrapper) + _ = errWrapper.Dialer.(*DialerSystem) }) } func TestDialerSystem(t *testing.T) { t.Run("has a default timeout", func(t *testing.T) { - d := &dialerSystem{} + d := &DialerSystem{} ud := d.newUnderlyingDialer() if ud.(*net.Dialer).Timeout != dialerDefaultTimeout { t.Fatal("unexpected default timeout") @@ -45,7 +64,7 @@ func TestDialerSystem(t *testing.T) { t.Run("we can change the timeout for testing", func(t *testing.T) { const smaller = 1 * time.Second - d := &dialerSystem{timeout: smaller} + d := &DialerSystem{timeout: smaller} ud := d.newUnderlyingDialer() if ud.(*net.Dialer).Timeout != smaller { t.Fatal("unexpected timeout") @@ -53,13 +72,13 @@ func TestDialerSystem(t *testing.T) { }) t.Run("CloseIdleConnections", func(t *testing.T) { - d := &dialerSystem{} + d := &DialerSystem{} d.CloseIdleConnections() // to avoid missing coverage }) t.Run("DialContext", func(t *testing.T) { t.Run("with canceled context", func(t *testing.T) { - d := &dialerSystem{} + d := &DialerSystem{} ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately! conn, err := d.DialContext(ctx, "tcp", "8.8.8.8:443") @@ -73,7 +92,7 @@ func TestDialerSystem(t *testing.T) { t.Run("enforces the configured timeout", func(t *testing.T) { const timeout = 1 * time.Nanosecond - d := &dialerSystem{timeout: timeout} + d := &DialerSystem{timeout: timeout} ctx := context.Background() start := time.Now() conn, err := d.DialContext(ctx, "tcp", "dns.google:443") @@ -95,7 +114,7 @@ func TestDialerResolver(t *testing.T) { t.Run("DialContext", func(t *testing.T) { t.Run("fails without a port", func(t *testing.T) { d := &dialerResolver{ - Dialer: &dialerSystem{}, + Dialer: &DialerSystem{}, Resolver: &resolverSystem{}, } const missingPort = "ooni.nu" @@ -115,7 +134,7 @@ func TestDialerResolver(t *testing.T) { return nil, io.EOF }, }, - Resolver: &nullResolver{}, + Resolver: &NullResolver{}, } conn, err := d.DialContext(context.Background(), "tcp", "1.1.1.1:853") if !errors.Is(err, io.EOF) { @@ -335,8 +354,8 @@ func TestDialerResolver(t *testing.T) { t.Run("lookupHost", func(t *testing.T) { t.Run("handles addresses correctly", func(t *testing.T) { dialer := &dialerResolver{ - Dialer: &dialerSystem{}, - Resolver: &nullResolver{}, + Dialer: &DialerSystem{}, + Resolver: &NullResolver{}, } addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1") if err != nil { @@ -349,8 +368,8 @@ func TestDialerResolver(t *testing.T) { t.Run("fails correctly on lookup error", func(t *testing.T) { dialer := &dialerResolver{ - Dialer: &dialerSystem{}, - Resolver: &nullResolver{}, + Dialer: &DialerSystem{}, + Resolver: &NullResolver{}, } ctx := context.Background() conn, err := dialer.DialContext(ctx, "tcp", "dns.google.com:853") diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 6bbf6f1..0e72d09 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -8,7 +8,7 @@ package netxlite // // Deprecated: do not use these names in new code. var ( - DefaultDialer = &dialerSystem{} + DefaultDialer = &DialerSystem{} DefaultTLSHandshaker = defaultTLSHandshaker NewConnUTLS = newConnUTLS DefaultResolver = &resolverSystem{} @@ -36,7 +36,6 @@ type ( ResolverIDNA = resolverIDNA TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerLogger = tlsHandshakerLogger - DialerSystem = dialerSystem TLSDialerLegacy = tlsDialer AddressResolver = resolverShortCircuitIPAddr ) diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index 8bfe027..9e41f08 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -79,7 +79,7 @@ func NewQUICDialerWithResolver(listener model.QUICListener, // an address containing a domain name, the dial will fail with // the ErrNoResolver failure. func NewQUICDialerWithoutResolver(listener model.QUICListener, logger model.DebugLogger) model.QUICDialer { - return NewQUICDialerWithResolver(listener, logger, &nullResolver{}) + return NewQUICDialerWithResolver(listener, logger, &NullResolver{}) } // quicDialerQUICGo dials using the lucas-clemente/quic-go library. diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index f974911..56951e4 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -30,7 +30,7 @@ func TestNewQUICDialer(t *testing.T) { t.Fatal("invalid logger") } resolver := logger.Dialer.(*quicDialerResolver) - if _, okay := resolver.Resolver.(*nullResolver); !okay { + if _, okay := resolver.Resolver.(*NullResolver); !okay { t.Fatal("invalid resolver type") } logger = resolver.Dialer.(*quicDialerLogger) diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolver.go index c0898d7..ac97c22 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolver.go @@ -334,32 +334,32 @@ func isIPv6(candidate string) bool { // since they can only dial for endpoints containing IP addresses. var ErrNoResolver = errors.New("no configured resolver") -// nullResolver is a resolver that is not capable of resolving +// NullResolver is a resolver that is not capable of resolving // domain names to IP addresses and always returns ErrNoResolver. -type nullResolver struct{} +type NullResolver struct{} -func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) { +func (r *NullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) { return nil, ErrNoResolver } -func (r *nullResolver) Network() string { +func (r *NullResolver) Network() string { return "null" } -func (r *nullResolver) Address() string { +func (r *NullResolver) Address() string { return "" } -func (r *nullResolver) CloseIdleConnections() { +func (r *NullResolver) CloseIdleConnections() { // nothing to do } -func (r *nullResolver) LookupHTTPS( +func (r *NullResolver) LookupHTTPS( ctx context.Context, domain string) (*model.HTTPSSvc, error) { return nil, ErrNoResolver } -func (r *nullResolver) LookupNS( +func (r *NullResolver) LookupNS( ctx context.Context, domain string) ([]*net.NS, error) { return nil, ErrNoResolver } diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolver_test.go index 17230b8..b0f265d 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolver_test.go @@ -807,7 +807,7 @@ func TestIsIPv6(t *testing.T) { func TestNullResolver(t *testing.T) { t.Run("LookupHost", func(t *testing.T) { - r := &nullResolver{} + r := &NullResolver{} ctx := context.Background() addrs, err := r.LookupHost(ctx, "dns.google") if !errors.Is(err, ErrNoResolver) { @@ -826,7 +826,7 @@ func TestNullResolver(t *testing.T) { }) t.Run("LookupHTTPS", func(t *testing.T) { - r := &nullResolver{} + r := &NullResolver{} ctx := context.Background() addrs, err := r.LookupHTTPS(ctx, "dns.google") if !errors.Is(err, ErrNoResolver) { @@ -845,7 +845,7 @@ func TestNullResolver(t *testing.T) { }) t.Run("LookupNS", func(t *testing.T) { - r := &nullResolver{} + r := &NullResolver{} ctx := context.Background() ns, err := r.LookupNS(ctx, "dns.google") if !errors.Is(err, ErrNoResolver) { diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 828864c..1c0cd88 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -392,7 +392,7 @@ func TestTLSDialer(t *testing.T) { t.Run("failure dialing", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail - dialer := tlsDialer{Dialer: &dialerSystem{}} + dialer := tlsDialer{Dialer: &DialerSystem{}} conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { t.Fatal("not the error we expected", err)