diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 2955575..01d2baa 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -179,7 +179,7 @@ func NewTLSDialer(config Config) TLSDialer { var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{} h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h} if config.Logger != nil { - h = tlsdialer.LoggingTLSHandshaker{Logger: config.Logger, TLSHandshaker: h} + h = &netxlite.TLSHandshakerLogger{Logger: config.Logger, TLSHandshaker: h} } if config.TLSSaver != nil { h = tlsdialer.SaverTLSHandshaker{TLSHandshaker: h, Saver: config.TLSSaver} diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index c192f4e..bf17abd 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -292,7 +292,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - lth, ok := rtd.TLSHandshaker.(tlsdialer.LoggingTLSHandshaker) + lth, ok := rtd.TLSHandshaker.(*netxlite.TLSHandshakerLogger) if !ok { t.Fatal("not the TLSHandshaker we expected") } diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go index 2f6e619..13c1768 100644 --- a/internal/engine/netx/tlsdialer/integration_test.go +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -17,7 +17,7 @@ func TestTLSDialerSuccess(t *testing.T) { } log.SetLevel(log.DebugLevel) dialer := tlsdialer.TLSDialer{Dialer: new(net.Dialer), - TLSHandshaker: tlsdialer.LoggingTLSHandshaker{ + TLSHandshaker: &netxlite.TLSHandshakerLogger{ TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Logger: log.Log, }, diff --git a/internal/engine/netx/tlsdialer/logging.go b/internal/engine/netx/tlsdialer/logging.go index 6945d8c..e55459d 100644 --- a/internal/engine/netx/tlsdialer/logging.go +++ b/internal/engine/netx/tlsdialer/logging.go @@ -1,39 +1,7 @@ package tlsdialer -import ( - "context" - "crypto/tls" - "net" - "time" - - "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsx" -) - // Logger is the logger assumed by this package type Logger interface { Debugf(format string, v ...interface{}) Debug(message string) } - -// LoggingTLSHandshaker is a TLSHandshaker with logging -type LoggingTLSHandshaker struct { - TLSHandshaker - Logger Logger -} - -// Handshake implements Handshaker.Handshake -func (h LoggingTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - h.Logger.Debugf("tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos) - start := time.Now() - tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) - stop := time.Now() - h.Logger.Debugf( - "tls {sni=%s next=%+v}... %+v in %s {next=%s cipher=%s v=%s}", config.ServerName, - config.NextProtos, err, stop.Sub(start), state.NegotiatedProtocol, - tlsx.CipherSuiteString(state.CipherSuite), tlsx.VersionString(state.Version)) - return tlsconn, state, err -} - -var _ TLSHandshaker = LoggingTLSHandshaker{} diff --git a/internal/engine/netx/tlsdialer/logging_test.go b/internal/engine/netx/tlsdialer/logging_test.go deleted file mode 100644 index 96733ef..0000000 --- a/internal/engine/netx/tlsdialer/logging_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package tlsdialer_test - -import ( - "context" - "crypto/tls" - "errors" - "io" - "testing" - - "github.com/apex/log" - "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" -) - -func TestLoggingTLSHandshakerFailure(t *testing.T) { - h := tlsdialer.LoggingTLSHandshaker{ - TLSHandshaker: tlsdialer.EOFTLSHandshaker{}, - Logger: log.Log, - } - tlsconn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{ - ServerName: "www.google.com", - }) - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if tlsconn != nil { - t.Fatal("expected nil tlsconn here") - } -} diff --git a/internal/netxlite/tlshandshaker.go b/internal/netxlite/tlshandshaker.go index 391ea53..a80d3d5 100644 --- a/internal/netxlite/tlshandshaker.go +++ b/internal/netxlite/tlshandshaker.go @@ -5,6 +5,8 @@ import ( "crypto/tls" "net" "time" + + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsx" ) // TLSHandshaker is the generic TLS handshaker. @@ -44,3 +46,37 @@ func (h *TLSHandshakerStdlib) Handshake( // DefaultTLSHandshaker is the default TLS handshaker. var DefaultTLSHandshaker = &TLSHandshakerStdlib{} + +// TLSHandshakerLogger is a TLSHandshaker with logging. +type TLSHandshakerLogger struct { + // TLSHandshaker is the underlying handshaker. + TLSHandshaker TLSHandshaker + + // Logger is the underlying logger. + Logger Logger +} + +// Handshake implements Handshaker.Handshake +func (h *TLSHandshakerLogger) Handshake( + ctx context.Context, conn net.Conn, config *tls.Config, +) (net.Conn, tls.ConnectionState, error) { + h.Logger.Debugf( + "tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos) + start := time.Now() + tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) + elapsed := time.Since(start) + if err != nil { + h.Logger.Debugf( + "tls {sni=%s next=%+v}... %s in %s", config.ServerName, + config.NextProtos, err, elapsed) + return nil, tls.ConnectionState{}, err + } + h.Logger.Debugf( + "tls {sni=%s next=%+v}... ok in %s {next=%s cipher=%s v=%s}", + config.ServerName, config.NextProtos, elapsed, state.NegotiatedProtocol, + tlsx.CipherSuiteString(state.CipherSuite), + tlsx.VersionString(state.Version)) + return tlsconn, state, nil +} + +var _ TLSHandshaker = &TLSHandshakerLogger{} diff --git a/internal/netxlite/tlshandshaker_test.go b/internal/netxlite/tlshandshaker_test.go index 5d9cd7a..24f1cd2 100644 --- a/internal/netxlite/tlshandshaker_test.go +++ b/internal/netxlite/tlshandshaker_test.go @@ -3,14 +3,17 @@ package netxlite import ( "context" "crypto/tls" + "errors" "io" "net" "net/http" "net/http/httptest" "net/url" + "reflect" "testing" "time" + "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/netxmocks" ) @@ -79,3 +82,60 @@ func TestTLSHandshakerStdlibSuccess(t *testing.T) { t.Fatal("unexpected TLS version") } } + +func TestTLSHandshakerLoggerSuccess(t *testing.T) { + th := &TLSHandshakerLogger{ + TLSHandshaker: &netxmocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return tls.Client(conn, config), tls.ConnectionState{}, nil + }, + }, + Logger: log.Log, + } + conn := &netxmocks.Conn{ + MockClose: func() error { + return nil + }, + } + config := &tls.Config{} + ctx := context.Background() + tlsConn, connState, err := th.Handshake(ctx, conn, config) + if err != nil { + t.Fatal(err) + } + if err := tlsConn.Close(); err != nil { + t.Fatal(err) + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero ConnectionState here") + } +} + +func TestTLSHandshakerLoggerFailure(t *testing.T) { + expected := errors.New("mocked error") + th := &TLSHandshakerLogger{ + TLSHandshaker: &netxmocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, expected + }, + }, + Logger: log.Log, + } + conn := &netxmocks.Conn{ + MockClose: func() error { + return nil + }, + } + config := &tls.Config{} + ctx := context.Background() + tlsConn, connState, err := th.Handshake(ctx, conn, config) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if tlsConn != nil { + t.Fatal("expected nil conn here") + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero ConnectionState here") + } +} diff --git a/internal/netxmocks/dialer.go b/internal/netxmocks/dialer.go index f176a20..e6e0280 100644 --- a/internal/netxmocks/dialer.go +++ b/internal/netxmocks/dialer.go @@ -5,11 +5,6 @@ import ( "net" ) -// dialer is the interface we expect from a dialer -type dialer interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - // Dialer is a mockable Dialer. type Dialer struct { MockDialContext func(ctx context.Context, network, address string) (net.Conn, error) @@ -19,5 +14,3 @@ type Dialer struct { func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.MockDialContext(ctx, network, address) } - -var _ dialer = &Dialer{} diff --git a/internal/netxmocks/resolver.go b/internal/netxmocks/resolver.go index 89c996f..c488911 100644 --- a/internal/netxmocks/resolver.go +++ b/internal/netxmocks/resolver.go @@ -2,13 +2,6 @@ package netxmocks import "context" -// resolver is the interface we expect from a resolver -type resolver interface { - LookupHost(ctx context.Context, domain string) ([]string, error) - Network() string - Address() string -} - // Resolver is a mockable Resolver. type Resolver struct { MockLookupHost func(ctx context.Context, domain string) ([]string, error) @@ -30,5 +23,3 @@ func (r *Resolver) Address() string { func (r *Resolver) Network() string { return r.MockNetwork() } - -var _ resolver = &Resolver{} diff --git a/internal/netxmocks/tlshandshaker.go b/internal/netxmocks/tlshandshaker.go new file mode 100644 index 0000000..8981c74 --- /dev/null +++ b/internal/netxmocks/tlshandshaker.go @@ -0,0 +1,19 @@ +package netxmocks + +import ( + "context" + "crypto/tls" + "net" +) + +// TLSHandshaker is a mockable TLS handshaker. +type TLSHandshaker struct { + MockHandshake func(ctx context.Context, conn net.Conn, config *tls.Config) ( + net.Conn, tls.ConnectionState, error) +} + +// Handshake calls MockHandshake. +func (th *TLSHandshaker) Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( + net.Conn, tls.ConnectionState, error) { + return th.MockHandshake(ctx, conn, config) +} diff --git a/internal/netxmocks/tlshandshaker_test.go b/internal/netxmocks/tlshandshaker_test.go new file mode 100644 index 0000000..7a0d427 --- /dev/null +++ b/internal/netxmocks/tlshandshaker_test.go @@ -0,0 +1,33 @@ +package netxmocks + +import ( + "context" + "crypto/tls" + "errors" + "net" + "reflect" + "testing" +) + +func TestTLSHandshakerHandshake(t *testing.T) { + expected := errors.New("mocked error") + conn := &Conn{} + ctx := context.Background() + config := &tls.Config{} + th := &TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, + config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, expected + }, + } + tlsConn, connState, err := th.Handshake(ctx, conn, config) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if !reflect.ValueOf(connState).IsZero() { + t.Fatal("expected zero ConnectionState here") + } + if tlsConn != nil { + t.Fatal("expected nil conn here") + } +}