feat: tlsping and tcpping using step-by-step (#815)

## Checklist

- [x] I have read the [contribution guidelines](https://github.com/ooni/probe-cli/blob/master/CONTRIBUTING.md)
- [x] reference issue for this pull request: https://github.com/ooni/probe/issues/2158
- [x] if you changed anything related how experiments work and you need to reflect these changes in the ooni/spec repository, please link to the related ooni/spec pull request: https://github.com/ooni/spec/pull/250

## Description

This diff refactors the codebase to reimplement tlsping and tcpping
to use the step-by-step measurements style.

See docs/design/dd-003-step-by-step.md for more information on the
step-by-step measurement style.
This commit is contained in:
Simone Basso 2022-07-01 12:22:22 +02:00 committed by GitHub
parent 5371c7f486
commit 5ebdeb56ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 2825 additions and 299 deletions

View File

@ -874,7 +874,7 @@ mechanisms if the context is too bad):
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
qtype uint16, out chan\<- \*parallelResolverResult) {
qtype uint16, out chan<- *parallelResolverResult) {
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
+ started := time.Now()
@ -1329,7 +1329,7 @@ const webDomain = "web.telegram.org"
// This method does not return any value and writes results directly inside
// the test keys, which have thread safe methods for that.
func (mx *Measurer) measureWebEndpointHTTPS(ctx context.Context, wg *sync.WaitGroup,
logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
logger model.Logger, zeroTime time.Time, tk *TestKeys, address string) {
// 0. setup
const webTimeout = 7 * time.Second
ctx, cancel := context.WithTimeout(ctx, webTimeout)
@ -1344,8 +1344,8 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
// 1. establish a TCP connection with the endpoint
// dialer := nextlite.NewDialerwithoutResolver(logger) // --- (removed line)
trace := measurexlite.NewTrace(index, logger, zeroTime) // +++ (added line)
dialer := trace.NewDialerWithoutResolver() // +++ (...)
trace := measurexlite.NewTrace(index, zeroTime) // +++ (added line)
dialer := trace.NewDialerWithoutResolver(logger) // +++ (...)
defer tk.addTCPConnectResults(trace.TCPConnectResults()) // +++ (...)
conn, err := dialer.DialContext(ctx, "tcp", endpoint)
@ -1366,7 +1366,7 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
// thx := netxlite.NewTLSHandshakerStdlib(logger) // ---
conn = trace.WrapConn(conn) // +++
defer tk.addNetworkEvents(trace.NetworkEvents()) // +++
thx := trace.NewTLSHandshakerStdlib() // +++
thx := trace.NewTLSHandshakerStdlib(logger) // +++
defer tk.addTLSHandshakeResult(trace.TLSHandshakeResults()) // +++
config := &tls.Config{
@ -1498,23 +1498,22 @@ type Trace struct { /* ... */ }
var _ model.Trace = &Trace{}
func NewTrace(index int64, logger model.Logger, zeroTime time.Time) *Trace {
func NewTrace(index int64, zeroTime time.Time) *Trace {
const (
tlsHandshakeBuffer = 16,
// ...
)
return &Trace{
Index: index,
Logger: logger,
TLS: make(chan *model.ArchivalTLSOrQUICHandshakeResult, tlsHandshakeBuffer),
ZeroTime: zeroTime,
// ...
}
}
func (tx *Trace) NewTLSHandshakerStdlib() model.TLSHandshaker {
func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
return &tlsHandshakerTrace{
TLSHandshaker: netxlite.NewTLSHandshakerStdlib(tx.Logger),
TLSHandshaker: netxlite.NewTLSHandshakerStdlib(dl),
Trace: tx,
}
}
@ -1524,8 +1523,8 @@ type tlsHandshakerTrace { /* ... */ }
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
func (thx *tlsHandshakerTrace) Handshake(ctx context.Context,
conn net.Conn, config \*tls.Config) (net.Conn, tls.ConnectionState, error) {
ctx = netxlite.ContextWithTrace(ctx, thx.Trace) // <- here we setup the context magic
conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
ctx = netxlite.ContextWithTraceOrDefault(ctx, thx.Trace) // <- here we setup the context magic
return thx.TLSHandshaker.Handshake(ctx, conn, config)
}
@ -1540,7 +1539,7 @@ func (tx *Trace) OnTLSHandshake(started time.Time, remoteAddr string,
}
}
func (tx *Trace) TLSHandshakeResults() (out []\*model.ArchivalTLSOrQUICHandshakeResult) {
func (tx *Trace) TLSHandshakeResults() (out []*model.ArchivalTLSOrQUICHandshakeResult) {
for {
select {
case ev := <-tx.TLS:
@ -1575,7 +1574,7 @@ func (thx *tlsHandshakerConfigurable) Handshake(ctx context.Context,
state := tlsconn.connectionState()
trace.OnTLSHandshake(started, remoteAddr, config, // +++
state, nil, finished) // +++
return tlsconn, state, nil
return tlsconn, state, nil
}
func ContextTraceOrDefault(ctx context.Context) model.Trace { /* ... */ }

View File

@ -10,13 +10,13 @@ import (
"net/url"
"time"
"github.com/ooni/probe-cli/v3/internal/measurex"
"github.com/ooni/probe-cli/v3/internal/measurexlite"
"github.com/ooni/probe-cli/v3/internal/model"
)
const (
testName = "tcpping"
testVersion = "0.1.0"
testVersion = "0.2.0"
)
// Config contains the experiment configuration.
@ -49,7 +49,7 @@ type TestKeys struct {
// SinglePing contains the results of a single ping.
type SinglePing struct {
TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"`
TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"`
}
// Measurer performs the measurement.
@ -103,42 +103,47 @@ func (m *Measurer) Run(
}
tk := new(TestKeys)
measurement.TestKeys = tk
out := make(chan *measurex.EndpointMeasurement)
mxmx := measurex.NewMeasurerWithDefaultSettings()
go m.tcpPingLoop(ctx, mxmx, parsed.Host, out)
out := make(chan *SinglePing)
go m.tcpPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out)
for len(tk.Pings) < int(m.config.repetitions()) {
meas := <-out
tk.Pings = append(tk.Pings, &SinglePing{
TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect),
})
tk.Pings = append(tk.Pings, <-out)
}
return nil // return nil so we always submit the measurement
}
// tcpPingLoop sends all the ping requests and emits the results onto the out channel.
func (m *Measurer) tcpPingLoop(ctx context.Context, mxmx *measurex.Measurer,
address string, out chan<- *measurex.EndpointMeasurement) {
func (m *Measurer) tcpPingLoop(ctx context.Context, zeroTime time.Time,
logger model.Logger, address string, out chan<- *SinglePing) {
ticker := time.NewTicker(m.config.delay())
defer ticker.Stop()
for i := int64(0); i < m.config.repetitions(); i++ {
go m.tcpPingAsync(ctx, mxmx, address, out)
go m.tcpPingAsync(ctx, i, zeroTime, logger, address, out)
<-ticker.C
}
}
// tcpPingAsync performs a TCP ping and emits the result onto the out channel.
func (m *Measurer) tcpPingAsync(ctx context.Context, mxmx *measurex.Measurer,
address string, out chan<- *measurex.EndpointMeasurement) {
out <- m.tcpConnect(ctx, mxmx, address)
func (m *Measurer) tcpPingAsync(ctx context.Context, index int64,
zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) {
out <- m.tcpConnect(ctx, index, zeroTime, logger, address)
}
// tcpConnect performs a TCP connect and returns the result to the caller.
func (m *Measurer) tcpConnect(ctx context.Context, mxmx *measurex.Measurer,
address string) *measurex.EndpointMeasurement {
func (m *Measurer) tcpConnect(ctx context.Context, index int64,
zeroTime time.Time, logger model.Logger, address string) *SinglePing {
// TODO(bassosimone): make the timeout user-configurable
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
return mxmx.TCPConnect(ctx, address)
trace := measurexlite.NewTrace(index, zeroTime)
dialer := trace.NewDialerWithoutResolver(logger)
ol := measurexlite.NewOperationLogger(logger, "TCPPing #%d %s", index, address)
conn, err := dialer.DialContext(ctx, "tcp", address)
ol.Stop(err)
measurexlite.MaybeClose(conn)
sp := &SinglePing{
TCPConnect: <-trace.TCPConnect,
}
return sp
}
// NewExperimentMeasurer creates a new ExperimentMeasurer.

View File

@ -40,14 +40,16 @@ func TestMeasurer_run(t *testing.T) {
if m.ExperimentName() != "tcpping" {
t.Fatal("invalid experiment name")
}
if m.ExperimentVersion() != "0.1.0" {
if m.ExperimentVersion() != "0.2.0" {
t.Fatal("invalid experiment version")
}
ctx := context.Background()
meas := &model.Measurement{
Input: model.MeasurementTarget(input),
}
sess := &mockable.Session{}
sess := &mockable.Session{
MockableLogger: model.DiscardLogger,
}
callbacks := model.NewPrinterCallbacks(model.DiscardLogger)
err := m.Run(ctx, sess, meas, callbacks)
return meas, m, err

View File

@ -13,14 +13,14 @@ import (
"strings"
"time"
"github.com/ooni/probe-cli/v3/internal/measurex"
"github.com/ooni/probe-cli/v3/internal/measurexlite"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
const (
testName = "tlsping"
testVersion = "0.1.0"
testVersion = "0.2.0"
)
// Config contains the experiment configuration.
@ -77,9 +77,9 @@ type TestKeys struct {
// SinglePing contains the results of a single ping.
type SinglePing struct {
NetworkEvents []*measurex.ArchivalNetworkEvent `json:"network_events"`
TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"`
TLSHandshakes []*measurex.ArchivalQUICTLSHandshakeEvent `json:"tls_handshakes"`
NetworkEvents []*model.ArchivalNetworkEvent `json:"network_events"`
TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"`
TLSHandshake *model.ArchivalTLSOrQUICHandshakeResult `json:"tls_handshake"`
}
// Measurer performs the measurement.
@ -133,49 +133,67 @@ func (m *Measurer) Run(
}
tk := new(TestKeys)
measurement.TestKeys = tk
out := make(chan *measurex.EndpointMeasurement)
mxmx := measurex.NewMeasurerWithDefaultSettings()
go m.tlsPingLoop(ctx, mxmx, parsed.Host, out)
out := make(chan *SinglePing)
go m.tlsPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out)
for len(tk.Pings) < int(m.config.repetitions()) {
meas := <-out
tk.Pings = append(tk.Pings, &SinglePing{
NetworkEvents: measurex.NewArchivalNetworkEventList(meas.ReadWrite),
TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect),
TLSHandshakes: measurex.NewArchivalQUICTLSHandshakeEventList(meas.TLSHandshake),
})
tk.Pings = append(tk.Pings, <-out)
}
return nil // return nil so we always submit the measurement
}
// tlsPingLoop sends all the ping requests and emits the results onto the out channel.
func (m *Measurer) tlsPingLoop(ctx context.Context, mxmx *measurex.Measurer,
address string, out chan<- *measurex.EndpointMeasurement) {
func (m *Measurer) tlsPingLoop(ctx context.Context, zeroTime time.Time,
logger model.Logger, address string, out chan<- *SinglePing) {
ticker := time.NewTicker(m.config.delay())
defer ticker.Stop()
for i := int64(0); i < m.config.repetitions(); i++ {
go m.tlsPingAsync(ctx, mxmx, address, out)
go m.tlsPingAsync(ctx, i, zeroTime, logger, address, out)
<-ticker.C
}
}
// tlsPingAsync performs a TLS ping and emits the result onto the out channel.
func (m *Measurer) tlsPingAsync(ctx context.Context, mxmx *measurex.Measurer,
address string, out chan<- *measurex.EndpointMeasurement) {
out <- m.tlsConnectAndHandshake(ctx, mxmx, address)
func (m *Measurer) tlsPingAsync(ctx context.Context, index int64,
zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) {
out <- m.tlsConnectAndHandshake(ctx, index, zeroTime, logger, address)
}
// tlsConnectAndHandshake performs a TCP connect followed by a TLS handshake
// and returns the results of these operations to the caller.
func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, mxmx *measurex.Measurer,
address string) *measurex.EndpointMeasurement {
func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, index int64,
zeroTime time.Time, logger model.Logger, address string) *SinglePing {
// TODO(bassosimone): make the timeout user-configurable
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
return mxmx.TLSConnectAndHandshake(ctx, address, &tls.Config{
NextProtos: strings.Split(m.config.alpn(), " "),
sp := &SinglePing{
NetworkEvents: []*model.ArchivalNetworkEvent{},
TCPConnect: nil,
TLSHandshake: nil,
}
trace := measurexlite.NewTrace(index, zeroTime)
dialer := trace.NewDialerWithoutResolver(logger)
alpn := strings.Split(m.config.alpn(), " ")
sni := m.config.sni(address)
ol := measurexlite.NewOperationLogger(logger, "TLSPing #%d %s %s %v", index, address, sni, alpn)
conn, err := dialer.DialContext(ctx, "tcp", address)
sp.TCPConnect = <-trace.TCPConnect
if err != nil {
ol.Stop(err)
return sp
}
defer conn.Close()
conn = trace.WrapNetConn(conn)
thx := trace.NewTLSHandshakerStdlib(logger)
config := &tls.Config{
NextProtos: alpn,
RootCAs: netxlite.NewDefaultCertPool(),
ServerName: m.config.sni(address),
})
ServerName: sni,
}
_, _, err = thx.Handshake(ctx, conn, config)
ol.Stop(err)
sp.TLSHandshake = <-trace.TLSHandshake
sp.NetworkEvents = trace.NetworkEvents()
return sp
}
// NewExperimentMeasurer creates a new ExperimentMeasurer.

View File

@ -39,7 +39,7 @@ func TestMeasurer_run(t *testing.T) {
const expectedPings = 4
// runHelper is an helper function to run this set of tests.
runHelper := func(input string) (*model.Measurement, model.ExperimentMeasurer, error) {
runHelper := func(ctx context.Context, input string) (*model.Measurement, model.ExperimentMeasurer, error) {
m := NewExperimentMeasurer(Config{
ALPN: "http/1.1",
Delay: 1, // millisecond
@ -48,10 +48,9 @@ func TestMeasurer_run(t *testing.T) {
if m.ExperimentName() != "tlsping" {
t.Fatal("invalid experiment name")
}
if m.ExperimentVersion() != "0.1.0" {
if m.ExperimentVersion() != "0.2.0" {
t.Fatal("invalid experiment version")
}
ctx := context.Background()
meas := &model.Measurement{
Input: model.MeasurementTarget(input),
}
@ -64,34 +63,34 @@ func TestMeasurer_run(t *testing.T) {
}
t.Run("with empty input", func(t *testing.T) {
_, _, err := runHelper("")
_, _, err := runHelper(context.Background(), "")
if !errors.Is(err, errNoInputProvided) {
t.Fatal("unexpected error", err)
}
})
t.Run("with invalid URL", func(t *testing.T) {
_, _, err := runHelper("\t")
_, _, err := runHelper(context.Background(), "\t")
if !errors.Is(err, errInputIsNotAnURL) {
t.Fatal("unexpected error", err)
}
})
t.Run("with invalid scheme", func(t *testing.T) {
_, _, err := runHelper("https://8.8.8.8:443/")
_, _, err := runHelper(context.Background(), "https://8.8.8.8:443/")
if !errors.Is(err, errInvalidScheme) {
t.Fatal("unexpected error", err)
}
})
t.Run("with missing port", func(t *testing.T) {
_, _, err := runHelper("tlshandshake://8.8.8.8")
_, _, err := runHelper(context.Background(), "tlshandshake://8.8.8.8")
if !errors.Is(err, errMissingPort) {
t.Fatal("unexpected error", err)
}
})
t.Run("with local listener", func(t *testing.T) {
t.Run("with local listener and successful outcome", func(t *testing.T) {
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
@ -101,7 +100,37 @@ func TestMeasurer_run(t *testing.T) {
t.Fatal(err)
}
URL.Scheme = "tlshandshake"
meas, m, err := runHelper(URL.String())
meas, m, err := runHelper(context.Background(), URL.String())
if err != nil {
t.Fatal(err)
}
tk := meas.TestKeys.(*TestKeys)
if len(tk.Pings) != expectedPings {
t.Fatal("unexpected number of pings")
}
ask, err := m.GetSummaryKeys(meas)
if err != nil {
t.Fatal("cannot obtain summary")
}
summary := ask.(SummaryKeys)
if summary.IsAnomaly {
t.Fatal("expected no anomaly")
}
})
t.Run("with local listener and connect issues", func(t *testing.T) {
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer srvr.Close()
URL, err := url.Parse(srvr.URL)
if err != nil {
t.Fatal(err)
}
URL.Scheme = "tlshandshake"
ctx, cancel := context.WithCancel(context.Background())
cancel() // so we cannot dial any connection
meas, m, err := runHelper(ctx, URL.String())
if err != nil {
t.Fatal(err)
}

View File

@ -1,6 +1,11 @@
// Package urlgetter implements a nettest that fetches a URL.
//
// See https://github.com/ooni/spec/blob/master/nettests/ts-027-urlgetter.md.
//
// This package is now frozen. Please, use measurexlite for new code. New
// network experiments should not depend on this package. Please see
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
// for details about this.
package urlgetter
import (

View File

@ -46,4 +46,8 @@
//
// See docs/design/dd-002-nets.md in the probe-cli repository for
// the design document describing this package.
//
// This package is now frozen. Please, use measurexlite for new code. See
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
// for details about this.
package netx

View File

@ -13,10 +13,10 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/fakefill"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
"github.com/ooni/probe-cli/v3/internal/version"
)
@ -47,7 +47,7 @@ func TestAPIClientTemplate(t *testing.T) {
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
}
ff := &fakefill.Filler{}
ff := &testingx.FakeFiller{}
ff.Fill(tmpl)
ac := tmpl.Build()
orig := apiClient(*tmpl)
@ -64,7 +64,7 @@ func TestAPIClientTemplate(t *testing.T) {
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
}
ff := &fakefill.Filler{}
ff := &testingx.FakeFiller{}
ff.Fill(tmpl)
tok := ""
ff.Fill(&tok)
@ -188,7 +188,7 @@ func TestAPIClient(t *testing.T) {
t.Run("sets the content-type properly", func(t *testing.T) {
var jsonReq fakeRequest
ff := &fakefill.Filler{}
ff := &testingx.FakeFiller{}
ff.Fill(&jsonReq)
client := newAPIClient()
req, err := client.newRequestWithJSONBody(

View File

@ -1,2 +1,6 @@
// Package measurex contains measurement extensions.
//
// This package is now frozen. Please, use measurexlite for new code. See
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
// for details about this.
package measurex

View File

@ -0,0 +1,103 @@
package measurexlite
//
// Conn tracing
//
import (
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/tracex"
)
// MaybeClose is a convenience function for closing a conn only when such a conn isn't nil.
func MaybeClose(conn net.Conn) (err error) {
if conn != nil {
err = conn.Close()
}
return
}
// WrapNetConn returns a wrapped conn that saves network events into this trace.
func (tx *Trace) WrapNetConn(conn net.Conn) net.Conn {
return &connTrace{
Conn: conn,
tx: tx,
}
}
// connTrace is a trace-aware net.Conn.
type connTrace struct {
// Implementation note: it seems safe to use embedding here because net.Conn
// is an interface from the standard library that we don't control
net.Conn
tx *Trace
}
var _ net.Conn = &connTrace{}
// Read implements net.Conn.Read and saves network events.
func (c *connTrace) Read(b []byte) (int, error) {
network := c.RemoteAddr().Network()
addr := c.RemoteAddr().String()
started := c.tx.TimeSince(c.tx.ZeroTime)
count, err := c.Conn.Read(b)
finished := c.tx.TimeSince(c.tx.ZeroTime)
select {
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
c.tx.Index, started, netxlite.ReadOperation, network, addr, count, err, finished):
default: // buffer is full
}
return count, err
}
// Write implements net.Conn.Write and saves network events.
func (c *connTrace) Write(b []byte) (int, error) {
network := c.RemoteAddr().Network()
addr := c.RemoteAddr().String()
started := c.tx.TimeSince(c.tx.ZeroTime)
count, err := c.Conn.Write(b)
finished := c.tx.TimeSince(c.tx.ZeroTime)
select {
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
c.tx.Index, started, netxlite.WriteOperation, network, addr, count, err, finished):
default: // buffer is full
}
return count, err
}
// NewArchivalNetworkEvent creates a new model.ArchivalNetworkEvent.
func NewArchivalNetworkEvent(index int64, started time.Duration, operation string, network string,
address string, count int, err error, finished time.Duration) *model.ArchivalNetworkEvent {
return &model.ArchivalNetworkEvent{
Address: address,
Failure: tracex.NewFailure(err),
NumBytes: int64(count),
Operation: operation,
Proto: network,
T: finished.Seconds(),
Tags: []string{},
}
}
// NewAnnotationArchivalNetworkEvent is a simplified NewArchivalNetworkEvent
// where we create a simple annotation without attached I/O info.
func NewAnnotationArchivalNetworkEvent(
index int64, time time.Duration, operation string) *model.ArchivalNetworkEvent {
return NewArchivalNetworkEvent(index, time, operation, "", "", 0, nil, time)
}
// NetworkEvents drains the network events buffered inside the NetworkEvent channel.
func (tx *Trace) NetworkEvents() (out []*model.ArchivalNetworkEvent) {
for {
select {
case ev := <-tx.NetworkEvent:
out = append(out, ev)
default:
return // done
}
}
}

View File

@ -0,0 +1,243 @@
package measurexlite
import (
"net"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestMaybeClose(t *testing.T) {
t.Run("with nil conn", func(t *testing.T) {
var conn net.Conn = nil
MaybeClose(conn)
})
t.Run("with nonnil conn", func(t *testing.T) {
var called bool
conn := &mocks.Conn{
MockClose: func() error {
called = true
return nil
},
}
if err := MaybeClose(conn); err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("not called")
}
})
}
func TestWrapNetConn(t *testing.T) {
t.Run("WrapNetConn wraps the conn", func(t *testing.T) {
underlying := &mocks.Conn{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
conn := trace.WrapNetConn(underlying)
ct := conn.(*connTrace)
if ct.Conn != underlying {
t.Fatal("invalid underlying")
}
if ct.tx != trace {
t.Fatal("invalid trace")
}
})
t.Run("Read saves a trace", func(t *testing.T) {
underlying := &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time counting
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Read(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 1 {
t.Fatal("did not save network events")
}
expect := &model.ArchivalNetworkEvent{
Address: "1.1.1.1:443",
Failure: nil,
NumBytes: bufsiz,
Operation: netxlite.ReadOperation,
Proto: "tcp",
T: 1.0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Read discards the event when the buffer is full", func(t *testing.T) {
underlying := &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Read(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 0 {
t.Fatal("expected no network events")
}
})
t.Run("Write saves a trace", func(t *testing.T) {
underlying := &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time tracking
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Write(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 1 {
t.Fatal("did not save network events")
}
expect := &model.ArchivalNetworkEvent{
Address: "1.1.1.1:443",
Failure: nil,
NumBytes: bufsiz,
Operation: netxlite.WriteOperation,
Proto: "tcp",
T: 1.0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Write discards the event when the buffer is full", func(t *testing.T) {
underlying := &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Write(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 0 {
t.Fatal("expected no network events")
}
})
}
func TestNewAnnotationArchivalNetworkEvent(t *testing.T) {
var (
index int64 = 3
duration = 250 * time.Millisecond
operation = "tls_handshake_start"
)
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: operation,
Proto: "",
T: duration.Seconds(),
Tags: []string{},
}
got := NewAnnotationArchivalNetworkEvent(
index, duration, operation,
)
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
}

View File

@ -0,0 +1,121 @@
package measurexlite
//
// Dialer tracing
//
import (
"context"
"log"
"math"
"net"
"strconv"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/tracex"
)
// NewDialerWithoutResolver is equivalent to netxlite.NewDialerWithoutResolver
// except that it returns a model.Dialer that uses this trace.
//
// Note: unlike code in netx or measurex, this factory DOES NOT return you a
// dialer that also performs wrapping of a net.Conn in case of success. If you
// want to wrap the conn, you need to wrap it explicitly using WrapNetConn.
func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
return &dialerTrace{
d: tx.newDialerWithoutResolver(dl),
tx: tx,
}
}
// dialerTrace is a trace-aware model.Dialer.
type dialerTrace struct {
d model.Dialer
tx *Trace
}
var _ model.Dialer = &dialerTrace{}
// DialContext implements model.Dialer.DialContext.
func (d *dialerTrace) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(netxlite.ContextWithTrace(ctx, d.tx), network, address)
}
// CloseIdleConnections implements model.Dialer.CloseIdleConnections.
func (d *dialerTrace) CloseIdleConnections() {
d.d.CloseIdleConnections()
}
// OnTCPConnectDone implements model.Trace.OnTCPConnectDone.
func (tx *Trace) OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
switch network {
case "tcp", "tcp4", "tcp6":
select {
case tx.TCPConnect <- NewArchivalTCPConnectResult(
tx.Index,
started.Sub(tx.ZeroTime),
remoteAddr,
err,
finished.Sub(tx.ZeroTime),
):
default: // buffer is full
}
default:
// ignore UDP connect attempts because they cannot fail
// in interesting ways that make sense for censorship
}
}
// NewArchivalTCPConnectResult generates a model.ArchivalTCPConnectResult
// from the available information right after connect returns.
func NewArchivalTCPConnectResult(index int64, started time.Duration, address string,
err error, finished time.Duration) *model.ArchivalTCPConnectResult {
ip, port := archivalSplitHostPort(address)
return &model.ArchivalTCPConnectResult{
IP: ip,
Port: archivalPortToString(port),
Status: model.ArchivalTCPConnectStatus{
Blocked: nil,
Failure: tracex.NewFailure(err),
Success: err == nil,
},
T: finished.Seconds(),
}
}
// archivalSplitHostPort is like net.SplitHostPort but does not return an error. This
// function returns two empty strings in case of any failure.
func archivalSplitHostPort(endpoint string) (string, string) {
addr, port, err := net.SplitHostPort(endpoint)
if err != nil {
log.Printf("BUG: archivalSplitHostPort: invalid endpoint: %s", endpoint)
return "", ""
}
return addr, port
}
// archivalPortToString is like strconv.Atoi but does not return an error. This
// function returns a zero port number in case of any failure.
func archivalPortToString(sport string) int {
port, err := strconv.Atoi(sport)
if err != nil || port < 0 || port > math.MaxUint16 {
log.Printf("BUG: archivalStrconvAtoi: invalid port: %s", sport)
return 0
}
return port
}
// TCPConnects drains the network events buffered inside the TCPConnect channel.
func (tx *Trace) TCPConnects() (out []*model.ArchivalTCPConnectResult) {
for {
select {
case ev := <-tx.TCPConnect:
out = append(out, ev)
default:
return // done
}
}
}

View File

@ -0,0 +1,211 @@
package measurexlite
import (
"context"
"errors"
"math"
"net"
"strconv"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewDialerWithoutResolver(t *testing.T) {
t.Run("NewDialerWithoutResolver creates a wrapped dialer", func(t *testing.T) {
underlying := &mocks.Dialer{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
return underlying
}
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
dt := dialer.(*dialerTrace)
if dt.d != underlying {
t.Fatal("invalid dialer")
}
if dt.tx != trace {
t.Fatal("invalid trace")
}
})
t.Run("DialContext calls the underlying dialer with context-based tracing", func(t *testing.T) {
expectedErr := errors.New("mocked err")
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
var hasCorrectTrace bool
underlying := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
gotTrace := netxlite.ContextTraceOrDefault(ctx)
hasCorrectTrace = (gotTrace == trace)
return nil, expectedErr
},
}
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
return underlying
}
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx := context.Background()
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
if !hasCorrectTrace {
t.Fatal("does not have the correct trace")
}
})
t.Run("CloseIdleConnection is correctly forwarded", func(t *testing.T) {
var called bool
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
underlying := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
return underlying
}
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("DialContext saves into the trace", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time tracking
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 1 {
t.Fatal("expected to see single TCPConnect event")
}
expectedFailure := netxlite.FailureInterrupted
expect := &model.ArchivalTCPConnectResult{
IP: "1.1.1.1",
Port: 443,
Status: model.ArchivalTCPConnectStatus{
Blocked: nil,
Failure: &expectedFailure,
Success: false,
},
T: time.Second.Seconds(),
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("DialContext discards events when buffer is full", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.TCPConnect = make(chan *model.ArchivalTCPConnectResult) // no buffer
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 0 {
t.Fatal("expected to see no TCPConnect events")
}
})
t.Run("DialContext ignores UDP connect attempts", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "udp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 0 {
t.Fatal("expected to see no TCPConnect events")
}
})
t.Run("DialContext uses a dialer without a resolver", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "udp", "dns.google:443") // domain
if !errors.Is(err, netxlite.ErrNoResolver) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 0 {
t.Fatal("expected to see no TCPConnect events")
}
})
}
func TestArchivalSplitHostPort(t *testing.T) {
addr, port := archivalSplitHostPort("1.1.1.1") // missing port
if addr != "" {
t.Fatal("invalid addr", addr)
}
if port != "" {
t.Fatal("invalid port", port)
}
}
func TestArchivalPortToString(t *testing.T) {
t.Run("with invalid number", func(t *testing.T) {
port := archivalPortToString("antani")
if port != 0 {
t.Fatal("invalid port")
}
})
t.Run("with negative number", func(t *testing.T) {
port := archivalPortToString("-1")
if port != 0 {
t.Fatal("invalid port")
}
})
t.Run("with too-large positive number", func(t *testing.T) {
port := archivalPortToString(strconv.Itoa(math.MaxUint16 + 1))
if port != 0 {
t.Fatal("invalid port")
}
})
}

View File

@ -0,0 +1,10 @@
// Package measurexlite contains measurement extensions.
//
// See docs/design/dd-003-step-by-step.md in the ooni/probe-cli
// repository for the design document.
//
// This implementation features a Trace that saves events in
// buffered channels as proposed by df-003-step-by-step.md. We
// have reasonable default buffers for channels. But, if you
// are not draining them, eventually we stop collecting events.
package measurexlite

View File

@ -0,0 +1,16 @@
package measurexlite
//
// Logging support
//
import "github.com/ooni/probe-cli/v3/internal/measurex"
// TODO(bassosimone): we should eventually remove measurex and
// move the logging code from measurex to this package.
// NewOperationLogger is an alias for measurex.NewOperationLogger.
var NewOperationLogger = measurex.NewOperationLogger
// OperationLogger is an alias for measurex.OperationLogger.
type OperationLogger = measurex.OperationLogger

View File

@ -0,0 +1,144 @@
package measurexlite
//
// TLS tracing
//
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/tracex"
)
// NewTLSHandshakerStdlib is equivalent to netxlite.NewTLSHandshakerStdlib
// except that it returns a model.TLSHandshaker that uses this trace.
func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
return &tlsHandshakerTrace{
thx: tx.newTLSHandshakerStdlib(dl),
tx: tx,
}
}
// tlsHandshakerTrace is a trace-aware TLS handshaker.
type tlsHandshakerTrace struct {
thx model.TLSHandshaker
tx *Trace
}
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
// Handshake implements model.TLSHandshaker.Handshake.
func (thx *tlsHandshakerTrace) Handshake(
ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (net.Conn, tls.ConnectionState, error) {
return thx.thx.Handshake(netxlite.ContextWithTrace(ctx, thx.tx), conn, tlsConfig)
}
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
t := now.Sub(tx.ZeroTime)
select {
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_start"):
default: // buffer is full
}
}
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Time) {
t := finished.Sub(tx.ZeroTime)
select {
case tx.TLSHandshake <- NewArchivalTLSOrQUICHandshakeResult(
tx.Index,
started.Sub(tx.ZeroTime),
remoteAddr,
config,
state,
err,
t,
):
default: // buffer is full
}
select {
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_done"):
default: // buffer is full
}
}
// NewArchivalTLSOrQUICHandshakeResult generates a model.ArchivalTLSOrQUICHandshakeResult
// from the available information right after the TLS handshake returns.
func NewArchivalTLSOrQUICHandshakeResult(
index int64, started time.Duration, address string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Duration) *model.ArchivalTLSOrQUICHandshakeResult {
return &model.ArchivalTLSOrQUICHandshakeResult{
Address: address,
CipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
Failure: tracex.NewFailure(err),
NegotiatedProtocol: state.NegotiatedProtocol,
NoTLSVerify: config.InsecureSkipVerify,
PeerCertificates: TLSPeerCerts(state, err),
ServerName: config.ServerName,
T: finished.Seconds(),
Tags: []string{},
TLSVersion: netxlite.TLSVersionString(state.Version),
}
}
// newArchivalBinaryData is a factory that adapts binary data to the
// model.ArchivalMaybeBinaryData format.
func newArchivalBinaryData(data []byte) model.ArchivalMaybeBinaryData {
// TODO(https://github.com/ooni/probe/issues/2165): we should actually extend the
// model's archival data format to have a pure-binary-data type for the cases in which
// we know in advance we're dealing with binary data.
return model.ArchivalMaybeBinaryData{
Value: string(data),
}
}
// TLSPeerCerts extracts the certificates either from the list of certificates
// in the connection state or from the error that occurred.
func TLSPeerCerts(
state tls.ConnectionState, err error) (out []model.ArchivalMaybeBinaryData) {
out = []model.ArchivalMaybeBinaryData{}
var x509HostnameError x509.HostnameError
if errors.As(err, &x509HostnameError) {
// Test case: https://wrong.host.badssl.com/
out = append(out, newArchivalBinaryData(x509HostnameError.Certificate.Raw))
return
}
var x509UnknownAuthorityError x509.UnknownAuthorityError
if errors.As(err, &x509UnknownAuthorityError) {
// Test case: https://self-signed.badssl.com/. This error has
// never been among the ones returned by MK.
out = append(out, newArchivalBinaryData(x509UnknownAuthorityError.Cert.Raw))
return
}
var x509CertificateInvalidError x509.CertificateInvalidError
if errors.As(err, &x509CertificateInvalidError) {
// Test case: https://expired.badssl.com/
out = append(out, newArchivalBinaryData(x509CertificateInvalidError.Cert.Raw))
return
}
for _, cert := range state.PeerCertificates {
out = append(out, newArchivalBinaryData(cert.Raw))
}
return
}
// TLSHandshakes drains the network events buffered inside the TLSHandshake channel.
func (tx *Trace) TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) {
for {
select {
case ev := <-tx.TLSHandshake:
out = append(out, ev)
default:
return // done
}
}
}

View File

@ -0,0 +1,418 @@
package measurexlite
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewTLSHandshakerStdlib(t *testing.T) {
t.Run("NewTLSHandshakerStdlib creates a wrapped TLSHandshaker", func(t *testing.T) {
underlying := &mocks.TLSHandshaker{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
return underlying
}
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
thxt := thx.(*tlsHandshakerTrace)
if thxt.thx != underlying {
t.Fatal("invalid TLS handshaker")
}
if thxt.tx != trace {
t.Fatal("invalid trace")
}
})
t.Run("Handshake calls the underlying dialer with context-based tracing", func(t *testing.T) {
expectedErr := errors.New("mocked err")
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
var hasCorrectTrace bool
underlying := &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
gotTrace := netxlite.ContextTraceOrDefault(ctx)
hasCorrectTrace = (gotTrace == trace)
return nil, tls.ConnectionState{}, expectedErr
},
}
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
return underlying
}
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected zero-value state")
}
if conn != nil {
t.Fatal("expected nil conn")
}
if !hasCorrectTrace {
t.Fatal("does not have the correct trace")
}
})
t.Run("Handshake saves into the trace", func(t *testing.T) {
mockedErr := errors.New("mocked")
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic timing
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
MockWrite: func(b []byte) (int, error) {
return 0, mockedErr
},
MockClose: func() error {
return nil
},
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: "dns.cloudflare.com",
}
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected zero-value state")
}
if conn != nil {
t.Fatal("expected nil conn")
}
t.Run("TLSHandshake events", func(t *testing.T) {
events := trace.TLSHandshakes()
if len(events) != 1 {
t.Fatal("expected to see single TLSHandshake event")
}
expectedFailure := netxlite.FailureInterrupted
expect := &model.ArchivalTLSOrQUICHandshakeResult{
Address: "1.1.1.1:443",
CipherSuite: "",
Failure: &expectedFailure,
NegotiatedProtocol: "",
NoTLSVerify: true,
PeerCertificates: []model.ArchivalMaybeBinaryData{},
ServerName: "dns.cloudflare.com",
T: time.Second.Seconds(),
Tags: []string{},
TLSVersion: "",
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Network events", func(t *testing.T) {
events := trace.NetworkEvents()
if len(events) != 2 {
t.Fatal("expected to see two Network events")
}
t.Run("tls_handshake_start", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_start",
Proto: "",
T: 0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("tls_handshake_done", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_done",
Proto: "",
T: time.Second.Seconds(),
Tags: []string{},
}
got := events[1]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
})
})
t.Run("Handshake discards events when buffers are full", func(t *testing.T) {
mockedErr := errors.New("mocked")
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
trace.TLSHandshake = make(chan *model.ArchivalTLSOrQUICHandshakeResult) // no buffer
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
MockWrite: func(b []byte) (int, error) {
return 0, mockedErr
},
MockClose: func() error {
return nil
},
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: "dns.cloudflare.com",
}
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected zero-value state")
}
if conn != nil {
t.Fatal("expected nil conn")
}
t.Run("TLSHandshake events", func(t *testing.T) {
events := trace.TLSHandshakes()
if len(events) != 0 {
t.Fatal("expected to see no TLSHandshake events")
}
})
t.Run("Network events", func(t *testing.T) {
events := trace.NetworkEvents()
if len(events) != 0 {
t.Fatal("expected to see no Network events")
}
})
})
t.Run("we collect the desired data with a local TLS server", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger)
ctx := context.Background()
conn, err := dialer.DialContext(ctx, "tcp", server.Endpoint())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
zeroTime := time.Now()
dt := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = dt.Now // deterministic timing
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
tlsConfig := &tls.Config{
RootCAs: server.CertPool(),
ServerName: "dns.google",
}
tlsConn, connState, err := thx.Handshake(ctx, conn, tlsConfig)
if err != nil {
t.Fatal(err)
}
defer tlsConn.Close()
data, err := netxlite.ReadAllContext(ctx, tlsConn)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, filtering.HTTPBlockpage451) {
t.Fatal("bytes should match")
}
t.Run("TLSHandshake events", func(t *testing.T) {
events := trace.TLSHandshakes()
if len(events) != 1 {
t.Fatal("expected to see a single TLSHandshake event")
}
expected := &model.ArchivalTLSOrQUICHandshakeResult{
Address: conn.RemoteAddr().String(),
CipherSuite: netxlite.TLSCipherSuiteString(connState.CipherSuite),
Failure: nil,
NegotiatedProtocol: "",
NoTLSVerify: false,
PeerCertificates: []model.ArchivalMaybeBinaryData{},
ServerName: "dns.google",
T: time.Second.Seconds(),
Tags: []string{},
TLSVersion: netxlite.TLSVersionString(connState.Version),
}
got := events[0]
// TODO(bassosimone): it's still unclear to me how to test that
// I am getting exactly the expected certificate here. I think the
// certificate is generated on the fly by google/martian. So, I'm
// just going to reduce the precision of this check.
if len(got.PeerCertificates) != 2 {
t.Fatal("expected to see two certificates")
}
got.PeerCertificates = []model.ArchivalMaybeBinaryData{} // see above
if diff := cmp.Diff(expected, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Network events", func(t *testing.T) {
events := trace.NetworkEvents()
if len(events) != 2 {
t.Fatal("expected to see two Network events")
}
t.Run("tls_handshake_start", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_start",
Proto: "",
T: 0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("tls_handshake_done", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_done",
Proto: "",
T: time.Second.Seconds(),
Tags: []string{},
}
got := events[1]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
})
})
}
func TestTLSPeerCerts(t *testing.T) {
type args struct {
state tls.ConnectionState
err error
}
tests := []struct {
name string
args args
wantOut []model.ArchivalMaybeBinaryData
}{{
name: "x509.HostnameError",
args: args{
state: tls.ConnectionState{},
err: x509.HostnameError{
Certificate: &x509.Certificate{
Raw: []byte("deadbeef"),
},
},
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}},
}, {
name: "x509.UnknownAuthorityError",
args: args{
state: tls.ConnectionState{},
err: x509.UnknownAuthorityError{
Cert: &x509.Certificate{
Raw: []byte("deadbeef"),
},
},
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}},
}, {
name: "x509.CertificateInvalidError",
args: args{
state: tls.ConnectionState{},
err: x509.CertificateInvalidError{
Cert: &x509.Certificate{
Raw: []byte("deadbeef"),
},
},
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}},
}, {
name: "successful case",
args: args{
state: tls.ConnectionState{
PeerCertificates: []*x509.Certificate{{
Raw: []byte("deadbeef"),
}, {
Raw: []byte("abad1dea"),
}},
},
err: nil,
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}, {
Value: "abad1dea",
}},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotOut := TLSPeerCerts(tt.args.state, tt.args.err)
if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" {
t.Fatal(diff)
}
})
}
}

View File

@ -0,0 +1,143 @@
package measurexlite
//
// Definition of Trace
//
import (
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// Trace implements model.Trace.
//
// The zero-value of this struct is invalid. To construct you should either
// fill all the fields marked as MANDATORY or use NewTrace.
//
// Buffered channels
//
// NewTrace uses reasonable buffer sizes for the channels used for collecting
// events. You should drain the channels used by this implementation after
// each operation you perform (i.e., we expect you to peform step-by-step
// measurements). If you want larger (or smaller) buffers, then you should
// construct this data type manually with the desired buffer sizes.
//
// We have convenience methods for extracting events from the buffered
// channels. Otherwise, you could read the channels directly. (In which
// case, remember to issue nonblocking channel reads because channels are
// never closed and they're just written when new events occur.)
type Trace struct {
// Index is the MANDATORY unique index of this trace within the
// current measurement. If you don't care about uniquely identifying
// treaces, you can use zero to indicate the "default" trace.
Index int64
// NetworkEvent is MANDATORY and buffers network events. If you create
// this channel manually, ensure it has some buffer.
NetworkEvent chan *model.ArchivalNetworkEvent
// NewDialerWithoutResolverFn is OPTIONAL and can be used to override
// calls to the netxlite.NewDialerWithoutResolver factory.
NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer
// NewTLSHandshakerStdlibFn is OPTIONAL and can be used to overide
// calls to the netxlite.NewTLSHandshakerStdlib factory.
NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker
// TCPConnect is MANDATORY and buffers TCP connect observations. If you create
// this channel manually, ensure it has some buffer.
TCPConnect chan *model.ArchivalTCPConnectResult
// TLSHandshake is MANDATORY and buffers TLS handshake observations. If you create
// this channel manually, ensure it has some buffer.
TLSHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
// to produce deterministic timing when testing.
TimeNowFn func() time.Time
// ZeroTime is the MANDATORY time when we started the current measurement.
ZeroTime time.Time
}
const (
// NetworkEventBufferSize is the buffer size for constructing
// the Trace's NetworkEvent buffered channel.
NetworkEventBufferSize = 64
// TCPConnectBufferSize is the buffer size for constructing
// the Trace's TCPConnect buffered channel.
TCPConnectBufferSize = 8
// TLSHandshakeBufferSize is the buffer for construcing
// the Trace's TLSHandshake buffered channel.
TLSHandshakeBufferSize = 8
)
// NewTrace creates a new instance of Trace using default settings.
//
// We create buffered channels using as buffer sizes the constants that
// are also defined by this package.
//
// Arguments:
//
// - index is the unique index of this trace within the current measurement (use
// zero if you don't care about giving this trace a unique ID);
//
// - zeroTime is the time when we started the current measurement.
func NewTrace(index int64, zeroTime time.Time) *Trace {
return &Trace{
Index: index,
NetworkEvent: make(
chan *model.ArchivalNetworkEvent,
NetworkEventBufferSize,
),
NewDialerWithoutResolverFn: nil, // use default
NewTLSHandshakerStdlibFn: nil, // use default
TCPConnect: make(
chan *model.ArchivalTCPConnectResult,
TCPConnectBufferSize,
),
TLSHandshake: make(
chan *model.ArchivalTLSOrQUICHandshakeResult,
TLSHandshakeBufferSize,
),
TimeNowFn: nil, // use default
ZeroTime: zeroTime,
}
}
// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver
// thus allows us to mock this func for testing.
func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
if tx.NewDialerWithoutResolverFn != nil {
return tx.NewDialerWithoutResolverFn(dl)
}
return netxlite.NewDialerWithoutResolver(dl)
}
// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
// thus allowing us to mock this func for testing.
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
if tx.NewTLSHandshakerStdlibFn != nil {
return tx.NewTLSHandshakerStdlibFn(dl)
}
return netxlite.NewTLSHandshakerStdlib(dl)
}
// TimeNow implements model.Trace.TimeNow.
func (tx *Trace) TimeNow() time.Time {
if tx.TimeNowFn != nil {
return tx.TimeNowFn()
}
return time.Now()
}
// TimeSince is equivalent to Trace.TimeNow().Sub(t0).
func (tx *Trace) TimeSince(t0 time.Time) time.Duration {
return tx.TimeNow().Sub(t0)
}
var _ model.Trace = &Trace{}

View File

@ -0,0 +1,248 @@
package measurexlite
import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewTrace(t *testing.T) {
t.Run("NewTrace correctly constructs a trace", func(t *testing.T) {
const index = 17
zeroTime := time.Now()
trace := NewTrace(index, zeroTime)
t.Run("Index", func(t *testing.T) {
if trace.Index != index {
t.Fatal("invalid index")
}
})
t.Run("NetworkEvent has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalNetworkEvent{}
ff.Fill(ev)
select {
case trace.NetworkEvent <- ev:
idx++
default:
break Loop
}
}
if idx != NetworkEventBufferSize {
t.Fatal("invalid NetworkEvent channel buffer size")
}
})
t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) {
if trace.NewDialerWithoutResolverFn != nil {
t.Fatal("expected nil NewDialerWithoutResolverFn")
}
})
t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) {
if trace.NewTLSHandshakerStdlibFn != nil {
t.Fatal("expected nil NewTLSHandshakerStdlibFn")
}
})
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalTCPConnectResult{}
ff.Fill(ev)
select {
case trace.TCPConnect <- ev:
idx++
default:
break Loop
}
}
if idx != TCPConnectBufferSize {
t.Fatal("invalid TCPConnect channel buffer size")
}
})
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalTLSOrQUICHandshakeResult{}
ff.Fill(ev)
select {
case trace.TLSHandshake <- ev:
idx++
default:
break Loop
}
}
if idx != TLSHandshakeBufferSize {
t.Fatal("invalid TLSHandshake channel buffer size")
}
})
t.Run("TimeNowFn is nil", func(t *testing.T) {
if trace.TimeNowFn != nil {
t.Fatal("expected nil TimeNowFn")
}
})
t.Run("ZeroTime", func(t *testing.T) {
if !trace.ZeroTime.Equal(zeroTime) {
t.Fatal("invalid zero time")
}
})
})
}
func TestTrace(t *testing.T) {
t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer {
return &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mockedErr
},
}
},
}
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
ctx := context.Background()
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("when nil", func(t *testing.T) {
tx := &Trace{
NewDialerWithoutResolverFn: nil,
}
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker {
return &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, mockedErr
},
}
},
}
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("state is not a zero value")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("when nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewTLSHandshakerStdlibFn: nil,
}
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
MockWrite: func(b []byte) (int, error) {
return 0, mockedErr
},
MockClose: func() error {
return nil
},
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
}
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("state is not a zero value")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("TimeNowFn works as intended", func(t *testing.T) {
fixedTime := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
tx := &Trace{
TimeNowFn: func() time.Time {
return fixedTime
},
}
if !tx.TimeNow().Equal(fixedTime) {
t.Fatal("we cannot override time.Now calls")
}
})
t.Run("TimeSince works as intended", func(t *testing.T) {
t0 := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
t1 := t0.Add(10 * time.Second)
tx := &Trace{
TimeNowFn: func() time.Time {
return t1
},
}
if tx.TimeSince(t0) != 10*time.Second {
t.Fatal("apparently Trace.Since is broken")
}
})
}

