feat: tlsping and tcpping using step-by-step (#815)
## Checklist - [x] I have read the [contribution guidelines](https://github.com/ooni/probe-cli/blob/master/CONTRIBUTING.md) - [x] reference issue for this pull request: https://github.com/ooni/probe/issues/2158 - [x] if you changed anything related how experiments work and you need to reflect these changes in the ooni/spec repository, please link to the related ooni/spec pull request: https://github.com/ooni/spec/pull/250 ## Description This diff refactors the codebase to reimplement tlsping and tcpping to use the step-by-step measurements style. See docs/design/dd-003-step-by-step.md for more information on the step-by-step measurement style.
This commit is contained in:
parent
5371c7f486
commit
5ebdeb56ca
|
@ -874,7 +874,7 @@ mechanisms if the context is too bad):
|
|||
|
||||
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
|
||||
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||
qtype uint16, out chan\<- \*parallelResolverResult) {
|
||||
qtype uint16, out chan<- *parallelResolverResult) {
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
+ started := time.Now()
|
||||
|
@ -1329,7 +1329,7 @@ const webDomain = "web.telegram.org"
|
|||
// This method does not return any value and writes results directly inside
|
||||
// the test keys, which have thread safe methods for that.
|
||||
func (mx *Measurer) measureWebEndpointHTTPS(ctx context.Context, wg *sync.WaitGroup,
|
||||
logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
|
||||
logger model.Logger, zeroTime time.Time, tk *TestKeys, address string) {
|
||||
// 0. setup
|
||||
const webTimeout = 7 * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, webTimeout)
|
||||
|
@ -1344,8 +1344,8 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
|
|||
// 1. establish a TCP connection with the endpoint
|
||||
|
||||
// dialer := nextlite.NewDialerwithoutResolver(logger) // --- (removed line)
|
||||
trace := measurexlite.NewTrace(index, logger, zeroTime) // +++ (added line)
|
||||
dialer := trace.NewDialerWithoutResolver() // +++ (...)
|
||||
trace := measurexlite.NewTrace(index, zeroTime) // +++ (added line)
|
||||
dialer := trace.NewDialerWithoutResolver(logger) // +++ (...)
|
||||
defer tk.addTCPConnectResults(trace.TCPConnectResults()) // +++ (...)
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", endpoint)
|
||||
|
@ -1366,7 +1366,7 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
|
|||
// thx := netxlite.NewTLSHandshakerStdlib(logger) // ---
|
||||
conn = trace.WrapConn(conn) // +++
|
||||
defer tk.addNetworkEvents(trace.NetworkEvents()) // +++
|
||||
thx := trace.NewTLSHandshakerStdlib() // +++
|
||||
thx := trace.NewTLSHandshakerStdlib(logger) // +++
|
||||
defer tk.addTLSHandshakeResult(trace.TLSHandshakeResults()) // +++
|
||||
|
||||
config := &tls.Config{
|
||||
|
@ -1498,23 +1498,22 @@ type Trace struct { /* ... */ }
|
|||
|
||||
var _ model.Trace = &Trace{}
|
||||
|
||||
func NewTrace(index int64, logger model.Logger, zeroTime time.Time) *Trace {
|
||||
func NewTrace(index int64, zeroTime time.Time) *Trace {
|
||||
const (
|
||||
tlsHandshakeBuffer = 16,
|
||||
// ...
|
||||
)
|
||||
return &Trace{
|
||||
Index: index,
|
||||
Logger: logger,
|
||||
TLS: make(chan *model.ArchivalTLSOrQUICHandshakeResult, tlsHandshakeBuffer),
|
||||
ZeroTime: zeroTime,
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func (tx *Trace) NewTLSHandshakerStdlib() model.TLSHandshaker {
|
||||
func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
|
||||
return &tlsHandshakerTrace{
|
||||
TLSHandshaker: netxlite.NewTLSHandshakerStdlib(tx.Logger),
|
||||
TLSHandshaker: netxlite.NewTLSHandshakerStdlib(dl),
|
||||
Trace: tx,
|
||||
}
|
||||
}
|
||||
|
@ -1524,8 +1523,8 @@ type tlsHandshakerTrace { /* ... */ }
|
|||
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
|
||||
|
||||
func (thx *tlsHandshakerTrace) Handshake(ctx context.Context,
|
||||
conn net.Conn, config \*tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
ctx = netxlite.ContextWithTrace(ctx, thx.Trace) // <- here we setup the context magic
|
||||
conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
ctx = netxlite.ContextWithTraceOrDefault(ctx, thx.Trace) // <- here we setup the context magic
|
||||
return thx.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
}
|
||||
|
||||
|
@ -1540,7 +1539,7 @@ func (tx *Trace) OnTLSHandshake(started time.Time, remoteAddr string,
|
|||
}
|
||||
}
|
||||
|
||||
func (tx *Trace) TLSHandshakeResults() (out []\*model.ArchivalTLSOrQUICHandshakeResult) {
|
||||
func (tx *Trace) TLSHandshakeResults() (out []*model.ArchivalTLSOrQUICHandshakeResult) {
|
||||
for {
|
||||
select {
|
||||
case ev := <-tx.TLS:
|
||||
|
@ -1575,7 +1574,7 @@ func (thx *tlsHandshakerConfigurable) Handshake(ctx context.Context,
|
|||
state := tlsconn.connectionState()
|
||||
trace.OnTLSHandshake(started, remoteAddr, config, // +++
|
||||
state, nil, finished) // +++
|
||||
return tlsconn, state, nil
|
||||
return tlsconn, state, nil
|
||||
}
|
||||
|
||||
func ContextTraceOrDefault(ctx context.Context) model.Trace { /* ... */ }
|
||||
|
|
|
@ -10,13 +10,13 @@ import (
|
|||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/measurex"
|
||||
"github.com/ooni/probe-cli/v3/internal/measurexlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
testName = "tcpping"
|
||||
testVersion = "0.1.0"
|
||||
testVersion = "0.2.0"
|
||||
)
|
||||
|
||||
// Config contains the experiment configuration.
|
||||
|
@ -49,7 +49,7 @@ type TestKeys struct {
|
|||
|
||||
// SinglePing contains the results of a single ping.
|
||||
type SinglePing struct {
|
||||
TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"`
|
||||
TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"`
|
||||
}
|
||||
|
||||
// Measurer performs the measurement.
|
||||
|
@ -103,42 +103,47 @@ func (m *Measurer) Run(
|
|||
}
|
||||
tk := new(TestKeys)
|
||||
measurement.TestKeys = tk
|
||||
out := make(chan *measurex.EndpointMeasurement)
|
||||
mxmx := measurex.NewMeasurerWithDefaultSettings()
|
||||
go m.tcpPingLoop(ctx, mxmx, parsed.Host, out)
|
||||
out := make(chan *SinglePing)
|
||||
go m.tcpPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out)
|
||||
for len(tk.Pings) < int(m.config.repetitions()) {
|
||||
meas := <-out
|
||||
tk.Pings = append(tk.Pings, &SinglePing{
|
||||
TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect),
|
||||
})
|
||||
tk.Pings = append(tk.Pings, <-out)
|
||||
}
|
||||
return nil // return nil so we always submit the measurement
|
||||
}
|
||||
|
||||
// tcpPingLoop sends all the ping requests and emits the results onto the out channel.
|
||||
func (m *Measurer) tcpPingLoop(ctx context.Context, mxmx *measurex.Measurer,
|
||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
||||
func (m *Measurer) tcpPingLoop(ctx context.Context, zeroTime time.Time,
|
||||
logger model.Logger, address string, out chan<- *SinglePing) {
|
||||
ticker := time.NewTicker(m.config.delay())
|
||||
defer ticker.Stop()
|
||||
for i := int64(0); i < m.config.repetitions(); i++ {
|
||||
go m.tcpPingAsync(ctx, mxmx, address, out)
|
||||
go m.tcpPingAsync(ctx, i, zeroTime, logger, address, out)
|
||||
<-ticker.C
|
||||
}
|
||||
}
|
||||
|
||||
// tcpPingAsync performs a TCP ping and emits the result onto the out channel.
|
||||
func (m *Measurer) tcpPingAsync(ctx context.Context, mxmx *measurex.Measurer,
|
||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
||||
out <- m.tcpConnect(ctx, mxmx, address)
|
||||
func (m *Measurer) tcpPingAsync(ctx context.Context, index int64,
|
||||
zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) {
|
||||
out <- m.tcpConnect(ctx, index, zeroTime, logger, address)
|
||||
}
|
||||
|
||||
// tcpConnect performs a TCP connect and returns the result to the caller.
|
||||
func (m *Measurer) tcpConnect(ctx context.Context, mxmx *measurex.Measurer,
|
||||
address string) *measurex.EndpointMeasurement {
|
||||
func (m *Measurer) tcpConnect(ctx context.Context, index int64,
|
||||
zeroTime time.Time, logger model.Logger, address string) *SinglePing {
|
||||
// TODO(bassosimone): make the timeout user-configurable
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
return mxmx.TCPConnect(ctx, address)
|
||||
trace := measurexlite.NewTrace(index, zeroTime)
|
||||
dialer := trace.NewDialerWithoutResolver(logger)
|
||||
ol := measurexlite.NewOperationLogger(logger, "TCPPing #%d %s", index, address)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||
ol.Stop(err)
|
||||
measurexlite.MaybeClose(conn)
|
||||
sp := &SinglePing{
|
||||
TCPConnect: <-trace.TCPConnect,
|
||||
}
|
||||
return sp
|
||||
}
|
||||
|
||||
// NewExperimentMeasurer creates a new ExperimentMeasurer.
|
||||
|
|
|
@ -40,14 +40,16 @@ func TestMeasurer_run(t *testing.T) {
|
|||
if m.ExperimentName() != "tcpping" {
|
||||
t.Fatal("invalid experiment name")
|
||||
}
|
||||
if m.ExperimentVersion() != "0.1.0" {
|
||||
if m.ExperimentVersion() != "0.2.0" {
|
||||
t.Fatal("invalid experiment version")
|
||||
}
|
||||
ctx := context.Background()
|
||||
meas := &model.Measurement{
|
||||
Input: model.MeasurementTarget(input),
|
||||
}
|
||||
sess := &mockable.Session{}
|
||||
sess := &mockable.Session{
|
||||
MockableLogger: model.DiscardLogger,
|
||||
}
|
||||
callbacks := model.NewPrinterCallbacks(model.DiscardLogger)
|
||||
err := m.Run(ctx, sess, meas, callbacks)
|
||||
return meas, m, err
|
||||
|
|
|
@ -13,14 +13,14 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/measurex"
|
||||
"github.com/ooni/probe-cli/v3/internal/measurexlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
const (
|
||||
testName = "tlsping"
|
||||
testVersion = "0.1.0"
|
||||
testVersion = "0.2.0"
|
||||
)
|
||||
|
||||
// Config contains the experiment configuration.
|
||||
|
@ -77,9 +77,9 @@ type TestKeys struct {
|
|||
|
||||
// SinglePing contains the results of a single ping.
|
||||
type SinglePing struct {
|
||||
NetworkEvents []*measurex.ArchivalNetworkEvent `json:"network_events"`
|
||||
TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"`
|
||||
TLSHandshakes []*measurex.ArchivalQUICTLSHandshakeEvent `json:"tls_handshakes"`
|
||||
NetworkEvents []*model.ArchivalNetworkEvent `json:"network_events"`
|
||||
TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"`
|
||||
TLSHandshake *model.ArchivalTLSOrQUICHandshakeResult `json:"tls_handshake"`
|
||||
}
|
||||
|
||||
// Measurer performs the measurement.
|
||||
|
@ -133,49 +133,67 @@ func (m *Measurer) Run(
|
|||
}
|
||||
tk := new(TestKeys)
|
||||
measurement.TestKeys = tk
|
||||
out := make(chan *measurex.EndpointMeasurement)
|
||||
mxmx := measurex.NewMeasurerWithDefaultSettings()
|
||||
go m.tlsPingLoop(ctx, mxmx, parsed.Host, out)
|
||||
out := make(chan *SinglePing)
|
||||
go m.tlsPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out)
|
||||
for len(tk.Pings) < int(m.config.repetitions()) {
|
||||
meas := <-out
|
||||
tk.Pings = append(tk.Pings, &SinglePing{
|
||||
NetworkEvents: measurex.NewArchivalNetworkEventList(meas.ReadWrite),
|
||||
TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect),
|
||||
TLSHandshakes: measurex.NewArchivalQUICTLSHandshakeEventList(meas.TLSHandshake),
|
||||
})
|
||||
tk.Pings = append(tk.Pings, <-out)
|
||||
}
|
||||
return nil // return nil so we always submit the measurement
|
||||
}
|
||||
|
||||
// tlsPingLoop sends all the ping requests and emits the results onto the out channel.
|
||||
func (m *Measurer) tlsPingLoop(ctx context.Context, mxmx *measurex.Measurer,
|
||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
||||
func (m *Measurer) tlsPingLoop(ctx context.Context, zeroTime time.Time,
|
||||
logger model.Logger, address string, out chan<- *SinglePing) {
|
||||
ticker := time.NewTicker(m.config.delay())
|
||||
defer ticker.Stop()
|
||||
for i := int64(0); i < m.config.repetitions(); i++ {
|
||||
go m.tlsPingAsync(ctx, mxmx, address, out)
|
||||
go m.tlsPingAsync(ctx, i, zeroTime, logger, address, out)
|
||||
<-ticker.C
|
||||
}
|
||||
}
|
||||
|
||||
// tlsPingAsync performs a TLS ping and emits the result onto the out channel.
|
||||
func (m *Measurer) tlsPingAsync(ctx context.Context, mxmx *measurex.Measurer,
|
||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
||||
out <- m.tlsConnectAndHandshake(ctx, mxmx, address)
|
||||
func (m *Measurer) tlsPingAsync(ctx context.Context, index int64,
|
||||
zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) {
|
||||
out <- m.tlsConnectAndHandshake(ctx, index, zeroTime, logger, address)
|
||||
}
|
||||
|
||||
// tlsConnectAndHandshake performs a TCP connect followed by a TLS handshake
|
||||
// and returns the results of these operations to the caller.
|
||||
func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, mxmx *measurex.Measurer,
|
||||
address string) *measurex.EndpointMeasurement {
|
||||
func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, index int64,
|
||||
zeroTime time.Time, logger model.Logger, address string) *SinglePing {
|
||||
// TODO(bassosimone): make the timeout user-configurable
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
return mxmx.TLSConnectAndHandshake(ctx, address, &tls.Config{
|
||||
NextProtos: strings.Split(m.config.alpn(), " "),
|
||||
sp := &SinglePing{
|
||||
NetworkEvents: []*model.ArchivalNetworkEvent{},
|
||||
TCPConnect: nil,
|
||||
TLSHandshake: nil,
|
||||
}
|
||||
trace := measurexlite.NewTrace(index, zeroTime)
|
||||
dialer := trace.NewDialerWithoutResolver(logger)
|
||||
alpn := strings.Split(m.config.alpn(), " ")
|
||||
sni := m.config.sni(address)
|
||||
ol := measurexlite.NewOperationLogger(logger, "TLSPing #%d %s %s %v", index, address, sni, alpn)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||
sp.TCPConnect = <-trace.TCPConnect
|
||||
if err != nil {
|
||||
ol.Stop(err)
|
||||
return sp
|
||||
}
|
||||
defer conn.Close()
|
||||
conn = trace.WrapNetConn(conn)
|
||||
thx := trace.NewTLSHandshakerStdlib(logger)
|
||||
config := &tls.Config{
|
||||
NextProtos: alpn,
|
||||
RootCAs: netxlite.NewDefaultCertPool(),
|
||||
ServerName: m.config.sni(address),
|
||||
})
|
||||
ServerName: sni,
|
||||
}
|
||||
_, _, err = thx.Handshake(ctx, conn, config)
|
||||
ol.Stop(err)
|
||||
sp.TLSHandshake = <-trace.TLSHandshake
|
||||
sp.NetworkEvents = trace.NetworkEvents()
|
||||
return sp
|
||||
}
|
||||
|
||||
// NewExperimentMeasurer creates a new ExperimentMeasurer.
|
||||
|
|
|
@ -39,7 +39,7 @@ func TestMeasurer_run(t *testing.T) {
|
|||
const expectedPings = 4
|
||||
|
||||
// runHelper is an helper function to run this set of tests.
|
||||
runHelper := func(input string) (*model.Measurement, model.ExperimentMeasurer, error) {
|
||||
runHelper := func(ctx context.Context, input string) (*model.Measurement, model.ExperimentMeasurer, error) {
|
||||
m := NewExperimentMeasurer(Config{
|
||||
ALPN: "http/1.1",
|
||||
Delay: 1, // millisecond
|
||||
|
@ -48,10 +48,9 @@ func TestMeasurer_run(t *testing.T) {
|
|||
if m.ExperimentName() != "tlsping" {
|
||||
t.Fatal("invalid experiment name")
|
||||
}
|
||||
if m.ExperimentVersion() != "0.1.0" {
|
||||
if m.ExperimentVersion() != "0.2.0" {
|
||||
t.Fatal("invalid experiment version")
|
||||
}
|
||||
ctx := context.Background()
|
||||
meas := &model.Measurement{
|
||||
Input: model.MeasurementTarget(input),
|
||||
}
|
||||
|
@ -64,34 +63,34 @@ func TestMeasurer_run(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run("with empty input", func(t *testing.T) {
|
||||
_, _, err := runHelper("")
|
||||
_, _, err := runHelper(context.Background(), "")
|
||||
if !errors.Is(err, errNoInputProvided) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with invalid URL", func(t *testing.T) {
|
||||
_, _, err := runHelper("\t")
|
||||
_, _, err := runHelper(context.Background(), "\t")
|
||||
if !errors.Is(err, errInputIsNotAnURL) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with invalid scheme", func(t *testing.T) {
|
||||
_, _, err := runHelper("https://8.8.8.8:443/")
|
||||
_, _, err := runHelper(context.Background(), "https://8.8.8.8:443/")
|
||||
if !errors.Is(err, errInvalidScheme) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with missing port", func(t *testing.T) {
|
||||
_, _, err := runHelper("tlshandshake://8.8.8.8")
|
||||
_, _, err := runHelper(context.Background(), "tlshandshake://8.8.8.8")
|
||||
if !errors.Is(err, errMissingPort) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with local listener", func(t *testing.T) {
|
||||
t.Run("with local listener and successful outcome", func(t *testing.T) {
|
||||
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
|
@ -101,7 +100,37 @@ func TestMeasurer_run(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
URL.Scheme = "tlshandshake"
|
||||
meas, m, err := runHelper(URL.String())
|
||||
meas, m, err := runHelper(context.Background(), URL.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tk := meas.TestKeys.(*TestKeys)
|
||||
if len(tk.Pings) != expectedPings {
|
||||
t.Fatal("unexpected number of pings")
|
||||
}
|
||||
ask, err := m.GetSummaryKeys(meas)
|
||||
if err != nil {
|
||||
t.Fatal("cannot obtain summary")
|
||||
}
|
||||
summary := ask.(SummaryKeys)
|
||||
if summary.IsAnomaly {
|
||||
t.Fatal("expected no anomaly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with local listener and connect issues", func(t *testing.T) {
|
||||
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
defer srvr.Close()
|
||||
URL, err := url.Parse(srvr.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
URL.Scheme = "tlshandshake"
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // so we cannot dial any connection
|
||||
meas, m, err := runHelper(ctx, URL.String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
// Package urlgetter implements a nettest that fetches a URL.
|
||||
//
|
||||
// See https://github.com/ooni/spec/blob/master/nettests/ts-027-urlgetter.md.
|
||||
//
|
||||
// This package is now frozen. Please, use measurexlite for new code. New
|
||||
// network experiments should not depend on this package. Please see
|
||||
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||
// for details about this.
|
||||
package urlgetter
|
||||
|
||||
import (
|
||||
|
|
|
@ -46,4 +46,8 @@
|
|||
//
|
||||
// See docs/design/dd-002-nets.md in the probe-cli repository for
|
||||
// the design document describing this package.
|
||||
//
|
||||
// This package is now frozen. Please, use measurexlite for new code. See
|
||||
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||
// for details about this.
|
||||
package netx
|
||||
|
|
|
@ -13,10 +13,10 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/fakefill"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
"github.com/ooni/probe-cli/v3/internal/version"
|
||||
)
|
||||
|
||||
|
@ -47,7 +47,7 @@ func TestAPIClientTemplate(t *testing.T) {
|
|||
HTTPClient: http.DefaultClient,
|
||||
Logger: model.DiscardLogger,
|
||||
}
|
||||
ff := &fakefill.Filler{}
|
||||
ff := &testingx.FakeFiller{}
|
||||
ff.Fill(tmpl)
|
||||
ac := tmpl.Build()
|
||||
orig := apiClient(*tmpl)
|
||||
|
@ -64,7 +64,7 @@ func TestAPIClientTemplate(t *testing.T) {
|
|||
HTTPClient: http.DefaultClient,
|
||||
Logger: model.DiscardLogger,
|
||||
}
|
||||
ff := &fakefill.Filler{}
|
||||
ff := &testingx.FakeFiller{}
|
||||
ff.Fill(tmpl)
|
||||
tok := ""
|
||||
ff.Fill(&tok)
|
||||
|
@ -188,7 +188,7 @@ func TestAPIClient(t *testing.T) {
|
|||
|
||||
t.Run("sets the content-type properly", func(t *testing.T) {
|
||||
var jsonReq fakeRequest
|
||||
ff := &fakefill.Filler{}
|
||||
ff := &testingx.FakeFiller{}
|
||||
ff.Fill(&jsonReq)
|
||||
client := newAPIClient()
|
||||
req, err := client.newRequestWithJSONBody(
|
||||
|
|
|
@ -1,2 +1,6 @@
|
|||
// Package measurex contains measurement extensions.
|
||||
//
|
||||
// This package is now frozen. Please, use measurexlite for new code. See
|
||||
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||
// for details about this.
|
||||
package measurex
|
||||
|
|
103
internal/measurexlite/conn.go
Normal file
103
internal/measurexlite/conn.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package measurexlite
|
||||
|
||||
//
|
||||
// Conn tracing
|
||||
//
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||
)
|
||||
|
||||
// MaybeClose is a convenience function for closing a conn only when such a conn isn't nil.
|
||||
func MaybeClose(conn net.Conn) (err error) {
|
||||
if conn != nil {
|
||||
err = conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// WrapNetConn returns a wrapped conn that saves network events into this trace.
|
||||
func (tx *Trace) WrapNetConn(conn net.Conn) net.Conn {
|
||||
return &connTrace{
|
||||
Conn: conn,
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// connTrace is a trace-aware net.Conn.
|
||||
type connTrace struct {
|
||||
// Implementation note: it seems safe to use embedding here because net.Conn
|
||||
// is an interface from the standard library that we don't control
|
||||
net.Conn
|
||||
tx *Trace
|
||||
}
|
||||
|
||||
var _ net.Conn = &connTrace{}
|
||||
|
||||
// Read implements net.Conn.Read and saves network events.
|
||||
func (c *connTrace) Read(b []byte) (int, error) {
|
||||
network := c.RemoteAddr().Network()
|
||||
addr := c.RemoteAddr().String()
|
||||
started := c.tx.TimeSince(c.tx.ZeroTime)
|
||||
count, err := c.Conn.Read(b)
|
||||
finished := c.tx.TimeSince(c.tx.ZeroTime)
|
||||
select {
|
||||
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
|
||||
c.tx.Index, started, netxlite.ReadOperation, network, addr, count, err, finished):
|
||||
default: // buffer is full
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write and saves network events.
|
||||
func (c *connTrace) Write(b []byte) (int, error) {
|
||||
network := c.RemoteAddr().Network()
|
||||
addr := c.RemoteAddr().String()
|
||||
started := c.tx.TimeSince(c.tx.ZeroTime)
|
||||
count, err := c.Conn.Write(b)
|
||||
finished := c.tx.TimeSince(c.tx.ZeroTime)
|
||||
select {
|
||||
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
|
||||
c.tx.Index, started, netxlite.WriteOperation, network, addr, count, err, finished):
|
||||
default: // buffer is full
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
// NewArchivalNetworkEvent creates a new model.ArchivalNetworkEvent.
|
||||
func NewArchivalNetworkEvent(index int64, started time.Duration, operation string, network string,
|
||||
address string, count int, err error, finished time.Duration) *model.ArchivalNetworkEvent {
|
||||
return &model.ArchivalNetworkEvent{
|
||||
Address: address,
|
||||
Failure: tracex.NewFailure(err),
|
||||
NumBytes: int64(count),
|
||||
Operation: operation,
|
||||
Proto: network,
|
||||
T: finished.Seconds(),
|
||||
Tags: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// NewAnnotationArchivalNetworkEvent is a simplified NewArchivalNetworkEvent
|
||||
// where we create a simple annotation without attached I/O info.
|
||||
func NewAnnotationArchivalNetworkEvent(
|
||||
index int64, time time.Duration, operation string) *model.ArchivalNetworkEvent {
|
||||
return NewArchivalNetworkEvent(index, time, operation, "", "", 0, nil, time)
|
||||
}
|
||||
|
||||
// NetworkEvents drains the network events buffered inside the NetworkEvent channel.
|
||||
func (tx *Trace) NetworkEvents() (out []*model.ArchivalNetworkEvent) {
|
||||
for {
|
||||
select {
|
||||
case ev := <-tx.NetworkEvent:
|
||||
out = append(out, ev)
|
||||
default:
|
||||
return // done
|
||||
}
|
||||
}
|
||||
}
|
243
internal/measurexlite/conn_test.go
Normal file
243
internal/measurexlite/conn_test.go
Normal file
|
@ -0,0 +1,243 @@
|
|||
package measurexlite
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestMaybeClose(t *testing.T) {
|
||||
t.Run("with nil conn", func(t *testing.T) {
|
||||
var conn net.Conn = nil
|
||||
MaybeClose(conn)
|
||||
})
|
||||
|
||||
t.Run("with nonnil conn", func(t *testing.T) {
|
||||
var called bool
|
||||
conn := &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
if err := MaybeClose(conn); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWrapNetConn(t *testing.T) {
|
||||
t.Run("WrapNetConn wraps the conn", func(t *testing.T) {
|
||||
underlying := &mocks.Conn{}
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
conn := trace.WrapNetConn(underlying)
|
||||
ct := conn.(*connTrace)
|
||||
if ct.Conn != underlying {
|
||||
t.Fatal("invalid underlying")
|
||||
}
|
||||
if ct.tx != trace {
|
||||
t.Fatal("invalid trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Read saves a trace", func(t *testing.T) {
|
||||
underlying := &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now // deterministic time counting
|
||||
conn := trace.WrapNetConn(underlying)
|
||||
const bufsiz = 128
|
||||
buffer := make([]byte, bufsiz)
|
||||
count, err := conn.Read(buffer)
|
||||
if count != bufsiz {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("invalid err")
|
||||
}
|
||||
events := trace.NetworkEvents()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("did not save network events")
|
||||
}
|
||||
expect := &model.ArchivalNetworkEvent{
|
||||
Address: "1.1.1.1:443",
|
||||
Failure: nil,
|
||||
NumBytes: bufsiz,
|
||||
Operation: netxlite.ReadOperation,
|
||||
Proto: "tcp",
|
||||
T: 1.0,
|
||||
Tags: []string{},
|
||||
}
|
||||
got := events[0]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Read discards the event when the buffer is full", func(t *testing.T) {
|
||||
underlying := &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
|
||||
conn := trace.WrapNetConn(underlying)
|
||||
const bufsiz = 128
|
||||
buffer := make([]byte, bufsiz)
|
||||
count, err := conn.Read(buffer)
|
||||
if count != bufsiz {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("invalid err")
|
||||
}
|
||||
events := trace.NetworkEvents()
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected no network events")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write saves a trace", func(t *testing.T) {
|
||||
underlying := &mocks.Conn{
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now // deterministic time tracking
|
||||
conn := trace.WrapNetConn(underlying)
|
||||
const bufsiz = 128
|
||||
buffer := make([]byte, bufsiz)
|
||||
count, err := conn.Write(buffer)
|
||||
if count != bufsiz {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("invalid err")
|
||||
}
|
||||
events := trace.NetworkEvents()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("did not save network events")
|
||||
}
|
||||
expect := &model.ArchivalNetworkEvent{
|
||||
Address: "1.1.1.1:443",
|
||||
Failure: nil,
|
||||
NumBytes: bufsiz,
|
||||
Operation: netxlite.WriteOperation,
|
||||
Proto: "tcp",
|
||||
T: 1.0,
|
||||
Tags: []string{},
|
||||
}
|
||||
got := events[0]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write discards the event when the buffer is full", func(t *testing.T) {
|
||||
underlying := &mocks.Conn{
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
|
||||
conn := trace.WrapNetConn(underlying)
|
||||
const bufsiz = 128
|
||||
buffer := make([]byte, bufsiz)
|
||||
count, err := conn.Write(buffer)
|
||||
if count != bufsiz {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("invalid err")
|
||||
}
|
||||
events := trace.NetworkEvents()
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected no network events")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewAnnotationArchivalNetworkEvent(t *testing.T) {
|
||||
var (
|
||||
index int64 = 3
|
||||
duration = 250 * time.Millisecond
|
||||
operation = "tls_handshake_start"
|
||||
)
|
||||
expect := &model.ArchivalNetworkEvent{
|
||||
Address: "",
|
||||
Failure: nil,
|
||||
NumBytes: 0,
|
||||
Operation: operation,
|
||||
Proto: "",
|
||||
T: duration.Seconds(),
|
||||
Tags: []string{},
|
||||
}
|
||||
got := NewAnnotationArchivalNetworkEvent(
|
||||
index, duration, operation,
|
||||
)
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
121
internal/measurexlite/dialer.go
Normal file
121
internal/measurexlite/dialer.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package measurexlite
|
||||
|
||||
//
|
||||
// Dialer tracing
|
||||
//
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||
)
|
||||
|
||||
// NewDialerWithoutResolver is equivalent to netxlite.NewDialerWithoutResolver
|
||||
// except that it returns a model.Dialer that uses this trace.
|
||||
//
|
||||
// Note: unlike code in netx or measurex, this factory DOES NOT return you a
|
||||
// dialer that also performs wrapping of a net.Conn in case of success. If you
|
||||
// want to wrap the conn, you need to wrap it explicitly using WrapNetConn.
|
||||
func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
|
||||
return &dialerTrace{
|
||||
d: tx.newDialerWithoutResolver(dl),
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// dialerTrace is a trace-aware model.Dialer.
|
||||
type dialerTrace struct {
|
||||
d model.Dialer
|
||||
tx *Trace
|
||||
}
|
||||
|
||||
var _ model.Dialer = &dialerTrace{}
|
||||
|
||||
// DialContext implements model.Dialer.DialContext.
|
||||
func (d *dialerTrace) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.d.DialContext(netxlite.ContextWithTrace(ctx, d.tx), network, address)
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements model.Dialer.CloseIdleConnections.
|
||||
func (d *dialerTrace) CloseIdleConnections() {
|
||||
d.d.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// OnTCPConnectDone implements model.Trace.OnTCPConnectDone.
|
||||
func (tx *Trace) OnConnectDone(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "tcp6":
|
||||
select {
|
||||
case tx.TCPConnect <- NewArchivalTCPConnectResult(
|
||||
tx.Index,
|
||||
started.Sub(tx.ZeroTime),
|
||||
remoteAddr,
|
||||
err,
|
||||
finished.Sub(tx.ZeroTime),
|
||||
):
|
||||
default: // buffer is full
|
||||
}
|
||||
default:
|
||||
// ignore UDP connect attempts because they cannot fail
|
||||
// in interesting ways that make sense for censorship
|
||||
}
|
||||
}
|
||||
|
||||
// NewArchivalTCPConnectResult generates a model.ArchivalTCPConnectResult
|
||||
// from the available information right after connect returns.
|
||||
func NewArchivalTCPConnectResult(index int64, started time.Duration, address string,
|
||||
err error, finished time.Duration) *model.ArchivalTCPConnectResult {
|
||||
ip, port := archivalSplitHostPort(address)
|
||||
return &model.ArchivalTCPConnectResult{
|
||||
IP: ip,
|
||||
Port: archivalPortToString(port),
|
||||
Status: model.ArchivalTCPConnectStatus{
|
||||
Blocked: nil,
|
||||
Failure: tracex.NewFailure(err),
|
||||
Success: err == nil,
|
||||
},
|
||||
T: finished.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// archivalSplitHostPort is like net.SplitHostPort but does not return an error. This
|
||||
// function returns two empty strings in case of any failure.
|
||||
func archivalSplitHostPort(endpoint string) (string, string) {
|
||||
addr, port, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
log.Printf("BUG: archivalSplitHostPort: invalid endpoint: %s", endpoint)
|
||||
return "", ""
|
||||
}
|
||||
return addr, port
|
||||
}
|
||||
|
||||
// archivalPortToString is like strconv.Atoi but does not return an error. This
|
||||
// function returns a zero port number in case of any failure.
|
||||
func archivalPortToString(sport string) int {
|
||||
port, err := strconv.Atoi(sport)
|
||||
if err != nil || port < 0 || port > math.MaxUint16 {
|
||||
log.Printf("BUG: archivalStrconvAtoi: invalid port: %s", sport)
|
||||
return 0
|
||||
}
|
||||
return port
|
||||
}
|
||||
|
||||
// TCPConnects drains the network events buffered inside the TCPConnect channel.
|
||||
func (tx *Trace) TCPConnects() (out []*model.ArchivalTCPConnectResult) {
|
||||
for {
|
||||
select {
|
||||
case ev := <-tx.TCPConnect:
|
||||
out = append(out, ev)
|
||||
default:
|
||||
return // done
|
||||
}
|
||||
}
|
||||
}
|
211
internal/measurexlite/dialer_test.go
Normal file
211
internal/measurexlite/dialer_test.go
Normal file
|
@ -0,0 +1,211 @@
|
|||
package measurexlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestNewDialerWithoutResolver(t *testing.T) {
|
||||
t.Run("NewDialerWithoutResolver creates a wrapped dialer", func(t *testing.T) {
|
||||
underlying := &mocks.Dialer{}
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
|
||||
return underlying
|
||||
}
|
||||
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
dt := dialer.(*dialerTrace)
|
||||
if dt.d != underlying {
|
||||
t.Fatal("invalid dialer")
|
||||
}
|
||||
if dt.tx != trace {
|
||||
t.Fatal("invalid trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialContext calls the underlying dialer with context-based tracing", func(t *testing.T) {
|
||||
expectedErr := errors.New("mocked err")
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
var hasCorrectTrace bool
|
||||
underlying := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
gotTrace := netxlite.ContextTraceOrDefault(ctx)
|
||||
hasCorrectTrace = (gotTrace == trace)
|
||||
return nil, expectedErr
|
||||
},
|
||||
}
|
||||
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
|
||||
return underlying
|
||||
}
|
||||
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx := context.Background()
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
if !hasCorrectTrace {
|
||||
t.Fatal("does not have the correct trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnection is correctly forwarded", func(t *testing.T) {
|
||||
var called bool
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
underlying := &mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
|
||||
return underlying
|
||||
}
|
||||
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
dialer.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialContext saves into the trace", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now // deterministic time tracking
|
||||
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // we cancel immediately so connect is ~instantaneous
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
events := trace.TCPConnects()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("expected to see single TCPConnect event")
|
||||
}
|
||||
expectedFailure := netxlite.FailureInterrupted
|
||||
expect := &model.ArchivalTCPConnectResult{
|
||||
IP: "1.1.1.1",
|
||||
Port: 443,
|
||||
Status: model.ArchivalTCPConnectStatus{
|
||||
Blocked: nil,
|
||||
Failure: &expectedFailure,
|
||||
Success: false,
|
||||
},
|
||||
T: time.Second.Seconds(),
|
||||
}
|
||||
got := events[0]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialContext discards events when buffer is full", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TCPConnect = make(chan *model.ArchivalTCPConnectResult) // no buffer
|
||||
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // we cancel immediately so connect is ~instantaneous
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
events := trace.TCPConnects()
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected to see no TCPConnect events")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialContext ignores UDP connect attempts", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // we cancel immediately so connect is ~instantaneous
|
||||
conn, err := dialer.DialContext(ctx, "udp", "1.1.1.1:443")
|
||||
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
events := trace.TCPConnects()
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected to see no TCPConnect events")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DialContext uses a dialer without a resolver", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // we cancel immediately so connect is ~instantaneous
|
||||
conn, err := dialer.DialContext(ctx, "udp", "dns.google:443") // domain
|
||||
if !errors.Is(err, netxlite.ErrNoResolver) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
events := trace.TCPConnects()
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected to see no TCPConnect events")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestArchivalSplitHostPort(t *testing.T) {
|
||||
addr, port := archivalSplitHostPort("1.1.1.1") // missing port
|
||||
if addr != "" {
|
||||
t.Fatal("invalid addr", addr)
|
||||
}
|
||||
if port != "" {
|
||||
t.Fatal("invalid port", port)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArchivalPortToString(t *testing.T) {
|
||||
t.Run("with invalid number", func(t *testing.T) {
|
||||
port := archivalPortToString("antani")
|
||||
if port != 0 {
|
||||
t.Fatal("invalid port")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with negative number", func(t *testing.T) {
|
||||
port := archivalPortToString("-1")
|
||||
if port != 0 {
|
||||
t.Fatal("invalid port")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with too-large positive number", func(t *testing.T) {
|
||||
port := archivalPortToString(strconv.Itoa(math.MaxUint16 + 1))
|
||||
if port != 0 {
|
||||
t.Fatal("invalid port")
|
||||
}
|
||||
})
|
||||
}
|
10
internal/measurexlite/doc.go
Normal file
10
internal/measurexlite/doc.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
// Package measurexlite contains measurement extensions.
|
||||
//
|
||||
// See docs/design/dd-003-step-by-step.md in the ooni/probe-cli
|
||||
// repository for the design document.
|
||||
//
|
||||
// This implementation features a Trace that saves events in
|
||||
// buffered channels as proposed by df-003-step-by-step.md. We
|
||||
// have reasonable default buffers for channels. But, if you
|
||||
// are not draining them, eventually we stop collecting events.
|
||||
package measurexlite
|
16
internal/measurexlite/logger.go
Normal file
16
internal/measurexlite/logger.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package measurexlite
|
||||
|
||||
//
|
||||
// Logging support
|
||||
//
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/measurex"
|
||||
|
||||
// TODO(bassosimone): we should eventually remove measurex and
|
||||
// move the logging code from measurex to this package.
|
||||
|
||||
// NewOperationLogger is an alias for measurex.NewOperationLogger.
|
||||
var NewOperationLogger = measurex.NewOperationLogger
|
||||
|
||||
// OperationLogger is an alias for measurex.OperationLogger.
|
||||
type OperationLogger = measurex.OperationLogger
|
144
internal/measurexlite/tls.go
Normal file
144
internal/measurexlite/tls.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package measurexlite
|
||||
|
||||
//
|
||||
// TLS tracing
|
||||
//
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||
)
|
||||
|
||||
// NewTLSHandshakerStdlib is equivalent to netxlite.NewTLSHandshakerStdlib
|
||||
// except that it returns a model.TLSHandshaker that uses this trace.
|
||||
func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
|
||||
return &tlsHandshakerTrace{
|
||||
thx: tx.newTLSHandshakerStdlib(dl),
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// tlsHandshakerTrace is a trace-aware TLS handshaker.
|
||||
type tlsHandshakerTrace struct {
|
||||
thx model.TLSHandshaker
|
||||
tx *Trace
|
||||
}
|
||||
|
||||
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
|
||||
|
||||
// Handshake implements model.TLSHandshaker.Handshake.
|
||||
func (thx *tlsHandshakerTrace) Handshake(
|
||||
ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return thx.thx.Handshake(netxlite.ContextWithTrace(ctx, thx.tx), conn, tlsConfig)
|
||||
}
|
||||
|
||||
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
|
||||
func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
t := now.Sub(tx.ZeroTime)
|
||||
select {
|
||||
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_start"):
|
||||
default: // buffer is full
|
||||
}
|
||||
}
|
||||
|
||||
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
|
||||
func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||
state tls.ConnectionState, err error, finished time.Time) {
|
||||
t := finished.Sub(tx.ZeroTime)
|
||||
select {
|
||||
case tx.TLSHandshake <- NewArchivalTLSOrQUICHandshakeResult(
|
||||
tx.Index,
|
||||
started.Sub(tx.ZeroTime),
|
||||
remoteAddr,
|
||||
config,
|
||||
state,
|
||||
err,
|
||||
t,
|
||||
):
|
||||
default: // buffer is full
|
||||
}
|
||||
select {
|
||||
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_done"):
|
||||
default: // buffer is full
|
||||
}
|
||||
}
|
||||
|
||||
// NewArchivalTLSOrQUICHandshakeResult generates a model.ArchivalTLSOrQUICHandshakeResult
|
||||
// from the available information right after the TLS handshake returns.
|
||||
func NewArchivalTLSOrQUICHandshakeResult(
|
||||
index int64, started time.Duration, address string, config *tls.Config,
|
||||
state tls.ConnectionState, err error, finished time.Duration) *model.ArchivalTLSOrQUICHandshakeResult {
|
||||
return &model.ArchivalTLSOrQUICHandshakeResult{
|
||||
Address: address,
|
||||
CipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
|
||||
Failure: tracex.NewFailure(err),
|
||||
NegotiatedProtocol: state.NegotiatedProtocol,
|
||||
NoTLSVerify: config.InsecureSkipVerify,
|
||||
PeerCertificates: TLSPeerCerts(state, err),
|
||||
ServerName: config.ServerName,
|
||||
T: finished.Seconds(),
|
||||
Tags: []string{},
|
||||
TLSVersion: netxlite.TLSVersionString(state.Version),
|
||||
}
|
||||
}
|
||||
|
||||
// newArchivalBinaryData is a factory that adapts binary data to the
|
||||
// model.ArchivalMaybeBinaryData format.
|
||||
func newArchivalBinaryData(data []byte) model.ArchivalMaybeBinaryData {
|
||||
// TODO(https://github.com/ooni/probe/issues/2165): we should actually extend the
|
||||
// model's archival data format to have a pure-binary-data type for the cases in which
|
||||
// we know in advance we're dealing with binary data.
|
||||
return model.ArchivalMaybeBinaryData{
|
||||
Value: string(data),
|
||||
}
|
||||
}
|
||||
|
||||
// TLSPeerCerts extracts the certificates either from the list of certificates
|
||||
// in the connection state or from the error that occurred.
|
||||
func TLSPeerCerts(
|
||||
state tls.ConnectionState, err error) (out []model.ArchivalMaybeBinaryData) {
|
||||
out = []model.ArchivalMaybeBinaryData{}
|
||||
var x509HostnameError x509.HostnameError
|
||||
if errors.As(err, &x509HostnameError) {
|
||||
// Test case: https://wrong.host.badssl.com/
|
||||
out = append(out, newArchivalBinaryData(x509HostnameError.Certificate.Raw))
|
||||
return
|
||||
}
|
||||
var x509UnknownAuthorityError x509.UnknownAuthorityError
|
||||
if errors.As(err, &x509UnknownAuthorityError) {
|
||||
// Test case: https://self-signed.badssl.com/. This error has
|
||||
// never been among the ones returned by MK.
|
||||
out = append(out, newArchivalBinaryData(x509UnknownAuthorityError.Cert.Raw))
|
||||
return
|
||||
}
|
||||
var x509CertificateInvalidError x509.CertificateInvalidError
|
||||
if errors.As(err, &x509CertificateInvalidError) {
|
||||
// Test case: https://expired.badssl.com/
|
||||
out = append(out, newArchivalBinaryData(x509CertificateInvalidError.Cert.Raw))
|
||||
return
|
||||
}
|
||||
for _, cert := range state.PeerCertificates {
|
||||
out = append(out, newArchivalBinaryData(cert.Raw))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TLSHandshakes drains the network events buffered inside the TLSHandshake channel.
|
||||
func (tx *Trace) TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) {
|
||||
for {
|
||||
select {
|
||||
case ev := <-tx.TLSHandshake:
|
||||
out = append(out, ev)
|
||||
default:
|
||||
return // done
|
||||
}
|
||||
}
|
||||
}
|
418
internal/measurexlite/tls_test.go
Normal file
418
internal/measurexlite/tls_test.go
Normal file
|
@ -0,0 +1,418 @@
|
|||
package measurexlite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||
t.Run("NewTLSHandshakerStdlib creates a wrapped TLSHandshaker", func(t *testing.T) {
|
||||
underlying := &mocks.TLSHandshaker{}
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
|
||||
return underlying
|
||||
}
|
||||
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
thxt := thx.(*tlsHandshakerTrace)
|
||||
if thxt.thx != underlying {
|
||||
t.Fatal("invalid TLS handshaker")
|
||||
}
|
||||
if thxt.tx != trace {
|
||||
t.Fatal("invalid trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handshake calls the underlying dialer with context-based tracing", func(t *testing.T) {
|
||||
expectedErr := errors.New("mocked err")
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
var hasCorrectTrace bool
|
||||
underlying := &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
gotTrace := netxlite.ContextTraceOrDefault(ctx)
|
||||
hasCorrectTrace = (gotTrace == trace)
|
||||
return nil, tls.ConnectionState{}, expectedErr
|
||||
},
|
||||
}
|
||||
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
|
||||
return underlying
|
||||
}
|
||||
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
ctx := context.Background()
|
||||
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("expected zero-value state")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
if !hasCorrectTrace {
|
||||
t.Fatal("does not have the correct trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Handshake saves into the trace", func(t *testing.T) {
|
||||
mockedErr := errors.New("mocked")
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now // deterministic timing
|
||||
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // we cancel immediately so connect is ~instantaneous
|
||||
tcpConn := &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
}
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, mockedErr
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "dns.cloudflare.com",
|
||||
}
|
||||
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("expected zero-value state")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
|
||||
t.Run("TLSHandshake events", func(t *testing.T) {
|
||||
events := trace.TLSHandshakes()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("expected to see single TLSHandshake event")
|
||||
}
|
||||
expectedFailure := netxlite.FailureInterrupted
|
||||
expect := &model.ArchivalTLSOrQUICHandshakeResult{
|
||||
Address: "1.1.1.1:443",
|
||||
CipherSuite: "",
|
||||
Failure: &expectedFailure,
|
||||
NegotiatedProtocol: "",
|
||||
NoTLSVerify: true,
|
||||
PeerCertificates: []model.ArchivalMaybeBinaryData{},
|
||||
ServerName: "dns.cloudflare.com",
|
||||
T: time.Second.Seconds(),
|
||||
Tags: []string{},
|
||||
TLSVersion: "",
|
||||
}
|
||||
got := events[0]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network events", func(t *testing.T) {
|
||||
events := trace.NetworkEvents()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("expected to see two Network events")
|
||||
}
|
||||
|
||||
t.Run("tls_handshake_start", func(t *testing.T) {
|
||||
expect := &model.ArchivalNetworkEvent{
|
||||
Address: "",
|
||||
Failure: nil,
|
||||
NumBytes: 0,
|
||||
Operation: "tls_handshake_start",
|
||||
Proto: "",
|
||||
T: 0,
|
||||
Tags: []string{},
|
||||
}
|
||||
got := events[0]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tls_handshake_done", func(t *testing.T) {
|
||||
expect := &model.ArchivalNetworkEvent{
|
||||
Address: "",
|
||||
Failure: nil,
|
||||
NumBytes: 0,
|
||||
Operation: "tls_handshake_done",
|
||||
Proto: "",
|
||||
T: time.Second.Seconds(),
|
||||
Tags: []string{},
|
||||
}
|
||||
got := events[1]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Handshake discards events when buffers are full", func(t *testing.T) {
|
||||
mockedErr := errors.New("mocked")
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
|
||||
trace.TLSHandshake = make(chan *model.ArchivalTLSOrQUICHandshakeResult) // no buffer
|
||||
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // we cancel immediately so connect is ~instantaneous
|
||||
tcpConn := &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
}
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, mockedErr
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "dns.cloudflare.com",
|
||||
}
|
||||
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("expected zero-value state")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
|
||||
t.Run("TLSHandshake events", func(t *testing.T) {
|
||||
events := trace.TLSHandshakes()
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected to see no TLSHandshake events")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network events", func(t *testing.T) {
|
||||
events := trace.NetworkEvents()
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected to see no Network events")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("we collect the desired data with a local TLS server", func(t *testing.T) {
|
||||
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||
dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx := context.Background()
|
||||
conn, err := dialer.DialContext(ctx, "tcp", server.Endpoint())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
zeroTime := time.Now()
|
||||
dt := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = dt.Now // deterministic timing
|
||||
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: server.CertPool(),
|
||||
ServerName: "dns.google",
|
||||
}
|
||||
tlsConn, connState, err := thx.Handshake(ctx, conn, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tlsConn.Close()
|
||||
data, err := netxlite.ReadAllContext(ctx, tlsConn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, filtering.HTTPBlockpage451) {
|
||||
t.Fatal("bytes should match")
|
||||
}
|
||||
|
||||
t.Run("TLSHandshake events", func(t *testing.T) {
|
||||
events := trace.TLSHandshakes()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("expected to see a single TLSHandshake event")
|
||||
}
|
||||
expected := &model.ArchivalTLSOrQUICHandshakeResult{
|
||||
Address: conn.RemoteAddr().String(),
|
||||
CipherSuite: netxlite.TLSCipherSuiteString(connState.CipherSuite),
|
||||
Failure: nil,
|
||||
NegotiatedProtocol: "",
|
||||
NoTLSVerify: false,
|
||||
PeerCertificates: []model.ArchivalMaybeBinaryData{},
|
||||
ServerName: "dns.google",
|
||||
T: time.Second.Seconds(),
|
||||
Tags: []string{},
|
||||
TLSVersion: netxlite.TLSVersionString(connState.Version),
|
||||
}
|
||||
got := events[0]
|
||||
// TODO(bassosimone): it's still unclear to me how to test that
|
||||
// I am getting exactly the expected certificate here. I think the
|
||||
// certificate is generated on the fly by google/martian. So, I'm
|
||||
// just going to reduce the precision of this check.
|
||||
if len(got.PeerCertificates) != 2 {
|
||||
t.Fatal("expected to see two certificates")
|
||||
}
|
||||
got.PeerCertificates = []model.ArchivalMaybeBinaryData{} // see above
|
||||
if diff := cmp.Diff(expected, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network events", func(t *testing.T) {
|
||||
events := trace.NetworkEvents()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("expected to see two Network events")
|
||||
}
|
||||
|
||||
t.Run("tls_handshake_start", func(t *testing.T) {
|
||||
expect := &model.ArchivalNetworkEvent{
|
||||
Address: "",
|
||||
Failure: nil,
|
||||
NumBytes: 0,
|
||||
Operation: "tls_handshake_start",
|
||||
Proto: "",
|
||||
T: 0,
|
||||
Tags: []string{},
|
||||
}
|
||||
got := events[0]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tls_handshake_done", func(t *testing.T) {
|
||||
expect := &model.ArchivalNetworkEvent{
|
||||
Address: "",
|
||||
Failure: nil,
|
||||
NumBytes: 0,
|
||||
Operation: "tls_handshake_done",
|
||||
Proto: "",
|
||||
T: time.Second.Seconds(),
|
||||
Tags: []string{},
|
||||
}
|
||||
got := events[1]
|
||||
if diff := cmp.Diff(expect, got); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestTLSPeerCerts(t *testing.T) {
|
||||
type args struct {
|
||||
state tls.ConnectionState
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantOut []model.ArchivalMaybeBinaryData
|
||||
}{{
|
||||
name: "x509.HostnameError",
|
||||
args: args{
|
||||
state: tls.ConnectionState{},
|
||||
err: x509.HostnameError{
|
||||
Certificate: &x509.Certificate{
|
||||
Raw: []byte("deadbeef"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||
Value: "deadbeef",
|
||||
}},
|
||||
}, {
|
||||
name: "x509.UnknownAuthorityError",
|
||||
args: args{
|
||||
state: tls.ConnectionState{},
|
||||
err: x509.UnknownAuthorityError{
|
||||
Cert: &x509.Certificate{
|
||||
Raw: []byte("deadbeef"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||
Value: "deadbeef",
|
||||
}},
|
||||
}, {
|
||||
name: "x509.CertificateInvalidError",
|
||||
args: args{
|
||||
state: tls.ConnectionState{},
|
||||
err: x509.CertificateInvalidError{
|
||||
Cert: &x509.Certificate{
|
||||
Raw: []byte("deadbeef"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||
Value: "deadbeef",
|
||||
}},
|
||||
}, {
|
||||
name: "successful case",
|
||||
args: args{
|
||||
state: tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{{
|
||||
Raw: []byte("deadbeef"),
|
||||
}, {
|
||||
Raw: []byte("abad1dea"),
|
||||
}},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||
Value: "deadbeef",
|
||||
}, {
|
||||
Value: "abad1dea",
|
||||
}},
|
||||
}}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotOut := TLSPeerCerts(tt.args.state, tt.args.err)
|
||||
if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
143
internal/measurexlite/trace.go
Normal file
143
internal/measurexlite/trace.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package measurexlite
|
||||
|
||||
//
|
||||
// Definition of Trace
|
||||
//
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
// Trace implements model.Trace.
|
||||
//
|
||||
// The zero-value of this struct is invalid. To construct you should either
|
||||
// fill all the fields marked as MANDATORY or use NewTrace.
|
||||
//
|
||||
// Buffered channels
|
||||
//
|
||||
// NewTrace uses reasonable buffer sizes for the channels used for collecting
|
||||
// events. You should drain the channels used by this implementation after
|
||||
// each operation you perform (i.e., we expect you to peform step-by-step
|
||||
// measurements). If you want larger (or smaller) buffers, then you should
|
||||
// construct this data type manually with the desired buffer sizes.
|
||||
//
|
||||
// We have convenience methods for extracting events from the buffered
|
||||
// channels. Otherwise, you could read the channels directly. (In which
|
||||
// case, remember to issue nonblocking channel reads because channels are
|
||||
// never closed and they're just written when new events occur.)
|
||||
type Trace struct {
|
||||
// Index is the MANDATORY unique index of this trace within the
|
||||
// current measurement. If you don't care about uniquely identifying
|
||||
// treaces, you can use zero to indicate the "default" trace.
|
||||
Index int64
|
||||
|
||||
// NetworkEvent is MANDATORY and buffers network events. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
NetworkEvent chan *model.ArchivalNetworkEvent
|
||||
|
||||
// NewDialerWithoutResolverFn is OPTIONAL and can be used to override
|
||||
// calls to the netxlite.NewDialerWithoutResolver factory.
|
||||
NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer
|
||||
|
||||
// NewTLSHandshakerStdlibFn is OPTIONAL and can be used to overide
|
||||
// calls to the netxlite.NewTLSHandshakerStdlib factory.
|
||||
NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker
|
||||
|
||||
// TCPConnect is MANDATORY and buffers TCP connect observations. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
TCPConnect chan *model.ArchivalTCPConnectResult
|
||||
|
||||
// TLSHandshake is MANDATORY and buffers TLS handshake observations. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
TLSHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
||||
|
||||
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
|
||||
// to produce deterministic timing when testing.
|
||||
TimeNowFn func() time.Time
|
||||
|
||||
// ZeroTime is the MANDATORY time when we started the current measurement.
|
||||
ZeroTime time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
// NetworkEventBufferSize is the buffer size for constructing
|
||||
// the Trace's NetworkEvent buffered channel.
|
||||
NetworkEventBufferSize = 64
|
||||
|
||||
// TCPConnectBufferSize is the buffer size for constructing
|
||||
// the Trace's TCPConnect buffered channel.
|
||||
TCPConnectBufferSize = 8
|
||||
|
||||
// TLSHandshakeBufferSize is the buffer for construcing
|
||||
// the Trace's TLSHandshake buffered channel.
|
||||
TLSHandshakeBufferSize = 8
|
||||
)
|
||||
|
||||
// NewTrace creates a new instance of Trace using default settings.
|
||||
//
|
||||
// We create buffered channels using as buffer sizes the constants that
|
||||
// are also defined by this package.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - index is the unique index of this trace within the current measurement (use
|
||||
// zero if you don't care about giving this trace a unique ID);
|
||||
//
|
||||
// - zeroTime is the time when we started the current measurement.
|
||||
func NewTrace(index int64, zeroTime time.Time) *Trace {
|
||||
return &Trace{
|
||||
Index: index,
|
||||
NetworkEvent: make(
|
||||
chan *model.ArchivalNetworkEvent,
|
||||
NetworkEventBufferSize,
|
||||
),
|
||||
NewDialerWithoutResolverFn: nil, // use default
|
||||
NewTLSHandshakerStdlibFn: nil, // use default
|
||||
TCPConnect: make(
|
||||
chan *model.ArchivalTCPConnectResult,
|
||||
TCPConnectBufferSize,
|
||||
),
|
||||
TLSHandshake: make(
|
||||
chan *model.ArchivalTLSOrQUICHandshakeResult,
|
||||
TLSHandshakeBufferSize,
|
||||
),
|
||||
TimeNowFn: nil, // use default
|
||||
ZeroTime: zeroTime,
|
||||
}
|
||||
}
|
||||
|
||||
// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver
|
||||
// thus allows us to mock this func for testing.
|
||||
func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
|
||||
if tx.NewDialerWithoutResolverFn != nil {
|
||||
return tx.NewDialerWithoutResolverFn(dl)
|
||||
}
|
||||
return netxlite.NewDialerWithoutResolver(dl)
|
||||
}
|
||||
|
||||
// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
|
||||
// thus allowing us to mock this func for testing.
|
||||
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
|
||||
if tx.NewTLSHandshakerStdlibFn != nil {
|
||||
return tx.NewTLSHandshakerStdlibFn(dl)
|
||||
}
|
||||
return netxlite.NewTLSHandshakerStdlib(dl)
|
||||
}
|
||||
|
||||
// TimeNow implements model.Trace.TimeNow.
|
||||
func (tx *Trace) TimeNow() time.Time {
|
||||
if tx.TimeNowFn != nil {
|
||||
return tx.TimeNowFn()
|
||||
}
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// TimeSince is equivalent to Trace.TimeNow().Sub(t0).
|
||||
func (tx *Trace) TimeSince(t0 time.Time) time.Duration {
|
||||
return tx.TimeNow().Sub(t0)
|
||||
}
|
||||
|
||||
var _ model.Trace = &Trace{}
|
248
internal/measurexlite/trace_test.go
Normal file
248
internal/measurexlite/trace_test.go
Normal file
|
@ -0,0 +1,248 @@
|
|||
package measurexlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestNewTrace(t *testing.T) {
|
||||
t.Run("NewTrace correctly constructs a trace", func(t *testing.T) {
|
||||
const index = 17
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(index, zeroTime)
|
||||
|
||||
t.Run("Index", func(t *testing.T) {
|
||||
if trace.Index != index {
|
||||
t.Fatal("invalid index")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NetworkEvent has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
for {
|
||||
ev := &model.ArchivalNetworkEvent{}
|
||||
ff.Fill(ev)
|
||||
select {
|
||||
case trace.NetworkEvent <- ev:
|
||||
idx++
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
if idx != NetworkEventBufferSize {
|
||||
t.Fatal("invalid NetworkEvent channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) {
|
||||
if trace.NewDialerWithoutResolverFn != nil {
|
||||
t.Fatal("expected nil NewDialerWithoutResolverFn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) {
|
||||
if trace.NewTLSHandshakerStdlibFn != nil {
|
||||
t.Fatal("expected nil NewTLSHandshakerStdlibFn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
for {
|
||||
ev := &model.ArchivalTCPConnectResult{}
|
||||
ff.Fill(ev)
|
||||
select {
|
||||
case trace.TCPConnect <- ev:
|
||||
idx++
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
if idx != TCPConnectBufferSize {
|
||||
t.Fatal("invalid TCPConnect channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
for {
|
||||
ev := &model.ArchivalTLSOrQUICHandshakeResult{}
|
||||
ff.Fill(ev)
|
||||
select {
|
||||
case trace.TLSHandshake <- ev:
|
||||
idx++
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
if idx != TLSHandshakeBufferSize {
|
||||
t.Fatal("invalid TLSHandshake channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TimeNowFn is nil", func(t *testing.T) {
|
||||
if trace.TimeNowFn != nil {
|
||||
t.Fatal("expected nil TimeNowFn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ZeroTime", func(t *testing.T) {
|
||||
if !trace.ZeroTime.Equal(zeroTime) {
|
||||
t.Fatal("invalid zero time")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestTrace(t *testing.T) {
|
||||
t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) {
|
||||
t.Run("when not nil", func(t *testing.T) {
|
||||
mockedErr := errors.New("mocked")
|
||||
tx := &Trace{
|
||||
NewDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer {
|
||||
return &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, mockedErr
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx := context.Background()
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||
if !errors.Is(err, mockedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when nil", func(t *testing.T) {
|
||||
tx := &Trace{
|
||||
NewDialerWithoutResolverFn: nil,
|
||||
}
|
||||
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // fail immediately
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) {
|
||||
t.Run("when not nil", func(t *testing.T) {
|
||||
mockedErr := errors.New("mocked")
|
||||
tx := &Trace{
|
||||
NewTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker {
|
||||
return &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return nil, tls.ConnectionState{}, mockedErr
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
ctx := context.Background()
|
||||
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if !errors.Is(err, mockedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("state is not a zero value")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when nil", func(t *testing.T) {
|
||||
mockedErr := errors.New("mocked")
|
||||
tx := &Trace{
|
||||
NewTLSHandshakerStdlibFn: nil,
|
||||
}
|
||||
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
tcpConn := &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
}
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, mockedErr
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||
if !errors.Is(err, mockedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("state is not a zero value")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("TimeNowFn works as intended", func(t *testing.T) {
|
||||
fixedTime := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
|
||||
tx := &Trace{
|
||||
TimeNowFn: func() time.Time {
|
||||
return fixedTime
|
||||
},
|
||||
}
|
||||
if !tx.TimeNow().Equal(fixedTime) {
|
||||
t.Fatal("we cannot override time.Now calls")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TimeSince works as intended", func(t *testing.T) {
|
||||
t0 := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
|
||||
t1 := t0.Add(10 * time.Second)
|
||||
tx := &Trace{
|
||||
TimeNowFn: func() time.Time {
|
||||
return t1
|
||||
},
|
||||
}
|
||||
if tx.TimeSince(t0) != 10*time.Second {
|
||||
t.Fatal("apparently Trace.Since is broken")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -4,7 +4,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/fakefill"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestArchivalExtSpec(t *testing.T) {
|
||||
|
@ -301,7 +301,7 @@ func TestHTTPBody(t *testing.T) {
|
|||
// we make a mistake and apply the above change (which will in turn
|
||||
// break correct JSON serialization), the this test will fail.
|
||||
var body ArchivalHTTPBody
|
||||
ff := &fakefill.Filler{}
|
||||
ff := &testingx.FakeFiller{}
|
||||
ff.Fill(&body)
|
||||
data := ArchivalMaybeBinaryData(body)
|
||||
if diff := cmp.Diff(body, data); diff != "" {
|
||||
|
|
45
internal/model/mocks/trace.go
Normal file
45
internal/model/mocks/trace.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package mocks
|
||||
|
||||
//
|
||||
// Mocks for model.Trace
|
||||
//
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// Trace allows mocking model.Trace.
|
||||
type Trace struct {
|
||||
MockTimeNow func() time.Time
|
||||
|
||||
MockOnConnectDone func(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
|
||||
|
||||
MockOnTLSHandshakeStart func(now time.Time, remoteAddr string, config *tls.Config)
|
||||
|
||||
MockOnTLSHandshakeDone func(started time.Time, remoteAddr string, config *tls.Config,
|
||||
state tls.ConnectionState, err error, finished time.Time)
|
||||
}
|
||||
|
||||
var _ model.Trace = &Trace{}
|
||||
|
||||
func (t *Trace) TimeNow() time.Time {
|
||||
return t.MockTimeNow()
|
||||
}
|
||||
|
||||
func (t *Trace) OnConnectDone(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished)
|
||||
}
|
||||
|
||||
func (t *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
t.MockOnTLSHandshakeStart(now, remoteAddr, config)
|
||||
}
|
||||
|
||||
func (t *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||
state tls.ConnectionState, err error, finished time.Time) {
|
||||
t.MockOnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
|
||||
}
|
74
internal/model/mocks/trace_test.go
Normal file
74
internal/model/mocks/trace_test.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTrace(t *testing.T) {
|
||||
t.Run("TimeNow", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
tx := &Trace{
|
||||
MockTimeNow: func() time.Time {
|
||||
return now
|
||||
},
|
||||
}
|
||||
if !tx.TimeNow().Equal(now) {
|
||||
t.Fatal("not working as intended")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnConnectDone", func(t *testing.T) {
|
||||
var called bool
|
||||
tx := &Trace{
|
||||
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
tx.OnConnectDone(
|
||||
time.Now(),
|
||||
"tcp",
|
||||
"dns.google",
|
||||
"8.8.8.8:443",
|
||||
nil,
|
||||
time.Now(),
|
||||
)
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnTLSHandshakeStart", func(t *testing.T) {
|
||||
var called bool
|
||||
tx := &Trace{
|
||||
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
tx.OnTLSHandshakeStart(time.Now(), "8.8.8.8:443", &tls.Config{})
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnTLSHandshakeDone", func(t *testing.T) {
|
||||
var called bool
|
||||
tx := &Trace{
|
||||
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
tx.OnTLSHandshakeDone(
|
||||
time.Now(),
|
||||
"8.8.8.8:443",
|
||||
&tls.Config{},
|
||||
tls.ConnectionState{},
|
||||
nil,
|
||||
time.Now(),
|
||||
)
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -292,6 +292,76 @@ type TLSHandshaker interface {
|
|||
net.Conn, tls.ConnectionState, error)
|
||||
}
|
||||
|
||||
// Trace allows to collect measurement traces. A trace is injected into
|
||||
// netx operations using context.WithValue. Netx code retrieves the trace
|
||||
// using context.Value. See docs/design/dd-003-step-by-step.md for the
|
||||
// design document explaining why we implemented context-based tracing.
|
||||
type Trace interface {
|
||||
// TimeNow returns the current time. Normally, this should be the same
|
||||
// value returned by time.Now but you may want to manipulate the time
|
||||
// returned when testing to have deterministic tests. To this end, you
|
||||
// can use functionality exported by the ./internal/testingx pkg.
|
||||
TimeNow() time.Time
|
||||
|
||||
// OnConnectDone is called when connect terminates.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - started is when we called connect;
|
||||
//
|
||||
// - network is the network we're using (one of "tcp" and "udp");
|
||||
//
|
||||
// - domain is the domain for which we're calling connect. If the user called
|
||||
// connect for an IP address and a port, then domain will be an IP address;
|
||||
//
|
||||
// - remoteAddr is the TCP endpoint with which we are connecting: it will
|
||||
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
|
||||
//
|
||||
// - err is the result of connect: either an error or nil;
|
||||
//
|
||||
// - finished is when connect returned.
|
||||
//
|
||||
// The error passed to this function will always be wrapped such that the
|
||||
// string returned by Error is an OONI error.
|
||||
OnConnectDone(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
|
||||
|
||||
// OnTLSHandshakeStart is called when the TLS handshake starts.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - now is the moment before we start the handshake;
|
||||
//
|
||||
// - remoteAddr is the TCP endpoint with which we are connecting: it will
|
||||
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
|
||||
//
|
||||
// - config is the non-nil TLS config we're using.
|
||||
OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config)
|
||||
|
||||
// OnTLSHandshakeDone is called when the TLS handshake terminates.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - started is when we started the handshake;
|
||||
//
|
||||
// - remoteAddr is the TCP endpoint with which we are connecting: it will
|
||||
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
|
||||
//
|
||||
// - config is the non-nil TLS config we're using;
|
||||
//
|
||||
// - state is the state of the TLS connection after the handshake, where all
|
||||
// fields are zero-initialized if the handshake failed;
|
||||
//
|
||||
// - err is the result of the handshake: either an error or nil;
|
||||
//
|
||||
// - finished is right after the handshake.
|
||||
//
|
||||
// The error passed to this function will always be wrapped such that the
|
||||
// string returned by Error is an OONI error.
|
||||
OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||
state tls.ConnectionState, err error, finished time.Time)
|
||||
}
|
||||
|
||||
// UDPLikeConn is a net.PacketConn with some extra functions
|
||||
// required to convince the QUIC library (lucas-clemente/quic-go)
|
||||
// to inflate the receive buffer of the connection.
|
||||
|
|
|
@ -47,7 +47,7 @@ func (r *bogonResolver) LookupHost(ctx context.Context, hostname string) ([]stri
|
|||
for _, addr := range addrs {
|
||||
if IsBogon(addr) {
|
||||
// wrap ErrDNSBogon as documented
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, ErrDNSBogon)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, ErrDNSBogon)
|
||||
}
|
||||
}
|
||||
return addrs, nil
|
||||
|
|
|
@ -15,8 +15,8 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/scrubber"
|
||||
)
|
||||
|
||||
// classifyGenericError is maps an error occurred during an operation
|
||||
// to an OONI failure string. This specific classifier is the most
|
||||
// ClassifyGenericError maps an error occurred during an operation to
|
||||
// an OONI failure string. This specific classifier is the most
|
||||
// generic one. You usually use it when mapping I/O errors. You should
|
||||
// check whether there is a specific classifier for more specific
|
||||
// operations (e.g., DNS resolution, TLS handshake).
|
||||
|
@ -38,7 +38,7 @@ import (
|
|||
// If everything else fails, this classifier returns a string
|
||||
// like "unknown_failure: XXX" where XXX has been scrubbed
|
||||
// so to remove any network endpoints from the original error string.
|
||||
func classifyGenericError(err error) string {
|
||||
func ClassifyGenericError(err error) string {
|
||||
// The list returned here matches the values used by MK unless
|
||||
// explicitly noted otherwise with a comment.
|
||||
|
||||
|
@ -139,7 +139,7 @@ const (
|
|||
quicTLSUnrecognizedName = 112
|
||||
)
|
||||
|
||||
// classifyQUICHandshakeError maps errors during a QUIC
|
||||
// ClassifyQUICHandshakeError maps errors during a QUIC
|
||||
// handshake to OONI failure strings.
|
||||
//
|
||||
// If the input error is an *ErrWrapper we don't perform
|
||||
|
@ -147,7 +147,7 @@ const (
|
|||
//
|
||||
// If this classifier fails, it calls ClassifyGenericError
|
||||
// and returns to the caller its return value.
|
||||
func classifyQUICHandshakeError(err error) string {
|
||||
func ClassifyQUICHandshakeError(err error) string {
|
||||
|
||||
// Robustness: handle the case where we're passed a wrapped error.
|
||||
var errwrapper *ErrWrapper
|
||||
|
@ -207,7 +207,7 @@ func classifyQUICHandshakeError(err error) string {
|
|||
}
|
||||
}
|
||||
}
|
||||
return classifyGenericError(err)
|
||||
return ClassifyGenericError(err)
|
||||
}
|
||||
|
||||
// quicIsCertificateError tells us whether a specific TLS alert error
|
||||
|
@ -277,7 +277,7 @@ var (
|
|||
// anything as explained in getaddrinfo_linux.go.
|
||||
var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
|
||||
|
||||
// classifyResolverError maps DNS resolution errors to
|
||||
// ClassifyResolverError maps DNS resolution errors to
|
||||
// OONI failure strings.
|
||||
//
|
||||
// If the input error is an *ErrWrapper we don't perform
|
||||
|
@ -285,7 +285,7 @@ var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
|
|||
//
|
||||
// If this classifier fails, it calls ClassifyGenericError and
|
||||
// returns to the caller its return value.
|
||||
func classifyResolverError(err error) string {
|
||||
func ClassifyResolverError(err error) string {
|
||||
|
||||
// Robustness: handle the case where we're passed a wrapped error.
|
||||
var errwrapper *ErrWrapper
|
||||
|
@ -310,10 +310,10 @@ func classifyResolverError(err error) string {
|
|||
if errors.Is(err, ErrAndroidDNSCacheNoData) {
|
||||
return FailureAndroidDNSCacheNoData
|
||||
}
|
||||
return classifyGenericError(err)
|
||||
return ClassifyGenericError(err)
|
||||
}
|
||||
|
||||
// classifyTLSHandshakeError maps an error occurred during the TLS
|
||||
// ClassifyTLSHandshakeError maps an error occurred during the TLS
|
||||
// handshake to an OONI failure string.
|
||||
//
|
||||
// If the input error is an *ErrWrapper we don't perform
|
||||
|
@ -321,7 +321,7 @@ func classifyResolverError(err error) string {
|
|||
//
|
||||
// If this classifier fails, it calls ClassifyGenericError and
|
||||
// returns to the caller its return value.
|
||||
func classifyTLSHandshakeError(err error) string {
|
||||
func ClassifyTLSHandshakeError(err error) string {
|
||||
|
||||
// Robustness: handle the case where we're passed a wrapped error.
|
||||
var errwrapper *ErrWrapper
|
||||
|
@ -345,5 +345,5 @@ func classifyTLSHandshakeError(err error) string {
|
|||
// Test case: https://expired.badssl.com/
|
||||
return FailureSSLInvalidCertificate
|
||||
}
|
||||
return classifyGenericError(err)
|
||||
return ClassifyGenericError(err)
|
||||
}
|
||||
|
|
|
@ -18,13 +18,13 @@ func TestClassifyGenericError(t *testing.T) {
|
|||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyGenericError(err) != FailureEOFError {
|
||||
if ClassifyGenericError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for a system call error", func(t *testing.T) {
|
||||
if classifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
|
||||
if ClassifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
@ -35,63 +35,63 @@ func TestClassifyGenericError(t *testing.T) {
|
|||
// is just an implementation detail.
|
||||
|
||||
t.Run("for operation was canceled", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
|
||||
if ClassifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for EOF", func(t *testing.T) {
|
||||
if classifyGenericError(io.EOF) != FailureEOFError {
|
||||
if ClassifyGenericError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for context deadline exceeded", func(t *testing.T) {
|
||||
if classifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for stun's transaction is timed out", func(t *testing.T) {
|
||||
if classifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for i/o timeout", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for TLS handshake timeout", func(t *testing.T) {
|
||||
err := errors.New("net/http: TLS handshake timeout")
|
||||
if classifyGenericError(err) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(err) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for no such host", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
|
||||
if ClassifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for dns server misbehaving", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
|
||||
if ClassifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for no answer from DNS server", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
|
||||
if ClassifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for use of closed network connection", func(t *testing.T) {
|
||||
err := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
|
||||
if classifyGenericError(err) != FailureConnectionAlreadyClosed {
|
||||
if ClassifyGenericError(err) != FailureConnectionAlreadyClosed {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
@ -99,7 +99,7 @@ func TestClassifyGenericError(t *testing.T) {
|
|||
// Now we're back in ClassifyGenericError
|
||||
|
||||
t.Run("for context.Canceled", func(t *testing.T) {
|
||||
if classifyGenericError(context.Canceled) != FailureInterrupted {
|
||||
if ClassifyGenericError(context.Canceled) != FailureInterrupted {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
@ -108,7 +108,7 @@ func TestClassifyGenericError(t *testing.T) {
|
|||
t.Run("with an IPv4 address", func(t *testing.T) {
|
||||
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: some error")
|
||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
||||
out := classifyGenericError(input)
|
||||
out := ClassifyGenericError(input)
|
||||
if out != expected {
|
||||
t.Fatal(cmp.Diff(expected, out))
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ func TestClassifyGenericError(t *testing.T) {
|
|||
t.Run("with an IPv6 address", func(t *testing.T) {
|
||||
input := errors.New("read tcp [::1]:56948->[::1]:443: some error")
|
||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
||||
out := classifyGenericError(input)
|
||||
out := ClassifyGenericError(input)
|
||||
if out != expected {
|
||||
t.Fatal(cmp.Diff(expected, out))
|
||||
}
|
||||
|
@ -131,100 +131,100 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
|
|||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyQUICHandshakeError(err) != FailureEOFError {
|
||||
if ClassifyQUICHandshakeError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for incompatible quic version", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
|
||||
if ClassifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for stateless reset", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
|
||||
if ClassifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for handshake timeout", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
|
||||
if ClassifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for idle timeout", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
|
||||
if ClassifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for connection refused", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for bad certificate", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertBadCertificate
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for unsupported certificate", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertUnsupportedCertificate
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for certificate expired", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertCertificateExpired
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for certificate revoked", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertCertificateRevoked
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for certificate unknown", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertCertificateUnknown
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for decrypt error", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertDecryptError
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for handshake failure", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertHandshakeFailure
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for unknown CA", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertUnknownCA
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for unrecognized hostname", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSUnrecognizedName
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
@ -234,13 +234,13 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
|
|||
ErrorCode: quic.InternalError,
|
||||
ErrorMessage: FailureHostUnreachable,
|
||||
}
|
||||
if classifyQUICHandshakeError(err) != FailureHostUnreachable {
|
||||
if ClassifyQUICHandshakeError(err) != FailureHostUnreachable {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for another kind of error", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(io.EOF) != FailureEOFError {
|
||||
if ClassifyQUICHandshakeError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
@ -252,43 +252,43 @@ func TestClassifyResolverError(t *testing.T) {
|
|||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyResolverError(err) != FailureEOFError {
|
||||
if ClassifyResolverError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for ErrDNSBogon", func(t *testing.T) {
|
||||
if classifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
|
||||
if ClassifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for refused", func(t *testing.T) {
|
||||
if classifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
|
||||
if ClassifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for servfail", func(t *testing.T) {
|
||||
if classifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
|
||||
if ClassifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for dns reply with wrong queryID", func(t *testing.T) {
|
||||
if classifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
|
||||
if ClassifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for EAI_NODATA returned by Android's getaddrinfo", func(t *testing.T) {
|
||||
if classifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
|
||||
if ClassifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for another kind of error", func(t *testing.T) {
|
||||
if classifyResolverError(io.EOF) != FailureEOFError {
|
||||
if ClassifyResolverError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
@ -300,34 +300,34 @@ func TestClassifyTLSHandshakeError(t *testing.T) {
|
|||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyTLSHandshakeError(err) != FailureEOFError {
|
||||
if ClassifyTLSHandshakeError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for x509.HostnameError", func(t *testing.T) {
|
||||
var err x509.HostnameError
|
||||
if classifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
|
||||
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
|
||||
var err x509.UnknownAuthorityError
|
||||
if classifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
|
||||
if ClassifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
|
||||
var err x509.CertificateInvalidError
|
||||
if classifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
|
||||
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for another kind of error", func(t *testing.T) {
|
||||
if classifyTLSHandshakeError(io.EOF) != FailureEOFError {
|
||||
if ClassifyTLSHandshakeError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
|
|
@ -125,7 +125,7 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver,
|
|||
outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors
|
||||
}
|
||||
return &dialerLogger{
|
||||
Dialer: &dialerResolver{
|
||||
Dialer: &dialerResolverWithTracing{
|
||||
Dialer: &dialerLogger{
|
||||
Dialer: outDialer,
|
||||
DebugLogger: logger,
|
||||
|
@ -171,15 +171,24 @@ func (d *DialerSystem) CloseIdleConnections() {
|
|||
// nothing to do here
|
||||
}
|
||||
|
||||
// dialerResolver combines dialing with domain name resolution.
|
||||
type dialerResolver struct {
|
||||
// dialerResolverWithTracing combines dialing with domain name resolution and
|
||||
// implements hooks to trace TCP (or UDP) connect operations.
|
||||
type dialerResolverWithTracing struct {
|
||||
Dialer model.Dialer
|
||||
Resolver model.Resolver
|
||||
}
|
||||
|
||||
var _ model.Dialer = &dialerResolver{}
|
||||
var _ model.Dialer = &dialerResolverWithTracing{}
|
||||
|
||||
func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// DialContext implements model.Dialer.DialContext. Specifically this
|
||||
// method performs the following operations:
|
||||
//
|
||||
// 1. resolve the domain inside the address using a resolver;
|
||||
//
|
||||
// 2. cycle through the available IP addresses and try to dial each of them;
|
||||
//
|
||||
// 3. trace the TCP (or UDP) connect and allow wrapping the returned conn.
|
||||
func (d *dialerResolverWithTracing) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// QUIRK: this routine and the related routines in quirks.go cannot
|
||||
// be changed easily until we use events tracing to measure.
|
||||
//
|
||||
|
@ -194,10 +203,27 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
|
|||
}
|
||||
addrs = quirkSortIPAddrs(addrs)
|
||||
var errorslist []error
|
||||
trace := ContextTraceOrDefault(ctx)
|
||||
for _, addr := range addrs {
|
||||
target := net.JoinHostPort(addr, onlyport)
|
||||
started := trace.TimeNow()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, target)
|
||||
finished := trace.TimeNow()
|
||||
// TODO(bassosimone): to make the code robust to future refactoring we have
|
||||
// moved error wrapping inside this type. This change opens up the possibility
|
||||
// of simplifying the dialing chain by removing dialerErrWrapper. We'll be
|
||||
// able to implement this refactoring once netx is gone. We cannot complete
|
||||
// this refactoring _before_ because WrapDialer inserts extra wrappers
|
||||
// provided by netx in the dialers chain _before_ this dialer and the dialers
|
||||
// that netx insert assume that they wrap a dialer with error wrapping.
|
||||
//
|
||||
// Because error wrapping should be idempotent, it should not be a problem
|
||||
// to have two error wrapping dialers in the chain except that, of course, it
|
||||
// would be less efficient than just having a single wrapper.
|
||||
err = MaybeNewErrWrapper(ClassifyGenericError, ConnectOperation, err)
|
||||
trace.OnConnectDone(started, network, onlyhost, target, err, finished)
|
||||
if err == nil {
|
||||
conn = &dialerErrWrapperConn{conn}
|
||||
return conn, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
|
@ -206,14 +232,14 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
|
|||
}
|
||||
|
||||
// lookupHost ensures we correctly handle IP addresses.
|
||||
func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
func (d *dialerResolverWithTracing) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
return d.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
func (d *dialerResolver) CloseIdleConnections() {
|
||||
func (d *dialerResolverWithTracing) CloseIdleConnections() {
|
||||
d.Dialer.CloseIdleConnections()
|
||||
d.Resolver.CloseIdleConnections()
|
||||
}
|
||||
|
@ -303,7 +329,7 @@ var _ model.Dialer = &dialerErrWrapper{}
|
|||
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyGenericError, ConnectOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyGenericError, ConnectOperation, err)
|
||||
}
|
||||
return &dialerErrWrapperConn{Conn: conn}, nil
|
||||
}
|
||||
|
@ -322,7 +348,7 @@ var _ net.Conn = &dialerErrWrapperConn{}
|
|||
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
||||
count, err := c.Conn.Read(b)
|
||||
if err != nil {
|
||||
return 0, newErrWrapper(classifyGenericError, ReadOperation, err)
|
||||
return 0, NewErrWrapper(ClassifyGenericError, ReadOperation, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
@ -330,7 +356,7 @@ func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
|||
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
||||
count, err := c.Conn.Write(b)
|
||||
if err != nil {
|
||||
return 0, newErrWrapper(classifyGenericError, WriteOperation, err)
|
||||
return 0, NewErrWrapper(ClassifyGenericError, WriteOperation, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
@ -338,7 +364,7 @@ func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
|||
func (c *dialerErrWrapperConn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
if err != nil {
|
||||
return newErrWrapper(classifyGenericError, CloseOperation, err)
|
||||
return NewErrWrapper(ClassifyGenericError, CloseOperation, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestNewDialerWithStdlibResolver(t *testing.T) {
|
||||
|
@ -22,7 +23,7 @@ func TestNewDialerWithStdlibResolver(t *testing.T) {
|
|||
t.Fatal("invalid logger")
|
||||
}
|
||||
// typecheck the resolver
|
||||
reso := logger.Dialer.(*dialerResolver)
|
||||
reso := logger.Dialer.(*dialerResolverWithTracing)
|
||||
typecheckForSystemResolver(t, reso.Resolver, model.DiscardLogger)
|
||||
// typecheck the dialer
|
||||
logger = reso.Dialer.(*dialerLogger)
|
||||
|
@ -64,7 +65,7 @@ func TestNewDialer(t *testing.T) {
|
|||
if logger.DebugLogger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
reso := logger.Dialer.(*dialerResolver)
|
||||
reso := logger.Dialer.(*dialerResolverWithTracing)
|
||||
if _, okay := reso.Resolver.(*NullResolver); !okay {
|
||||
t.Fatal("invalid Resolver type")
|
||||
}
|
||||
|
@ -136,10 +137,10 @@ func TestDialerSystem(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDialerResolver(t *testing.T) {
|
||||
func TestDialerResolverWithTracing(t *testing.T) {
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
t.Run("fails without a port", func(t *testing.T) {
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &DialerSystem{},
|
||||
Resolver: NewUnwrappedStdlibResolver(),
|
||||
}
|
||||
|
@ -154,7 +155,7 @@ func TestDialerResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("handles dialing error correctly for single IP address", func(t *testing.T) {
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
|
@ -166,13 +167,26 @@ func TestDialerResolver(t *testing.T) {
|
|||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("the error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != FailureEOFError {
|
||||
t.Fatal("invalid wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != ConnectOperation {
|
||||
t.Fatal("invalid wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||
t.Fatal("invalid wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles dialing error correctly for many IP addresses", func(t *testing.T) {
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
|
@ -188,6 +202,19 @@ func TestDialerResolver(t *testing.T) {
|
|||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("the error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != FailureEOFError {
|
||||
t.Fatal("invalid wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != ConnectOperation {
|
||||
t.Fatal("invalid wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||
t.Fatal("invalid wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
|
@ -199,7 +226,7 @@ func TestDialerResolver(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
}
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return expectedConn, nil
|
||||
|
@ -215,7 +242,10 @@ func TestDialerResolver(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn != expectedConn {
|
||||
// Ensure that the dialer returns a connection that is already wrapping errors,
|
||||
// which is a new behavior since https://github.com/ooni/probe-cli/pull/815
|
||||
errWrapperConn := conn.(*dialerErrWrapperConn)
|
||||
if errWrapperConn.Conn != expectedConn {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
conn.Close()
|
||||
|
@ -225,7 +255,7 @@ func TestDialerResolver(t *testing.T) {
|
|||
// This test is fundamental to the following
|
||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||
mu := &sync.Mutex{}
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
|
@ -257,7 +287,7 @@ func TestDialerResolver(t *testing.T) {
|
|||
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||
mu := &sync.Mutex{}
|
||||
var attempts []string
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
|
@ -298,14 +328,14 @@ func TestDialerResolver(t *testing.T) {
|
|||
mu := &sync.Mutex{}
|
||||
errorsList := []error{
|
||||
errors.New("a mocked error"),
|
||||
newErrWrapper(
|
||||
classifyGenericError,
|
||||
NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
io.EOF,
|
||||
),
|
||||
}
|
||||
var errorIdx int
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
|
@ -337,17 +367,18 @@ func TestDialerResolver(t *testing.T) {
|
|||
t.Run("though ignores the unknown failures", func(t *testing.T) {
|
||||
// This test is fundamental to the following
|
||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||
expectedErr := errors.New("a mocked error")
|
||||
mu := &sync.Mutex{}
|
||||
errorsList := []error{
|
||||
errors.New("a mocked error"),
|
||||
newErrWrapper(
|
||||
classifyGenericError,
|
||||
expectedErr,
|
||||
NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
errors.New("antani"),
|
||||
errors.New("antani"), // this is an unknown failure and we should not return it
|
||||
),
|
||||
}
|
||||
var errorIdx int
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
|
@ -368,18 +399,95 @@ func TestDialerResolver(t *testing.T) {
|
|||
},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
|
||||
if err == nil || err.Error() != "a mocked error" {
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != "unknown_failure: a mocked error" {
|
||||
t.Fatal("unexpected wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != ConnectOperation {
|
||||
t.Fatal("unexpected wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, expectedErr) {
|
||||
t.Fatal("unexpected wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses a context-injected custom trace", func(t *testing.T) {
|
||||
var (
|
||||
called bool
|
||||
domainOK bool
|
||||
networkOK bool
|
||||
remoteAddrOK bool
|
||||
startTimeOK bool
|
||||
finishTimeOK bool
|
||||
wrappedErr bool
|
||||
)
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
var ew *ErrWrapper
|
||||
called = true
|
||||
domainOK = (domain == "1.1.1.1")
|
||||
networkOK = (network == "tcp")
|
||||
remoteAddrOK = (remoteAddr == "1.1.1.1:853")
|
||||
startTimeOK = (started.Sub(zeroTime) == 0)
|
||||
finishTimeOK = (finished.Sub(zeroTime) == time.Second)
|
||||
wrappedErr = errors.As(err, &ew) && ew.Failure == FailureEOFError
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
},
|
||||
},
|
||||
Resolver: &NullResolver{},
|
||||
}
|
||||
conn, err := d.DialContext(ctx, "tcp", "1.1.1.1:853")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
if !domainOK {
|
||||
t.Fatal("domain was not okay")
|
||||
}
|
||||
if !networkOK {
|
||||
t.Fatal("network was not okay")
|
||||
}
|
||||
if !remoteAddrOK {
|
||||
t.Fatal("remoteAddr was not okay")
|
||||
}
|
||||
if !startTimeOK {
|
||||
t.Fatal("start time was not okay")
|
||||
}
|
||||
if !finishTimeOK {
|
||||
t.Fatal("finish time was not okay")
|
||||
}
|
||||
if !wrappedErr {
|
||||
t.Fatal("not wrapped")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("lookupHost", func(t *testing.T) {
|
||||
t.Run("handles addresses correctly", func(t *testing.T) {
|
||||
dialer := &dialerResolver{
|
||||
dialer := &dialerResolverWithTracing{
|
||||
Dialer: &DialerSystem{},
|
||||
Resolver: &NullResolver{},
|
||||
}
|
||||
|
@ -393,7 +501,7 @@ func TestDialerResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("fails correctly on lookup error", func(t *testing.T) {
|
||||
dialer := &dialerResolver{
|
||||
dialer := &dialerResolverWithTracing{
|
||||
Dialer: &DialerSystem{},
|
||||
Resolver: &NullResolver{},
|
||||
}
|
||||
|
@ -413,7 +521,7 @@ func TestDialerResolver(t *testing.T) {
|
|||
calledDialer bool
|
||||
calledResolver bool
|
||||
)
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
calledDialer = true
|
||||
|
|
|
@ -38,7 +38,7 @@ func (t *dnsTransportErrWrapper) RoundTrip(
|
|||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
resp, err := t.DNSTransport.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, DNSRoundTripOperation, err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
|
|
@ -34,6 +34,10 @@
|
|||
//
|
||||
// We want to have reasonable watchdog timeouts for each operation.
|
||||
//
|
||||
// We also want lightweight support for tracing network events. To this end, we
|
||||
// use context.WithValue and context.Value to inject, and retrieve, a model.Trace
|
||||
// implementation OPTIONALLY configured by the user.
|
||||
//
|
||||
// See also the design document at docs/design/dd-003-step-by-step.md,
|
||||
// which provides an overview of netxlite's main concerns.
|
||||
//
|
||||
|
|
|
@ -71,7 +71,7 @@ func (e *ErrWrapper) MarshalJSON() ([]byte, error) {
|
|||
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
|
||||
type classifier func(err error) string
|
||||
|
||||
// newErrWrapper creates a new ErrWrapper using the given
|
||||
// NewErrWrapper creates a new ErrWrapper using the given
|
||||
// classifier, operation name, and underlying error.
|
||||
//
|
||||
// This function panics if classifier is nil, or operation
|
||||
|
@ -81,7 +81,7 @@ type classifier func(err error) string
|
|||
// error wrapper will use the same classification string and
|
||||
// will determine whether to keep the major operation as documented
|
||||
// in the ErrWrapper.Operation documentation.
|
||||
func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||
func NewErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||
var wrapper *ErrWrapper
|
||||
if errors.As(err, &wrapper) {
|
||||
return &ErrWrapper{
|
||||
|
@ -106,6 +106,19 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(https://github.com/ooni/probe/issues/2163): we can really
|
||||
// simplify the error wrapping situation here by just dropping
|
||||
// NewErrWrapper and always using MaybeNewErrWrapper.
|
||||
|
||||
// MaybeNewErrWrapper is like NewErrWrapper except that this
|
||||
// function won't panic if passed a nil error.
|
||||
func MaybeNewErrWrapper(c classifier, op string, err error) error {
|
||||
if err != nil {
|
||||
return NewErrWrapper(c, op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewTopLevelGenericErrWrapper wraps an error occurring at top
|
||||
// level using a generic classifier as classifier. This is the
|
||||
// function you should call when you suspect a given error hasn't
|
||||
|
@ -115,7 +128,7 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
|||
// error wrapper will use the same classification string and
|
||||
// failed operation of the original error.
|
||||
func NewTopLevelGenericErrWrapper(err error) *ErrWrapper {
|
||||
return newErrWrapper(classifyGenericError, TopLevelOperation, err)
|
||||
return NewErrWrapper(ClassifyGenericError, TopLevelOperation, err)
|
||||
}
|
||||
|
||||
func classifyOperation(ew *ErrWrapper, operation string) string {
|
||||
|
|
|
@ -53,7 +53,7 @@ func TestNewErrWrapper(t *testing.T) {
|
|||
recovered.Add(1)
|
||||
}
|
||||
}()
|
||||
newErrWrapper(nil, CloseOperation, io.EOF)
|
||||
NewErrWrapper(nil, CloseOperation, io.EOF)
|
||||
}()
|
||||
if recovered.Load() != 1 {
|
||||
t.Fatal("did not panic")
|
||||
|
@ -68,7 +68,7 @@ func TestNewErrWrapper(t *testing.T) {
|
|||
recovered.Add(1)
|
||||
}
|
||||
}()
|
||||
newErrWrapper(classifyGenericError, "", io.EOF)
|
||||
NewErrWrapper(ClassifyGenericError, "", io.EOF)
|
||||
}()
|
||||
if recovered.Load() != 1 {
|
||||
t.Fatal("did not panic")
|
||||
|
@ -83,7 +83,7 @@ func TestNewErrWrapper(t *testing.T) {
|
|||
recovered.Add(1)
|
||||
}
|
||||
}()
|
||||
newErrWrapper(classifyGenericError, CloseOperation, nil)
|
||||
NewErrWrapper(ClassifyGenericError, CloseOperation, nil)
|
||||
}()
|
||||
if recovered.Load() != 1 {
|
||||
t.Fatal("did not panic")
|
||||
|
@ -91,7 +91,7 @@ func TestNewErrWrapper(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("otherwise, works as intended", func(t *testing.T) {
|
||||
ew := newErrWrapper(classifyGenericError, CloseOperation, io.EOF)
|
||||
ew := NewErrWrapper(ClassifyGenericError, CloseOperation, io.EOF)
|
||||
if ew.Failure != FailureEOFError {
|
||||
t.Fatal("unexpected failure")
|
||||
}
|
||||
|
@ -104,10 +104,10 @@ func TestNewErrWrapper(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("when the underlying error is already a wrapped error", func(t *testing.T) {
|
||||
ew := newErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||
ew := NewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||
var err1 error = ew
|
||||
err2 := fmt.Errorf("cannot read: %w", err1)
|
||||
ew2 := newErrWrapper(classifyGenericError, HTTPRoundTripOperation, err2)
|
||||
ew2 := NewErrWrapper(ClassifyGenericError, HTTPRoundTripOperation, err2)
|
||||
if ew2.Failure != ew.Failure {
|
||||
t.Fatal("not the same failure")
|
||||
}
|
||||
|
@ -117,6 +117,34 @@ func TestNewErrWrapper(t *testing.T) {
|
|||
if ew2.WrappedErr != err2 {
|
||||
t.Fatal("invalid underlying error")
|
||||
}
|
||||
// Make sure we can still use errors.Is with two layers of wrapping
|
||||
if !errors.Is(ew2, ECONNRESET) {
|
||||
t.Fatal("we cannot use errors.Is to retrieve the real syscall error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaybeNewErrWrapper(t *testing.T) {
|
||||
// TODO(https://github.com/ooni/probe/issues/2163): we can really
|
||||
// simplify the error wrapping situation here by just dropping
|
||||
// NewErrWrapper and always using MaybeNewErrWrapper.
|
||||
|
||||
t.Run("when we pass a nil error to this function", func(t *testing.T) {
|
||||
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, nil)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected output", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when we pass a non-nil error to this function", func(t *testing.T) {
|
||||
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||
if !errors.Is(err, ECONNRESET) {
|
||||
t.Fatal("unexpected output", err)
|
||||
}
|
||||
var ew *ErrWrapper
|
||||
if !errors.As(err, &ew) {
|
||||
t.Fatal("not an instance of ErrWrapper")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
|
|||
dialer := txpCc.Dialer
|
||||
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
||||
dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolver)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
|
||||
if dialerReso.Resolver != resolver {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
|
|||
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
||||
dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer)
|
||||
dialerLog := dialerProxy.Dialer.(*dialerLogger)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolver)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
|
||||
if dialerReso.Resolver != resolver {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
|
@ -269,7 +269,7 @@ func TestNewHTTPTransport(t *testing.T) {
|
|||
t.Run("works as intended with failing dialer", func(t *testing.T) {
|
||||
called := &atomicx.Int64{}
|
||||
expected := errors.New("mocked error")
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context,
|
||||
network, address string) (net.Conn, error) {
|
||||
|
@ -612,7 +612,7 @@ func TestNewHTTPClientWithResolver(t *testing.T) {
|
|||
txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser)
|
||||
dialer := txpCc.Dialer.(*httpDialerWithReadTimeout)
|
||||
dialerLogger := dialer.Dialer.(*dialerLogger)
|
||||
dialerReso := dialerLogger.Dialer.(*dialerResolver)
|
||||
dialerReso := dialerLogger.Dialer.(*dialerResolverWithTracing)
|
||||
if dialerReso.Resolver != reso {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ func TestReadAllContext(t *testing.T) {
|
|||
//
|
||||
// Note: Returning a wrapped error to ensure we address
|
||||
// https://github.com/ooni/probe/issues/1965
|
||||
return len(b), newErrWrapper(classifyGenericError,
|
||||
return len(b), NewErrWrapper(ClassifyGenericError,
|
||||
ReadOperation, io.EOF)
|
||||
},
|
||||
}
|
||||
|
@ -171,7 +171,7 @@ func TestCopyContext(t *testing.T) {
|
|||
//
|
||||
// Note: Returning a wrapped error to ensure we address
|
||||
// https://github.com/ooni/probe/issues/1965
|
||||
return len(b), newErrWrapper(classifyGenericError,
|
||||
return len(b), NewErrWrapper(ClassifyGenericError,
|
||||
ReadOperation, io.EOF)
|
||||
},
|
||||
}
|
||||
|
|
|
@ -380,7 +380,7 @@ var _ model.QUICListener = &quicListenerErrWrapper{}
|
|||
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
|
||||
pconn, err := qls.QUICListener.Listen(addr)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyGenericError, QUICListenOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyGenericError, QUICListenOperation, err)
|
||||
}
|
||||
return &quicErrWrapperUDPLikeConn{pconn}, nil
|
||||
}
|
||||
|
@ -397,7 +397,7 @@ var _ model.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
|
|||
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
count, err := c.UDPLikeConn.WriteTo(p, addr)
|
||||
if err != nil {
|
||||
return 0, newErrWrapper(classifyGenericError, WriteToOperation, err)
|
||||
return 0, NewErrWrapper(ClassifyGenericError, WriteToOperation, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
@ -406,7 +406,7 @@ func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error
|
|||
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, addr, err := c.UDPLikeConn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return 0, nil, newErrWrapper(classifyGenericError, ReadFromOperation, err)
|
||||
return 0, nil, NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
|
@ -415,7 +415,7 @@ func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
|||
func (c *quicErrWrapperUDPLikeConn) Close() error {
|
||||
err := c.UDPLikeConn.Close()
|
||||
if err != nil {
|
||||
return newErrWrapper(classifyGenericError, ReadFromOperation, err)
|
||||
return NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -433,8 +433,8 @@ func (d *quicDialerErrWrapper) DialContext(
|
|||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
qconn, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(
|
||||
classifyQUICHandshakeError, QUICHandshakeOperation, err)
|
||||
return nil, NewErrWrapper(
|
||||
ClassifyQUICHandshakeError, QUICHandshakeOperation, err)
|
||||
}
|
||||
return qconn, nil
|
||||
}
|
||||
|
|
|
@ -34,13 +34,13 @@ func TestQuirkReduceErrors(t *testing.T) {
|
|||
|
||||
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
||||
err1 := errors.New("mocked error #1")
|
||||
err2 := newErrWrapper(
|
||||
classifyGenericError,
|
||||
err2 := NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
errors.New("antani"),
|
||||
)
|
||||
err3 := newErrWrapper(
|
||||
classifyGenericError,
|
||||
err3 := NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
ECONNREFUSED,
|
||||
)
|
||||
|
|
|
@ -387,7 +387,7 @@ var _ model.Resolver = &resolverErrWrapper{}
|
|||
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
@ -396,7 +396,7 @@ func (r *resolverErrWrapper) LookupHTTPS(
|
|||
ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||
out, err := r.Resolver.LookupHTTPS(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
@ -417,7 +417,7 @@ func (r *resolverErrWrapper) LookupNS(
|
|||
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
out, err := r.Resolver.LookupNS(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
|
|
@ -18,9 +18,6 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
// TODO(bassosimone): check whether there's now equivalent functionality
|
||||
// inside the standard library allowing us to map numbers to names.
|
||||
|
||||
var (
|
||||
tlsVersionString = map[uint16]string{
|
||||
tls.VersionTLS10: "TLSv1",
|
||||
|
@ -85,6 +82,13 @@ func TLSVersionString(value uint16) string {
|
|||
// the value to a cipher suite name, we return `TLS_CIPHER_SUITE_UNKNOWN_ddd`
|
||||
// where `ddd` is the numeric value passed to this function.
|
||||
func TLSCipherSuiteString(value uint16) string {
|
||||
// TODO(https://github.com/ooni/probe/issues/2166): the standard library has a
|
||||
// function for mapping a cipher suite to a string, but the value returned in case of
|
||||
// missing cipher suite is different from the one we would return
|
||||
// here. We could consider simplifying this code anyway because
|
||||
// in most, if not all, cases we have a valid cipher suite and we
|
||||
// just need to make sure what the spec says we should do when
|
||||
// passed an unknown cipher suite.
|
||||
if str, found := tlsCipherSuiteString[value]; found {
|
||||
return str
|
||||
}
|
||||
|
@ -158,15 +162,15 @@ func NewTLSHandshakerStdlib(logger model.DebugLogger) model.TLSHandshaker {
|
|||
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
|
||||
func newTLSHandshaker(th model.TLSHandshaker, logger model.DebugLogger) model.TLSHandshaker {
|
||||
return &tlsHandshakerLogger{
|
||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: th,
|
||||
},
|
||||
DebugLogger: logger,
|
||||
TLSHandshaker: th,
|
||||
DebugLogger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// tlsHandshakerConfigurable is a configurable TLS handshaker that
|
||||
// uses by default the standard library's TLS implementation.
|
||||
//
|
||||
// This type also implements error wrapping and events tracing.
|
||||
type tlsHandshakerConfigurable struct {
|
||||
// NewConn is the OPTIONAL factory for creating a new connection. If
|
||||
// this factory is not set, we'll use the stdlib.
|
||||
|
@ -183,9 +187,20 @@ var _ model.TLSHandshaker = &tlsHandshakerConfigurable{}
|
|||
// value into a private variable to enable for unit testing.
|
||||
var defaultCertPool = NewDefaultCertPool()
|
||||
|
||||
// tlsMaybeConnectionState returns the connection state if error is nil
|
||||
// and otherwise just returns an empty state to the caller.
|
||||
func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState {
|
||||
if err != nil {
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
return conn.ConnectionState()
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake. This function will
|
||||
// configure the code to use the built-in Mozilla CA if the config
|
||||
// field contains a nil RootCAs field.
|
||||
//
|
||||
// This function will also emit TLS-handshake-related tracing events.
|
||||
func (h *tlsHandshakerConfigurable) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
|
@ -203,10 +218,19 @@ func (h *tlsHandshakerConfigurable) Handshake(
|
|||
if err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
if err := tlsconn.HandshakeContext(ctx); err != nil {
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
trace := ContextTraceOrDefault(ctx)
|
||||
started := trace.TimeNow()
|
||||
trace.OnTLSHandshakeStart(started, remoteAddr, config)
|
||||
err = tlsconn.HandshakeContext(ctx)
|
||||
err = MaybeNewErrWrapper(ClassifyTLSHandshakeError, TLSHandshakeOperation, err)
|
||||
finished := trace.TimeNow()
|
||||
state := tlsMaybeConnectionState(tlsconn, err)
|
||||
trace.OnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
|
||||
if err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
return tlsconn, tlsconn.ConnectionState(), nil
|
||||
return tlsconn, state, nil
|
||||
}
|
||||
|
||||
// newConn creates a new TLSConn.
|
||||
|
@ -352,23 +376,6 @@ func (d *tlsDialerSingleUseAdapter) CloseIdleConnections() {
|
|||
d.Dialer.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
|
||||
type tlsHandshakerErrWrapper struct {
|
||||
TLSHandshaker model.TLSHandshaker
|
||||
}
|
||||
|
||||
// Handshake implements TLSHandshaker.Handshake
|
||||
func (h *tlsHandshakerErrWrapper) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
if err != nil {
|
||||
return nil, tls.ConnectionState{}, newErrWrapper(
|
||||
classifyTLSHandshakeError, TLSHandshakeOperation, err)
|
||||
}
|
||||
return tlsconn, state, nil
|
||||
}
|
||||
|
||||
// ErrNoTLSDialer is the type of error returned by "null" TLS dialers
|
||||
// when you attempt to dial with them.
|
||||
var ErrNoTLSDialer = errors.New("no configured TLS dialer")
|
||||
|
|
|
@ -16,7 +16,10 @@ import (
|
|||
|
||||
"github.com/apex/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestVersionString(t *testing.T) {
|
||||
|
@ -123,8 +126,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
|
|||
if logger.DebugLogger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if configurable.NewConn != nil {
|
||||
t.Fatal("expected nil NewConn")
|
||||
}
|
||||
|
@ -132,7 +134,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
|
|||
|
||||
func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
t.Run("Handshake", func(t *testing.T) {
|
||||
t.Run("with error", func(t *testing.T) {
|
||||
t.Run("with handshake I/O error", func(t *testing.T) {
|
||||
var times []time.Time
|
||||
h := &tlsHandshakerConfigurable{}
|
||||
tcpConn := &mocks.Conn{
|
||||
|
@ -143,14 +145,37 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
|||
times = append(times, t)
|
||||
return nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
||||
conn, state, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
||||
ServerName: "x.org",
|
||||
})
|
||||
if err != io.EOF {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("the error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != FailureEOFError {
|
||||
t.Fatal("invalid wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != TLSHandshakeOperation {
|
||||
t.Fatal("invalid wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||
t.Fatal("invalid wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
|
@ -163,6 +188,9 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
|||
if !times[1].IsZero() {
|
||||
t.Fatal("did not clear timeout on exit")
|
||||
}
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("the returned connection state is not a zero value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with success", func(t *testing.T) {
|
||||
|
@ -217,6 +245,16 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
|||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||
if !errors.Is(err, expected) {
|
||||
|
@ -236,7 +274,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("we cannot create a new conn", func(t *testing.T) {
|
||||
t.Run("h.newConn fails", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
handshaker := &tlsHandshakerConfigurable{
|
||||
NewConn: func(conn net.Conn, config *tls.Config) (TLSConn, error) {
|
||||
|
@ -261,6 +299,222 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
|||
t.Fatal("expected nil tlsConn here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
|
||||
var (
|
||||
expectedSNI = "dns.google"
|
||||
goodStartStartTime bool
|
||||
goodStartInsecureSkipVerify bool
|
||||
goodDoneInsecureSkipVerify bool
|
||||
goodStartServerName bool
|
||||
goodDoneServerName bool
|
||||
goodDoneStartTime bool
|
||||
goodDoneDoneTime bool
|
||||
goodStartRemoteAddr bool
|
||||
goodDoneRemoteAddr bool
|
||||
goodDoneError bool
|
||||
goodConnectionState bool
|
||||
startCalled bool
|
||||
doneCalled bool
|
||||
)
|
||||
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||
defer server.Close()
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
startCalled = true
|
||||
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodStartServerName = (config.ServerName == expectedSNI)
|
||||
goodStartStartTime = (now.Sub(zeroTime) == 0)
|
||||
goodStartRemoteAddr = (remoteAddr == server.Endpoint())
|
||||
},
|
||||
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||
doneCalled = true
|
||||
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodDoneServerName = (config.ServerName == expectedSNI)
|
||||
goodDoneStartTime = (started.Sub(zeroTime) == 0)
|
||||
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
|
||||
goodDoneRemoteAddr = (remoteAddr == server.Endpoint())
|
||||
goodDoneError = (err == nil)
|
||||
goodConnectionState = (!reflect.ValueOf(state).IsZero())
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
tcpConn, err := net.Dial("tcp", server.Endpoint())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: expectedSNI,
|
||||
}
|
||||
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
if reflect.ValueOf(connState).IsZero() {
|
||||
t.Fatal("expected nonzero connState")
|
||||
}
|
||||
if !startCalled {
|
||||
t.Fatal("start not called")
|
||||
}
|
||||
if !doneCalled {
|
||||
t.Fatal("done not called")
|
||||
}
|
||||
if !goodStartInsecureSkipVerify {
|
||||
t.Fatal("invalid start-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodDoneInsecureSkipVerify {
|
||||
t.Fatal("invalid done-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodStartServerName {
|
||||
t.Fatal("invalid start-event's ServerName")
|
||||
}
|
||||
if !goodDoneServerName {
|
||||
t.Fatal("invalid done-event's ServerName")
|
||||
}
|
||||
if !goodStartStartTime {
|
||||
t.Fatal("invalid start-event's start time")
|
||||
}
|
||||
if !goodDoneStartTime {
|
||||
t.Fatal("invalid done-event's start time")
|
||||
}
|
||||
if !goodDoneDoneTime {
|
||||
t.Fatal("invalid done-event's done time")
|
||||
}
|
||||
if !goodStartRemoteAddr {
|
||||
t.Fatal("invalid start-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneRemoteAddr {
|
||||
t.Fatal("invalid done-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneError {
|
||||
t.Fatal("invalid done-event's error")
|
||||
}
|
||||
if !goodConnectionState {
|
||||
t.Fatal("invalid done-event's connState")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
|
||||
var (
|
||||
expectedEndpoint = "8.8.8.8:443"
|
||||
expectedSNI = "dns.google"
|
||||
goodStartStartTime bool
|
||||
goodStartInsecureSkipVerify bool
|
||||
goodDoneInsecureSkipVerify bool
|
||||
goodStartServerName bool
|
||||
goodDoneServerName bool
|
||||
goodDoneStartTime bool
|
||||
goodDoneDoneTime bool
|
||||
goodStartRemoteAddr bool
|
||||
goodDoneRemoteAddr bool
|
||||
goodDoneError bool
|
||||
goodConnectionState bool
|
||||
startCalled bool
|
||||
doneCalled bool
|
||||
)
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
startCalled = true
|
||||
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodStartServerName = (config.ServerName == expectedSNI)
|
||||
goodStartStartTime = (now.Sub(zeroTime) == 0)
|
||||
goodStartRemoteAddr = (remoteAddr == expectedEndpoint)
|
||||
},
|
||||
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||
doneCalled = true
|
||||
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodDoneServerName = (config.ServerName == expectedSNI)
|
||||
goodDoneStartTime = (started.Sub(zeroTime) == 0)
|
||||
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
|
||||
goodDoneRemoteAddr = (remoteAddr == expectedEndpoint)
|
||||
var ew *ErrWrapper
|
||||
goodDoneError = (errors.As(err, &ew) && ew.Error() == FailureEOFError)
|
||||
goodConnectionState = (reflect.ValueOf(state).IsZero())
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
tcpConn := &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockString: func() string {
|
||||
return expectedEndpoint
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: expectedSNI,
|
||||
}
|
||||
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if tlsConn != nil {
|
||||
t.Fatal("expected nil tlsConn")
|
||||
}
|
||||
if !reflect.ValueOf(connState).IsZero() {
|
||||
t.Fatal("expected zero connState")
|
||||
}
|
||||
if !startCalled {
|
||||
t.Fatal("start not called")
|
||||
}
|
||||
if !doneCalled {
|
||||
t.Fatal("done not called")
|
||||
}
|
||||
if !goodStartInsecureSkipVerify {
|
||||
t.Fatal("invalid start-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodDoneInsecureSkipVerify {
|
||||
t.Fatal("invalid done-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodStartServerName {
|
||||
t.Fatal("invalid start-event's ServerName")
|
||||
}
|
||||
if !goodDoneServerName {
|
||||
t.Fatal("invalid done-event's ServerName")
|
||||
}
|
||||
if !goodStartStartTime {
|
||||
t.Fatal("invalid start-event's start time")
|
||||
}
|
||||
if !goodDoneStartTime {
|
||||
t.Fatal("invalid done-event's start time")
|
||||
}
|
||||
if !goodDoneDoneTime {
|
||||
t.Fatal("invalid done-event's done time")
|
||||
}
|
||||
if !goodStartRemoteAddr {
|
||||
t.Fatal("invalid start-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneRemoteAddr {
|
||||
t.Fatal("invalid done-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneError {
|
||||
t.Fatal("invalid done-event's error")
|
||||
}
|
||||
if !goodConnectionState {
|
||||
t.Fatal("invalid done-event's connState")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -413,6 +667,15 @@ func TestTLSDialer(t *testing.T) {
|
|||
return nil
|
||||
}, MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
}, MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
}}, nil
|
||||
}},
|
||||
TLSHandshaker: &tlsHandshakerConfigurable{},
|
||||
|
@ -532,54 +795,6 @@ func TestNewSingleUseTLSDialer(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerErrWrapper(t *testing.T) {
|
||||
t.Run("Handshake", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedConn := &mocks.TLSConn{}
|
||||
expectedState := tls.ConnectionState{
|
||||
Version: tls.VersionTLS12,
|
||||
}
|
||||
th := &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return expectedConn, expectedState, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if expectedState.Version != state.Version {
|
||||
t.Fatal("unexpected state")
|
||||
}
|
||||
if expectedConn != conn {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expectedErr := io.EOF
|
||||
th := &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return nil, tls.ConnectionState{}, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if err == nil || err.Error() != FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewNullTLSDialer(t *testing.T) {
|
||||
dialer := NewNullTLSDialer()
|
||||
conn, err := dialer.DialTLSContext(context.Background(), "", "")
|
||||
|
@ -618,3 +833,35 @@ func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaybeConnectionState(t *testing.T) {
|
||||
t.Run("with an error", func(t *testing.T) {
|
||||
returned := tls.ConnectionState{
|
||||
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||||
}
|
||||
conn := &mocks.TLSConn{
|
||||
MockConnectionState: func() tls.ConnectionState {
|
||||
return returned
|
||||
},
|
||||
}
|
||||
state := tlsMaybeConnectionState(conn, errors.New("mocked error"))
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("expected to see a zero connection state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without an error", func(t *testing.T) {
|
||||
returned := tls.ConnectionState{
|
||||
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||||
}
|
||||
conn := &mocks.TLSConn{
|
||||
MockConnectionState: func() tls.ConnectionState {
|
||||
return returned
|
||||
},
|
||||
}
|
||||
state := tlsMaybeConnectionState(conn, nil)
|
||||
if reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("expected to see a nonzero connection state")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
67
internal/netxlite/trace.go
Normal file
67
internal/netxlite/trace.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package netxlite
|
||||
|
||||
//
|
||||
// Context-based tracing
|
||||
//
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
// traceKey is the private type used to set/retrieve the context's trace.
|
||||
type traceKey struct{}
|
||||
|
||||
// ContextTraceOrDefault retrieves the trace bound to the context or returns
|
||||
// a default implementation of the trace in case no tracing was configured.
|
||||
func ContextTraceOrDefault(ctx context.Context) model.Trace {
|
||||
t, _ := ctx.Value(traceKey{}).(model.Trace)
|
||||
return traceOrDefault(t)
|
||||
}
|
||||
|
||||
// ContextWithTrace returns a new context that binds to the given trace. If the
|
||||
// given trace is nil, this function will call panic.
|
||||
func ContextWithTrace(ctx context.Context, trace model.Trace) context.Context {
|
||||
runtimex.PanicIfTrue(trace == nil, "netxlite.WithTrace passed a nil trace")
|
||||
return context.WithValue(ctx, traceKey{}, trace)
|
||||
}
|
||||
|
||||
// traceOrDefault takes in input a trace and returns in output the
|
||||
// given trace, if not nil, or a default trace implementation.
|
||||
func traceOrDefault(trace model.Trace) model.Trace {
|
||||
if trace != nil {
|
||||
return trace
|
||||
}
|
||||
return &traceDefault{}
|
||||
}
|
||||
|
||||
// traceDefault is a default model.Trace implementation where each method is a no-op.
|
||||
type traceDefault struct{}
|
||||
|
||||
var _ model.Trace = &traceDefault{}
|
||||
|
||||
// TimeNow implements model.Trace.TimeNow
|
||||
func (*traceDefault) TimeNow() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// OnConnectDone implements model.Trace.OnConnectDone.
|
||||
func (*traceDefault) OnConnectDone(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
// nothing
|
||||
}
|
||||
|
||||
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
|
||||
func (*traceDefault) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
// nothing
|
||||
}
|
||||
|
||||
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
|
||||
func (*traceDefault) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||
state tls.ConnectionState, err error, finished time.Time) {
|
||||
// nothing
|
||||
}
|
40
internal/netxlite/trace_test.go
Normal file
40
internal/netxlite/trace_test.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package netxlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestContextTraceOrDefault(t *testing.T) {
|
||||
t.Run("without a configured trace we get a default", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := ContextTraceOrDefault(ctx)
|
||||
_ = tx.(*traceDefault) // panic if cannot cast
|
||||
})
|
||||
|
||||
t.Run("with a configured trace we get the expected trace", func(t *testing.T) {
|
||||
realTrace := &mocks.Trace{}
|
||||
ctx := ContextWithTrace(context.Background(), realTrace)
|
||||
tx := ContextTraceOrDefault(ctx)
|
||||
if tx != realTrace {
|
||||
t.Fatal("not the trace we expected")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextWithTrace(t *testing.T) {
|
||||
t.Run("panics if passed a nil trace", func(t *testing.T) {
|
||||
var called bool
|
||||
func() {
|
||||
defer func() {
|
||||
called = (recover() != nil)
|
||||
}()
|
||||
_ = ContextWithTrace(context.Background(), nil)
|
||||
}()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -19,8 +19,7 @@ func TestNewTLSHandshakerUTLS(t *testing.T) {
|
|||
if logger.DebugLogger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if configurable.NewConn == nil {
|
||||
t.Fatal("expected non-nil NewConn")
|
||||
}
|
||||
|
|
2
internal/testingx/doc.go
Normal file
2
internal/testingx/doc.go
Normal file
|
@ -0,0 +1,2 @@
|
|||
// Package testingx contains code useful for testing.
|
||||
package testingx
|
|
@ -1,11 +1,4 @@
|
|||
// Package fakefill contains code to fill structs for testing.
|
||||
//
|
||||
// This package is quite limited in scope and we can fill only the
|
||||
// structures you typically send over as JSONs.
|
||||
//
|
||||
// As part of future work, we aim to investigate whether we can
|
||||
// replace this implementation with https://go.dev/blog/fuzz-beta.
|
||||
package fakefill
|
||||
package testingx
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
@ -14,7 +7,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// Filler fills specific data structures with random data. The only
|
||||
// FakeFiller fills specific data structures with random data. The only
|
||||
// exception to this behaviour is time.Time, which is instead filled
|
||||
// with the current time plus a small random number of seconds.
|
||||
//
|
||||
|
@ -25,7 +18,13 @@ import (
|
|||
// Caveat: this kind of fillter does not support filling interfaces
|
||||
// and channels and other complex types. The current behavior when this
|
||||
// kind of data types is encountered is to just ignore them.
|
||||
type Filler struct {
|
||||
//
|
||||
// This struct is quite limited in scope and we can fill only the
|
||||
// structures you typically send over as JSONs.
|
||||
//
|
||||
// As part of future work, we aim to investigate whether we can
|
||||
// replace this implementation with https://go.dev/blog/fuzz-beta.
|
||||
type FakeFiller struct {
|
||||
// mu provides mutual exclusion
|
||||
mu sync.Mutex
|
||||
|
||||
|
@ -37,7 +36,7 @@ type Filler struct {
|
|||
rnd *rand.Rand
|
||||
}
|
||||
|
||||
func (ff *Filler) getRandLocked() *rand.Rand {
|
||||
func (ff *FakeFiller) getRandLocked() *rand.Rand {
|
||||
if ff.rnd == nil {
|
||||
now := time.Now
|
||||
if ff.Now != nil {
|
||||
|
@ -48,7 +47,7 @@ func (ff *Filler) getRandLocked() *rand.Rand {
|
|||
return ff.rnd
|
||||
}
|
||||
|
||||
func (ff *Filler) getRandomString() string {
|
||||
func (ff *FakeFiller) getRandomString() string {
|
||||
defer ff.mu.Unlock()
|
||||
ff.mu.Lock()
|
||||
rnd := ff.getRandLocked()
|
||||
|
@ -62,28 +61,28 @@ func (ff *Filler) getRandomString() string {
|
|||
return string(b)
|
||||
}
|
||||
|
||||
func (ff *Filler) getRandomInt64() int64 {
|
||||
func (ff *FakeFiller) getRandomInt64() int64 {
|
||||
defer ff.mu.Unlock()
|
||||
ff.mu.Lock()
|
||||
rnd := ff.getRandLocked()
|
||||
return rnd.Int63()
|
||||
}
|
||||
|
||||
func (ff *Filler) getRandomBool() bool {
|
||||
func (ff *FakeFiller) getRandomBool() bool {
|
||||
defer ff.mu.Unlock()
|
||||
ff.mu.Lock()
|
||||
rnd := ff.getRandLocked()
|
||||
return rnd.Float64() >= 0.5
|
||||
}
|
||||
|
||||
func (ff *Filler) getRandomSmallPositiveInt() int {
|
||||
func (ff *FakeFiller) getRandomSmallPositiveInt() int {
|
||||
defer ff.mu.Unlock()
|
||||
ff.mu.Lock()
|
||||
rnd := ff.getRandLocked()
|
||||
return int(rnd.Int63n(8)) + 1 // safe cast
|
||||
}
|
||||
|
||||
func (ff *Filler) doFill(v reflect.Value) {
|
||||
func (ff *FakeFiller) doFill(v reflect.Value) {
|
||||
for v.Type().Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
// if the pointer is nil, allocate an element
|
||||
|
@ -134,6 +133,6 @@ func (ff *Filler) doFill(v reflect.Value) {
|
|||
}
|
||||
|
||||
// Fill fills the input structure or pointer with random data.
|
||||
func (ff *Filler) Fill(in interface{}) {
|
||||
func (ff *FakeFiller) Fill(in interface{}) {
|
||||
ff.doFill(reflect.ValueOf(in))
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package fakefill
|
||||
package testingx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
@ -16,7 +16,7 @@ type exampleStructure struct {
|
|||
|
||||
func TestFakeFillWorksWithCustomTime(t *testing.T) {
|
||||
var req *exampleStructure
|
||||
ff := &Filler{
|
||||
ff := &FakeFiller{
|
||||
Now: func() time.Time {
|
||||
return time.Date(1992, time.January, 24, 17, 53, 0, 0, time.UTC)
|
||||
},
|
||||
|
@ -29,7 +29,7 @@ func TestFakeFillWorksWithCustomTime(t *testing.T) {
|
|||
|
||||
func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) {
|
||||
var req *exampleStructure
|
||||
ff := &Filler{}
|
||||
ff := &FakeFiller{}
|
||||
ff.Fill(&req)
|
||||
if req == nil {
|
||||
t.Fatal("we expected non nil here")
|
||||
|
@ -38,7 +38,7 @@ func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) {
|
|||
|
||||
func TestFakeFillAllocatesIntoAMapLikeWithStringKeys(t *testing.T) {
|
||||
var resp map[string]*exampleStructure
|
||||
ff := &Filler{}
|
||||
ff := &FakeFiller{}
|
||||
ff.Fill(&resp)
|
||||
if resp == nil {
|
||||
t.Fatal("we expected non nil here")
|
||||
|
@ -62,7 +62,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
var resp map[int64]*exampleStructure
|
||||
ff := &Filler{}
|
||||
ff := &FakeFiller{}
|
||||
ff.Fill(&resp)
|
||||
if resp != nil {
|
||||
t.Fatal("we expected nil here")
|
||||
|
@ -75,7 +75,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) {
|
|||
|
||||
func TestFakeFillAllocatesIntoASlice(t *testing.T) {
|
||||
var resp *[]*exampleStructure
|
||||
ff := &Filler{}
|
||||
ff := &FakeFiller{}
|
||||
ff.Fill(&resp)
|
||||
if resp == nil {
|
||||
t.Fatal("we expected non nil here")
|
48
internal/testingx/time.go
Normal file
48
internal/testingx/time.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package testingx
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TimeDeterministic implements time.Now in a deterministic fashion
|
||||
// such that every time.Time call returns a moment in time that occurs
|
||||
// one second after the configured zeroTime.
|
||||
//
|
||||
// It's safe to use this struct from multiple goroutine contexts.
|
||||
type TimeDeterministic struct {
|
||||
// counter counts the number of "ticks" passed since the zero time: each
|
||||
// call to Now increments this counter by one second.
|
||||
counter time.Duration
|
||||
|
||||
// mu protects fields in this structure from concurrent access.
|
||||
mu sync.Mutex
|
||||
|
||||
// zeroTime is the lazy-initialized zero time. The first call to Now
|
||||
// will initialize this field with the current time.
|
||||
zeroTime time.Time
|
||||
}
|
||||
|
||||
// NewTimeDeterministic creates a new instance using the given zeroTime value.
|
||||
func NewTimeDeterministic(zeroTime time.Time) *TimeDeterministic {
|
||||
return &TimeDeterministic{
|
||||
counter: 0,
|
||||
mu: sync.Mutex{},
|
||||
zeroTime: zeroTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Now is like time.Now but more deterministic. The first call returns the
|
||||
// configured zeroTime and subsequent calls return moments in time that occur
|
||||
// exactly one second after the time returned by the previous call.
|
||||
func (td *TimeDeterministic) Now() time.Time {
|
||||
td.mu.Lock()
|
||||
if td.zeroTime.IsZero() {
|
||||
td.zeroTime = time.Now()
|
||||
}
|
||||
offset := td.counter
|
||||
td.counter += time.Second
|
||||
res := td.zeroTime.Add(offset)
|
||||
td.mu.Unlock()
|
||||
return res
|
||||
}
|
22
internal/testingx/time_test.go
Normal file
22
internal/testingx/time_test.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package testingx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTimeDeterministic(t *testing.T) {
|
||||
td := &TimeDeterministic{}
|
||||
t0 := td.Now()
|
||||
if !t0.Equal(td.zeroTime) {
|
||||
t.Fatal("invalid t0 value")
|
||||
}
|
||||
t1 := td.Now()
|
||||
if t1.Sub(t0) != time.Second {
|
||||
t.Fatal("invalid t1 value")
|
||||
}
|
||||
t2 := td.Now()
|
||||
if t2.Sub(t1) != time.Second {
|
||||
t.Fatal("invalid t2 value")
|
||||
}
|
||||
}
|
|
@ -5,4 +5,8 @@
|
|||
// resulting connection). When done, we will extract the trace containing
|
||||
// all the events that occurred from the saver and process it to determine
|
||||
// what happened during the measurement itself.
|
||||
//
|
||||
// This package is now frozen. Please, use measurexlite for new code. See
|
||||
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||
// for details about this.
|
||||
package tracex
|
||||
|
|
Loading…
Reference in New Issue
Block a user