diff --git a/docs/design/dd-003-step-by-step.md b/docs/design/dd-003-step-by-step.md index 258234b..4baa56a 100644 --- a/docs/design/dd-003-step-by-step.md +++ b/docs/design/dd-003-step-by-step.md @@ -874,7 +874,7 @@ mechanisms if the context is too bad): // lookupHost issues a lookup host query for the specified qtype (e.g., dns.A). func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string, - qtype uint16, out chan\<- \*parallelResolverResult) { + qtype uint16, out chan<- *parallelResolverResult) { encoder := &DNSEncoderMiekg{} query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) + started := time.Now() @@ -1329,7 +1329,7 @@ const webDomain = "web.telegram.org" // This method does not return any value and writes results directly inside // the test keys, which have thread safe methods for that. func (mx *Measurer) measureWebEndpointHTTPS(ctx context.Context, wg *sync.WaitGroup, -logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) { +logger model.Logger, zeroTime time.Time, tk *TestKeys, address string) { // 0. setup const webTimeout = 7 * time.Second ctx, cancel := context.WithTimeout(ctx, webTimeout) @@ -1344,8 +1344,8 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) { // 1. establish a TCP connection with the endpoint // dialer := nextlite.NewDialerwithoutResolver(logger) // --- (removed line) - trace := measurexlite.NewTrace(index, logger, zeroTime) // +++ (added line) - dialer := trace.NewDialerWithoutResolver() // +++ (...) + trace := measurexlite.NewTrace(index, zeroTime) // +++ (added line) + dialer := trace.NewDialerWithoutResolver(logger) // +++ (...) defer tk.addTCPConnectResults(trace.TCPConnectResults()) // +++ (...) conn, err := dialer.DialContext(ctx, "tcp", endpoint) @@ -1366,7 +1366,7 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) { // thx := netxlite.NewTLSHandshakerStdlib(logger) // --- conn = trace.WrapConn(conn) // +++ defer tk.addNetworkEvents(trace.NetworkEvents()) // +++ - thx := trace.NewTLSHandshakerStdlib() // +++ + thx := trace.NewTLSHandshakerStdlib(logger) // +++ defer tk.addTLSHandshakeResult(trace.TLSHandshakeResults()) // +++ config := &tls.Config{ @@ -1498,23 +1498,22 @@ type Trace struct { /* ... */ } var _ model.Trace = &Trace{} -func NewTrace(index int64, logger model.Logger, zeroTime time.Time) *Trace { +func NewTrace(index int64, zeroTime time.Time) *Trace { const ( tlsHandshakeBuffer = 16, // ... ) return &Trace{ Index: index, - Logger: logger, TLS: make(chan *model.ArchivalTLSOrQUICHandshakeResult, tlsHandshakeBuffer), ZeroTime: zeroTime, // ... } } -func (tx *Trace) NewTLSHandshakerStdlib() model.TLSHandshaker { +func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { return &tlsHandshakerTrace{ - TLSHandshaker: netxlite.NewTLSHandshakerStdlib(tx.Logger), + TLSHandshaker: netxlite.NewTLSHandshakerStdlib(dl), Trace: tx, } } @@ -1524,8 +1523,8 @@ type tlsHandshakerTrace { /* ... */ } var _ model.TLSHandshaker = &tlsHandshakerTrace{} func (thx *tlsHandshakerTrace) Handshake(ctx context.Context, - conn net.Conn, config \*tls.Config) (net.Conn, tls.ConnectionState, error) { - ctx = netxlite.ContextWithTrace(ctx, thx.Trace) // <- here we setup the context magic + conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + ctx = netxlite.ContextWithTraceOrDefault(ctx, thx.Trace) // <- here we setup the context magic return thx.TLSHandshaker.Handshake(ctx, conn, config) } @@ -1540,7 +1539,7 @@ func (tx *Trace) OnTLSHandshake(started time.Time, remoteAddr string, } } -func (tx *Trace) TLSHandshakeResults() (out []\*model.ArchivalTLSOrQUICHandshakeResult) { +func (tx *Trace) TLSHandshakeResults() (out []*model.ArchivalTLSOrQUICHandshakeResult) { for { select { case ev := <-tx.TLS: @@ -1575,7 +1574,7 @@ func (thx *tlsHandshakerConfigurable) Handshake(ctx context.Context, state := tlsconn.connectionState() trace.OnTLSHandshake(started, remoteAddr, config, // +++ state, nil, finished) // +++ - return tlsconn, state, nil + return tlsconn, state, nil } func ContextTraceOrDefault(ctx context.Context) model.Trace { /* ... */ } diff --git a/internal/engine/experiment/tcpping/tcpping.go b/internal/engine/experiment/tcpping/tcpping.go index ef56a01..821aea8 100644 --- a/internal/engine/experiment/tcpping/tcpping.go +++ b/internal/engine/experiment/tcpping/tcpping.go @@ -10,13 +10,13 @@ import ( "net/url" "time" - "github.com/ooni/probe-cli/v3/internal/measurex" + "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" ) const ( testName = "tcpping" - testVersion = "0.1.0" + testVersion = "0.2.0" ) // Config contains the experiment configuration. @@ -49,7 +49,7 @@ type TestKeys struct { // SinglePing contains the results of a single ping. type SinglePing struct { - TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"` + TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"` } // Measurer performs the measurement. @@ -103,42 +103,47 @@ func (m *Measurer) Run( } tk := new(TestKeys) measurement.TestKeys = tk - out := make(chan *measurex.EndpointMeasurement) - mxmx := measurex.NewMeasurerWithDefaultSettings() - go m.tcpPingLoop(ctx, mxmx, parsed.Host, out) + out := make(chan *SinglePing) + go m.tcpPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out) for len(tk.Pings) < int(m.config.repetitions()) { - meas := <-out - tk.Pings = append(tk.Pings, &SinglePing{ - TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect), - }) + tk.Pings = append(tk.Pings, <-out) } return nil // return nil so we always submit the measurement } // tcpPingLoop sends all the ping requests and emits the results onto the out channel. -func (m *Measurer) tcpPingLoop(ctx context.Context, mxmx *measurex.Measurer, - address string, out chan<- *measurex.EndpointMeasurement) { +func (m *Measurer) tcpPingLoop(ctx context.Context, zeroTime time.Time, + logger model.Logger, address string, out chan<- *SinglePing) { ticker := time.NewTicker(m.config.delay()) defer ticker.Stop() for i := int64(0); i < m.config.repetitions(); i++ { - go m.tcpPingAsync(ctx, mxmx, address, out) + go m.tcpPingAsync(ctx, i, zeroTime, logger, address, out) <-ticker.C } } // tcpPingAsync performs a TCP ping and emits the result onto the out channel. -func (m *Measurer) tcpPingAsync(ctx context.Context, mxmx *measurex.Measurer, - address string, out chan<- *measurex.EndpointMeasurement) { - out <- m.tcpConnect(ctx, mxmx, address) +func (m *Measurer) tcpPingAsync(ctx context.Context, index int64, + zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) { + out <- m.tcpConnect(ctx, index, zeroTime, logger, address) } // tcpConnect performs a TCP connect and returns the result to the caller. -func (m *Measurer) tcpConnect(ctx context.Context, mxmx *measurex.Measurer, - address string) *measurex.EndpointMeasurement { +func (m *Measurer) tcpConnect(ctx context.Context, index int64, + zeroTime time.Time, logger model.Logger, address string) *SinglePing { // TODO(bassosimone): make the timeout user-configurable ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - return mxmx.TCPConnect(ctx, address) + trace := measurexlite.NewTrace(index, zeroTime) + dialer := trace.NewDialerWithoutResolver(logger) + ol := measurexlite.NewOperationLogger(logger, "TCPPing #%d %s", index, address) + conn, err := dialer.DialContext(ctx, "tcp", address) + ol.Stop(err) + measurexlite.MaybeClose(conn) + sp := &SinglePing{ + TCPConnect: <-trace.TCPConnect, + } + return sp } // NewExperimentMeasurer creates a new ExperimentMeasurer. diff --git a/internal/engine/experiment/tcpping/tcpping_test.go b/internal/engine/experiment/tcpping/tcpping_test.go index ec7f915..faec432 100644 --- a/internal/engine/experiment/tcpping/tcpping_test.go +++ b/internal/engine/experiment/tcpping/tcpping_test.go @@ -40,14 +40,16 @@ func TestMeasurer_run(t *testing.T) { if m.ExperimentName() != "tcpping" { t.Fatal("invalid experiment name") } - if m.ExperimentVersion() != "0.1.0" { + if m.ExperimentVersion() != "0.2.0" { t.Fatal("invalid experiment version") } ctx := context.Background() meas := &model.Measurement{ Input: model.MeasurementTarget(input), } - sess := &mockable.Session{} + sess := &mockable.Session{ + MockableLogger: model.DiscardLogger, + } callbacks := model.NewPrinterCallbacks(model.DiscardLogger) err := m.Run(ctx, sess, meas, callbacks) return meas, m, err diff --git a/internal/engine/experiment/tlsping/tlsping.go b/internal/engine/experiment/tlsping/tlsping.go index 8049cc9..5e2a71a 100644 --- a/internal/engine/experiment/tlsping/tlsping.go +++ b/internal/engine/experiment/tlsping/tlsping.go @@ -13,14 +13,14 @@ import ( "strings" "time" - "github.com/ooni/probe-cli/v3/internal/measurex" + "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) const ( testName = "tlsping" - testVersion = "0.1.0" + testVersion = "0.2.0" ) // Config contains the experiment configuration. @@ -77,9 +77,9 @@ type TestKeys struct { // SinglePing contains the results of a single ping. type SinglePing struct { - NetworkEvents []*measurex.ArchivalNetworkEvent `json:"network_events"` - TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"` - TLSHandshakes []*measurex.ArchivalQUICTLSHandshakeEvent `json:"tls_handshakes"` + NetworkEvents []*model.ArchivalNetworkEvent `json:"network_events"` + TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"` + TLSHandshake *model.ArchivalTLSOrQUICHandshakeResult `json:"tls_handshake"` } // Measurer performs the measurement. @@ -133,49 +133,67 @@ func (m *Measurer) Run( } tk := new(TestKeys) measurement.TestKeys = tk - out := make(chan *measurex.EndpointMeasurement) - mxmx := measurex.NewMeasurerWithDefaultSettings() - go m.tlsPingLoop(ctx, mxmx, parsed.Host, out) + out := make(chan *SinglePing) + go m.tlsPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out) for len(tk.Pings) < int(m.config.repetitions()) { - meas := <-out - tk.Pings = append(tk.Pings, &SinglePing{ - NetworkEvents: measurex.NewArchivalNetworkEventList(meas.ReadWrite), - TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect), - TLSHandshakes: measurex.NewArchivalQUICTLSHandshakeEventList(meas.TLSHandshake), - }) + tk.Pings = append(tk.Pings, <-out) } return nil // return nil so we always submit the measurement } // tlsPingLoop sends all the ping requests and emits the results onto the out channel. -func (m *Measurer) tlsPingLoop(ctx context.Context, mxmx *measurex.Measurer, - address string, out chan<- *measurex.EndpointMeasurement) { +func (m *Measurer) tlsPingLoop(ctx context.Context, zeroTime time.Time, + logger model.Logger, address string, out chan<- *SinglePing) { ticker := time.NewTicker(m.config.delay()) defer ticker.Stop() for i := int64(0); i < m.config.repetitions(); i++ { - go m.tlsPingAsync(ctx, mxmx, address, out) + go m.tlsPingAsync(ctx, i, zeroTime, logger, address, out) <-ticker.C } } // tlsPingAsync performs a TLS ping and emits the result onto the out channel. -func (m *Measurer) tlsPingAsync(ctx context.Context, mxmx *measurex.Measurer, - address string, out chan<- *measurex.EndpointMeasurement) { - out <- m.tlsConnectAndHandshake(ctx, mxmx, address) +func (m *Measurer) tlsPingAsync(ctx context.Context, index int64, + zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) { + out <- m.tlsConnectAndHandshake(ctx, index, zeroTime, logger, address) } // tlsConnectAndHandshake performs a TCP connect followed by a TLS handshake // and returns the results of these operations to the caller. -func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, mxmx *measurex.Measurer, - address string) *measurex.EndpointMeasurement { +func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, index int64, + zeroTime time.Time, logger model.Logger, address string) *SinglePing { // TODO(bassosimone): make the timeout user-configurable ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - return mxmx.TLSConnectAndHandshake(ctx, address, &tls.Config{ - NextProtos: strings.Split(m.config.alpn(), " "), + sp := &SinglePing{ + NetworkEvents: []*model.ArchivalNetworkEvent{}, + TCPConnect: nil, + TLSHandshake: nil, + } + trace := measurexlite.NewTrace(index, zeroTime) + dialer := trace.NewDialerWithoutResolver(logger) + alpn := strings.Split(m.config.alpn(), " ") + sni := m.config.sni(address) + ol := measurexlite.NewOperationLogger(logger, "TLSPing #%d %s %s %v", index, address, sni, alpn) + conn, err := dialer.DialContext(ctx, "tcp", address) + sp.TCPConnect = <-trace.TCPConnect + if err != nil { + ol.Stop(err) + return sp + } + defer conn.Close() + conn = trace.WrapNetConn(conn) + thx := trace.NewTLSHandshakerStdlib(logger) + config := &tls.Config{ + NextProtos: alpn, RootCAs: netxlite.NewDefaultCertPool(), - ServerName: m.config.sni(address), - }) + ServerName: sni, + } + _, _, err = thx.Handshake(ctx, conn, config) + ol.Stop(err) + sp.TLSHandshake = <-trace.TLSHandshake + sp.NetworkEvents = trace.NetworkEvents() + return sp } // NewExperimentMeasurer creates a new ExperimentMeasurer. diff --git a/internal/engine/experiment/tlsping/tlsping_test.go b/internal/engine/experiment/tlsping/tlsping_test.go index eeb050e..716549f 100644 --- a/internal/engine/experiment/tlsping/tlsping_test.go +++ b/internal/engine/experiment/tlsping/tlsping_test.go @@ -39,7 +39,7 @@ func TestMeasurer_run(t *testing.T) { const expectedPings = 4 // runHelper is an helper function to run this set of tests. - runHelper := func(input string) (*model.Measurement, model.ExperimentMeasurer, error) { + runHelper := func(ctx context.Context, input string) (*model.Measurement, model.ExperimentMeasurer, error) { m := NewExperimentMeasurer(Config{ ALPN: "http/1.1", Delay: 1, // millisecond @@ -48,10 +48,9 @@ func TestMeasurer_run(t *testing.T) { if m.ExperimentName() != "tlsping" { t.Fatal("invalid experiment name") } - if m.ExperimentVersion() != "0.1.0" { + if m.ExperimentVersion() != "0.2.0" { t.Fatal("invalid experiment version") } - ctx := context.Background() meas := &model.Measurement{ Input: model.MeasurementTarget(input), } @@ -64,34 +63,34 @@ func TestMeasurer_run(t *testing.T) { } t.Run("with empty input", func(t *testing.T) { - _, _, err := runHelper("") + _, _, err := runHelper(context.Background(), "") if !errors.Is(err, errNoInputProvided) { t.Fatal("unexpected error", err) } }) t.Run("with invalid URL", func(t *testing.T) { - _, _, err := runHelper("\t") + _, _, err := runHelper(context.Background(), "\t") if !errors.Is(err, errInputIsNotAnURL) { t.Fatal("unexpected error", err) } }) t.Run("with invalid scheme", func(t *testing.T) { - _, _, err := runHelper("https://8.8.8.8:443/") + _, _, err := runHelper(context.Background(), "https://8.8.8.8:443/") if !errors.Is(err, errInvalidScheme) { t.Fatal("unexpected error", err) } }) t.Run("with missing port", func(t *testing.T) { - _, _, err := runHelper("tlshandshake://8.8.8.8") + _, _, err := runHelper(context.Background(), "tlshandshake://8.8.8.8") if !errors.Is(err, errMissingPort) { t.Fatal("unexpected error", err) } }) - t.Run("with local listener", func(t *testing.T) { + t.Run("with local listener and successful outcome", func(t *testing.T) { srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) })) @@ -101,7 +100,37 @@ func TestMeasurer_run(t *testing.T) { t.Fatal(err) } URL.Scheme = "tlshandshake" - meas, m, err := runHelper(URL.String()) + meas, m, err := runHelper(context.Background(), URL.String()) + if err != nil { + t.Fatal(err) + } + tk := meas.TestKeys.(*TestKeys) + if len(tk.Pings) != expectedPings { + t.Fatal("unexpected number of pings") + } + ask, err := m.GetSummaryKeys(meas) + if err != nil { + t.Fatal("cannot obtain summary") + } + summary := ask.(SummaryKeys) + if summary.IsAnomaly { + t.Fatal("expected no anomaly") + } + }) + + t.Run("with local listener and connect issues", func(t *testing.T) { + srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer srvr.Close() + URL, err := url.Parse(srvr.URL) + if err != nil { + t.Fatal(err) + } + URL.Scheme = "tlshandshake" + ctx, cancel := context.WithCancel(context.Background()) + cancel() // so we cannot dial any connection + meas, m, err := runHelper(ctx, URL.String()) if err != nil { t.Fatal(err) } diff --git a/internal/engine/experiment/urlgetter/urlgetter.go b/internal/engine/experiment/urlgetter/urlgetter.go index afd8ffe..8a6593c 100644 --- a/internal/engine/experiment/urlgetter/urlgetter.go +++ b/internal/engine/experiment/urlgetter/urlgetter.go @@ -1,6 +1,11 @@ // Package urlgetter implements a nettest that fetches a URL. // // See https://github.com/ooni/spec/blob/master/nettests/ts-027-urlgetter.md. +// +// This package is now frozen. Please, use measurexlite for new code. New +// network experiments should not depend on this package. Please see +// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md +// for details about this. package urlgetter import ( diff --git a/internal/engine/netx/doc.go b/internal/engine/netx/doc.go index 1e4edbf..efe4d64 100644 --- a/internal/engine/netx/doc.go +++ b/internal/engine/netx/doc.go @@ -46,4 +46,8 @@ // // See docs/design/dd-002-nets.md in the probe-cli repository for // the design document describing this package. +// +// This package is now frozen. Please, use measurexlite for new code. See +// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md +// for details about this. package netx diff --git a/internal/httpx/httpx_test.go b/internal/httpx/httpx_test.go index 399204a..fdb8b88 100644 --- a/internal/httpx/httpx_test.go +++ b/internal/httpx/httpx_test.go @@ -13,10 +13,10 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/ooni/probe-cli/v3/internal/fakefill" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/testingx" "github.com/ooni/probe-cli/v3/internal/version" ) @@ -47,7 +47,7 @@ func TestAPIClientTemplate(t *testing.T) { HTTPClient: http.DefaultClient, Logger: model.DiscardLogger, } - ff := &fakefill.Filler{} + ff := &testingx.FakeFiller{} ff.Fill(tmpl) ac := tmpl.Build() orig := apiClient(*tmpl) @@ -64,7 +64,7 @@ func TestAPIClientTemplate(t *testing.T) { HTTPClient: http.DefaultClient, Logger: model.DiscardLogger, } - ff := &fakefill.Filler{} + ff := &testingx.FakeFiller{} ff.Fill(tmpl) tok := "" ff.Fill(&tok) @@ -188,7 +188,7 @@ func TestAPIClient(t *testing.T) { t.Run("sets the content-type properly", func(t *testing.T) { var jsonReq fakeRequest - ff := &fakefill.Filler{} + ff := &testingx.FakeFiller{} ff.Fill(&jsonReq) client := newAPIClient() req, err := client.newRequestWithJSONBody( diff --git a/internal/measurex/doc.go b/internal/measurex/doc.go index f902e5e..406f43d 100644 --- a/internal/measurex/doc.go +++ b/internal/measurex/doc.go @@ -1,2 +1,6 @@ // Package measurex contains measurement extensions. +// +// This package is now frozen. Please, use measurexlite for new code. See +// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md +// for details about this. package measurex diff --git a/internal/measurexlite/conn.go b/internal/measurexlite/conn.go new file mode 100644 index 0000000..a688bd4 --- /dev/null +++ b/internal/measurexlite/conn.go @@ -0,0 +1,103 @@ +package measurexlite + +// +// Conn tracing +// + +import ( + "net" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/tracex" +) + +// MaybeClose is a convenience function for closing a conn only when such a conn isn't nil. +func MaybeClose(conn net.Conn) (err error) { + if conn != nil { + err = conn.Close() + } + return +} + +// WrapNetConn returns a wrapped conn that saves network events into this trace. +func (tx *Trace) WrapNetConn(conn net.Conn) net.Conn { + return &connTrace{ + Conn: conn, + tx: tx, + } +} + +// connTrace is a trace-aware net.Conn. +type connTrace struct { + // Implementation note: it seems safe to use embedding here because net.Conn + // is an interface from the standard library that we don't control + net.Conn + tx *Trace +} + +var _ net.Conn = &connTrace{} + +// Read implements net.Conn.Read and saves network events. +func (c *connTrace) Read(b []byte) (int, error) { + network := c.RemoteAddr().Network() + addr := c.RemoteAddr().String() + started := c.tx.TimeSince(c.tx.ZeroTime) + count, err := c.Conn.Read(b) + finished := c.tx.TimeSince(c.tx.ZeroTime) + select { + case c.tx.NetworkEvent <- NewArchivalNetworkEvent( + c.tx.Index, started, netxlite.ReadOperation, network, addr, count, err, finished): + default: // buffer is full + } + return count, err +} + +// Write implements net.Conn.Write and saves network events. +func (c *connTrace) Write(b []byte) (int, error) { + network := c.RemoteAddr().Network() + addr := c.RemoteAddr().String() + started := c.tx.TimeSince(c.tx.ZeroTime) + count, err := c.Conn.Write(b) + finished := c.tx.TimeSince(c.tx.ZeroTime) + select { + case c.tx.NetworkEvent <- NewArchivalNetworkEvent( + c.tx.Index, started, netxlite.WriteOperation, network, addr, count, err, finished): + default: // buffer is full + } + return count, err +} + +// NewArchivalNetworkEvent creates a new model.ArchivalNetworkEvent. +func NewArchivalNetworkEvent(index int64, started time.Duration, operation string, network string, + address string, count int, err error, finished time.Duration) *model.ArchivalNetworkEvent { + return &model.ArchivalNetworkEvent{ + Address: address, + Failure: tracex.NewFailure(err), + NumBytes: int64(count), + Operation: operation, + Proto: network, + T: finished.Seconds(), + Tags: []string{}, + } +} + +// NewAnnotationArchivalNetworkEvent is a simplified NewArchivalNetworkEvent +// where we create a simple annotation without attached I/O info. +func NewAnnotationArchivalNetworkEvent( + index int64, time time.Duration, operation string) *model.ArchivalNetworkEvent { + return NewArchivalNetworkEvent(index, time, operation, "", "", 0, nil, time) +} + +// NetworkEvents drains the network events buffered inside the NetworkEvent channel. +func (tx *Trace) NetworkEvents() (out []*model.ArchivalNetworkEvent) { + for { + select { + case ev := <-tx.NetworkEvent: + out = append(out, ev) + default: + return // done + } + } +} diff --git a/internal/measurexlite/conn_test.go b/internal/measurexlite/conn_test.go new file mode 100644 index 0000000..792a374 --- /dev/null +++ b/internal/measurexlite/conn_test.go @@ -0,0 +1,243 @@ +package measurexlite + +import ( + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/testingx" +) + +func TestMaybeClose(t *testing.T) { + t.Run("with nil conn", func(t *testing.T) { + var conn net.Conn = nil + MaybeClose(conn) + }) + + t.Run("with nonnil conn", func(t *testing.T) { + var called bool + conn := &mocks.Conn{ + MockClose: func() error { + called = true + return nil + }, + } + if err := MaybeClose(conn); err != nil { + t.Fatal(err) + } + if !called { + t.Fatal("not called") + } + }) +} + +func TestWrapNetConn(t *testing.T) { + t.Run("WrapNetConn wraps the conn", func(t *testing.T) { + underlying := &mocks.Conn{} + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + conn := trace.WrapNetConn(underlying) + ct := conn.(*connTrace) + if ct.Conn != underlying { + t.Fatal("invalid underlying") + } + if ct.tx != trace { + t.Fatal("invalid trace") + } + }) + + t.Run("Read saves a trace", func(t *testing.T) { + underlying := &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return len(b), nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + } + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now // deterministic time counting + conn := trace.WrapNetConn(underlying) + const bufsiz = 128 + buffer := make([]byte, bufsiz) + count, err := conn.Read(buffer) + if count != bufsiz { + t.Fatal("invalid count") + } + if err != nil { + t.Fatal("invalid err") + } + events := trace.NetworkEvents() + if len(events) != 1 { + t.Fatal("did not save network events") + } + expect := &model.ArchivalNetworkEvent{ + Address: "1.1.1.1:443", + Failure: nil, + NumBytes: bufsiz, + Operation: netxlite.ReadOperation, + Proto: "tcp", + T: 1.0, + Tags: []string{}, + } + got := events[0] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("Read discards the event when the buffer is full", func(t *testing.T) { + underlying := &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return len(b), nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + } + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer + conn := trace.WrapNetConn(underlying) + const bufsiz = 128 + buffer := make([]byte, bufsiz) + count, err := conn.Read(buffer) + if count != bufsiz { + t.Fatal("invalid count") + } + if err != nil { + t.Fatal("invalid err") + } + events := trace.NetworkEvents() + if len(events) != 0 { + t.Fatal("expected no network events") + } + }) + + t.Run("Write saves a trace", func(t *testing.T) { + underlying := &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + } + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now // deterministic time tracking + conn := trace.WrapNetConn(underlying) + const bufsiz = 128 + buffer := make([]byte, bufsiz) + count, err := conn.Write(buffer) + if count != bufsiz { + t.Fatal("invalid count") + } + if err != nil { + t.Fatal("invalid err") + } + events := trace.NetworkEvents() + if len(events) != 1 { + t.Fatal("did not save network events") + } + expect := &model.ArchivalNetworkEvent{ + Address: "1.1.1.1:443", + Failure: nil, + NumBytes: bufsiz, + Operation: netxlite.WriteOperation, + Proto: "tcp", + T: 1.0, + Tags: []string{}, + } + got := events[0] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("Write discards the event when the buffer is full", func(t *testing.T) { + underlying := &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + } + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer + conn := trace.WrapNetConn(underlying) + const bufsiz = 128 + buffer := make([]byte, bufsiz) + count, err := conn.Write(buffer) + if count != bufsiz { + t.Fatal("invalid count") + } + if err != nil { + t.Fatal("invalid err") + } + events := trace.NetworkEvents() + if len(events) != 0 { + t.Fatal("expected no network events") + } + }) +} + +func TestNewAnnotationArchivalNetworkEvent(t *testing.T) { + var ( + index int64 = 3 + duration = 250 * time.Millisecond + operation = "tls_handshake_start" + ) + expect := &model.ArchivalNetworkEvent{ + Address: "", + Failure: nil, + NumBytes: 0, + Operation: operation, + Proto: "", + T: duration.Seconds(), + Tags: []string{}, + } + got := NewAnnotationArchivalNetworkEvent( + index, duration, operation, + ) + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } +} diff --git a/internal/measurexlite/dialer.go b/internal/measurexlite/dialer.go new file mode 100644 index 0000000..e7da206 --- /dev/null +++ b/internal/measurexlite/dialer.go @@ -0,0 +1,121 @@ +package measurexlite + +// +// Dialer tracing +// + +import ( + "context" + "log" + "math" + "net" + "strconv" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/tracex" +) + +// NewDialerWithoutResolver is equivalent to netxlite.NewDialerWithoutResolver +// except that it returns a model.Dialer that uses this trace. +// +// Note: unlike code in netx or measurex, this factory DOES NOT return you a +// dialer that also performs wrapping of a net.Conn in case of success. If you +// want to wrap the conn, you need to wrap it explicitly using WrapNetConn. +func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer { + return &dialerTrace{ + d: tx.newDialerWithoutResolver(dl), + tx: tx, + } +} + +// dialerTrace is a trace-aware model.Dialer. +type dialerTrace struct { + d model.Dialer + tx *Trace +} + +var _ model.Dialer = &dialerTrace{} + +// DialContext implements model.Dialer.DialContext. +func (d *dialerTrace) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.d.DialContext(netxlite.ContextWithTrace(ctx, d.tx), network, address) +} + +// CloseIdleConnections implements model.Dialer.CloseIdleConnections. +func (d *dialerTrace) CloseIdleConnections() { + d.d.CloseIdleConnections() +} + +// OnTCPConnectDone implements model.Trace.OnTCPConnectDone. +func (tx *Trace) OnConnectDone( + started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { + switch network { + case "tcp", "tcp4", "tcp6": + select { + case tx.TCPConnect <- NewArchivalTCPConnectResult( + tx.Index, + started.Sub(tx.ZeroTime), + remoteAddr, + err, + finished.Sub(tx.ZeroTime), + ): + default: // buffer is full + } + default: + // ignore UDP connect attempts because they cannot fail + // in interesting ways that make sense for censorship + } +} + +// NewArchivalTCPConnectResult generates a model.ArchivalTCPConnectResult +// from the available information right after connect returns. +func NewArchivalTCPConnectResult(index int64, started time.Duration, address string, + err error, finished time.Duration) *model.ArchivalTCPConnectResult { + ip, port := archivalSplitHostPort(address) + return &model.ArchivalTCPConnectResult{ + IP: ip, + Port: archivalPortToString(port), + Status: model.ArchivalTCPConnectStatus{ + Blocked: nil, + Failure: tracex.NewFailure(err), + Success: err == nil, + }, + T: finished.Seconds(), + } +} + +// archivalSplitHostPort is like net.SplitHostPort but does not return an error. This +// function returns two empty strings in case of any failure. +func archivalSplitHostPort(endpoint string) (string, string) { + addr, port, err := net.SplitHostPort(endpoint) + if err != nil { + log.Printf("BUG: archivalSplitHostPort: invalid endpoint: %s", endpoint) + return "", "" + } + return addr, port +} + +// archivalPortToString is like strconv.Atoi but does not return an error. This +// function returns a zero port number in case of any failure. +func archivalPortToString(sport string) int { + port, err := strconv.Atoi(sport) + if err != nil || port < 0 || port > math.MaxUint16 { + log.Printf("BUG: archivalStrconvAtoi: invalid port: %s", sport) + return 0 + } + return port +} + +// TCPConnects drains the network events buffered inside the TCPConnect channel. +func (tx *Trace) TCPConnects() (out []*model.ArchivalTCPConnectResult) { + for { + select { + case ev := <-tx.TCPConnect: + out = append(out, ev) + default: + return // done + } + } +} diff --git a/internal/measurexlite/dialer_test.go b/internal/measurexlite/dialer_test.go new file mode 100644 index 0000000..c6c765d --- /dev/null +++ b/internal/measurexlite/dialer_test.go @@ -0,0 +1,211 @@ +package measurexlite + +import ( + "context" + "errors" + "math" + "net" + "strconv" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/testingx" +) + +func TestNewDialerWithoutResolver(t *testing.T) { + t.Run("NewDialerWithoutResolver creates a wrapped dialer", func(t *testing.T) { + underlying := &mocks.Dialer{} + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { + return underlying + } + dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + dt := dialer.(*dialerTrace) + if dt.d != underlying { + t.Fatal("invalid dialer") + } + if dt.tx != trace { + t.Fatal("invalid trace") + } + }) + + t.Run("DialContext calls the underlying dialer with context-based tracing", func(t *testing.T) { + expectedErr := errors.New("mocked err") + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + var hasCorrectTrace bool + underlying := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + gotTrace := netxlite.ContextTraceOrDefault(ctx) + hasCorrectTrace = (gotTrace == trace) + return nil, expectedErr + }, + } + trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { + return underlying + } + dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + ctx := context.Background() + conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") + if !errors.Is(err, expectedErr) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + if !hasCorrectTrace { + t.Fatal("does not have the correct trace") + } + }) + + t.Run("CloseIdleConnection is correctly forwarded", func(t *testing.T) { + var called bool + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + underlying := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { + return underlying + } + dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("DialContext saves into the trace", func(t *testing.T) { + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now // deterministic time tracking + dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // we cancel immediately so connect is ~instantaneous + conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := trace.TCPConnects() + if len(events) != 1 { + t.Fatal("expected to see single TCPConnect event") + } + expectedFailure := netxlite.FailureInterrupted + expect := &model.ArchivalTCPConnectResult{ + IP: "1.1.1.1", + Port: 443, + Status: model.ArchivalTCPConnectStatus{ + Blocked: nil, + Failure: &expectedFailure, + Success: false, + }, + T: time.Second.Seconds(), + } + got := events[0] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("DialContext discards events when buffer is full", func(t *testing.T) { + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + trace.TCPConnect = make(chan *model.ArchivalTCPConnectResult) // no buffer + dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // we cancel immediately so connect is ~instantaneous + conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := trace.TCPConnects() + if len(events) != 0 { + t.Fatal("expected to see no TCPConnect events") + } + }) + + t.Run("DialContext ignores UDP connect attempts", func(t *testing.T) { + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // we cancel immediately so connect is ~instantaneous + conn, err := dialer.DialContext(ctx, "udp", "1.1.1.1:443") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := trace.TCPConnects() + if len(events) != 0 { + t.Fatal("expected to see no TCPConnect events") + } + }) + + t.Run("DialContext uses a dialer without a resolver", func(t *testing.T) { + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // we cancel immediately so connect is ~instantaneous + conn, err := dialer.DialContext(ctx, "udp", "dns.google:443") // domain + if !errors.Is(err, netxlite.ErrNoResolver) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := trace.TCPConnects() + if len(events) != 0 { + t.Fatal("expected to see no TCPConnect events") + } + }) +} + +func TestArchivalSplitHostPort(t *testing.T) { + addr, port := archivalSplitHostPort("1.1.1.1") // missing port + if addr != "" { + t.Fatal("invalid addr", addr) + } + if port != "" { + t.Fatal("invalid port", port) + } +} + +func TestArchivalPortToString(t *testing.T) { + t.Run("with invalid number", func(t *testing.T) { + port := archivalPortToString("antani") + if port != 0 { + t.Fatal("invalid port") + } + }) + + t.Run("with negative number", func(t *testing.T) { + port := archivalPortToString("-1") + if port != 0 { + t.Fatal("invalid port") + } + }) + + t.Run("with too-large positive number", func(t *testing.T) { + port := archivalPortToString(strconv.Itoa(math.MaxUint16 + 1)) + if port != 0 { + t.Fatal("invalid port") + } + }) +} diff --git a/internal/measurexlite/doc.go b/internal/measurexlite/doc.go new file mode 100644 index 0000000..cc52f84 --- /dev/null +++ b/internal/measurexlite/doc.go @@ -0,0 +1,10 @@ +// Package measurexlite contains measurement extensions. +// +// See docs/design/dd-003-step-by-step.md in the ooni/probe-cli +// repository for the design document. +// +// This implementation features a Trace that saves events in +// buffered channels as proposed by df-003-step-by-step.md. We +// have reasonable default buffers for channels. But, if you +// are not draining them, eventually we stop collecting events. +package measurexlite diff --git a/internal/measurexlite/logger.go b/internal/measurexlite/logger.go new file mode 100644 index 0000000..1729346 --- /dev/null +++ b/internal/measurexlite/logger.go @@ -0,0 +1,16 @@ +package measurexlite + +// +// Logging support +// + +import "github.com/ooni/probe-cli/v3/internal/measurex" + +// TODO(bassosimone): we should eventually remove measurex and +// move the logging code from measurex to this package. + +// NewOperationLogger is an alias for measurex.NewOperationLogger. +var NewOperationLogger = measurex.NewOperationLogger + +// OperationLogger is an alias for measurex.OperationLogger. +type OperationLogger = measurex.OperationLogger diff --git a/internal/measurexlite/tls.go b/internal/measurexlite/tls.go new file mode 100644 index 0000000..2d881a7 --- /dev/null +++ b/internal/measurexlite/tls.go @@ -0,0 +1,144 @@ +package measurexlite + +// +// TLS tracing +// + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "net" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/tracex" +) + +// NewTLSHandshakerStdlib is equivalent to netxlite.NewTLSHandshakerStdlib +// except that it returns a model.TLSHandshaker that uses this trace. +func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { + return &tlsHandshakerTrace{ + thx: tx.newTLSHandshakerStdlib(dl), + tx: tx, + } +} + +// tlsHandshakerTrace is a trace-aware TLS handshaker. +type tlsHandshakerTrace struct { + thx model.TLSHandshaker + tx *Trace +} + +var _ model.TLSHandshaker = &tlsHandshakerTrace{} + +// Handshake implements model.TLSHandshaker.Handshake. +func (thx *tlsHandshakerTrace) Handshake( + ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (net.Conn, tls.ConnectionState, error) { + return thx.thx.Handshake(netxlite.ContextWithTrace(ctx, thx.tx), conn, tlsConfig) +} + +// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart. +func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) { + t := now.Sub(tx.ZeroTime) + select { + case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_start"): + default: // buffer is full + } +} + +// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone. +func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config, + state tls.ConnectionState, err error, finished time.Time) { + t := finished.Sub(tx.ZeroTime) + select { + case tx.TLSHandshake <- NewArchivalTLSOrQUICHandshakeResult( + tx.Index, + started.Sub(tx.ZeroTime), + remoteAddr, + config, + state, + err, + t, + ): + default: // buffer is full + } + select { + case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_done"): + default: // buffer is full + } +} + +// NewArchivalTLSOrQUICHandshakeResult generates a model.ArchivalTLSOrQUICHandshakeResult +// from the available information right after the TLS handshake returns. +func NewArchivalTLSOrQUICHandshakeResult( + index int64, started time.Duration, address string, config *tls.Config, + state tls.ConnectionState, err error, finished time.Duration) *model.ArchivalTLSOrQUICHandshakeResult { + return &model.ArchivalTLSOrQUICHandshakeResult{ + Address: address, + CipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), + Failure: tracex.NewFailure(err), + NegotiatedProtocol: state.NegotiatedProtocol, + NoTLSVerify: config.InsecureSkipVerify, + PeerCertificates: TLSPeerCerts(state, err), + ServerName: config.ServerName, + T: finished.Seconds(), + Tags: []string{}, + TLSVersion: netxlite.TLSVersionString(state.Version), + } +} + +// newArchivalBinaryData is a factory that adapts binary data to the +// model.ArchivalMaybeBinaryData format. +func newArchivalBinaryData(data []byte) model.ArchivalMaybeBinaryData { + // TODO(https://github.com/ooni/probe/issues/2165): we should actually extend the + // model's archival data format to have a pure-binary-data type for the cases in which + // we know in advance we're dealing with binary data. + return model.ArchivalMaybeBinaryData{ + Value: string(data), + } +} + +// TLSPeerCerts extracts the certificates either from the list of certificates +// in the connection state or from the error that occurred. +func TLSPeerCerts( + state tls.ConnectionState, err error) (out []model.ArchivalMaybeBinaryData) { + out = []model.ArchivalMaybeBinaryData{} + var x509HostnameError x509.HostnameError + if errors.As(err, &x509HostnameError) { + // Test case: https://wrong.host.badssl.com/ + out = append(out, newArchivalBinaryData(x509HostnameError.Certificate.Raw)) + return + } + var x509UnknownAuthorityError x509.UnknownAuthorityError + if errors.As(err, &x509UnknownAuthorityError) { + // Test case: https://self-signed.badssl.com/. This error has + // never been among the ones returned by MK. + out = append(out, newArchivalBinaryData(x509UnknownAuthorityError.Cert.Raw)) + return + } + var x509CertificateInvalidError x509.CertificateInvalidError + if errors.As(err, &x509CertificateInvalidError) { + // Test case: https://expired.badssl.com/ + out = append(out, newArchivalBinaryData(x509CertificateInvalidError.Cert.Raw)) + return + } + for _, cert := range state.PeerCertificates { + out = append(out, newArchivalBinaryData(cert.Raw)) + } + return +} + +// TLSHandshakes drains the network events buffered inside the TLSHandshake channel. +func (tx *Trace) TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) { + for { + select { + case ev := <-tx.TLSHandshake: + out = append(out, ev) + default: + return // done + } + } +} diff --git a/internal/measurexlite/tls_test.go b/internal/measurexlite/tls_test.go new file mode 100644 index 0000000..cd0f462 --- /dev/null +++ b/internal/measurexlite/tls_test.go @@ -0,0 +1,418 @@ +package measurexlite + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "errors" + "net" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" + "github.com/ooni/probe-cli/v3/internal/testingx" +) + +func TestNewTLSHandshakerStdlib(t *testing.T) { + t.Run("NewTLSHandshakerStdlib creates a wrapped TLSHandshaker", func(t *testing.T) { + underlying := &mocks.TLSHandshaker{} + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { + return underlying + } + thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) + thxt := thx.(*tlsHandshakerTrace) + if thxt.thx != underlying { + t.Fatal("invalid TLS handshaker") + } + if thxt.tx != trace { + t.Fatal("invalid trace") + } + }) + + t.Run("Handshake calls the underlying dialer with context-based tracing", func(t *testing.T) { + expectedErr := errors.New("mocked err") + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + var hasCorrectTrace bool + underlying := &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + gotTrace := netxlite.ContextTraceOrDefault(ctx) + hasCorrectTrace = (gotTrace == trace) + return nil, tls.ConnectionState{}, expectedErr + }, + } + trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { + return underlying + } + thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) + if !errors.Is(err, expectedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("expected zero-value state") + } + if conn != nil { + t.Fatal("expected nil conn") + } + if !hasCorrectTrace { + t.Fatal("does not have the correct trace") + } + }) + + t.Run("Handshake saves into the trace", func(t *testing.T) { + mockedErr := errors.New("mocked") + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now // deterministic timing + thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // we cancel immediately so connect is ~instantaneous + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: "dns.cloudflare.com", + } + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("expected zero-value state") + } + if conn != nil { + t.Fatal("expected nil conn") + } + + t.Run("TLSHandshake events", func(t *testing.T) { + events := trace.TLSHandshakes() + if len(events) != 1 { + t.Fatal("expected to see single TLSHandshake event") + } + expectedFailure := netxlite.FailureInterrupted + expect := &model.ArchivalTLSOrQUICHandshakeResult{ + Address: "1.1.1.1:443", + CipherSuite: "", + Failure: &expectedFailure, + NegotiatedProtocol: "", + NoTLSVerify: true, + PeerCertificates: []model.ArchivalMaybeBinaryData{}, + ServerName: "dns.cloudflare.com", + T: time.Second.Seconds(), + Tags: []string{}, + TLSVersion: "", + } + got := events[0] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("Network events", func(t *testing.T) { + events := trace.NetworkEvents() + if len(events) != 2 { + t.Fatal("expected to see two Network events") + } + + t.Run("tls_handshake_start", func(t *testing.T) { + expect := &model.ArchivalNetworkEvent{ + Address: "", + Failure: nil, + NumBytes: 0, + Operation: "tls_handshake_start", + Proto: "", + T: 0, + Tags: []string{}, + } + got := events[0] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("tls_handshake_done", func(t *testing.T) { + expect := &model.ArchivalNetworkEvent{ + Address: "", + Failure: nil, + NumBytes: 0, + Operation: "tls_handshake_done", + Proto: "", + T: time.Second.Seconds(), + Tags: []string{}, + } + got := events[1] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + }) + }) + + t.Run("Handshake discards events when buffers are full", func(t *testing.T) { + mockedErr := errors.New("mocked") + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer + trace.TLSHandshake = make(chan *model.ArchivalTLSOrQUICHandshakeResult) // no buffer + thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // we cancel immediately so connect is ~instantaneous + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: "dns.cloudflare.com", + } + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("expected zero-value state") + } + if conn != nil { + t.Fatal("expected nil conn") + } + + t.Run("TLSHandshake events", func(t *testing.T) { + events := trace.TLSHandshakes() + if len(events) != 0 { + t.Fatal("expected to see no TLSHandshake events") + } + }) + + t.Run("Network events", func(t *testing.T) { + events := trace.NetworkEvents() + if len(events) != 0 { + t.Fatal("expected to see no Network events") + } + }) + }) + + t.Run("we collect the desired data with a local TLS server", func(t *testing.T) { + server := filtering.NewTLSServer(filtering.TLSActionBlockText) + dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger) + ctx := context.Background() + conn, err := dialer.DialContext(ctx, "tcp", server.Endpoint()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + zeroTime := time.Now() + dt := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = dt.Now // deterministic timing + thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) + tlsConfig := &tls.Config{ + RootCAs: server.CertPool(), + ServerName: "dns.google", + } + tlsConn, connState, err := thx.Handshake(ctx, conn, tlsConfig) + if err != nil { + t.Fatal(err) + } + defer tlsConn.Close() + data, err := netxlite.ReadAllContext(ctx, tlsConn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, filtering.HTTPBlockpage451) { + t.Fatal("bytes should match") + } + + t.Run("TLSHandshake events", func(t *testing.T) { + events := trace.TLSHandshakes() + if len(events) != 1 { + t.Fatal("expected to see a single TLSHandshake event") + } + expected := &model.ArchivalTLSOrQUICHandshakeResult{ + Address: conn.RemoteAddr().String(), + CipherSuite: netxlite.TLSCipherSuiteString(connState.CipherSuite), + Failure: nil, + NegotiatedProtocol: "", + NoTLSVerify: false, + PeerCertificates: []model.ArchivalMaybeBinaryData{}, + ServerName: "dns.google", + T: time.Second.Seconds(), + Tags: []string{}, + TLSVersion: netxlite.TLSVersionString(connState.Version), + } + got := events[0] + // TODO(bassosimone): it's still unclear to me how to test that + // I am getting exactly the expected certificate here. I think the + // certificate is generated on the fly by google/martian. So, I'm + // just going to reduce the precision of this check. + if len(got.PeerCertificates) != 2 { + t.Fatal("expected to see two certificates") + } + got.PeerCertificates = []model.ArchivalMaybeBinaryData{} // see above + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("Network events", func(t *testing.T) { + events := trace.NetworkEvents() + if len(events) != 2 { + t.Fatal("expected to see two Network events") + } + + t.Run("tls_handshake_start", func(t *testing.T) { + expect := &model.ArchivalNetworkEvent{ + Address: "", + Failure: nil, + NumBytes: 0, + Operation: "tls_handshake_start", + Proto: "", + T: 0, + Tags: []string{}, + } + got := events[0] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("tls_handshake_done", func(t *testing.T) { + expect := &model.ArchivalNetworkEvent{ + Address: "", + Failure: nil, + NumBytes: 0, + Operation: "tls_handshake_done", + Proto: "", + T: time.Second.Seconds(), + Tags: []string{}, + } + got := events[1] + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + }) + }) +} + +func TestTLSPeerCerts(t *testing.T) { + type args struct { + state tls.ConnectionState + err error + } + tests := []struct { + name string + args args + wantOut []model.ArchivalMaybeBinaryData + }{{ + name: "x509.HostnameError", + args: args{ + state: tls.ConnectionState{}, + err: x509.HostnameError{ + Certificate: &x509.Certificate{ + Raw: []byte("deadbeef"), + }, + }, + }, + wantOut: []model.ArchivalMaybeBinaryData{{ + Value: "deadbeef", + }}, + }, { + name: "x509.UnknownAuthorityError", + args: args{ + state: tls.ConnectionState{}, + err: x509.UnknownAuthorityError{ + Cert: &x509.Certificate{ + Raw: []byte("deadbeef"), + }, + }, + }, + wantOut: []model.ArchivalMaybeBinaryData{{ + Value: "deadbeef", + }}, + }, { + name: "x509.CertificateInvalidError", + args: args{ + state: tls.ConnectionState{}, + err: x509.CertificateInvalidError{ + Cert: &x509.Certificate{ + Raw: []byte("deadbeef"), + }, + }, + }, + wantOut: []model.ArchivalMaybeBinaryData{{ + Value: "deadbeef", + }}, + }, { + name: "successful case", + args: args{ + state: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{{ + Raw: []byte("deadbeef"), + }, { + Raw: []byte("abad1dea"), + }}, + }, + err: nil, + }, + wantOut: []model.ArchivalMaybeBinaryData{{ + Value: "deadbeef", + }, { + Value: "abad1dea", + }}, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOut := TLSPeerCerts(tt.args.state, tt.args.err) + if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" { + t.Fatal(diff) + } + }) + } +} diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go new file mode 100644 index 0000000..93ad677 --- /dev/null +++ b/internal/measurexlite/trace.go @@ -0,0 +1,143 @@ +package measurexlite + +// +// Definition of Trace +// + +import ( + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +// Trace implements model.Trace. +// +// The zero-value of this struct is invalid. To construct you should either +// fill all the fields marked as MANDATORY or use NewTrace. +// +// Buffered channels +// +// NewTrace uses reasonable buffer sizes for the channels used for collecting +// events. You should drain the channels used by this implementation after +// each operation you perform (i.e., we expect you to peform step-by-step +// measurements). If you want larger (or smaller) buffers, then you should +// construct this data type manually with the desired buffer sizes. +// +// We have convenience methods for extracting events from the buffered +// channels. Otherwise, you could read the channels directly. (In which +// case, remember to issue nonblocking channel reads because channels are +// never closed and they're just written when new events occur.) +type Trace struct { + // Index is the MANDATORY unique index of this trace within the + // current measurement. If you don't care about uniquely identifying + // treaces, you can use zero to indicate the "default" trace. + Index int64 + + // NetworkEvent is MANDATORY and buffers network events. If you create + // this channel manually, ensure it has some buffer. + NetworkEvent chan *model.ArchivalNetworkEvent + + // NewDialerWithoutResolverFn is OPTIONAL and can be used to override + // calls to the netxlite.NewDialerWithoutResolver factory. + NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer + + // NewTLSHandshakerStdlibFn is OPTIONAL and can be used to overide + // calls to the netxlite.NewTLSHandshakerStdlib factory. + NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker + + // TCPConnect is MANDATORY and buffers TCP connect observations. If you create + // this channel manually, ensure it has some buffer. + TCPConnect chan *model.ArchivalTCPConnectResult + + // TLSHandshake is MANDATORY and buffers TLS handshake observations. If you create + // this channel manually, ensure it has some buffer. + TLSHandshake chan *model.ArchivalTLSOrQUICHandshakeResult + + // TimeNowFn is OPTIONAL and can be used to override calls to time.Now + // to produce deterministic timing when testing. + TimeNowFn func() time.Time + + // ZeroTime is the MANDATORY time when we started the current measurement. + ZeroTime time.Time +} + +const ( + // NetworkEventBufferSize is the buffer size for constructing + // the Trace's NetworkEvent buffered channel. + NetworkEventBufferSize = 64 + + // TCPConnectBufferSize is the buffer size for constructing + // the Trace's TCPConnect buffered channel. + TCPConnectBufferSize = 8 + + // TLSHandshakeBufferSize is the buffer for construcing + // the Trace's TLSHandshake buffered channel. + TLSHandshakeBufferSize = 8 +) + +// NewTrace creates a new instance of Trace using default settings. +// +// We create buffered channels using as buffer sizes the constants that +// are also defined by this package. +// +// Arguments: +// +// - index is the unique index of this trace within the current measurement (use +// zero if you don't care about giving this trace a unique ID); +// +// - zeroTime is the time when we started the current measurement. +func NewTrace(index int64, zeroTime time.Time) *Trace { + return &Trace{ + Index: index, + NetworkEvent: make( + chan *model.ArchivalNetworkEvent, + NetworkEventBufferSize, + ), + NewDialerWithoutResolverFn: nil, // use default + NewTLSHandshakerStdlibFn: nil, // use default + TCPConnect: make( + chan *model.ArchivalTCPConnectResult, + TCPConnectBufferSize, + ), + TLSHandshake: make( + chan *model.ArchivalTLSOrQUICHandshakeResult, + TLSHandshakeBufferSize, + ), + TimeNowFn: nil, // use default + ZeroTime: zeroTime, + } +} + +// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver +// thus allows us to mock this func for testing. +func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { + if tx.NewDialerWithoutResolverFn != nil { + return tx.NewDialerWithoutResolverFn(dl) + } + return netxlite.NewDialerWithoutResolver(dl) +} + +// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib +// thus allowing us to mock this func for testing. +func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { + if tx.NewTLSHandshakerStdlibFn != nil { + return tx.NewTLSHandshakerStdlibFn(dl) + } + return netxlite.NewTLSHandshakerStdlib(dl) +} + +// TimeNow implements model.Trace.TimeNow. +func (tx *Trace) TimeNow() time.Time { + if tx.TimeNowFn != nil { + return tx.TimeNowFn() + } + return time.Now() +} + +// TimeSince is equivalent to Trace.TimeNow().Sub(t0). +func (tx *Trace) TimeSince(t0 time.Time) time.Duration { + return tx.TimeNow().Sub(t0) +} + +var _ model.Trace = &Trace{} diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go new file mode 100644 index 0000000..aba42f0 --- /dev/null +++ b/internal/measurexlite/trace_test.go @@ -0,0 +1,248 @@ +package measurexlite + +import ( + "context" + "crypto/tls" + "errors" + "net" + "reflect" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/testingx" +) + +func TestNewTrace(t *testing.T) { + t.Run("NewTrace correctly constructs a trace", func(t *testing.T) { + const index = 17 + zeroTime := time.Now() + trace := NewTrace(index, zeroTime) + + t.Run("Index", func(t *testing.T) { + if trace.Index != index { + t.Fatal("invalid index") + } + }) + + t.Run("NetworkEvent has the expected buffer size", func(t *testing.T) { + ff := &testingx.FakeFiller{} + var idx int + Loop: + for { + ev := &model.ArchivalNetworkEvent{} + ff.Fill(ev) + select { + case trace.NetworkEvent <- ev: + idx++ + default: + break Loop + } + } + if idx != NetworkEventBufferSize { + t.Fatal("invalid NetworkEvent channel buffer size") + } + }) + + t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) { + if trace.NewDialerWithoutResolverFn != nil { + t.Fatal("expected nil NewDialerWithoutResolverFn") + } + }) + + t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) { + if trace.NewTLSHandshakerStdlibFn != nil { + t.Fatal("expected nil NewTLSHandshakerStdlibFn") + } + }) + + t.Run("TCPConnect has the expected buffer size", func(t *testing.T) { + ff := &testingx.FakeFiller{} + var idx int + Loop: + for { + ev := &model.ArchivalTCPConnectResult{} + ff.Fill(ev) + select { + case trace.TCPConnect <- ev: + idx++ + default: + break Loop + } + } + if idx != TCPConnectBufferSize { + t.Fatal("invalid TCPConnect channel buffer size") + } + }) + + t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) { + ff := &testingx.FakeFiller{} + var idx int + Loop: + for { + ev := &model.ArchivalTLSOrQUICHandshakeResult{} + ff.Fill(ev) + select { + case trace.TLSHandshake <- ev: + idx++ + default: + break Loop + } + } + if idx != TLSHandshakeBufferSize { + t.Fatal("invalid TLSHandshake channel buffer size") + } + }) + + t.Run("TimeNowFn is nil", func(t *testing.T) { + if trace.TimeNowFn != nil { + t.Fatal("expected nil TimeNowFn") + } + }) + + t.Run("ZeroTime", func(t *testing.T) { + if !trace.ZeroTime.Equal(zeroTime) { + t.Fatal("invalid zero time") + } + }) + }) +} + +func TestTrace(t *testing.T) { + t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) { + t.Run("when not nil", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := &Trace{ + NewDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer { + return &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, mockedErr + }, + } + }, + } + dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) + ctx := context.Background() + conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("when nil", func(t *testing.T) { + tx := &Trace{ + NewDialerWithoutResolverFn: nil, + } + dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + }) + + t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) { + t.Run("when not nil", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := &Trace{ + NewTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker { + return &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, mockedErr + }, + } + }, + } + thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("when nil", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := &Trace{ + NewTLSHandshakerStdlibFn: nil, + } + thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + }) + + t.Run("TimeNowFn works as intended", func(t *testing.T) { + fixedTime := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC) + tx := &Trace{ + TimeNowFn: func() time.Time { + return fixedTime + }, + } + if !tx.TimeNow().Equal(fixedTime) { + t.Fatal("we cannot override time.Now calls") + } + }) + + t.Run("TimeSince works as intended", func(t *testing.T) { + t0 := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC) + t1 := t0.Add(10 * time.Second) + tx := &Trace{ + TimeNowFn: func() time.Time { + return t1 + }, + } + if tx.TimeSince(t0) != 10*time.Second { + t.Fatal("apparently Trace.Since is broken") + } + }) +} diff --git a/internal/model/archival_test.go b/internal/model/archival_test.go index eec139a..74341f4 100644 --- a/internal/model/archival_test.go +++ b/internal/model/archival_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/ooni/probe-cli/v3/internal/fakefill" + "github.com/ooni/probe-cli/v3/internal/testingx" ) func TestArchivalExtSpec(t *testing.T) { @@ -301,7 +301,7 @@ func TestHTTPBody(t *testing.T) { // we make a mistake and apply the above change (which will in turn // break correct JSON serialization), the this test will fail. var body ArchivalHTTPBody - ff := &fakefill.Filler{} + ff := &testingx.FakeFiller{} ff.Fill(&body) data := ArchivalMaybeBinaryData(body) if diff := cmp.Diff(body, data); diff != "" { diff --git a/internal/model/mocks/trace.go b/internal/model/mocks/trace.go new file mode 100644 index 0000000..1049b51 --- /dev/null +++ b/internal/model/mocks/trace.go @@ -0,0 +1,45 @@ +package mocks + +// +// Mocks for model.Trace +// + +import ( + "crypto/tls" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// Trace allows mocking model.Trace. +type Trace struct { + MockTimeNow func() time.Time + + MockOnConnectDone func( + started time.Time, network, domain, remoteAddr string, err error, finished time.Time) + + MockOnTLSHandshakeStart func(now time.Time, remoteAddr string, config *tls.Config) + + MockOnTLSHandshakeDone func(started time.Time, remoteAddr string, config *tls.Config, + state tls.ConnectionState, err error, finished time.Time) +} + +var _ model.Trace = &Trace{} + +func (t *Trace) TimeNow() time.Time { + return t.MockTimeNow() +} + +func (t *Trace) OnConnectDone( + started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { + t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished) +} + +func (t *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) { + t.MockOnTLSHandshakeStart(now, remoteAddr, config) +} + +func (t *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config, + state tls.ConnectionState, err error, finished time.Time) { + t.MockOnTLSHandshakeDone(started, remoteAddr, config, state, err, finished) +} diff --git a/internal/model/mocks/trace_test.go b/internal/model/mocks/trace_test.go new file mode 100644 index 0000000..d7ebcc8 --- /dev/null +++ b/internal/model/mocks/trace_test.go @@ -0,0 +1,74 @@ +package mocks + +import ( + "crypto/tls" + "testing" + "time" +) + +func TestTrace(t *testing.T) { + t.Run("TimeNow", func(t *testing.T) { + now := time.Now() + tx := &Trace{ + MockTimeNow: func() time.Time { + return now + }, + } + if !tx.TimeNow().Equal(now) { + t.Fatal("not working as intended") + } + }) + + t.Run("OnConnectDone", func(t *testing.T) { + var called bool + tx := &Trace{ + MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { + called = true + }, + } + tx.OnConnectDone( + time.Now(), + "tcp", + "dns.google", + "8.8.8.8:443", + nil, + time.Now(), + ) + if !called { + t.Fatal("not called") + } + }) + + t.Run("OnTLSHandshakeStart", func(t *testing.T) { + var called bool + tx := &Trace{ + MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) { + called = true + }, + } + tx.OnTLSHandshakeStart(time.Now(), "8.8.8.8:443", &tls.Config{}) + if !called { + t.Fatal("not called") + } + }) + + t.Run("OnTLSHandshakeDone", func(t *testing.T) { + var called bool + tx := &Trace{ + MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) { + called = true + }, + } + tx.OnTLSHandshakeDone( + time.Now(), + "8.8.8.8:443", + &tls.Config{}, + tls.ConnectionState{}, + nil, + time.Now(), + ) + if !called { + t.Fatal("not called") + } + }) +} diff --git a/internal/model/netx.go b/internal/model/netx.go index f33240f..a2fab89 100644 --- a/internal/model/netx.go +++ b/internal/model/netx.go @@ -292,6 +292,76 @@ type TLSHandshaker interface { net.Conn, tls.ConnectionState, error) } +// Trace allows to collect measurement traces. A trace is injected into +// netx operations using context.WithValue. Netx code retrieves the trace +// using context.Value. See docs/design/dd-003-step-by-step.md for the +// design document explaining why we implemented context-based tracing. +type Trace interface { + // TimeNow returns the current time. Normally, this should be the same + // value returned by time.Now but you may want to manipulate the time + // returned when testing to have deterministic tests. To this end, you + // can use functionality exported by the ./internal/testingx pkg. + TimeNow() time.Time + + // OnConnectDone is called when connect terminates. + // + // Arguments: + // + // - started is when we called connect; + // + // - network is the network we're using (one of "tcp" and "udp"); + // + // - domain is the domain for which we're calling connect. If the user called + // connect for an IP address and a port, then domain will be an IP address; + // + // - remoteAddr is the TCP endpoint with which we are connecting: it will + // consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421); + // + // - err is the result of connect: either an error or nil; + // + // - finished is when connect returned. + // + // The error passed to this function will always be wrapped such that the + // string returned by Error is an OONI error. + OnConnectDone( + started time.Time, network, domain, remoteAddr string, err error, finished time.Time) + + // OnTLSHandshakeStart is called when the TLS handshake starts. + // + // Arguments: + // + // - now is the moment before we start the handshake; + // + // - remoteAddr is the TCP endpoint with which we are connecting: it will + // consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421); + // + // - config is the non-nil TLS config we're using. + OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) + + // OnTLSHandshakeDone is called when the TLS handshake terminates. + // + // Arguments: + // + // - started is when we started the handshake; + // + // - remoteAddr is the TCP endpoint with which we are connecting: it will + // consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421); + // + // - config is the non-nil TLS config we're using; + // + // - state is the state of the TLS connection after the handshake, where all + // fields are zero-initialized if the handshake failed; + // + // - err is the result of the handshake: either an error or nil; + // + // - finished is right after the handshake. + // + // The error passed to this function will always be wrapped such that the + // string returned by Error is an OONI error. + OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config, + state tls.ConnectionState, err error, finished time.Time) +} + // UDPLikeConn is a net.PacketConn with some extra functions // required to convince the QUIC library (lucas-clemente/quic-go) // to inflate the receive buffer of the connection. diff --git a/internal/netxlite/bogon.go b/internal/netxlite/bogon.go index 33ce86b..a31f571 100644 --- a/internal/netxlite/bogon.go +++ b/internal/netxlite/bogon.go @@ -47,7 +47,7 @@ func (r *bogonResolver) LookupHost(ctx context.Context, hostname string) ([]stri for _, addr := range addrs { if IsBogon(addr) { // wrap ErrDNSBogon as documented - return nil, newErrWrapper(classifyResolverError, ResolveOperation, ErrDNSBogon) + return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, ErrDNSBogon) } } return addrs, nil diff --git a/internal/netxlite/classify.go b/internal/netxlite/classify.go index bfd7f25..854225b 100644 --- a/internal/netxlite/classify.go +++ b/internal/netxlite/classify.go @@ -15,8 +15,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/scrubber" ) -// classifyGenericError is maps an error occurred during an operation -// to an OONI failure string. This specific classifier is the most +// ClassifyGenericError maps an error occurred during an operation to +// an OONI failure string. This specific classifier is the most // generic one. You usually use it when mapping I/O errors. You should // check whether there is a specific classifier for more specific // operations (e.g., DNS resolution, TLS handshake). @@ -38,7 +38,7 @@ import ( // If everything else fails, this classifier returns a string // like "unknown_failure: XXX" where XXX has been scrubbed // so to remove any network endpoints from the original error string. -func classifyGenericError(err error) string { +func ClassifyGenericError(err error) string { // The list returned here matches the values used by MK unless // explicitly noted otherwise with a comment. @@ -139,7 +139,7 @@ const ( quicTLSUnrecognizedName = 112 ) -// classifyQUICHandshakeError maps errors during a QUIC +// ClassifyQUICHandshakeError maps errors during a QUIC // handshake to OONI failure strings. // // If the input error is an *ErrWrapper we don't perform @@ -147,7 +147,7 @@ const ( // // If this classifier fails, it calls ClassifyGenericError // and returns to the caller its return value. -func classifyQUICHandshakeError(err error) string { +func ClassifyQUICHandshakeError(err error) string { // Robustness: handle the case where we're passed a wrapped error. var errwrapper *ErrWrapper @@ -207,7 +207,7 @@ func classifyQUICHandshakeError(err error) string { } } } - return classifyGenericError(err) + return ClassifyGenericError(err) } // quicIsCertificateError tells us whether a specific TLS alert error @@ -277,7 +277,7 @@ var ( // anything as explained in getaddrinfo_linux.go. var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData) -// classifyResolverError maps DNS resolution errors to +// ClassifyResolverError maps DNS resolution errors to // OONI failure strings. // // If the input error is an *ErrWrapper we don't perform @@ -285,7 +285,7 @@ var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData) // // If this classifier fails, it calls ClassifyGenericError and // returns to the caller its return value. -func classifyResolverError(err error) string { +func ClassifyResolverError(err error) string { // Robustness: handle the case where we're passed a wrapped error. var errwrapper *ErrWrapper @@ -310,10 +310,10 @@ func classifyResolverError(err error) string { if errors.Is(err, ErrAndroidDNSCacheNoData) { return FailureAndroidDNSCacheNoData } - return classifyGenericError(err) + return ClassifyGenericError(err) } -// classifyTLSHandshakeError maps an error occurred during the TLS +// ClassifyTLSHandshakeError maps an error occurred during the TLS // handshake to an OONI failure string. // // If the input error is an *ErrWrapper we don't perform @@ -321,7 +321,7 @@ func classifyResolverError(err error) string { // // If this classifier fails, it calls ClassifyGenericError and // returns to the caller its return value. -func classifyTLSHandshakeError(err error) string { +func ClassifyTLSHandshakeError(err error) string { // Robustness: handle the case where we're passed a wrapped error. var errwrapper *ErrWrapper @@ -345,5 +345,5 @@ func classifyTLSHandshakeError(err error) string { // Test case: https://expired.badssl.com/ return FailureSSLInvalidCertificate } - return classifyGenericError(err) + return ClassifyGenericError(err) } diff --git a/internal/netxlite/classify_test.go b/internal/netxlite/classify_test.go index d5d976e..a06f9b2 100644 --- a/internal/netxlite/classify_test.go +++ b/internal/netxlite/classify_test.go @@ -18,13 +18,13 @@ func TestClassifyGenericError(t *testing.T) { t.Run("for input being already an ErrWrapper", func(t *testing.T) { err := &ErrWrapper{Failure: FailureEOFError} - if classifyGenericError(err) != FailureEOFError { + if ClassifyGenericError(err) != FailureEOFError { t.Fatal("did not classify existing ErrWrapper correctly") } }) t.Run("for a system call error", func(t *testing.T) { - if classifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock { + if ClassifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock { t.Fatal("unexpected results") } }) @@ -35,63 +35,63 @@ func TestClassifyGenericError(t *testing.T) { // is just an implementation detail. t.Run("for operation was canceled", func(t *testing.T) { - if classifyGenericError(errors.New("operation was canceled")) != FailureInterrupted { + if ClassifyGenericError(errors.New("operation was canceled")) != FailureInterrupted { t.Fatal("unexpected result") } }) t.Run("for EOF", func(t *testing.T) { - if classifyGenericError(io.EOF) != FailureEOFError { + if ClassifyGenericError(io.EOF) != FailureEOFError { t.Fatal("unexpected result") } }) t.Run("for context deadline exceeded", func(t *testing.T) { - if classifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError { + if ClassifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError { t.Fatal("unexpected results") } }) t.Run("for stun's transaction is timed out", func(t *testing.T) { - if classifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError { + if ClassifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError { t.Fatal("unexpected results") } }) t.Run("for i/o timeout", func(t *testing.T) { - if classifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError { + if ClassifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError { t.Fatal("unexpected results") } }) t.Run("for TLS handshake timeout", func(t *testing.T) { err := errors.New("net/http: TLS handshake timeout") - if classifyGenericError(err) != FailureGenericTimeoutError { + if ClassifyGenericError(err) != FailureGenericTimeoutError { t.Fatal("unexpected results") } }) t.Run("for no such host", func(t *testing.T) { - if classifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError { + if ClassifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError { t.Fatal("unexpected results") } }) t.Run("for dns server misbehaving", func(t *testing.T) { - if classifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving { + if ClassifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving { t.Fatal("unexpected results") } }) t.Run("for no answer from DNS server", func(t *testing.T) { - if classifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer { + if ClassifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer { t.Fatal("unexpected results") } }) t.Run("for use of closed network connection", func(t *testing.T) { err := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection") - if classifyGenericError(err) != FailureConnectionAlreadyClosed { + if ClassifyGenericError(err) != FailureConnectionAlreadyClosed { t.Fatal("unexpected results") } }) @@ -99,7 +99,7 @@ func TestClassifyGenericError(t *testing.T) { // Now we're back in ClassifyGenericError t.Run("for context.Canceled", func(t *testing.T) { - if classifyGenericError(context.Canceled) != FailureInterrupted { + if ClassifyGenericError(context.Canceled) != FailureInterrupted { t.Fatal("unexpected result") } }) @@ -108,7 +108,7 @@ func TestClassifyGenericError(t *testing.T) { t.Run("with an IPv4 address", func(t *testing.T) { input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: some error") expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error" - out := classifyGenericError(input) + out := ClassifyGenericError(input) if out != expected { t.Fatal(cmp.Diff(expected, out)) } @@ -117,7 +117,7 @@ func TestClassifyGenericError(t *testing.T) { t.Run("with an IPv6 address", func(t *testing.T) { input := errors.New("read tcp [::1]:56948->[::1]:443: some error") expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error" - out := classifyGenericError(input) + out := ClassifyGenericError(input) if out != expected { t.Fatal(cmp.Diff(expected, out)) } @@ -131,100 +131,100 @@ func TestClassifyQUICHandshakeError(t *testing.T) { t.Run("for input being already an ErrWrapper", func(t *testing.T) { err := &ErrWrapper{Failure: FailureEOFError} - if classifyQUICHandshakeError(err) != FailureEOFError { + if ClassifyQUICHandshakeError(err) != FailureEOFError { t.Fatal("did not classify existing ErrWrapper correctly") } }) t.Run("for incompatible quic version", func(t *testing.T) { - if classifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion { + if ClassifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion { t.Fatal("unexpected results") } }) t.Run("for stateless reset", func(t *testing.T) { - if classifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset { + if ClassifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset { t.Fatal("unexpected results") } }) t.Run("for handshake timeout", func(t *testing.T) { - if classifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError { + if ClassifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError { t.Fatal("unexpected results") } }) t.Run("for idle timeout", func(t *testing.T) { - if classifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError { + if ClassifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError { t.Fatal("unexpected results") } }) t.Run("for connection refused", func(t *testing.T) { - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused { t.Fatal("unexpected results") } }) t.Run("for bad certificate", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertBadCertificate - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { t.Fatal("unexpected results") } }) t.Run("for unsupported certificate", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertUnsupportedCertificate - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { t.Fatal("unexpected results") } }) t.Run("for certificate expired", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertCertificateExpired - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { t.Fatal("unexpected results") } }) t.Run("for certificate revoked", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertCertificateRevoked - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { t.Fatal("unexpected results") } }) t.Run("for certificate unknown", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertCertificateUnknown - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate { t.Fatal("unexpected results") } }) t.Run("for decrypt error", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertDecryptError - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake { t.Fatal("unexpected results") } }) t.Run("for handshake failure", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertHandshakeFailure - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake { t.Fatal("unexpected results") } }) t.Run("for unknown CA", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSAlertUnknownCA - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority { t.Fatal("unexpected results") } }) t.Run("for unrecognized hostname", func(t *testing.T) { var err quic.TransportErrorCode = quicTLSUnrecognizedName - if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname { + if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname { t.Fatal("unexpected results") } }) @@ -234,13 +234,13 @@ func TestClassifyQUICHandshakeError(t *testing.T) { ErrorCode: quic.InternalError, ErrorMessage: FailureHostUnreachable, } - if classifyQUICHandshakeError(err) != FailureHostUnreachable { + if ClassifyQUICHandshakeError(err) != FailureHostUnreachable { t.Fatal("unexpected results") } }) t.Run("for another kind of error", func(t *testing.T) { - if classifyQUICHandshakeError(io.EOF) != FailureEOFError { + if ClassifyQUICHandshakeError(io.EOF) != FailureEOFError { t.Fatal("unexpected result") } }) @@ -252,43 +252,43 @@ func TestClassifyResolverError(t *testing.T) { t.Run("for input being already an ErrWrapper", func(t *testing.T) { err := &ErrWrapper{Failure: FailureEOFError} - if classifyResolverError(err) != FailureEOFError { + if ClassifyResolverError(err) != FailureEOFError { t.Fatal("did not classify existing ErrWrapper correctly") } }) t.Run("for ErrDNSBogon", func(t *testing.T) { - if classifyResolverError(ErrDNSBogon) != FailureDNSBogonError { + if ClassifyResolverError(ErrDNSBogon) != FailureDNSBogonError { t.Fatal("unexpected result") } }) t.Run("for refused", func(t *testing.T) { - if classifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError { + if ClassifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError { t.Fatal("unexpected result") } }) t.Run("for servfail", func(t *testing.T) { - if classifyResolverError(ErrOODNSServfail) != FailureDNSServfailError { + if ClassifyResolverError(ErrOODNSServfail) != FailureDNSServfailError { t.Fatal("unexpected result") } }) t.Run("for dns reply with wrong queryID", func(t *testing.T) { - if classifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID { + if ClassifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID { t.Fatal("unexpected result") } }) t.Run("for EAI_NODATA returned by Android's getaddrinfo", func(t *testing.T) { - if classifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData { + if ClassifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData { t.Fatal("unexpected result") } }) t.Run("for another kind of error", func(t *testing.T) { - if classifyResolverError(io.EOF) != FailureEOFError { + if ClassifyResolverError(io.EOF) != FailureEOFError { t.Fatal("unexpected result") } }) @@ -300,34 +300,34 @@ func TestClassifyTLSHandshakeError(t *testing.T) { t.Run("for input being already an ErrWrapper", func(t *testing.T) { err := &ErrWrapper{Failure: FailureEOFError} - if classifyTLSHandshakeError(err) != FailureEOFError { + if ClassifyTLSHandshakeError(err) != FailureEOFError { t.Fatal("did not classify existing ErrWrapper correctly") } }) t.Run("for x509.HostnameError", func(t *testing.T) { var err x509.HostnameError - if classifyTLSHandshakeError(err) != FailureSSLInvalidHostname { + if ClassifyTLSHandshakeError(err) != FailureSSLInvalidHostname { t.Fatal("unexpected result") } }) t.Run("for x509.UnknownAuthorityError", func(t *testing.T) { var err x509.UnknownAuthorityError - if classifyTLSHandshakeError(err) != FailureSSLUnknownAuthority { + if ClassifyTLSHandshakeError(err) != FailureSSLUnknownAuthority { t.Fatal("unexpected result") } }) t.Run("for x509.CertificateInvalidError", func(t *testing.T) { var err x509.CertificateInvalidError - if classifyTLSHandshakeError(err) != FailureSSLInvalidCertificate { + if ClassifyTLSHandshakeError(err) != FailureSSLInvalidCertificate { t.Fatal("unexpected result") } }) t.Run("for another kind of error", func(t *testing.T) { - if classifyTLSHandshakeError(io.EOF) != FailureEOFError { + if ClassifyTLSHandshakeError(io.EOF) != FailureEOFError { t.Fatal("unexpected result") } }) diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index 91ac454..c766fe3 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -125,7 +125,7 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver, outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors } return &dialerLogger{ - Dialer: &dialerResolver{ + Dialer: &dialerResolverWithTracing{ Dialer: &dialerLogger{ Dialer: outDialer, DebugLogger: logger, @@ -171,15 +171,24 @@ func (d *DialerSystem) CloseIdleConnections() { // nothing to do here } -// dialerResolver combines dialing with domain name resolution. -type dialerResolver struct { +// dialerResolverWithTracing combines dialing with domain name resolution and +// implements hooks to trace TCP (or UDP) connect operations. +type dialerResolverWithTracing struct { Dialer model.Dialer Resolver model.Resolver } -var _ model.Dialer = &dialerResolver{} +var _ model.Dialer = &dialerResolverWithTracing{} -func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +// DialContext implements model.Dialer.DialContext. Specifically this +// method performs the following operations: +// +// 1. resolve the domain inside the address using a resolver; +// +// 2. cycle through the available IP addresses and try to dial each of them; +// +// 3. trace the TCP (or UDP) connect and allow wrapping the returned conn. +func (d *dialerResolverWithTracing) DialContext(ctx context.Context, network, address string) (net.Conn, error) { // QUIRK: this routine and the related routines in quirks.go cannot // be changed easily until we use events tracing to measure. // @@ -194,10 +203,27 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin } addrs = quirkSortIPAddrs(addrs) var errorslist []error + trace := ContextTraceOrDefault(ctx) for _, addr := range addrs { target := net.JoinHostPort(addr, onlyport) + started := trace.TimeNow() conn, err := d.Dialer.DialContext(ctx, network, target) + finished := trace.TimeNow() + // TODO(bassosimone): to make the code robust to future refactoring we have + // moved error wrapping inside this type. This change opens up the possibility + // of simplifying the dialing chain by removing dialerErrWrapper. We'll be + // able to implement this refactoring once netx is gone. We cannot complete + // this refactoring _before_ because WrapDialer inserts extra wrappers + // provided by netx in the dialers chain _before_ this dialer and the dialers + // that netx insert assume that they wrap a dialer with error wrapping. + // + // Because error wrapping should be idempotent, it should not be a problem + // to have two error wrapping dialers in the chain except that, of course, it + // would be less efficient than just having a single wrapper. + err = MaybeNewErrWrapper(ClassifyGenericError, ConnectOperation, err) + trace.OnConnectDone(started, network, onlyhost, target, err, finished) if err == nil { + conn = &dialerErrWrapperConn{conn} return conn, nil } errorslist = append(errorslist, err) @@ -206,14 +232,14 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin } // lookupHost ensures we correctly handle IP addresses. -func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) { +func (d *dialerResolverWithTracing) lookupHost(ctx context.Context, hostname string) ([]string, error) { if net.ParseIP(hostname) != nil { return []string{hostname}, nil } return d.Resolver.LookupHost(ctx, hostname) } -func (d *dialerResolver) CloseIdleConnections() { +func (d *dialerResolverWithTracing) CloseIdleConnections() { d.Dialer.CloseIdleConnections() d.Resolver.CloseIdleConnections() } @@ -303,7 +329,7 @@ var _ model.Dialer = &dialerErrWrapper{} func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { - return nil, newErrWrapper(classifyGenericError, ConnectOperation, err) + return nil, NewErrWrapper(ClassifyGenericError, ConnectOperation, err) } return &dialerErrWrapperConn{Conn: conn}, nil } @@ -322,7 +348,7 @@ var _ net.Conn = &dialerErrWrapperConn{} func (c *dialerErrWrapperConn) Read(b []byte) (int, error) { count, err := c.Conn.Read(b) if err != nil { - return 0, newErrWrapper(classifyGenericError, ReadOperation, err) + return 0, NewErrWrapper(ClassifyGenericError, ReadOperation, err) } return count, nil } @@ -330,7 +356,7 @@ func (c *dialerErrWrapperConn) Read(b []byte) (int, error) { func (c *dialerErrWrapperConn) Write(b []byte) (int, error) { count, err := c.Conn.Write(b) if err != nil { - return 0, newErrWrapper(classifyGenericError, WriteOperation, err) + return 0, NewErrWrapper(ClassifyGenericError, WriteOperation, err) } return count, nil } @@ -338,7 +364,7 @@ func (c *dialerErrWrapperConn) Write(b []byte) (int, error) { func (c *dialerErrWrapperConn) Close() error { err := c.Conn.Close() if err != nil { - return newErrWrapper(classifyGenericError, CloseOperation, err) + return NewErrWrapper(ClassifyGenericError, CloseOperation, err) } return nil } diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 701bb29..6e394bd 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -13,6 +13,7 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/testingx" ) func TestNewDialerWithStdlibResolver(t *testing.T) { @@ -22,7 +23,7 @@ func TestNewDialerWithStdlibResolver(t *testing.T) { t.Fatal("invalid logger") } // typecheck the resolver - reso := logger.Dialer.(*dialerResolver) + reso := logger.Dialer.(*dialerResolverWithTracing) typecheckForSystemResolver(t, reso.Resolver, model.DiscardLogger) // typecheck the dialer logger = reso.Dialer.(*dialerLogger) @@ -64,7 +65,7 @@ func TestNewDialer(t *testing.T) { if logger.DebugLogger != log.Log { t.Fatal("invalid logger") } - reso := logger.Dialer.(*dialerResolver) + reso := logger.Dialer.(*dialerResolverWithTracing) if _, okay := reso.Resolver.(*NullResolver); !okay { t.Fatal("invalid Resolver type") } @@ -136,10 +137,10 @@ func TestDialerSystem(t *testing.T) { }) } -func TestDialerResolver(t *testing.T) { +func TestDialerResolverWithTracing(t *testing.T) { t.Run("DialContext", func(t *testing.T) { t.Run("fails without a port", func(t *testing.T) { - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &DialerSystem{}, Resolver: NewUnwrappedStdlibResolver(), } @@ -154,7 +155,7 @@ func TestDialerResolver(t *testing.T) { }) t.Run("handles dialing error correctly for single IP address", func(t *testing.T) { - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF @@ -166,13 +167,26 @@ func TestDialerResolver(t *testing.T) { if !errors.Is(err, io.EOF) { t.Fatal("not the error we expected") } + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("the error has not been wrapped") + } + if errWrapper.Failure != FailureEOFError { + t.Fatal("invalid wrapped error's failure") + } + if errWrapper.Operation != ConnectOperation { + t.Fatal("invalid wrapped error's operation") + } + if !errors.Is(errWrapper.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error's underlying error") + } if conn != nil { t.Fatal("expected nil conn") } }) t.Run("handles dialing error correctly for many IP addresses", func(t *testing.T) { - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF @@ -188,6 +202,19 @@ func TestDialerResolver(t *testing.T) { if !errors.Is(err, io.EOF) { t.Fatal("not the error we expected") } + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("the error has not been wrapped") + } + if errWrapper.Failure != FailureEOFError { + t.Fatal("invalid wrapped error's failure") + } + if errWrapper.Operation != ConnectOperation { + t.Fatal("invalid wrapped error's operation") + } + if !errors.Is(errWrapper.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error's underlying error") + } if conn != nil { t.Fatal("expected nil conn") } @@ -199,7 +226,7 @@ func TestDialerResolver(t *testing.T) { return nil }, } - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return expectedConn, nil @@ -215,7 +242,10 @@ func TestDialerResolver(t *testing.T) { if err != nil { t.Fatal(err) } - if conn != expectedConn { + // Ensure that the dialer returns a connection that is already wrapping errors, + // which is a new behavior since https://github.com/ooni/probe-cli/pull/815 + errWrapperConn := conn.(*dialerErrWrapperConn) + if errWrapperConn.Conn != expectedConn { t.Fatal("unexpected conn") } conn.Close() @@ -225,7 +255,7 @@ func TestDialerResolver(t *testing.T) { // This test is fundamental to the following // TODO(https://github.com/ooni/probe/issues/1779) mu := &sync.Mutex{} - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { // It should not happen to have parallel dials with @@ -257,7 +287,7 @@ func TestDialerResolver(t *testing.T) { // TODO(https://github.com/ooni/probe/issues/1779) mu := &sync.Mutex{} var attempts []string - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { // It should not happen to have parallel dials with @@ -298,14 +328,14 @@ func TestDialerResolver(t *testing.T) { mu := &sync.Mutex{} errorsList := []error{ errors.New("a mocked error"), - newErrWrapper( - classifyGenericError, + NewErrWrapper( + ClassifyGenericError, CloseOperation, io.EOF, ), } var errorIdx int - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { // It should not happen to have parallel dials with @@ -337,17 +367,18 @@ func TestDialerResolver(t *testing.T) { t.Run("though ignores the unknown failures", func(t *testing.T) { // This test is fundamental to the following // TODO(https://github.com/ooni/probe/issues/1779) + expectedErr := errors.New("a mocked error") mu := &sync.Mutex{} errorsList := []error{ - errors.New("a mocked error"), - newErrWrapper( - classifyGenericError, + expectedErr, + NewErrWrapper( + ClassifyGenericError, CloseOperation, - errors.New("antani"), + errors.New("antani"), // this is an unknown failure and we should not return it ), } var errorIdx int - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { // It should not happen to have parallel dials with @@ -368,18 +399,95 @@ func TestDialerResolver(t *testing.T) { }, } conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853") - if err == nil || err.Error() != "a mocked error" { + if !errors.Is(err, expectedErr) { t.Fatal("unexpected err", err) } + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("error has not been wrapped") + } + if errWrapper.Failure != "unknown_failure: a mocked error" { + t.Fatal("unexpected wrapped error's failure") + } + if errWrapper.Operation != ConnectOperation { + t.Fatal("unexpected wrapped error's operation") + } + if !errors.Is(errWrapper.WrappedErr, expectedErr) { + t.Fatal("unexpected wrapped error's underlying error") + } if conn != nil { t.Fatal("expected nil conn") } }) + + t.Run("uses a context-injected custom trace", func(t *testing.T) { + var ( + called bool + domainOK bool + networkOK bool + remoteAddrOK bool + startTimeOK bool + finishTimeOK bool + wrappedErr bool + ) + zeroTime := time.Now() + deterministicTime := testingx.NewTimeDeterministic(zeroTime) + tx := &mocks.Trace{ + MockTimeNow: deterministicTime.Now, + MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { + var ew *ErrWrapper + called = true + domainOK = (domain == "1.1.1.1") + networkOK = (network == "tcp") + remoteAddrOK = (remoteAddr == "1.1.1.1:853") + startTimeOK = (started.Sub(zeroTime) == 0) + finishTimeOK = (finished.Sub(zeroTime) == time.Second) + wrappedErr = errors.As(err, &ew) && ew.Failure == FailureEOFError + }, + } + ctx := ContextWithTrace(context.Background(), tx) + d := &dialerResolverWithTracing{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }, + Resolver: &NullResolver{}, + } + conn, err := d.DialContext(ctx, "tcp", "1.1.1.1:853") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn") + } + if !called { + t.Fatal("not called") + } + if !domainOK { + t.Fatal("domain was not okay") + } + if !networkOK { + t.Fatal("network was not okay") + } + if !remoteAddrOK { + t.Fatal("remoteAddr was not okay") + } + if !startTimeOK { + t.Fatal("start time was not okay") + } + if !finishTimeOK { + t.Fatal("finish time was not okay") + } + if !wrappedErr { + t.Fatal("not wrapped") + } + }) }) t.Run("lookupHost", func(t *testing.T) { t.Run("handles addresses correctly", func(t *testing.T) { - dialer := &dialerResolver{ + dialer := &dialerResolverWithTracing{ Dialer: &DialerSystem{}, Resolver: &NullResolver{}, } @@ -393,7 +501,7 @@ func TestDialerResolver(t *testing.T) { }) t.Run("fails correctly on lookup error", func(t *testing.T) { - dialer := &dialerResolver{ + dialer := &dialerResolverWithTracing{ Dialer: &DialerSystem{}, Resolver: &NullResolver{}, } @@ -413,7 +521,7 @@ func TestDialerResolver(t *testing.T) { calledDialer bool calledResolver bool ) - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockCloseIdleConnections: func() { calledDialer = true diff --git a/internal/netxlite/dnstransport.go b/internal/netxlite/dnstransport.go index 8a779ef..4a8ed39 100644 --- a/internal/netxlite/dnstransport.go +++ b/internal/netxlite/dnstransport.go @@ -38,7 +38,7 @@ func (t *dnsTransportErrWrapper) RoundTrip( ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { resp, err := t.DNSTransport.RoundTrip(ctx, query) if err != nil { - return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err) + return nil, NewErrWrapper(ClassifyResolverError, DNSRoundTripOperation, err) } return resp, nil } diff --git a/internal/netxlite/doc.go b/internal/netxlite/doc.go index ba738fd..678e06b 100644 --- a/internal/netxlite/doc.go +++ b/internal/netxlite/doc.go @@ -34,6 +34,10 @@ // // We want to have reasonable watchdog timeouts for each operation. // +// We also want lightweight support for tracing network events. To this end, we +// use context.WithValue and context.Value to inject, and retrieve, a model.Trace +// implementation OPTIONALLY configured by the user. +// // See also the design document at docs/design/dd-003-step-by-step.md, // which provides an overview of netxlite's main concerns. // diff --git a/internal/netxlite/errwrapper.go b/internal/netxlite/errwrapper.go index 9af52c1..0c34f90 100644 --- a/internal/netxlite/errwrapper.go +++ b/internal/netxlite/errwrapper.go @@ -71,7 +71,7 @@ func (e *ErrWrapper) MarshalJSON() ([]byte, error) { // https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md. type classifier func(err error) string -// newErrWrapper creates a new ErrWrapper using the given +// NewErrWrapper creates a new ErrWrapper using the given // classifier, operation name, and underlying error. // // This function panics if classifier is nil, or operation @@ -81,7 +81,7 @@ type classifier func(err error) string // error wrapper will use the same classification string and // will determine whether to keep the major operation as documented // in the ErrWrapper.Operation documentation. -func newErrWrapper(c classifier, op string, err error) *ErrWrapper { +func NewErrWrapper(c classifier, op string, err error) *ErrWrapper { var wrapper *ErrWrapper if errors.As(err, &wrapper) { return &ErrWrapper{ @@ -106,6 +106,19 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper { } } +// TODO(https://github.com/ooni/probe/issues/2163): we can really +// simplify the error wrapping situation here by just dropping +// NewErrWrapper and always using MaybeNewErrWrapper. + +// MaybeNewErrWrapper is like NewErrWrapper except that this +// function won't panic if passed a nil error. +func MaybeNewErrWrapper(c classifier, op string, err error) error { + if err != nil { + return NewErrWrapper(c, op, err) + } + return nil +} + // NewTopLevelGenericErrWrapper wraps an error occurring at top // level using a generic classifier as classifier. This is the // function you should call when you suspect a given error hasn't @@ -115,7 +128,7 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper { // error wrapper will use the same classification string and // failed operation of the original error. func NewTopLevelGenericErrWrapper(err error) *ErrWrapper { - return newErrWrapper(classifyGenericError, TopLevelOperation, err) + return NewErrWrapper(ClassifyGenericError, TopLevelOperation, err) } func classifyOperation(ew *ErrWrapper, operation string) string { diff --git a/internal/netxlite/errwrapper_test.go b/internal/netxlite/errwrapper_test.go index 2960dcc..7b75ac2 100644 --- a/internal/netxlite/errwrapper_test.go +++ b/internal/netxlite/errwrapper_test.go @@ -53,7 +53,7 @@ func TestNewErrWrapper(t *testing.T) { recovered.Add(1) } }() - newErrWrapper(nil, CloseOperation, io.EOF) + NewErrWrapper(nil, CloseOperation, io.EOF) }() if recovered.Load() != 1 { t.Fatal("did not panic") @@ -68,7 +68,7 @@ func TestNewErrWrapper(t *testing.T) { recovered.Add(1) } }() - newErrWrapper(classifyGenericError, "", io.EOF) + NewErrWrapper(ClassifyGenericError, "", io.EOF) }() if recovered.Load() != 1 { t.Fatal("did not panic") @@ -83,7 +83,7 @@ func TestNewErrWrapper(t *testing.T) { recovered.Add(1) } }() - newErrWrapper(classifyGenericError, CloseOperation, nil) + NewErrWrapper(ClassifyGenericError, CloseOperation, nil) }() if recovered.Load() != 1 { t.Fatal("did not panic") @@ -91,7 +91,7 @@ func TestNewErrWrapper(t *testing.T) { }) t.Run("otherwise, works as intended", func(t *testing.T) { - ew := newErrWrapper(classifyGenericError, CloseOperation, io.EOF) + ew := NewErrWrapper(ClassifyGenericError, CloseOperation, io.EOF) if ew.Failure != FailureEOFError { t.Fatal("unexpected failure") } @@ -104,10 +104,10 @@ func TestNewErrWrapper(t *testing.T) { }) t.Run("when the underlying error is already a wrapped error", func(t *testing.T) { - ew := newErrWrapper(classifySyscallError, ReadOperation, ECONNRESET) + ew := NewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET) var err1 error = ew err2 := fmt.Errorf("cannot read: %w", err1) - ew2 := newErrWrapper(classifyGenericError, HTTPRoundTripOperation, err2) + ew2 := NewErrWrapper(ClassifyGenericError, HTTPRoundTripOperation, err2) if ew2.Failure != ew.Failure { t.Fatal("not the same failure") } @@ -117,6 +117,34 @@ func TestNewErrWrapper(t *testing.T) { if ew2.WrappedErr != err2 { t.Fatal("invalid underlying error") } + // Make sure we can still use errors.Is with two layers of wrapping + if !errors.Is(ew2, ECONNRESET) { + t.Fatal("we cannot use errors.Is to retrieve the real syscall error") + } + }) +} + +func TestMaybeNewErrWrapper(t *testing.T) { + // TODO(https://github.com/ooni/probe/issues/2163): we can really + // simplify the error wrapping situation here by just dropping + // NewErrWrapper and always using MaybeNewErrWrapper. + + t.Run("when we pass a nil error to this function", func(t *testing.T) { + err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, nil) + if err != nil { + t.Fatal("unexpected output", err) + } + }) + + t.Run("when we pass a non-nil error to this function", func(t *testing.T) { + err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET) + if !errors.Is(err, ECONNRESET) { + t.Fatal("unexpected output", err) + } + var ew *ErrWrapper + if !errors.As(err, &ew) { + t.Fatal("not an instance of ErrWrapper") + } }) } diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index e50ea8c..a8c5bdf 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -31,7 +31,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) { dialer := txpCc.Dialer dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout) dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger) - dialerReso := dialerLog.Dialer.(*dialerResolver) + dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing) if dialerReso.Resolver != resolver { t.Fatal("invalid resolver") } @@ -52,7 +52,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) { dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout) dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer) dialerLog := dialerProxy.Dialer.(*dialerLogger) - dialerReso := dialerLog.Dialer.(*dialerResolver) + dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing) if dialerReso.Resolver != resolver { t.Fatal("invalid resolver") } @@ -269,7 +269,7 @@ func TestNewHTTPTransport(t *testing.T) { t.Run("works as intended with failing dialer", func(t *testing.T) { called := &atomicx.Int64{} expected := errors.New("mocked error") - d := &dialerResolver{ + d := &dialerResolverWithTracing{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -612,7 +612,7 @@ func TestNewHTTPClientWithResolver(t *testing.T) { txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser) dialer := txpCc.Dialer.(*httpDialerWithReadTimeout) dialerLogger := dialer.Dialer.(*dialerLogger) - dialerReso := dialerLogger.Dialer.(*dialerResolver) + dialerReso := dialerLogger.Dialer.(*dialerResolverWithTracing) if dialerReso.Resolver != reso { t.Fatal("invalid resolver") } diff --git a/internal/netxlite/iox_test.go b/internal/netxlite/iox_test.go index 268b8dd..554f243 100644 --- a/internal/netxlite/iox_test.go +++ b/internal/netxlite/iox_test.go @@ -41,7 +41,7 @@ func TestReadAllContext(t *testing.T) { // // Note: Returning a wrapped error to ensure we address // https://github.com/ooni/probe/issues/1965 - return len(b), newErrWrapper(classifyGenericError, + return len(b), NewErrWrapper(ClassifyGenericError, ReadOperation, io.EOF) }, } @@ -171,7 +171,7 @@ func TestCopyContext(t *testing.T) { // // Note: Returning a wrapped error to ensure we address // https://github.com/ooni/probe/issues/1965 - return len(b), newErrWrapper(classifyGenericError, + return len(b), NewErrWrapper(ClassifyGenericError, ReadOperation, io.EOF) }, } diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index f989771..24ddabc 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -380,7 +380,7 @@ var _ model.QUICListener = &quicListenerErrWrapper{} func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) { pconn, err := qls.QUICListener.Listen(addr) if err != nil { - return nil, newErrWrapper(classifyGenericError, QUICListenOperation, err) + return nil, NewErrWrapper(ClassifyGenericError, QUICListenOperation, err) } return &quicErrWrapperUDPLikeConn{pconn}, nil } @@ -397,7 +397,7 @@ var _ model.UDPLikeConn = &quicErrWrapperUDPLikeConn{} func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) { count, err := c.UDPLikeConn.WriteTo(p, addr) if err != nil { - return 0, newErrWrapper(classifyGenericError, WriteToOperation, err) + return 0, NewErrWrapper(ClassifyGenericError, WriteToOperation, err) } return count, nil } @@ -406,7 +406,7 @@ func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) { n, addr, err := c.UDPLikeConn.ReadFrom(b) if err != nil { - return 0, nil, newErrWrapper(classifyGenericError, ReadFromOperation, err) + return 0, nil, NewErrWrapper(ClassifyGenericError, ReadFromOperation, err) } return n, addr, nil } @@ -415,7 +415,7 @@ func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *quicErrWrapperUDPLikeConn) Close() error { err := c.UDPLikeConn.Close() if err != nil { - return newErrWrapper(classifyGenericError, ReadFromOperation, err) + return NewErrWrapper(ClassifyGenericError, ReadFromOperation, err) } return nil } @@ -433,8 +433,8 @@ func (d *quicDialerErrWrapper) DialContext( tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { qconn, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg) if err != nil { - return nil, newErrWrapper( - classifyQUICHandshakeError, QUICHandshakeOperation, err) + return nil, NewErrWrapper( + ClassifyQUICHandshakeError, QUICHandshakeOperation, err) } return qconn, nil } diff --git a/internal/netxlite/quirks_test.go b/internal/netxlite/quirks_test.go index 298d640..ae32696 100644 --- a/internal/netxlite/quirks_test.go +++ b/internal/netxlite/quirks_test.go @@ -34,13 +34,13 @@ func TestQuirkReduceErrors(t *testing.T) { t.Run("multiple errors with meaningful ones", func(t *testing.T) { err1 := errors.New("mocked error #1") - err2 := newErrWrapper( - classifyGenericError, + err2 := NewErrWrapper( + ClassifyGenericError, CloseOperation, errors.New("antani"), ) - err3 := newErrWrapper( - classifyGenericError, + err3 := NewErrWrapper( + ClassifyGenericError, CloseOperation, ECONNREFUSED, ) diff --git a/internal/netxlite/resolvercore.go b/internal/netxlite/resolvercore.go index 08c81f4..12ec5a9 100644 --- a/internal/netxlite/resolvercore.go +++ b/internal/netxlite/resolvercore.go @@ -387,7 +387,7 @@ var _ model.Resolver = &resolverErrWrapper{} func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) { addrs, err := r.Resolver.LookupHost(ctx, hostname) if err != nil { - return nil, newErrWrapper(classifyResolverError, ResolveOperation, err) + return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err) } return addrs, nil } @@ -396,7 +396,7 @@ func (r *resolverErrWrapper) LookupHTTPS( ctx context.Context, domain string) (*model.HTTPSSvc, error) { out, err := r.Resolver.LookupHTTPS(ctx, domain) if err != nil { - return nil, newErrWrapper(classifyResolverError, ResolveOperation, err) + return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err) } return out, nil } @@ -417,7 +417,7 @@ func (r *resolverErrWrapper) LookupNS( ctx context.Context, domain string) ([]*net.NS, error) { out, err := r.Resolver.LookupNS(ctx, domain) if err != nil { - return nil, newErrWrapper(classifyResolverError, ResolveOperation, err) + return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err) } return out, nil } diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index 314626e..31ff4d6 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -18,9 +18,6 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) -// TODO(bassosimone): check whether there's now equivalent functionality -// inside the standard library allowing us to map numbers to names. - var ( tlsVersionString = map[uint16]string{ tls.VersionTLS10: "TLSv1", @@ -85,6 +82,13 @@ func TLSVersionString(value uint16) string { // the value to a cipher suite name, we return `TLS_CIPHER_SUITE_UNKNOWN_ddd` // where `ddd` is the numeric value passed to this function. func TLSCipherSuiteString(value uint16) string { + // TODO(https://github.com/ooni/probe/issues/2166): the standard library has a + // function for mapping a cipher suite to a string, but the value returned in case of + // missing cipher suite is different from the one we would return + // here. We could consider simplifying this code anyway because + // in most, if not all, cases we have a valid cipher suite and we + // just need to make sure what the spec says we should do when + // passed an unknown cipher suite. if str, found := tlsCipherSuiteString[value]; found { return str } @@ -158,15 +162,15 @@ func NewTLSHandshakerStdlib(logger model.DebugLogger) model.TLSHandshaker { // newTLSHandshaker is the common factory for creating a new TLSHandshaker func newTLSHandshaker(th model.TLSHandshaker, logger model.DebugLogger) model.TLSHandshaker { return &tlsHandshakerLogger{ - TLSHandshaker: &tlsHandshakerErrWrapper{ - TLSHandshaker: th, - }, - DebugLogger: logger, + TLSHandshaker: th, + DebugLogger: logger, } } // tlsHandshakerConfigurable is a configurable TLS handshaker that // uses by default the standard library's TLS implementation. +// +// This type also implements error wrapping and events tracing. type tlsHandshakerConfigurable struct { // NewConn is the OPTIONAL factory for creating a new connection. If // this factory is not set, we'll use the stdlib. @@ -183,9 +187,20 @@ var _ model.TLSHandshaker = &tlsHandshakerConfigurable{} // value into a private variable to enable for unit testing. var defaultCertPool = NewDefaultCertPool() +// tlsMaybeConnectionState returns the connection state if error is nil +// and otherwise just returns an empty state to the caller. +func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState { + if err != nil { + return tls.ConnectionState{} + } + return conn.ConnectionState() +} + // Handshake implements Handshaker.Handshake. This function will // configure the code to use the built-in Mozilla CA if the config // field contains a nil RootCAs field. +// +// This function will also emit TLS-handshake-related tracing events. func (h *tlsHandshakerConfigurable) Handshake( ctx context.Context, conn net.Conn, config *tls.Config, ) (net.Conn, tls.ConnectionState, error) { @@ -203,10 +218,19 @@ func (h *tlsHandshakerConfigurable) Handshake( if err != nil { return nil, tls.ConnectionState{}, err } - if err := tlsconn.HandshakeContext(ctx); err != nil { + remoteAddr := conn.RemoteAddr().String() + trace := ContextTraceOrDefault(ctx) + started := trace.TimeNow() + trace.OnTLSHandshakeStart(started, remoteAddr, config) + err = tlsconn.HandshakeContext(ctx) + err = MaybeNewErrWrapper(ClassifyTLSHandshakeError, TLSHandshakeOperation, err) + finished := trace.TimeNow() + state := tlsMaybeConnectionState(tlsconn, err) + trace.OnTLSHandshakeDone(started, remoteAddr, config, state, err, finished) + if err != nil { return nil, tls.ConnectionState{}, err } - return tlsconn, tlsconn.ConnectionState(), nil + return tlsconn, state, nil } // newConn creates a new TLSConn. @@ -352,23 +376,6 @@ func (d *tlsDialerSingleUseAdapter) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -// tlsHandshakerErrWrapper wraps the returned error to be an OONI error -type tlsHandshakerErrWrapper struct { - TLSHandshaker model.TLSHandshaker -} - -// Handshake implements TLSHandshaker.Handshake -func (h *tlsHandshakerErrWrapper) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) - if err != nil { - return nil, tls.ConnectionState{}, newErrWrapper( - classifyTLSHandshakeError, TLSHandshakeOperation, err) - } - return tlsconn, state, nil -} - // ErrNoTLSDialer is the type of error returned by "null" TLS dialers // when you attempt to dial with them. var ErrNoTLSDialer = errors.New("no configured TLS dialer") diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 6d87af0..5075c55 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -16,7 +16,10 @@ import ( "github.com/apex/log" "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" + "github.com/ooni/probe-cli/v3/internal/testingx" ) func TestVersionString(t *testing.T) { @@ -123,8 +126,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { if logger.DebugLogger != log.Log { t.Fatal("invalid logger") } - errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper) - configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable) + configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable) if configurable.NewConn != nil { t.Fatal("expected nil NewConn") } @@ -132,7 +134,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { func TestTLSHandshakerConfigurable(t *testing.T) { t.Run("Handshake", func(t *testing.T) { - t.Run("with error", func(t *testing.T) { + t.Run("with handshake I/O error", func(t *testing.T) { var times []time.Time h := &tlsHandshakerConfigurable{} tcpConn := &mocks.Conn{ @@ -143,14 +145,37 @@ func TestTLSHandshakerConfigurable(t *testing.T) { times = append(times, t) return nil }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "1.1.1.1:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, } ctx := context.Background() - conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{ + conn, state, err := h.Handshake(ctx, tcpConn, &tls.Config{ ServerName: "x.org", }) - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Fatal("not the error that we expected") } + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("the error has not been wrapped") + } + if errWrapper.Failure != FailureEOFError { + t.Fatal("invalid wrapped error's failure") + } + if errWrapper.Operation != TLSHandshakeOperation { + t.Fatal("invalid wrapped error's operation") + } + if !errors.Is(errWrapper.WrappedErr, io.EOF) { + t.Fatal("invalid wrapped error's underlying error") + } if conn != nil { t.Fatal("expected nil con here") } @@ -163,6 +188,9 @@ func TestTLSHandshakerConfigurable(t *testing.T) { if !times[1].IsZero() { t.Fatal("did not clear timeout on exit") } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("the returned connection state is not a zero value") + } }) t.Run("with success", func(t *testing.T) { @@ -217,6 +245,16 @@ func TestTLSHandshakerConfigurable(t *testing.T) { MockSetDeadline: func(t time.Time) error { return nil }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "1.1.1.1:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, } tlsConn, connState, err := handshaker.Handshake(ctx, conn, config) if !errors.Is(err, expected) { @@ -236,7 +274,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) { } }) - t.Run("we cannot create a new conn", func(t *testing.T) { + t.Run("h.newConn fails", func(t *testing.T) { expected := errors.New("mocked error") handshaker := &tlsHandshakerConfigurable{ NewConn: func(conn net.Conn, config *tls.Config) (TLSConn, error) { @@ -261,6 +299,222 @@ func TestTLSHandshakerConfigurable(t *testing.T) { t.Fatal("expected nil tlsConn here") } }) + + t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) { + var ( + expectedSNI = "dns.google" + goodStartStartTime bool + goodStartInsecureSkipVerify bool + goodDoneInsecureSkipVerify bool + goodStartServerName bool + goodDoneServerName bool + goodDoneStartTime bool + goodDoneDoneTime bool + goodStartRemoteAddr bool + goodDoneRemoteAddr bool + goodDoneError bool + goodConnectionState bool + startCalled bool + doneCalled bool + ) + server := filtering.NewTLSServer(filtering.TLSActionBlockText) + defer server.Close() + zeroTime := time.Now() + deterministicTime := testingx.NewTimeDeterministic(zeroTime) + tx := &mocks.Trace{ + MockTimeNow: deterministicTime.Now, + MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) { + startCalled = true + goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true) + goodStartServerName = (config.ServerName == expectedSNI) + goodStartStartTime = (now.Sub(zeroTime) == 0) + goodStartRemoteAddr = (remoteAddr == server.Endpoint()) + }, + MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) { + doneCalled = true + goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true) + goodDoneServerName = (config.ServerName == expectedSNI) + goodDoneStartTime = (started.Sub(zeroTime) == 0) + goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second) + goodDoneRemoteAddr = (remoteAddr == server.Endpoint()) + goodDoneError = (err == nil) + goodConnectionState = (!reflect.ValueOf(state).IsZero()) + }, + } + ctx := ContextWithTrace(context.Background(), tx) + tcpConn, err := net.Dial("tcp", server.Endpoint()) + if err != nil { + t.Fatal(err) + } + thx := NewTLSHandshakerStdlib(model.DiscardLogger) + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: expectedSNI, + } + tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if err != nil { + t.Fatal(err) + } + tlsConn.Close() + if reflect.ValueOf(connState).IsZero() { + t.Fatal("expected nonzero connState") + } + if !startCalled { + t.Fatal("start not called") + } + if !doneCalled { + t.Fatal("done not called") + } + if !goodStartInsecureSkipVerify { + t.Fatal("invalid start-event's InsecureSkipVerify") + } + if !goodDoneInsecureSkipVerify { + t.Fatal("invalid done-event's InsecureSkipVerify") + } + if !goodStartServerName { + t.Fatal("invalid start-event's ServerName") + } + if !goodDoneServerName { + t.Fatal("invalid done-event's ServerName") + } + if !goodStartStartTime { + t.Fatal("invalid start-event's start time") + } + if !goodDoneStartTime { + t.Fatal("invalid done-event's start time") + } + if !goodDoneDoneTime { + t.Fatal("invalid done-event's done time") + } + if !goodStartRemoteAddr { + t.Fatal("invalid start-event's remoteAddr") + } + if !goodDoneRemoteAddr { + t.Fatal("invalid done-event's remoteAddr") + } + if !goodDoneError { + t.Fatal("invalid done-event's error") + } + if !goodConnectionState { + t.Fatal("invalid done-event's connState") + } + }) + + t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) { + var ( + expectedEndpoint = "8.8.8.8:443" + expectedSNI = "dns.google" + goodStartStartTime bool + goodStartInsecureSkipVerify bool + goodDoneInsecureSkipVerify bool + goodStartServerName bool + goodDoneServerName bool + goodDoneStartTime bool + goodDoneDoneTime bool + goodStartRemoteAddr bool + goodDoneRemoteAddr bool + goodDoneError bool + goodConnectionState bool + startCalled bool + doneCalled bool + ) + zeroTime := time.Now() + deterministicTime := testingx.NewTimeDeterministic(zeroTime) + tx := &mocks.Trace{ + MockTimeNow: deterministicTime.Now, + MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) { + startCalled = true + goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true) + goodStartServerName = (config.ServerName == expectedSNI) + goodStartStartTime = (now.Sub(zeroTime) == 0) + goodStartRemoteAddr = (remoteAddr == expectedEndpoint) + }, + MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) { + doneCalled = true + goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true) + goodDoneServerName = (config.ServerName == expectedSNI) + goodDoneStartTime = (started.Sub(zeroTime) == 0) + goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second) + goodDoneRemoteAddr = (remoteAddr == expectedEndpoint) + var ew *ErrWrapper + goodDoneError = (errors.As(err, &ew) && ew.Error() == FailureEOFError) + goodConnectionState = (reflect.ValueOf(state).IsZero()) + }, + } + ctx := ContextWithTrace(context.Background(), tx) + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return expectedEndpoint + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + thx := NewTLSHandshakerStdlib(model.DiscardLogger) + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: expectedSNI, + } + tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, io.EOF) { + t.Fatal("unexpected err", err) + } + if tlsConn != nil { + t.Fatal("expected nil tlsConn") + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero connState") + } + if !startCalled { + t.Fatal("start not called") + } + if !doneCalled { + t.Fatal("done not called") + } + if !goodStartInsecureSkipVerify { + t.Fatal("invalid start-event's InsecureSkipVerify") + } + if !goodDoneInsecureSkipVerify { + t.Fatal("invalid done-event's InsecureSkipVerify") + } + if !goodStartServerName { + t.Fatal("invalid start-event's ServerName") + } + if !goodDoneServerName { + t.Fatal("invalid done-event's ServerName") + } + if !goodStartStartTime { + t.Fatal("invalid start-event's start time") + } + if !goodDoneStartTime { + t.Fatal("invalid done-event's start time") + } + if !goodDoneDoneTime { + t.Fatal("invalid done-event's done time") + } + if !goodStartRemoteAddr { + t.Fatal("invalid start-event's remoteAddr") + } + if !goodDoneRemoteAddr { + t.Fatal("invalid done-event's remoteAddr") + } + if !goodDoneError { + t.Fatal("invalid done-event's error") + } + if !goodConnectionState { + t.Fatal("invalid done-event's connState") + } + }) }) } @@ -413,6 +667,15 @@ func TestTLSDialer(t *testing.T) { return nil }, MockSetDeadline: func(t time.Time) error { return nil + }, MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "1.1.1.1:443" + }, + MockString: func() string { + return "tcp" + }, + } }}, nil }}, TLSHandshaker: &tlsHandshakerConfigurable{}, @@ -532,54 +795,6 @@ func TestNewSingleUseTLSDialer(t *testing.T) { } } -func TestTLSHandshakerErrWrapper(t *testing.T) { - t.Run("Handshake", func(t *testing.T) { - t.Run("on success", func(t *testing.T) { - expectedConn := &mocks.TLSConn{} - expectedState := tls.ConnectionState{ - Version: tls.VersionTLS12, - } - th := &tlsHandshakerErrWrapper{ - TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return expectedConn, expectedState, nil - }, - }, - } - ctx := context.Background() - conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) - if err != nil { - t.Fatal(err) - } - if expectedState.Version != state.Version { - t.Fatal("unexpected state") - } - if expectedConn != conn { - t.Fatal("unexpected conn") - } - }) - - t.Run("on failure", func(t *testing.T) { - expectedErr := io.EOF - th := &tlsHandshakerErrWrapper{ - TLSHandshaker: &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, expectedErr - }, - }, - } - ctx := context.Background() - conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) - if err == nil || err.Error() != FailureEOFError { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("unexpected conn") - } - }) - }) -} - func TestNewNullTLSDialer(t *testing.T) { dialer := NewNullTLSDialer() conn, err := dialer.DialTLSContext(context.Background(), "", "") @@ -618,3 +833,35 @@ func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) { } }) } + +func TestMaybeConnectionState(t *testing.T) { + t.Run("with an error", func(t *testing.T) { + returned := tls.ConnectionState{ + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + } + conn := &mocks.TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return returned + }, + } + state := tlsMaybeConnectionState(conn, errors.New("mocked error")) + if !reflect.ValueOf(state).IsZero() { + t.Fatal("expected to see a zero connection state") + } + }) + + t.Run("without an error", func(t *testing.T) { + returned := tls.ConnectionState{ + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + } + conn := &mocks.TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return returned + }, + } + state := tlsMaybeConnectionState(conn, nil) + if reflect.ValueOf(state).IsZero() { + t.Fatal("expected to see a nonzero connection state") + } + }) +} diff --git a/internal/netxlite/trace.go b/internal/netxlite/trace.go new file mode 100644 index 0000000..9efc175 --- /dev/null +++ b/internal/netxlite/trace.go @@ -0,0 +1,67 @@ +package netxlite + +// +// Context-based tracing +// + +import ( + "context" + "crypto/tls" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +// traceKey is the private type used to set/retrieve the context's trace. +type traceKey struct{} + +// ContextTraceOrDefault retrieves the trace bound to the context or returns +// a default implementation of the trace in case no tracing was configured. +func ContextTraceOrDefault(ctx context.Context) model.Trace { + t, _ := ctx.Value(traceKey{}).(model.Trace) + return traceOrDefault(t) +} + +// ContextWithTrace returns a new context that binds to the given trace. If the +// given trace is nil, this function will call panic. +func ContextWithTrace(ctx context.Context, trace model.Trace) context.Context { + runtimex.PanicIfTrue(trace == nil, "netxlite.WithTrace passed a nil trace") + return context.WithValue(ctx, traceKey{}, trace) +} + +// traceOrDefault takes in input a trace and returns in output the +// given trace, if not nil, or a default trace implementation. +func traceOrDefault(trace model.Trace) model.Trace { + if trace != nil { + return trace + } + return &traceDefault{} +} + +// traceDefault is a default model.Trace implementation where each method is a no-op. +type traceDefault struct{} + +var _ model.Trace = &traceDefault{} + +// TimeNow implements model.Trace.TimeNow +func (*traceDefault) TimeNow() time.Time { + return time.Now() +} + +// OnConnectDone implements model.Trace.OnConnectDone. +func (*traceDefault) OnConnectDone( + started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { + // nothing +} + +// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart. +func (*traceDefault) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) { + // nothing +} + +// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone. +func (*traceDefault) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config, + state tls.ConnectionState, err error, finished time.Time) { + // nothing +} diff --git a/internal/netxlite/trace_test.go b/internal/netxlite/trace_test.go new file mode 100644 index 0000000..69a3c2b --- /dev/null +++ b/internal/netxlite/trace_test.go @@ -0,0 +1,40 @@ +package netxlite + +import ( + "context" + "testing" + + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestContextTraceOrDefault(t *testing.T) { + t.Run("without a configured trace we get a default", func(t *testing.T) { + ctx := context.Background() + tx := ContextTraceOrDefault(ctx) + _ = tx.(*traceDefault) // panic if cannot cast + }) + + t.Run("with a configured trace we get the expected trace", func(t *testing.T) { + realTrace := &mocks.Trace{} + ctx := ContextWithTrace(context.Background(), realTrace) + tx := ContextTraceOrDefault(ctx) + if tx != realTrace { + t.Fatal("not the trace we expected") + } + }) +} + +func TestContextWithTrace(t *testing.T) { + t.Run("panics if passed a nil trace", func(t *testing.T) { + var called bool + func() { + defer func() { + called = (recover() != nil) + }() + _ = ContextWithTrace(context.Background(), nil) + }() + if !called { + t.Fatal("not called") + } + }) +} diff --git a/internal/netxlite/utls_test.go b/internal/netxlite/utls_test.go index e6d35fc..54cabab 100644 --- a/internal/netxlite/utls_test.go +++ b/internal/netxlite/utls_test.go @@ -19,8 +19,7 @@ func TestNewTLSHandshakerUTLS(t *testing.T) { if logger.DebugLogger != log.Log { t.Fatal("invalid logger") } - errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper) - configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable) + configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable) if configurable.NewConn == nil { t.Fatal("expected non-nil NewConn") } diff --git a/internal/testingx/doc.go b/internal/testingx/doc.go new file mode 100644 index 0000000..6b7e944 --- /dev/null +++ b/internal/testingx/doc.go @@ -0,0 +1,2 @@ +// Package testingx contains code useful for testing. +package testingx diff --git a/internal/fakefill/fakefill.go b/internal/testingx/fakefill.go similarity index 85% rename from internal/fakefill/fakefill.go rename to internal/testingx/fakefill.go index c5d0d9b..1adea3e 100644 --- a/internal/fakefill/fakefill.go +++ b/internal/testingx/fakefill.go @@ -1,11 +1,4 @@ -// Package fakefill contains code to fill structs for testing. -// -// This package is quite limited in scope and we can fill only the -// structures you typically send over as JSONs. -// -// As part of future work, we aim to investigate whether we can -// replace this implementation with https://go.dev/blog/fuzz-beta. -package fakefill +package testingx import ( "math/rand" @@ -14,7 +7,7 @@ import ( "time" ) -// Filler fills specific data structures with random data. The only +// FakeFiller fills specific data structures with random data. The only // exception to this behaviour is time.Time, which is instead filled // with the current time plus a small random number of seconds. // @@ -25,7 +18,13 @@ import ( // Caveat: this kind of fillter does not support filling interfaces // and channels and other complex types. The current behavior when this // kind of data types is encountered is to just ignore them. -type Filler struct { +// +// This struct is quite limited in scope and we can fill only the +// structures you typically send over as JSONs. +// +// As part of future work, we aim to investigate whether we can +// replace this implementation with https://go.dev/blog/fuzz-beta. +type FakeFiller struct { // mu provides mutual exclusion mu sync.Mutex @@ -37,7 +36,7 @@ type Filler struct { rnd *rand.Rand } -func (ff *Filler) getRandLocked() *rand.Rand { +func (ff *FakeFiller) getRandLocked() *rand.Rand { if ff.rnd == nil { now := time.Now if ff.Now != nil { @@ -48,7 +47,7 @@ func (ff *Filler) getRandLocked() *rand.Rand { return ff.rnd } -func (ff *Filler) getRandomString() string { +func (ff *FakeFiller) getRandomString() string { defer ff.mu.Unlock() ff.mu.Lock() rnd := ff.getRandLocked() @@ -62,28 +61,28 @@ func (ff *Filler) getRandomString() string { return string(b) } -func (ff *Filler) getRandomInt64() int64 { +func (ff *FakeFiller) getRandomInt64() int64 { defer ff.mu.Unlock() ff.mu.Lock() rnd := ff.getRandLocked() return rnd.Int63() } -func (ff *Filler) getRandomBool() bool { +func (ff *FakeFiller) getRandomBool() bool { defer ff.mu.Unlock() ff.mu.Lock() rnd := ff.getRandLocked() return rnd.Float64() >= 0.5 } -func (ff *Filler) getRandomSmallPositiveInt() int { +func (ff *FakeFiller) getRandomSmallPositiveInt() int { defer ff.mu.Unlock() ff.mu.Lock() rnd := ff.getRandLocked() return int(rnd.Int63n(8)) + 1 // safe cast } -func (ff *Filler) doFill(v reflect.Value) { +func (ff *FakeFiller) doFill(v reflect.Value) { for v.Type().Kind() == reflect.Ptr { if v.IsNil() { // if the pointer is nil, allocate an element @@ -134,6 +133,6 @@ func (ff *Filler) doFill(v reflect.Value) { } // Fill fills the input structure or pointer with random data. -func (ff *Filler) Fill(in interface{}) { +func (ff *FakeFiller) Fill(in interface{}) { ff.doFill(reflect.ValueOf(in)) } diff --git a/internal/fakefill/fakefill_test.go b/internal/testingx/fakefill_test.go similarity index 93% rename from internal/fakefill/fakefill_test.go rename to internal/testingx/fakefill_test.go index c31c593..7e3f7f1 100644 --- a/internal/fakefill/fakefill_test.go +++ b/internal/testingx/fakefill_test.go @@ -1,4 +1,4 @@ -package fakefill +package testingx import ( "testing" @@ -16,7 +16,7 @@ type exampleStructure struct { func TestFakeFillWorksWithCustomTime(t *testing.T) { var req *exampleStructure - ff := &Filler{ + ff := &FakeFiller{ Now: func() time.Time { return time.Date(1992, time.January, 24, 17, 53, 0, 0, time.UTC) }, @@ -29,7 +29,7 @@ func TestFakeFillWorksWithCustomTime(t *testing.T) { func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) { var req *exampleStructure - ff := &Filler{} + ff := &FakeFiller{} ff.Fill(&req) if req == nil { t.Fatal("we expected non nil here") @@ -38,7 +38,7 @@ func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) { func TestFakeFillAllocatesIntoAMapLikeWithStringKeys(t *testing.T) { var resp map[string]*exampleStructure - ff := &Filler{} + ff := &FakeFiller{} ff.Fill(&resp) if resp == nil { t.Fatal("we expected non nil here") @@ -62,7 +62,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) { } }() var resp map[int64]*exampleStructure - ff := &Filler{} + ff := &FakeFiller{} ff.Fill(&resp) if resp != nil { t.Fatal("we expected nil here") @@ -75,7 +75,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) { func TestFakeFillAllocatesIntoASlice(t *testing.T) { var resp *[]*exampleStructure - ff := &Filler{} + ff := &FakeFiller{} ff.Fill(&resp) if resp == nil { t.Fatal("we expected non nil here") diff --git a/internal/testingx/time.go b/internal/testingx/time.go new file mode 100644 index 0000000..68abdf8 --- /dev/null +++ b/internal/testingx/time.go @@ -0,0 +1,48 @@ +package testingx + +import ( + "sync" + "time" +) + +// TimeDeterministic implements time.Now in a deterministic fashion +// such that every time.Time call returns a moment in time that occurs +// one second after the configured zeroTime. +// +// It's safe to use this struct from multiple goroutine contexts. +type TimeDeterministic struct { + // counter counts the number of "ticks" passed since the zero time: each + // call to Now increments this counter by one second. + counter time.Duration + + // mu protects fields in this structure from concurrent access. + mu sync.Mutex + + // zeroTime is the lazy-initialized zero time. The first call to Now + // will initialize this field with the current time. + zeroTime time.Time +} + +// NewTimeDeterministic creates a new instance using the given zeroTime value. +func NewTimeDeterministic(zeroTime time.Time) *TimeDeterministic { + return &TimeDeterministic{ + counter: 0, + mu: sync.Mutex{}, + zeroTime: zeroTime, + } +} + +// Now is like time.Now but more deterministic. The first call returns the +// configured zeroTime and subsequent calls return moments in time that occur +// exactly one second after the time returned by the previous call. +func (td *TimeDeterministic) Now() time.Time { + td.mu.Lock() + if td.zeroTime.IsZero() { + td.zeroTime = time.Now() + } + offset := td.counter + td.counter += time.Second + res := td.zeroTime.Add(offset) + td.mu.Unlock() + return res +} diff --git a/internal/testingx/time_test.go b/internal/testingx/time_test.go new file mode 100644 index 0000000..fd00e99 --- /dev/null +++ b/internal/testingx/time_test.go @@ -0,0 +1,22 @@ +package testingx + +import ( + "testing" + "time" +) + +func TestTimeDeterministic(t *testing.T) { + td := &TimeDeterministic{} + t0 := td.Now() + if !t0.Equal(td.zeroTime) { + t.Fatal("invalid t0 value") + } + t1 := td.Now() + if t1.Sub(t0) != time.Second { + t.Fatal("invalid t1 value") + } + t2 := td.Now() + if t2.Sub(t1) != time.Second { + t.Fatal("invalid t2 value") + } +} diff --git a/internal/tracex/doc.go b/internal/tracex/doc.go index 7a2b467..c6ce3f7 100644 --- a/internal/tracex/doc.go +++ b/internal/tracex/doc.go @@ -5,4 +5,8 @@ // resulting connection). When done, we will extract the trace containing // all the events that occurred from the saver and process it to determine // what happened during the measurement itself. +// +// This package is now frozen. Please, use measurexlite for new code. See +// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md +// for details about this. package tracex