diff --git a/internal/engine/netx/dialer.go b/internal/engine/netx/dialer.go index 90c712f..29dd31d 100644 --- a/internal/engine/netx/dialer.go +++ b/internal/engine/netx/dialer.go @@ -21,7 +21,7 @@ func NewDialer(config Config) model.Dialer { logger, config.FullResolver, config.Saver.NewConnectObserver(), config.ReadWriteSaver.NewReadWriteObserver(), ) - d = netxlite.NewMaybeProxyDialer(d, config.ProxyURL) + d = netxlite.MaybeWrapWithProxyDialer(d, config.ProxyURL) d = bytecounter.MaybeWrapWithContextAwareDialer(config.ContextByteCounting, d) return d } diff --git a/internal/engine/session.go b/internal/engine/session.go index a6ab2c9..bcc0f7e 100644 --- a/internal/engine/session.go +++ b/internal/engine/session.go @@ -198,7 +198,7 @@ func NewSession(ctx context.Context, config SessionConfig) (*Session, error) { ProxyURL: proxyURL, } dialer := netxlite.NewDialerWithResolver(sess.logger, sess.resolver) - dialer = netxlite.NewMaybeProxyDialer(dialer, proxyURL) + dialer = netxlite.MaybeWrapWithProxyDialer(dialer, proxyURL) handshaker := netxlite.NewTLSHandshakerStdlib(sess.logger) tlsDialer := netxlite.NewTLSDialer(dialer, handshaker) txp := netxlite.NewHTTPTransport(sess.logger, dialer, tlsDialer) diff --git a/internal/netxlite/maybeproxy.go b/internal/netxlite/maybeproxy.go index 6aaf09c..edab70b 100644 --- a/internal/netxlite/maybeproxy.go +++ b/internal/netxlite/maybeproxy.go @@ -1,5 +1,9 @@ package netxlite +// +// Optional proxy support +// + import ( "context" "errors" @@ -10,38 +14,37 @@ import ( "golang.org/x/net/proxy" ) -// MaybeProxyDialer is a dialer that may use a proxy. If the ProxyURL is not configured, -// this dialer is a passthrough for the next Dialer in chain. Otherwise, it will internally -// create a SOCKS5 dialer that will connect to the proxy using the underlying Dialer. -type MaybeProxyDialer struct { +// proxyDialer is a dialer using a proxy. +type proxyDialer struct { Dialer model.Dialer ProxyURL *url.URL } -// NewMaybeProxyDialer creates a new NewMaybeProxyDialer. -func NewMaybeProxyDialer(dialer model.Dialer, proxyURL *url.URL) *MaybeProxyDialer { - return &MaybeProxyDialer{ +// MaybeWrapWithProxyDialer returns the original dialer if the proxyURL is nil +// and otherwise returns a wrapped dialer that implements proxying. +func MaybeWrapWithProxyDialer(dialer model.Dialer, proxyURL *url.URL) model.Dialer { + if proxyURL == nil { + return dialer + } + return &proxyDialer{ Dialer: dialer, ProxyURL: proxyURL, } } -var _ model.Dialer = &MaybeProxyDialer{} +var _ model.Dialer = &proxyDialer{} // CloseIdleConnections implements Dialer.CloseIdleConnections. -func (d *MaybeProxyDialer) CloseIdleConnections() { +func (d *proxyDialer) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -// ErrProxyUnsupportedScheme indicates we don't support a protocol scheme. +// ErrProxyUnsupportedScheme indicates we don't support the proxy scheme. var ErrProxyUnsupportedScheme = errors.New("proxy: unsupported scheme") // DialContext implements Dialer.DialContext. -func (d *MaybeProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { url := d.ProxyURL - if url == nil { - return d.Dialer.DialContext(ctx, network, address) - } if url.Scheme != "socks5" { return nil, ErrProxyUnsupportedScheme } @@ -50,7 +53,7 @@ func (d *MaybeProxyDialer) DialContext(ctx context.Context, network, address str return d.dial(ctx, child, network, address) } -func (d *MaybeProxyDialer) dial( +func (d *proxyDialer) dial( ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) { cd := child.(proxy.ContextDialer) // will work return cd.DialContext(ctx, network, address) diff --git a/internal/netxlite/maybeproxy_test.go b/internal/netxlite/maybeproxy_test.go index af3a502..cb9fd07 100644 --- a/internal/netxlite/maybeproxy_test.go +++ b/internal/netxlite/maybeproxy_test.go @@ -12,28 +12,34 @@ import ( ) func TestMaybeProxyDialer(t *testing.T) { - t.Run("DialContext", func(t *testing.T) { - t.Run("missing proxy URL", func(t *testing.T) { - expected := errors.New("mocked error") - d := &MaybeProxyDialer{ - Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, expected - }}, - ProxyURL: nil, - } - conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, expected) { - t.Fatal(err) - } - if conn != nil { - t.Fatal("conn is not nil") + t.Run("MaybeWrapWithProxyDialer", func(t *testing.T) { + t.Run("without a proxy URL", func(t *testing.T) { + underlying := &mocks.Dialer{} + dialer := MaybeWrapWithProxyDialer(underlying, nil) + if dialer != underlying { + t.Fatal("should not have wrapped") } }) + t.Run("with a proxy URL", func(t *testing.T) { + URL := &url.URL{} + underlying := &mocks.Dialer{} + dialer := MaybeWrapWithProxyDialer(underlying, URL) + real := dialer.(*proxyDialer) + if real.Dialer != underlying { + t.Fatal("did not wrap correctly") + } + if real.ProxyURL != URL { + t.Fatal("invalid URL") + } + }) + }) + + t.Run("DialContext", func(t *testing.T) { t.Run("invalid scheme", func(t *testing.T) { child := &mocks.Dialer{} URL := &url.URL{Scheme: "antani"} - d := NewMaybeProxyDialer(child, URL) + d := MaybeWrapWithProxyDialer(child, URL) conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") if !errors.Is(err, ErrProxyUnsupportedScheme) { t.Fatal("not the error we expected") @@ -45,7 +51,7 @@ func TestMaybeProxyDialer(t *testing.T) { t.Run("underlying dial fails with EOF", func(t *testing.T) { const expect = "10.0.0.1:9050" - d := &MaybeProxyDialer{ + d := &proxyDialer{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { if address != expect { @@ -77,7 +83,7 @@ func TestMaybeProxyDialer(t *testing.T) { }, } URL := &url.URL{} - dialer := NewMaybeProxyDialer(child, URL) + dialer := MaybeWrapWithProxyDialer(child, URL) dialer.CloseIdleConnections() if !called { t.Fatal("not called")