From 06ee0e55a9dde842f64879e6f5dc0b405e3c878b Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 9 Jun 2021 09:42:31 +0200 Subject: [PATCH] refactor(netx/dialer): hide implementation complexity (#372) * refactor(netx/dialer): hide implementation complexity This follows the blueprint of `module.Config` and `nodule.New` described at https://github.com/ooni/probe/issues/1591. * fix: ndt7 bug where we were not using the right resolver * fix(legacy/netx): clarify irrelevant implementation change * fix: improve comments * fix(hhfm): do not use dialer.New b/c it breaks it Unclear to me why this is happening. Still, improve upon the previous situation by adding a timeout. It does not seem a priority to look into this issue now. --- internal/engine/experiment/hhfm/hhfm.go | 7 +- internal/engine/experiment/hhfm/hhfm_test.go | 22 +- internal/engine/experiment/ndt7/dial.go | 12 +- internal/engine/legacy/netx/dialer.go | 24 +- internal/engine/netx/dialer/bytecounter.go | 38 +-- .../engine/netx/dialer/bytecounter_test.go | 11 +- internal/engine/netx/dialer/dialer.go | 76 ++++- internal/engine/netx/dialer/dialer_test.go | 57 ++++ internal/engine/netx/dialer/dns.go | 19 +- internal/engine/netx/dialer/dns_test.go | 28 +- internal/engine/netx/dialer/doc.go | 4 + internal/engine/netx/dialer/errorwrapper.go | 18 +- .../engine/netx/dialer/errorwrapper_test.go | 7 +- internal/engine/netx/dialer/example_test.go | 30 ++ .../engine/netx/dialer/integration_test.go | 12 +- internal/engine/netx/dialer/logging.go | 12 +- internal/engine/netx/dialer/logging_test.go | 5 +- internal/engine/netx/dialer/proxy.go | 12 +- internal/engine/netx/dialer/proxy_test.go | 6 +- internal/engine/netx/dialer/saver.go | 23 +- internal/engine/netx/dialer/saver_test.go | 11 +- .../engine/netx/dialer/shaping_disabled.go | 6 +- .../engine/netx/dialer/shaping_enabled.go | 10 +- internal/engine/netx/dialer/shaping_test.go | 14 +- internal/engine/netx/dialer/system.go | 16 +- internal/engine/netx/dialer/system_test.go | 6 +- .../netx/httptransport/http3transport_test.go | 19 +- internal/engine/netx/netx.go | 25 +- internal/engine/netx/netx_test.go | 294 ------------------ internal/engine/netx/tlsdialer/saver_test.go | 5 +- 30 files changed, 312 insertions(+), 517 deletions(-) create mode 100644 internal/engine/netx/dialer/dialer_test.go create mode 100644 internal/engine/netx/dialer/doc.go create mode 100644 internal/engine/netx/dialer/example_test.go diff --git a/internal/engine/experiment/hhfm/hhfm.go b/internal/engine/experiment/hhfm/hhfm.go index 1d07dd0..a2eee07 100644 --- a/internal/engine/experiment/hhfm/hhfm.go +++ b/internal/engine/experiment/hhfm/hhfm.go @@ -18,7 +18,6 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/experiment/urlgetter" "github.com/ooni/probe-cli/v3/internal/engine/httpheader" "github.com/ooni/probe-cli/v3/internal/engine/model" - "github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx/archival" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" @@ -312,7 +311,7 @@ type JSONHeaders struct { // guarantee that the connection is used for a single request and that // such a request does not contain any body. type Dialer struct { - Dialer netx.Dialer // used for testing + Dialer dialer.Dialer // used for testing Headers map[string]string } @@ -321,7 +320,9 @@ type Dialer struct { func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { child := d.Dialer if child == nil { - child = dialer.Default + // TODO(bassosimone): figure out why using dialer.New here + // causes the experiment to fail with eof_error + child = &net.Dialer{Timeout: 15 * time.Second} } conn, err := child.DialContext(ctx, network, address) if err != nil { diff --git a/internal/engine/experiment/hhfm/hhfm_test.go b/internal/engine/experiment/hhfm/hhfm_test.go index b3190fd..090dd4e 100644 --- a/internal/engine/experiment/hhfm/hhfm_test.go +++ b/internal/engine/experiment/hhfm/hhfm_test.go @@ -13,7 +13,6 @@ import ( "github.com/apex/log" "github.com/google/go-cmp/cmp" - engine "github.com/ooni/probe-cli/v3/internal/engine" "github.com/ooni/probe-cli/v3/internal/engine/experiment/hhfm" "github.com/ooni/probe-cli/v3/internal/engine/experiment/urlgetter" "github.com/ooni/probe-cli/v3/internal/engine/internal/mockable" @@ -55,7 +54,7 @@ func TestSuccess(t *testing.T) { t.Fatal("invalid Agent") } if tk.Failure != nil { - t.Fatal("invalid Failure") + t.Fatal("invalid Failure", *tk.Failure) } if len(tk.Requests) != 1 { t.Fatal("invalid Requests") @@ -557,25 +556,6 @@ func TestTransactCannotReadBody(t *testing.T) { } } -func newsession(t *testing.T) model.ExperimentSession { - sess, err := engine.NewSession(context.Background(), engine.SessionConfig{ - AvailableProbeServices: []model.Service{{ - Address: "https://ams-pg-test.ooni.org", - Type: "https", - }}, - Logger: log.Log, - SoftwareName: "ooniprobe-engine", - SoftwareVersion: "0.0.1", - }) - if err != nil { - t.Fatal(err) - } - if err := sess.MaybeLookupBackends(); err != nil { - t.Fatal(err) - } - return sess -} - func TestTestKeys_FillTampering(t *testing.T) { type fields struct { Agent string diff --git a/internal/engine/experiment/ndt7/dial.go b/internal/engine/experiment/ndt7/dial.go index 3829da6..31645d7 100644 --- a/internal/engine/experiment/ndt7/dial.go +++ b/internal/engine/experiment/ndt7/dial.go @@ -35,13 +35,11 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) { var reso resolver.Resolver = resolver.SystemResolver{} reso = resolver.LoggingResolver{Resolver: reso, Logger: mgr.logger} - var dlr dialer.Dialer = dialer.Default - dlr = dialer.ErrorWrapperDialer{Dialer: dlr} - dlr = dialer.LoggingDialer{Dialer: dlr, Logger: mgr.logger} - dlr = dialer.DNSDialer{Dialer: dlr, Resolver: reso} - dlr = dialer.ProxyDialer{Dialer: dlr, ProxyURL: mgr.proxyURL} - dlr = dialer.ByteCounterDialer{Dialer: dlr} - dlr = dialer.ShapingDialer{Dialer: dlr} + dlr := dialer.New(&dialer.Config{ + ContextByteCounting: true, + Logger: mgr.logger, + ProxyURL: mgr.proxyURL, + }, reso) dialer := websocket.Dialer{ NetDialContext: dlr.DialContext, ReadBufferSize: mgr.readBufferSize, diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index e3ecf64..56721f0 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -64,17 +64,19 @@ func maybeWithMeasurementRoot( // - dialer.Default // // If you have others needs, manually build the chain you need. -func newDNSDialer(resolver dialer.Resolver) dialer.DNSDialer { - return dialer.DNSDialer{ - Dialer: EmitterDialer{ - Dialer: dialer.ErrorWrapperDialer{ - Dialer: dialer.ByteCounterDialer{ - Dialer: dialer.Default, - }, - }, - }, - Resolver: resolver, - } +func newDNSDialer(resolver dialer.Resolver) dialer.Dialer { + // Implementation note: we're wrapping the result of dialer.New + // on the outside, while previously we were puttting the + // EmitterDialer before the DNSDialer (see the above comment). + // + // Yet, this is fine because the only experiment which is + // using this code is tor, for which it doesn't matter. + // + // Also (and I am always scared to write this kind of + // comments), we should rewrite tor soon. + return &EmitterDialer{dialer.New(&dialer.Config{ + ContextByteCounting: true, + }, resolver)} } // DialContext is like Dial but the context allows to interrupt a diff --git a/internal/engine/netx/dialer/bytecounter.go b/internal/engine/netx/dialer/bytecounter.go index c818ba8..2aede34 100644 --- a/internal/engine/netx/dialer/bytecounter.go +++ b/internal/engine/netx/dialer/bytecounter.go @@ -7,59 +7,49 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" ) -// ByteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you +// byteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you // should make sure that you insert this dialer in the dialing chain. -// -// Bug -// -// This implementation cannot properly account for the bytes that are sent by -// persistent connections, because they strick to the counters set when the -// connection was established. This typically means we miss the bytes sent and -// received when submitting a measurement. Such bytes are specifically not -// see by the experiment specific byte counter. -// -// For this reason, this implementation may be heavily changed/removed. -type ByteCounterDialer struct { +type byteCounterDialer struct { Dialer } // DialContext implements Dialer.DialContext -func (d ByteCounterDialer) DialContext( +func (d *byteCounterDialer) DialContext( ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { return nil, err } - exp := ContextExperimentByteCounter(ctx) - sess := ContextSessionByteCounter(ctx) + exp := contextExperimentByteCounter(ctx) + sess := contextSessionByteCounter(ctx) if exp == nil && sess == nil { return conn, nil // no point in wrapping } - return byteCounterConnWrapper{Conn: conn, exp: exp, sess: sess}, nil + return &byteCounterConnWrapper{Conn: conn, exp: exp, sess: sess}, nil } type byteCounterSessionKey struct{} -// ContextSessionByteCounter retrieves the session byte counter from the context -func ContextSessionByteCounter(ctx context.Context) *bytecounter.Counter { +// contextSessionByteCounter retrieves the session byte counter from the context +func contextSessionByteCounter(ctx context.Context) *bytecounter.Counter { counter, _ := ctx.Value(byteCounterSessionKey{}).(*bytecounter.Counter) return counter } -// WithSessionByteCounter assigns the session byte counter to the context +// WithSessionByteCounter assigns the session byte counter to the context. func WithSessionByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context { return context.WithValue(ctx, byteCounterSessionKey{}, counter) } type byteCounterExperimentKey struct{} -// ContextExperimentByteCounter retrieves the experiment byte counter from the context -func ContextExperimentByteCounter(ctx context.Context) *bytecounter.Counter { +// contextExperimentByteCounter retrieves the experiment byte counter from the context +func contextExperimentByteCounter(ctx context.Context) *bytecounter.Counter { counter, _ := ctx.Value(byteCounterExperimentKey{}).(*bytecounter.Counter) return counter } -// WithExperimentByteCounter assigns the experiment byte counter to the context +// WithExperimentByteCounter assigns the experiment byte counter to the context. func WithExperimentByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context { return context.WithValue(ctx, byteCounterExperimentKey{}, counter) } @@ -70,7 +60,7 @@ type byteCounterConnWrapper struct { sess *bytecounter.Counter } -func (c byteCounterConnWrapper) Read(p []byte) (int, error) { +func (c *byteCounterConnWrapper) Read(p []byte) (int, error) { count, err := c.Conn.Read(p) if c.exp != nil { c.exp.CountBytesReceived(count) @@ -81,7 +71,7 @@ func (c byteCounterConnWrapper) Read(p []byte) (int, error) { return count, err } -func (c byteCounterConnWrapper) Write(p []byte) (int, error) { +func (c *byteCounterConnWrapper) Write(p []byte) (int, error) { count, err := c.Conn.Write(p) if c.exp != nil { c.exp.CountBytesSent(count) diff --git a/internal/engine/netx/dialer/bytecounter_test.go b/internal/engine/netx/dialer/bytecounter_test.go index 8fd09a2..2e9ff65 100644 --- a/internal/engine/netx/dialer/bytecounter_test.go +++ b/internal/engine/netx/dialer/bytecounter_test.go @@ -1,4 +1,4 @@ -package dialer_test +package dialer import ( "context" @@ -10,14 +10,13 @@ import ( "testing" "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func dorequest(ctx context.Context, url string) error { txp := http.DefaultTransport.(*http.Transport).Clone() defer txp.CloseIdleConnections() - dialer := dialer.ByteCounterDialer{Dialer: new(net.Dialer)} + dialer := &byteCounterDialer{Dialer: new(net.Dialer)} txp.DialContext = dialer.DialContext client := &http.Client{Transport: txp} req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil) @@ -40,12 +39,12 @@ func TestByteCounterNormalUsage(t *testing.T) { } sess := bytecounter.New() ctx := context.Background() - ctx = dialer.WithSessionByteCounter(ctx, sess) + ctx = WithSessionByteCounter(ctx, sess) if err := dorequest(ctx, "http://www.google.com"); err != nil { t.Fatal(err) } exp := bytecounter.New() - ctx = dialer.WithExperimentByteCounter(ctx, exp) + ctx = WithExperimentByteCounter(ctx, exp) if err := dorequest(ctx, "http://facebook.com"); err != nil { t.Fatal(err) } @@ -71,7 +70,7 @@ func TestByteCounterNoHandlers(t *testing.T) { } func TestByteCounterConnectFailure(t *testing.T) { - dialer := dialer.ByteCounterDialer{Dialer: mockablex.Dialer{ + dialer := &byteCounterDialer{Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, diff --git a/internal/engine/netx/dialer/dialer.go b/internal/engine/netx/dialer/dialer.go index b50250d..1bcd272 100644 --- a/internal/engine/netx/dialer/dialer.go +++ b/internal/engine/netx/dialer/dialer.go @@ -3,9 +3,83 @@ package dialer import ( "context" "net" + "net/url" + + "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) -// Dialer is the interface we expect from a dialer +// Dialer establishes network connections. type Dialer interface { + // DialContext behaves like net.Dialer.DialContext. DialContext(ctx context.Context, network, address string) (net.Conn, error) } + +// Resolver is the interface we expect from a DNS resolver. +type Resolver interface { + // LookupHost behaves like net.Resolver.LookupHost. + LookupHost(ctx context.Context, hostname string) (addrs []string, err error) +} + +// Logger is the interface we expect from a logger. +type Logger interface { + // Debugf formats and emits a debug message. + Debugf(format string, v ...interface{}) +} + +// Config contains the settings for New. +type Config struct { + // ContextByteCounting optionally configures context-based + // byte counting. By default we don't do that. + // + // Use WithExperimentByteCounter and WithSessionByteCounter + // to assign byte counters to a context. The code will use + // corresponding, private functions to access the configured + // byte counters and will notify them about I/O events. + // + // Bug + // + // This implementation cannot properly account for the bytes that are sent by + // persistent connections, because they strick to the counters set when the + // connection was established. This typically means we miss the bytes sent and + // received when submitting a measurement. Such bytes are specifically not + // seen by the experiment specific byte counter. + // + // For this reason, this implementation may be heavily changed/removed. + ContextByteCounting bool + + // DialSaver is the optional saver for dialing events. If not + // set, we will not save any dialing event. + DialSaver *trace.Saver + + // Logger is the optional logger. If not set, there + // will be no logging from the new dialer. + Logger Logger + + // ProxyURL is the optional proxy URL. + ProxyURL *url.URL + + // ReadWriteSaver is like DialSaver but for I/O events. + ReadWriteSaver *trace.Saver +} + +// New creates a new Dialer from the specified config and resolver. +func New(config *Config, resolver Resolver) Dialer { + var d Dialer = systemDialer + d = &errorWrapperDialer{Dialer: d} + if config.Logger != nil { + d = &loggingDialer{Dialer: d, Logger: config.Logger} + } + if config.DialSaver != nil { + d = &saverDialer{Dialer: d, Saver: config.DialSaver} + } + if config.ReadWriteSaver != nil { + d = &saverConnDialer{Dialer: d, Saver: config.ReadWriteSaver} + } + d = &dnsDialer{Resolver: resolver, Dialer: d} + d = &proxyDialer{ProxyURL: config.ProxyURL, Dialer: d} + if config.ContextByteCounting { + d = &byteCounterDialer{Dialer: d} + } + d = &shapingDialer{Dialer: d} + return d +} diff --git a/internal/engine/netx/dialer/dialer_test.go b/internal/engine/netx/dialer/dialer_test.go new file mode 100644 index 0000000..b861be0 --- /dev/null +++ b/internal/engine/netx/dialer/dialer_test.go @@ -0,0 +1,57 @@ +package dialer + +import ( + "net" + "net/url" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" +) + +func TestNewCreatesTheExpectedChain(t *testing.T) { + saver := &trace.Saver{} + dlr := New(&Config{ + ContextByteCounting: true, + DialSaver: saver, + Logger: log.Log, + ProxyURL: &url.URL{}, + ReadWriteSaver: saver, + }, &net.Resolver{}) + shd, ok := dlr.(*shapingDialer) + if !ok { + t.Fatal("not a shapingDialer") + } + bcd, ok := shd.Dialer.(*byteCounterDialer) + if !ok { + t.Fatal("not a byteCounterDialer") + } + pd, ok := bcd.Dialer.(*proxyDialer) + if !ok { + t.Fatal("not a proxyDialer") + } + dnsd, ok := pd.Dialer.(*dnsDialer) + if !ok { + t.Fatal("not a dnsDialer") + } + scd, ok := dnsd.Dialer.(*saverConnDialer) + if !ok { + t.Fatal("not a saverConnDialer") + } + sd, ok := scd.Dialer.(*saverDialer) + if !ok { + t.Fatal("not a saverDialer") + } + ld, ok := sd.Dialer.(*loggingDialer) + if !ok { + t.Fatal("not a loggingDialer") + } + ewd, ok := ld.Dialer.(*errorWrapperDialer) + if !ok { + t.Fatal("not an errorWrappingDialer") + } + _, ok = ewd.Dialer.(*net.Dialer) + if !ok { + t.Fatal("not a net.Dialer") + } +} diff --git a/internal/engine/netx/dialer/dns.go b/internal/engine/netx/dialer/dns.go index 452eedd..0062954 100644 --- a/internal/engine/netx/dialer/dns.go +++ b/internal/engine/netx/dialer/dns.go @@ -9,26 +9,21 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" ) -// Resolver is the interface we expect from a resolver -type Resolver interface { - LookupHost(ctx context.Context, hostname string) (addrs []string, err error) -} - -// DNSDialer is a dialer that uses the configured Resolver to resolver a +// dnsDialer is a dialer that uses the configured Resolver to resolver a // domain name to IP addresses, and the configured Dialer to connect. -type DNSDialer struct { +type dnsDialer struct { Dialer Resolver Resolver } // DialContext implements Dialer.DialContext. -func (d DNSDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *dnsDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { onlyhost, onlyport, err := net.SplitHostPort(address) if err != nil { return nil, err } var addrs []string - addrs, err = d.LookupHost(ctx, onlyhost) + addrs, err = d.lookupHost(ctx, onlyhost) if err != nil { return nil, err } @@ -44,7 +39,7 @@ func (d DNSDialer) DialContext(ctx context.Context, network, address string) (ne return nil, ReduceErrors(errorslist) } -// ReduceErrors finds a known error in a list of errors since it's probably most relevant +// ReduceErrors finds a known error in a list of errors since it's probably most relevant. func ReduceErrors(errorslist []error) error { if len(errorslist) == 0 { return nil @@ -67,8 +62,8 @@ func ReduceErrors(errorslist []error) error { return errorslist[0] } -// LookupHost implements Resolver.LookupHost -func (d DNSDialer) LookupHost(ctx context.Context, hostname string) ([]string, error) { +// lookupHost performs a domain name resolution. +func (d *dnsDialer) lookupHost(ctx context.Context, hostname string) ([]string, error) { if net.ParseIP(hostname) != nil { return []string{hostname}, nil } diff --git a/internal/engine/netx/dialer/dns_test.go b/internal/engine/netx/dialer/dns_test.go index 6921633..bdbdf68 100644 --- a/internal/engine/netx/dialer/dns_test.go +++ b/internal/engine/netx/dialer/dns_test.go @@ -1,4 +1,4 @@ -package dialer_test +package dialer import ( "context" @@ -7,13 +7,12 @@ import ( "net" "testing" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestDNSDialerNoPort(t *testing.T) { - dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: new(net.Resolver)} + dialer := &dnsDialer{Dialer: new(net.Dialer), Resolver: new(net.Resolver)} conn, err := dialer.DialContext(context.Background(), "tcp", "antani.ooni.nu") if err == nil { t.Fatal("expected an error here") @@ -24,10 +23,10 @@ func TestDNSDialerNoPort(t *testing.T) { } func TestDNSDialerLookupHostAddress(t *testing.T) { - dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{ + dialer := &dnsDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{ Err: errors.New("mocked error"), }} - addrs, err := dialer.LookupHost(context.Background(), "1.1.1.1") + addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1") if err != nil { t.Fatal(err) } @@ -38,7 +37,7 @@ func TestDNSDialerLookupHostAddress(t *testing.T) { func TestDNSDialerLookupHostFailure(t *testing.T) { expected := errors.New("mocked error") - dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{ + dialer := &dnsDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{ Err: expected, }} conn, err := dialer.DialContext(context.Background(), "tcp", "dns.google.com:853") @@ -60,7 +59,7 @@ func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string } func TestDNSDialerDialForSingleIPFails(t *testing.T) { - dialer := dialer.DNSDialer{Dialer: mockablex.Dialer{ + dialer := &dnsDialer{Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, @@ -75,7 +74,7 @@ func TestDNSDialerDialForSingleIPFails(t *testing.T) { } func TestDNSDialerDialForManyIPFails(t *testing.T) { - dialer := dialer.DNSDialer{ + dialer := &dnsDialer{ Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF @@ -93,7 +92,7 @@ func TestDNSDialerDialForManyIPFails(t *testing.T) { } func TestDNSDialerDialForManyIPSuccess(t *testing.T) { - dialer := dialer.DNSDialer{Dialer: mockablex.Dialer{ + dialer := &dnsDialer{Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return &mockablex.Conn{ MockClose: func() error { @@ -114,12 +113,9 @@ func TestDNSDialerDialForManyIPSuccess(t *testing.T) { conn.Close() } -// TODO(bassosimone): remove the dialID etc since the only -// test still using legacy/netx does not care. - func TestReduceErrors(t *testing.T) { t.Run("no errors", func(t *testing.T) { - result := dialer.ReduceErrors(nil) + result := ReduceErrors(nil) if result != nil { t.Fatal("wrong result") } @@ -127,7 +123,7 @@ func TestReduceErrors(t *testing.T) { t.Run("single error", func(t *testing.T) { err := errors.New("mocked error") - result := dialer.ReduceErrors([]error{err}) + result := ReduceErrors([]error{err}) if result != err { t.Fatal("wrong result") } @@ -136,7 +132,7 @@ func TestReduceErrors(t *testing.T) { t.Run("multiple errors", func(t *testing.T) { err1 := errors.New("mocked error #1") err2 := errors.New("mocked error #2") - result := dialer.ReduceErrors([]error{err1, err2}) + result := ReduceErrors([]error{err1, err2}) if result.Error() != "mocked error #1" { t.Fatal("wrong result") } @@ -151,7 +147,7 @@ func TestReduceErrors(t *testing.T) { Failure: errorx.FailureConnectionRefused, } err4 := errors.New("mocked error #3") - result := dialer.ReduceErrors([]error{err1, err2, err3, err4}) + result := ReduceErrors([]error{err1, err2, err3, err4}) if result.Error() != errorx.FailureConnectionRefused { t.Fatal("wrong result") } diff --git a/internal/engine/netx/dialer/doc.go b/internal/engine/netx/dialer/doc.go new file mode 100644 index 0000000..0ead6b0 --- /dev/null +++ b/internal/engine/netx/dialer/doc.go @@ -0,0 +1,4 @@ +// Package dialer allows you to create a net.Dialer-compatible +// DialContext-enabled dialer with error wrapping, optional logging, +// optional network-events saving, and optional proxying. +package dialer diff --git a/internal/engine/netx/dialer/errorwrapper.go b/internal/engine/netx/dialer/errorwrapper.go index 23f2805..0d11cde 100644 --- a/internal/engine/netx/dialer/errorwrapper.go +++ b/internal/engine/netx/dialer/errorwrapper.go @@ -7,13 +7,13 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" ) -// ErrorWrapperDialer is a dialer that performs err wrapping -type ErrorWrapperDialer struct { +// errorWrapperDialer is a dialer that performs err wrapping +type errorWrapperDialer struct { Dialer } // DialContext implements Dialer.DialContext -func (d ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *errorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) err = errorx.SafeErrWrapperBuilder{ Error: err, @@ -22,16 +22,16 @@ func (d ErrorWrapperDialer) DialContext(ctx context.Context, network, address st if err != nil { return nil, err } - return &ErrorWrapperConn{Conn: conn}, nil + return &errorWrapperConn{Conn: conn}, nil } -// ErrorWrapperConn is a net.Conn that performs error wrapping. -type ErrorWrapperConn struct { +// errorWrapperConn is a net.Conn that performs error wrapping. +type errorWrapperConn struct { net.Conn } // Read implements net.Conn.Read -func (c ErrorWrapperConn) Read(b []byte) (n int, err error) { +func (c *errorWrapperConn) Read(b []byte) (n int, err error) { n, err = c.Conn.Read(b) err = errorx.SafeErrWrapperBuilder{ Error: err, @@ -41,7 +41,7 @@ func (c ErrorWrapperConn) Read(b []byte) (n int, err error) { } // Write implements net.Conn.Write -func (c ErrorWrapperConn) Write(b []byte) (n int, err error) { +func (c *errorWrapperConn) Write(b []byte) (n int, err error) { n, err = c.Conn.Write(b) err = errorx.SafeErrWrapperBuilder{ Error: err, @@ -51,7 +51,7 @@ func (c ErrorWrapperConn) Write(b []byte) (n int, err error) { } // Close implements net.Conn.Close -func (c ErrorWrapperConn) Close() (err error) { +func (c *errorWrapperConn) Close() (err error) { err = c.Conn.Close() err = errorx.SafeErrWrapperBuilder{ Error: err, diff --git a/internal/engine/netx/dialer/errorwrapper_test.go b/internal/engine/netx/dialer/errorwrapper_test.go index fc7f530..91b0ae1 100644 --- a/internal/engine/netx/dialer/errorwrapper_test.go +++ b/internal/engine/netx/dialer/errorwrapper_test.go @@ -1,4 +1,4 @@ -package dialer_test +package dialer import ( "context" @@ -7,14 +7,13 @@ import ( "net" "testing" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestErrorWrapperFailure(t *testing.T) { ctx := context.Background() - d := dialer.ErrorWrapperDialer{Dialer: mockablex.Dialer{ + d := &errorWrapperDialer{Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, @@ -44,7 +43,7 @@ func errorWrapperCheckErr(t *testing.T, err error, op string) { func TestErrorWrapperSuccess(t *testing.T) { ctx := context.Background() - d := dialer.ErrorWrapperDialer{Dialer: mockablex.Dialer{ + d := &errorWrapperDialer{Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return &mockablex.Conn{ MockRead: func(b []byte) (int, error) { diff --git a/internal/engine/netx/dialer/example_test.go b/internal/engine/netx/dialer/example_test.go new file mode 100644 index 0000000..c141354 --- /dev/null +++ b/internal/engine/netx/dialer/example_test.go @@ -0,0 +1,30 @@ +package dialer_test + +import ( + "context" + "net" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" +) + +func Example() { + saver := &trace.Saver{} + + dlr := dialer.New(&dialer.Config{ + DialSaver: saver, + Logger: log.Log, + ReadWriteSaver: saver, + }, &net.Resolver{}) + + ctx := context.Background() + conn, err := dlr.DialContext(ctx, "tcp", "8.8.8.8:53") + if err != nil { + log.WithError(err).Fatal("DialContext failed") + } + + // ... use the connection ... + + conn.Close() +} diff --git a/internal/engine/netx/dialer/integration_test.go b/internal/engine/netx/dialer/integration_test.go index df3fe4c..aa26b82 100644 --- a/internal/engine/netx/dialer/integration_test.go +++ b/internal/engine/netx/dialer/integration_test.go @@ -9,19 +9,13 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" ) -func TestDNSDialerSuccess(t *testing.T) { +func TestDialerNewSuccess(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } log.SetLevel(log.DebugLevel) - dialer := dialer.DNSDialer{ - Dialer: dialer.LoggingDialer{ - Dialer: new(net.Dialer), - Logger: log.Log, - }, - Resolver: new(net.Resolver), - } - txp := &http.Transport{DialContext: dialer.DialContext} + d := dialer.New(&dialer.Config{Logger: log.Log}, &net.Resolver{}) + txp := &http.Transport{DialContext: d.DialContext} client := &http.Client{Transport: txp} resp, err := client.Get("http://www.google.com") if err != nil { diff --git a/internal/engine/netx/dialer/logging.go b/internal/engine/netx/dialer/logging.go index 7c448ed..cb53b93 100644 --- a/internal/engine/netx/dialer/logging.go +++ b/internal/engine/netx/dialer/logging.go @@ -6,20 +6,14 @@ import ( "time" ) -// Logger is the logger assumed by this package -type Logger interface { - Debugf(format string, v ...interface{}) - Debug(message string) -} - -// LoggingDialer is a Dialer with logging -type LoggingDialer struct { +// loggingDialer is a Dialer with logging +type loggingDialer struct { Dialer Logger Logger } // DialContext implements Dialer.DialContext -func (d LoggingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *loggingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { d.Logger.Debugf("dial %s/%s...", address, network) start := time.Now() conn, err := d.Dialer.DialContext(ctx, network, address) diff --git a/internal/engine/netx/dialer/logging_test.go b/internal/engine/netx/dialer/logging_test.go index bfbf1e4..f1b88b3 100644 --- a/internal/engine/netx/dialer/logging_test.go +++ b/internal/engine/netx/dialer/logging_test.go @@ -1,4 +1,4 @@ -package dialer_test +package dialer import ( "context" @@ -8,12 +8,11 @@ import ( "testing" "github.com/apex/log" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestLoggingDialerFailure(t *testing.T) { - d := dialer.LoggingDialer{ + d := &loggingDialer{ Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF diff --git a/internal/engine/netx/dialer/proxy.go b/internal/engine/netx/dialer/proxy.go index 3d07320..ba4c768 100644 --- a/internal/engine/netx/dialer/proxy.go +++ b/internal/engine/netx/dialer/proxy.go @@ -9,10 +9,10 @@ import ( "golang.org/x/net/proxy" ) -// ProxyDialer is a dialer that uses a proxy. If the ProxyURL is not configured, this +// proxyDialer is a dialer that uses a proxy. If the ProxyURL is not configured, this // dialer is a passthrough for the next Dialer in chain. Otherwise, it will internally // create a SOCKS5 dialer that will connect to the proxy using the underlying Dialer. -type ProxyDialer struct { +type proxyDialer struct { Dialer ProxyURL *url.URL } @@ -21,7 +21,7 @@ type ProxyDialer struct { var ErrProxyUnsupportedScheme = errors.New("proxy: unsupported scheme") // DialContext implements Dialer.DialContext -func (d ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { url := d.ProxyURL if url == nil { return d.Dialer.DialContext(ctx, network, address) @@ -31,11 +31,11 @@ func (d ProxyDialer) DialContext(ctx context.Context, network, address string) ( } // the code at proxy/socks5.go never fails; see https://git.io/JfJ4g child, _ := proxy.SOCKS5( - network, url.Host, nil, proxyDialerWrapper{d.Dialer}) + network, url.Host, nil, &proxyDialerWrapper{d.Dialer}) return d.dial(ctx, child, network, address) } -func (d ProxyDialer) dial( +func (d *proxyDialer) dial( ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) { cd := child.(proxy.ContextDialer) // will work return cd.DialContext(ctx, network, address) @@ -50,6 +50,6 @@ type proxyDialerWrapper struct { Dialer } -func (d proxyDialerWrapper) Dial(network, address string) (net.Conn, error) { +func (d *proxyDialerWrapper) Dial(network, address string) (net.Conn, error) { panic(errors.New("proxyDialerWrapper.Dial should not be called directly")) } diff --git a/internal/engine/netx/dialer/proxy_test.go b/internal/engine/netx/dialer/proxy_test.go index 7172ee9..8ad7ee3 100644 --- a/internal/engine/netx/dialer/proxy_test.go +++ b/internal/engine/netx/dialer/proxy_test.go @@ -13,7 +13,7 @@ import ( func TestProxyDialerDialContextNoProxyURL(t *testing.T) { expected := errors.New("mocked error") - d := ProxyDialer{ + d := &proxyDialer{ Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, expected @@ -30,7 +30,7 @@ func TestProxyDialerDialContextNoProxyURL(t *testing.T) { } func TestProxyDialerDialContextInvalidScheme(t *testing.T) { - d := ProxyDialer{ + d := &proxyDialer{ ProxyURL: &url.URL{Scheme: "antani"}, } conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") @@ -44,7 +44,7 @@ func TestProxyDialerDialContextInvalidScheme(t *testing.T) { func TestProxyDialerDialContextWithEOF(t *testing.T) { const expect = "10.0.0.1:9050" - d := ProxyDialer{ + d := &proxyDialer{ Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { if address != expect { diff --git a/internal/engine/netx/dialer/saver.go b/internal/engine/netx/dialer/saver.go index 82592ef..682fab2 100644 --- a/internal/engine/netx/dialer/saver.go +++ b/internal/engine/netx/dialer/saver.go @@ -9,14 +9,14 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) -// SaverDialer saves events occurring during the dial -type SaverDialer struct { +// saverDialer saves events occurring during the dial +type saverDialer struct { Dialer Saver *trace.Saver } // DialContext implements Dialer.DialContext -func (d SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *saverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { start := time.Now() conn, err := d.Dialer.DialContext(ctx, network, address) stop := time.Now() @@ -31,20 +31,20 @@ func (d SaverDialer) DialContext(ctx context.Context, network, address string) ( return conn, err } -// SaverConnDialer wraps the returned connection such that we +// saverConnDialer wraps the returned connection such that we // collect all the read/write events that occur. -type SaverConnDialer struct { +type saverConnDialer struct { Dialer Saver *trace.Saver } // DialContext implements Dialer.DialContext -func (d SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *saverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { return nil, err } - return saverConn{saver: d.Saver, Conn: conn}, nil + return &saverConn{saver: d.Saver, Conn: conn}, nil } type saverConn struct { @@ -52,7 +52,7 @@ type saverConn struct { saver *trace.Saver } -func (c saverConn) Read(p []byte) (int, error) { +func (c *saverConn) Read(p []byte) (int, error) { start := time.Now() count, err := c.Conn.Read(p) stop := time.Now() @@ -67,7 +67,7 @@ func (c saverConn) Read(p []byte) (int, error) { return count, err } -func (c saverConn) Write(p []byte) (int, error) { +func (c *saverConn) Write(p []byte) (int, error) { start := time.Now() count, err := c.Conn.Write(p) stop := time.Now() @@ -82,5 +82,6 @@ func (c saverConn) Write(p []byte) (int, error) { return count, err } -var _ Dialer = SaverDialer{} -var _ net.Conn = saverConn{} +var _ Dialer = &saverDialer{} +var _ Dialer = &saverConnDialer{} +var _ net.Conn = &saverConn{} diff --git a/internal/engine/netx/dialer/saver_test.go b/internal/engine/netx/dialer/saver_test.go index 3dd59eb..4dbee07 100644 --- a/internal/engine/netx/dialer/saver_test.go +++ b/internal/engine/netx/dialer/saver_test.go @@ -1,4 +1,4 @@ -package dialer_test +package dialer import ( "context" @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" @@ -17,7 +16,7 @@ import ( func TestSaverDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} - dlr := dialer.SaverDialer{ + dlr := &saverDialer{ Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, expected @@ -59,7 +58,7 @@ func TestSaverDialerFailure(t *testing.T) { func TestSaverConnDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} - dlr := dialer.SaverConnDialer{ + dlr := &saverConnDialer{ Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, expected @@ -78,8 +77,8 @@ func TestSaverConnDialerFailure(t *testing.T) { func TestSaverConnDialerSuccess(t *testing.T) { saver := &trace.Saver{} - dlr := dialer.SaverConnDialer{ - Dialer: dialer.SaverDialer{ + dlr := &saverConnDialer{ + Dialer: &saverDialer{ Dialer: mockablex.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return &mockablex.Conn{ diff --git a/internal/engine/netx/dialer/shaping_disabled.go b/internal/engine/netx/dialer/shaping_disabled.go index 0d07bf6..1562921 100644 --- a/internal/engine/netx/dialer/shaping_disabled.go +++ b/internal/engine/netx/dialer/shaping_disabled.go @@ -7,15 +7,15 @@ import ( "net" ) -// ShapingDialer ensures we don't use too much bandwidth +// shapingDialer ensures we don't use too much bandwidth // when using integration tests at GitHub. To select // the implementation with shaping use `-tags shaping`. -type ShapingDialer struct { +type shapingDialer struct { Dialer } // DialContext implements Dialer.DialContext -func (d ShapingDialer) DialContext( +func (d *shapingDialer) DialContext( ctx context.Context, network, address string) (net.Conn, error) { return d.Dialer.DialContext(ctx, network, address) } diff --git a/internal/engine/netx/dialer/shaping_enabled.go b/internal/engine/netx/dialer/shaping_enabled.go index 00b27ca..7e0917c 100644 --- a/internal/engine/netx/dialer/shaping_enabled.go +++ b/internal/engine/netx/dialer/shaping_enabled.go @@ -8,15 +8,15 @@ import ( "time" ) -// ShapingDialer ensures we don't use too much bandwidth +// shapingDialer ensures we don't use too much bandwidth // when using integration tests at GitHub. To select // the implementation with shaping use `-tags shaping`. -type ShapingDialer struct { +type shapingDialer struct { Dialer } // DialContext implements Dialer.DialContext -func (d ShapingDialer) DialContext( +func (d *shapingDialer) DialContext( ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { @@ -29,12 +29,12 @@ type shapingConn struct { net.Conn } -func (c shapingConn) Read(p []byte) (int, error) { +func (c *shapingConn) Read(p []byte) (int, error) { time.Sleep(100 * time.Millisecond) return c.Conn.Read(p) } -func (c shapingConn) Write(p []byte) (int, error) { +func (c *shapingConn) Write(p []byte) (int, error) { time.Sleep(100 * time.Millisecond) return c.Conn.Write(p) } diff --git a/internal/engine/netx/dialer/shaping_test.go b/internal/engine/netx/dialer/shaping_test.go index e910c04..ffad493 100644 --- a/internal/engine/netx/dialer/shaping_test.go +++ b/internal/engine/netx/dialer/shaping_test.go @@ -1,20 +1,14 @@ -package dialer_test +package dialer import ( "net" "net/http" "testing" - - "github.com/ooni/probe-cli/v3/internal/engine/netx" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" ) -func TestGood(t *testing.T) { - txp := netx.NewHTTPTransport(netx.Config{ - Dialer: dialer.ShapingDialer{ - Dialer: new(net.Dialer), - }, - }) +func TestShapingDialerGood(t *testing.T) { + d := &shapingDialer{Dialer: &net.Dialer{}} + txp := &http.Transport{DialContext: d.DialContext} client := &http.Client{Transport: txp} resp, err := client.Get("https://www.google.com/") if err != nil { diff --git a/internal/engine/netx/dialer/system.go b/internal/engine/netx/dialer/system.go index ea22148..37f2258 100644 --- a/internal/engine/netx/dialer/system.go +++ b/internal/engine/netx/dialer/system.go @@ -1,24 +1,12 @@ package dialer import ( - "context" "net" "time" ) -// underlyingDialer is the underlying net.Dialer. -var underlyingDialer = &net.Dialer{ +// systemDialer is the underlying net.Dialer. +var systemDialer = &net.Dialer{ Timeout: 15 * time.Second, KeepAlive: 15 * time.Second, } - -// SystemDialer is the system dialer. -type SystemDialer struct{} - -// DialContext implements Dialer.DialContext -func (d SystemDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return underlyingDialer.DialContext(ctx, network, address) -} - -// Default is the dialer we use by default. -var Default = SystemDialer{} diff --git a/internal/engine/netx/dialer/system_test.go b/internal/engine/netx/dialer/system_test.go index 08bbbcf..a21073a 100644 --- a/internal/engine/netx/dialer/system_test.go +++ b/internal/engine/netx/dialer/system_test.go @@ -11,7 +11,7 @@ import ( func TestSystemDialerWorks(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // fail immediately - conn, err := Default.DialContext(ctx, "tcp", "8.8.8.8:853") + conn, err := systemDialer.DialContext(ctx, "tcp", "8.8.8.8:853") if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { t.Fatal("not the error we expected", err) } @@ -20,9 +20,9 @@ func TestSystemDialerWorks(t *testing.T) { } } -func TestUnderlyingDialerHasTimeout(t *testing.T) { +func TestSystemDialerHasTimeout(t *testing.T) { expected := 15 * time.Second - if underlyingDialer.Timeout != expected { + if systemDialer.Timeout != expected { t.Fatal("unexpected timeout value") } } diff --git a/internal/engine/netx/httptransport/http3transport_test.go b/internal/engine/netx/httptransport/http3transport_test.go index 7042385..3683c0b 100644 --- a/internal/engine/netx/httptransport/http3transport_test.go +++ b/internal/engine/netx/httptransport/http3transport_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "errors" + "net" "net/http" "strings" "testing" @@ -42,7 +43,9 @@ func TestHTTP3TransportSNI(t *testing.T) { namech := make(chan string, 1) sni := "sni.org" txp := httptransport.NewHTTP3Transport(httptransport.Config{ - Dialer: dialer.Default, QUICDialer: MockSNIQUICDialer{namech: namech}, TLSConfig: &tls.Config{ServerName: sni}}) + Dialer: dialer.New(&dialer.Config{}, &net.Resolver{}), + QUICDialer: MockSNIQUICDialer{namech: namech}, + TLSConfig: &tls.Config{ServerName: sni}}) req, err := http.NewRequest("GET", "https://www.google.com", nil) if err != nil { t.Fatal(err) @@ -67,7 +70,9 @@ func TestHTTP3TransportSNINoVerify(t *testing.T) { namech := make(chan string, 1) sni := "sni.org" txp := httptransport.NewHTTP3Transport(httptransport.Config{ - Dialer: dialer.Default, QUICDialer: MockSNIQUICDialer{namech: namech}, TLSConfig: &tls.Config{ServerName: sni, InsecureSkipVerify: true}}) + Dialer: dialer.New(&dialer.Config{}, &net.Resolver{}), + QUICDialer: MockSNIQUICDialer{namech: namech}, + TLSConfig: &tls.Config{ServerName: sni, InsecureSkipVerify: true}}) req, err := http.NewRequest("GET", "https://www.google.com", nil) if err != nil { t.Fatal(err) @@ -89,7 +94,9 @@ func TestHTTP3TransportCABundle(t *testing.T) { certch := make(chan *x509.CertPool, 1) certpool := x509.NewCertPool() txp := httptransport.NewHTTP3Transport(httptransport.Config{ - Dialer: dialer.Default, QUICDialer: MockCertQUICDialer{certch: certch}, TLSConfig: &tls.Config{RootCAs: certpool}}) + Dialer: dialer.New(&dialer.Config{}, &net.Resolver{}), + QUICDialer: MockCertQUICDialer{certch: certch}, + TLSConfig: &tls.Config{RootCAs: certpool}}) req, err := http.NewRequest("GET", "https://www.google.com", nil) if err != nil { t.Fatal(err) @@ -114,7 +121,8 @@ func TestHTTP3TransportCABundle(t *testing.T) { func TestUnitHTTP3TransportSuccess(t *testing.T) { txp := httptransport.NewHTTP3Transport(httptransport.Config{ - Dialer: dialer.Default, QUICDialer: MockQUICDialer{}}) + Dialer: dialer.New(&dialer.Config{}, &net.Resolver{}), + QUICDialer: MockQUICDialer{}}) req, err := http.NewRequest("GET", "https://www.google.com", nil) if err != nil { @@ -134,7 +142,8 @@ func TestUnitHTTP3TransportSuccess(t *testing.T) { func TestUnitHTTP3TransportFailure(t *testing.T) { txp := httptransport.NewHTTP3Transport(httptransport.Config{ - Dialer: dialer.Default, QUICDialer: MockQUICDialer{}}) + Dialer: dialer.New(&dialer.Config{}, &net.Resolver{}), + QUICDialer: MockQUICDialer{}}) ctx, cancel := context.WithCancel(context.Background()) cancel() // so that the request immediately fails diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 8951da7..3a228c6 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -146,24 +146,13 @@ func NewDialer(config Config) Dialer { if config.FullResolver == nil { config.FullResolver = NewResolver(config) } - var d Dialer = dialer.Default - d = dialer.ErrorWrapperDialer{Dialer: d} - if config.Logger != nil { - d = dialer.LoggingDialer{Dialer: d, Logger: config.Logger} - } - if config.DialSaver != nil { - d = dialer.SaverDialer{Dialer: d, Saver: config.DialSaver} - } - if config.ReadWriteSaver != nil { - d = dialer.SaverConnDialer{Dialer: d, Saver: config.ReadWriteSaver} - } - d = dialer.DNSDialer{Resolver: config.FullResolver, Dialer: d} - d = dialer.ProxyDialer{ProxyURL: config.ProxyURL, Dialer: d} - if config.ContextByteCounting { - d = dialer.ByteCounterDialer{Dialer: d} - } - d = dialer.ShapingDialer{Dialer: d} - return d + return dialer.New(&dialer.Config{ + ContextByteCounting: config.ContextByteCounting, + DialSaver: config.DialSaver, + Logger: config.Logger, + ProxyURL: config.ProxyURL, + ReadWriteSaver: config.ReadWriteSaver, + }, config.FullResolver) } // NewQUICDialer creates a new DNS Dialer for QUIC, with the resolver from the specified config diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 6504277..973db12 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -10,7 +10,6 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport" "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" @@ -210,257 +209,6 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { } } -func TestNewDialerVanilla(t *testing.T) { - d := netx.NewDialer(netx.Config{}) - sd, ok := d.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - pd, ok := sd.Dialer.(dialer.ProxyDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if pd.ProxyURL != nil { - t.Fatal("not the proxy URL we expected") - } - dnsd, ok := pd.Dialer.(dialer.DNSDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if dnsd.Resolver == nil { - t.Fatal("not the resolver we expected") - } - ir, ok := dnsd.Resolver.(resolver.IDNAResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { - t.Fatal("not the resolver we expected") - } - ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := ewd.Dialer.(dialer.SystemDialer); !ok { - t.Fatal("not the dialer we expected") - } -} - -func TestNewDialerWithResolver(t *testing.T) { - d := netx.NewDialer(netx.Config{ - FullResolver: resolver.BogonResolver{ - // not initialized because it doesn't matter in this context - }, - }) - sd, ok := d.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - pd, ok := sd.Dialer.(dialer.ProxyDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if pd.ProxyURL != nil { - t.Fatal("not the proxy URL we expected") - } - dnsd, ok := pd.Dialer.(dialer.DNSDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if dnsd.Resolver == nil { - t.Fatal("not the resolver we expected") - } - if _, ok := dnsd.Resolver.(resolver.BogonResolver); !ok { - t.Fatal("not the resolver we expected") - } - ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := ewd.Dialer.(dialer.SystemDialer); !ok { - t.Fatal("not the dialer we expected") - } -} - -func TestNewDialerWithLogger(t *testing.T) { - d := netx.NewDialer(netx.Config{ - Logger: log.Log, - }) - sd, ok := d.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - pd, ok := sd.Dialer.(dialer.ProxyDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if pd.ProxyURL != nil { - t.Fatal("not the proxy URL we expected") - } - dnsd, ok := pd.Dialer.(dialer.DNSDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if dnsd.Resolver == nil { - t.Fatal("not the resolver we expected") - } - ir, ok := dnsd.Resolver.(resolver.IDNAResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - if _, ok := ir.Resolver.(resolver.LoggingResolver); !ok { - t.Fatal("not the resolver we expected") - } - ld, ok := dnsd.Dialer.(dialer.LoggingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if ld.Logger != log.Log { - t.Fatal("not the logger we expected") - } - ewd, ok := ld.Dialer.(dialer.ErrorWrapperDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := ewd.Dialer.(dialer.SystemDialer); !ok { - t.Fatal("not the dialer we expected") - } -} - -func TestNewDialerWithDialSaver(t *testing.T) { - saver := new(trace.Saver) - d := netx.NewDialer(netx.Config{ - DialSaver: saver, - }) - sd, ok := d.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - pd, ok := sd.Dialer.(dialer.ProxyDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if pd.ProxyURL != nil { - t.Fatal("not the proxy URL we expected") - } - dnsd, ok := pd.Dialer.(dialer.DNSDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if dnsd.Resolver == nil { - t.Fatal("not the resolver we expected") - } - ir, ok := dnsd.Resolver.(resolver.IDNAResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { - t.Fatal("not the resolver we expected") - } - sad, ok := dnsd.Dialer.(dialer.SaverDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if sad.Saver != saver { - t.Fatal("not the logger we expected") - } - ewd, ok := sad.Dialer.(dialer.ErrorWrapperDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := ewd.Dialer.(dialer.SystemDialer); !ok { - t.Fatal("not the dialer we expected") - } -} - -func TestNewDialerWithReadWriteSaver(t *testing.T) { - saver := new(trace.Saver) - d := netx.NewDialer(netx.Config{ - ReadWriteSaver: saver, - }) - sd, ok := d.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - pd, ok := sd.Dialer.(dialer.ProxyDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if pd.ProxyURL != nil { - t.Fatal("not the proxy URL we expected") - } - dnsd, ok := pd.Dialer.(dialer.DNSDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if dnsd.Resolver == nil { - t.Fatal("not the resolver we expected") - } - ir, ok := dnsd.Resolver.(resolver.IDNAResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { - t.Fatal("not the resolver we expected") - } - scd, ok := dnsd.Dialer.(dialer.SaverConnDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if scd.Saver != saver { - t.Fatal("not the logger we expected") - } - ewd, ok := scd.Dialer.(dialer.ErrorWrapperDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := ewd.Dialer.(dialer.SystemDialer); !ok { - t.Fatal("not the dialer we expected") - } -} - -func TestNewDialerWithContextByteCounting(t *testing.T) { - d := netx.NewDialer(netx.Config{ - ContextByteCounting: true, - }) - sd, ok := d.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - bcd, ok := sd.Dialer.(dialer.ByteCounterDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - pd, ok := bcd.Dialer.(dialer.ProxyDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if pd.ProxyURL != nil { - t.Fatal("not the proxy URL we expected") - } - dnsd, ok := pd.Dialer.(dialer.DNSDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if dnsd.Resolver == nil { - t.Fatal("not the resolver we expected") - } - ir, ok := dnsd.Resolver.(resolver.IDNAResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { - t.Fatal("not the resolver we expected") - } - ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := ewd.Dialer.(dialer.SystemDialer); !ok { - t.Fatal("not the dialer we expected") - } -} - func TestNewTLSDialerVanilla(t *testing.T) { td := netx.NewTLSDialer(netx.Config{}) rtd, ok := td.(tlsdialer.TLSDialer) @@ -479,13 +227,6 @@ func TestNewTLSDialerVanilla(t *testing.T) { if rtd.Dialer == nil { t.Fatal("invalid Dialer") } - sd, ok := rtd.Dialer.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { - t.Fatal("not the Dialer we expected") - } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } @@ -519,13 +260,6 @@ func TestNewTLSDialerWithConfig(t *testing.T) { if rtd.Dialer == nil { t.Fatal("invalid Dialer") } - sd, ok := rtd.Dialer.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { - t.Fatal("not the Dialer we expected") - } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } @@ -562,13 +296,6 @@ func TestNewTLSDialerWithLogging(t *testing.T) { if rtd.Dialer == nil { t.Fatal("invalid Dialer") } - sd, ok := rtd.Dialer.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { - t.Fatal("not the Dialer we expected") - } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } @@ -613,13 +340,6 @@ func TestNewTLSDialerWithSaver(t *testing.T) { if rtd.Dialer == nil { t.Fatal("invalid Dialer") } - sd, ok := rtd.Dialer.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { - t.Fatal("not the Dialer we expected") - } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } @@ -664,13 +384,6 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { if rtd.Dialer == nil { t.Fatal("invalid Dialer") } - sd, ok := rtd.Dialer.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { - t.Fatal("not the Dialer we expected") - } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } @@ -710,13 +423,6 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { if rtd.Dialer == nil { t.Fatal("invalid Dialer") } - sd, ok := rtd.Dialer.(dialer.ShapingDialer) - if !ok { - t.Fatal("not the dialer we expected") - } - if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { - t.Fatal("not the Dialer we expected") - } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go index 7cbaa16..8ad79c4 100644 --- a/internal/engine/netx/tlsdialer/saver_test.go +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -23,10 +23,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { saver := &trace.Saver{} tlsdlr := tlsdialer.TLSDialer{ Config: &tls.Config{NextProtos: nextprotos}, - Dialer: dialer.SaverConnDialer{ - Dialer: new(net.Dialer), - Saver: saver, - }, + Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, Saver: saver,