diff --git a/internal/engine/netx/dialer/dialer.go b/internal/engine/netx/dialer/dialer.go index 17e093c..f134c56 100644 --- a/internal/engine/netx/dialer/dialer.go +++ b/internal/engine/netx/dialer/dialer.go @@ -50,23 +50,25 @@ type Config struct { // New creates a new Dialer from the specified config and resolver. func New(config *Config, resolver model.Resolver) model.Dialer { - var d model.Dialer = &netxlite.ErrorWrapperDialer{Dialer: netxlite.DefaultDialer} + var logger model.DebugLogger = model.DiscardLogger if config.Logger != nil { - d = &netxlite.DialerLogger{ - Dialer: d, - DebugLogger: config.Logger, - } + logger = config.Logger } - if config.DialSaver != nil { - d = &saverDialer{Dialer: d, Saver: config.DialSaver} - } - if config.ReadWriteSaver != nil { - d = &saverConnDialer{Dialer: d, Saver: config.ReadWriteSaver} - } - d = &netxlite.DialerResolver{ - Resolver: resolver, - Dialer: d, + modifiers := []netxlite.DialerWrapper{ + func(dialer model.Dialer) model.Dialer { + if config.DialSaver != nil { + dialer = &saverDialer{Dialer: dialer, Saver: config.DialSaver} + } + return dialer + }, + func(dialer model.Dialer) model.Dialer { + if config.ReadWriteSaver != nil { + dialer = &saverConnDialer{Dialer: dialer, Saver: config.ReadWriteSaver} + } + return dialer + }, } + d := netxlite.NewDialerWithResolver(logger, resolver, modifiers...) d = &netxlite.MaybeProxyDialer{ProxyURL: config.ProxyURL, Dialer: d} if config.ContextByteCounting { d = &bytecounter.ContextAwareDialer{Dialer: d} diff --git a/internal/engine/netx/dialer/dialer_test.go b/internal/engine/netx/dialer/dialer_test.go index 41bd5b5..2ecf349 100644 --- a/internal/engine/netx/dialer/dialer_test.go +++ b/internal/engine/netx/dialer/dialer_test.go @@ -24,34 +24,12 @@ func TestNewCreatesTheExpectedChain(t *testing.T) { if !ok { t.Fatal("not a byteCounterDialer") } - pd, ok := bcd.Dialer.(*netxlite.MaybeProxyDialer) + _, ok = bcd.Dialer.(*netxlite.MaybeProxyDialer) if !ok { t.Fatal("not a proxyDialer") } - dnsd, ok := pd.Dialer.(*netxlite.DialerResolver) - if !ok { - t.Fatal("not a dnsDialer") - } - scd, ok := dnsd.Dialer.(*saverConnDialer) - if !ok { - t.Fatal("not a saverConnDialer") - } - sd, ok := scd.Dialer.(*saverDialer) - if !ok { - t.Fatal("not a saverDialer") - } - ld, ok := sd.Dialer.(*netxlite.DialerLogger) - if !ok { - t.Fatal("not a loggingDialer") - } - ewd, ok := ld.Dialer.(*netxlite.ErrorWrapperDialer) - if !ok { - t.Fatal("not an errorWrappingDialer") - } - _, ok = ewd.Dialer.(*netxlite.DialerSystem) - if !ok { - t.Fatal("not a DialerSystem") - } + // We can safely stop here: the rest is tested by + // the internal/netxlite package } func TestDialerNewSuccess(t *testing.T) { diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index 43cade5..a0e3c5e 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -14,47 +14,112 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// NewDialerWithResolver calls WrapDialer for the stdlib dialer. -func NewDialerWithResolver(logger model.DebugLogger, resolver model.Resolver) model.Dialer { - return WrapDialer(logger, resolver, &dialerSystem{}) +// DialerWrapper is a function that allows you to customize the kind of Dialer returned +// by WrapDialer, NewDialerWithResolver, and NewDialerWithoutResolver. +type DialerWrapper func(dialer model.Dialer) model.Dialer + +// NewDialerWithResolver is equivalent to calling WrapDialer with +// the dialer argument being equal to &DialerSystem{}. +func NewDialerWithResolver(dl model.DebugLogger, r model.Resolver, w ...DialerWrapper) model.Dialer { + return WrapDialer(dl, r, &DialerSystem{}, w...) } -// WrapDialer creates a new Dialer that wraps the given -// Dialer. The returned Dialer has the following properties: +// WrapDialer wraps an existing Dialer to add extra functionality +// such as separting DNS lookup and connecting, error wrapping, logging, etc. // -// 1. logs events using the given logger; +// When possible use NewDialerWithResolver or NewDialerWithoutResolver +// instead of using this rather low-level function. // -// 2. resolves domain names using the givern resolver; +// Arguments // -// 3. when the resolver is not a "null" resolver, -// each available enpoint is tried -// sequentially. On error, the code will return what it believes -// to be the most representative error in the pack. Most often, -// the first error that occurred. Choosing the -// error to return using this logic is a QUIRK that we owe -// to the original implementation of netx. We cannot change -// this behavior until we refactor legacy code using it. +// 1. logger is used to emit debug messages (MUST NOT be nil); // -// Removing this quirk from the codebase is documented as -// TODO(https://github.com/ooni/probe/issues/1779). +// 2. resolver is the resolver to use when dialing for endpoint +// addresses containing domain names (MUST NOT be nil); // -// 4. wraps errors; +// 3. baseDialer is the dialer to wrap (MUST NOT be nil); // -// 5. has a configured connect timeout; +// 4. wrappers is a list of zero or more functions allowing you to +// modify the behavior of the returned dialer (see below). // -// 6. if a dialer wraps a resolver, the dialer will forward -// the CloseIdleConnection call to its resolver (which is -// instrumental to manage a DoH resolver connections properly). +// Return value // -// In general, do not use WrapDialer directly but try to use -// more high-level factories, e.g., NewDialerWithResolver. -func WrapDialer(logger model.DebugLogger, resolver model.Resolver, dialer model.Dialer) model.Dialer { +// The returned dialer is an opaque type consisting of the composition of +// several simple dialers. The following pseudo code illustrates the general +// behavior of the returned composed dialer: +// +// addrs, err := dnslookup() +// if err != nil { +// return nil, err +// } +// errors := []error{} +// for _, a := range addrs { +// conn, err := tcpconnect(a) +// if err != nil { +// errors = append(errors, err) +// continue +// } +// return conn, nil +// } +// return nil, errors[0] +// +// +// The following table describes the structure of the returned dialer: +// +// +-------+-----------------+------------------------------------------+ +// | Index | Name | Description | +// +-------+-----------------+------------------------------------------+ +// | 0 | base | the baseDialer argument | +// +-------+-----------------+------------------------------------------+ +// | 1 | errWrapper | wraps Go errors to be consistent with | +// | | | OONI df-007-errors spec | +// +-------+-----------------+------------------------------------------+ +// | 2 | ??? | if there are wrappers, result of calling | +// | | | the first one on the errWrapper dialer | +// +-------+-----------------+------------------------------------------+ +// | ... | ... | ... | +// +-------+-----------------+------------------------------------------+ +// | N | ??? | if there are wrappers, result of calling | +// | | | the last one on the N-1 dialer | +// +-------+-----------------+------------------------------------------+ +// | N+1 | logger (inner) | logs TCP connect operations | +// +-------+-----------------+------------------------------------------+ +// | N+2 | resolver | DNS lookup and try connect each IP in | +// | | | sequence until one of them succeeds | +// +-------+-----------------+------------------------------------------+ +// | N+3 | logger (outer) | logs the overall dial operation | +// +-------+-----------------+------------------------------------------+ +// +// The list of wrappers allows to insert modified dialers in the correct +// place for observing and saving I/O events (connect, read, etc.). +// +// Remarks +// +// When the resolver is &NullResolver{} any attempt to perform DNS resolutions +// in the dialer at index N+2 will fail with ErrNoResolver. +// +// Otherwise, the dialer at index N+2 will try each resolver IP address +// sequentially. In case of failure, such a resolver will return the first +// error that occurred. This implementation strategy is a QUIRK that is +// documented at TODO(https://github.com/ooni/probe/issues/1779). +// +// If the baseDialer is &DialerSystem{}, there will be a fixed TCP connect +// timeout for each connect operation. Because there may be multiple IP +// addresses per dial, the overall timeout would be a multiple of the timeout +// of a single connect operation. You may want to use the context to reduce +// the overall time spent trying all addresses and timing out. +func WrapDialer(logger model.DebugLogger, resolver model.Resolver, + baseDialer model.Dialer, wrappers ...DialerWrapper) (outDialer model.Dialer) { + outDialer = &dialerErrWrapper{ + Dialer: baseDialer, + } + for _, wrapper := range wrappers { + outDialer = wrapper(outDialer) // extend with user-supplied constructors + } return &dialerLogger{ Dialer: &dialerResolver{ Dialer: &dialerLogger{ - Dialer: &dialerErrWrapper{ - Dialer: dialer, - }, + Dialer: outDialer, DebugLogger: logger, operationSuffix: "_address", }, @@ -64,25 +129,25 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver, dialer model. } } -// NewDialerWithoutResolver calls NewDialerWithResolver with a "null" resolver. -// -// The returned dialer fails with ErrNoResolver if passed a domain name. -func NewDialerWithoutResolver(logger model.DebugLogger) model.Dialer { - return NewDialerWithResolver(logger, &nullResolver{}) +// NewDialerWithoutResolver is equivalent to calling NewDialerWithResolver +// with the resolver argument being &NullResolver{}. +func NewDialerWithoutResolver(dl model.DebugLogger, w ...DialerWrapper) model.Dialer { + return NewDialerWithResolver(dl, &NullResolver{}, w...) } -// dialerSystem uses system facilities to perform domain name -// resolution and guarantees we have a dialer timeout. -type dialerSystem struct { - // timeout is the OPTIONAL timeout used for testing. +// DialerSystem is a model.Dialer that users TProxy.NewSimplerDialer +// to construct the new SimpleDialer used for dialing. This dialer has +// a fixed timeout for each connect operation equal to 15 seconds. +type DialerSystem struct { + // timeout is the OPTIONAL timeout (for testing). timeout time.Duration } -var _ model.Dialer = &dialerSystem{} +var _ model.Dialer = &DialerSystem{} const dialerDefaultTimeout = 15 * time.Second -func (d *dialerSystem) newUnderlyingDialer() model.SimpleDialer { +func (d *DialerSystem) newUnderlyingDialer() model.SimpleDialer { t := d.timeout if t <= 0 { t = dialerDefaultTimeout @@ -90,11 +155,11 @@ func (d *dialerSystem) newUnderlyingDialer() model.SimpleDialer { return TProxy.NewSimpleDialer(t) } -func (d *dialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerSystem) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.newUnderlyingDialer().DialContext(ctx, network, address) } -func (d *dialerSystem) CloseIdleConnections() { +func (d *DialerSystem) CloseIdleConnections() { // nothing to do here } diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index f421a1c..795a09f 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -11,32 +11,51 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" ) +type extensionDialerFirst struct { + model.Dialer +} + +type extensionDialerSecond struct { + model.Dialer +} + func TestNewDialer(t *testing.T) { t.Run("produces a chain with the expected types", func(t *testing.T) { - d := NewDialerWithoutResolver(log.Log) + modifiers := []DialerWrapper{ + func(dialer model.Dialer) model.Dialer { + return &extensionDialerFirst{dialer} + }, + func(dialer model.Dialer) model.Dialer { + return &extensionDialerSecond{dialer} + }, + } + d := NewDialerWithoutResolver(log.Log, modifiers...) logger := d.(*dialerLogger) if logger.DebugLogger != log.Log { t.Fatal("invalid logger") } reso := logger.Dialer.(*dialerResolver) - if _, okay := reso.Resolver.(*nullResolver); !okay { + if _, okay := reso.Resolver.(*NullResolver); !okay { t.Fatal("invalid Resolver type") } logger = reso.Dialer.(*dialerLogger) if logger.DebugLogger != log.Log { t.Fatal("invalid logger") } - errWrapper := logger.Dialer.(*dialerErrWrapper) - _ = errWrapper.Dialer.(*dialerSystem) + ext2 := logger.Dialer.(*extensionDialerSecond) + ext1 := ext2.Dialer.(*extensionDialerFirst) + errWrapper := ext1.Dialer.(*dialerErrWrapper) + _ = errWrapper.Dialer.(*DialerSystem) }) } func TestDialerSystem(t *testing.T) { t.Run("has a default timeout", func(t *testing.T) { - d := &dialerSystem{} + d := &DialerSystem{} ud := d.newUnderlyingDialer() if ud.(*net.Dialer).Timeout != dialerDefaultTimeout { t.Fatal("unexpected default timeout") @@ -45,7 +64,7 @@ func TestDialerSystem(t *testing.T) { t.Run("we can change the timeout for testing", func(t *testing.T) { const smaller = 1 * time.Second - d := &dialerSystem{timeout: smaller} + d := &DialerSystem{timeout: smaller} ud := d.newUnderlyingDialer() if ud.(*net.Dialer).Timeout != smaller { t.Fatal("unexpected timeout") @@ -53,13 +72,13 @@ func TestDialerSystem(t *testing.T) { }) t.Run("CloseIdleConnections", func(t *testing.T) { - d := &dialerSystem{} + d := &DialerSystem{} d.CloseIdleConnections() // to avoid missing coverage }) t.Run("DialContext", func(t *testing.T) { t.Run("with canceled context", func(t *testing.T) { - d := &dialerSystem{} + d := &DialerSystem{} ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately! conn, err := d.DialContext(ctx, "tcp", "8.8.8.8:443") @@ -73,7 +92,7 @@ func TestDialerSystem(t *testing.T) { t.Run("enforces the configured timeout", func(t *testing.T) { const timeout = 1 * time.Nanosecond - d := &dialerSystem{timeout: timeout} + d := &DialerSystem{timeout: timeout} ctx := context.Background() start := time.Now() conn, err := d.DialContext(ctx, "tcp", "dns.google:443") @@ -95,7 +114,7 @@ func TestDialerResolver(t *testing.T) { t.Run("DialContext", func(t *testing.T) { t.Run("fails without a port", func(t *testing.T) { d := &dialerResolver{ - Dialer: &dialerSystem{}, + Dialer: &DialerSystem{}, Resolver: &resolverSystem{}, } const missingPort = "ooni.nu" @@ -115,7 +134,7 @@ func TestDialerResolver(t *testing.T) { return nil, io.EOF }, }, - Resolver: &nullResolver{}, + Resolver: &NullResolver{}, } conn, err := d.DialContext(context.Background(), "tcp", "1.1.1.1:853") if !errors.Is(err, io.EOF) { @@ -335,8 +354,8 @@ func TestDialerResolver(t *testing.T) { t.Run("lookupHost", func(t *testing.T) { t.Run("handles addresses correctly", func(t *testing.T) { dialer := &dialerResolver{ - Dialer: &dialerSystem{}, - Resolver: &nullResolver{}, + Dialer: &DialerSystem{}, + Resolver: &NullResolver{}, } addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1") if err != nil { @@ -349,8 +368,8 @@ func TestDialerResolver(t *testing.T) { t.Run("fails correctly on lookup error", func(t *testing.T) { dialer := &dialerResolver{ - Dialer: &dialerSystem{}, - Resolver: &nullResolver{}, + Dialer: &DialerSystem{}, + Resolver: &NullResolver{}, } ctx := context.Background() conn, err := dialer.DialContext(ctx, "tcp", "dns.google.com:853") diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 6bbf6f1..0e72d09 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -8,7 +8,7 @@ package netxlite // // Deprecated: do not use these names in new code. var ( - DefaultDialer = &dialerSystem{} + DefaultDialer = &DialerSystem{} DefaultTLSHandshaker = defaultTLSHandshaker NewConnUTLS = newConnUTLS DefaultResolver = &resolverSystem{} @@ -36,7 +36,6 @@ type ( ResolverIDNA = resolverIDNA TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerLogger = tlsHandshakerLogger - DialerSystem = dialerSystem TLSDialerLegacy = tlsDialer AddressResolver = resolverShortCircuitIPAddr ) diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index 8bfe027..9e41f08 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -79,7 +79,7 @@ func NewQUICDialerWithResolver(listener model.QUICListener, // an address containing a domain name, the dial will fail with // the ErrNoResolver failure. func NewQUICDialerWithoutResolver(listener model.QUICListener, logger model.DebugLogger) model.QUICDialer { - return NewQUICDialerWithResolver(listener, logger, &nullResolver{}) + return NewQUICDialerWithResolver(listener, logger, &NullResolver{}) } // quicDialerQUICGo dials using the lucas-clemente/quic-go library. diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index f974911..56951e4 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -30,7 +30,7 @@ func TestNewQUICDialer(t *testing.T) { t.Fatal("invalid logger") } resolver := logger.Dialer.(*quicDialerResolver) - if _, okay := resolver.Resolver.(*nullResolver); !okay { + if _, okay := resolver.Resolver.(*NullResolver); !okay { t.Fatal("invalid resolver type") } logger = resolver.Dialer.(*quicDialerLogger) diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolver.go index c0898d7..ac97c22 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolver.go @@ -334,32 +334,32 @@ func isIPv6(candidate string) bool { // since they can only dial for endpoints containing IP addresses. var ErrNoResolver = errors.New("no configured resolver") -// nullResolver is a resolver that is not capable of resolving +// NullResolver is a resolver that is not capable of resolving // domain names to IP addresses and always returns ErrNoResolver. -type nullResolver struct{} +type NullResolver struct{} -func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) { +func (r *NullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) { return nil, ErrNoResolver } -func (r *nullResolver) Network() string { +func (r *NullResolver) Network() string { return "null" } -func (r *nullResolver) Address() string { +func (r *NullResolver) Address() string { return "" } -func (r *nullResolver) CloseIdleConnections() { +func (r *NullResolver) CloseIdleConnections() { // nothing to do } -func (r *nullResolver) LookupHTTPS( +func (r *NullResolver) LookupHTTPS( ctx context.Context, domain string) (*model.HTTPSSvc, error) { return nil, ErrNoResolver } -func (r *nullResolver) LookupNS( +func (r *NullResolver) LookupNS( ctx context.Context, domain string) ([]*net.NS, error) { return nil, ErrNoResolver } diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolver_test.go index 17230b8..b0f265d 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolver_test.go @@ -807,7 +807,7 @@ func TestIsIPv6(t *testing.T) { func TestNullResolver(t *testing.T) { t.Run("LookupHost", func(t *testing.T) { - r := &nullResolver{} + r := &NullResolver{} ctx := context.Background() addrs, err := r.LookupHost(ctx, "dns.google") if !errors.Is(err, ErrNoResolver) { @@ -826,7 +826,7 @@ func TestNullResolver(t *testing.T) { }) t.Run("LookupHTTPS", func(t *testing.T) { - r := &nullResolver{} + r := &NullResolver{} ctx := context.Background() addrs, err := r.LookupHTTPS(ctx, "dns.google") if !errors.Is(err, ErrNoResolver) { @@ -845,7 +845,7 @@ func TestNullResolver(t *testing.T) { }) t.Run("LookupNS", func(t *testing.T) { - r := &nullResolver{} + r := &NullResolver{} ctx := context.Background() ns, err := r.LookupNS(ctx, "dns.google") if !errors.Is(err, ErrNoResolver) { diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 828864c..1c0cd88 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -392,7 +392,7 @@ func TestTLSDialer(t *testing.T) { t.Run("failure dialing", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail - dialer := tlsDialer{Dialer: &dialerSystem{}} + dialer := tlsDialer{Dialer: &DialerSystem{}} conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { t.Fatal("not the error we expected", err)