View File

@ -4,7 +4,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/fakefill"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestArchivalExtSpec(t *testing.T) {
@ -301,7 +301,7 @@ func TestHTTPBody(t *testing.T) {
// we make a mistake and apply the above change (which will in turn
// break correct JSON serialization), the this test will fail.
var body ArchivalHTTPBody
ff := &fakefill.Filler{}
ff := &testingx.FakeFiller{}
ff.Fill(&body)
data := ArchivalMaybeBinaryData(body)
if diff := cmp.Diff(body, data); diff != "" {

View File

@ -0,0 +1,45 @@
package mocks
//
// Mocks for model.Trace
//
import (
"crypto/tls"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
)
// Trace allows mocking model.Trace.
type Trace struct {
MockTimeNow func() time.Time
MockOnConnectDone func(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
MockOnTLSHandshakeStart func(now time.Time, remoteAddr string, config *tls.Config)
MockOnTLSHandshakeDone func(started time.Time, remoteAddr string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Time)
}
var _ model.Trace = &Trace{}
func (t *Trace) TimeNow() time.Time {
return t.MockTimeNow()
}
func (t *Trace) OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished)
}
func (t *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
t.MockOnTLSHandshakeStart(now, remoteAddr, config)
}
func (t *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Time) {
t.MockOnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
}

View File

@ -0,0 +1,74 @@
package mocks
import (
"crypto/tls"
"testing"
"time"
)
func TestTrace(t *testing.T) {
t.Run("TimeNow", func(t *testing.T) {
now := time.Now()
tx := &Trace{
MockTimeNow: func() time.Time {
return now
},
}
if !tx.TimeNow().Equal(now) {
t.Fatal("not working as intended")
}
})
t.Run("OnConnectDone", func(t *testing.T) {
var called bool
tx := &Trace{
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
called = true
},
}
tx.OnConnectDone(
time.Now(),
"tcp",
"dns.google",
"8.8.8.8:443",
nil,
time.Now(),
)
if !called {
t.Fatal("not called")
}
})
t.Run("OnTLSHandshakeStart", func(t *testing.T) {
var called bool
tx := &Trace{
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
called = true
},
}
tx.OnTLSHandshakeStart(time.Now(), "8.8.8.8:443", &tls.Config{})
if !called {
t.Fatal("not called")
}
})
t.Run("OnTLSHandshakeDone", func(t *testing.T) {
var called bool
tx := &Trace{
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
called = true
},
}
tx.OnTLSHandshakeDone(
time.Now(),
"8.8.8.8:443",
&tls.Config{},
tls.ConnectionState{},
nil,
time.Now(),
)
if !called {
t.Fatal("not called")
}
})
}

