feat: tlsping and tcpping using step-by-step (#815)

## Checklist

- [x] I have read the [contribution guidelines](https://github.com/ooni/probe-cli/blob/master/CONTRIBUTING.md)
- [x] reference issue for this pull request: https://github.com/ooni/probe/issues/2158
- [x] if you changed anything related how experiments work and you need to reflect these changes in the ooni/spec repository, please link to the related ooni/spec pull request: https://github.com/ooni/spec/pull/250

## Description

This diff refactors the codebase to reimplement tlsping and tcpping
to use the step-by-step measurements style.

See docs/design/dd-003-step-by-step.md for more information on the
step-by-step measurement style.
This commit is contained in:
Simone Basso
2022-07-01 12:22:22 +02:00
committed by GitHub
parent 5371c7f486
commit 5ebdeb56ca
48 changed files with 2825 additions and 299 deletions
+103
View File
@@ -0,0 +1,103 @@
package measurexlite
//
// Conn tracing
//
import (
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/tracex"
)
// MaybeClose is a convenience function for closing a conn only when such a conn isn't nil.
func MaybeClose(conn net.Conn) (err error) {
if conn != nil {
err = conn.Close()
}
return
}
// WrapNetConn returns a wrapped conn that saves network events into this trace.
func (tx *Trace) WrapNetConn(conn net.Conn) net.Conn {
return &connTrace{
Conn: conn,
tx: tx,
}
}
// connTrace is a trace-aware net.Conn.
type connTrace struct {
// Implementation note: it seems safe to use embedding here because net.Conn
// is an interface from the standard library that we don't control
net.Conn
tx *Trace
}
var _ net.Conn = &connTrace{}
// Read implements net.Conn.Read and saves network events.
func (c *connTrace) Read(b []byte) (int, error) {
network := c.RemoteAddr().Network()
addr := c.RemoteAddr().String()
started := c.tx.TimeSince(c.tx.ZeroTime)
count, err := c.Conn.Read(b)
finished := c.tx.TimeSince(c.tx.ZeroTime)
select {
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
c.tx.Index, started, netxlite.ReadOperation, network, addr, count, err, finished):
default: // buffer is full
}
return count, err
}
// Write implements net.Conn.Write and saves network events.
func (c *connTrace) Write(b []byte) (int, error) {
network := c.RemoteAddr().Network()
addr := c.RemoteAddr().String()
started := c.tx.TimeSince(c.tx.ZeroTime)
count, err := c.Conn.Write(b)
finished := c.tx.TimeSince(c.tx.ZeroTime)
select {
case c.tx.NetworkEvent <- NewArchivalNetworkEvent(
c.tx.Index, started, netxlite.WriteOperation, network, addr, count, err, finished):
default: // buffer is full
}
return count, err
}
// NewArchivalNetworkEvent creates a new model.ArchivalNetworkEvent.
func NewArchivalNetworkEvent(index int64, started time.Duration, operation string, network string,
address string, count int, err error, finished time.Duration) *model.ArchivalNetworkEvent {
return &model.ArchivalNetworkEvent{
Address: address,
Failure: tracex.NewFailure(err),
NumBytes: int64(count),
Operation: operation,
Proto: network,
T: finished.Seconds(),
Tags: []string{},
}
}
// NewAnnotationArchivalNetworkEvent is a simplified NewArchivalNetworkEvent
// where we create a simple annotation without attached I/O info.
func NewAnnotationArchivalNetworkEvent(
index int64, time time.Duration, operation string) *model.ArchivalNetworkEvent {
return NewArchivalNetworkEvent(index, time, operation, "", "", 0, nil, time)
}
// NetworkEvents drains the network events buffered inside the NetworkEvent channel.
func (tx *Trace) NetworkEvents() (out []*model.ArchivalNetworkEvent) {
for {
select {
case ev := <-tx.NetworkEvent:
out = append(out, ev)
default:
return // done
}
}
}
+243
View File
@@ -0,0 +1,243 @@
package measurexlite
import (
"net"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestMaybeClose(t *testing.T) {
t.Run("with nil conn", func(t *testing.T) {
var conn net.Conn = nil
MaybeClose(conn)
})
t.Run("with nonnil conn", func(t *testing.T) {
var called bool
conn := &mocks.Conn{
MockClose: func() error {
called = true
return nil
},
}
if err := MaybeClose(conn); err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("not called")
}
})
}
func TestWrapNetConn(t *testing.T) {
t.Run("WrapNetConn wraps the conn", func(t *testing.T) {
underlying := &mocks.Conn{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
conn := trace.WrapNetConn(underlying)
ct := conn.(*connTrace)
if ct.Conn != underlying {
t.Fatal("invalid underlying")
}
if ct.tx != trace {
t.Fatal("invalid trace")
}
})
t.Run("Read saves a trace", func(t *testing.T) {
underlying := &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time counting
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Read(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 1 {
t.Fatal("did not save network events")
}
expect := &model.ArchivalNetworkEvent{
Address: "1.1.1.1:443",
Failure: nil,
NumBytes: bufsiz,
Operation: netxlite.ReadOperation,
Proto: "tcp",
T: 1.0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Read discards the event when the buffer is full", func(t *testing.T) {
underlying := &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Read(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 0 {
t.Fatal("expected no network events")
}
})
t.Run("Write saves a trace", func(t *testing.T) {
underlying := &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time tracking
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Write(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 1 {
t.Fatal("did not save network events")
}
expect := &model.ArchivalNetworkEvent{
Address: "1.1.1.1:443",
Failure: nil,
NumBytes: bufsiz,
Operation: netxlite.WriteOperation,
Proto: "tcp",
T: 1.0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Write discards the event when the buffer is full", func(t *testing.T) {
underlying := &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapNetConn(underlying)
const bufsiz = 128
buffer := make([]byte, bufsiz)
count, err := conn.Write(buffer)
if count != bufsiz {
t.Fatal("invalid count")
}
if err != nil {
t.Fatal("invalid err")
}
events := trace.NetworkEvents()
if len(events) != 0 {
t.Fatal("expected no network events")
}
})
}
func TestNewAnnotationArchivalNetworkEvent(t *testing.T) {
var (
index int64 = 3
duration = 250 * time.Millisecond
operation = "tls_handshake_start"
)
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: operation,
Proto: "",
T: duration.Seconds(),
Tags: []string{},
}
got := NewAnnotationArchivalNetworkEvent(
index, duration, operation,
)
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
}
+121
View File
@@ -0,0 +1,121 @@
package measurexlite
//
// Dialer tracing
//
import (
"context"
"log"
"math"
"net"
"strconv"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/tracex"
)
// NewDialerWithoutResolver is equivalent to netxlite.NewDialerWithoutResolver
// except that it returns a model.Dialer that uses this trace.
//
// Note: unlike code in netx or measurex, this factory DOES NOT return you a
// dialer that also performs wrapping of a net.Conn in case of success. If you
// want to wrap the conn, you need to wrap it explicitly using WrapNetConn.
func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
return &dialerTrace{
d: tx.newDialerWithoutResolver(dl),
tx: tx,
}
}
// dialerTrace is a trace-aware model.Dialer.
type dialerTrace struct {
d model.Dialer
tx *Trace
}
var _ model.Dialer = &dialerTrace{}
// DialContext implements model.Dialer.DialContext.
func (d *dialerTrace) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(netxlite.ContextWithTrace(ctx, d.tx), network, address)
}
// CloseIdleConnections implements model.Dialer.CloseIdleConnections.
func (d *dialerTrace) CloseIdleConnections() {
d.d.CloseIdleConnections()
}
// OnTCPConnectDone implements model.Trace.OnTCPConnectDone.
func (tx *Trace) OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
switch network {
case "tcp", "tcp4", "tcp6":
select {
case tx.TCPConnect <- NewArchivalTCPConnectResult(
tx.Index,
started.Sub(tx.ZeroTime),
remoteAddr,
err,
finished.Sub(tx.ZeroTime),
):
default: // buffer is full
}
default:
// ignore UDP connect attempts because they cannot fail
// in interesting ways that make sense for censorship
}
}
// NewArchivalTCPConnectResult generates a model.ArchivalTCPConnectResult
// from the available information right after connect returns.
func NewArchivalTCPConnectResult(index int64, started time.Duration, address string,
err error, finished time.Duration) *model.ArchivalTCPConnectResult {
ip, port := archivalSplitHostPort(address)
return &model.ArchivalTCPConnectResult{
IP: ip,
Port: archivalPortToString(port),
Status: model.ArchivalTCPConnectStatus{
Blocked: nil,
Failure: tracex.NewFailure(err),
Success: err == nil,
},
T: finished.Seconds(),
}
}
// archivalSplitHostPort is like net.SplitHostPort but does not return an error. This
// function returns two empty strings in case of any failure.
func archivalSplitHostPort(endpoint string) (string, string) {
addr, port, err := net.SplitHostPort(endpoint)
if err != nil {
log.Printf("BUG: archivalSplitHostPort: invalid endpoint: %s", endpoint)
return "", ""
}
return addr, port
}
// archivalPortToString is like strconv.Atoi but does not return an error. This
// function returns a zero port number in case of any failure.
func archivalPortToString(sport string) int {
port, err := strconv.Atoi(sport)
if err != nil || port < 0 || port > math.MaxUint16 {
log.Printf("BUG: archivalStrconvAtoi: invalid port: %s", sport)
return 0
}
return port
}
// TCPConnects drains the network events buffered inside the TCPConnect channel.
func (tx *Trace) TCPConnects() (out []*model.ArchivalTCPConnectResult) {
for {
select {
case ev := <-tx.TCPConnect:
out = append(out, ev)
default:
return // done
}
}
}
+211
View File
@@ -0,0 +1,211 @@
package measurexlite
import (
"context"
"errors"
"math"
"net"
"strconv"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewDialerWithoutResolver(t *testing.T) {
t.Run("NewDialerWithoutResolver creates a wrapped dialer", func(t *testing.T) {
underlying := &mocks.Dialer{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
return underlying
}
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
dt := dialer.(*dialerTrace)
if dt.d != underlying {
t.Fatal("invalid dialer")
}
if dt.tx != trace {
t.Fatal("invalid trace")
}
})
t.Run("DialContext calls the underlying dialer with context-based tracing", func(t *testing.T) {
expectedErr := errors.New("mocked err")
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
var hasCorrectTrace bool
underlying := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
gotTrace := netxlite.ContextTraceOrDefault(ctx)
hasCorrectTrace = (gotTrace == trace)
return nil, expectedErr
},
}
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
return underlying
}
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx := context.Background()
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
if !hasCorrectTrace {
t.Fatal("does not have the correct trace")
}
})
t.Run("CloseIdleConnection is correctly forwarded", func(t *testing.T) {
var called bool
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
underlying := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
trace.NewDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer {
return underlying
}
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("DialContext saves into the trace", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time tracking
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 1 {
t.Fatal("expected to see single TCPConnect event")
}
expectedFailure := netxlite.FailureInterrupted
expect := &model.ArchivalTCPConnectResult{
IP: "1.1.1.1",
Port: 443,
Status: model.ArchivalTCPConnectStatus{
Blocked: nil,
Failure: &expectedFailure,
Success: false,
},
T: time.Second.Seconds(),
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("DialContext discards events when buffer is full", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.TCPConnect = make(chan *model.ArchivalTCPConnectResult) // no buffer
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 0 {
t.Fatal("expected to see no TCPConnect events")
}
})
t.Run("DialContext ignores UDP connect attempts", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "udp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 0 {
t.Fatal("expected to see no TCPConnect events")
}
})
t.Run("DialContext uses a dialer without a resolver", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
dialer := trace.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
conn, err := dialer.DialContext(ctx, "udp", "dns.google:443") // domain
if !errors.Is(err, netxlite.ErrNoResolver) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := trace.TCPConnects()
if len(events) != 0 {
t.Fatal("expected to see no TCPConnect events")
}
})
}
func TestArchivalSplitHostPort(t *testing.T) {
addr, port := archivalSplitHostPort("1.1.1.1") // missing port
if addr != "" {
t.Fatal("invalid addr", addr)
}
if port != "" {
t.Fatal("invalid port", port)
}
}
func TestArchivalPortToString(t *testing.T) {
t.Run("with invalid number", func(t *testing.T) {
port := archivalPortToString("antani")
if port != 0 {
t.Fatal("invalid port")
}
})
t.Run("with negative number", func(t *testing.T) {
port := archivalPortToString("-1")
if port != 0 {
t.Fatal("invalid port")
}
})
t.Run("with too-large positive number", func(t *testing.T) {
port := archivalPortToString(strconv.Itoa(math.MaxUint16 + 1))
if port != 0 {
t.Fatal("invalid port")
}
})
}
+10
View File
@@ -0,0 +1,10 @@
// Package measurexlite contains measurement extensions.
//
// See docs/design/dd-003-step-by-step.md in the ooni/probe-cli
// repository for the design document.
//
// This implementation features a Trace that saves events in
// buffered channels as proposed by df-003-step-by-step.md. We
// have reasonable default buffers for channels. But, if you
// are not draining them, eventually we stop collecting events.
package measurexlite
+16
View File
@@ -0,0 +1,16 @@
package measurexlite
//
// Logging support
//
import "github.com/ooni/probe-cli/v3/internal/measurex"
// TODO(bassosimone): we should eventually remove measurex and
// move the logging code from measurex to this package.
// NewOperationLogger is an alias for measurex.NewOperationLogger.
var NewOperationLogger = measurex.NewOperationLogger
// OperationLogger is an alias for measurex.OperationLogger.
type OperationLogger = measurex.OperationLogger
+144
View File
@@ -0,0 +1,144 @@
package measurexlite
//
// TLS tracing
//
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/tracex"
)
// NewTLSHandshakerStdlib is equivalent to netxlite.NewTLSHandshakerStdlib
// except that it returns a model.TLSHandshaker that uses this trace.
func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
return &tlsHandshakerTrace{
thx: tx.newTLSHandshakerStdlib(dl),
tx: tx,
}
}
// tlsHandshakerTrace is a trace-aware TLS handshaker.
type tlsHandshakerTrace struct {
thx model.TLSHandshaker
tx *Trace
}
var _ model.TLSHandshaker = &tlsHandshakerTrace{}
// Handshake implements model.TLSHandshaker.Handshake.
func (thx *tlsHandshakerTrace) Handshake(
ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (net.Conn, tls.ConnectionState, error) {
return thx.thx.Handshake(netxlite.ContextWithTrace(ctx, thx.tx), conn, tlsConfig)
}
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
t := now.Sub(tx.ZeroTime)
select {
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_start"):
default: // buffer is full
}
}
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Time) {
t := finished.Sub(tx.ZeroTime)
select {
case tx.TLSHandshake <- NewArchivalTLSOrQUICHandshakeResult(
tx.Index,
started.Sub(tx.ZeroTime),
remoteAddr,
config,
state,
err,
t,
):
default: // buffer is full
}
select {
case tx.NetworkEvent <- NewAnnotationArchivalNetworkEvent(tx.Index, t, "tls_handshake_done"):
default: // buffer is full
}
}
// NewArchivalTLSOrQUICHandshakeResult generates a model.ArchivalTLSOrQUICHandshakeResult
// from the available information right after the TLS handshake returns.
func NewArchivalTLSOrQUICHandshakeResult(
index int64, started time.Duration, address string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Duration) *model.ArchivalTLSOrQUICHandshakeResult {
return &model.ArchivalTLSOrQUICHandshakeResult{
Address: address,
CipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
Failure: tracex.NewFailure(err),
NegotiatedProtocol: state.NegotiatedProtocol,
NoTLSVerify: config.InsecureSkipVerify,
PeerCertificates: TLSPeerCerts(state, err),
ServerName: config.ServerName,
T: finished.Seconds(),
Tags: []string{},
TLSVersion: netxlite.TLSVersionString(state.Version),
}
}
// newArchivalBinaryData is a factory that adapts binary data to the
// model.ArchivalMaybeBinaryData format.
func newArchivalBinaryData(data []byte) model.ArchivalMaybeBinaryData {
// TODO(https://github.com/ooni/probe/issues/2165): we should actually extend the
// model's archival data format to have a pure-binary-data type for the cases in which
// we know in advance we're dealing with binary data.
return model.ArchivalMaybeBinaryData{
Value: string(data),
}
}
// TLSPeerCerts extracts the certificates either from the list of certificates
// in the connection state or from the error that occurred.
func TLSPeerCerts(
state tls.ConnectionState, err error) (out []model.ArchivalMaybeBinaryData) {
out = []model.ArchivalMaybeBinaryData{}
var x509HostnameError x509.HostnameError
if errors.As(err, &x509HostnameError) {
// Test case: https://wrong.host.badssl.com/
out = append(out, newArchivalBinaryData(x509HostnameError.Certificate.Raw))
return
}
var x509UnknownAuthorityError x509.UnknownAuthorityError
if errors.As(err, &x509UnknownAuthorityError) {
// Test case: https://self-signed.badssl.com/. This error has
// never been among the ones returned by MK.
out = append(out, newArchivalBinaryData(x509UnknownAuthorityError.Cert.Raw))
return
}
var x509CertificateInvalidError x509.CertificateInvalidError
if errors.As(err, &x509CertificateInvalidError) {
// Test case: https://expired.badssl.com/
out = append(out, newArchivalBinaryData(x509CertificateInvalidError.Cert.Raw))
return
}
for _, cert := range state.PeerCertificates {
out = append(out, newArchivalBinaryData(cert.Raw))
}
return
}
// TLSHandshakes drains the network events buffered inside the TLSHandshake channel.
func (tx *Trace) TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) {
for {
select {
case ev := <-tx.TLSHandshake:
out = append(out, ev)
default:
return // done
}
}
}
+418
View File
@@ -0,0 +1,418 @@
package measurexlite
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewTLSHandshakerStdlib(t *testing.T) {
t.Run("NewTLSHandshakerStdlib creates a wrapped TLSHandshaker", func(t *testing.T) {
underlying := &mocks.TLSHandshaker{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
return underlying
}
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
thxt := thx.(*tlsHandshakerTrace)
if thxt.thx != underlying {
t.Fatal("invalid TLS handshaker")
}
if thxt.tx != trace {
t.Fatal("invalid trace")
}
})
t.Run("Handshake calls the underlying dialer with context-based tracing", func(t *testing.T) {
expectedErr := errors.New("mocked err")
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
var hasCorrectTrace bool
underlying := &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
gotTrace := netxlite.ContextTraceOrDefault(ctx)
hasCorrectTrace = (gotTrace == trace)
return nil, tls.ConnectionState{}, expectedErr
},
}
trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
return underlying
}
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected zero-value state")
}
if conn != nil {
t.Fatal("expected nil conn")
}
if !hasCorrectTrace {
t.Fatal("does not have the correct trace")
}
})
t.Run("Handshake saves into the trace", func(t *testing.T) {
mockedErr := errors.New("mocked")
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic timing
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
MockWrite: func(b []byte) (int, error) {
return 0, mockedErr
},
MockClose: func() error {
return nil
},
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: "dns.cloudflare.com",
}
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected zero-value state")
}
if conn != nil {
t.Fatal("expected nil conn")
}
t.Run("TLSHandshake events", func(t *testing.T) {
events := trace.TLSHandshakes()
if len(events) != 1 {
t.Fatal("expected to see single TLSHandshake event")
}
expectedFailure := netxlite.FailureInterrupted
expect := &model.ArchivalTLSOrQUICHandshakeResult{
Address: "1.1.1.1:443",
CipherSuite: "",
Failure: &expectedFailure,
NegotiatedProtocol: "",
NoTLSVerify: true,
PeerCertificates: []model.ArchivalMaybeBinaryData{},
ServerName: "dns.cloudflare.com",
T: time.Second.Seconds(),
Tags: []string{},
TLSVersion: "",
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Network events", func(t *testing.T) {
events := trace.NetworkEvents()
if len(events) != 2 {
t.Fatal("expected to see two Network events")
}
t.Run("tls_handshake_start", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_start",
Proto: "",
T: 0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("tls_handshake_done", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_done",
Proto: "",
T: time.Second.Seconds(),
Tags: []string{},
}
got := events[1]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
})
})
t.Run("Handshake discards events when buffers are full", func(t *testing.T) {
mockedErr := errors.New("mocked")
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NetworkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
trace.TLSHandshake = make(chan *model.ArchivalTLSOrQUICHandshakeResult) // no buffer
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // we cancel immediately so connect is ~instantaneous
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
MockWrite: func(b []byte) (int, error) {
return 0, mockedErr
},
MockClose: func() error {
return nil
},
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: "dns.cloudflare.com",
}
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected zero-value state")
}
if conn != nil {
t.Fatal("expected nil conn")
}
t.Run("TLSHandshake events", func(t *testing.T) {
events := trace.TLSHandshakes()
if len(events) != 0 {
t.Fatal("expected to see no TLSHandshake events")
}
})
t.Run("Network events", func(t *testing.T) {
events := trace.NetworkEvents()
if len(events) != 0 {
t.Fatal("expected to see no Network events")
}
})
})
t.Run("we collect the desired data with a local TLS server", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger)
ctx := context.Background()
conn, err := dialer.DialContext(ctx, "tcp", server.Endpoint())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
zeroTime := time.Now()
dt := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = dt.Now // deterministic timing
thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
tlsConfig := &tls.Config{
RootCAs: server.CertPool(),
ServerName: "dns.google",
}
tlsConn, connState, err := thx.Handshake(ctx, conn, tlsConfig)
if err != nil {
t.Fatal(err)
}
defer tlsConn.Close()
data, err := netxlite.ReadAllContext(ctx, tlsConn)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, filtering.HTTPBlockpage451) {
t.Fatal("bytes should match")
}
t.Run("TLSHandshake events", func(t *testing.T) {
events := trace.TLSHandshakes()
if len(events) != 1 {
t.Fatal("expected to see a single TLSHandshake event")
}
expected := &model.ArchivalTLSOrQUICHandshakeResult{
Address: conn.RemoteAddr().String(),
CipherSuite: netxlite.TLSCipherSuiteString(connState.CipherSuite),
Failure: nil,
NegotiatedProtocol: "",
NoTLSVerify: false,
PeerCertificates: []model.ArchivalMaybeBinaryData{},
ServerName: "dns.google",
T: time.Second.Seconds(),
Tags: []string{},
TLSVersion: netxlite.TLSVersionString(connState.Version),
}
got := events[0]
// TODO(bassosimone): it's still unclear to me how to test that
// I am getting exactly the expected certificate here. I think the
// certificate is generated on the fly by google/martian. So, I'm
// just going to reduce the precision of this check.
if len(got.PeerCertificates) != 2 {
t.Fatal("expected to see two certificates")
}
got.PeerCertificates = []model.ArchivalMaybeBinaryData{} // see above
if diff := cmp.Diff(expected, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("Network events", func(t *testing.T) {
events := trace.NetworkEvents()
if len(events) != 2 {
t.Fatal("expected to see two Network events")
}
t.Run("tls_handshake_start", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_start",
Proto: "",
T: 0,
Tags: []string{},
}
got := events[0]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
t.Run("tls_handshake_done", func(t *testing.T) {
expect := &model.ArchivalNetworkEvent{
Address: "",
Failure: nil,
NumBytes: 0,
Operation: "tls_handshake_done",
Proto: "",
T: time.Second.Seconds(),
Tags: []string{},
}
got := events[1]
if diff := cmp.Diff(expect, got); diff != "" {
t.Fatal(diff)
}
})
})
})
}
func TestTLSPeerCerts(t *testing.T) {
type args struct {
state tls.ConnectionState
err error
}
tests := []struct {
name string
args args
wantOut []model.ArchivalMaybeBinaryData
}{{
name: "x509.HostnameError",
args: args{
state: tls.ConnectionState{},
err: x509.HostnameError{
Certificate: &x509.Certificate{
Raw: []byte("deadbeef"),
},
},
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}},
}, {
name: "x509.UnknownAuthorityError",
args: args{
state: tls.ConnectionState{},
err: x509.UnknownAuthorityError{
Cert: &x509.Certificate{
Raw: []byte("deadbeef"),
},
},
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}},
}, {
name: "x509.CertificateInvalidError",
args: args{
state: tls.ConnectionState{},
err: x509.CertificateInvalidError{
Cert: &x509.Certificate{
Raw: []byte("deadbeef"),
},
},
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}},
}, {
name: "successful case",
args: args{
state: tls.ConnectionState{
PeerCertificates: []*x509.Certificate{{
Raw: []byte("deadbeef"),
}, {
Raw: []byte("abad1dea"),
}},
},
err: nil,
},
wantOut: []model.ArchivalMaybeBinaryData{{
Value: "deadbeef",
}, {
Value: "abad1dea",
}},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotOut := TLSPeerCerts(tt.args.state, tt.args.err)
if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" {
t.Fatal(diff)
}
})
}
}
+143
View File
@@ -0,0 +1,143 @@
package measurexlite
//
// Definition of Trace
//
import (
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// Trace implements model.Trace.
//
// The zero-value of this struct is invalid. To construct you should either
// fill all the fields marked as MANDATORY or use NewTrace.
//
// Buffered channels
//
// NewTrace uses reasonable buffer sizes for the channels used for collecting
// events. You should drain the channels used by this implementation after
// each operation you perform (i.e., we expect you to peform step-by-step
// measurements). If you want larger (or smaller) buffers, then you should
// construct this data type manually with the desired buffer sizes.
//
// We have convenience methods for extracting events from the buffered
// channels. Otherwise, you could read the channels directly. (In which
// case, remember to issue nonblocking channel reads because channels are
// never closed and they're just written when new events occur.)
type Trace struct {
// Index is the MANDATORY unique index of this trace within the
// current measurement. If you don't care about uniquely identifying
// treaces, you can use zero to indicate the "default" trace.
Index int64
// NetworkEvent is MANDATORY and buffers network events. If you create
// this channel manually, ensure it has some buffer.
NetworkEvent chan *model.ArchivalNetworkEvent
// NewDialerWithoutResolverFn is OPTIONAL and can be used to override
// calls to the netxlite.NewDialerWithoutResolver factory.
NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer
// NewTLSHandshakerStdlibFn is OPTIONAL and can be used to overide
// calls to the netxlite.NewTLSHandshakerStdlib factory.
NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker
// TCPConnect is MANDATORY and buffers TCP connect observations. If you create
// this channel manually, ensure it has some buffer.
TCPConnect chan *model.ArchivalTCPConnectResult
// TLSHandshake is MANDATORY and buffers TLS handshake observations. If you create
// this channel manually, ensure it has some buffer.
TLSHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
// to produce deterministic timing when testing.
TimeNowFn func() time.Time
// ZeroTime is the MANDATORY time when we started the current measurement.
ZeroTime time.Time
}
const (
// NetworkEventBufferSize is the buffer size for constructing
// the Trace's NetworkEvent buffered channel.
NetworkEventBufferSize = 64
// TCPConnectBufferSize is the buffer size for constructing
// the Trace's TCPConnect buffered channel.
TCPConnectBufferSize = 8
// TLSHandshakeBufferSize is the buffer for construcing
// the Trace's TLSHandshake buffered channel.
TLSHandshakeBufferSize = 8
)
// NewTrace creates a new instance of Trace using default settings.
//
// We create buffered channels using as buffer sizes the constants that
// are also defined by this package.
//
// Arguments:
//
// - index is the unique index of this trace within the current measurement (use
// zero if you don't care about giving this trace a unique ID);
//
// - zeroTime is the time when we started the current measurement.
func NewTrace(index int64, zeroTime time.Time) *Trace {
return &Trace{
Index: index,
NetworkEvent: make(
chan *model.ArchivalNetworkEvent,
NetworkEventBufferSize,
),
NewDialerWithoutResolverFn: nil, // use default
NewTLSHandshakerStdlibFn: nil, // use default
TCPConnect: make(
chan *model.ArchivalTCPConnectResult,
TCPConnectBufferSize,
),
TLSHandshake: make(
chan *model.ArchivalTLSOrQUICHandshakeResult,
TLSHandshakeBufferSize,
),
TimeNowFn: nil, // use default
ZeroTime: zeroTime,
}
}
// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver
// thus allows us to mock this func for testing.
func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
if tx.NewDialerWithoutResolverFn != nil {
return tx.NewDialerWithoutResolverFn(dl)
}
return netxlite.NewDialerWithoutResolver(dl)
}
// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
// thus allowing us to mock this func for testing.
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
if tx.NewTLSHandshakerStdlibFn != nil {
return tx.NewTLSHandshakerStdlibFn(dl)
}
return netxlite.NewTLSHandshakerStdlib(dl)
}
// TimeNow implements model.Trace.TimeNow.
func (tx *Trace) TimeNow() time.Time {
if tx.TimeNowFn != nil {
return tx.TimeNowFn()
}
return time.Now()
}
// TimeSince is equivalent to Trace.TimeNow().Sub(t0).
func (tx *Trace) TimeSince(t0 time.Time) time.Duration {
return tx.TimeNow().Sub(t0)
}
var _ model.Trace = &Trace{}
+248
View File
@@ -0,0 +1,248 @@
package measurexlite
import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewTrace(t *testing.T) {
t.Run("NewTrace correctly constructs a trace", func(t *testing.T) {
const index = 17
zeroTime := time.Now()
trace := NewTrace(index, zeroTime)
t.Run("Index", func(t *testing.T) {
if trace.Index != index {
t.Fatal("invalid index")
}
})
t.Run("NetworkEvent has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalNetworkEvent{}
ff.Fill(ev)
select {
case trace.NetworkEvent <- ev:
idx++
default:
break Loop
}
}
if idx != NetworkEventBufferSize {
t.Fatal("invalid NetworkEvent channel buffer size")
}
})
t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) {
if trace.NewDialerWithoutResolverFn != nil {
t.Fatal("expected nil NewDialerWithoutResolverFn")
}
})
t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) {
if trace.NewTLSHandshakerStdlibFn != nil {
t.Fatal("expected nil NewTLSHandshakerStdlibFn")
}
})
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalTCPConnectResult{}
ff.Fill(ev)
select {
case trace.TCPConnect <- ev:
idx++
default:
break Loop
}
}
if idx != TCPConnectBufferSize {
t.Fatal("invalid TCPConnect channel buffer size")
}
})
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalTLSOrQUICHandshakeResult{}
ff.Fill(ev)
select {
case trace.TLSHandshake <- ev:
idx++
default:
break Loop
}
}
if idx != TLSHandshakeBufferSize {
t.Fatal("invalid TLSHandshake channel buffer size")
}
})
t.Run("TimeNowFn is nil", func(t *testing.T) {
if trace.TimeNowFn != nil {
t.Fatal("expected nil TimeNowFn")
}
})
t.Run("ZeroTime", func(t *testing.T) {
if !trace.ZeroTime.Equal(zeroTime) {
t.Fatal("invalid zero time")
}
})
})
}
func TestTrace(t *testing.T) {
t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer {
return &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mockedErr
},
}
},
}
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
ctx := context.Background()
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("when nil", func(t *testing.T) {
tx := &Trace{
NewDialerWithoutResolverFn: nil,
}
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker {
return &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, mockedErr
},
}
},
}
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("state is not a zero value")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("when nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewTLSHandshakerStdlibFn: nil,
}
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
MockWrite: func(b []byte) (int, error) {
return 0, mockedErr
},
MockClose: func() error {
return nil
},
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
}
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("state is not a zero value")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("TimeNowFn works as intended", func(t *testing.T) {
fixedTime := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
tx := &Trace{
TimeNowFn: func() time.Time {
return fixedTime
},
}
if !tx.TimeNow().Equal(fixedTime) {
t.Fatal("we cannot override time.Now calls")
}
})
t.Run("TimeSince works as intended", func(t *testing.T) {
t0 := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
t1 := t0.Add(10 * time.Second)
tx := &Trace{
TimeNowFn: func() time.Time {
return t1
},
}
if tx.TimeSince(t0) != 10*time.Second {
t.Fatal("apparently Trace.Since is broken")
}
})
}