ooni-probe-cli/internal/measurexlite/trace_test.go

327 lines
8.0 KiB
Go
Raw Normal View History

package measurexlite
import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewTrace(t *testing.T) {
t.Run("NewTrace correctly constructs a trace", func(t *testing.T) {
const index = 17
zeroTime := time.Now()
trace := NewTrace(index, zeroTime)
t.Run("Index", func(t *testing.T) {
if trace.Index != index {
t.Fatal("invalid index")
}
})
t.Run("NetworkEvent has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalNetworkEvent{}
ff.Fill(ev)
select {
case trace.NetworkEvent <- ev:
idx++
default:
break Loop
}
}
if idx != NetworkEventBufferSize {
t.Fatal("invalid NetworkEvent channel buffer size")
}
})
t.Run("NewParallelResolverFn is nil", func(t *testing.T) {
if trace.NewParallelResolverFn != nil {
t.Fatal("expected nil NewUnwrappedParallelResolverFn")
}
})
t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) {
if trace.NewDialerWithoutResolverFn != nil {
t.Fatal("expected nil NewDialerWithoutResolverFn")
}
})
t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) {
if trace.NewTLSHandshakerStdlibFn != nil {
t.Fatal("expected nil NewTLSHandshakerStdlibFn")
}
})
t.Run("DNSLookup has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
for _, qtype := range DNSQueryTypes {
var count int
Loop:
for {
ev := &model.ArchivalDNSLookupResult{}
ff.Fill(ev)
select {
case trace.DNSLookup[qtype] <- ev:
count++
default:
break Loop
}
}
if count != DNSLookupBufferSize {
t.Fatal("invalid DNSLookup A channel buffer size")
}
}
})
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalTCPConnectResult{}
ff.Fill(ev)
select {
case trace.TCPConnect <- ev:
idx++
default:
break Loop
}
}
if idx != TCPConnectBufferSize {
t.Fatal("invalid TCPConnect channel buffer size")
}
})
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalTLSOrQUICHandshakeResult{}
ff.Fill(ev)
select {
case trace.TLSHandshake <- ev:
idx++
default:
break Loop
}
}
if idx != TLSHandshakeBufferSize {
t.Fatal("invalid TLSHandshake channel buffer size")
}
})
t.Run("TimeNowFn is nil", func(t *testing.T) {
if trace.TimeNowFn != nil {
t.Fatal("expected nil TimeNowFn")
}
})
t.Run("ZeroTime", func(t *testing.T) {
if !trace.ZeroTime.Equal(zeroTime) {
t.Fatal("invalid zero time")
}
})
})
}
func TestTrace(t *testing.T) {
t.Run("NewParallelResolverFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewParallelResolverFn: func() model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{}, mockedErr
},
}
},
}
resolver := tx.newParallelResolver(func() model.Resolver {
return nil
})
ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com")
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if len(addrs) != 0 {
t.Fatal("expected array of size 0")
}
})
t.Run("when nil", func(t *testing.T) {
tx := &Trace{
NewParallelResolverFn: nil,
}
newResolver := func() model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"1.1.1.1"}, nil
},
}
}
resolver := tx.newParallelResolver(newResolver)
ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com")
if err != nil {
t.Fatal("unexpected err", err)
}
if len(addrs) != 1 {
t.Fatal("expected array of size 1")
}
if addrs[0] != "1.1.1.1" {
t.Fatal("unexpected array output", addrs)
}
})
})
t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer {
return &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mockedErr
},
}
},
}
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
ctx := context.Background()
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("when nil", func(t *testing.T) {
tx := &Trace{
NewDialerWithoutResolverFn: nil,
}
dialer := tx.NewDialerWithoutResolver(model.DiscardLogger)
ctx, cancel := context.WithCancel(context.Background())
cancel() // fail immediately
conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443")
if err == nil || err.Error() != netxlite.FailureInterrupted {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) {
t.Run("when not nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker {
return &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, mockedErr
},
}
},
}
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("state is not a zero value")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("when nil", func(t *testing.T) {
mockedErr := errors.New("mocked")
tx := &Trace{
NewTLSHandshakerStdlibFn: nil,
}
thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger)
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "tcp"
},
MockString: func() string {
return "1.1.1.1:443"
},
}
},
MockWrite: func(b []byte) (int, error) {
return 0, mockedErr
},
MockClose: func() error {
return nil
},
}
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
}
ctx := context.Background()
conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("state is not a zero value")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
})
t.Run("TimeNowFn works as intended", func(t *testing.T) {
fixedTime := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
tx := &Trace{
TimeNowFn: func() time.Time {
return fixedTime
},
}
if !tx.TimeNow().Equal(fixedTime) {
t.Fatal("we cannot override time.Now calls")
}
})
t.Run("TimeSince works as intended", func(t *testing.T) {
t0 := time.Date(2022, 01, 01, 00, 00, 00, 00, time.UTC)
t1 := t0.Add(10 * time.Second)
tx := &Trace{
TimeNowFn: func() time.Time {
return t1
},
}
if tx.TimeSince(t0) != 10*time.Second {
t.Fatal("apparently Trace.Since is broken")
}
})
}