View File

@ -292,6 +292,76 @@ type TLSHandshaker interface {
net.Conn, tls.ConnectionState, error)
}
// Trace allows to collect measurement traces. A trace is injected into
// netx operations using context.WithValue. Netx code retrieves the trace
// using context.Value. See docs/design/dd-003-step-by-step.md for the
// design document explaining why we implemented context-based tracing.
type Trace interface {
// TimeNow returns the current time. Normally, this should be the same
// value returned by time.Now but you may want to manipulate the time
// returned when testing to have deterministic tests. To this end, you
// can use functionality exported by the ./internal/testingx pkg.
TimeNow() time.Time
// OnConnectDone is called when connect terminates.
//
// Arguments:
//
// - started is when we called connect;
//
// - network is the network we're using (one of "tcp" and "udp");
//
// - domain is the domain for which we're calling connect. If the user called
// connect for an IP address and a port, then domain will be an IP address;
//
// - remoteAddr is the TCP endpoint with which we are connecting: it will
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
//
// - err is the result of connect: either an error or nil;
//
// - finished is when connect returned.
//
// The error passed to this function will always be wrapped such that the
// string returned by Error is an OONI error.
OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
// OnTLSHandshakeStart is called when the TLS handshake starts.
//
// Arguments:
//
// - now is the moment before we start the handshake;
//
// - remoteAddr is the TCP endpoint with which we are connecting: it will
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
//
// - config is the non-nil TLS config we're using.
OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config)
// OnTLSHandshakeDone is called when the TLS handshake terminates.
//
// Arguments:
//
// - started is when we started the handshake;
//
// - remoteAddr is the TCP endpoint with which we are connecting: it will
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
//
// - config is the non-nil TLS config we're using;
//
// - state is the state of the TLS connection after the handshake, where all
// fields are zero-initialized if the handshake failed;
//
// - err is the result of the handshake: either an error or nil;
//
// - finished is right after the handshake.
//
// The error passed to this function will always be wrapped such that the
// string returned by Error is an OONI error.
OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Time)
}
// UDPLikeConn is a net.PacketConn with some extra functions
// required to convince the QUIC library (lucas-clemente/quic-go)
// to inflate the receive buffer of the connection.

