diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index b631cb6..5d651f7 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -9,6 +9,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + utls "gitlab.com/yawning/utls.git" ) // Trace implements model.Trace. @@ -57,6 +58,10 @@ type Trace struct { // calls to the netxlite.NewTLSHandshakerStdlib factory. NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker + // NewTLSHandshakerUTLSFn is OPTIONAL and can be used to overide + // calls to the netxlite.NewTLSHandshakerUTLS factory. + NewTLSHandshakerUTLSFn func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker + // NewDialerWithoutResolverFn is OPTIONAL and can be used to override // calls to the netxlite.NewQUICDialerWithoutResolver factory. NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer @@ -200,7 +205,16 @@ func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshake return netxlite.NewTLSHandshakerStdlib(dl) } -// newWUICDialerWithoutResolver indirectly calls netxlite.NewQUICDialerWithoutResolver +// newTLSHandshakerUTLS indirectly calls netxlite.NewTLSHandshakerUTLS +// thus allowing us to mock this func for testing. +func (tx *Trace) newTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + if tx.NewTLSHandshakerUTLSFn != nil { + return tx.NewTLSHandshakerUTLSFn(dl, id) + } + return netxlite.NewTLSHandshakerUTLS(dl, id) +} + +// newQUICDialerWithoutResolver indirectly calls netxlite.NewQUICDialerWithoutResolver // thus allowing us to mock this func for testing. func (tx *Trace) newQUICDialerWithoutResolver(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { if tx.NewQUICDialerWithoutResolverFn != nil { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index 2e8396e..6562bca 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -15,6 +15,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/testingx" + utls "gitlab.com/yawning/utls.git" ) func TestNewTrace(t *testing.T) { @@ -78,6 +79,12 @@ func TestNewTrace(t *testing.T) { } }) + t.Run("newTLShandshakerUTLSFn is nil", func(t *testing.T) { + if trace.NewTLSHandshakerUTLSFn != nil { + t.Fatal("expected nil NewTLSHandshakerUTLSfn") + } + }) + t.Run("NewQUICDialerWithoutResolverFn is nil", func(t *testing.T) { if trace.NewQUICDialerWithoutResolverFn != nil { t.Fatal("expected nil NewQUICDialerQithoutResolverFn") @@ -426,6 +433,76 @@ func TestTrace(t *testing.T) { }) }) + t.Run("NewTLSHandshakerUTLSFn works as intended", func(t *testing.T) { + t.Run("when not nil", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := &Trace{ + NewTLSHandshakerUTLSFn: func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + return &mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, mockedErr + }, + } + }, + } + thx := tx.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("when nil", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := &Trace{ + NewTLSHandshakerStdlibFn: nil, + } + thx := tx.newTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + }) + t.Run("NewQUICDialerWithoutResolverFn works as intended", func(t *testing.T) { t.Run("when not nil", func(t *testing.T) { mockedErr := errors.New("mocked") diff --git a/internal/measurexlite/utls.go b/internal/measurexlite/utls.go new file mode 100644 index 0000000..9e7160d --- /dev/null +++ b/internal/measurexlite/utls.go @@ -0,0 +1,15 @@ +package measurexlite + +import ( + "github.com/ooni/probe-cli/v3/internal/model" + utls "gitlab.com/yawning/utls.git" +) + +// NewTLSHandshakerUTLS is equivalent to netxlite.NewTLSHandshakerUTLS +// except that it returns a model.TLSHandshaker that uses this trace. +func (tx *Trace) NewTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + return &tlsHandshakerTrace{ + thx: tx.newTLSHandshakerUTLS(dl, id), + tx: tx, + } +} diff --git a/internal/measurexlite/utls_test.go b/internal/measurexlite/utls_test.go new file mode 100644 index 0000000..2308981 --- /dev/null +++ b/internal/measurexlite/utls_test.go @@ -0,0 +1,29 @@ +package measurexlite + +import ( + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + utls "gitlab.com/yawning/utls.git" +) + +func TestNewTLSHandshakerUTLS(t *testing.T) { + t.Run("NewTLSHandshakerUTLS creates a wrapped TLSHandshaker", func(t *testing.T) { + underlying := &mocks.TLSHandshaker{} + zeroTime := time.Now() + trace := NewTrace(0, zeroTime) + trace.NewTLSHandshakerUTLSFn = func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + return underlying + } + thx := trace.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) + thxt := thx.(*tlsHandshakerTrace) + if thxt.thx != underlying { + t.Fatal("invalid TLS handshaker") + } + if thxt.tx != trace { + t.Fatal("invalid trace") + } + }) +}