diff --git a/internal/engine/netx/dialer/dialer.go b/internal/engine/netx/dialer/dialer.go index 567d61b..d385f38 100644 --- a/internal/engine/netx/dialer/dialer.go +++ b/internal/engine/netx/dialer/dialer.go @@ -54,21 +54,10 @@ func New(config *Config, resolver model.Resolver) model.Dialer { if config.Logger != nil { logger = config.Logger } - modifiers := []netxlite.DialerWrapper{ - func(dialer model.Dialer) model.Dialer { - if config.DialSaver != nil { - dialer = &tracex.SaverDialer{Dialer: dialer, Saver: config.DialSaver} - } - return dialer - }, - func(dialer model.Dialer) model.Dialer { - if config.ReadWriteSaver != nil { - dialer = &tracex.SaverConnDialer{Dialer: dialer, Saver: config.ReadWriteSaver} - } - return dialer - }, - } - d := netxlite.NewDialerWithResolver(logger, resolver, modifiers...) + d := netxlite.NewDialerWithResolver( + logger, resolver, config.DialSaver.NewConnectObserver(), + config.ReadWriteSaver.NewReadWriteObserver(), + ) d = &netxlite.MaybeProxyDialer{ProxyURL: config.ProxyURL, Dialer: d} if config.ContextByteCounting { d = &bytecounter.ContextAwareDialer{Dialer: d} diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index ae85470..bd4715d 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -132,12 +132,7 @@ func NewQUICDialer(config Config) model.QUICDialer { if config.Logger != nil { logger = config.Logger } - extensions := []netxlite.QUICDialerWrapper{ - func(dialer model.QUICDialer) model.QUICDialer { - return config.TLSSaver.WrapQUICDialer(dialer) // robust to nil TLSSaver - }, - } - return netxlite.NewQUICDialerWithResolver(ql, logger, config.FullResolver, extensions...) + return netxlite.NewQUICDialerWithResolver(ql, logger, config.FullResolver, config.TLSSaver) } // NewTLSDialer creates a new TLSDialer from the specified config diff --git a/internal/engine/netx/tracex/dialer.go b/internal/engine/netx/tracex/dialer.go index e6abdc9..528b930 100644 --- a/internal/engine/netx/tracex/dialer.go +++ b/internal/engine/netx/tracex/dialer.go @@ -22,6 +22,31 @@ type SaverDialer struct { Saver *Saver } +// NewConnectObserver returns a DialerWrapper that observes the +// connect event. This function will return nil, which is a valid +// DialerWrapper for netxlite.WrapDialer, if Saver is nil. +func (s *Saver) NewConnectObserver() model.DialerWrapper { + if s == nil { + return nil // valid DialerWrapper according to netxlite's docs + } + return &saverDialerWrapper{ + saver: s, + } +} + +type saverDialerWrapper struct { + saver *Saver +} + +var _ model.DialerWrapper = &saverDialerWrapper{} + +func (w *saverDialerWrapper) WrapDialer(d model.Dialer) model.Dialer { + return &SaverDialer{ + Dialer: d, + Saver: w.saver, + } +} + // DialContext implements Dialer.DialContext func (d *SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { start := time.Now() @@ -52,6 +77,31 @@ type SaverConnDialer struct { Saver *Saver } +// NewReadWriteObserver returns a DialerWrapper that observes the +// I/O events. This function will return nil, which is a valid +// DialerWrapper for netxlite.WrapDialer, if Saver is nil. +func (s *Saver) NewReadWriteObserver() model.DialerWrapper { + if s == nil { + return nil // valid DialerWrapper according to netxlite's docs + } + return &saverReadWriteWrapper{ + saver: s, + } +} + +type saverReadWriteWrapper struct { + saver *Saver +} + +var _ model.DialerWrapper = &saverReadWriteWrapper{} + +func (w *saverReadWriteWrapper) WrapDialer(d model.Dialer) model.Dialer { + return &SaverConnDialer{ + Dialer: d, + Saver: w.saver, + } +} + // DialContext implements Dialer.DialContext func (d *SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) diff --git a/internal/engine/netx/tracex/saver.go b/internal/engine/netx/tracex/saver.go index 174f1b1..426e9d5 100644 --- a/internal/engine/netx/tracex/saver.go +++ b/internal/engine/netx/tracex/saver.go @@ -7,7 +7,7 @@ package tracex import "sync" // The Saver saves a trace. The zero value of this type -// is valid and can be used without initializtion. +// is valid and can be used without initialization. type Saver struct { // ops contains the saved events. ops []Event diff --git a/internal/engine/netx/tracex/tls_test.go b/internal/engine/netx/tracex/tls_test.go index c6f9e98..5a322a0 100644 --- a/internal/engine/netx/tracex/tls_test.go +++ b/internal/engine/netx/tracex/tls_test.go @@ -23,12 +23,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { Dialer: netxlite.NewDialerWithResolver( model.DiscardLogger, netxlite.NewResolverStdlib(model.DiscardLogger), - func(dialer model.Dialer) model.Dialer { - return &SaverConnDialer{ - Dialer: dialer, - Saver: saver, - } - }, + saver.NewReadWriteObserver(), ), TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), } diff --git a/internal/model/netx.go b/internal/model/netx.go index e1f017f..326acdc 100644 --- a/internal/model/netx.go +++ b/internal/model/netx.go @@ -119,6 +119,12 @@ type DNSTransport interface { CloseIdleConnections() } +// DialerWrapper is a type that takes in input a Dialer +// and returns in output a wrapped Dialer. +type DialerWrapper interface { + WrapDialer(d Dialer) Dialer +} + // SimpleDialer establishes network connections. type SimpleDialer interface { // DialContext behaves like net.Dialer.DialContext. @@ -171,6 +177,12 @@ type QUICListener interface { Listen(addr *net.UDPAddr) (UDPLikeConn, error) } +// QUICDialerWrapper is a type that takes in input a QUICDialer +// and returns in output a wrapped QUICDialer. +type QUICDialerWrapper interface { + WrapQUICDialer(qd QUICDialer) QUICDialer +} + // QUICDialer dials QUIC sessions. type QUICDialer interface { // DialContext establishes a new QUIC session using the given diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index 48f2a4c..8f2f71f 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -14,13 +14,9 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// DialerWrapper is a function that allows you to customize the kind of Dialer returned -// by WrapDialer, NewDialerWithResolver, and NewDialerWithoutResolver. -type DialerWrapper func(dialer model.Dialer) model.Dialer - // NewDialerWithResolver is equivalent to calling WrapDialer with // the dialer argument being equal to &DialerSystem{}. -func NewDialerWithResolver(dl model.DebugLogger, r model.Resolver, w ...DialerWrapper) model.Dialer { +func NewDialerWithResolver(dl model.DebugLogger, r model.Resolver, w ...model.DialerWrapper) model.Dialer { return WrapDialer(dl, r, &DialerSystem{}, w...) } @@ -40,7 +36,8 @@ func NewDialerWithResolver(dl model.DebugLogger, r model.Resolver, w ...DialerWr // 3. baseDialer is the dialer to wrap (MUST NOT be nil); // // 4. wrappers is a list of zero or more functions allowing you to -// modify the behavior of the returned dialer (see below). +// modify the behavior of the returned dialer (see below). Please note +// that this function will just ignore any nil wrapper. // // Return value // @@ -109,12 +106,15 @@ func NewDialerWithResolver(dl model.DebugLogger, r model.Resolver, w ...DialerWr // of a single connect operation. You may want to use the context to reduce // the overall time spent trying all addresses and timing out. func WrapDialer(logger model.DebugLogger, resolver model.Resolver, - baseDialer model.Dialer, wrappers ...DialerWrapper) (outDialer model.Dialer) { + baseDialer model.Dialer, wrappers ...model.DialerWrapper) (outDialer model.Dialer) { outDialer = &dialerErrWrapper{ Dialer: baseDialer, } for _, wrapper := range wrappers { - outDialer = wrapper(outDialer) // extend with user-supplied constructors + if wrapper == nil { + continue // ignore as documented + } + outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors } return &dialerLogger{ Dialer: &dialerResolver{ @@ -131,7 +131,7 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver, // NewDialerWithoutResolver is equivalent to calling NewDialerWithResolver // with the resolver argument being &NullResolver{}. -func NewDialerWithoutResolver(dl model.DebugLogger, w ...DialerWrapper) model.Dialer { +func NewDialerWithoutResolver(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { return NewDialerWithResolver(dl, &NullResolver{}, w...) } diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 795a09f..b8736cd 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -19,19 +19,27 @@ type extensionDialerFirst struct { model.Dialer } +type dialerWrapperFirst struct{} + +func (*dialerWrapperFirst) WrapDialer(d model.Dialer) model.Dialer { + return &extensionDialerFirst{d} +} + type extensionDialerSecond struct { model.Dialer } +type dialerWrapperSecond struct{} + +func (*dialerWrapperSecond) WrapDialer(d model.Dialer) model.Dialer { + return &extensionDialerSecond{d} +} func TestNewDialer(t *testing.T) { t.Run("produces a chain with the expected types", func(t *testing.T) { - modifiers := []DialerWrapper{ - func(dialer model.Dialer) model.Dialer { - return &extensionDialerFirst{dialer} - }, - func(dialer model.Dialer) model.Dialer { - return &extensionDialerSecond{dialer} - }, + modifiers := []model.DialerWrapper{ + &dialerWrapperFirst{}, + nil, // explicitly test for this documented case + &dialerWrapperSecond{}, } d := NewDialerWithoutResolver(log.Log, modifiers...) logger := d.(*dialerLogger) diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index a47d386..0969a40 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -32,18 +32,16 @@ func (qls *quicListenerStdlib) Listen(addr *net.UDPAddr) (model.UDPLikeConn, err return TProxy.ListenUDP("udp", addr) } -// QUICDialerWrapper is a function that allows you to customize the kind of QUICDialer -// returned by NewQUICDialerWithResolver and NewQUICDialerWithoutResolver. -type QUICDialerWrapper func(dialer model.QUICDialer) model.QUICDialer - // NewQUICDialerWithResolver is the WrapDialer equivalent for QUIC where // we return a composed QUICDialer modified by optional wrappers. // +// Please, note that this fuunction will just ignore any nil wrapper. +// // Unlike the dialer returned by WrapDialer, this dialer MAY attempt // happy eyeballs, perform parallel dial attempts, and return an error // that aggregates all the errors that occurred. func NewQUICDialerWithResolver(listener model.QUICListener, logger model.DebugLogger, - resolver model.Resolver, wrappers ...QUICDialerWrapper) (outDialer model.QUICDialer) { + resolver model.Resolver, wrappers ...model.QUICDialerWrapper) (outDialer model.QUICDialer) { outDialer = &quicDialerErrWrapper{ QUICDialer: &quicDialerHandshakeCompleter{ Dialer: &quicDialerQUICGo{ @@ -52,7 +50,10 @@ func NewQUICDialerWithResolver(listener model.QUICListener, logger model.DebugLo }, } for _, wrapper := range wrappers { - outDialer = wrapper(outDialer) // extend with user-supplied constructors + if wrapper == nil { + continue // ignore as documented + } + outDialer = wrapper.WrapQUICDialer(outDialer) // extend with user-supplied constructors } return &quicDialerLogger{ Dialer: &quicDialerResolver{ @@ -70,7 +71,7 @@ func NewQUICDialerWithResolver(listener model.QUICListener, logger model.DebugLo // NewQUICDialerWithoutResolver is equivalent to calling NewQUICDialerWithResolver // with the resolver argument set to &NullResolver{}. func NewQUICDialerWithoutResolver(listener model.QUICListener, - logger model.DebugLogger, wrappers ...QUICDialerWrapper) model.QUICDialer { + logger model.DebugLogger, wrappers ...model.QUICDialerWrapper) model.QUICDialer { return NewQUICDialerWithResolver(listener, logger, &NullResolver{}, wrappers...) } diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index e730e2e..aca1860 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -26,19 +26,30 @@ type extensionQUICDialerFirst struct { model.QUICDialer } +type quicDialerWrapperFirst struct{} + +func (*quicDialerWrapperFirst) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer { + return &extensionQUICDialerFirst{qd} +} + type extensionQUICDialerSecond struct { model.QUICDialer } +type quicDialerWrapperSecond struct { + model.QUICDialer +} + +func (*quicDialerWrapperSecond) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer { + return &extensionQUICDialerSecond{qd} +} + func TestNewQUICDialer(t *testing.T) { ql := NewQUICListener() - extensions := []QUICDialerWrapper{ - func(dialer model.QUICDialer) model.QUICDialer { - return &extensionQUICDialerFirst{dialer} - }, - func(dialer model.QUICDialer) model.QUICDialer { - return &extensionQUICDialerSecond{dialer} - }, + extensions := []model.QUICDialerWrapper{ + &quicDialerWrapperFirst{}, + nil, // explicitly test for this documented case + &quicDialerWrapperSecond{}, } dlr := NewQUICDialerWithoutResolver(ql, log.Log, extensions...) logger := dlr.(*quicDialerLogger)