View File

@ -47,7 +47,7 @@ func (r *bogonResolver) LookupHost(ctx context.Context, hostname string) ([]stri
for _, addr := range addrs {
if IsBogon(addr) {
// wrap ErrDNSBogon as documented
return nil, newErrWrapper(classifyResolverError, ResolveOperation, ErrDNSBogon)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, ErrDNSBogon)
}
}
return addrs, nil

View File

@ -15,8 +15,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/scrubber"
)
// classifyGenericError is maps an error occurred during an operation
// to an OONI failure string. This specific classifier is the most
// ClassifyGenericError maps an error occurred during an operation to
// an OONI failure string. This specific classifier is the most
// generic one. You usually use it when mapping I/O errors. You should
// check whether there is a specific classifier for more specific
// operations (e.g., DNS resolution, TLS handshake).
@ -38,7 +38,7 @@ import (
// If everything else fails, this classifier returns a string
// like "unknown_failure: XXX" where XXX has been scrubbed
// so to remove any network endpoints from the original error string.
func classifyGenericError(err error) string {
func ClassifyGenericError(err error) string {
// The list returned here matches the values used by MK unless
// explicitly noted otherwise with a comment.
@ -139,7 +139,7 @@ const (
quicTLSUnrecognizedName = 112
)
// classifyQUICHandshakeError maps errors during a QUIC
// ClassifyQUICHandshakeError maps errors during a QUIC
// handshake to OONI failure strings.
//
// If the input error is an *ErrWrapper we don't perform
@ -147,7 +147,7 @@ const (
//
// If this classifier fails, it calls ClassifyGenericError
// and returns to the caller its return value.
func classifyQUICHandshakeError(err error) string {
func ClassifyQUICHandshakeError(err error) string {
// Robustness: handle the case where we're passed a wrapped error.
var errwrapper *ErrWrapper
@ -207,7 +207,7 @@ func classifyQUICHandshakeError(err error) string {
}
}
}
return classifyGenericError(err)
return ClassifyGenericError(err)
}
// quicIsCertificateError tells us whether a specific TLS alert error
@ -277,7 +277,7 @@ var (
// anything as explained in getaddrinfo_linux.go.
var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
// classifyResolverError maps DNS resolution errors to
// ClassifyResolverError maps DNS resolution errors to
// OONI failure strings.
//
// If the input error is an *ErrWrapper we don't perform
@ -285,7 +285,7 @@ var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
//
// If this classifier fails, it calls ClassifyGenericError and
// returns to the caller its return value.
func classifyResolverError(err error) string {
func ClassifyResolverError(err error) string {
// Robustness: handle the case where we're passed a wrapped error.
var errwrapper *ErrWrapper
@ -310,10 +310,10 @@ func classifyResolverError(err error) string {
if errors.Is(err, ErrAndroidDNSCacheNoData) {
return FailureAndroidDNSCacheNoData
}
return classifyGenericError(err)
return ClassifyGenericError(err)
}
// classifyTLSHandshakeError maps an error occurred during the TLS
// ClassifyTLSHandshakeError maps an error occurred during the TLS
// handshake to an OONI failure string.
//
// If the input error is an *ErrWrapper we don't perform
@ -321,7 +321,7 @@ func classifyResolverError(err error) string {
//
// If this classifier fails, it calls ClassifyGenericError and
// returns to the caller its return value.
func classifyTLSHandshakeError(err error) string {
func ClassifyTLSHandshakeError(err error) string {
// Robustness: handle the case where we're passed a wrapped error.
var errwrapper *ErrWrapper
@ -345,5 +345,5 @@ func classifyTLSHandshakeError(err error) string {
// Test case: https://expired.badssl.com/
return FailureSSLInvalidCertificate
}
return classifyGenericError(err)
return ClassifyGenericError(err)
}

View File

@ -18,13 +18,13 @@ func TestClassifyGenericError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyGenericError(err) != FailureEOFError {
if ClassifyGenericError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for a system call error", func(t *testing.T) {
if classifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
if ClassifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
t.Fatal("unexpected results")
}
})
@ -35,63 +35,63 @@ func TestClassifyGenericError(t *testing.T) {
// is just an implementation detail.
t.Run("for operation was canceled", func(t *testing.T) {
if classifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
if ClassifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
t.Fatal("unexpected result")
}
})
t.Run("for EOF", func(t *testing.T) {
if classifyGenericError(io.EOF) != FailureEOFError {
if ClassifyGenericError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})
t.Run("for context deadline exceeded", func(t *testing.T) {
if classifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
if ClassifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for stun's transaction is timed out", func(t *testing.T) {
if classifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
if ClassifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for i/o timeout", func(t *testing.T) {
if classifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
if ClassifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for TLS handshake timeout", func(t *testing.T) {
err := errors.New("net/http: TLS handshake timeout")
if classifyGenericError(err) != FailureGenericTimeoutError {
if ClassifyGenericError(err) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for no such host", func(t *testing.T) {
if classifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
if ClassifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
t.Fatal("unexpected results")
}
})
t.Run("for dns server misbehaving", func(t *testing.T) {
if classifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
if ClassifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
t.Fatal("unexpected results")
}
})
t.Run("for no answer from DNS server", func(t *testing.T) {
if classifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
if ClassifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
t.Fatal("unexpected results")
}
})
t.Run("for use of closed network connection", func(t *testing.T) {
err := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
if classifyGenericError(err) != FailureConnectionAlreadyClosed {
if ClassifyGenericError(err) != FailureConnectionAlreadyClosed {
t.Fatal("unexpected results")
}
})
@ -99,7 +99,7 @@ func TestClassifyGenericError(t *testing.T) {
// Now we're back in ClassifyGenericError
t.Run("for context.Canceled", func(t *testing.T) {
if classifyGenericError(context.Canceled) != FailureInterrupted {
if ClassifyGenericError(context.Canceled) != FailureInterrupted {
t.Fatal("unexpected result")
}
})
@ -108,7 +108,7 @@ func TestClassifyGenericError(t *testing.T) {
t.Run("with an IPv4 address", func(t *testing.T) {
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: some error")
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
out := classifyGenericError(input)
out := ClassifyGenericError(input)
if out != expected {
t.Fatal(cmp.Diff(expected, out))
}
@ -117,7 +117,7 @@ func TestClassifyGenericError(t *testing.T) {
t.Run("with an IPv6 address", func(t *testing.T) {
input := errors.New("read tcp [::1]:56948->[::1]:443: some error")
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
out := classifyGenericError(input)
out := ClassifyGenericError(input)
if out != expected {
t.Fatal(cmp.Diff(expected, out))
}
@ -131,100 +131,100 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyQUICHandshakeError(err) != FailureEOFError {
if ClassifyQUICHandshakeError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for incompatible quic version", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
if ClassifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
t.Fatal("unexpected results")
}
})
t.Run("for stateless reset", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
if ClassifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
t.Fatal("unexpected results")
}
})
t.Run("for handshake timeout", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
if ClassifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for idle timeout", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
if ClassifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for connection refused", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
t.Fatal("unexpected results")
}
})
t.Run("for bad certificate", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertBadCertificate
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for unsupported certificate", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertUnsupportedCertificate
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for certificate expired", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertCertificateExpired
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for certificate revoked", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertCertificateRevoked
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for certificate unknown", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertCertificateUnknown
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for decrypt error", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertDecryptError
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
t.Fatal("unexpected results")
}
})
t.Run("for handshake failure", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertHandshakeFailure
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
t.Fatal("unexpected results")
}
})
t.Run("for unknown CA", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertUnknownCA
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
t.Fatal("unexpected results")
}
})
t.Run("for unrecognized hostname", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSUnrecognizedName
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
t.Fatal("unexpected results")
}
})
@ -234,13 +234,13 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
ErrorCode: quic.InternalError,
ErrorMessage: FailureHostUnreachable,
}
if classifyQUICHandshakeError(err) != FailureHostUnreachable {
if ClassifyQUICHandshakeError(err) != FailureHostUnreachable {
t.Fatal("unexpected results")
}
})
t.Run("for another kind of error", func(t *testing.T) {
if classifyQUICHandshakeError(io.EOF) != FailureEOFError {
if ClassifyQUICHandshakeError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})
@ -252,43 +252,43 @@ func TestClassifyResolverError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyResolverError(err) != FailureEOFError {
if ClassifyResolverError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for ErrDNSBogon", func(t *testing.T) {
if classifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
if ClassifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
t.Fatal("unexpected result")
}
})
t.Run("for refused", func(t *testing.T) {
if classifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
if ClassifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
t.Fatal("unexpected result")
}
})
t.Run("for servfail", func(t *testing.T) {
if classifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
if ClassifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
t.Fatal("unexpected result")
}
})
t.Run("for dns reply with wrong queryID", func(t *testing.T) {
if classifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
if ClassifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
t.Fatal("unexpected result")
}
})
t.Run("for EAI_NODATA returned by Android's getaddrinfo", func(t *testing.T) {
if classifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
if ClassifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
t.Fatal("unexpected result")
}
})
t.Run("for another kind of error", func(t *testing.T) {
if classifyResolverError(io.EOF) != FailureEOFError {
if ClassifyResolverError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})
@ -300,34 +300,34 @@ func TestClassifyTLSHandshakeError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyTLSHandshakeError(err) != FailureEOFError {
if ClassifyTLSHandshakeError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for x509.HostnameError", func(t *testing.T) {
var err x509.HostnameError
if classifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
t.Fatal("unexpected result")
}
})
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
var err x509.UnknownAuthorityError
if classifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
if ClassifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
t.Fatal("unexpected result")
}
})
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
var err x509.CertificateInvalidError
if classifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
t.Fatal("unexpected result")
}
})
t.Run("for another kind of error", func(t *testing.T) {
if classifyTLSHandshakeError(io.EOF) != FailureEOFError {
if ClassifyTLSHandshakeError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})

View File

@ -125,7 +125,7 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver,
outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors
}
return &dialerLogger{
Dialer: &dialerResolver{
Dialer: &dialerResolverWithTracing{
Dialer: &dialerLogger{
Dialer: outDialer,
DebugLogger: logger,
@ -171,15 +171,24 @@ func (d *DialerSystem) CloseIdleConnections() {
// nothing to do here
}
// dialerResolver combines dialing with domain name resolution.
type dialerResolver struct {
// dialerResolverWithTracing combines dialing with domain name resolution and
// implements hooks to trace TCP (or UDP) connect operations.
type dialerResolverWithTracing struct {
Dialer model.Dialer
Resolver model.Resolver
}
var _ model.Dialer = &dialerResolver{}
var _ model.Dialer = &dialerResolverWithTracing{}
func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
// DialContext implements model.Dialer.DialContext. Specifically this
// method performs the following operations:
//
// 1. resolve the domain inside the address using a resolver;
//
// 2. cycle through the available IP addresses and try to dial each of them;
//
// 3. trace the TCP (or UDP) connect and allow wrapping the returned conn.
func (d *dialerResolverWithTracing) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
// QUIRK: this routine and the related routines in quirks.go cannot
// be changed easily until we use events tracing to measure.
//
@ -194,10 +203,27 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
}
addrs = quirkSortIPAddrs(addrs)
var errorslist []error
trace := ContextTraceOrDefault(ctx)
for _, addr := range addrs {
target := net.JoinHostPort(addr, onlyport)
started := trace.TimeNow()
conn, err := d.Dialer.DialContext(ctx, network, target)
finished := trace.TimeNow()
// TODO(bassosimone): to make the code robust to future refactoring we have
// moved error wrapping inside this type. This change opens up the possibility
// of simplifying the dialing chain by removing dialerErrWrapper. We'll be
// able to implement this refactoring once netx is gone. We cannot complete
// this refactoring _before_ because WrapDialer inserts extra wrappers
// provided by netx in the dialers chain _before_ this dialer and the dialers
// that netx insert assume that they wrap a dialer with error wrapping.
//
// Because error wrapping should be idempotent, it should not be a problem
// to have two error wrapping dialers in the chain except that, of course, it
// would be less efficient than just having a single wrapper.
err = MaybeNewErrWrapper(ClassifyGenericError, ConnectOperation, err)
trace.OnConnectDone(started, network, onlyhost, target, err, finished)
if err == nil {
conn = &dialerErrWrapperConn{conn}
return conn, nil
}
errorslist = append(errorslist, err)
@ -206,14 +232,14 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
}
// lookupHost ensures we correctly handle IP addresses.
func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
func (d *dialerResolverWithTracing) lookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {
return []string{hostname}, nil
}
return d.Resolver.LookupHost(ctx, hostname)
}
func (d *dialerResolver) CloseIdleConnections() {
func (d *dialerResolverWithTracing) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
d.Resolver.CloseIdleConnections()
}
@ -303,7 +329,7 @@ var _ model.Dialer = &dialerErrWrapper{}
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, newErrWrapper(classifyGenericError, ConnectOperation, err)
return nil, NewErrWrapper(ClassifyGenericError, ConnectOperation, err)
}
return &dialerErrWrapperConn{Conn: conn}, nil
}
@ -322,7 +348,7 @@ var _ net.Conn = &dialerErrWrapperConn{}
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b)
if err != nil {
return 0, newErrWrapper(classifyGenericError, ReadOperation, err)
return 0, NewErrWrapper(ClassifyGenericError, ReadOperation, err)
}
return count, nil
}
@ -330,7 +356,7 @@ func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
count, err := c.Conn.Write(b)
if err != nil {
return 0, newErrWrapper(classifyGenericError, WriteOperation, err)
return 0, NewErrWrapper(ClassifyGenericError, WriteOperation, err)
}
return count, nil
}
@ -338,7 +364,7 @@ func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
func (c *dialerErrWrapperConn) Close() error {
err := c.Conn.Close()
if err != nil {
return newErrWrapper(classifyGenericError, CloseOperation, err)
return NewErrWrapper(ClassifyGenericError, CloseOperation, err)
}
return nil
}

