feat: tlsping and tcpping using step-by-step (#815)
## Checklist - [x] I have read the [contribution guidelines](https://github.com/ooni/probe-cli/blob/master/CONTRIBUTING.md) - [x] reference issue for this pull request: https://github.com/ooni/probe/issues/2158 - [x] if you changed anything related how experiments work and you need to reflect these changes in the ooni/spec repository, please link to the related ooni/spec pull request: https://github.com/ooni/spec/pull/250 ## Description This diff refactors the codebase to reimplement tlsping and tcpping to use the step-by-step measurements style. See docs/design/dd-003-step-by-step.md for more information on the step-by-step measurement style.
This commit is contained in:
parent
5371c7f486
commit
5ebdeb56ca
|
@ -874,7 +874,7 @@ mechanisms if the context is too bad):
|
||||||
|
|
||||||
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
|
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
|
||||||
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||||
qtype uint16, out chan\<- \*parallelResolverResult) {
|
qtype uint16, out chan<- *parallelResolverResult) {
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||||
+ started := time.Now()
|
+ started := time.Now()
|
||||||
|
@ -1329,7 +1329,7 @@ const webDomain = "web.telegram.org"
|
||||||
// This method does not return any value and writes results directly inside
|
// This method does not return any value and writes results directly inside
|
||||||
// the test keys, which have thread safe methods for that.
|
// the test keys, which have thread safe methods for that.
|
||||||
func (mx *Measurer) measureWebEndpointHTTPS(ctx context.Context, wg *sync.WaitGroup,
|
func (mx *Measurer) measureWebEndpointHTTPS(ctx context.Context, wg *sync.WaitGroup,
|
||||||
logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
|
logger model.Logger, zeroTime time.Time, tk *TestKeys, address string) {
|
||||||
// 0. setup
|
// 0. setup
|
||||||
const webTimeout = 7 * time.Second
|
const webTimeout = 7 * time.Second
|
||||||
ctx, cancel := context.WithTimeout(ctx, webTimeout)
|
ctx, cancel := context.WithTimeout(ctx, webTimeout)
|
||||||
|
@ -1344,8 +1344,8 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
|
||||||
// 1. establish a TCP connection with the endpoint
|
// 1. establish a TCP connection with the endpoint
|
||||||
|
|
||||||
// dialer := nextlite.NewDialerwithoutResolver(logger) // --- (removed line)
|
// dialer := nextlite.NewDialerwithoutResolver(logger) // --- (removed line)
|
||||||
trace := measurexlite.NewTrace(index, logger, zeroTime) // +++ (added line)
|
trace := measurexlite.NewTrace(index, zeroTime) // +++ (added line)
|
||||||
dialer := trace.NewDialerWithoutResolver() // +++ (...)
|
dialer := trace.NewDialerWithoutResolver(logger) // +++ (...)
|
||||||
defer tk.addTCPConnectResults(trace.TCPConnectResults()) // +++ (...)
|
defer tk.addTCPConnectResults(trace.TCPConnectResults()) // +++ (...)
|
||||||
|
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", endpoint)
|
conn, err := dialer.DialContext(ctx, "tcp", endpoint)
|
||||||
|
@ -1366,7 +1366,7 @@ logger model.Logger, zeroTime time.Time, tk \*TestKeys, address string) {
|
||||||
// thx := netxlite.NewTLSHandshakerStdlib(logger) // ---
|
// thx := netxlite.NewTLSHandshakerStdlib(logger) // ---
|
||||||
conn = trace.WrapConn(conn) // +++
|
conn = trace.WrapConn(conn) // +++
|
||||||
defer tk.addNetworkEvents(trace.NetworkEvents()) // +++
|
defer tk.addNetworkEvents(trace.NetworkEvents()) // +++
|
||||||
thx := trace.NewTLSHandshakerStdlib() // +++
|
thx := trace.NewTLSHandshakerStdlib(logger) // +++
|
||||||
defer tk.addTLSHandshakeResult(trace.TLSHandshakeResults()) // +++
|
defer tk.addTLSHandshakeResult(trace.TLSHandshakeResults()) // +++
|
||||||
|
|
||||||
config := &tls.Config{
|
config := &tls.Config{
|
||||||
|
@ -1498,23 +1498,22 @@ type Trace struct { /* ... */ }
|
||||||
|
|
||||||
var _ model.Trace = &Trace{}
|
var _ model.Trace = &Trace{}
|
||||||
|
|
||||||
func NewTrace(index int64, logger model.Logger, zeroTime time.Time) *Trace {
|
func NewTrace(index int64, zeroTime time.Time) *Trace {
|
||||||
const (
|
const (
|
||||||
tlsHandshakeBuffer = 16,
|
tlsHandshakeBuffer = 16,
|
||||||
// ...
|
// ...
|
||||||
)
|
)
|
||||||
return &Trace{
|
return &Trace{
|
||||||
Index: index,
|
Index: index,
|
||||||
Logger: logger,
|
|
||||||
TLS: make(chan *model.ArchivalTLSOrQUICHandshakeResult, tlsHandshakeBuffer),
|
TLS: make(chan *model.ArchivalTLSOrQUICHandshakeResult, tlsHandshakeBuffer),
|
||||||
ZeroTime: zeroTime,
|
ZeroTime: zeroTime,
|
||||||
// ...
|
// ...
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Trace) NewTLSHandshakerStdlib() model.TLSHandshaker {
|
func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
|
||||||
return &tlsHandshakerTrace{
|
return &tlsHandshakerTrace{
|
||||||
TLSHandshaker: netxlite.NewTLSHandshakerStdlib(tx.Logger),
|
TLSHandshaker: netxlite.NewTLSHandshakerStdlib(dl),
|
||||||
Trace: tx,
|
Trace: tx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1524,8 +1523,8 @@ type tlsHandshakerTrace { /* ... */ }
|
||||||
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
|
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
|
||||||
|
|
||||||
func (thx *tlsHandshakerTrace) Handshake(ctx context.Context,
|
func (thx *tlsHandshakerTrace) Handshake(ctx context.Context,
|
||||||
conn net.Conn, config \*tls.Config) (net.Conn, tls.ConnectionState, error) {
|
conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
ctx = netxlite.ContextWithTrace(ctx, thx.Trace) // <- here we setup the context magic
|
ctx = netxlite.ContextWithTraceOrDefault(ctx, thx.Trace) // <- here we setup the context magic
|
||||||
return thx.TLSHandshaker.Handshake(ctx, conn, config)
|
return thx.TLSHandshaker.Handshake(ctx, conn, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1540,7 +1539,7 @@ func (tx *Trace) OnTLSHandshake(started time.Time, remoteAddr string,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Trace) TLSHandshakeResults() (out []\*model.ArchivalTLSOrQUICHandshakeResult) {
|
func (tx *Trace) TLSHandshakeResults() (out []*model.ArchivalTLSOrQUICHandshakeResult) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case ev := <-tx.TLS:
|
case ev := <-tx.TLS:
|
||||||
|
|
|
@ -10,13 +10,13 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/measurex"
|
"github.com/ooni/probe-cli/v3/internal/measurexlite"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testName = "tcpping"
|
testName = "tcpping"
|
||||||
testVersion = "0.1.0"
|
testVersion = "0.2.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config contains the experiment configuration.
|
// Config contains the experiment configuration.
|
||||||
|
@ -49,7 +49,7 @@ type TestKeys struct {
|
||||||
|
|
||||||
// SinglePing contains the results of a single ping.
|
// SinglePing contains the results of a single ping.
|
||||||
type SinglePing struct {
|
type SinglePing struct {
|
||||||
TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"`
|
TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measurer performs the measurement.
|
// Measurer performs the measurement.
|
||||||
|
@ -103,42 +103,47 @@ func (m *Measurer) Run(
|
||||||
}
|
}
|
||||||
tk := new(TestKeys)
|
tk := new(TestKeys)
|
||||||
measurement.TestKeys = tk
|
measurement.TestKeys = tk
|
||||||
out := make(chan *measurex.EndpointMeasurement)
|
out := make(chan *SinglePing)
|
||||||
mxmx := measurex.NewMeasurerWithDefaultSettings()
|
go m.tcpPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out)
|
||||||
go m.tcpPingLoop(ctx, mxmx, parsed.Host, out)
|
|
||||||
for len(tk.Pings) < int(m.config.repetitions()) {
|
for len(tk.Pings) < int(m.config.repetitions()) {
|
||||||
meas := <-out
|
tk.Pings = append(tk.Pings, <-out)
|
||||||
tk.Pings = append(tk.Pings, &SinglePing{
|
|
||||||
TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
return nil // return nil so we always submit the measurement
|
return nil // return nil so we always submit the measurement
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpPingLoop sends all the ping requests and emits the results onto the out channel.
|
// tcpPingLoop sends all the ping requests and emits the results onto the out channel.
|
||||||
func (m *Measurer) tcpPingLoop(ctx context.Context, mxmx *measurex.Measurer,
|
func (m *Measurer) tcpPingLoop(ctx context.Context, zeroTime time.Time,
|
||||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
logger model.Logger, address string, out chan<- *SinglePing) {
|
||||||
ticker := time.NewTicker(m.config.delay())
|
ticker := time.NewTicker(m.config.delay())
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for i := int64(0); i < m.config.repetitions(); i++ {
|
for i := int64(0); i < m.config.repetitions(); i++ {
|
||||||
go m.tcpPingAsync(ctx, mxmx, address, out)
|
go m.tcpPingAsync(ctx, i, zeroTime, logger, address, out)
|
||||||
<-ticker.C
|
<-ticker.C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpPingAsync performs a TCP ping and emits the result onto the out channel.
|
// tcpPingAsync performs a TCP ping and emits the result onto the out channel.
|
||||||
func (m *Measurer) tcpPingAsync(ctx context.Context, mxmx *measurex.Measurer,
|
func (m *Measurer) tcpPingAsync(ctx context.Context, index int64,
|
||||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) {
|
||||||
out <- m.tcpConnect(ctx, mxmx, address)
|
out <- m.tcpConnect(ctx, index, zeroTime, logger, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// tcpConnect performs a TCP connect and returns the result to the caller.
|
// tcpConnect performs a TCP connect and returns the result to the caller.
|
||||||
func (m *Measurer) tcpConnect(ctx context.Context, mxmx *measurex.Measurer,
|
func (m *Measurer) tcpConnect(ctx context.Context, index int64,
|
||||||
address string) *measurex.EndpointMeasurement {
|
zeroTime time.Time, logger model.Logger, address string) *SinglePing {
|
||||||
// TODO(bassosimone): make the timeout user-configurable
|
// TODO(bassosimone): make the timeout user-configurable
|
||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return mxmx.TCPConnect(ctx, address)
|
trace := measurexlite.NewTrace(index, zeroTime)
|
||||||
|
dialer := trace.NewDialerWithoutResolver(logger)
|
||||||
|
ol := measurexlite.NewOperationLogger(logger, "TCPPing #%d %s", index, address)
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||||
|
ol.Stop(err)
|
||||||
|
measurexlite.MaybeClose(conn)
|
||||||
|
sp := &SinglePing{
|
||||||
|
TCPConnect: <-trace.TCPConnect,
|
||||||
|
}
|
||||||
|
return sp
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewExperimentMeasurer creates a new ExperimentMeasurer.
|
// NewExperimentMeasurer creates a new ExperimentMeasurer.
|
||||||
|
|
|
@ -40,14 +40,16 @@ func TestMeasurer_run(t *testing.T) {
|
||||||
if m.ExperimentName() != "tcpping" {
|
if m.ExperimentName() != "tcpping" {
|
||||||
t.Fatal("invalid experiment name")
|
t.Fatal("invalid experiment name")
|
||||||
}
|
}
|
||||||
if m.ExperimentVersion() != "0.1.0" {
|
if m.ExperimentVersion() != "0.2.0" {
|
||||||
t.Fatal("invalid experiment version")
|
t.Fatal("invalid experiment version")
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
meas := &model.Measurement{
|
meas := &model.Measurement{
|
||||||
Input: model.MeasurementTarget(input),
|
Input: model.MeasurementTarget(input),
|
||||||
}
|
}
|
||||||
sess := &mockable.Session{}
|
sess := &mockable.Session{
|
||||||
|
MockableLogger: model.DiscardLogger,
|
||||||
|
}
|
||||||
callbacks := model.NewPrinterCallbacks(model.DiscardLogger)
|
callbacks := model.NewPrinterCallbacks(model.DiscardLogger)
|
||||||
err := m.Run(ctx, sess, meas, callbacks)
|
err := m.Run(ctx, sess, meas, callbacks)
|
||||||
return meas, m, err
|
return meas, m, err
|
||||||
|
|
|
@ -13,14 +13,14 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/measurex"
|
"github.com/ooni/probe-cli/v3/internal/measurexlite"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testName = "tlsping"
|
testName = "tlsping"
|
||||||
testVersion = "0.1.0"
|
testVersion = "0.2.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config contains the experiment configuration.
|
// Config contains the experiment configuration.
|
||||||
|
@ -77,9 +77,9 @@ type TestKeys struct {
|
||||||
|
|
||||||
// SinglePing contains the results of a single ping.
|
// SinglePing contains the results of a single ping.
|
||||||
type SinglePing struct {
|
type SinglePing struct {
|
||||||
NetworkEvents []*measurex.ArchivalNetworkEvent `json:"network_events"`
|
NetworkEvents []*model.ArchivalNetworkEvent `json:"network_events"`
|
||||||
TCPConnect []*measurex.ArchivalTCPConnect `json:"tcp_connect"`
|
TCPConnect *model.ArchivalTCPConnectResult `json:"tcp_connect"`
|
||||||
TLSHandshakes []*measurex.ArchivalQUICTLSHandshakeEvent `json:"tls_handshakes"`
|
TLSHandshake *model.ArchivalTLSOrQUICHandshakeResult `json:"tls_handshake"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measurer performs the measurement.
|
// Measurer performs the measurement.
|
||||||
|
@ -133,49 +133,67 @@ func (m *Measurer) Run(
|
||||||
}
|
}
|
||||||
tk := new(TestKeys)
|
tk := new(TestKeys)
|
||||||
measurement.TestKeys = tk
|
measurement.TestKeys = tk
|
||||||
out := make(chan *measurex.EndpointMeasurement)
|
out := make(chan *SinglePing)
|
||||||
mxmx := measurex.NewMeasurerWithDefaultSettings()
|
go m.tlsPingLoop(ctx, measurement.MeasurementStartTimeSaved, sess.Logger(), parsed.Host, out)
|
||||||
go m.tlsPingLoop(ctx, mxmx, parsed.Host, out)
|
|
||||||
for len(tk.Pings) < int(m.config.repetitions()) {
|
for len(tk.Pings) < int(m.config.repetitions()) {
|
||||||
meas := <-out
|
tk.Pings = append(tk.Pings, <-out)
|
||||||
tk.Pings = append(tk.Pings, &SinglePing{
|
|
||||||
NetworkEvents: measurex.NewArchivalNetworkEventList(meas.ReadWrite),
|
|
||||||
TCPConnect: measurex.NewArchivalTCPConnectList(meas.Connect),
|
|
||||||
TLSHandshakes: measurex.NewArchivalQUICTLSHandshakeEventList(meas.TLSHandshake),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
return nil // return nil so we always submit the measurement
|
return nil // return nil so we always submit the measurement
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsPingLoop sends all the ping requests and emits the results onto the out channel.
|
// tlsPingLoop sends all the ping requests and emits the results onto the out channel.
|
||||||
func (m *Measurer) tlsPingLoop(ctx context.Context, mxmx *measurex.Measurer,
|
func (m *Measurer) tlsPingLoop(ctx context.Context, zeroTime time.Time,
|
||||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
logger model.Logger, address string, out chan<- *SinglePing) {
|
||||||
ticker := time.NewTicker(m.config.delay())
|
ticker := time.NewTicker(m.config.delay())
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for i := int64(0); i < m.config.repetitions(); i++ {
|
for i := int64(0); i < m.config.repetitions(); i++ {
|
||||||
go m.tlsPingAsync(ctx, mxmx, address, out)
|
go m.tlsPingAsync(ctx, i, zeroTime, logger, address, out)
|
||||||
<-ticker.C
|
<-ticker.C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsPingAsync performs a TLS ping and emits the result onto the out channel.
|
// tlsPingAsync performs a TLS ping and emits the result onto the out channel.
|
||||||
func (m *Measurer) tlsPingAsync(ctx context.Context, mxmx *measurex.Measurer,
|
func (m *Measurer) tlsPingAsync(ctx context.Context, index int64,
|
||||||
address string, out chan<- *measurex.EndpointMeasurement) {
|
zeroTime time.Time, logger model.Logger, address string, out chan<- *SinglePing) {
|
||||||
out <- m.tlsConnectAndHandshake(ctx, mxmx, address)
|
out <- m.tlsConnectAndHandshake(ctx, index, zeroTime, logger, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsConnectAndHandshake performs a TCP connect followed by a TLS handshake
|
// tlsConnectAndHandshake performs a TCP connect followed by a TLS handshake
|
||||||
// and returns the results of these operations to the caller.
|
// and returns the results of these operations to the caller.
|
||||||
func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, mxmx *measurex.Measurer,
|
func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, index int64,
|
||||||
address string) *measurex.EndpointMeasurement {
|
zeroTime time.Time, logger model.Logger, address string) *SinglePing {
|
||||||
// TODO(bassosimone): make the timeout user-configurable
|
// TODO(bassosimone): make the timeout user-configurable
|
||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return mxmx.TLSConnectAndHandshake(ctx, address, &tls.Config{
|
sp := &SinglePing{
|
||||||
NextProtos: strings.Split(m.config.alpn(), " "),
|
NetworkEvents: []*model.ArchivalNetworkEvent{},
|
||||||
|
TCPConnect: nil,
|
||||||
|
TLSHandshake: nil,
|
||||||
|
}
|
||||||
|
trace := measurexlite.NewTrace(index, zeroTime)
|
||||||
|
dialer := trace.NewDialerWithoutResolver(logger)
|
||||||
|
alpn := strings.Split(m.config.alpn(), " ")
|
||||||
|
sni := m.config.sni(address)
|
||||||
|
ol := measurexlite.NewOperationLogger(logger, "TLSPing #%d %s %s %v", index, address, sni, alpn)
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||||
|
sp.TCPConnect = <-trace.TCPConnect
|
||||||
|
if err != nil {
|
||||||
|
ol.Stop(err)
|
||||||
|
return sp
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
conn = trace.WrapNetConn(conn)
|
||||||
|
thx := trace.NewTLSHandshakerStdlib(logger)
|
||||||
|
config := &tls.Config{
|
||||||
|
NextProtos: alpn,
|
||||||
RootCAs: netxlite.NewDefaultCertPool(),
|
RootCAs: netxlite.NewDefaultCertPool(),
|
||||||
ServerName: m.config.sni(address),
|
ServerName: sni,
|
||||||
})
|
}
|
||||||
|
_, _, err = thx.Handshake(ctx, conn, config)
|
||||||
|
ol.Stop(err)
|
||||||
|
sp.TLSHandshake = <-trace.TLSHandshake
|
||||||
|
sp.NetworkEvents = trace.NetworkEvents()
|
||||||
|
return sp
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewExperimentMeasurer creates a new ExperimentMeasurer.
|
// NewExperimentMeasurer creates a new ExperimentMeasurer.
|
||||||
|
|
|
@ -39,7 +39,7 @@ func TestMeasurer_run(t *testing.T) {
|
||||||
const expectedPings = 4
|
const expectedPings = 4
|
||||||
|
|
||||||
// runHelper is an helper function to run this set of tests.
|
// runHelper is an helper function to run this set of tests.
|
||||||
runHelper := func(input string) (*model.Measurement, model.ExperimentMeasurer, error) {
|
runHelper := func(ctx context.Context, input string) (*model.Measurement, model.ExperimentMeasurer, error) {
|
||||||
m := NewExperimentMeasurer(Config{
|
m := NewExperimentMeasurer(Config{
|
||||||
ALPN: "http/1.1",
|
ALPN: "http/1.1",
|
||||||
Delay: 1, // millisecond
|
Delay: 1, // millisecond
|
||||||
|
@ -48,10 +48,9 @@ func TestMeasurer_run(t *testing.T) {
|
||||||
if m.ExperimentName() != "tlsping" {
|
if m.ExperimentName() != "tlsping" {
|
||||||
t.Fatal("invalid experiment name")
|
t.Fatal("invalid experiment name")
|
||||||
}
|
}
|
||||||
if m.ExperimentVersion() != "0.1.0" {
|
if m.ExperimentVersion() != "0.2.0" {
|
||||||
t.Fatal("invalid experiment version")
|
t.Fatal("invalid experiment version")
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
|
||||||
meas := &model.Measurement{
|
meas := &model.Measurement{
|
||||||
Input: model.MeasurementTarget(input),
|
Input: model.MeasurementTarget(input),
|
||||||
}
|
}
|
||||||
|
@ -64,34 +63,34 @@ func TestMeasurer_run(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("with empty input", func(t *testing.T) {
|
t.Run("with empty input", func(t *testing.T) {
|
||||||
_, _, err := runHelper("")
|
_, _, err := runHelper(context.Background(), "")
|
||||||
if !errors.Is(err, errNoInputProvided) {
|
if !errors.Is(err, errNoInputProvided) {
|
||||||
t.Fatal("unexpected error", err)
|
t.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with invalid URL", func(t *testing.T) {
|
t.Run("with invalid URL", func(t *testing.T) {
|
||||||
_, _, err := runHelper("\t")
|
_, _, err := runHelper(context.Background(), "\t")
|
||||||
if !errors.Is(err, errInputIsNotAnURL) {
|
if !errors.Is(err, errInputIsNotAnURL) {
|
||||||
t.Fatal("unexpected error", err)
|
t.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with invalid scheme", func(t *testing.T) {
|
t.Run("with invalid scheme", func(t *testing.T) {
|
||||||
_, _, err := runHelper("https://8.8.8.8:443/")
|
_, _, err := runHelper(context.Background(), "https://8.8.8.8:443/")
|
||||||
if !errors.Is(err, errInvalidScheme) {
|
if !errors.Is(err, errInvalidScheme) {
|
||||||
t.Fatal("unexpected error", err)
|
t.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with missing port", func(t *testing.T) {
|
t.Run("with missing port", func(t *testing.T) {
|
||||||
_, _, err := runHelper("tlshandshake://8.8.8.8")
|
_, _, err := runHelper(context.Background(), "tlshandshake://8.8.8.8")
|
||||||
if !errors.Is(err, errMissingPort) {
|
if !errors.Is(err, errMissingPort) {
|
||||||
t.Fatal("unexpected error", err)
|
t.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with local listener", func(t *testing.T) {
|
t.Run("with local listener and successful outcome", func(t *testing.T) {
|
||||||
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
}))
|
}))
|
||||||
|
@ -101,7 +100,37 @@ func TestMeasurer_run(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
URL.Scheme = "tlshandshake"
|
URL.Scheme = "tlshandshake"
|
||||||
meas, m, err := runHelper(URL.String())
|
meas, m, err := runHelper(context.Background(), URL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tk := meas.TestKeys.(*TestKeys)
|
||||||
|
if len(tk.Pings) != expectedPings {
|
||||||
|
t.Fatal("unexpected number of pings")
|
||||||
|
}
|
||||||
|
ask, err := m.GetSummaryKeys(meas)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("cannot obtain summary")
|
||||||
|
}
|
||||||
|
summary := ask.(SummaryKeys)
|
||||||
|
if summary.IsAnomaly {
|
||||||
|
t.Fatal("expected no anomaly")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with local listener and connect issues", func(t *testing.T) {
|
||||||
|
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}))
|
||||||
|
defer srvr.Close()
|
||||||
|
URL, err := url.Parse(srvr.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
URL.Scheme = "tlshandshake"
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // so we cannot dial any connection
|
||||||
|
meas, m, err := runHelper(ctx, URL.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
// Package urlgetter implements a nettest that fetches a URL.
|
// Package urlgetter implements a nettest that fetches a URL.
|
||||||
//
|
//
|
||||||
// See https://github.com/ooni/spec/blob/master/nettests/ts-027-urlgetter.md.
|
// See https://github.com/ooni/spec/blob/master/nettests/ts-027-urlgetter.md.
|
||||||
|
//
|
||||||
|
// This package is now frozen. Please, use measurexlite for new code. New
|
||||||
|
// network experiments should not depend on this package. Please see
|
||||||
|
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||||
|
// for details about this.
|
||||||
package urlgetter
|
package urlgetter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -46,4 +46,8 @@
|
||||||
//
|
//
|
||||||
// See docs/design/dd-002-nets.md in the probe-cli repository for
|
// See docs/design/dd-002-nets.md in the probe-cli repository for
|
||||||
// the design document describing this package.
|
// the design document describing this package.
|
||||||
|
//
|
||||||
|
// This package is now frozen. Please, use measurexlite for new code. See
|
||||||
|
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||||
|
// for details about this.
|
||||||
package netx
|
package netx
|
||||||
|
|
|
@ -13,10 +13,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ooni/probe-cli/v3/internal/fakefill"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/version"
|
"github.com/ooni/probe-cli/v3/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ func TestAPIClientTemplate(t *testing.T) {
|
||||||
HTTPClient: http.DefaultClient,
|
HTTPClient: http.DefaultClient,
|
||||||
Logger: model.DiscardLogger,
|
Logger: model.DiscardLogger,
|
||||||
}
|
}
|
||||||
ff := &fakefill.Filler{}
|
ff := &testingx.FakeFiller{}
|
||||||
ff.Fill(tmpl)
|
ff.Fill(tmpl)
|
||||||
ac := tmpl.Build()
|
ac := tmpl.Build()
|
||||||
orig := apiClient(*tmpl)
|
orig := apiClient(*tmpl)
|
||||||
|
@ -64,7 +64,7 @@ func TestAPIClientTemplate(t *testing.T) {
|
||||||
HTTPClient: http.DefaultClient,
|
HTTPClient: http.DefaultClient,
|
||||||
Logger: model.DiscardLogger,
|
Logger: model.DiscardLogger,
|
||||||
}
|
}
|
||||||
ff := &fakefill.Filler{}
|
ff := &testingx.FakeFiller{}
|
||||||
ff.Fill(tmpl)
|
ff.Fill(tmpl)
|
||||||
tok := ""
|
tok := ""
|
||||||
ff.Fill(&tok)
|
ff.Fill(&tok)
|
||||||
|
@ -188,7 +188,7 @@ func TestAPIClient(t *testing.T) {
|
||||||
|
|
||||||
t.Run("sets the content-type properly", func(t *testing.T) {
|
t.Run("sets the content-type properly", func(t *testing.T) {
|
||||||
var jsonReq fakeRequest
|
var jsonReq fakeRequest
|
||||||
ff := &fakefill.Filler{}
|
ff := &testingx.FakeFiller{}
|
||||||
ff.Fill(&jsonReq)
|
ff.Fill(&jsonReq)
|
||||||
client := newAPIClient()
|
client := newAPIClient()
|
||||||
req, err := client.newRequestWithJSONBody(
|
req, err := client.newRequestWithJSONBody(
|
||||||
|
|
|
@ -1,2 +1,6 @@
|
||||||
// Package measurex contains measurement extensions.
|
// Package measurex contains measurement extensions.
|
||||||
|
//
|
||||||
|
// This package is now frozen. Please, use measurexlite for new code. See
|
||||||
|
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||||
|
// for details about this.
|
||||||
package measurex
|
package measurex
|
||||||
|
|
103
internal/measurexlite/conn.go
Normal file
103
internal/measurexlite/conn.go
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Conn tracing
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaybeClose is a convenience function for closing a conn only when such a conn isn't nil.
|
||||||
|
func MaybeClose(conn net.Conn) (err error) {
|
||||||
|
if conn != nil {
|
||||||
|
err = conn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapNetConn returns a wrapped conn that saves network events into this trace.
|
||||||
|
func (tx *Trace) WrapNetConn(conn net.Conn) net.Conn {
|
||||||
|
return &connTrace{
|
||||||
|
Conn: conn,
|
||||||
|
tx: tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// connTrace is a trace-aware net.Conn.
|
||||||
|
type connTrace struct {
|
||||||
|
// Implementation note: it seems safe to use embedding here because net.Conn
|
||||||
|
// is an interface from the standard library that we don't control
|
||||||
|
net.Conn
|
||||||
|
tx *Trace
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = &connTrace{}
|
||||||
|
|
||||||
|
// Read implements net.Conn.Read and saves network events.
|
||||||
|
func (c *connTrace) Read(b []byte) (int, error) {
|
||||||
|
network := c.RemoteAddr().Network()
|
||||||
|
addr := c.RemoteAddr().String()
|
||||||
|
started := c.tx.TimeSince(c.tx.ZeroTime)
|
||||||
|
count, err := c.Conn.Read(b)
|
||||||
|
finished := c.tx.TimeSince(c.tx.ZeroTime)
|
||||||
|
select {
|
||||||
|
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
|
||||||
|
c.tx.Index, started, netxlite.ReadOperation, network, addr, count, err, finished):
|
||||||
|
default: // buffer is full
|
||||||
|
}
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements net.Conn.Write and saves network events.
|
||||||
|
func (c *connTrace) Write(b []byte) (int, error) {
|
||||||
|
network := c.RemoteAddr().Network()
|
||||||
|
addr := c.RemoteAddr().String()
|
||||||
|
started := c.tx.TimeSince(c.tx.ZeroTime)
|
||||||
|
count, err := c.Conn.Write(b)
|
||||||
|
finished := c.tx.TimeSince(c.tx.ZeroTime)
|
||||||
|
select {
|
||||||
|
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
|
||||||
|
c.tx.Index, started, netxlite.WriteOperation, network, addr, count, err, finished):
|
||||||
|
default: // buffer is full
|
||||||
|
}
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewArchivalNetworkEvent creates a new model.ArchivalNetworkEvent.
|
||||||
|
func NewArchivalNetworkEvent(index int64, started time.Duration, operation string, network string,
|
||||||
|
address string, count int, err error, finished time.Duration) *model.ArchivalNetworkEvent {
|
||||||
|
return &model.ArchivalNetworkEvent{
|
||||||
|
Address: address,
|
||||||
|
Failure: tracex.NewFailure(err),
|
||||||
|
NumBytes: int64(count),
|
||||||
|
Operation: operation,
|
||||||
|
Proto: network,
|
||||||
|
T: finished.Seconds(),
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAnnotationArchivalNetworkEvent is a simplified NewArchivalNetworkEvent
|
||||||
|
// where we create a simple annotation without attached I/O info.
|
||||||
|
func NewAnnotationArchivalNetworkEvent(
|
||||||
|
index int64, time time.Duration, operation string) *model.ArchivalNetworkEvent {
|
||||||
|
return NewArchivalNetworkEvent(index, time, operation, "", "", 0, nil, time)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkEvents drains the network events buffered inside the NetworkEvent channel.
|
||||||
|
func (tx *Trace) NetworkEvents() (out []*model.ArchivalNetworkEvent) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev := <-tx.NetworkEvent:
|
||||||
|
out = append(out, ev)
|
||||||
|
default:
|
||||||
|
return // done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
243
internal/measurexlite/conn_test.go
Normal file
243
internal/measurexlite/conn_test.go
Normal file
|
@ -0,0 +1,243 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMaybeClose(t *testing.T) {
|
||||||
|
t.Run("with nil conn", func(t *testing.T) {
|
||||||
|
var conn net.Conn = nil
|
||||||
|
MaybeClose(conn)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with nonnil conn", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
conn := &mocks.Conn{
|
||||||
|
MockClose: func() error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := MaybeClose(conn); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapNetConn(t *testing.T) {
|
||||||
|
t.Run("WrapNetConn wraps the conn", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Conn{}
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
conn := trace.WrapNetConn(underlying)
|
||||||
|
ct := conn.(*connTrace)
|
||||||
|
if ct.Conn != underlying {
|
||||||
|
t.Fatal("invalid underlying")
|
||||||
|
}
|
||||||
|
if ct.tx != trace {
|
||||||
|
t.Fatal("invalid trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Read saves a trace", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Conn{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now // deterministic time counting
|
||||||
|
conn := trace.WrapNetConn(underlying)
|
||||||
|
const bufsiz = 128
|
||||||
|
buffer := make([]byte, bufsiz)
|
||||||
|
count, err := conn.Read(buffer)
|
||||||
|
if count != bufsiz {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("invalid err")
|
||||||
|
}
|
||||||
|
events := trace.NetworkEvents()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("did not save network events")
|
||||||
|
}
|
||||||
|
expect := &model.ArchivalNetworkEvent{
|
||||||
|
Address: "1.1.1.1:443",
|
||||||
|
Failure: nil,
|
||||||
|
NumBytes: bufsiz,
|
||||||
|
Operation: netxlite.ReadOperation,
|
||||||
|
Proto: "tcp",
|
||||||
|
T: 1.0,
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
got := events[0]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Read discards the event when the buffer is full", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Conn{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
|
||||||
|
conn := trace.WrapNetConn(underlying)
|
||||||
|
const bufsiz = 128
|
||||||
|
buffer := make([]byte, bufsiz)
|
||||||
|
count, err := conn.Read(buffer)
|
||||||
|
if count != bufsiz {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("invalid err")
|
||||||
|
}
|
||||||
|
events := trace.NetworkEvents()
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatal("expected no network events")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Write saves a trace", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Conn{
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now // deterministic time tracking
|
||||||
|
conn := trace.WrapNetConn(underlying)
|
||||||
|
const bufsiz = 128
|
||||||
|
buffer := make([]byte, bufsiz)
|
||||||
|
count, err := conn.Write(buffer)
|
||||||
|
if count != bufsiz {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("invalid err")
|
||||||
|
}
|
||||||
|
events := trace.NetworkEvents()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("did not save network events")
|
||||||
|
}
|
||||||
|
expect := &model.ArchivalNetworkEvent{
|
||||||
|
Address: "1.1.1.1:443",
|
||||||
|
Failure: nil,
|
||||||
|
NumBytes: bufsiz,
|
||||||
|
Operation: netxlite.WriteOperation,
|
||||||
|
Proto: "tcp",
|
||||||
|
T: 1.0,
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
got := events[0]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Write discards the event when the buffer is full", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Conn{
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
|
||||||
|
conn := trace.WrapNetConn(underlying)
|
||||||
|
const bufsiz = 128
|
||||||
|
buffer := make([]byte, bufsiz)
|
||||||
|
count, err := conn.Write(buffer)
|
||||||
|
if count != bufsiz {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("invalid err")
|
||||||
|
}
|
||||||
|
events := trace.NetworkEvents()
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatal("expected no network events")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAnnotationArchivalNetworkEvent(t *testing.T) {
|
||||||
|
var (
|
||||||
|
index int64 = 3
|
||||||
|
duration = 250 * time.Millisecond
|
||||||
|
operation = "tls_handshake_start"
|
||||||
|
)
|
||||||
|
expect := &model.ArchivalNetworkEvent{
|
||||||
|
Address: "",
|
||||||
|
Failure: nil,
|
||||||
|
NumBytes: 0,
|
||||||
|
Operation: operation,
|
||||||
|
Proto: "",
|
||||||
|
T: duration.Seconds(),
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
got := NewAnnotationArchivalNetworkEvent(
|
||||||
|
index, duration, operation,
|
||||||
|
)
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
}
|
121
internal/measurexlite/dialer.go
Normal file
121
internal/measurexlite/dialer.go
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Dialer tracing
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewDialerWithoutResolver is equivalent to netxlite.NewDialerWithoutResolver
|
||||||
|
// except that it returns a model.Dialer that uses this trace.
|
||||||
|
//
|
||||||
|
// Note: unlike code in netx or measurex, this factory DOES NOT return you a
|
||||||
|
// dialer that also performs wrapping of a net.Conn in case of success. If you
|
||||||
|
// want to wrap the conn, you need to wrap it explicitly using WrapNetConn.
|
||||||
|
func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
|
||||||
|
return &dialerTrace{
|
||||||
|
d: tx.newDialerWithoutResolver(dl),
|
||||||
|
tx: tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialerTrace is a trace-aware model.Dialer.
|
||||||
|
type dialerTrace struct {
|
||||||
|
d model.Dialer
|
||||||
|
tx *Trace
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.Dialer = &dialerTrace{}
|
||||||
|
|
||||||
|
// DialContext implements model.Dialer.DialContext.
|
||||||
|
func (d *dialerTrace) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return d.d.DialContext(netxlite.ContextWithTrace(ctx, d.tx), network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections implements model.Dialer.CloseIdleConnections.
|
||||||
|
func (d *dialerTrace) CloseIdleConnections() {
|
||||||
|
d.d.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnTCPConnectDone implements model.Trace.OnTCPConnectDone.
|
||||||
|
func (tx *Trace) OnConnectDone(
|
||||||
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
select {
|
||||||
|
case tx.TCPConnect <- NewArchivalTCPConnectResult(
|
||||||
|
tx.Index,
|
||||||
|
started.Sub(tx.ZeroTime),
|
||||||
|
remoteAddr,
|
||||||
|
err,
|
||||||
|
finished.Sub(tx.ZeroTime),
|
||||||
|
):
|
||||||
|
default: // buffer is full
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// ignore UDP connect attempts because they cannot fail
|
||||||
|
// in interesting ways that make sense for censorship
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewArchivalTCPConnectResult generates a model.ArchivalTCPConnectResult
|
||||||
|
// from the available information right after connect returns.
|
||||||
|
func NewArchivalTCPConnectResult(index int64, started time.Duration, address string,
|
||||||
|
err error, finished time.Duration) *model.ArchivalTCPConnectResult {
|
||||||
|
ip, port := archivalSplitHostPort(address)
|
||||||
|
return &model.ArchivalTCPConnectResult{
|
||||||
|
IP: ip,
|
||||||
|
Port: archivalPortToString(port),
|
||||||
|
Status: model.ArchivalTCPConnectStatus{
|
||||||
|
Blocked: nil,
|
||||||
|
Failure: tracex.NewFailure(err),
|
||||||
|
Success: err == nil,
|
||||||
|
},
|
||||||
|
T: finished.Seconds(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// archivalSplitHostPort is like net.SplitHostPort but does not return an error. This
|
||||||
|
// function returns two empty strings in case of any failure.
|
||||||
|
func archivalSplitHostPort(endpoint string) (string, string) {
|
||||||
|
addr, port, err := net.SplitHostPort(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("BUG: archivalSplitHostPort: invalid endpoint: %s", endpoint)
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
return addr, port
|
||||||
|
}
|
||||||
|
|
||||||
|
// archivalPortToString is like strconv.Atoi but does not return an error. This
|
||||||
|
// function returns a zero port number in case of any failure.
|
||||||
|
func archivalPortToString(sport string) int {
|
||||||
|
port, err := strconv.Atoi(sport)
|
||||||
|
if err != nil || port < 0 || port > math.MaxUint16 {
|
||||||
|
log.Printf("BUG: archivalStrconvAtoi: invalid port: %s", sport)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return port
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPConnects drains the network events buffered inside the TCPConnect channel.
|
||||||
|
func (tx *Trace) TCPConnects() (out []*model.ArchivalTCPConnectResult) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev := <-tx.TCPConnect:
|
||||||
|
out = append(out, ev)
|
||||||
|
default:
|
||||||
|
return // done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
211
internal/measurexlite/dialer_test.go
Normal file
211
internal/measurexlite/dialer_test.go
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewDialerWithoutResolver(t *testing.T) {
|
||||||
|
t.Run("NewDialerWithoutResolver creates a wrapped dialer", func(t *testing.T) {
|
||||||
|
underlying := &mocks.Dialer{}
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
|
||||||
|
return underlying
|
||||||
|
}
|
||||||
|
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
dt := dialer.(*dialerTrace)
|
||||||
|
if dt.d != underlying {
|
||||||
|
t.Fatal("invalid dialer")
|
||||||
|
}
|
||||||
|
if dt.tx != trace {
|
||||||
|
t.Fatal("invalid trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DialContext calls the underlying dialer with context-based tracing", func(t *testing.T) {
|
||||||
|
expectedErr := errors.New("mocked err")
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
var hasCorrectTrace bool
|
||||||
|
underlying := &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
gotTrace := netxlite.ContextTraceOrDefault(ctx)
|
||||||
|
hasCorrectTrace = (gotTrace == trace)
|
||||||
|
return nil, expectedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
|
||||||
|
return underlying
|
||||||
|
}
|
||||||
|
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||||
|
if !errors.Is(err, expectedErr) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
if !hasCorrectTrace {
|
||||||
|
t.Fatal("does not have the correct trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnection is correctly forwarded", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
underlying := &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
|
||||||
|
return underlying
|
||||||
|
}
|
||||||
|
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
dialer.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DialContext saves into the trace", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now // deterministic time tracking
|
||||||
|
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // we cancel immediately so connect is ~instantaneous
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||||
|
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
events := trace.TCPConnects()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("expected to see single TCPConnect event")
|
||||||
|
}
|
||||||
|
expectedFailure := netxlite.FailureInterrupted
|
||||||
|
expect := &model.ArchivalTCPConnectResult{
|
||||||
|
IP: "1.1.1.1",
|
||||||
|
Port: 443,
|
||||||
|
Status: model.ArchivalTCPConnectStatus{
|
||||||
|
Blocked: nil,
|
||||||
|
Failure: &expectedFailure,
|
||||||
|
Success: false,
|
||||||
|
},
|
||||||
|
T: time.Second.Seconds(),
|
||||||
|
}
|
||||||
|
got := events[0]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DialContext discards events when buffer is full", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TCPConnect = make(chan *model.ArchivalTCPConnectResult) // no buffer
|
||||||
|
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // we cancel immediately so connect is ~instantaneous
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||||
|
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
events := trace.TCPConnects()
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatal("expected to see no TCPConnect events")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DialContext ignores UDP connect attempts", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // we cancel immediately so connect is ~instantaneous
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", "1.1.1.1:443")
|
||||||
|
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
events := trace.TCPConnects()
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatal("expected to see no TCPConnect events")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DialContext uses a dialer without a resolver", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // we cancel immediately so connect is ~instantaneous
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", "dns.google:443") // domain
|
||||||
|
if !errors.Is(err, netxlite.ErrNoResolver) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
events := trace.TCPConnects()
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatal("expected to see no TCPConnect events")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArchivalSplitHostPort(t *testing.T) {
|
||||||
|
addr, port := archivalSplitHostPort("1.1.1.1") // missing port
|
||||||
|
if addr != "" {
|
||||||
|
t.Fatal("invalid addr", addr)
|
||||||
|
}
|
||||||
|
if port != "" {
|
||||||
|
t.Fatal("invalid port", port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArchivalPortToString(t *testing.T) {
|
||||||
|
t.Run("with invalid number", func(t *testing.T) {
|
||||||
|
port := archivalPortToString("antani")
|
||||||
|
if port != 0 {
|
||||||
|
t.Fatal("invalid port")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with negative number", func(t *testing.T) {
|
||||||
|
port := archivalPortToString("-1")
|
||||||
|
if port != 0 {
|
||||||
|
t.Fatal("invalid port")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with too-large positive number", func(t *testing.T) {
|
||||||
|
port := archivalPortToString(strconv.Itoa(math.MaxUint16 + 1))
|
||||||
|
if port != 0 {
|
||||||
|
t.Fatal("invalid port")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
10
internal/measurexlite/doc.go
Normal file
10
internal/measurexlite/doc.go
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
// Package measurexlite contains measurement extensions.
|
||||||
|
//
|
||||||
|
// See docs/design/dd-003-step-by-step.md in the ooni/probe-cli
|
||||||
|
// repository for the design document.
|
||||||
|
//
|
||||||
|
// This implementation features a Trace that saves events in
|
||||||
|
// buffered channels as proposed by df-003-step-by-step.md. We
|
||||||
|
// have reasonable default buffers for channels. But, if you
|
||||||
|
// are not draining them, eventually we stop collecting events.
|
||||||
|
package measurexlite
|
16
internal/measurexlite/logger.go
Normal file
16
internal/measurexlite/logger.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Logging support
|
||||||
|
//
|
||||||
|
|
||||||
|
import "github.com/ooni/probe-cli/v3/internal/measurex"
|
||||||
|
|
||||||
|
// TODO(bassosimone): we should eventually remove measurex and
|
||||||
|
// move the logging code from measurex to this package.
|
||||||
|
|
||||||
|
// NewOperationLogger is an alias for measurex.NewOperationLogger.
|
||||||
|
var NewOperationLogger = measurex.NewOperationLogger
|
||||||
|
|
||||||
|
// OperationLogger is an alias for measurex.OperationLogger.
|
||||||
|
type OperationLogger = measurex.OperationLogger
|
144
internal/measurexlite/tls.go
Normal file
144
internal/measurexlite/tls.go
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// TLS tracing
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewTLSHandshakerStdlib is equivalent to netxlite.NewTLSHandshakerStdlib
|
||||||
|
// except that it returns a model.TLSHandshaker that uses this trace.
|
||||||
|
func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
|
||||||
|
return &tlsHandshakerTrace{
|
||||||
|
thx: tx.newTLSHandshakerStdlib(dl),
|
||||||
|
tx: tx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tlsHandshakerTrace is a trace-aware TLS handshaker.
|
||||||
|
type tlsHandshakerTrace struct {
|
||||||
|
thx model.TLSHandshaker
|
||||||
|
tx *Trace
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
|
||||||
|
|
||||||
|
// Handshake implements model.TLSHandshaker.Handshake.
|
||||||
|
func (thx *tlsHandshakerTrace) Handshake(
|
||||||
|
ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
return thx.thx.Handshake(netxlite.ContextWithTrace(ctx, thx.tx), conn, tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
|
||||||
|
func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
|
||||||
|
t := now.Sub(tx.ZeroTime)
|
||||||
|
select {
|
||||||
|
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_start"):
|
||||||
|
default: // buffer is full
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
|
||||||
|
func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||||
|
state tls.ConnectionState, err error, finished time.Time) {
|
||||||
|
t := finished.Sub(tx.ZeroTime)
|
||||||
|
select {
|
||||||
|
case tx.TLSHandshake <- NewArchivalTLSOrQUICHandshakeResult(
|
||||||
|
tx.Index,
|
||||||
|
started.Sub(tx.ZeroTime),
|
||||||
|
remoteAddr,
|
||||||
|
config,
|
||||||
|
state,
|
||||||
|
err,
|
||||||
|
t,
|
||||||
|
):
|
||||||
|
default: // buffer is full
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_done"):
|
||||||
|
default: // buffer is full
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewArchivalTLSOrQUICHandshakeResult generates a model.ArchivalTLSOrQUICHandshakeResult
|
||||||
|
// from the available information right after the TLS handshake returns.
|
||||||
|
func NewArchivalTLSOrQUICHandshakeResult(
|
||||||
|
index int64, started time.Duration, address string, config *tls.Config,
|
||||||
|
state tls.ConnectionState, err error, finished time.Duration) *model.ArchivalTLSOrQUICHandshakeResult {
|
||||||
|
return &model.ArchivalTLSOrQUICHandshakeResult{
|
||||||
|
Address: address,
|
||||||
|
CipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
|
||||||
|
Failure: tracex.NewFailure(err),
|
||||||
|
NegotiatedProtocol: state.NegotiatedProtocol,
|
||||||
|
NoTLSVerify: config.InsecureSkipVerify,
|
||||||
|
PeerCertificates: TLSPeerCerts(state, err),
|
||||||
|
ServerName: config.ServerName,
|
||||||
|
T: finished.Seconds(),
|
||||||
|
Tags: []string{},
|
||||||
|
TLSVersion: netxlite.TLSVersionString(state.Version),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newArchivalBinaryData is a factory that adapts binary data to the
|
||||||
|
// model.ArchivalMaybeBinaryData format.
|
||||||
|
func newArchivalBinaryData(data []byte) model.ArchivalMaybeBinaryData {
|
||||||
|
// TODO(https://github.com/ooni/probe/issues/2165): we should actually extend the
|
||||||
|
// model's archival data format to have a pure-binary-data type for the cases in which
|
||||||
|
// we know in advance we're dealing with binary data.
|
||||||
|
return model.ArchivalMaybeBinaryData{
|
||||||
|
Value: string(data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSPeerCerts extracts the certificates either from the list of certificates
|
||||||
|
// in the connection state or from the error that occurred.
|
||||||
|
func TLSPeerCerts(
|
||||||
|
state tls.ConnectionState, err error) (out []model.ArchivalMaybeBinaryData) {
|
||||||
|
out = []model.ArchivalMaybeBinaryData{}
|
||||||
|
var x509HostnameError x509.HostnameError
|
||||||
|
if errors.As(err, &x509HostnameError) {
|
||||||
|
// Test case: https://wrong.host.badssl.com/
|
||||||
|
out = append(out, newArchivalBinaryData(x509HostnameError.Certificate.Raw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var x509UnknownAuthorityError x509.UnknownAuthorityError
|
||||||
|
if errors.As(err, &x509UnknownAuthorityError) {
|
||||||
|
// Test case: https://self-signed.badssl.com/. This error has
|
||||||
|
// never been among the ones returned by MK.
|
||||||
|
out = append(out, newArchivalBinaryData(x509UnknownAuthorityError.Cert.Raw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var x509CertificateInvalidError x509.CertificateInvalidError
|
||||||
|
if errors.As(err, &x509CertificateInvalidError) {
|
||||||
|
// Test case: https://expired.badssl.com/
|
||||||
|
out = append(out, newArchivalBinaryData(x509CertificateInvalidError.Cert.Raw))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, cert := range state.PeerCertificates {
|
||||||
|
out = append(out, newArchivalBinaryData(cert.Raw))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSHandshakes drains the network events buffered inside the TLSHandshake channel.
|
||||||
|
func (tx *Trace) TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev := <-tx.TLSHandshake:
|
||||||
|
out = append(out, ev)
|
||||||
|
default:
|
||||||
|
return // done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
418
internal/measurexlite/tls_test.go
Normal file
418
internal/measurexlite/tls_test.go
Normal file
|
@ -0,0 +1,418 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||||
|
t.Run("NewTLSHandshakerStdlib creates a wrapped TLSHandshaker", func(t *testing.T) {
|
||||||
|
underlying := &mocks.TLSHandshaker{}
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
|
||||||
|
return underlying
|
||||||
|
}
|
||||||
|
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
thxt := thx.(*tlsHandshakerTrace)
|
||||||
|
if thxt.thx != underlying {
|
||||||
|
t.Fatal("invalid TLS handshaker")
|
||||||
|
}
|
||||||
|
if thxt.tx != trace {
|
||||||
|
t.Fatal("invalid trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Handshake calls the underlying dialer with context-based tracing", func(t *testing.T) {
|
||||||
|
expectedErr := errors.New("mocked err")
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
var hasCorrectTrace bool
|
||||||
|
underlying := &mocks.TLSHandshaker{
|
||||||
|
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
gotTrace := netxlite.ContextTraceOrDefault(ctx)
|
||||||
|
hasCorrectTrace = (gotTrace == trace)
|
||||||
|
return nil, tls.ConnectionState{}, expectedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
|
||||||
|
return underlying
|
||||||
|
}
|
||||||
|
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||||
|
if !errors.Is(err, expectedErr) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("expected zero-value state")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
if !hasCorrectTrace {
|
||||||
|
t.Fatal("does not have the correct trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Handshake saves into the trace", func(t *testing.T) {
|
||||||
|
mockedErr := errors.New("mocked")
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now // deterministic timing
|
||||||
|
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // we cancel immediately so connect is ~instantaneous
|
||||||
|
tcpConn := &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return 0, mockedErr
|
||||||
|
},
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: "dns.cloudflare.com",
|
||||||
|
}
|
||||||
|
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||||
|
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("expected zero-value state")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("TLSHandshake events", func(t *testing.T) {
|
||||||
|
events := trace.TLSHandshakes()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("expected to see single TLSHandshake event")
|
||||||
|
}
|
||||||
|
expectedFailure := netxlite.FailureInterrupted
|
||||||
|
expect := &model.ArchivalTLSOrQUICHandshakeResult{
|
||||||
|
Address: "1.1.1.1:443",
|
||||||
|
CipherSuite: "",
|
||||||
|
Failure: &expectedFailure,
|
||||||
|
NegotiatedProtocol: "",
|
||||||
|
NoTLSVerify: true,
|
||||||
|
PeerCertificates: []model.ArchivalMaybeBinaryData{},
|
||||||
|
ServerName: "dns.cloudflare.com",
|
||||||
|
T: time.Second.Seconds(),
|
||||||
|
Tags: []string{},
|
||||||
|
TLSVersion: "",
|
||||||
|
}
|
||||||
|
got := events[0]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network events", func(t *testing.T) {
|
||||||
|
events := trace.NetworkEvents()
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Fatal("expected to see two Network events")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("tls_handshake_start", func(t *testing.T) {
|
||||||
|
expect := &model.ArchivalNetworkEvent{
|
||||||
|
Address: "",
|
||||||
|
Failure: nil,
|
||||||
|
NumBytes: 0,
|
||||||
|
Operation: "tls_handshake_start",
|
||||||
|
Proto: "",
|
||||||
|
T: 0,
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
got := events[0]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tls_handshake_done", func(t *testing.T) {
|
||||||
|
expect := &model.ArchivalNetworkEvent{
|
||||||
|
Address: "",
|
||||||
|
Failure: nil,
|
||||||
|
NumBytes: 0,
|
||||||
|
Operation: "tls_handshake_done",
|
||||||
|
Proto: "",
|
||||||
|
T: time.Second.Seconds(),
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
got := events[1]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Handshake discards events when buffers are full", func(t *testing.T) {
|
||||||
|
mockedErr := errors.New("mocked")
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
|
||||||
|
trace.TLSHandshake = make(chan *model.ArchivalTLSOrQUICHandshakeResult) // no buffer
|
||||||
|
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // we cancel immediately so connect is ~instantaneous
|
||||||
|
tcpConn := &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return 0, mockedErr
|
||||||
|
},
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: "dns.cloudflare.com",
|
||||||
|
}
|
||||||
|
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||||
|
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("expected zero-value state")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("TLSHandshake events", func(t *testing.T) {
|
||||||
|
events := trace.TLSHandshakes()
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatal("expected to see no TLSHandshake events")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network events", func(t *testing.T) {
|
||||||
|
events := trace.NetworkEvents()
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatal("expected to see no Network events")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("we collect the desired data with a local TLS server", func(t *testing.T) {
|
||||||
|
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||||
|
dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", server.Endpoint())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
zeroTime := time.Now()
|
||||||
|
dt := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = dt.Now // deterministic timing
|
||||||
|
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
RootCAs: server.CertPool(),
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
tlsConn, connState, err := thx.Handshake(ctx, conn, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tlsConn.Close()
|
||||||
|
data, err := netxlite.ReadAllContext(ctx, tlsConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, filtering.HTTPBlockpage451) {
|
||||||
|
t.Fatal("bytes should match")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("TLSHandshake events", func(t *testing.T) {
|
||||||
|
events := trace.TLSHandshakes()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("expected to see a single TLSHandshake event")
|
||||||
|
}
|
||||||
|
expected := &model.ArchivalTLSOrQUICHandshakeResult{
|
||||||
|
Address: conn.RemoteAddr().String(),
|
||||||
|
CipherSuite: netxlite.TLSCipherSuiteString(connState.CipherSuite),
|
||||||
|
Failure: nil,
|
||||||
|
NegotiatedProtocol: "",
|
||||||
|
NoTLSVerify: false,
|
||||||
|
PeerCertificates: []model.ArchivalMaybeBinaryData{},
|
||||||
|
ServerName: "dns.google",
|
||||||
|
T: time.Second.Seconds(),
|
||||||
|
Tags: []string{},
|
||||||
|
TLSVersion: netxlite.TLSVersionString(connState.Version),
|
||||||
|
}
|
||||||
|
got := events[0]
|
||||||
|
// TODO(bassosimone): it's still unclear to me how to test that
|
||||||
|
// I am getting exactly the expected certificate here. I think the
|
||||||
|
// certificate is generated on the fly by google/martian. So, I'm
|
||||||
|
// just going to reduce the precision of this check.
|
||||||
|
if len(got.PeerCertificates) != 2 {
|
||||||
|
t.Fatal("expected to see two certificates")
|
||||||
|
}
|
||||||
|
got.PeerCertificates = []model.ArchivalMaybeBinaryData{} // see above
|
||||||
|
if diff := cmp.Diff(expected, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network events", func(t *testing.T) {
|
||||||
|
events := trace.NetworkEvents()
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Fatal("expected to see two Network events")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("tls_handshake_start", func(t *testing.T) {
|
||||||
|
expect := &model.ArchivalNetworkEvent{
|
||||||
|
Address: "",
|
||||||
|
Failure: nil,
|
||||||
|
NumBytes: 0,
|
||||||
|
Operation: "tls_handshake_start",
|
||||||
|
Proto: "",
|
||||||
|
T: 0,
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
got := events[0]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tls_handshake_done", func(t *testing.T) {
|
||||||
|
expect := &model.ArchivalNetworkEvent{
|
||||||
|
Address: "",
|
||||||
|
Failure: nil,
|
||||||
|
NumBytes: 0,
|
||||||
|
Operation: "tls_handshake_done",
|
||||||
|
Proto: "",
|
||||||
|
T: time.Second.Seconds(),
|
||||||
|
Tags: []string{},
|
||||||
|
}
|
||||||
|
got := events[1]
|
||||||
|
if diff := cmp.Diff(expect, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSPeerCerts(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
state tls.ConnectionState
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantOut []model.ArchivalMaybeBinaryData
|
||||||
|
}{{
|
||||||
|
name: "x509.HostnameError",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{},
|
||||||
|
err: x509.HostnameError{
|
||||||
|
Certificate: &x509.Certificate{
|
||||||
|
Raw: []byte("deadbeef"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||||
|
Value: "deadbeef",
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
name: "x509.UnknownAuthorityError",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{},
|
||||||
|
err: x509.UnknownAuthorityError{
|
||||||
|
Cert: &x509.Certificate{
|
||||||
|
Raw: []byte("deadbeef"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||||
|
Value: "deadbeef",
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
name: "x509.CertificateInvalidError",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{},
|
||||||
|
err: x509.CertificateInvalidError{
|
||||||
|
Cert: &x509.Certificate{
|
||||||
|
Raw: []byte("deadbeef"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||||
|
Value: "deadbeef",
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
name: "successful case",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{
|
||||||
|
PeerCertificates: []*x509.Certificate{{
|
||||||
|
Raw: []byte("deadbeef"),
|
||||||
|
}, {
|
||||||
|
Raw: []byte("abad1dea"),
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
wantOut: []model.ArchivalMaybeBinaryData{{
|
||||||
|
Value: "deadbeef",
|
||||||
|
}, {
|
||||||
|
Value: "abad1dea",
|
||||||
|
}},
|
||||||
|
}}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotOut := TLSPeerCerts(tt.args.state, tt.args.err)
|
||||||
|
if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
143
internal/measurexlite/trace.go
Normal file
143
internal/measurexlite/trace.go
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Definition of Trace
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trace implements model.Trace.
|
||||||
|
//
|
||||||
|
// The zero-value of this struct is invalid. To construct you should either
|
||||||
|
// fill all the fields marked as MANDATORY or use NewTrace.
|
||||||
|
//
|
||||||
|
// Buffered channels
|
||||||
|
//
|
||||||
|
// NewTrace uses reasonable buffer sizes for the channels used for collecting
|
||||||
|
// events. You should drain the channels used by this implementation after
|
||||||
|
// each operation you perform (i.e., we expect you to peform step-by-step
|
||||||
|
// measurements). If you want larger (or smaller) buffers, then you should
|
||||||
|
// construct this data type manually with the desired buffer sizes.
|
||||||
|
//
|
||||||
|
// We have convenience methods for extracting events from the buffered
|
||||||
|
// channels. Otherwise, you could read the channels directly. (In which
|
||||||
|
// case, remember to issue nonblocking channel reads because channels are
|
||||||
|
// never closed and they're just written when new events occur.)
|
||||||
|
type Trace struct {
|
||||||
|
// Index is the MANDATORY unique index of this trace within the
|
||||||
|
// current measurement. If you don't care about uniquely identifying
|
||||||
|
// treaces, you can use zero to indicate the "default" trace.
|
||||||
|
Index int64
|
||||||
|
|
||||||
|
// NetworkEvent is MANDATORY and buffers network events. If you create
|
||||||
|
// this channel manually, ensure it has some buffer.
|
||||||
|
NetworkEvent chan *model.ArchivalNetworkEvent
|
||||||
|
|
||||||
|
// NewDialerWithoutResolverFn is OPTIONAL and can be used to override
|
||||||
|
// calls to the netxlite.NewDialerWithoutResolver factory.
|
||||||
|
NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer
|
||||||
|
|
||||||
|
// NewTLSHandshakerStdlibFn is OPTIONAL and can be used to overide
|
||||||
|
// calls to the netxlite.NewTLSHandshakerStdlib factory.
|
||||||
|
NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker
|
||||||
|
|
||||||
|
// TCPConnect is MANDATORY and buffers TCP connect observations. If you create
|
||||||
|
// this channel manually, ensure it has some buffer.
|
||||||
|
TCPConnect chan *model.ArchivalTCPConnectResult
|
||||||
|
|
||||||
|
// TLSHandshake is MANDATORY and buffers TLS handshake observations. If you create
|
||||||
|
// this channel manually, ensure it has some buffer.
|
||||||
|
TLSHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
||||||
|
|
||||||
|
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
|
||||||
|
// to produce deterministic timing when testing.
|
||||||
|
TimeNowFn func() time.Time
|
||||||
|
|
||||||
|
// ZeroTime is the MANDATORY time when we started the current measurement.
|
||||||
|
ZeroTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// NetworkEventBufferSize is the buffer size for constructing
|
||||||
|
// the Trace's NetworkEvent buffered channel.
|
||||||
|
NetworkEventBufferSize = 64
|
||||||
|
|
||||||
|
// TCPConnectBufferSize is the buffer size for constructing
|
||||||
|
// the Trace's TCPConnect buffered channel.
|
||||||
|
TCPConnectBufferSize = 8
|
||||||
|
|
||||||
|
// TLSHandshakeBufferSize is the buffer for construcing
|
||||||
|
// the Trace's TLSHandshake buffered channel.
|
||||||
|
TLSHandshakeBufferSize = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewTrace creates a new instance of Trace using default settings.
|
||||||
|
//
|
||||||
|
// We create buffered channels using as buffer sizes the constants that
|
||||||
|
// are also defined by this package.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// - index is the unique index of this trace within the current measurement (use
|
||||||
|
// zero if you don't care about giving this trace a unique ID);
|
||||||
|
//
|
||||||
|
// - zeroTime is the time when we started the current measurement.
|
||||||
|
func NewTrace(index int64, zeroTime time.Time) *Trace {
|
||||||
|
return &Trace{
|
||||||
|
Index: index,
|
||||||
|
NetworkEvent: make(
|
||||||
|
chan *model.ArchivalNetworkEvent,
|
||||||
|
NetworkEventBufferSize,
|
||||||
|
),
|
||||||
|
NewDialerWithoutResolverFn: nil, // use default
|
||||||
|
NewTLSHandshakerStdlibFn: nil, // use default
|
||||||
|
TCPConnect: make(
|
||||||
|
chan *model.ArchivalTCPConnectResult,
|
||||||
|
TCPConnectBufferSize,
|
||||||
|
),
|
||||||
|
TLSHandshake: make(
|
||||||
|
chan *model.ArchivalTLSOrQUICHandshakeResult,
|
||||||
|
TLSHandshakeBufferSize,
|
||||||
|
),
|
||||||
|
TimeNowFn: nil, // use default
|
||||||
|
ZeroTime: zeroTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver
|
||||||
|
// thus allows us to mock this func for testing.
|
||||||
|
func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
|
||||||
|
if tx.NewDialerWithoutResolverFn != nil {
|
||||||
|
return tx.NewDialerWithoutResolverFn(dl)
|
||||||
|
}
|
||||||
|
return netxlite.NewDialerWithoutResolver(dl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
|
||||||
|
// thus allowing us to mock this func for testing.
|
||||||
|
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
|
||||||
|
if tx.NewTLSHandshakerStdlibFn != nil {
|
||||||
|
return tx.NewTLSHandshakerStdlibFn(dl)
|
||||||
|
}
|
||||||
|
return netxlite.NewTLSHandshakerStdlib(dl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeNow implements model.Trace.TimeNow.
|
||||||
|
func (tx *Trace) TimeNow() time.Time {
|
||||||
|
if tx.TimeNowFn != nil {
|
||||||
|
return tx.TimeNowFn()
|
||||||
|
}
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeSince is equivalent to Trace.TimeNow().Sub(t0).
|
||||||
|
func (tx *Trace) TimeSince(t0 time.Time) time.Duration {
|
||||||
|
return tx.TimeNow().Sub(t0)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.Trace = &Trace{}
|
248
internal/measurexlite/trace_test.go
Normal file
248
internal/measurexlite/trace_test.go
Normal file
|
@ -0,0 +1,248 @@
|
||||||
|
package measurexlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewTrace(t *testing.T) {
|
||||||
|
t.Run("NewTrace correctly constructs a trace", func(t *testing.T) {
|
||||||
|
const index = 17
|
||||||
|
zeroTime := time.Now()
|
||||||
|
trace := NewTrace(index, zeroTime)
|
||||||
|
|
||||||
|
t.Run("Index", func(t *testing.T) {
|
||||||
|
if trace.Index != index {
|
||||||
|
t.Fatal("invalid index")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NetworkEvent has the expected buffer size", func(t *testing.T) {
|
||||||
|
ff := &testingx.FakeFiller{}
|
||||||
|
var idx int
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
ev := &model.ArchivalNetworkEvent{}
|
||||||
|
ff.Fill(ev)
|
||||||
|
select {
|
||||||
|
case trace.NetworkEvent <- ev:
|
||||||
|
idx++
|
||||||
|
default:
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx != NetworkEventBufferSize {
|
||||||
|
t.Fatal("invalid NetworkEvent channel buffer size")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) {
|
||||||
|
if trace.NewDialerWithoutResolverFn != nil {
|
||||||
|
t.Fatal("expected nil NewDialerWithoutResolverFn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) {
|
||||||
|
if trace.NewTLSHandshakerStdlibFn != nil {
|
||||||
|
t.Fatal("expected nil NewTLSHandshakerStdlibFn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
|
||||||
|
ff := &testingx.FakeFiller{}
|
||||||
|
var idx int
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
ev := &model.ArchivalTCPConnectResult{}
|
||||||
|
ff.Fill(ev)
|
||||||
|
select {
|
||||||
|
case trace.TCPConnect <- ev:
|
||||||
|
idx++
|
||||||
|
default:
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx != TCPConnectBufferSize {
|
||||||
|
t.Fatal("invalid TCPConnect channel buffer size")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
|
||||||
|
ff := &testingx.FakeFiller{}
|
||||||
|
var idx int
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
ev := &model.ArchivalTLSOrQUICHandshakeResult{}
|
||||||
|
ff.Fill(ev)
|
||||||
|
select {
|
||||||
|
case trace.TLSHandshake <- ev:
|
||||||
|
idx++
|
||||||
|
default:
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx != TLSHandshakeBufferSize {
|
||||||
|
t.Fatal("invalid TLSHandshake channel buffer size")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TimeNowFn is nil", func(t *testing.T) {
|
||||||
|
if trace.TimeNowFn != nil {
|
||||||
|
t.Fatal("expected nil TimeNowFn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ZeroTime", func(t *testing.T) {
|
||||||
|
if !trace.ZeroTime.Equal(zeroTime) {
|
||||||
|
t.Fatal("invalid zero time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrace(t *testing.T) {
|
||||||
|
t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) {
|
||||||
|
t.Run("when not nil", func(t *testing.T) {
|
||||||
|
mockedErr := errors.New("mocked")
|
||||||
|
tx := &Trace{
|
||||||
|
NewDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer {
|
||||||
|
return &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return nil, mockedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||||
|
if !errors.Is(err, mockedErr) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when nil", func(t *testing.T) {
|
||||||
|
tx := &Trace{
|
||||||
|
NewDialerWithoutResolverFn: nil,
|
||||||
|
}
|
||||||
|
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // fail immediately
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
|
||||||
|
if err == nil || err.Error() != netxlite.FailureInterrupted {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) {
|
||||||
|
t.Run("when not nil", func(t *testing.T) {
|
||||||
|
mockedErr := errors.New("mocked")
|
||||||
|
tx := &Trace{
|
||||||
|
NewTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker {
|
||||||
|
return &mocks.TLSHandshaker{
|
||||||
|
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
return nil, tls.ConnectionState{}, mockedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||||
|
if !errors.Is(err, mockedErr) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("state is not a zero value")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when nil", func(t *testing.T) {
|
||||||
|
mockedErr := errors.New("mocked")
|
||||||
|
tx := &Trace{
|
||||||
|
NewTLSHandshakerStdlibFn: nil,
|
||||||
|
}
|
||||||
|
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
tcpConn := &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return 0, mockedErr
|
||||||
|
},
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||||
|
if !errors.Is(err, mockedErr) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("state is not a zero value")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TimeNowFn works as intended", func(t *testing.T) {
|
||||||
|
fixedTime := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
|
||||||
|
tx := &Trace{
|
||||||
|
TimeNowFn: func() time.Time {
|
||||||
|
return fixedTime
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if !tx.TimeNow().Equal(fixedTime) {
|
||||||
|
t.Fatal("we cannot override time.Now calls")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TimeSince works as intended", func(t *testing.T) {
|
||||||
|
t0 := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
|
||||||
|
t1 := t0.Add(10 * time.Second)
|
||||||
|
tx := &Trace{
|
||||||
|
TimeNowFn: func() time.Time {
|
||||||
|
return t1
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if tx.TimeSince(t0) != 10*time.Second {
|
||||||
|
t.Fatal("apparently Trace.Since is broken")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ooni/probe-cli/v3/internal/fakefill"
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestArchivalExtSpec(t *testing.T) {
|
func TestArchivalExtSpec(t *testing.T) {
|
||||||
|
@ -301,7 +301,7 @@ func TestHTTPBody(t *testing.T) {
|
||||||
// we make a mistake and apply the above change (which will in turn
|
// we make a mistake and apply the above change (which will in turn
|
||||||
// break correct JSON serialization), the this test will fail.
|
// break correct JSON serialization), the this test will fail.
|
||||||
var body ArchivalHTTPBody
|
var body ArchivalHTTPBody
|
||||||
ff := &fakefill.Filler{}
|
ff := &testingx.FakeFiller{}
|
||||||
ff.Fill(&body)
|
ff.Fill(&body)
|
||||||
data := ArchivalMaybeBinaryData(body)
|
data := ArchivalMaybeBinaryData(body)
|
||||||
if diff := cmp.Diff(body, data); diff != "" {
|
if diff := cmp.Diff(body, data); diff != "" {
|
||||||
|
|
45
internal/model/mocks/trace.go
Normal file
45
internal/model/mocks/trace.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
//
|
||||||
|
// Mocks for model.Trace
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trace allows mocking model.Trace.
|
||||||
|
type Trace struct {
|
||||||
|
MockTimeNow func() time.Time
|
||||||
|
|
||||||
|
MockOnConnectDone func(
|
||||||
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
|
||||||
|
|
||||||
|
MockOnTLSHandshakeStart func(now time.Time, remoteAddr string, config *tls.Config)
|
||||||
|
|
||||||
|
MockOnTLSHandshakeDone func(started time.Time, remoteAddr string, config *tls.Config,
|
||||||
|
state tls.ConnectionState, err error, finished time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.Trace = &Trace{}
|
||||||
|
|
||||||
|
func (t *Trace) TimeNow() time.Time {
|
||||||
|
return t.MockTimeNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Trace) OnConnectDone(
|
||||||
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||||
|
t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
|
||||||
|
t.MockOnTLSHandshakeStart(now, remoteAddr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||||
|
state tls.ConnectionState, err error, finished time.Time) {
|
||||||
|
t.MockOnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
|
||||||
|
}
|
74
internal/model/mocks/trace_test.go
Normal file
74
internal/model/mocks/trace_test.go
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTrace(t *testing.T) {
|
||||||
|
t.Run("TimeNow", func(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
tx := &Trace{
|
||||||
|
MockTimeNow: func() time.Time {
|
||||||
|
return now
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if !tx.TimeNow().Equal(now) {
|
||||||
|
t.Fatal("not working as intended")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OnConnectDone", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
tx := &Trace{
|
||||||
|
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tx.OnConnectDone(
|
||||||
|
time.Now(),
|
||||||
|
"tcp",
|
||||||
|
"dns.google",
|
||||||
|
"8.8.8.8:443",
|
||||||
|
nil,
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OnTLSHandshakeStart", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
tx := &Trace{
|
||||||
|
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tx.OnTLSHandshakeStart(time.Now(), "8.8.8.8:443", &tls.Config{})
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("OnTLSHandshakeDone", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
tx := &Trace{
|
||||||
|
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tx.OnTLSHandshakeDone(
|
||||||
|
time.Now(),
|
||||||
|
"8.8.8.8:443",
|
||||||
|
&tls.Config{},
|
||||||
|
tls.ConnectionState{},
|
||||||
|
nil,
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -292,6 +292,76 @@ type TLSHandshaker interface {
|
||||||
net.Conn, tls.ConnectionState, error)
|
net.Conn, tls.ConnectionState, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trace allows to collect measurement traces. A trace is injected into
|
||||||
|
// netx operations using context.WithValue. Netx code retrieves the trace
|
||||||
|
// using context.Value. See docs/design/dd-003-step-by-step.md for the
|
||||||
|
// design document explaining why we implemented context-based tracing.
|
||||||
|
type Trace interface {
|
||||||
|
// TimeNow returns the current time. Normally, this should be the same
|
||||||
|
// value returned by time.Now but you may want to manipulate the time
|
||||||
|
// returned when testing to have deterministic tests. To this end, you
|
||||||
|
// can use functionality exported by the ./internal/testingx pkg.
|
||||||
|
TimeNow() time.Time
|
||||||
|
|
||||||
|
// OnConnectDone is called when connect terminates.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// - started is when we called connect;
|
||||||
|
//
|
||||||
|
// - network is the network we're using (one of "tcp" and "udp");
|
||||||
|
//
|
||||||
|
// - domain is the domain for which we're calling connect. If the user called
|
||||||
|
// connect for an IP address and a port, then domain will be an IP address;
|
||||||
|
//
|
||||||
|
// - remoteAddr is the TCP endpoint with which we are connecting: it will
|
||||||
|
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
|
||||||
|
//
|
||||||
|
// - err is the result of connect: either an error or nil;
|
||||||
|
//
|
||||||
|
// - finished is when connect returned.
|
||||||
|
//
|
||||||
|
// The error passed to this function will always be wrapped such that the
|
||||||
|
// string returned by Error is an OONI error.
|
||||||
|
OnConnectDone(
|
||||||
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
|
||||||
|
|
||||||
|
// OnTLSHandshakeStart is called when the TLS handshake starts.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// - now is the moment before we start the handshake;
|
||||||
|
//
|
||||||
|
// - remoteAddr is the TCP endpoint with which we are connecting: it will
|
||||||
|
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
|
||||||
|
//
|
||||||
|
// - config is the non-nil TLS config we're using.
|
||||||
|
OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config)
|
||||||
|
|
||||||
|
// OnTLSHandshakeDone is called when the TLS handshake terminates.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// - started is when we started the handshake;
|
||||||
|
//
|
||||||
|
// - remoteAddr is the TCP endpoint with which we are connecting: it will
|
||||||
|
// consist of an IP address and a port (e.g., 8.8.8.8:443, [::1]:5421);
|
||||||
|
//
|
||||||
|
// - config is the non-nil TLS config we're using;
|
||||||
|
//
|
||||||
|
// - state is the state of the TLS connection after the handshake, where all
|
||||||
|
// fields are zero-initialized if the handshake failed;
|
||||||
|
//
|
||||||
|
// - err is the result of the handshake: either an error or nil;
|
||||||
|
//
|
||||||
|
// - finished is right after the handshake.
|
||||||
|
//
|
||||||
|
// The error passed to this function will always be wrapped such that the
|
||||||
|
// string returned by Error is an OONI error.
|
||||||
|
OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||||
|
state tls.ConnectionState, err error, finished time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
// UDPLikeConn is a net.PacketConn with some extra functions
|
// UDPLikeConn is a net.PacketConn with some extra functions
|
||||||
// required to convince the QUIC library (lucas-clemente/quic-go)
|
// required to convince the QUIC library (lucas-clemente/quic-go)
|
||||||
// to inflate the receive buffer of the connection.
|
// to inflate the receive buffer of the connection.
|
||||||
|
|
|
@ -47,7 +47,7 @@ func (r *bogonResolver) LookupHost(ctx context.Context, hostname string) ([]stri
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
if IsBogon(addr) {
|
if IsBogon(addr) {
|
||||||
// wrap ErrDNSBogon as documented
|
// wrap ErrDNSBogon as documented
|
||||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, ErrDNSBogon)
|
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, ErrDNSBogon)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return addrs, nil
|
return addrs, nil
|
||||||
|
|
|
@ -15,8 +15,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/scrubber"
|
"github.com/ooni/probe-cli/v3/internal/scrubber"
|
||||||
)
|
)
|
||||||
|
|
||||||
// classifyGenericError is maps an error occurred during an operation
|
// ClassifyGenericError maps an error occurred during an operation to
|
||||||
// to an OONI failure string. This specific classifier is the most
|
// an OONI failure string. This specific classifier is the most
|
||||||
// generic one. You usually use it when mapping I/O errors. You should
|
// generic one. You usually use it when mapping I/O errors. You should
|
||||||
// check whether there is a specific classifier for more specific
|
// check whether there is a specific classifier for more specific
|
||||||
// operations (e.g., DNS resolution, TLS handshake).
|
// operations (e.g., DNS resolution, TLS handshake).
|
||||||
|
@ -38,7 +38,7 @@ import (
|
||||||
// If everything else fails, this classifier returns a string
|
// If everything else fails, this classifier returns a string
|
||||||
// like "unknown_failure: XXX" where XXX has been scrubbed
|
// like "unknown_failure: XXX" where XXX has been scrubbed
|
||||||
// so to remove any network endpoints from the original error string.
|
// so to remove any network endpoints from the original error string.
|
||||||
func classifyGenericError(err error) string {
|
func ClassifyGenericError(err error) string {
|
||||||
// The list returned here matches the values used by MK unless
|
// The list returned here matches the values used by MK unless
|
||||||
// explicitly noted otherwise with a comment.
|
// explicitly noted otherwise with a comment.
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ const (
|
||||||
quicTLSUnrecognizedName = 112
|
quicTLSUnrecognizedName = 112
|
||||||
)
|
)
|
||||||
|
|
||||||
// classifyQUICHandshakeError maps errors during a QUIC
|
// ClassifyQUICHandshakeError maps errors during a QUIC
|
||||||
// handshake to OONI failure strings.
|
// handshake to OONI failure strings.
|
||||||
//
|
//
|
||||||
// If the input error is an *ErrWrapper we don't perform
|
// If the input error is an *ErrWrapper we don't perform
|
||||||
|
@ -147,7 +147,7 @@ const (
|
||||||
//
|
//
|
||||||
// If this classifier fails, it calls ClassifyGenericError
|
// If this classifier fails, it calls ClassifyGenericError
|
||||||
// and returns to the caller its return value.
|
// and returns to the caller its return value.
|
||||||
func classifyQUICHandshakeError(err error) string {
|
func ClassifyQUICHandshakeError(err error) string {
|
||||||
|
|
||||||
// Robustness: handle the case where we're passed a wrapped error.
|
// Robustness: handle the case where we're passed a wrapped error.
|
||||||
var errwrapper *ErrWrapper
|
var errwrapper *ErrWrapper
|
||||||
|
@ -207,7 +207,7 @@ func classifyQUICHandshakeError(err error) string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return classifyGenericError(err)
|
return ClassifyGenericError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// quicIsCertificateError tells us whether a specific TLS alert error
|
// quicIsCertificateError tells us whether a specific TLS alert error
|
||||||
|
@ -277,7 +277,7 @@ var (
|
||||||
// anything as explained in getaddrinfo_linux.go.
|
// anything as explained in getaddrinfo_linux.go.
|
||||||
var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
|
var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
|
||||||
|
|
||||||
// classifyResolverError maps DNS resolution errors to
|
// ClassifyResolverError maps DNS resolution errors to
|
||||||
// OONI failure strings.
|
// OONI failure strings.
|
||||||
//
|
//
|
||||||
// If the input error is an *ErrWrapper we don't perform
|
// If the input error is an *ErrWrapper we don't perform
|
||||||
|
@ -285,7 +285,7 @@ var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
|
||||||
//
|
//
|
||||||
// If this classifier fails, it calls ClassifyGenericError and
|
// If this classifier fails, it calls ClassifyGenericError and
|
||||||
// returns to the caller its return value.
|
// returns to the caller its return value.
|
||||||
func classifyResolverError(err error) string {
|
func ClassifyResolverError(err error) string {
|
||||||
|
|
||||||
// Robustness: handle the case where we're passed a wrapped error.
|
// Robustness: handle the case where we're passed a wrapped error.
|
||||||
var errwrapper *ErrWrapper
|
var errwrapper *ErrWrapper
|
||||||
|
@ -310,10 +310,10 @@ func classifyResolverError(err error) string {
|
||||||
if errors.Is(err, ErrAndroidDNSCacheNoData) {
|
if errors.Is(err, ErrAndroidDNSCacheNoData) {
|
||||||
return FailureAndroidDNSCacheNoData
|
return FailureAndroidDNSCacheNoData
|
||||||
}
|
}
|
||||||
return classifyGenericError(err)
|
return ClassifyGenericError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// classifyTLSHandshakeError maps an error occurred during the TLS
|
// ClassifyTLSHandshakeError maps an error occurred during the TLS
|
||||||
// handshake to an OONI failure string.
|
// handshake to an OONI failure string.
|
||||||
//
|
//
|
||||||
// If the input error is an *ErrWrapper we don't perform
|
// If the input error is an *ErrWrapper we don't perform
|
||||||
|
@ -321,7 +321,7 @@ func classifyResolverError(err error) string {
|
||||||
//
|
//
|
||||||
// If this classifier fails, it calls ClassifyGenericError and
|
// If this classifier fails, it calls ClassifyGenericError and
|
||||||
// returns to the caller its return value.
|
// returns to the caller its return value.
|
||||||
func classifyTLSHandshakeError(err error) string {
|
func ClassifyTLSHandshakeError(err error) string {
|
||||||
|
|
||||||
// Robustness: handle the case where we're passed a wrapped error.
|
// Robustness: handle the case where we're passed a wrapped error.
|
||||||
var errwrapper *ErrWrapper
|
var errwrapper *ErrWrapper
|
||||||
|
@ -345,5 +345,5 @@ func classifyTLSHandshakeError(err error) string {
|
||||||
// Test case: https://expired.badssl.com/
|
// Test case: https://expired.badssl.com/
|
||||||
return FailureSSLInvalidCertificate
|
return FailureSSLInvalidCertificate
|
||||||
}
|
}
|
||||||
return classifyGenericError(err)
|
return ClassifyGenericError(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,13 +18,13 @@ func TestClassifyGenericError(t *testing.T) {
|
||||||
|
|
||||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||||
err := &ErrWrapper{Failure: FailureEOFError}
|
err := &ErrWrapper{Failure: FailureEOFError}
|
||||||
if classifyGenericError(err) != FailureEOFError {
|
if ClassifyGenericError(err) != FailureEOFError {
|
||||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for a system call error", func(t *testing.T) {
|
t.Run("for a system call error", func(t *testing.T) {
|
||||||
if classifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
|
if ClassifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -35,63 +35,63 @@ func TestClassifyGenericError(t *testing.T) {
|
||||||
// is just an implementation detail.
|
// is just an implementation detail.
|
||||||
|
|
||||||
t.Run("for operation was canceled", func(t *testing.T) {
|
t.Run("for operation was canceled", func(t *testing.T) {
|
||||||
if classifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
|
if ClassifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for EOF", func(t *testing.T) {
|
t.Run("for EOF", func(t *testing.T) {
|
||||||
if classifyGenericError(io.EOF) != FailureEOFError {
|
if ClassifyGenericError(io.EOF) != FailureEOFError {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for context deadline exceeded", func(t *testing.T) {
|
t.Run("for context deadline exceeded", func(t *testing.T) {
|
||||||
if classifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
|
if ClassifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for stun's transaction is timed out", func(t *testing.T) {
|
t.Run("for stun's transaction is timed out", func(t *testing.T) {
|
||||||
if classifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
|
if ClassifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for i/o timeout", func(t *testing.T) {
|
t.Run("for i/o timeout", func(t *testing.T) {
|
||||||
if classifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
|
if ClassifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for TLS handshake timeout", func(t *testing.T) {
|
t.Run("for TLS handshake timeout", func(t *testing.T) {
|
||||||
err := errors.New("net/http: TLS handshake timeout")
|
err := errors.New("net/http: TLS handshake timeout")
|
||||||
if classifyGenericError(err) != FailureGenericTimeoutError {
|
if ClassifyGenericError(err) != FailureGenericTimeoutError {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for no such host", func(t *testing.T) {
|
t.Run("for no such host", func(t *testing.T) {
|
||||||
if classifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
|
if ClassifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for dns server misbehaving", func(t *testing.T) {
|
t.Run("for dns server misbehaving", func(t *testing.T) {
|
||||||
if classifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
|
if ClassifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for no answer from DNS server", func(t *testing.T) {
|
t.Run("for no answer from DNS server", func(t *testing.T) {
|
||||||
if classifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
|
if ClassifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for use of closed network connection", func(t *testing.T) {
|
t.Run("for use of closed network connection", func(t *testing.T) {
|
||||||
err := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
|
err := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
|
||||||
if classifyGenericError(err) != FailureConnectionAlreadyClosed {
|
if ClassifyGenericError(err) != FailureConnectionAlreadyClosed {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -99,7 +99,7 @@ func TestClassifyGenericError(t *testing.T) {
|
||||||
// Now we're back in ClassifyGenericError
|
// Now we're back in ClassifyGenericError
|
||||||
|
|
||||||
t.Run("for context.Canceled", func(t *testing.T) {
|
t.Run("for context.Canceled", func(t *testing.T) {
|
||||||
if classifyGenericError(context.Canceled) != FailureInterrupted {
|
if ClassifyGenericError(context.Canceled) != FailureInterrupted {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -108,7 +108,7 @@ func TestClassifyGenericError(t *testing.T) {
|
||||||
t.Run("with an IPv4 address", func(t *testing.T) {
|
t.Run("with an IPv4 address", func(t *testing.T) {
|
||||||
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: some error")
|
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: some error")
|
||||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
||||||
out := classifyGenericError(input)
|
out := ClassifyGenericError(input)
|
||||||
if out != expected {
|
if out != expected {
|
||||||
t.Fatal(cmp.Diff(expected, out))
|
t.Fatal(cmp.Diff(expected, out))
|
||||||
}
|
}
|
||||||
|
@ -117,7 +117,7 @@ func TestClassifyGenericError(t *testing.T) {
|
||||||
t.Run("with an IPv6 address", func(t *testing.T) {
|
t.Run("with an IPv6 address", func(t *testing.T) {
|
||||||
input := errors.New("read tcp [::1]:56948->[::1]:443: some error")
|
input := errors.New("read tcp [::1]:56948->[::1]:443: some error")
|
||||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
||||||
out := classifyGenericError(input)
|
out := ClassifyGenericError(input)
|
||||||
if out != expected {
|
if out != expected {
|
||||||
t.Fatal(cmp.Diff(expected, out))
|
t.Fatal(cmp.Diff(expected, out))
|
||||||
}
|
}
|
||||||
|
@ -131,100 +131,100 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
|
||||||
|
|
||||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||||
err := &ErrWrapper{Failure: FailureEOFError}
|
err := &ErrWrapper{Failure: FailureEOFError}
|
||||||
if classifyQUICHandshakeError(err) != FailureEOFError {
|
if ClassifyQUICHandshakeError(err) != FailureEOFError {
|
||||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for incompatible quic version", func(t *testing.T) {
|
t.Run("for incompatible quic version", func(t *testing.T) {
|
||||||
if classifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
|
if ClassifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for stateless reset", func(t *testing.T) {
|
t.Run("for stateless reset", func(t *testing.T) {
|
||||||
if classifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
|
if ClassifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for handshake timeout", func(t *testing.T) {
|
t.Run("for handshake timeout", func(t *testing.T) {
|
||||||
if classifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
|
if ClassifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for idle timeout", func(t *testing.T) {
|
t.Run("for idle timeout", func(t *testing.T) {
|
||||||
if classifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
|
if ClassifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for connection refused", func(t *testing.T) {
|
t.Run("for connection refused", func(t *testing.T) {
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for bad certificate", func(t *testing.T) {
|
t.Run("for bad certificate", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertBadCertificate
|
var err quic.TransportErrorCode = quicTLSAlertBadCertificate
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for unsupported certificate", func(t *testing.T) {
|
t.Run("for unsupported certificate", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertUnsupportedCertificate
|
var err quic.TransportErrorCode = quicTLSAlertUnsupportedCertificate
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for certificate expired", func(t *testing.T) {
|
t.Run("for certificate expired", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertCertificateExpired
|
var err quic.TransportErrorCode = quicTLSAlertCertificateExpired
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for certificate revoked", func(t *testing.T) {
|
t.Run("for certificate revoked", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertCertificateRevoked
|
var err quic.TransportErrorCode = quicTLSAlertCertificateRevoked
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for certificate unknown", func(t *testing.T) {
|
t.Run("for certificate unknown", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertCertificateUnknown
|
var err quic.TransportErrorCode = quicTLSAlertCertificateUnknown
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for decrypt error", func(t *testing.T) {
|
t.Run("for decrypt error", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertDecryptError
|
var err quic.TransportErrorCode = quicTLSAlertDecryptError
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for handshake failure", func(t *testing.T) {
|
t.Run("for handshake failure", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertHandshakeFailure
|
var err quic.TransportErrorCode = quicTLSAlertHandshakeFailure
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for unknown CA", func(t *testing.T) {
|
t.Run("for unknown CA", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSAlertUnknownCA
|
var err quic.TransportErrorCode = quicTLSAlertUnknownCA
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for unrecognized hostname", func(t *testing.T) {
|
t.Run("for unrecognized hostname", func(t *testing.T) {
|
||||||
var err quic.TransportErrorCode = quicTLSUnrecognizedName
|
var err quic.TransportErrorCode = quicTLSUnrecognizedName
|
||||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
|
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -234,13 +234,13 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
|
||||||
ErrorCode: quic.InternalError,
|
ErrorCode: quic.InternalError,
|
||||||
ErrorMessage: FailureHostUnreachable,
|
ErrorMessage: FailureHostUnreachable,
|
||||||
}
|
}
|
||||||
if classifyQUICHandshakeError(err) != FailureHostUnreachable {
|
if ClassifyQUICHandshakeError(err) != FailureHostUnreachable {
|
||||||
t.Fatal("unexpected results")
|
t.Fatal("unexpected results")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for another kind of error", func(t *testing.T) {
|
t.Run("for another kind of error", func(t *testing.T) {
|
||||||
if classifyQUICHandshakeError(io.EOF) != FailureEOFError {
|
if ClassifyQUICHandshakeError(io.EOF) != FailureEOFError {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -252,43 +252,43 @@ func TestClassifyResolverError(t *testing.T) {
|
||||||
|
|
||||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||||
err := &ErrWrapper{Failure: FailureEOFError}
|
err := &ErrWrapper{Failure: FailureEOFError}
|
||||||
if classifyResolverError(err) != FailureEOFError {
|
if ClassifyResolverError(err) != FailureEOFError {
|
||||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for ErrDNSBogon", func(t *testing.T) {
|
t.Run("for ErrDNSBogon", func(t *testing.T) {
|
||||||
if classifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
|
if ClassifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for refused", func(t *testing.T) {
|
t.Run("for refused", func(t *testing.T) {
|
||||||
if classifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
|
if ClassifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for servfail", func(t *testing.T) {
|
t.Run("for servfail", func(t *testing.T) {
|
||||||
if classifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
|
if ClassifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for dns reply with wrong queryID", func(t *testing.T) {
|
t.Run("for dns reply with wrong queryID", func(t *testing.T) {
|
||||||
if classifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
|
if ClassifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for EAI_NODATA returned by Android's getaddrinfo", func(t *testing.T) {
|
t.Run("for EAI_NODATA returned by Android's getaddrinfo", func(t *testing.T) {
|
||||||
if classifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
|
if ClassifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for another kind of error", func(t *testing.T) {
|
t.Run("for another kind of error", func(t *testing.T) {
|
||||||
if classifyResolverError(io.EOF) != FailureEOFError {
|
if ClassifyResolverError(io.EOF) != FailureEOFError {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -300,34 +300,34 @@ func TestClassifyTLSHandshakeError(t *testing.T) {
|
||||||
|
|
||||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||||
err := &ErrWrapper{Failure: FailureEOFError}
|
err := &ErrWrapper{Failure: FailureEOFError}
|
||||||
if classifyTLSHandshakeError(err) != FailureEOFError {
|
if ClassifyTLSHandshakeError(err) != FailureEOFError {
|
||||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for x509.HostnameError", func(t *testing.T) {
|
t.Run("for x509.HostnameError", func(t *testing.T) {
|
||||||
var err x509.HostnameError
|
var err x509.HostnameError
|
||||||
if classifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
|
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
|
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
|
||||||
var err x509.UnknownAuthorityError
|
var err x509.UnknownAuthorityError
|
||||||
if classifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
|
if ClassifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
|
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
|
||||||
var err x509.CertificateInvalidError
|
var err x509.CertificateInvalidError
|
||||||
if classifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
|
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for another kind of error", func(t *testing.T) {
|
t.Run("for another kind of error", func(t *testing.T) {
|
||||||
if classifyTLSHandshakeError(io.EOF) != FailureEOFError {
|
if ClassifyTLSHandshakeError(io.EOF) != FailureEOFError {
|
||||||
t.Fatal("unexpected result")
|
t.Fatal("unexpected result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -125,7 +125,7 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver,
|
||||||
outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors
|
outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors
|
||||||
}
|
}
|
||||||
return &dialerLogger{
|
return &dialerLogger{
|
||||||
Dialer: &dialerResolver{
|
Dialer: &dialerResolverWithTracing{
|
||||||
Dialer: &dialerLogger{
|
Dialer: &dialerLogger{
|
||||||
Dialer: outDialer,
|
Dialer: outDialer,
|
||||||
DebugLogger: logger,
|
DebugLogger: logger,
|
||||||
|
@ -171,15 +171,24 @@ func (d *DialerSystem) CloseIdleConnections() {
|
||||||
// nothing to do here
|
// nothing to do here
|
||||||
}
|
}
|
||||||
|
|
||||||
// dialerResolver combines dialing with domain name resolution.
|
// dialerResolverWithTracing combines dialing with domain name resolution and
|
||||||
type dialerResolver struct {
|
// implements hooks to trace TCP (or UDP) connect operations.
|
||||||
|
type dialerResolverWithTracing struct {
|
||||||
Dialer model.Dialer
|
Dialer model.Dialer
|
||||||
Resolver model.Resolver
|
Resolver model.Resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Dialer = &dialerResolver{}
|
var _ model.Dialer = &dialerResolverWithTracing{}
|
||||||
|
|
||||||
func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
// DialContext implements model.Dialer.DialContext. Specifically this
|
||||||
|
// method performs the following operations:
|
||||||
|
//
|
||||||
|
// 1. resolve the domain inside the address using a resolver;
|
||||||
|
//
|
||||||
|
// 2. cycle through the available IP addresses and try to dial each of them;
|
||||||
|
//
|
||||||
|
// 3. trace the TCP (or UDP) connect and allow wrapping the returned conn.
|
||||||
|
func (d *dialerResolverWithTracing) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
// QUIRK: this routine and the related routines in quirks.go cannot
|
// QUIRK: this routine and the related routines in quirks.go cannot
|
||||||
// be changed easily until we use events tracing to measure.
|
// be changed easily until we use events tracing to measure.
|
||||||
//
|
//
|
||||||
|
@ -194,10 +203,27 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
|
||||||
}
|
}
|
||||||
addrs = quirkSortIPAddrs(addrs)
|
addrs = quirkSortIPAddrs(addrs)
|
||||||
var errorslist []error
|
var errorslist []error
|
||||||
|
trace := ContextTraceOrDefault(ctx)
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
target := net.JoinHostPort(addr, onlyport)
|
target := net.JoinHostPort(addr, onlyport)
|
||||||
|
started := trace.TimeNow()
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, target)
|
conn, err := d.Dialer.DialContext(ctx, network, target)
|
||||||
|
finished := trace.TimeNow()
|
||||||
|
// TODO(bassosimone): to make the code robust to future refactoring we have
|
||||||
|
// moved error wrapping inside this type. This change opens up the possibility
|
||||||
|
// of simplifying the dialing chain by removing dialerErrWrapper. We'll be
|
||||||
|
// able to implement this refactoring once netx is gone. We cannot complete
|
||||||
|
// this refactoring _before_ because WrapDialer inserts extra wrappers
|
||||||
|
// provided by netx in the dialers chain _before_ this dialer and the dialers
|
||||||
|
// that netx insert assume that they wrap a dialer with error wrapping.
|
||||||
|
//
|
||||||
|
// Because error wrapping should be idempotent, it should not be a problem
|
||||||
|
// to have two error wrapping dialers in the chain except that, of course, it
|
||||||
|
// would be less efficient than just having a single wrapper.
|
||||||
|
err = MaybeNewErrWrapper(ClassifyGenericError, ConnectOperation, err)
|
||||||
|
trace.OnConnectDone(started, network, onlyhost, target, err, finished)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
conn = &dialerErrWrapperConn{conn}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
errorslist = append(errorslist, err)
|
errorslist = append(errorslist, err)
|
||||||
|
@ -206,14 +232,14 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupHost ensures we correctly handle IP addresses.
|
// lookupHost ensures we correctly handle IP addresses.
|
||||||
func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (d *dialerResolverWithTracing) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
if net.ParseIP(hostname) != nil {
|
if net.ParseIP(hostname) != nil {
|
||||||
return []string{hostname}, nil
|
return []string{hostname}, nil
|
||||||
}
|
}
|
||||||
return d.Resolver.LookupHost(ctx, hostname)
|
return d.Resolver.LookupHost(ctx, hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dialerResolver) CloseIdleConnections() {
|
func (d *dialerResolverWithTracing) CloseIdleConnections() {
|
||||||
d.Dialer.CloseIdleConnections()
|
d.Dialer.CloseIdleConnections()
|
||||||
d.Resolver.CloseIdleConnections()
|
d.Resolver.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
@ -303,7 +329,7 @@ var _ model.Dialer = &dialerErrWrapper{}
|
||||||
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newErrWrapper(classifyGenericError, ConnectOperation, err)
|
return nil, NewErrWrapper(ClassifyGenericError, ConnectOperation, err)
|
||||||
}
|
}
|
||||||
return &dialerErrWrapperConn{Conn: conn}, nil
|
return &dialerErrWrapperConn{Conn: conn}, nil
|
||||||
}
|
}
|
||||||
|
@ -322,7 +348,7 @@ var _ net.Conn = &dialerErrWrapperConn{}
|
||||||
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
||||||
count, err := c.Conn.Read(b)
|
count, err := c.Conn.Read(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, newErrWrapper(classifyGenericError, ReadOperation, err)
|
return 0, NewErrWrapper(ClassifyGenericError, ReadOperation, err)
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
@ -330,7 +356,7 @@ func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
||||||
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
||||||
count, err := c.Conn.Write(b)
|
count, err := c.Conn.Write(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, newErrWrapper(classifyGenericError, WriteOperation, err)
|
return 0, NewErrWrapper(ClassifyGenericError, WriteOperation, err)
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
@ -338,7 +364,7 @@ func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
||||||
func (c *dialerErrWrapperConn) Close() error {
|
func (c *dialerErrWrapperConn) Close() error {
|
||||||
err := c.Conn.Close()
|
err := c.Conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newErrWrapper(classifyGenericError, CloseOperation, err)
|
return NewErrWrapper(ClassifyGenericError, CloseOperation, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/apex/log"
|
"github.com/apex/log"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewDialerWithStdlibResolver(t *testing.T) {
|
func TestNewDialerWithStdlibResolver(t *testing.T) {
|
||||||
|
@ -22,7 +23,7 @@ func TestNewDialerWithStdlibResolver(t *testing.T) {
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
// typecheck the resolver
|
// typecheck the resolver
|
||||||
reso := logger.Dialer.(*dialerResolver)
|
reso := logger.Dialer.(*dialerResolverWithTracing)
|
||||||
typecheckForSystemResolver(t, reso.Resolver, model.DiscardLogger)
|
typecheckForSystemResolver(t, reso.Resolver, model.DiscardLogger)
|
||||||
// typecheck the dialer
|
// typecheck the dialer
|
||||||
logger = reso.Dialer.(*dialerLogger)
|
logger = reso.Dialer.(*dialerLogger)
|
||||||
|
@ -64,7 +65,7 @@ func TestNewDialer(t *testing.T) {
|
||||||
if logger.DebugLogger != log.Log {
|
if logger.DebugLogger != log.Log {
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
reso := logger.Dialer.(*dialerResolver)
|
reso := logger.Dialer.(*dialerResolverWithTracing)
|
||||||
if _, okay := reso.Resolver.(*NullResolver); !okay {
|
if _, okay := reso.Resolver.(*NullResolver); !okay {
|
||||||
t.Fatal("invalid Resolver type")
|
t.Fatal("invalid Resolver type")
|
||||||
}
|
}
|
||||||
|
@ -136,10 +137,10 @@ func TestDialerSystem(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialerResolver(t *testing.T) {
|
func TestDialerResolverWithTracing(t *testing.T) {
|
||||||
t.Run("DialContext", func(t *testing.T) {
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
t.Run("fails without a port", func(t *testing.T) {
|
t.Run("fails without a port", func(t *testing.T) {
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &DialerSystem{},
|
Dialer: &DialerSystem{},
|
||||||
Resolver: NewUnwrappedStdlibResolver(),
|
Resolver: NewUnwrappedStdlibResolver(),
|
||||||
}
|
}
|
||||||
|
@ -154,7 +155,7 @@ func TestDialerResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles dialing error correctly for single IP address", func(t *testing.T) {
|
t.Run("handles dialing error correctly for single IP address", func(t *testing.T) {
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
|
@ -166,13 +167,26 @@ func TestDialerResolver(t *testing.T) {
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
|
var errWrapper *ErrWrapper
|
||||||
|
if !errors.As(err, &errWrapper) {
|
||||||
|
t.Fatal("the error has not been wrapped")
|
||||||
|
}
|
||||||
|
if errWrapper.Failure != FailureEOFError {
|
||||||
|
t.Fatal("invalid wrapped error's failure")
|
||||||
|
}
|
||||||
|
if errWrapper.Operation != ConnectOperation {
|
||||||
|
t.Fatal("invalid wrapped error's operation")
|
||||||
|
}
|
||||||
|
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||||
|
t.Fatal("invalid wrapped error's underlying error")
|
||||||
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handles dialing error correctly for many IP addresses", func(t *testing.T) {
|
t.Run("handles dialing error correctly for many IP addresses", func(t *testing.T) {
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
return nil, io.EOF
|
return nil, io.EOF
|
||||||
|
@ -188,6 +202,19 @@ func TestDialerResolver(t *testing.T) {
|
||||||
if !errors.Is(err, io.EOF) {
|
if !errors.Is(err, io.EOF) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
|
var errWrapper *ErrWrapper
|
||||||
|
if !errors.As(err, &errWrapper) {
|
||||||
|
t.Fatal("the error has not been wrapped")
|
||||||
|
}
|
||||||
|
if errWrapper.Failure != FailureEOFError {
|
||||||
|
t.Fatal("invalid wrapped error's failure")
|
||||||
|
}
|
||||||
|
if errWrapper.Operation != ConnectOperation {
|
||||||
|
t.Fatal("invalid wrapped error's operation")
|
||||||
|
}
|
||||||
|
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||||
|
t.Fatal("invalid wrapped error's underlying error")
|
||||||
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
|
@ -199,7 +226,7 @@ func TestDialerResolver(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
return expectedConn, nil
|
return expectedConn, nil
|
||||||
|
@ -215,7 +242,10 @@ func TestDialerResolver(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if conn != expectedConn {
|
// Ensure that the dialer returns a connection that is already wrapping errors,
|
||||||
|
// which is a new behavior since https://github.com/ooni/probe-cli/pull/815
|
||||||
|
errWrapperConn := conn.(*dialerErrWrapperConn)
|
||||||
|
if errWrapperConn.Conn != expectedConn {
|
||||||
t.Fatal("unexpected conn")
|
t.Fatal("unexpected conn")
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
@ -225,7 +255,7 @@ func TestDialerResolver(t *testing.T) {
|
||||||
// This test is fundamental to the following
|
// This test is fundamental to the following
|
||||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||||
mu := &sync.Mutex{}
|
mu := &sync.Mutex{}
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
// It should not happen to have parallel dials with
|
// It should not happen to have parallel dials with
|
||||||
|
@ -257,7 +287,7 @@ func TestDialerResolver(t *testing.T) {
|
||||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||||
mu := &sync.Mutex{}
|
mu := &sync.Mutex{}
|
||||||
var attempts []string
|
var attempts []string
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
// It should not happen to have parallel dials with
|
// It should not happen to have parallel dials with
|
||||||
|
@ -298,14 +328,14 @@ func TestDialerResolver(t *testing.T) {
|
||||||
mu := &sync.Mutex{}
|
mu := &sync.Mutex{}
|
||||||
errorsList := []error{
|
errorsList := []error{
|
||||||
errors.New("a mocked error"),
|
errors.New("a mocked error"),
|
||||||
newErrWrapper(
|
NewErrWrapper(
|
||||||
classifyGenericError,
|
ClassifyGenericError,
|
||||||
CloseOperation,
|
CloseOperation,
|
||||||
io.EOF,
|
io.EOF,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
var errorIdx int
|
var errorIdx int
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
// It should not happen to have parallel dials with
|
// It should not happen to have parallel dials with
|
||||||
|
@ -337,17 +367,18 @@ func TestDialerResolver(t *testing.T) {
|
||||||
t.Run("though ignores the unknown failures", func(t *testing.T) {
|
t.Run("though ignores the unknown failures", func(t *testing.T) {
|
||||||
// This test is fundamental to the following
|
// This test is fundamental to the following
|
||||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||||
|
expectedErr := errors.New("a mocked error")
|
||||||
mu := &sync.Mutex{}
|
mu := &sync.Mutex{}
|
||||||
errorsList := []error{
|
errorsList := []error{
|
||||||
errors.New("a mocked error"),
|
expectedErr,
|
||||||
newErrWrapper(
|
NewErrWrapper(
|
||||||
classifyGenericError,
|
ClassifyGenericError,
|
||||||
CloseOperation,
|
CloseOperation,
|
||||||
errors.New("antani"),
|
errors.New("antani"), // this is an unknown failure and we should not return it
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
var errorIdx int
|
var errorIdx int
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
// It should not happen to have parallel dials with
|
// It should not happen to have parallel dials with
|
||||||
|
@ -368,18 +399,95 @@ func TestDialerResolver(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
|
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
|
||||||
if err == nil || err.Error() != "a mocked error" {
|
if !errors.Is(err, expectedErr) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
|
var errWrapper *ErrWrapper
|
||||||
|
if !errors.As(err, &errWrapper) {
|
||||||
|
t.Fatal("error has not been wrapped")
|
||||||
|
}
|
||||||
|
if errWrapper.Failure != "unknown_failure: a mocked error" {
|
||||||
|
t.Fatal("unexpected wrapped error's failure")
|
||||||
|
}
|
||||||
|
if errWrapper.Operation != ConnectOperation {
|
||||||
|
t.Fatal("unexpected wrapped error's operation")
|
||||||
|
}
|
||||||
|
if !errors.Is(errWrapper.WrappedErr, expectedErr) {
|
||||||
|
t.Fatal("unexpected wrapped error's underlying error")
|
||||||
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("uses a context-injected custom trace", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
called bool
|
||||||
|
domainOK bool
|
||||||
|
networkOK bool
|
||||||
|
remoteAddrOK bool
|
||||||
|
startTimeOK bool
|
||||||
|
finishTimeOK bool
|
||||||
|
wrappedErr bool
|
||||||
|
)
|
||||||
|
zeroTime := time.Now()
|
||||||
|
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
tx := &mocks.Trace{
|
||||||
|
MockTimeNow: deterministicTime.Now,
|
||||||
|
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||||
|
var ew *ErrWrapper
|
||||||
|
called = true
|
||||||
|
domainOK = (domain == "1.1.1.1")
|
||||||
|
networkOK = (network == "tcp")
|
||||||
|
remoteAddrOK = (remoteAddr == "1.1.1.1:853")
|
||||||
|
startTimeOK = (started.Sub(zeroTime) == 0)
|
||||||
|
finishTimeOK = (finished.Sub(zeroTime) == time.Second)
|
||||||
|
wrappedErr = errors.As(err, &ew) && ew.Failure == FailureEOFError
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := ContextWithTrace(context.Background(), tx)
|
||||||
|
d := &dialerResolverWithTracing{
|
||||||
|
Dialer: &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
|
return nil, io.EOF
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Resolver: &NullResolver{},
|
||||||
|
}
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", "1.1.1.1:853")
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
if !domainOK {
|
||||||
|
t.Fatal("domain was not okay")
|
||||||
|
}
|
||||||
|
if !networkOK {
|
||||||
|
t.Fatal("network was not okay")
|
||||||
|
}
|
||||||
|
if !remoteAddrOK {
|
||||||
|
t.Fatal("remoteAddr was not okay")
|
||||||
|
}
|
||||||
|
if !startTimeOK {
|
||||||
|
t.Fatal("start time was not okay")
|
||||||
|
}
|
||||||
|
if !finishTimeOK {
|
||||||
|
t.Fatal("finish time was not okay")
|
||||||
|
}
|
||||||
|
if !wrappedErr {
|
||||||
|
t.Fatal("not wrapped")
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("lookupHost", func(t *testing.T) {
|
t.Run("lookupHost", func(t *testing.T) {
|
||||||
t.Run("handles addresses correctly", func(t *testing.T) {
|
t.Run("handles addresses correctly", func(t *testing.T) {
|
||||||
dialer := &dialerResolver{
|
dialer := &dialerResolverWithTracing{
|
||||||
Dialer: &DialerSystem{},
|
Dialer: &DialerSystem{},
|
||||||
Resolver: &NullResolver{},
|
Resolver: &NullResolver{},
|
||||||
}
|
}
|
||||||
|
@ -393,7 +501,7 @@ func TestDialerResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("fails correctly on lookup error", func(t *testing.T) {
|
t.Run("fails correctly on lookup error", func(t *testing.T) {
|
||||||
dialer := &dialerResolver{
|
dialer := &dialerResolverWithTracing{
|
||||||
Dialer: &DialerSystem{},
|
Dialer: &DialerSystem{},
|
||||||
Resolver: &NullResolver{},
|
Resolver: &NullResolver{},
|
||||||
}
|
}
|
||||||
|
@ -413,7 +521,7 @@ func TestDialerResolver(t *testing.T) {
|
||||||
calledDialer bool
|
calledDialer bool
|
||||||
calledResolver bool
|
calledResolver bool
|
||||||
)
|
)
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockCloseIdleConnections: func() {
|
MockCloseIdleConnections: func() {
|
||||||
calledDialer = true
|
calledDialer = true
|
||||||
|
|
|
@ -38,7 +38,7 @@ func (t *dnsTransportErrWrapper) RoundTrip(
|
||||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
resp, err := t.DNSTransport.RoundTrip(ctx, query)
|
resp, err := t.DNSTransport.RoundTrip(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err)
|
return nil, NewErrWrapper(ClassifyResolverError, DNSRoundTripOperation, err)
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,10 @@
|
||||||
//
|
//
|
||||||
// We want to have reasonable watchdog timeouts for each operation.
|
// We want to have reasonable watchdog timeouts for each operation.
|
||||||
//
|
//
|
||||||
|
// We also want lightweight support for tracing network events. To this end, we
|
||||||
|
// use context.WithValue and context.Value to inject, and retrieve, a model.Trace
|
||||||
|
// implementation OPTIONALLY configured by the user.
|
||||||
|
//
|
||||||
// See also the design document at docs/design/dd-003-step-by-step.md,
|
// See also the design document at docs/design/dd-003-step-by-step.md,
|
||||||
// which provides an overview of netxlite's main concerns.
|
// which provides an overview of netxlite's main concerns.
|
||||||
//
|
//
|
||||||
|
|
|
@ -71,7 +71,7 @@ func (e *ErrWrapper) MarshalJSON() ([]byte, error) {
|
||||||
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
|
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
|
||||||
type classifier func(err error) string
|
type classifier func(err error) string
|
||||||
|
|
||||||
// newErrWrapper creates a new ErrWrapper using the given
|
// NewErrWrapper creates a new ErrWrapper using the given
|
||||||
// classifier, operation name, and underlying error.
|
// classifier, operation name, and underlying error.
|
||||||
//
|
//
|
||||||
// This function panics if classifier is nil, or operation
|
// This function panics if classifier is nil, or operation
|
||||||
|
@ -81,7 +81,7 @@ type classifier func(err error) string
|
||||||
// error wrapper will use the same classification string and
|
// error wrapper will use the same classification string and
|
||||||
// will determine whether to keep the major operation as documented
|
// will determine whether to keep the major operation as documented
|
||||||
// in the ErrWrapper.Operation documentation.
|
// in the ErrWrapper.Operation documentation.
|
||||||
func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
func NewErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||||
var wrapper *ErrWrapper
|
var wrapper *ErrWrapper
|
||||||
if errors.As(err, &wrapper) {
|
if errors.As(err, &wrapper) {
|
||||||
return &ErrWrapper{
|
return &ErrWrapper{
|
||||||
|
@ -106,6 +106,19 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(https://github.com/ooni/probe/issues/2163): we can really
|
||||||
|
// simplify the error wrapping situation here by just dropping
|
||||||
|
// NewErrWrapper and always using MaybeNewErrWrapper.
|
||||||
|
|
||||||
|
// MaybeNewErrWrapper is like NewErrWrapper except that this
|
||||||
|
// function won't panic if passed a nil error.
|
||||||
|
func MaybeNewErrWrapper(c classifier, op string, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return NewErrWrapper(c, op, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewTopLevelGenericErrWrapper wraps an error occurring at top
|
// NewTopLevelGenericErrWrapper wraps an error occurring at top
|
||||||
// level using a generic classifier as classifier. This is the
|
// level using a generic classifier as classifier. This is the
|
||||||
// function you should call when you suspect a given error hasn't
|
// function you should call when you suspect a given error hasn't
|
||||||
|
@ -115,7 +128,7 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||||
// error wrapper will use the same classification string and
|
// error wrapper will use the same classification string and
|
||||||
// failed operation of the original error.
|
// failed operation of the original error.
|
||||||
func NewTopLevelGenericErrWrapper(err error) *ErrWrapper {
|
func NewTopLevelGenericErrWrapper(err error) *ErrWrapper {
|
||||||
return newErrWrapper(classifyGenericError, TopLevelOperation, err)
|
return NewErrWrapper(ClassifyGenericError, TopLevelOperation, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func classifyOperation(ew *ErrWrapper, operation string) string {
|
func classifyOperation(ew *ErrWrapper, operation string) string {
|
||||||
|
|
|
@ -53,7 +53,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||||
recovered.Add(1)
|
recovered.Add(1)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
newErrWrapper(nil, CloseOperation, io.EOF)
|
NewErrWrapper(nil, CloseOperation, io.EOF)
|
||||||
}()
|
}()
|
||||||
if recovered.Load() != 1 {
|
if recovered.Load() != 1 {
|
||||||
t.Fatal("did not panic")
|
t.Fatal("did not panic")
|
||||||
|
@ -68,7 +68,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||||
recovered.Add(1)
|
recovered.Add(1)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
newErrWrapper(classifyGenericError, "", io.EOF)
|
NewErrWrapper(ClassifyGenericError, "", io.EOF)
|
||||||
}()
|
}()
|
||||||
if recovered.Load() != 1 {
|
if recovered.Load() != 1 {
|
||||||
t.Fatal("did not panic")
|
t.Fatal("did not panic")
|
||||||
|
@ -83,7 +83,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||||
recovered.Add(1)
|
recovered.Add(1)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
newErrWrapper(classifyGenericError, CloseOperation, nil)
|
NewErrWrapper(ClassifyGenericError, CloseOperation, nil)
|
||||||
}()
|
}()
|
||||||
if recovered.Load() != 1 {
|
if recovered.Load() != 1 {
|
||||||
t.Fatal("did not panic")
|
t.Fatal("did not panic")
|
||||||
|
@ -91,7 +91,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("otherwise, works as intended", func(t *testing.T) {
|
t.Run("otherwise, works as intended", func(t *testing.T) {
|
||||||
ew := newErrWrapper(classifyGenericError, CloseOperation, io.EOF)
|
ew := NewErrWrapper(ClassifyGenericError, CloseOperation, io.EOF)
|
||||||
if ew.Failure != FailureEOFError {
|
if ew.Failure != FailureEOFError {
|
||||||
t.Fatal("unexpected failure")
|
t.Fatal("unexpected failure")
|
||||||
}
|
}
|
||||||
|
@ -104,10 +104,10 @@ func TestNewErrWrapper(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("when the underlying error is already a wrapped error", func(t *testing.T) {
|
t.Run("when the underlying error is already a wrapped error", func(t *testing.T) {
|
||||||
ew := newErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
ew := NewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||||
var err1 error = ew
|
var err1 error = ew
|
||||||
err2 := fmt.Errorf("cannot read: %w", err1)
|
err2 := fmt.Errorf("cannot read: %w", err1)
|
||||||
ew2 := newErrWrapper(classifyGenericError, HTTPRoundTripOperation, err2)
|
ew2 := NewErrWrapper(ClassifyGenericError, HTTPRoundTripOperation, err2)
|
||||||
if ew2.Failure != ew.Failure {
|
if ew2.Failure != ew.Failure {
|
||||||
t.Fatal("not the same failure")
|
t.Fatal("not the same failure")
|
||||||
}
|
}
|
||||||
|
@ -117,6 +117,34 @@ func TestNewErrWrapper(t *testing.T) {
|
||||||
if ew2.WrappedErr != err2 {
|
if ew2.WrappedErr != err2 {
|
||||||
t.Fatal("invalid underlying error")
|
t.Fatal("invalid underlying error")
|
||||||
}
|
}
|
||||||
|
// Make sure we can still use errors.Is with two layers of wrapping
|
||||||
|
if !errors.Is(ew2, ECONNRESET) {
|
||||||
|
t.Fatal("we cannot use errors.Is to retrieve the real syscall error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMaybeNewErrWrapper(t *testing.T) {
|
||||||
|
// TODO(https://github.com/ooni/probe/issues/2163): we can really
|
||||||
|
// simplify the error wrapping situation here by just dropping
|
||||||
|
// NewErrWrapper and always using MaybeNewErrWrapper.
|
||||||
|
|
||||||
|
t.Run("when we pass a nil error to this function", func(t *testing.T) {
|
||||||
|
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("unexpected output", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when we pass a non-nil error to this function", func(t *testing.T) {
|
||||||
|
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||||
|
if !errors.Is(err, ECONNRESET) {
|
||||||
|
t.Fatal("unexpected output", err)
|
||||||
|
}
|
||||||
|
var ew *ErrWrapper
|
||||||
|
if !errors.As(err, &ew) {
|
||||||
|
t.Fatal("not an instance of ErrWrapper")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
|
||||||
dialer := txpCc.Dialer
|
dialer := txpCc.Dialer
|
||||||
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
||||||
dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger)
|
dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger)
|
||||||
dialerReso := dialerLog.Dialer.(*dialerResolver)
|
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
|
||||||
if dialerReso.Resolver != resolver {
|
if dialerReso.Resolver != resolver {
|
||||||
t.Fatal("invalid resolver")
|
t.Fatal("invalid resolver")
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
|
||||||
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
||||||
dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer)
|
dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer)
|
||||||
dialerLog := dialerProxy.Dialer.(*dialerLogger)
|
dialerLog := dialerProxy.Dialer.(*dialerLogger)
|
||||||
dialerReso := dialerLog.Dialer.(*dialerResolver)
|
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
|
||||||
if dialerReso.Resolver != resolver {
|
if dialerReso.Resolver != resolver {
|
||||||
t.Fatal("invalid resolver")
|
t.Fatal("invalid resolver")
|
||||||
}
|
}
|
||||||
|
@ -269,7 +269,7 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||||
t.Run("works as intended with failing dialer", func(t *testing.T) {
|
t.Run("works as intended with failing dialer", func(t *testing.T) {
|
||||||
called := &atomicx.Int64{}
|
called := &atomicx.Int64{}
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
d := &dialerResolver{
|
d := &dialerResolverWithTracing{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context,
|
MockDialContext: func(ctx context.Context,
|
||||||
network, address string) (net.Conn, error) {
|
network, address string) (net.Conn, error) {
|
||||||
|
@ -612,7 +612,7 @@ func TestNewHTTPClientWithResolver(t *testing.T) {
|
||||||
txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser)
|
txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser)
|
||||||
dialer := txpCc.Dialer.(*httpDialerWithReadTimeout)
|
dialer := txpCc.Dialer.(*httpDialerWithReadTimeout)
|
||||||
dialerLogger := dialer.Dialer.(*dialerLogger)
|
dialerLogger := dialer.Dialer.(*dialerLogger)
|
||||||
dialerReso := dialerLogger.Dialer.(*dialerResolver)
|
dialerReso := dialerLogger.Dialer.(*dialerResolverWithTracing)
|
||||||
if dialerReso.Resolver != reso {
|
if dialerReso.Resolver != reso {
|
||||||
t.Fatal("invalid resolver")
|
t.Fatal("invalid resolver")
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ func TestReadAllContext(t *testing.T) {
|
||||||
//
|
//
|
||||||
// Note: Returning a wrapped error to ensure we address
|
// Note: Returning a wrapped error to ensure we address
|
||||||
// https://github.com/ooni/probe/issues/1965
|
// https://github.com/ooni/probe/issues/1965
|
||||||
return len(b), newErrWrapper(classifyGenericError,
|
return len(b), NewErrWrapper(ClassifyGenericError,
|
||||||
ReadOperation, io.EOF)
|
ReadOperation, io.EOF)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ func TestCopyContext(t *testing.T) {
|
||||||
//
|
//
|
||||||
// Note: Returning a wrapped error to ensure we address
|
// Note: Returning a wrapped error to ensure we address
|
||||||
// https://github.com/ooni/probe/issues/1965
|
// https://github.com/ooni/probe/issues/1965
|
||||||
return len(b), newErrWrapper(classifyGenericError,
|
return len(b), NewErrWrapper(ClassifyGenericError,
|
||||||
ReadOperation, io.EOF)
|
ReadOperation, io.EOF)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -380,7 +380,7 @@ var _ model.QUICListener = &quicListenerErrWrapper{}
|
||||||
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
|
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
|
||||||
pconn, err := qls.QUICListener.Listen(addr)
|
pconn, err := qls.QUICListener.Listen(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newErrWrapper(classifyGenericError, QUICListenOperation, err)
|
return nil, NewErrWrapper(ClassifyGenericError, QUICListenOperation, err)
|
||||||
}
|
}
|
||||||
return &quicErrWrapperUDPLikeConn{pconn}, nil
|
return &quicErrWrapperUDPLikeConn{pconn}, nil
|
||||||
}
|
}
|
||||||
|
@ -397,7 +397,7 @@ var _ model.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
|
||||||
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||||
count, err := c.UDPLikeConn.WriteTo(p, addr)
|
count, err := c.UDPLikeConn.WriteTo(p, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, newErrWrapper(classifyGenericError, WriteToOperation, err)
|
return 0, NewErrWrapper(ClassifyGenericError, WriteToOperation, err)
|
||||||
}
|
}
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
@ -406,7 +406,7 @@ func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error
|
||||||
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||||
n, addr, err := c.UDPLikeConn.ReadFrom(b)
|
n, addr, err := c.UDPLikeConn.ReadFrom(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, newErrWrapper(classifyGenericError, ReadFromOperation, err)
|
return 0, nil, NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
|
||||||
}
|
}
|
||||||
return n, addr, nil
|
return n, addr, nil
|
||||||
}
|
}
|
||||||
|
@ -415,7 +415,7 @@ func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||||
func (c *quicErrWrapperUDPLikeConn) Close() error {
|
func (c *quicErrWrapperUDPLikeConn) Close() error {
|
||||||
err := c.UDPLikeConn.Close()
|
err := c.UDPLikeConn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return newErrWrapper(classifyGenericError, ReadFromOperation, err)
|
return NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -433,8 +433,8 @@ func (d *quicDialerErrWrapper) DialContext(
|
||||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||||
qconn, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
qconn, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newErrWrapper(
|
return nil, NewErrWrapper(
|
||||||
classifyQUICHandshakeError, QUICHandshakeOperation, err)
|
ClassifyQUICHandshakeError, QUICHandshakeOperation, err)
|
||||||
}
|
}
|
||||||
return qconn, nil
|
return qconn, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,13 +34,13 @@ func TestQuirkReduceErrors(t *testing.T) {
|
||||||
|
|
||||||
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
||||||
err1 := errors.New("mocked error #1")
|
err1 := errors.New("mocked error #1")
|
||||||
err2 := newErrWrapper(
|
err2 := NewErrWrapper(
|
||||||
classifyGenericError,
|
ClassifyGenericError,
|
||||||
CloseOperation,
|
CloseOperation,
|
||||||
errors.New("antani"),
|
errors.New("antani"),
|
||||||
)
|
)
|
||||||
err3 := newErrWrapper(
|
err3 := NewErrWrapper(
|
||||||
classifyGenericError,
|
ClassifyGenericError,
|
||||||
CloseOperation,
|
CloseOperation,
|
||||||
ECONNREFUSED,
|
ECONNREFUSED,
|
||||||
)
|
)
|
||||||
|
|
|
@ -387,7 +387,7 @@ var _ model.Resolver = &resolverErrWrapper{}
|
||||||
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||||
}
|
}
|
||||||
return addrs, nil
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
@ -396,7 +396,7 @@ func (r *resolverErrWrapper) LookupHTTPS(
|
||||||
ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||||
out, err := r.Resolver.LookupHTTPS(ctx, domain)
|
out, err := r.Resolver.LookupHTTPS(ctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
@ -417,7 +417,7 @@ func (r *resolverErrWrapper) LookupNS(
|
||||||
ctx context.Context, domain string) ([]*net.NS, error) {
|
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||||
out, err := r.Resolver.LookupNS(ctx, domain)
|
out, err := r.Resolver.LookupNS(ctx, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,9 +18,6 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO(bassosimone): check whether there's now equivalent functionality
|
|
||||||
// inside the standard library allowing us to map numbers to names.
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tlsVersionString = map[uint16]string{
|
tlsVersionString = map[uint16]string{
|
||||||
tls.VersionTLS10: "TLSv1",
|
tls.VersionTLS10: "TLSv1",
|
||||||
|
@ -85,6 +82,13 @@ func TLSVersionString(value uint16) string {
|
||||||
// the value to a cipher suite name, we return `TLS_CIPHER_SUITE_UNKNOWN_ddd`
|
// the value to a cipher suite name, we return `TLS_CIPHER_SUITE_UNKNOWN_ddd`
|
||||||
// where `ddd` is the numeric value passed to this function.
|
// where `ddd` is the numeric value passed to this function.
|
||||||
func TLSCipherSuiteString(value uint16) string {
|
func TLSCipherSuiteString(value uint16) string {
|
||||||
|
// TODO(https://github.com/ooni/probe/issues/2166): the standard library has a
|
||||||
|
// function for mapping a cipher suite to a string, but the value returned in case of
|
||||||
|
// missing cipher suite is different from the one we would return
|
||||||
|
// here. We could consider simplifying this code anyway because
|
||||||
|
// in most, if not all, cases we have a valid cipher suite and we
|
||||||
|
// just need to make sure what the spec says we should do when
|
||||||
|
// passed an unknown cipher suite.
|
||||||
if str, found := tlsCipherSuiteString[value]; found {
|
if str, found := tlsCipherSuiteString[value]; found {
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
@ -158,15 +162,15 @@ func NewTLSHandshakerStdlib(logger model.DebugLogger) model.TLSHandshaker {
|
||||||
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
|
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
|
||||||
func newTLSHandshaker(th model.TLSHandshaker, logger model.DebugLogger) model.TLSHandshaker {
|
func newTLSHandshaker(th model.TLSHandshaker, logger model.DebugLogger) model.TLSHandshaker {
|
||||||
return &tlsHandshakerLogger{
|
return &tlsHandshakerLogger{
|
||||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
|
||||||
TLSHandshaker: th,
|
TLSHandshaker: th,
|
||||||
},
|
|
||||||
DebugLogger: logger,
|
DebugLogger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsHandshakerConfigurable is a configurable TLS handshaker that
|
// tlsHandshakerConfigurable is a configurable TLS handshaker that
|
||||||
// uses by default the standard library's TLS implementation.
|
// uses by default the standard library's TLS implementation.
|
||||||
|
//
|
||||||
|
// This type also implements error wrapping and events tracing.
|
||||||
type tlsHandshakerConfigurable struct {
|
type tlsHandshakerConfigurable struct {
|
||||||
// NewConn is the OPTIONAL factory for creating a new connection. If
|
// NewConn is the OPTIONAL factory for creating a new connection. If
|
||||||
// this factory is not set, we'll use the stdlib.
|
// this factory is not set, we'll use the stdlib.
|
||||||
|
@ -183,9 +187,20 @@ var _ model.TLSHandshaker = &tlsHandshakerConfigurable{}
|
||||||
// value into a private variable to enable for unit testing.
|
// value into a private variable to enable for unit testing.
|
||||||
var defaultCertPool = NewDefaultCertPool()
|
var defaultCertPool = NewDefaultCertPool()
|
||||||
|
|
||||||
|
// tlsMaybeConnectionState returns the connection state if error is nil
|
||||||
|
// and otherwise just returns an empty state to the caller.
|
||||||
|
func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState {
|
||||||
|
if err != nil {
|
||||||
|
return tls.ConnectionState{}
|
||||||
|
}
|
||||||
|
return conn.ConnectionState()
|
||||||
|
}
|
||||||
|
|
||||||
// Handshake implements Handshaker.Handshake. This function will
|
// Handshake implements Handshaker.Handshake. This function will
|
||||||
// configure the code to use the built-in Mozilla CA if the config
|
// configure the code to use the built-in Mozilla CA if the config
|
||||||
// field contains a nil RootCAs field.
|
// field contains a nil RootCAs field.
|
||||||
|
//
|
||||||
|
// This function will also emit TLS-handshake-related tracing events.
|
||||||
func (h *tlsHandshakerConfigurable) Handshake(
|
func (h *tlsHandshakerConfigurable) Handshake(
|
||||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||||
) (net.Conn, tls.ConnectionState, error) {
|
) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
@ -203,10 +218,19 @@ func (h *tlsHandshakerConfigurable) Handshake(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, tls.ConnectionState{}, err
|
return nil, tls.ConnectionState{}, err
|
||||||
}
|
}
|
||||||
if err := tlsconn.HandshakeContext(ctx); err != nil {
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
|
trace := ContextTraceOrDefault(ctx)
|
||||||
|
started := trace.TimeNow()
|
||||||
|
trace.OnTLSHandshakeStart(started, remoteAddr, config)
|
||||||
|
err = tlsconn.HandshakeContext(ctx)
|
||||||
|
err = MaybeNewErrWrapper(ClassifyTLSHandshakeError, TLSHandshakeOperation, err)
|
||||||
|
finished := trace.TimeNow()
|
||||||
|
state := tlsMaybeConnectionState(tlsconn, err)
|
||||||
|
trace.OnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
|
||||||
|
if err != nil {
|
||||||
return nil, tls.ConnectionState{}, err
|
return nil, tls.ConnectionState{}, err
|
||||||
}
|
}
|
||||||
return tlsconn, tlsconn.ConnectionState(), nil
|
return tlsconn, state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newConn creates a new TLSConn.
|
// newConn creates a new TLSConn.
|
||||||
|
@ -352,23 +376,6 @@ func (d *tlsDialerSingleUseAdapter) CloseIdleConnections() {
|
||||||
d.Dialer.CloseIdleConnections()
|
d.Dialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
|
|
||||||
type tlsHandshakerErrWrapper struct {
|
|
||||||
TLSHandshaker model.TLSHandshaker
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handshake implements TLSHandshaker.Handshake
|
|
||||||
func (h *tlsHandshakerErrWrapper) Handshake(
|
|
||||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
|
||||||
) (net.Conn, tls.ConnectionState, error) {
|
|
||||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, tls.ConnectionState{}, newErrWrapper(
|
|
||||||
classifyTLSHandshakeError, TLSHandshakeOperation, err)
|
|
||||||
}
|
|
||||||
return tlsconn, state, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrNoTLSDialer is the type of error returned by "null" TLS dialers
|
// ErrNoTLSDialer is the type of error returned by "null" TLS dialers
|
||||||
// when you attempt to dial with them.
|
// when you attempt to dial with them.
|
||||||
var ErrNoTLSDialer = errors.New("no configured TLS dialer")
|
var ErrNoTLSDialer = errors.New("no configured TLS dialer")
|
||||||
|
|
|
@ -16,7 +16,10 @@ import (
|
||||||
|
|
||||||
"github.com/apex/log"
|
"github.com/apex/log"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVersionString(t *testing.T) {
|
func TestVersionString(t *testing.T) {
|
||||||
|
@ -123,8 +126,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||||
if logger.DebugLogger != log.Log {
|
if logger.DebugLogger != log.Log {
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
|
||||||
if configurable.NewConn != nil {
|
if configurable.NewConn != nil {
|
||||||
t.Fatal("expected nil NewConn")
|
t.Fatal("expected nil NewConn")
|
||||||
}
|
}
|
||||||
|
@ -132,7 +134,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||||
|
|
||||||
func TestTLSHandshakerConfigurable(t *testing.T) {
|
func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
t.Run("Handshake", func(t *testing.T) {
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
t.Run("with error", func(t *testing.T) {
|
t.Run("with handshake I/O error", func(t *testing.T) {
|
||||||
var times []time.Time
|
var times []time.Time
|
||||||
h := &tlsHandshakerConfigurable{}
|
h := &tlsHandshakerConfigurable{}
|
||||||
tcpConn := &mocks.Conn{
|
tcpConn := &mocks.Conn{
|
||||||
|
@ -143,14 +145,37 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
times = append(times, t)
|
times = append(times, t)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
conn, state, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
||||||
ServerName: "x.org",
|
ServerName: "x.org",
|
||||||
})
|
})
|
||||||
if err != io.EOF {
|
if !errors.Is(err, io.EOF) {
|
||||||
t.Fatal("not the error that we expected")
|
t.Fatal("not the error that we expected")
|
||||||
}
|
}
|
||||||
|
var errWrapper *ErrWrapper
|
||||||
|
if !errors.As(err, &errWrapper) {
|
||||||
|
t.Fatal("the error has not been wrapped")
|
||||||
|
}
|
||||||
|
if errWrapper.Failure != FailureEOFError {
|
||||||
|
t.Fatal("invalid wrapped error's failure")
|
||||||
|
}
|
||||||
|
if errWrapper.Operation != TLSHandshakeOperation {
|
||||||
|
t.Fatal("invalid wrapped error's operation")
|
||||||
|
}
|
||||||
|
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||||
|
t.Fatal("invalid wrapped error's underlying error")
|
||||||
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil con here")
|
t.Fatal("expected nil con here")
|
||||||
}
|
}
|
||||||
|
@ -163,6 +188,9 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
if !times[1].IsZero() {
|
if !times[1].IsZero() {
|
||||||
t.Fatal("did not clear timeout on exit")
|
t.Fatal("did not clear timeout on exit")
|
||||||
}
|
}
|
||||||
|
if !reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("the returned connection state is not a zero value")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with success", func(t *testing.T) {
|
t.Run("with success", func(t *testing.T) {
|
||||||
|
@ -217,6 +245,16 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
MockSetDeadline: func(t time.Time) error {
|
MockSetDeadline: func(t time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
|
@ -236,7 +274,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("we cannot create a new conn", func(t *testing.T) {
|
t.Run("h.newConn fails", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
handshaker := &tlsHandshakerConfigurable{
|
handshaker := &tlsHandshakerConfigurable{
|
||||||
NewConn: func(conn net.Conn, config *tls.Config) (TLSConn, error) {
|
NewConn: func(conn net.Conn, config *tls.Config) (TLSConn, error) {
|
||||||
|
@ -261,6 +299,222 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
t.Fatal("expected nil tlsConn here")
|
t.Fatal("expected nil tlsConn here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
expectedSNI = "dns.google"
|
||||||
|
goodStartStartTime bool
|
||||||
|
goodStartInsecureSkipVerify bool
|
||||||
|
goodDoneInsecureSkipVerify bool
|
||||||
|
goodStartServerName bool
|
||||||
|
goodDoneServerName bool
|
||||||
|
goodDoneStartTime bool
|
||||||
|
goodDoneDoneTime bool
|
||||||
|
goodStartRemoteAddr bool
|
||||||
|
goodDoneRemoteAddr bool
|
||||||
|
goodDoneError bool
|
||||||
|
goodConnectionState bool
|
||||||
|
startCalled bool
|
||||||
|
doneCalled bool
|
||||||
|
)
|
||||||
|
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||||
|
defer server.Close()
|
||||||
|
zeroTime := time.Now()
|
||||||
|
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
tx := &mocks.Trace{
|
||||||
|
MockTimeNow: deterministicTime.Now,
|
||||||
|
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||||
|
startCalled = true
|
||||||
|
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||||
|
goodStartServerName = (config.ServerName == expectedSNI)
|
||||||
|
goodStartStartTime = (now.Sub(zeroTime) == 0)
|
||||||
|
goodStartRemoteAddr = (remoteAddr == server.Endpoint())
|
||||||
|
},
|
||||||
|
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||||
|
doneCalled = true
|
||||||
|
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||||
|
goodDoneServerName = (config.ServerName == expectedSNI)
|
||||||
|
goodDoneStartTime = (started.Sub(zeroTime) == 0)
|
||||||
|
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
|
||||||
|
goodDoneRemoteAddr = (remoteAddr == server.Endpoint())
|
||||||
|
goodDoneError = (err == nil)
|
||||||
|
goodConnectionState = (!reflect.ValueOf(state).IsZero())
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := ContextWithTrace(context.Background(), tx)
|
||||||
|
tcpConn, err := net.Dial("tcp", server.Endpoint())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: expectedSNI,
|
||||||
|
}
|
||||||
|
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tlsConn.Close()
|
||||||
|
if reflect.ValueOf(connState).IsZero() {
|
||||||
|
t.Fatal("expected nonzero connState")
|
||||||
|
}
|
||||||
|
if !startCalled {
|
||||||
|
t.Fatal("start not called")
|
||||||
|
}
|
||||||
|
if !doneCalled {
|
||||||
|
t.Fatal("done not called")
|
||||||
|
}
|
||||||
|
if !goodStartInsecureSkipVerify {
|
||||||
|
t.Fatal("invalid start-event's InsecureSkipVerify")
|
||||||
|
}
|
||||||
|
if !goodDoneInsecureSkipVerify {
|
||||||
|
t.Fatal("invalid done-event's InsecureSkipVerify")
|
||||||
|
}
|
||||||
|
if !goodStartServerName {
|
||||||
|
t.Fatal("invalid start-event's ServerName")
|
||||||
|
}
|
||||||
|
if !goodDoneServerName {
|
||||||
|
t.Fatal("invalid done-event's ServerName")
|
||||||
|
}
|
||||||
|
if !goodStartStartTime {
|
||||||
|
t.Fatal("invalid start-event's start time")
|
||||||
|
}
|
||||||
|
if !goodDoneStartTime {
|
||||||
|
t.Fatal("invalid done-event's start time")
|
||||||
|
}
|
||||||
|
if !goodDoneDoneTime {
|
||||||
|
t.Fatal("invalid done-event's done time")
|
||||||
|
}
|
||||||
|
if !goodStartRemoteAddr {
|
||||||
|
t.Fatal("invalid start-event's remoteAddr")
|
||||||
|
}
|
||||||
|
if !goodDoneRemoteAddr {
|
||||||
|
t.Fatal("invalid done-event's remoteAddr")
|
||||||
|
}
|
||||||
|
if !goodDoneError {
|
||||||
|
t.Fatal("invalid done-event's error")
|
||||||
|
}
|
||||||
|
if !goodConnectionState {
|
||||||
|
t.Fatal("invalid done-event's connState")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
expectedEndpoint = "8.8.8.8:443"
|
||||||
|
expectedSNI = "dns.google"
|
||||||
|
goodStartStartTime bool
|
||||||
|
goodStartInsecureSkipVerify bool
|
||||||
|
goodDoneInsecureSkipVerify bool
|
||||||
|
goodStartServerName bool
|
||||||
|
goodDoneServerName bool
|
||||||
|
goodDoneStartTime bool
|
||||||
|
goodDoneDoneTime bool
|
||||||
|
goodStartRemoteAddr bool
|
||||||
|
goodDoneRemoteAddr bool
|
||||||
|
goodDoneError bool
|
||||||
|
goodConnectionState bool
|
||||||
|
startCalled bool
|
||||||
|
doneCalled bool
|
||||||
|
)
|
||||||
|
zeroTime := time.Now()
|
||||||
|
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
tx := &mocks.Trace{
|
||||||
|
MockTimeNow: deterministicTime.Now,
|
||||||
|
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||||
|
startCalled = true
|
||||||
|
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||||
|
goodStartServerName = (config.ServerName == expectedSNI)
|
||||||
|
goodStartStartTime = (now.Sub(zeroTime) == 0)
|
||||||
|
goodStartRemoteAddr = (remoteAddr == expectedEndpoint)
|
||||||
|
},
|
||||||
|
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||||
|
doneCalled = true
|
||||||
|
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||||
|
goodDoneServerName = (config.ServerName == expectedSNI)
|
||||||
|
goodDoneStartTime = (started.Sub(zeroTime) == 0)
|
||||||
|
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
|
||||||
|
goodDoneRemoteAddr = (remoteAddr == expectedEndpoint)
|
||||||
|
var ew *ErrWrapper
|
||||||
|
goodDoneError = (errors.As(err, &ew) && ew.Error() == FailureEOFError)
|
||||||
|
goodConnectionState = (reflect.ValueOf(state).IsZero())
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := ContextWithTrace(context.Background(), tx)
|
||||||
|
tcpConn := &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return expectedEndpoint
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: expectedSNI,
|
||||||
|
}
|
||||||
|
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if tlsConn != nil {
|
||||||
|
t.Fatal("expected nil tlsConn")
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(connState).IsZero() {
|
||||||
|
t.Fatal("expected zero connState")
|
||||||
|
}
|
||||||
|
if !startCalled {
|
||||||
|
t.Fatal("start not called")
|
||||||
|
}
|
||||||
|
if !doneCalled {
|
||||||
|
t.Fatal("done not called")
|
||||||
|
}
|
||||||
|
if !goodStartInsecureSkipVerify {
|
||||||
|
t.Fatal("invalid start-event's InsecureSkipVerify")
|
||||||
|
}
|
||||||
|
if !goodDoneInsecureSkipVerify {
|
||||||
|
t.Fatal("invalid done-event's InsecureSkipVerify")
|
||||||
|
}
|
||||||
|
if !goodStartServerName {
|
||||||
|
t.Fatal("invalid start-event's ServerName")
|
||||||
|
}
|
||||||
|
if !goodDoneServerName {
|
||||||
|
t.Fatal("invalid done-event's ServerName")
|
||||||
|
}
|
||||||
|
if !goodStartStartTime {
|
||||||
|
t.Fatal("invalid start-event's start time")
|
||||||
|
}
|
||||||
|
if !goodDoneStartTime {
|
||||||
|
t.Fatal("invalid done-event's start time")
|
||||||
|
}
|
||||||
|
if !goodDoneDoneTime {
|
||||||
|
t.Fatal("invalid done-event's done time")
|
||||||
|
}
|
||||||
|
if !goodStartRemoteAddr {
|
||||||
|
t.Fatal("invalid start-event's remoteAddr")
|
||||||
|
}
|
||||||
|
if !goodDoneRemoteAddr {
|
||||||
|
t.Fatal("invalid done-event's remoteAddr")
|
||||||
|
}
|
||||||
|
if !goodDoneError {
|
||||||
|
t.Fatal("invalid done-event's error")
|
||||||
|
}
|
||||||
|
if !goodConnectionState {
|
||||||
|
t.Fatal("invalid done-event's connState")
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -413,6 +667,15 @@ func TestTLSDialer(t *testing.T) {
|
||||||
return nil
|
return nil
|
||||||
}, MockSetDeadline: func(t time.Time) error {
|
}, MockSetDeadline: func(t time.Time) error {
|
||||||
return nil
|
return nil
|
||||||
|
}, MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "1.1.1.1:443"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
}}, nil
|
}}, nil
|
||||||
}},
|
}},
|
||||||
TLSHandshaker: &tlsHandshakerConfigurable{},
|
TLSHandshaker: &tlsHandshakerConfigurable{},
|
||||||
|
@ -532,54 +795,6 @@ func TestNewSingleUseTLSDialer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerErrWrapper(t *testing.T) {
|
|
||||||
t.Run("Handshake", func(t *testing.T) {
|
|
||||||
t.Run("on success", func(t *testing.T) {
|
|
||||||
expectedConn := &mocks.TLSConn{}
|
|
||||||
expectedState := tls.ConnectionState{
|
|
||||||
Version: tls.VersionTLS12,
|
|
||||||
}
|
|
||||||
th := &tlsHandshakerErrWrapper{
|
|
||||||
TLSHandshaker: &mocks.TLSHandshaker{
|
|
||||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
|
||||||
return expectedConn, expectedState, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if expectedState.Version != state.Version {
|
|
||||||
t.Fatal("unexpected state")
|
|
||||||
}
|
|
||||||
if expectedConn != conn {
|
|
||||||
t.Fatal("unexpected conn")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("on failure", func(t *testing.T) {
|
|
||||||
expectedErr := io.EOF
|
|
||||||
th := &tlsHandshakerErrWrapper{
|
|
||||||
TLSHandshaker: &mocks.TLSHandshaker{
|
|
||||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
|
||||||
return nil, tls.ConnectionState{}, expectedErr
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
|
||||||
if err == nil || err.Error() != FailureEOFError {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("unexpected conn")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewNullTLSDialer(t *testing.T) {
|
func TestNewNullTLSDialer(t *testing.T) {
|
||||||
dialer := NewNullTLSDialer()
|
dialer := NewNullTLSDialer()
|
||||||
conn, err := dialer.DialTLSContext(context.Background(), "", "")
|
conn, err := dialer.DialTLSContext(context.Background(), "", "")
|
||||||
|
@ -618,3 +833,35 @@ func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMaybeConnectionState(t *testing.T) {
|
||||||
|
t.Run("with an error", func(t *testing.T) {
|
||||||
|
returned := tls.ConnectionState{
|
||||||
|
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||||||
|
}
|
||||||
|
conn := &mocks.TLSConn{
|
||||||
|
MockConnectionState: func() tls.ConnectionState {
|
||||||
|
return returned
|
||||||
|
},
|
||||||
|
}
|
||||||
|
state := tlsMaybeConnectionState(conn, errors.New("mocked error"))
|
||||||
|
if !reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("expected to see a zero connection state")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("without an error", func(t *testing.T) {
|
||||||
|
returned := tls.ConnectionState{
|
||||||
|
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||||||
|
}
|
||||||
|
conn := &mocks.TLSConn{
|
||||||
|
MockConnectionState: func() tls.ConnectionState {
|
||||||
|
return returned
|
||||||
|
},
|
||||||
|
}
|
||||||
|
state := tlsMaybeConnectionState(conn, nil)
|
||||||
|
if reflect.ValueOf(state).IsZero() {
|
||||||
|
t.Fatal("expected to see a nonzero connection state")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
67
internal/netxlite/trace.go
Normal file
67
internal/netxlite/trace.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package netxlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Context-based tracing
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
|
)
|
||||||
|
|
||||||
|
// traceKey is the private type used to set/retrieve the context's trace.
|
||||||
|
type traceKey struct{}
|
||||||
|
|
||||||
|
// ContextTraceOrDefault retrieves the trace bound to the context or returns
|
||||||
|
// a default implementation of the trace in case no tracing was configured.
|
||||||
|
func ContextTraceOrDefault(ctx context.Context) model.Trace {
|
||||||
|
t, _ := ctx.Value(traceKey{}).(model.Trace)
|
||||||
|
return traceOrDefault(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextWithTrace returns a new context that binds to the given trace. If the
|
||||||
|
// given trace is nil, this function will call panic.
|
||||||
|
func ContextWithTrace(ctx context.Context, trace model.Trace) context.Context {
|
||||||
|
runtimex.PanicIfTrue(trace == nil, "netxlite.WithTrace passed a nil trace")
|
||||||
|
return context.WithValue(ctx, traceKey{}, trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
// traceOrDefault takes in input a trace and returns in output the
|
||||||
|
// given trace, if not nil, or a default trace implementation.
|
||||||
|
func traceOrDefault(trace model.Trace) model.Trace {
|
||||||
|
if trace != nil {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
return &traceDefault{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// traceDefault is a default model.Trace implementation where each method is a no-op.
|
||||||
|
type traceDefault struct{}
|
||||||
|
|
||||||
|
var _ model.Trace = &traceDefault{}
|
||||||
|
|
||||||
|
// TimeNow implements model.Trace.TimeNow
|
||||||
|
func (*traceDefault) TimeNow() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConnectDone implements model.Trace.OnConnectDone.
|
||||||
|
func (*traceDefault) OnConnectDone(
|
||||||
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||||
|
// nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
|
||||||
|
func (*traceDefault) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
|
||||||
|
// nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
|
||||||
|
func (*traceDefault) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||||
|
state tls.ConnectionState, err error, finished time.Time) {
|
||||||
|
// nothing
|
||||||
|
}
|
40
internal/netxlite/trace_test.go
Normal file
40
internal/netxlite/trace_test.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package netxlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextTraceOrDefault(t *testing.T) {
|
||||||
|
t.Run("without a configured trace we get a default", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
tx := ContextTraceOrDefault(ctx)
|
||||||
|
_ = tx.(*traceDefault) // panic if cannot cast
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with a configured trace we get the expected trace", func(t *testing.T) {
|
||||||
|
realTrace := &mocks.Trace{}
|
||||||
|
ctx := ContextWithTrace(context.Background(), realTrace)
|
||||||
|
tx := ContextTraceOrDefault(ctx)
|
||||||
|
if tx != realTrace {
|
||||||
|
t.Fatal("not the trace we expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextWithTrace(t *testing.T) {
|
||||||
|
t.Run("panics if passed a nil trace", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
called = (recover() != nil)
|
||||||
|
}()
|
||||||
|
_ = ContextWithTrace(context.Background(), nil)
|
||||||
|
}()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -19,8 +19,7 @@ func TestNewTLSHandshakerUTLS(t *testing.T) {
|
||||||
if logger.DebugLogger != log.Log {
|
if logger.DebugLogger != log.Log {
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
|
||||||
if configurable.NewConn == nil {
|
if configurable.NewConn == nil {
|
||||||
t.Fatal("expected non-nil NewConn")
|
t.Fatal("expected non-nil NewConn")
|
||||||
}
|
}
|
||||||
|
|
2
internal/testingx/doc.go
Normal file
2
internal/testingx/doc.go
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
// Package testingx contains code useful for testing.
|
||||||
|
package testingx
|
|
@ -1,11 +1,4 @@
|
||||||
// Package fakefill contains code to fill structs for testing.
|
package testingx
|
||||||
//
|
|
||||||
// This package is quite limited in scope and we can fill only the
|
|
||||||
// structures you typically send over as JSONs.
|
|
||||||
//
|
|
||||||
// As part of future work, we aim to investigate whether we can
|
|
||||||
// replace this implementation with https://go.dev/blog/fuzz-beta.
|
|
||||||
package fakefill
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
@ -14,7 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Filler fills specific data structures with random data. The only
|
// FakeFiller fills specific data structures with random data. The only
|
||||||
// exception to this behaviour is time.Time, which is instead filled
|
// exception to this behaviour is time.Time, which is instead filled
|
||||||
// with the current time plus a small random number of seconds.
|
// with the current time plus a small random number of seconds.
|
||||||
//
|
//
|
||||||
|
@ -25,7 +18,13 @@ import (
|
||||||
// Caveat: this kind of fillter does not support filling interfaces
|
// Caveat: this kind of fillter does not support filling interfaces
|
||||||
// and channels and other complex types. The current behavior when this
|
// and channels and other complex types. The current behavior when this
|
||||||
// kind of data types is encountered is to just ignore them.
|
// kind of data types is encountered is to just ignore them.
|
||||||
type Filler struct {
|
//
|
||||||
|
// This struct is quite limited in scope and we can fill only the
|
||||||
|
// structures you typically send over as JSONs.
|
||||||
|
//
|
||||||
|
// As part of future work, we aim to investigate whether we can
|
||||||
|
// replace this implementation with https://go.dev/blog/fuzz-beta.
|
||||||
|
type FakeFiller struct {
|
||||||
// mu provides mutual exclusion
|
// mu provides mutual exclusion
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
|
@ -37,7 +36,7 @@ type Filler struct {
|
||||||
rnd *rand.Rand
|
rnd *rand.Rand
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ff *Filler) getRandLocked() *rand.Rand {
|
func (ff *FakeFiller) getRandLocked() *rand.Rand {
|
||||||
if ff.rnd == nil {
|
if ff.rnd == nil {
|
||||||
now := time.Now
|
now := time.Now
|
||||||
if ff.Now != nil {
|
if ff.Now != nil {
|
||||||
|
@ -48,7 +47,7 @@ func (ff *Filler) getRandLocked() *rand.Rand {
|
||||||
return ff.rnd
|
return ff.rnd
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ff *Filler) getRandomString() string {
|
func (ff *FakeFiller) getRandomString() string {
|
||||||
defer ff.mu.Unlock()
|
defer ff.mu.Unlock()
|
||||||
ff.mu.Lock()
|
ff.mu.Lock()
|
||||||
rnd := ff.getRandLocked()
|
rnd := ff.getRandLocked()
|
||||||
|
@ -62,28 +61,28 @@ func (ff *Filler) getRandomString() string {
|
||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ff *Filler) getRandomInt64() int64 {
|
func (ff *FakeFiller) getRandomInt64() int64 {
|
||||||
defer ff.mu.Unlock()
|
defer ff.mu.Unlock()
|
||||||
ff.mu.Lock()
|
ff.mu.Lock()
|
||||||
rnd := ff.getRandLocked()
|
rnd := ff.getRandLocked()
|
||||||
return rnd.Int63()
|
return rnd.Int63()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ff *Filler) getRandomBool() bool {
|
func (ff *FakeFiller) getRandomBool() bool {
|
||||||
defer ff.mu.Unlock()
|
defer ff.mu.Unlock()
|
||||||
ff.mu.Lock()
|
ff.mu.Lock()
|
||||||
rnd := ff.getRandLocked()
|
rnd := ff.getRandLocked()
|
||||||
return rnd.Float64() >= 0.5
|
return rnd.Float64() >= 0.5
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ff *Filler) getRandomSmallPositiveInt() int {
|
func (ff *FakeFiller) getRandomSmallPositiveInt() int {
|
||||||
defer ff.mu.Unlock()
|
defer ff.mu.Unlock()
|
||||||
ff.mu.Lock()
|
ff.mu.Lock()
|
||||||
rnd := ff.getRandLocked()
|
rnd := ff.getRandLocked()
|
||||||
return int(rnd.Int63n(8)) + 1 // safe cast
|
return int(rnd.Int63n(8)) + 1 // safe cast
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ff *Filler) doFill(v reflect.Value) {
|
func (ff *FakeFiller) doFill(v reflect.Value) {
|
||||||
for v.Type().Kind() == reflect.Ptr {
|
for v.Type().Kind() == reflect.Ptr {
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
// if the pointer is nil, allocate an element
|
// if the pointer is nil, allocate an element
|
||||||
|
@ -134,6 +133,6 @@ func (ff *Filler) doFill(v reflect.Value) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill fills the input structure or pointer with random data.
|
// Fill fills the input structure or pointer with random data.
|
||||||
func (ff *Filler) Fill(in interface{}) {
|
func (ff *FakeFiller) Fill(in interface{}) {
|
||||||
ff.doFill(reflect.ValueOf(in))
|
ff.doFill(reflect.ValueOf(in))
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package fakefill
|
package testingx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -16,7 +16,7 @@ type exampleStructure struct {
|
||||||
|
|
||||||
func TestFakeFillWorksWithCustomTime(t *testing.T) {
|
func TestFakeFillWorksWithCustomTime(t *testing.T) {
|
||||||
var req *exampleStructure
|
var req *exampleStructure
|
||||||
ff := &Filler{
|
ff := &FakeFiller{
|
||||||
Now: func() time.Time {
|
Now: func() time.Time {
|
||||||
return time.Date(1992, time.January, 24, 17, 53, 0, 0, time.UTC)
|
return time.Date(1992, time.January, 24, 17, 53, 0, 0, time.UTC)
|
||||||
},
|
},
|
||||||
|
@ -29,7 +29,7 @@ func TestFakeFillWorksWithCustomTime(t *testing.T) {
|
||||||
|
|
||||||
func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) {
|
func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) {
|
||||||
var req *exampleStructure
|
var req *exampleStructure
|
||||||
ff := &Filler{}
|
ff := &FakeFiller{}
|
||||||
ff.Fill(&req)
|
ff.Fill(&req)
|
||||||
if req == nil {
|
if req == nil {
|
||||||
t.Fatal("we expected non nil here")
|
t.Fatal("we expected non nil here")
|
||||||
|
@ -38,7 +38,7 @@ func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) {
|
||||||
|
|
||||||
func TestFakeFillAllocatesIntoAMapLikeWithStringKeys(t *testing.T) {
|
func TestFakeFillAllocatesIntoAMapLikeWithStringKeys(t *testing.T) {
|
||||||
var resp map[string]*exampleStructure
|
var resp map[string]*exampleStructure
|
||||||
ff := &Filler{}
|
ff := &FakeFiller{}
|
||||||
ff.Fill(&resp)
|
ff.Fill(&resp)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
t.Fatal("we expected non nil here")
|
t.Fatal("we expected non nil here")
|
||||||
|
@ -62,7 +62,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var resp map[int64]*exampleStructure
|
var resp map[int64]*exampleStructure
|
||||||
ff := &Filler{}
|
ff := &FakeFiller{}
|
||||||
ff.Fill(&resp)
|
ff.Fill(&resp)
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
t.Fatal("we expected nil here")
|
t.Fatal("we expected nil here")
|
||||||
|
@ -75,7 +75,7 @@ func TestFakeFillAllocatesIntoAMapLikeWithNonStringKeys(t *testing.T) {
|
||||||
|
|
||||||
func TestFakeFillAllocatesIntoASlice(t *testing.T) {
|
func TestFakeFillAllocatesIntoASlice(t *testing.T) {
|
||||||
var resp *[]*exampleStructure
|
var resp *[]*exampleStructure
|
||||||
ff := &Filler{}
|
ff := &FakeFiller{}
|
||||||
ff.Fill(&resp)
|
ff.Fill(&resp)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
t.Fatal("we expected non nil here")
|
t.Fatal("we expected non nil here")
|
48
internal/testingx/time.go
Normal file
48
internal/testingx/time.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package testingx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TimeDeterministic implements time.Now in a deterministic fashion
|
||||||
|
// such that every time.Time call returns a moment in time that occurs
|
||||||
|
// one second after the configured zeroTime.
|
||||||
|
//
|
||||||
|
// It's safe to use this struct from multiple goroutine contexts.
|
||||||
|
type TimeDeterministic struct {
|
||||||
|
// counter counts the number of "ticks" passed since the zero time: each
|
||||||
|
// call to Now increments this counter by one second.
|
||||||
|
counter time.Duration
|
||||||
|
|
||||||
|
// mu protects fields in this structure from concurrent access.
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// zeroTime is the lazy-initialized zero time. The first call to Now
|
||||||
|
// will initialize this field with the current time.
|
||||||
|
zeroTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTimeDeterministic creates a new instance using the given zeroTime value.
|
||||||
|
func NewTimeDeterministic(zeroTime time.Time) *TimeDeterministic {
|
||||||
|
return &TimeDeterministic{
|
||||||
|
counter: 0,
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
zeroTime: zeroTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now is like time.Now but more deterministic. The first call returns the
|
||||||
|
// configured zeroTime and subsequent calls return moments in time that occur
|
||||||
|
// exactly one second after the time returned by the previous call.
|
||||||
|
func (td *TimeDeterministic) Now() time.Time {
|
||||||
|
td.mu.Lock()
|
||||||
|
if td.zeroTime.IsZero() {
|
||||||
|
td.zeroTime = time.Now()
|
||||||
|
}
|
||||||
|
offset := td.counter
|
||||||
|
td.counter += time.Second
|
||||||
|
res := td.zeroTime.Add(offset)
|
||||||
|
td.mu.Unlock()
|
||||||
|
return res
|
||||||
|
}
|
22
internal/testingx/time_test.go
Normal file
22
internal/testingx/time_test.go
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
package testingx
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTimeDeterministic(t *testing.T) {
|
||||||
|
td := &TimeDeterministic{}
|
||||||
|
t0 := td.Now()
|
||||||
|
if !t0.Equal(td.zeroTime) {
|
||||||
|
t.Fatal("invalid t0 value")
|
||||||
|
}
|
||||||
|
t1 := td.Now()
|
||||||
|
if t1.Sub(t0) != time.Second {
|
||||||
|
t.Fatal("invalid t1 value")
|
||||||
|
}
|
||||||
|
t2 := td.Now()
|
||||||
|
if t2.Sub(t1) != time.Second {
|
||||||
|
t.Fatal("invalid t2 value")
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,4 +5,8 @@
|
||||||
// resulting connection). When done, we will extract the trace containing
|
// resulting connection). When done, we will extract the trace containing
|
||||||
// all the events that occurred from the saver and process it to determine
|
// all the events that occurred from the saver and process it to determine
|
||||||
// what happened during the measurement itself.
|
// what happened during the measurement itself.
|
||||||
|
//
|
||||||
|
// This package is now frozen. Please, use measurexlite for new code. See
|
||||||
|
// https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md
|
||||||
|
// for details about this.
|
||||||
package tracex
|
package tracex
|
||||||
|
|
Loading…
Reference in New Issue
Block a user