View File

@ -13,6 +13,7 @@ import (
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewDialerWithStdlibResolver(t *testing.T) {
@ -22,7 +23,7 @@ func TestNewDialerWithStdlibResolver(t *testing.T) {
t.Fatal("invalid logger")
}
// typecheck the resolver
reso := logger.Dialer.(*dialerResolver)
reso := logger.Dialer.(*dialerResolverWithTracing)
typecheckForSystemResolver(t, reso.Resolver, model.DiscardLogger)
// typecheck the dialer
logger = reso.Dialer.(*dialerLogger)
@ -64,7 +65,7 @@ func TestNewDialer(t *testing.T) {
if logger.DebugLogger != log.Log {
t.Fatal("invalid logger")
}
reso := logger.Dialer.(*dialerResolver)
reso := logger.Dialer.(*dialerResolverWithTracing)
if _, okay := reso.Resolver.(*NullResolver); !okay {
t.Fatal("invalid Resolver type")
}
@ -136,10 +137,10 @@ func TestDialerSystem(t *testing.T) {
})
}
func TestDialerResolver(t *testing.T) {
func TestDialerResolverWithTracing(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("fails without a port", func(t *testing.T) {
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &DialerSystem{},
Resolver: NewUnwrappedStdlibResolver(),
}
@ -154,7 +155,7 @@ func TestDialerResolver(t *testing.T) {
})
t.Run("handles dialing error correctly for single IP address", func(t *testing.T) {
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
@ -166,13 +167,26 @@ func TestDialerResolver(t *testing.T) {
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("the error has not been wrapped")
}
if errWrapper.Failure != FailureEOFError {
t.Fatal("invalid wrapped error's failure")
}
if errWrapper.Operation != ConnectOperation {
t.Fatal("invalid wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("handles dialing error correctly for many IP addresses", func(t *testing.T) {
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
@ -188,6 +202,19 @@ func TestDialerResolver(t *testing.T) {
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("the error has not been wrapped")
}
if errWrapper.Failure != FailureEOFError {
t.Fatal("invalid wrapped error's failure")
}
if errWrapper.Operation != ConnectOperation {
t.Fatal("invalid wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
@ -199,7 +226,7 @@ func TestDialerResolver(t *testing.T) {
return nil
},
}
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return expectedConn, nil
@ -215,7 +242,10 @@ func TestDialerResolver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if conn != expectedConn {
// Ensure that the dialer returns a connection that is already wrapping errors,
// which is a new behavior since https://github.com/ooni/probe-cli/pull/815
errWrapperConn := conn.(*dialerErrWrapperConn)
if errWrapperConn.Conn != expectedConn {
t.Fatal("unexpected conn")
}
conn.Close()
@ -225,7 +255,7 @@ func TestDialerResolver(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@ -257,7 +287,7 @@ func TestDialerResolver(t *testing.T) {
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
var attempts []string
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@ -298,14 +328,14 @@ func TestDialerResolver(t *testing.T) {
mu := &sync.Mutex{}
errorsList := []error{
errors.New("a mocked error"),
newErrWrapper(
classifyGenericError,
NewErrWrapper(
ClassifyGenericError,
CloseOperation,
io.EOF,
),
}
var errorIdx int
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@ -337,17 +367,18 @@ func TestDialerResolver(t *testing.T) {
t.Run("though ignores the unknown failures", func(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
expectedErr := errors.New("a mocked error")
mu := &sync.Mutex{}
errorsList := []error{
errors.New("a mocked error"),
newErrWrapper(
classifyGenericError,
expectedErr,
NewErrWrapper(
ClassifyGenericError,
CloseOperation,
errors.New("antani"),
errors.New("antani"), // this is an unknown failure and we should not return it
),
}
var errorIdx int
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@ -368,18 +399,95 @@ func TestDialerResolver(t *testing.T) {
},
}
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
if err == nil || err.Error() != "a mocked error" {
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("error has not been wrapped")
}
if errWrapper.Failure != "unknown_failure: a mocked error" {
t.Fatal("unexpected wrapped error's failure")
}
if errWrapper.Operation != ConnectOperation {
t.Fatal("unexpected wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, expectedErr) {
t.Fatal("unexpected wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("uses a context-injected custom trace", func(t *testing.T) {
var (
called bool
domainOK bool
networkOK bool
remoteAddrOK bool
startTimeOK bool
finishTimeOK bool
wrappedErr bool
)
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
var ew *ErrWrapper
called = true
domainOK = (domain == "1.1.1.1")
networkOK = (network == "tcp")
remoteAddrOK = (remoteAddr == "1.1.1.1:853")
startTimeOK = (started.Sub(zeroTime) == 0)
finishTimeOK = (finished.Sub(zeroTime) == time.Second)
wrappedErr = errors.As(err, &ew) && ew.Failure == FailureEOFError
},
}
ctx := ContextWithTrace(context.Background(), tx)
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
},
},
Resolver: &NullResolver{},
}
conn, err := d.DialContext(ctx, "tcp", "1.1.1.1:853")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn")
}
if !called {
t.Fatal("not called")
}
if !domainOK {
t.Fatal("domain was not okay")
}
if !networkOK {
t.Fatal("network was not okay")
}
if !remoteAddrOK {
t.Fatal("remoteAddr was not okay")
}
if !startTimeOK {
t.Fatal("start time was not okay")
}
if !finishTimeOK {
t.Fatal("finish time was not okay")
}
if !wrappedErr {
t.Fatal("not wrapped")
}
})
})
t.Run("lookupHost", func(t *testing.T) {
t.Run("handles addresses correctly", func(t *testing.T) {
dialer := &dialerResolver{
dialer := &dialerResolverWithTracing{
Dialer: &DialerSystem{},
Resolver: &NullResolver{},
}
@ -393,7 +501,7 @@ func TestDialerResolver(t *testing.T) {
})
t.Run("fails correctly on lookup error", func(t *testing.T) {
dialer := &dialerResolver{
dialer := &dialerResolverWithTracing{
Dialer: &DialerSystem{},
Resolver: &NullResolver{},
}
@ -413,7 +521,7 @@ func TestDialerResolver(t *testing.T) {
calledDialer bool
calledResolver bool
)
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockCloseIdleConnections: func() {
calledDialer = true

View File

@ -38,7 +38,7 @@ func (t *dnsTransportErrWrapper) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
resp, err := t.DNSTransport.RoundTrip(ctx, query)
if err != nil {
return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, DNSRoundTripOperation, err)
}
return resp, nil
}

View File

@ -34,6 +34,10 @@
//
// We want to have reasonable watchdog timeouts for each operation.
//
// We also want lightweight support for tracing network events. To this end, we
// use context.WithValue and context.Value to inject, and retrieve, a model.Trace
// implementation OPTIONALLY configured by the user.
//
// See also the design document at docs/design/dd-003-step-by-step.md,
// which provides an overview of netxlite's main concerns.
//

View File

@ -71,7 +71,7 @@ func (e *ErrWrapper) MarshalJSON() ([]byte, error) {
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
type classifier func(err error) string
// newErrWrapper creates a new ErrWrapper using the given
// NewErrWrapper creates a new ErrWrapper using the given
// classifier, operation name, and underlying error.
//
// This function panics if classifier is nil, or operation
@ -81,7 +81,7 @@ type classifier func(err error) string
// error wrapper will use the same classification string and
// will determine whether to keep the major operation as documented
// in the ErrWrapper.Operation documentation.
func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
func NewErrWrapper(c classifier, op string, err error) *ErrWrapper {
var wrapper *ErrWrapper
if errors.As(err, &wrapper) {
return &ErrWrapper{
@ -106,6 +106,19 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
}
}
// TODO(https://github.com/ooni/probe/issues/2163): we can really
// simplify the error wrapping situation here by just dropping
// NewErrWrapper and always using MaybeNewErrWrapper.
// MaybeNewErrWrapper is like NewErrWrapper except that this
// function won't panic if passed a nil error.
func MaybeNewErrWrapper(c classifier, op string, err error) error {
if err != nil {
return NewErrWrapper(c, op, err)
}
return nil
}
// NewTopLevelGenericErrWrapper wraps an error occurring at top
// level using a generic classifier as classifier. This is the
// function you should call when you suspect a given error hasn't
@ -115,7 +128,7 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
// error wrapper will use the same classification string and
// failed operation of the original error.
func NewTopLevelGenericErrWrapper(err error) *ErrWrapper {
return newErrWrapper(classifyGenericError, TopLevelOperation, err)
return NewErrWrapper(ClassifyGenericError, TopLevelOperation, err)
}
func classifyOperation(ew *ErrWrapper, operation string) string {

View File

@ -53,7 +53,7 @@ func TestNewErrWrapper(t *testing.T) {
recovered.Add(1)
}
}()
newErrWrapper(nil, CloseOperation, io.EOF)
NewErrWrapper(nil, CloseOperation, io.EOF)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
@ -68,7 +68,7 @@ func TestNewErrWrapper(t *testing.T) {
recovered.Add(1)
}
}()
newErrWrapper(classifyGenericError, "", io.EOF)
NewErrWrapper(ClassifyGenericError, "", io.EOF)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
@ -83,7 +83,7 @@ func TestNewErrWrapper(t *testing.T) {
recovered.Add(1)
}
}()
newErrWrapper(classifyGenericError, CloseOperation, nil)
NewErrWrapper(ClassifyGenericError, CloseOperation, nil)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
@ -91,7 +91,7 @@ func TestNewErrWrapper(t *testing.T) {
})
t.Run("otherwise, works as intended", func(t *testing.T) {
ew := newErrWrapper(classifyGenericError, CloseOperation, io.EOF)
ew := NewErrWrapper(ClassifyGenericError, CloseOperation, io.EOF)
if ew.Failure != FailureEOFError {
t.Fatal("unexpected failure")
}
@ -104,10 +104,10 @@ func TestNewErrWrapper(t *testing.T) {
})
t.Run("when the underlying error is already a wrapped error", func(t *testing.T) {
ew := newErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
ew := NewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
var err1 error = ew
err2 := fmt.Errorf("cannot read: %w", err1)
ew2 := newErrWrapper(classifyGenericError, HTTPRoundTripOperation, err2)
ew2 := NewErrWrapper(ClassifyGenericError, HTTPRoundTripOperation, err2)
if ew2.Failure != ew.Failure {
t.Fatal("not the same failure")
}
@ -117,6 +117,34 @@ func TestNewErrWrapper(t *testing.T) {
if ew2.WrappedErr != err2 {
t.Fatal("invalid underlying error")
}
// Make sure we can still use errors.Is with two layers of wrapping
if !errors.Is(ew2, ECONNRESET) {
t.Fatal("we cannot use errors.Is to retrieve the real syscall error")
}
})
}
func TestMaybeNewErrWrapper(t *testing.T) {
// TODO(https://github.com/ooni/probe/issues/2163): we can really
// simplify the error wrapping situation here by just dropping
// NewErrWrapper and always using MaybeNewErrWrapper.
t.Run("when we pass a nil error to this function", func(t *testing.T) {
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, nil)
if err != nil {
t.Fatal("unexpected output", err)
}
})
t.Run("when we pass a non-nil error to this function", func(t *testing.T) {
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
if !errors.Is(err, ECONNRESET) {
t.Fatal("unexpected output", err)
}
var ew *ErrWrapper
if !errors.As(err, &ew) {
t.Fatal("not an instance of ErrWrapper")
}
})
}

View File

@ -31,7 +31,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
dialer := txpCc.Dialer
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger)
dialerReso := dialerLog.Dialer.(*dialerResolver)
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
if dialerReso.Resolver != resolver {
t.Fatal("invalid resolver")
}
@ -52,7 +52,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer)
dialerLog := dialerProxy.Dialer.(*dialerLogger)
dialerReso := dialerLog.Dialer.(*dialerResolver)
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
if dialerReso.Resolver != resolver {
t.Fatal("invalid resolver")
}
@ -269,7 +269,7 @@ func TestNewHTTPTransport(t *testing.T) {
t.Run("works as intended with failing dialer", func(t *testing.T) {
called := &atomicx.Int64{}
expected := errors.New("mocked error")
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context,
network, address string) (net.Conn, error) {
@ -612,7 +612,7 @@ func TestNewHTTPClientWithResolver(t *testing.T) {
txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser)
dialer := txpCc.Dialer.(*httpDialerWithReadTimeout)
dialerLogger := dialer.Dialer.(*dialerLogger)
dialerReso := dialerLogger.Dialer.(*dialerResolver)
dialerReso := dialerLogger.Dialer.(*dialerResolverWithTracing)
if dialerReso.Resolver != reso {
t.Fatal("invalid resolver")
}

View File

@ -41,7 +41,7 @@ func TestReadAllContext(t *testing.T) {
//
// Note: Returning a wrapped error to ensure we address
// https://github.com/ooni/probe/issues/1965
return len(b), newErrWrapper(classifyGenericError,
return len(b), NewErrWrapper(ClassifyGenericError,
ReadOperation, io.EOF)
},
}
@ -171,7 +171,7 @@ func TestCopyContext(t *testing.T) {
//
// Note: Returning a wrapped error to ensure we address
// https://github.com/ooni/probe/issues/1965
return len(b), newErrWrapper(classifyGenericError,
return len(b), NewErrWrapper(ClassifyGenericError,
ReadOperation, io.EOF)
},
}

View File

@ -380,7 +380,7 @@ var _ model.QUICListener = &quicListenerErrWrapper{}
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
pconn, err := qls.QUICListener.Listen(addr)
if err != nil {
return nil, newErrWrapper(classifyGenericError, QUICListenOperation, err)
return nil, NewErrWrapper(ClassifyGenericError, QUICListenOperation, err)
}
return &quicErrWrapperUDPLikeConn{pconn}, nil
}
@ -397,7 +397,7 @@ var _ model.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
count, err := c.UDPLikeConn.WriteTo(p, addr)
if err != nil {
return 0, newErrWrapper(classifyGenericError, WriteToOperation, err)
return 0, NewErrWrapper(ClassifyGenericError, WriteToOperation, err)
}
return count, nil
}
@ -406,7 +406,7 @@ func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := c.UDPLikeConn.ReadFrom(b)
if err != nil {
return 0, nil, newErrWrapper(classifyGenericError, ReadFromOperation, err)
return 0, nil, NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
}
return n, addr, nil
}
@ -415,7 +415,7 @@ func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
func (c *quicErrWrapperUDPLikeConn) Close() error {
err := c.UDPLikeConn.Close()
if err != nil {
return newErrWrapper(classifyGenericError, ReadFromOperation, err)
return NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
}
return nil
}
@ -433,8 +433,8 @@ func (d *quicDialerErrWrapper) DialContext(
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
qconn, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
if err != nil {
return nil, newErrWrapper(
classifyQUICHandshakeError, QUICHandshakeOperation, err)
return nil, NewErrWrapper(
ClassifyQUICHandshakeError, QUICHandshakeOperation, err)
}
return qconn, nil
}

View File

@ -34,13 +34,13 @@ func TestQuirkReduceErrors(t *testing.T) {
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
err1 := errors.New("mocked error #1")
err2 := newErrWrapper(
classifyGenericError,
err2 := NewErrWrapper(
ClassifyGenericError,
CloseOperation,
errors.New("antani"),
)
err3 := newErrWrapper(
classifyGenericError,
err3 := NewErrWrapper(
ClassifyGenericError,
CloseOperation,
ECONNREFUSED,
)

View File

@ -387,7 +387,7 @@ var _ model.Resolver = &resolverErrWrapper{}
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil {
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
}
return addrs, nil
}
@ -396,7 +396,7 @@ func (r *resolverErrWrapper) LookupHTTPS(
ctx context.Context, domain string) (*model.HTTPSSvc, error) {
out, err := r.Resolver.LookupHTTPS(ctx, domain)
if err != nil {
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
}
return out, nil
}
@ -417,7 +417,7 @@ func (r *resolverErrWrapper) LookupNS(
ctx context.Context, domain string) ([]*net.NS, error) {
out, err := r.Resolver.LookupNS(ctx, domain)
if err != nil {
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
}
return out, nil
}

View File

@ -18,9 +18,6 @@ import (
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
// TODO(bassosimone): check whether there's now equivalent functionality
// inside the standard library allowing us to map numbers to names.
var (
tlsVersionString = map[uint16]string{
tls.VersionTLS10: "TLSv1",
@ -85,6 +82,13 @@ func TLSVersionString(value uint16) string {
// the value to a cipher suite name, we return `TLS_CIPHER_SUITE_UNKNOWN_ddd`
// where `ddd` is the numeric value passed to this function.
func TLSCipherSuiteString(value uint16) string {
// TODO(https://github.com/ooni/probe/issues/2166): the standard library has a
// function for mapping a cipher suite to a string, but the value returned in case of
// missing cipher suite is different from the one we would return
// here. We could consider simplifying this code anyway because
// in most, if not all, cases we have a valid cipher suite and we
// just need to make sure what the spec says we should do when
// passed an unknown cipher suite.
if str, found := tlsCipherSuiteString[value]; found {
return str
}
@ -158,15 +162,15 @@ func NewTLSHandshakerStdlib(logger model.DebugLogger) model.TLSHandshaker {
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
func newTLSHandshaker(th model.TLSHandshaker, logger model.DebugLogger) model.TLSHandshaker {
return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerErrWrapper{
TLSHandshaker: th,
},
DebugLogger: logger,
TLSHandshaker: th,
DebugLogger: logger,
}
}
// tlsHandshakerConfigurable is a configurable TLS handshaker that
// uses by default the standard library's TLS implementation.
//
// This type also implements error wrapping and events tracing.
type tlsHandshakerConfigurable struct {
// NewConn is the OPTIONAL factory for creating a new connection. If
// this factory is not set, we'll use the stdlib.
@ -183,9 +187,20 @@ var _ model.TLSHandshaker = &tlsHandshakerConfigurable{}
// value into a private variable to enable for unit testing.
var defaultCertPool = NewDefaultCertPool()
// tlsMaybeConnectionState returns the connection state if error is nil
// and otherwise just returns an empty state to the caller.
func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState {
if err != nil {
return tls.ConnectionState{}
}
return conn.ConnectionState()
}
// Handshake implements Handshaker.Handshake. This function will
// configure the code to use the built-in Mozilla CA if the config
// field contains a nil RootCAs field.
//
// This function will also emit TLS-handshake-related tracing events.
func (h *tlsHandshakerConfigurable) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
@ -203,10 +218,19 @@ func (h *tlsHandshakerConfigurable) Handshake(
if err != nil {
return nil, tls.ConnectionState{}, err
}
if err := tlsconn.HandshakeContext(ctx); err != nil {
remoteAddr := conn.RemoteAddr().String()
trace := ContextTraceOrDefault(ctx)
started := trace.TimeNow()
trace.OnTLSHandshakeStart(started, remoteAddr, config)
err = tlsconn.HandshakeContext(ctx)
err = MaybeNewErrWrapper(ClassifyTLSHandshakeError, TLSHandshakeOperation, err)
finished := trace.TimeNow()
state := tlsMaybeConnectionState(tlsconn, err)
trace.OnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
if err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
return tlsconn, state, nil
}
// newConn creates a new TLSConn.
@ -352,23 +376,6 @@ func (d *tlsDialerSingleUseAdapter) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
type tlsHandshakerErrWrapper struct {
TLSHandshaker model.TLSHandshaker
}
// Handshake implements TLSHandshaker.Handshake
func (h *tlsHandshakerErrWrapper) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil {
return nil, tls.ConnectionState{}, newErrWrapper(
classifyTLSHandshakeError, TLSHandshakeOperation, err)
}
return tlsconn, state, nil
}
// ErrNoTLSDialer is the type of error returned by "null" TLS dialers
// when you attempt to dial with them.
var ErrNoTLSDialer = errors.New("no configured TLS dialer")

View File

@ -16,7 +16,10 @@ import (
"github.com/apex/log"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestVersionString(t *testing.T) {
@ -123,8 +126,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
if logger.DebugLogger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn != nil {
t.Fatal("expected nil NewConn")
}
@ -132,7 +134,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
func TestTLSHandshakerConfigurable(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("with error", func(t *testing.T) {
t.Run("with handshake I/O error", func(t *testing.T) {
var times []time.Time
h := &tlsHandshakerConfigurable{}
tcpConn := &mocks.Conn{
@ -143,14 +145,37 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
times = append(times, t)
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "1.1.1.1:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
ctx := context.Background()
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
conn, state, err := h.Handshake(ctx, tcpConn, &tls.Config{
ServerName: "x.org",
})
if err != io.EOF {
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("the error has not been wrapped")
}
if errWrapper.Failure != FailureEOFError {
t.Fatal("invalid wrapped error's failure")
}
if errWrapper.Operation != TLSHandshakeOperation {
t.Fatal("invalid wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil con here")
}
@ -163,6 +188,9 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
if !times[1].IsZero() {
t.Fatal("did not clear timeout on exit")
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("the returned connection state is not a zero value")
}
})
t.Run("with success", func(t *testing.T) {
@ -217,6 +245,16 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "1.1.1.1:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
@ -236,7 +274,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
}
})
t.Run("we cannot create a new conn", func(t *testing.T) {
t.Run("h.newConn fails", func(t *testing.T) {
expected := errors.New("mocked error")
handshaker := &tlsHandshakerConfigurable{
NewConn: func(conn net.Conn, config *tls.Config) (TLSConn, error) {
@ -261,6 +299,222 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
t.Fatal("expected nil tlsConn here")
}
})
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
var (
expectedSNI = "dns.google"
goodStartStartTime bool
goodStartInsecureSkipVerify bool
goodDoneInsecureSkipVerify bool
goodStartServerName bool
goodDoneServerName bool
goodDoneStartTime bool
goodDoneDoneTime bool
goodStartRemoteAddr bool
goodDoneRemoteAddr bool
goodDoneError bool
goodConnectionState bool
startCalled bool
doneCalled bool
)
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
defer server.Close()
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
startCalled = true
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodStartServerName = (config.ServerName == expectedSNI)
goodStartStartTime = (now.Sub(zeroTime) == 0)
goodStartRemoteAddr = (remoteAddr == server.Endpoint())
},
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
doneCalled = true
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodDoneServerName = (config.ServerName == expectedSNI)
goodDoneStartTime = (started.Sub(zeroTime) == 0)
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
goodDoneRemoteAddr = (remoteAddr == server.Endpoint())
goodDoneError = (err == nil)
goodConnectionState = (!reflect.ValueOf(state).IsZero())
},
}
ctx := ContextWithTrace(context.Background(), tx)
tcpConn, err := net.Dial("tcp", server.Endpoint())
if err != nil {
t.Fatal(err)
}
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: expectedSNI,
}
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err != nil {
t.Fatal(err)
}
tlsConn.Close()
if reflect.ValueOf(connState).IsZero() {
t.Fatal("expected nonzero connState")
}
if !startCalled {
t.Fatal("start not called")
}
if !doneCalled {
t.Fatal("done not called")
}
if !goodStartInsecureSkipVerify {
t.Fatal("invalid start-event's InsecureSkipVerify")
}
if !goodDoneInsecureSkipVerify {
t.Fatal("invalid done-event's InsecureSkipVerify")
}
if !goodStartServerName {
t.Fatal("invalid start-event's ServerName")
}
if !goodDoneServerName {
t.Fatal("invalid done-event's ServerName")
}
if !goodStartStartTime {
t.Fatal("invalid start-event's start time")
}
if !goodDoneStartTime {
t.Fatal("invalid done-event's start time")
}
if !goodDoneDoneTime {
t.Fatal("invalid done-event's done time")
}
if !goodStartRemoteAddr {
t.Fatal("invalid start-event's remoteAddr")
}
if !goodDoneRemoteAddr {
t.Fatal("invalid done-event's remoteAddr")
}
if !goodDoneError {
t.Fatal("invalid done-event's error")
}
if !goodConnectionState {
t.Fatal("invalid done-event's connState")
}
})
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
var (
expectedEndpoint = "8.8.8.8:443"
expectedSNI = "dns.google"
goodStartStartTime bool
goodStartInsecureSkipVerify bool
goodDoneInsecureSkipVerify bool
goodStartServerName bool
goodDoneServerName bool
goodDoneStartTime bool
goodDoneDoneTime bool
goodStartRemoteAddr bool
goodDoneRemoteAddr bool
goodDoneError bool
goodConnectionState bool
startCalled bool
doneCalled bool
)
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
startCalled = true
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodStartServerName = (config.ServerName == expectedSNI)
goodStartStartTime = (now.Sub(zeroTime) == 0)
goodStartRemoteAddr = (remoteAddr == expectedEndpoint)
},
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
doneCalled = true
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodDoneServerName = (config.ServerName == expectedSNI)
goodDoneStartTime = (started.Sub(zeroTime) == 0)
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
goodDoneRemoteAddr = (remoteAddr == expectedEndpoint)
var ew *ErrWrapper
goodDoneError = (errors.As(err, &ew) && ew.Error() == FailureEOFError)
goodConnectionState = (reflect.ValueOf(state).IsZero())
},
}
ctx := ContextWithTrace(context.Background(), tx)
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return expectedEndpoint
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: expectedSNI,
}
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if !errors.Is(err, io.EOF) {
t.Fatal("unexpected err", err)
}
if tlsConn != nil {
t.Fatal("expected nil tlsConn")
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero connState")
}
if !startCalled {
t.Fatal("start not called")
}
if !doneCalled {
t.Fatal("done not called")
}
if !goodStartInsecureSkipVerify {
t.Fatal("invalid start-event's InsecureSkipVerify")
}
if !goodDoneInsecureSkipVerify {
t.Fatal("invalid done-event's InsecureSkipVerify")
}
if !goodStartServerName {
t.Fatal("invalid start-event's ServerName")
}
if !goodDoneServerName {
t.Fatal("invalid done-event's ServerName")
}
if !goodStartStartTime {
t.Fatal("invalid start-event's start time")
}
if !goodDoneStartTime {
t.Fatal("invalid done-event's start time")
}
if !goodDoneDoneTime {
t.Fatal("invalid done-event's done time")
}
if !goodStartRemoteAddr {
t.Fatal("invalid start-event's remoteAddr")
}
if !goodDoneRemoteAddr {
t.Fatal("invalid done-event's remoteAddr")
}
if !goodDoneError {
t.Fatal("invalid done-event's error")
}
if !goodConnectionState {
t.Fatal("invalid done-event's connState")
}
})
})
}
@ -413,6 +667,15 @@ func TestTLSDialer(t *testing.T) {
return nil
}, MockSetDeadline: func(t time.Time) error {
return nil
}, MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "1.1.1.1:443"
},
MockString: func() string {
return "tcp"
},
}
}}, nil
}},
TLSHandshaker: &tlsHandshakerConfigurable{},
@ -532,54 +795,6 @@ func TestNewSingleUseTLSDialer(t *testing.T) {
}
}
func TestTLSHandshakerErrWrapper(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.TLSConn{}
expectedState := tls.ConnectionState{
Version: tls.VersionTLS12,
}
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return expectedConn, expectedState, nil
},
},
}
ctx := context.Background()
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err != nil {
t.Fatal(err)
}
if expectedState.Version != state.Version {
t.Fatal("unexpected state")
}
if expectedConn != conn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expectedErr
},
},
}
ctx := context.Background()
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err == nil || err.Error() != FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("unexpected conn")
}
})
})
}
func TestNewNullTLSDialer(t *testing.T) {
dialer := NewNullTLSDialer()
conn, err := dialer.DialTLSContext(context.Background(), "", "")
@ -618,3 +833,35 @@ func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
}
})
}
func TestMaybeConnectionState(t *testing.T) {
t.Run("with an error", func(t *testing.T) {
returned := tls.ConnectionState{
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
}
conn := &mocks.TLSConn{
MockConnectionState: func() tls.ConnectionState {
return returned
},
}
state := tlsMaybeConnectionState(conn, errors.New("mocked error"))
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected to see a zero connection state")
}
})
t.Run("without an error", func(t *testing.T) {
returned := tls.ConnectionState{
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
}
conn := &mocks.TLSConn{
MockConnectionState: func() tls.ConnectionState {
return returned
},
}
state := tlsMaybeConnectionState(conn, nil)
if reflect.ValueOf(state).IsZero() {
t.Fatal("expected to see a nonzero connection state")
}
})
}

View File

@ -0,0 +1,67 @@
package netxlite
//
// Context-based tracing
//
import (
"context"
"crypto/tls"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
// traceKey is the private type used to set/retrieve the context's trace.
type traceKey struct{}
// ContextTraceOrDefault retrieves the trace bound to the context or returns
// a default implementation of the trace in case no tracing was configured.
func ContextTraceOrDefault(ctx context.Context) model.Trace {
t, _ := ctx.Value(traceKey{}).(model.Trace)
return traceOrDefault(t)
}
// ContextWithTrace returns a new context that binds to the given trace. If the
// given trace is nil, this function will call panic.
func ContextWithTrace(ctx context.Context, trace model.Trace) context.Context {
runtimex.PanicIfTrue(trace == nil, "netxlite.WithTrace passed a nil trace")
return context.WithValue(ctx, traceKey{}, trace)
}
// traceOrDefault takes in input a trace and returns in output the
// given trace, if not nil, or a default trace implementation.
func traceOrDefault(trace model.Trace) model.Trace {
if trace != nil {
return trace
}
return &traceDefault{}
}
// traceDefault is a default model.Trace implementation where each method is a no-op.
type traceDefault struct{}
var _ model.Trace = &traceDefault{}
// TimeNow implements model.Trace.TimeNow
func (*traceDefault) TimeNow() time.Time {
return time.Now()
}
// OnConnectDone implements model.Trace.OnConnectDone.
func (*traceDefault) OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
// nothing
}
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
func (*traceDefault) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
// nothing
}
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
func (*traceDefault) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Time) {
// nothing
}

View File

@ -0,0 +1,40 @@
package netxlite
import (
"context"
"testing"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestContextTraceOrDefault(t *testing.T) {
t.Run("without a configured trace we get a default", func(t *testing.T) {
ctx := context.Background()
tx := ContextTraceOrDefault(ctx)
_ = tx.(*traceDefault) // panic if cannot cast
})
t.Run("with a configured trace we get the expected trace", func(t *testing.T) {
realTrace := &mocks.Trace{}
ctx := ContextWithTrace(context.Background(), realTrace)
tx := ContextTraceOrDefault(ctx)
if tx != realTrace {
t.Fatal("not the trace we expected")
}
})
}
func TestContextWithTrace(t *testing.T) {
t.Run("panics if passed a nil trace", func(t *testing.T) {
var called bool
func() {
defer func() {
called = (recover() != nil)
}()
_ = ContextWithTrace(context.Background(), nil)
}()
if !called {
t.Fatal("not called")
}
})
}

View File

@ -19,8 +19,7 @@ func TestNewTLSHandshakerUTLS(t *testing.T) {
if logger.DebugLogger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn == nil {
t.Fatal("expected non-nil NewConn")
}

2
internal/testingx/doc.go Normal file
View File

@ -0,0 +1,2 @@
// Package testingx contains code useful for testing.
package testingx

View File

@ -1,11 +1,4 @@
// Package fakefill contains code to fill structs for testing.
//
// This package is quite limited in scope and we can fill only the
// structures you typically send over as JSONs.
//
// As part of future work, we aim to investigate whether we can
// replace this implementation with https://go.dev/blog/fuzz-beta.
package fakefill
package testingx
import (
"math/rand"
@ -14,7 +7,7 @@ import (
"time"
)
// Filler fills specific data structures with random data. The only
// FakeFiller fills specific data structures with random data. The only
// exception to this behaviour is time.Time, which is instead filled
// with the current time plus a small random number of seconds.
//
@ -25,7 +18,13 @@ import (
// Caveat: this kind of fillter does not support filling interfaces
// and channels and other complex types. The current behavior when this
// kind of data types is encountered is to just ignore them.
type Filler struct {
//
// This struct is quite limited in scope and we can fill only the
// structures you typically send over as JSONs.
//
// As part of future work, we aim to investigate whether we can
// replace this implementation with https://go.dev/blog/fuzz-beta.
type FakeFiller struct {
// mu provides mutual exclusion
mu sync.Mutex
@ -37,7 +36,7 @@ type Filler struct {
rnd *rand.Rand
}
func (ff *Filler) getRandLocked() *rand.Rand {
func (ff *FakeFiller) getRandLocked() *rand.Rand {
if ff.rnd == nil {
now := time.Now
if ff.Now != nil {
@ -48,7 +47,7 @@ func (ff *Filler) getRandLocked() *rand.Rand {
return ff.rnd
}
func (ff *Filler) getRandomString() string {
func (ff *FakeFiller) getRandomString() string {
defer ff.mu.Unlock()
ff.mu.Lock()
rnd := ff.getRandLocked()
@ -62,28 +61,28 @@ func (ff *Filler) getRandomString() string {
return string(b)
}
func (ff *Filler) getRandomInt64() int64 {
func (ff *FakeFiller) getRandomInt64() int64 {
defer ff.mu.Unlock()
ff.mu.Lock()
rnd := ff.getRandLocked()
return rnd.Int63()
}
func (ff *Filler) getRandomBool() bool {
func (ff *FakeFiller) getRandomBool() bool {
defer ff.mu.Unlock()
ff.mu.Lock()
rnd := ff.getRandLocked()
return rnd.Float64() >= 0.5
}
func (ff *Filler) getRandomSmallPositiveInt() int {
func (ff *FakeFiller) getRandomSmallPositiveInt() int {
defer ff.mu.Unlock()
ff.mu.Lock()
rnd := ff.getRandLocked()
return int(rnd.Int63n(8)) + 1 // safe cast
}
func (ff *Filler) doFill(v reflect.Value) {
func (ff *FakeFiller) doFill(v reflect.Value) {
for v.Type().Kind() == reflect.Ptr {
if v.IsNil() {
// if the pointer is nil, allocate an element
@ -134,6 +133,6 @@ func (ff *Filler) doFill(v reflect.Value) {
}
// Fill fills the input structure or pointer with random data.
func (ff *Filler) Fill(in interface{}) {
func (ff *FakeFiller) Fill(in interface{}) {
ff.doFill(reflect.ValueOf(in))
}

View File

@ -1,4 +1,4 @@
package fakefill
package testingx
import (
"testing"
@ -16,7 +16,7 @@ type exampleStructure struct {
func TestFakeFillWorksWithCustomTime(t *testing.T) {
var req *exampleStructure
ff := &Filler{
ff := &FakeFiller{
Now: func() time.Time {
return time.Date(1992, time.January, 24, 17, 53, 0, 0, time.UTC)
},
@ -29,7 +29,7 @@ func TestFakeFillWorksWithCustomTime(t *testing.T) {
func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) {
var req *exampleStructure
ff := &Filler{}
ff := &FakeFiller{}
ff.Fill(&req)
if req == nil {
t.Fatal("we expected non nil here")
@ -38,7 +38,7 @@ func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) {
func TestFakeFillAllocatesIntoAMapLikeWithStringKeys(t *testing.T) {
var resp map[string]*exampleStructure
ff := &Filler{}
ff := &FakeFiller{}
ff.Fill(&resp)
if resp == nil {
t.Fatal("we expected non nil here")
@ -62,7 +62,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) {
}
}()
var resp map[int64]*exampleStructure
ff := &Filler{}
ff := &FakeFiller{}
ff.Fill(&resp)
if resp != nil {
t.Fatal("we expected nil here")
@ -75,7 +75,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) {
func TestFakeFillAllocatesIntoASlice(t *testing.T) {
var resp *[]*exampleStructure
ff := &Filler{}
ff := &FakeFiller{}
ff.Fill(&resp)
if resp == nil {
t.Fatal("we expected non nil here")

48
internal/testingx/time.go Normal file
View File

@ -0,0 +1,48 @@
package testingx
import (
"sync"
"time"
)
// TimeDeterministic implements time.Now in a deterministic fashion
// such that every time.Time call returns a moment in time that occurs
// one second after the configured zeroTime.
//
// It's safe to use this struct from multiple goroutine contexts.
type TimeDeterministic struct {
// counter counts the number of "ticks" passed since the zero time: each
// call to Now increments this counter by one second.
counter time.Duration
// mu protects fields in this structure from concurrent access.
mu sync.Mutex
// zeroTime is the lazy-initialized zero time. The first call to Now
// will initialize this field with the current time.
zeroTime time.Time
}
// NewTimeDeterministic creates a new instance using the given zeroTime value.
func NewTimeDeterministic(zeroTime time.Time) *TimeDeterministic {
return &TimeDeterministic{
counter: 0,
mu: sync.Mutex{},
zeroTime: zeroTime,
}
}
// Now is like time.Now but more deterministic. The first call returns the
// configured zeroTime and subsequent calls return moments in time that occur
// exactly one second after the time returned by the previous call.
func (td *TimeDeterministic) Now() time.Time {
td.mu.Lock()
if td.zeroTime.IsZero() {
td.zeroTime = time.Now()
}
offset := td.counter
td.counter += time.Second
res := td.zeroTime.Add(offset)
td.mu.Unlock()
return res
}

View File

@ -0,0 +1,22 @@
package testingx
import (
"testing"
"time"
)
func TestTimeDeterministic(t *testing.T) {
td := &TimeDeterministic{}
t0 := td.Now()
if !t0.Equal(td.zeroTime) {
t.Fatal("invalid t0 value")
}
t1 := td.Now()
if t1.Sub(t0) != time.Second {
t.Fatal("invalid t1 value")
}
t2 := td.Now()
if t2.Sub(t1) != time.Second {
t.Fatal("invalid t2 value")
}
}

View File

@ -5,4 +5,8 @@
// resulting connection). When done, we will extract the trace containing
// all the events that occurred from the saver and process it to determine
// what happened during the measurement itself.
//
// This package is now frozen. Please, use measurexlite for new code. See
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
// for details about this.
package tracex