diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index 1b0d048..508a655 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -13,6 +13,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers" "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" ) // Dialer performs measurements while dialing. @@ -101,14 +102,14 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { // - SystemTLSHandshaker // // If you have others needs, manually build the chain you need. -func newTLSDialer(d dialer.Dialer, config *tls.Config) dialer.TLSDialer { - return dialer.TLSDialer{ +func newTLSDialer(d dialer.Dialer, config *tls.Config) tlsdialer.TLSDialer { + return tlsdialer.TLSDialer{ Config: config, Dialer: d, - TLSHandshaker: dialer.EmitterTLSHandshaker{ - TLSHandshaker: dialer.ErrorWrapperTLSHandshaker{ - TLSHandshaker: dialer.TimeoutTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, + TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ + TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{ + TLSHandshaker: tlsdialer.TimeoutTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, }, }, }, diff --git a/internal/engine/netx/dialer/eof_test.go b/internal/engine/netx/dialer/eof_test.go index 9016658..c629a69 100644 --- a/internal/engine/netx/dialer/eof_test.go +++ b/internal/engine/netx/dialer/eof_test.go @@ -2,7 +2,6 @@ package dialer import ( "context" - "crypto/tls" "io" "net" "time" @@ -69,12 +68,3 @@ func (EOFAddr) Network() string { func (EOFAddr) String() string { return "127.0.0.1:1234" } - -type EOFTLSHandshaker struct{} - -func (EOFTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - time.Sleep(10 * time.Microsecond) - return nil, tls.ConnectionState{}, io.EOF -} diff --git a/internal/engine/netx/dialer/integration_test.go b/internal/engine/netx/dialer/integration_test.go index 2072a01..df3fe4c 100644 --- a/internal/engine/netx/dialer/integration_test.go +++ b/internal/engine/netx/dialer/integration_test.go @@ -1,7 +1,6 @@ package dialer_test import ( - "context" "net" "net/http" "testing" @@ -10,31 +9,6 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" ) -func TestTLSDialerSuccess(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - log.SetLevel(log.DebugLevel) - dialer := dialer.TLSDialer{Dialer: new(net.Dialer), - TLSHandshaker: dialer.LoggingTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, - Logger: log.Log, - }, - } - txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) { - // AlpineLinux edge is still using Go 1.13. We cannot switch to - // using DialTLSContext here as we'd like to until either Alpine - // switches to Go 1.14 or we drop the MK dependency. - return dialer.DialTLSContext(context.Background(), network, address) - }} - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com") - if err != nil { - t.Fatal(err) - } - resp.Body.Close() -} - func TestDNSDialerSuccess(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") diff --git a/internal/engine/netx/dialer/logging.go b/internal/engine/netx/dialer/logging.go index c998122..7c448ed 100644 --- a/internal/engine/netx/dialer/logging.go +++ b/internal/engine/netx/dialer/logging.go @@ -2,11 +2,8 @@ package dialer import ( "context" - "crypto/tls" "net" "time" - - "github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx" ) // Logger is the logger assumed by this package @@ -30,27 +27,3 @@ func (d LoggingDialer) DialContext(ctx context.Context, network, address string) d.Logger.Debugf("dial %s/%s... %+v in %s", address, network, err, stop.Sub(start)) return conn, err } - -// LoggingTLSHandshaker is a TLSHandshaker with logging -type LoggingTLSHandshaker struct { - TLSHandshaker - Logger Logger -} - -// Handshake implements Handshaker.Handshake -func (h LoggingTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - h.Logger.Debugf("tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos) - start := time.Now() - tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) - stop := time.Now() - h.Logger.Debugf( - "tls {sni=%s next=%+v}... %+v in %s {next=%s cipher=%s v=%s}", config.ServerName, - config.NextProtos, err, stop.Sub(start), state.NegotiatedProtocol, - tlsx.CipherSuiteString(state.CipherSuite), tlsx.VersionString(state.Version)) - return tlsconn, state, err -} - -var _ Dialer = LoggingDialer{} -var _ TLSHandshaker = LoggingTLSHandshaker{} diff --git a/internal/engine/netx/dialer/logging_test.go b/internal/engine/netx/dialer/logging_test.go index 9f58836..ba262e4 100644 --- a/internal/engine/netx/dialer/logging_test.go +++ b/internal/engine/netx/dialer/logging_test.go @@ -2,7 +2,6 @@ package dialer_test import ( "context" - "crypto/tls" "errors" "io" "testing" @@ -24,19 +23,3 @@ func TestLoggingDialerFailure(t *testing.T) { t.Fatal("expected nil conn here") } } - -func TestLoggingTLSHandshakerFailure(t *testing.T) { - h := dialer.LoggingTLSHandshaker{ - TLSHandshaker: dialer.EOFTLSHandshaker{}, - Logger: log.Log, - } - tlsconn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{ - ServerName: "www.google.com", - }) - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if tlsconn != nil { - t.Fatal("expected nil tlsconn here") - } -} diff --git a/internal/engine/netx/dialer/saver.go b/internal/engine/netx/dialer/saver.go index db18076..82592ef 100644 --- a/internal/engine/netx/dialer/saver.go +++ b/internal/engine/netx/dialer/saver.go @@ -2,11 +2,9 @@ package dialer import ( "context" - "crypto/tls" "net" "time" - "github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) @@ -33,42 +31,6 @@ func (d SaverDialer) DialContext(ctx context.Context, network, address string) ( return conn, err } -// SaverTLSHandshaker saves events occurring during the handshake -type SaverTLSHandshaker struct { - TLSHandshaker - Saver *trace.Saver -} - -// Handshake implements TLSHandshaker.Handshake -func (h SaverTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - start := time.Now() - h.Saver.Write(trace.Event{ - Name: "tls_handshake_start", - NoTLSVerify: config.InsecureSkipVerify, - TLSNextProtos: config.NextProtos, - TLSServerName: config.ServerName, - Time: start, - }) - tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) - stop := time.Now() - h.Saver.Write(trace.Event{ - Duration: stop.Sub(start), - Err: err, - Name: "tls_handshake_done", - NoTLSVerify: config.InsecureSkipVerify, - TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite), - TLSNegotiatedProto: state.NegotiatedProtocol, - TLSNextProtos: config.NextProtos, - TLSPeerCerts: trace.PeerCerts(state, err), - TLSServerName: config.ServerName, - TLSVersion: tlsx.VersionString(state.Version), - Time: stop, - }) - return tlsconn, state, err -} - // SaverConnDialer wraps the returned connection such that we // collect all the read/write events that occur. type SaverConnDialer struct { @@ -121,5 +83,4 @@ func (c saverConn) Write(p []byte) (int, error) { } var _ Dialer = SaverDialer{} -var _ TLSHandshaker = SaverTLSHandshaker{} var _ net.Conn = saverConn{} diff --git a/internal/engine/netx/dialer/saver_test.go b/internal/engine/netx/dialer/saver_test.go index 8d30323..a256b3a 100644 --- a/internal/engine/netx/dialer/saver_test.go +++ b/internal/engine/netx/dialer/saver_test.go @@ -2,10 +2,7 @@ package dialer_test import ( "context" - "crypto/tls" "errors" - "net" - "reflect" "testing" "time" @@ -71,301 +68,3 @@ func TestSaverConnDialerFailure(t *testing.T) { t.Fatal("expected nil conn here") } } - -func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { - // This is the most common use case for collecting reads, writes - if testing.Short() { - t.Skip("skip test in short mode") - } - nextprotos := []string{"h2"} - saver := &trace.Saver{} - tlsdlr := dialer.TLSDialer{ - Config: &tls.Config{NextProtos: nextprotos}, - Dialer: dialer.SaverConnDialer{ - Dialer: new(net.Dialer), - Saver: saver, - }, - TLSHandshaker: dialer.SaverTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, - Saver: saver, - }, - } - // Implementation note: we don't close the connection here because it is - // very handy to have the last event being the end of the handshake - _, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal(err) - } - ev := saver.Read() - if len(ev) < 4 { - // it's a bit tricky to be sure about the right number of - // events because network conditions may influence that - t.Fatal("unexpected number of events") - } - if ev[0].Name != "tls_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - last := len(ev) - 1 - for idx := 1; idx < last; idx++ { - if ev[idx].Data == nil { - t.Fatal("unexpected Data") - } - if ev[idx].Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[idx].Err != nil { - t.Fatal("unexpected Err") - } - if ev[idx].NumBytes <= 0 { - t.Fatal("unexpected NumBytes") - } - switch ev[idx].Name { - case errorx.ReadOperation, errorx.WriteOperation: - default: - t.Fatal("unexpected Name") - } - if ev[idx].Time.Before(ev[idx-1].Time) { - t.Fatal("unexpected Time") - } - } - if ev[last].Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[last].Err != nil { - t.Fatal("unexpected Err") - } - if ev[last].Name != "tls_handshake_done" { - t.Fatal("unexpected Name") - } - if ev[last].TLSCipherSuite == "" { - t.Fatal("unexpected TLSCipherSuite") - } - if ev[last].TLSNegotiatedProto != "h2" { - t.Fatal("unexpected TLSNegotiatedProto") - } - if !reflect.DeepEqual(ev[last].TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[last].TLSPeerCerts == nil { - t.Fatal("unexpected TLSPeerCerts") - } - if ev[last].TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if ev[last].TLSVersion == "" { - t.Fatal("unexpected TLSVersion") - } - if ev[last].Time.Before(ev[last-1].Time) { - t.Fatal("unexpected Time") - } -} - -func TestSaverTLSHandshakerSuccess(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - nextprotos := []string{"h2"} - saver := &trace.Saver{} - tlsdlr := dialer.TLSDialer{ - Config: &tls.Config{NextProtos: nextprotos}, - Dialer: new(net.Dialer), - TLSHandshaker: dialer.SaverTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, - Saver: saver, - }, - } - conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal(err) - } - conn.Close() - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("unexpected number of events") - } - if ev[0].Name != "tls_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - if ev[1].Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Err != nil { - t.Fatal("unexpected Err") - } - if ev[1].Name != "tls_handshake_done" { - t.Fatal("unexpected Name") - } - if ev[1].TLSCipherSuite == "" { - t.Fatal("unexpected TLSCipherSuite") - } - if ev[1].TLSNegotiatedProto != "h2" { - t.Fatal("unexpected TLSNegotiatedProto") - } - if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[1].TLSPeerCerts == nil { - t.Fatal("unexpected TLSPeerCerts") - } - if ev[1].TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if ev[1].TLSVersion == "" { - t.Fatal("unexpected TLSVersion") - } - if ev[1].Time.Before(ev[0].Time) { - t.Fatal("unexpected Time") - } -} - -func TestSaverTLSHandshakerHostnameError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &trace.Saver{} - tlsdlr := dialer.TLSDialer{ - Dialer: new(net.Dialer), - TLSHandshaker: dialer.SaverTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, - Saver: saver, - }, - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "wrong.host.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name != "tls_handshake_done" { - continue - } - if ev.NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &trace.Saver{} - tlsdlr := dialer.TLSDialer{ - Dialer: new(net.Dialer), - TLSHandshaker: dialer.SaverTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, - Saver: saver, - }, - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "expired.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name != "tls_handshake_done" { - continue - } - if ev.NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerAuthorityError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &trace.Saver{} - tlsdlr := dialer.TLSDialer{ - Dialer: new(net.Dialer), - TLSHandshaker: dialer.SaverTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, - Saver: saver, - }, - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "self-signed.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name != "tls_handshake_done" { - continue - } - if ev.NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &trace.Saver{} - tlsdlr := dialer.TLSDialer{ - Config: &tls.Config{InsecureSkipVerify: true}, - Dialer: new(net.Dialer), - TLSHandshaker: dialer.SaverTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, - Saver: saver, - }, - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "self-signed.badssl.com:443") - if err != nil { - t.Fatal(err) - } - if conn == nil { - t.Fatal("expected non-nil conn here") - } - conn.Close() - for _, ev := range saver.Read() { - if ev.Name != "tls_handshake_done" { - continue - } - if ev.NoTLSVerify != true { - t.Fatal("expected NoTLSVerify to be true") - } - if len(ev.TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 5a2175a..da10f18 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -38,6 +38,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" "github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor" + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" "github.com/ooni/probe-cli/v3/internal/runtimex" ) @@ -196,14 +197,14 @@ func NewTLSDialer(config Config) TLSDialer { if config.Dialer == nil { config.Dialer = NewDialer(config) } - var h tlsHandshaker = dialer.SystemTLSHandshaker{} - h = dialer.TimeoutTLSHandshaker{TLSHandshaker: h} - h = dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h} + var h tlsHandshaker = tlsdialer.SystemTLSHandshaker{} + h = tlsdialer.TimeoutTLSHandshaker{TLSHandshaker: h} + h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h} if config.Logger != nil { - h = dialer.LoggingTLSHandshaker{Logger: config.Logger, TLSHandshaker: h} + h = tlsdialer.LoggingTLSHandshaker{Logger: config.Logger, TLSHandshaker: h} } if config.TLSSaver != nil { - h = dialer.SaverTLSHandshaker{TLSHandshaker: h, Saver: config.TLSSaver} + h = tlsdialer.SaverTLSHandshaker{TLSHandshaker: h, Saver: config.TLSSaver} } if config.TLSConfig == nil { config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}} @@ -213,7 +214,7 @@ func NewTLSDialer(config Config) TLSDialer { } config.TLSConfig.RootCAs = config.CertPool config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify - return dialer.TLSDialer{ + return tlsdialer.TLSDialer{ Config: config.TLSConfig, Dialer: config.Dialer, TLSHandshaker: h, diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 9526ee8..e04d2c4 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -14,6 +14,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport" "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" "github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor" + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) @@ -486,7 +487,7 @@ func TestNewDialerWithContextByteCounting(t *testing.T) { func TestNewTLSDialerVanilla(t *testing.T) { td := netx.NewTLSDialer(netx.Config{}) - rtd, ok := td.(dialer.TLSDialer) + rtd, ok := td.(tlsdialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -512,15 +513,15 @@ func TestNewTLSDialerVanilla(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) + ewth, ok := rtd.TLSHandshaker.(tlsdialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) + tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { + if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -529,7 +530,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSConfig: new(tls.Config), }) - rtd, ok := td.(dialer.TLSDialer) + rtd, ok := td.(tlsdialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -552,15 +553,15 @@ func TestNewTLSDialerWithConfig(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) + ewth, ok := rtd.TLSHandshaker.(tlsdialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) + tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { + if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -569,7 +570,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ Logger: log.Log, }) - rtd, ok := td.(dialer.TLSDialer) + rtd, ok := td.(tlsdialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -595,22 +596,22 @@ func TestNewTLSDialerWithLogging(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - lth, ok := rtd.TLSHandshaker.(dialer.LoggingTLSHandshaker) + lth, ok := rtd.TLSHandshaker.(tlsdialer.LoggingTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if lth.Logger != log.Log { t.Fatal("not the Logger we expected") } - ewth, ok := lth.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) + ewth, ok := lth.TLSHandshaker.(tlsdialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) + tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { + if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -620,7 +621,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSSaver: saver, }) - rtd, ok := td.(dialer.TLSDialer) + rtd, ok := td.(tlsdialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -646,22 +647,22 @@ func TestNewTLSDialerWithSaver(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - sth, ok := rtd.TLSHandshaker.(dialer.SaverTLSHandshaker) + sth, ok := rtd.TLSHandshaker.(tlsdialer.SaverTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if sth.Saver != saver { t.Fatal("not the Logger we expected") } - ewth, ok := sth.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) + ewth, ok := sth.TLSHandshaker.(tlsdialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) + tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { + if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -671,7 +672,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { TLSConfig: new(tls.Config), NoTLSVerify: true, }) - rtd, ok := td.(dialer.TLSDialer) + rtd, ok := td.(tlsdialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -697,15 +698,15 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) + ewth, ok := rtd.TLSHandshaker.(tlsdialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) + tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { + if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -714,7 +715,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ NoTLSVerify: true, }) - rtd, ok := td.(dialer.TLSDialer) + rtd, ok := td.(tlsdialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -743,15 +744,15 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) + ewth, ok := rtd.TLSHandshaker.(tlsdialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) + tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } - if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { + if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } @@ -785,10 +786,10 @@ func TestNewWithDialer(t *testing.T) { func TestNewWithTLSDialer(t *testing.T) { expected := errors.New("mocked error") - tlsDialer := dialer.TLSDialer{ + tlsDialer := tlsdialer.TLSDialer{ Config: new(tls.Config), Dialer: netx.FakeDialer{Err: expected}, - TLSHandshaker: dialer.SystemTLSHandshaker{}, + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, } txp := netx.NewHTTPTransport(netx.Config{ TLSDialer: tlsDialer, diff --git a/internal/engine/netx/tlsdialer/eof_test.go b/internal/engine/netx/tlsdialer/eof_test.go new file mode 100644 index 0000000..56612a3 --- /dev/null +++ b/internal/engine/netx/tlsdialer/eof_test.go @@ -0,0 +1,80 @@ +package tlsdialer + +import ( + "context" + "crypto/tls" + "io" + "net" + "time" +) + +type EOFDialer struct{} + +func (EOFDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + time.Sleep(10 * time.Microsecond) + return nil, io.EOF +} + +type EOFConnDialer struct{} + +func (EOFConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return EOFConn{}, nil +} + +type EOFConn struct { + net.Conn +} + +func (EOFConn) Read(p []byte) (int, error) { + time.Sleep(10 * time.Microsecond) + return 0, io.EOF +} + +func (EOFConn) Write(p []byte) (int, error) { + time.Sleep(10 * time.Microsecond) + return 0, io.EOF +} + +func (EOFConn) Close() error { + time.Sleep(10 * time.Microsecond) + return io.EOF +} + +func (EOFConn) LocalAddr() net.Addr { + return EOFAddr{} +} + +func (EOFConn) RemoteAddr() net.Addr { + return EOFAddr{} +} + +func (EOFConn) SetDeadline(t time.Time) error { + return nil +} + +func (EOFConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (EOFConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type EOFAddr struct{} + +func (EOFAddr) Network() string { + return "tcp" +} + +func (EOFAddr) String() string { + return "127.0.0.1:1234" +} + +type EOFTLSHandshaker struct{} + +func (EOFTLSHandshaker) Handshake( + ctx context.Context, conn net.Conn, config *tls.Config, +) (net.Conn, tls.ConnectionState, error) { + time.Sleep(10 * time.Microsecond) + return nil, tls.ConnectionState{}, io.EOF +} diff --git a/internal/engine/netx/tlsdialer/fake_test.go b/internal/engine/netx/tlsdialer/fake_test.go new file mode 100644 index 0000000..1b4e0b3 --- /dev/null +++ b/internal/engine/netx/tlsdialer/fake_test.go @@ -0,0 +1,71 @@ +package tlsdialer + +import ( + "context" + "io" + "net" + "time" +) + +type FakeDialer struct { + Conn net.Conn + Err error +} + +func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + time.Sleep(10 * time.Microsecond) + return d.Conn, d.Err +} + +type FakeConn struct { + ReadError error + ReadData []byte + SetDeadlineError error + SetReadDeadlineError error + SetWriteDeadlineError error + WriteError error +} + +func (c *FakeConn) Read(b []byte) (int, error) { + if len(c.ReadData) > 0 { + n := copy(b, c.ReadData) + c.ReadData = c.ReadData[n:] + return n, nil + } + if c.ReadError != nil { + return 0, c.ReadError + } + return 0, io.EOF +} + +func (c *FakeConn) Write(b []byte) (n int, err error) { + if c.WriteError != nil { + return 0, c.WriteError + } + n = len(b) + return +} + +func (*FakeConn) Close() (err error) { + return +} + +func (*FakeConn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} + +func (*FakeConn) RemoteAddr() net.Addr { + return &net.TCPAddr{} +} + +func (c *FakeConn) SetDeadline(t time.Time) (err error) { + return c.SetDeadlineError +} + +func (c *FakeConn) SetReadDeadline(t time.Time) (err error) { + return c.SetReadDeadlineError +} + +func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) { + return c.SetWriteDeadlineError +} diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go new file mode 100644 index 0000000..694d204 --- /dev/null +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -0,0 +1,36 @@ +package tlsdialer_test + +import ( + "context" + "net" + "net/http" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" +) + +func TestTLSDialerSuccess(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + log.SetLevel(log.DebugLevel) + dialer := tlsdialer.TLSDialer{Dialer: new(net.Dialer), + TLSHandshaker: tlsdialer.LoggingTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + Logger: log.Log, + }, + } + txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) { + // AlpineLinux edge is still using Go 1.13. We cannot switch to + // using DialTLSContext here as we'd like to until either Alpine + // switches to Go 1.14 or we drop the MK dependency. + return dialer.DialTLSContext(context.Background(), network, address) + }} + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + resp.Body.Close() +} diff --git a/internal/engine/netx/tlsdialer/logging.go b/internal/engine/netx/tlsdialer/logging.go new file mode 100644 index 0000000..0cd7e7e --- /dev/null +++ b/internal/engine/netx/tlsdialer/logging.go @@ -0,0 +1,39 @@ +package tlsdialer + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx" +) + +// Logger is the logger assumed by this package +type Logger interface { + Debugf(format string, v ...interface{}) + Debug(message string) +} + +// LoggingTLSHandshaker is a TLSHandshaker with logging +type LoggingTLSHandshaker struct { + TLSHandshaker + Logger Logger +} + +// Handshake implements Handshaker.Handshake +func (h LoggingTLSHandshaker) Handshake( + ctx context.Context, conn net.Conn, config *tls.Config, +) (net.Conn, tls.ConnectionState, error) { + h.Logger.Debugf("tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos) + start := time.Now() + tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) + stop := time.Now() + h.Logger.Debugf( + "tls {sni=%s next=%+v}... %+v in %s {next=%s cipher=%s v=%s}", config.ServerName, + config.NextProtos, err, stop.Sub(start), state.NegotiatedProtocol, + tlsx.CipherSuiteString(state.CipherSuite), tlsx.VersionString(state.Version)) + return tlsconn, state, err +} + +var _ TLSHandshaker = LoggingTLSHandshaker{} diff --git a/internal/engine/netx/tlsdialer/logging_test.go b/internal/engine/netx/tlsdialer/logging_test.go new file mode 100644 index 0000000..96733ef --- /dev/null +++ b/internal/engine/netx/tlsdialer/logging_test.go @@ -0,0 +1,28 @@ +package tlsdialer_test + +import ( + "context" + "crypto/tls" + "errors" + "io" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" +) + +func TestLoggingTLSHandshakerFailure(t *testing.T) { + h := tlsdialer.LoggingTLSHandshaker{ + TLSHandshaker: tlsdialer.EOFTLSHandshaker{}, + Logger: log.Log, + } + tlsconn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{ + ServerName: "www.google.com", + }) + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if tlsconn != nil { + t.Fatal("expected nil tlsconn here") + } +} diff --git a/internal/engine/netx/tlsdialer/saver.go b/internal/engine/netx/tlsdialer/saver.go new file mode 100644 index 0000000..d660f85 --- /dev/null +++ b/internal/engine/netx/tlsdialer/saver.go @@ -0,0 +1,49 @@ +package tlsdialer + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx" + "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" +) + +// SaverTLSHandshaker saves events occurring during the handshake +type SaverTLSHandshaker struct { + TLSHandshaker + Saver *trace.Saver +} + +// Handshake implements TLSHandshaker.Handshake +func (h SaverTLSHandshaker) Handshake( + ctx context.Context, conn net.Conn, config *tls.Config, +) (net.Conn, tls.ConnectionState, error) { + start := time.Now() + h.Saver.Write(trace.Event{ + Name: "tls_handshake_start", + NoTLSVerify: config.InsecureSkipVerify, + TLSNextProtos: config.NextProtos, + TLSServerName: config.ServerName, + Time: start, + }) + tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) + stop := time.Now() + h.Saver.Write(trace.Event{ + Duration: stop.Sub(start), + Err: err, + Name: "tls_handshake_done", + NoTLSVerify: config.InsecureSkipVerify, + TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite), + TLSNegotiatedProto: state.NegotiatedProtocol, + TLSNextProtos: config.NextProtos, + TLSPeerCerts: trace.PeerCerts(state, err), + TLSServerName: config.ServerName, + TLSVersion: tlsx.VersionString(state.Version), + Time: stop, + }) + return tlsconn, state, err +} + +var _ TLSHandshaker = SaverTLSHandshaker{} diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go new file mode 100644 index 0000000..7cbaa16 --- /dev/null +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -0,0 +1,313 @@ +package tlsdialer_test + +import ( + "context" + "crypto/tls" + "net" + "reflect" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" +) + +func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { + // This is the most common use case for collecting reads, writes + if testing.Short() { + t.Skip("skip test in short mode") + } + nextprotos := []string{"h2"} + saver := &trace.Saver{} + tlsdlr := tlsdialer.TLSDialer{ + Config: &tls.Config{NextProtos: nextprotos}, + Dialer: dialer.SaverConnDialer{ + Dialer: new(net.Dialer), + Saver: saver, + }, + TLSHandshaker: tlsdialer.SaverTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + Saver: saver, + }, + } + // Implementation note: we don't close the connection here because it is + // very handy to have the last event being the end of the handshake + _, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + ev := saver.Read() + if len(ev) < 4 { + // it's a bit tricky to be sure about the right number of + // events because network conditions may influence that + t.Fatal("unexpected number of events") + } + if ev[0].Name != "tls_handshake_start" { + t.Fatal("unexpected Name") + } + if ev[0].TLSServerName != "www.google.com" { + t.Fatal("unexpected TLSServerName") + } + if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) { + t.Fatal("unexpected TLSNextProtos") + } + if ev[0].Time.After(time.Now()) { + t.Fatal("unexpected Time") + } + last := len(ev) - 1 + for idx := 1; idx < last; idx++ { + if ev[idx].Data == nil { + t.Fatal("unexpected Data") + } + if ev[idx].Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[idx].Err != nil { + t.Fatal("unexpected Err") + } + if ev[idx].NumBytes <= 0 { + t.Fatal("unexpected NumBytes") + } + switch ev[idx].Name { + case errorx.ReadOperation, errorx.WriteOperation: + default: + t.Fatal("unexpected Name") + } + if ev[idx].Time.Before(ev[idx-1].Time) { + t.Fatal("unexpected Time") + } + } + if ev[last].Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[last].Err != nil { + t.Fatal("unexpected Err") + } + if ev[last].Name != "tls_handshake_done" { + t.Fatal("unexpected Name") + } + if ev[last].TLSCipherSuite == "" { + t.Fatal("unexpected TLSCipherSuite") + } + if ev[last].TLSNegotiatedProto != "h2" { + t.Fatal("unexpected TLSNegotiatedProto") + } + if !reflect.DeepEqual(ev[last].TLSNextProtos, nextprotos) { + t.Fatal("unexpected TLSNextProtos") + } + if ev[last].TLSPeerCerts == nil { + t.Fatal("unexpected TLSPeerCerts") + } + if ev[last].TLSServerName != "www.google.com" { + t.Fatal("unexpected TLSServerName") + } + if ev[last].TLSVersion == "" { + t.Fatal("unexpected TLSVersion") + } + if ev[last].Time.Before(ev[last-1].Time) { + t.Fatal("unexpected Time") + } +} + +func TestSaverTLSHandshakerSuccess(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + nextprotos := []string{"h2"} + saver := &trace.Saver{} + tlsdlr := tlsdialer.TLSDialer{ + Config: &tls.Config{NextProtos: nextprotos}, + Dialer: new(net.Dialer), + TLSHandshaker: tlsdialer.SaverTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + Saver: saver, + }, + } + conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + conn.Close() + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("unexpected number of events") + } + if ev[0].Name != "tls_handshake_start" { + t.Fatal("unexpected Name") + } + if ev[0].TLSServerName != "www.google.com" { + t.Fatal("unexpected TLSServerName") + } + if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) { + t.Fatal("unexpected TLSNextProtos") + } + if ev[0].Time.After(time.Now()) { + t.Fatal("unexpected Time") + } + if ev[1].Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Err != nil { + t.Fatal("unexpected Err") + } + if ev[1].Name != "tls_handshake_done" { + t.Fatal("unexpected Name") + } + if ev[1].TLSCipherSuite == "" { + t.Fatal("unexpected TLSCipherSuite") + } + if ev[1].TLSNegotiatedProto != "h2" { + t.Fatal("unexpected TLSNegotiatedProto") + } + if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) { + t.Fatal("unexpected TLSNextProtos") + } + if ev[1].TLSPeerCerts == nil { + t.Fatal("unexpected TLSPeerCerts") + } + if ev[1].TLSServerName != "www.google.com" { + t.Fatal("unexpected TLSServerName") + } + if ev[1].TLSVersion == "" { + t.Fatal("unexpected TLSVersion") + } + if ev[1].Time.Before(ev[0].Time) { + t.Fatal("unexpected Time") + } +} + +func TestSaverTLSHandshakerHostnameError(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + saver := &trace.Saver{} + tlsdlr := tlsdialer.TLSDialer{ + Dialer: new(net.Dialer), + TLSHandshaker: tlsdialer.SaverTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + Saver: saver, + }, + } + conn, err := tlsdlr.DialTLSContext( + context.Background(), "tcp", "wrong.host.badssl.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + for _, ev := range saver.Read() { + if ev.Name != "tls_handshake_done" { + continue + } + if ev.NoTLSVerify == true { + t.Fatal("expected NoTLSVerify to be false") + } + if len(ev.TLSPeerCerts) < 1 { + t.Fatal("expected at least a certificate here") + } + } +} + +func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + saver := &trace.Saver{} + tlsdlr := tlsdialer.TLSDialer{ + Dialer: new(net.Dialer), + TLSHandshaker: tlsdialer.SaverTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + Saver: saver, + }, + } + conn, err := tlsdlr.DialTLSContext( + context.Background(), "tcp", "expired.badssl.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + for _, ev := range saver.Read() { + if ev.Name != "tls_handshake_done" { + continue + } + if ev.NoTLSVerify == true { + t.Fatal("expected NoTLSVerify to be false") + } + if len(ev.TLSPeerCerts) < 1 { + t.Fatal("expected at least a certificate here") + } + } +} + +func TestSaverTLSHandshakerAuthorityError(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + saver := &trace.Saver{} + tlsdlr := tlsdialer.TLSDialer{ + Dialer: new(net.Dialer), + TLSHandshaker: tlsdialer.SaverTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + Saver: saver, + }, + } + conn, err := tlsdlr.DialTLSContext( + context.Background(), "tcp", "self-signed.badssl.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + for _, ev := range saver.Read() { + if ev.Name != "tls_handshake_done" { + continue + } + if ev.NoTLSVerify == true { + t.Fatal("expected NoTLSVerify to be false") + } + if len(ev.TLSPeerCerts) < 1 { + t.Fatal("expected at least a certificate here") + } + } +} + +func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + saver := &trace.Saver{} + tlsdlr := tlsdialer.TLSDialer{ + Config: &tls.Config{InsecureSkipVerify: true}, + Dialer: new(net.Dialer), + TLSHandshaker: tlsdialer.SaverTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, + Saver: saver, + }, + } + conn, err := tlsdlr.DialTLSContext( + context.Background(), "tcp", "self-signed.badssl.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn here") + } + conn.Close() + for _, ev := range saver.Read() { + if ev.Name != "tls_handshake_done" { + continue + } + if ev.NoTLSVerify != true { + t.Fatal("expected NoTLSVerify to be true") + } + if len(ev.TLSPeerCerts) < 1 { + t.Fatal("expected at least a certificate here") + } + } +} diff --git a/internal/engine/netx/dialer/tls.go b/internal/engine/netx/tlsdialer/tls.go similarity index 93% rename from internal/engine/netx/dialer/tls.go rename to internal/engine/netx/tlsdialer/tls.go index cf958d8..ccd76fa 100644 --- a/internal/engine/netx/dialer/tls.go +++ b/internal/engine/netx/tlsdialer/tls.go @@ -1,4 +1,5 @@ -package dialer +// Package tlsdialer contains code to establish TLS connections. +package tlsdialer import ( "context" @@ -11,6 +12,11 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" ) +// UnderlyingDialer is the underlying dialer type. +type UnderlyingDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + // TLSHandshaker is the generic TLS handshaker type TLSHandshaker interface { Handshake(ctx context.Context, conn net.Conn, config *tls.Config) ( @@ -105,7 +111,7 @@ func (h EmitterTLSHandshaker) Handshake( // TLSDialer is the TLS dialer type TLSDialer struct { Config *tls.Config - Dialer Dialer + Dialer UnderlyingDialer TLSHandshaker TLSHandshaker } diff --git a/internal/engine/netx/dialer/tls_test.go b/internal/engine/netx/tlsdialer/tls_test.go similarity index 78% rename from internal/engine/netx/dialer/tls_test.go rename to internal/engine/netx/tlsdialer/tls_test.go index d7b16ac..9e093fe 100644 --- a/internal/engine/netx/dialer/tls_test.go +++ b/internal/engine/netx/tlsdialer/tls_test.go @@ -1,4 +1,4 @@ -package dialer_test +package tlsdialer_test import ( "context" @@ -11,13 +11,13 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers" "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx" - "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" + "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" ) func TestSystemTLSHandshakerEOFError(t *testing.T) { - h := dialer.SystemTLSHandshaker{} - conn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{ + h := tlsdialer.SystemTLSHandshaker{} + conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{ ServerName: "x.org", }) if err != io.EOF { @@ -29,13 +29,13 @@ func TestSystemTLSHandshakerEOFError(t *testing.T) { } func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) { - h := dialer.TimeoutTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, + h := tlsdialer.TimeoutTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, HandshakeTimeout: 200 * time.Millisecond, } expected := errors.New("mocked error") conn, _, err := h.Handshake( - context.Background(), &dialer.FakeConn{SetDeadlineError: expected}, + context.Background(), &tlsdialer.FakeConn{SetDeadlineError: expected}, new(tls.Config)) if !errors.Is(err, expected) { t.Fatal("not the error that we expected") @@ -46,12 +46,12 @@ func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) { } func TestTimeoutTLSHandshakerEOFError(t *testing.T) { - h := dialer.TimeoutTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, + h := tlsdialer.TimeoutTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, HandshakeTimeout: 200 * time.Millisecond, } conn, _, err := h.Handshake( - context.Background(), dialer.EOFConn{}, &tls.Config{ServerName: "x.org"}) + context.Background(), tlsdialer.EOFConn{}, &tls.Config{ServerName: "x.org"}) if !errors.Is(err, io.EOF) { t.Fatal("not the error that we expected") } @@ -61,8 +61,8 @@ func TestTimeoutTLSHandshakerEOFError(t *testing.T) { } func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) { - h := dialer.TimeoutTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, + h := tlsdialer.TimeoutTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, HandshakeTimeout: 200 * time.Millisecond, } underlying := &SetDeadlineConn{} @@ -86,7 +86,7 @@ func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) { } type SetDeadlineConn struct { - dialer.EOFConn + tlsdialer.EOFConn deadlines []time.Time } @@ -96,9 +96,9 @@ func (c *SetDeadlineConn) SetDeadline(t time.Time) error { } func TestErrorWrapperTLSHandshakerFailure(t *testing.T) { - h := dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}} + h := tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: tlsdialer.EOFTLSHandshaker{}} conn, _, err := h.Handshake( - context.Background(), dialer.EOFConn{}, new(tls.Config)) + context.Background(), tlsdialer.EOFConn{}, new(tls.Config)) if !errors.Is(err, io.EOF) { t.Fatal("not the error that we expected") } @@ -126,8 +126,8 @@ func TestEmitterTLSHandshakerFailure(t *testing.T) { Beginning: time.Now(), Handler: saver, }) - h := dialer.EmitterTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}} - conn, _, err := h.Handshake(ctx, dialer.EOFConn{}, &tls.Config{ + h := tlsdialer.EmitterTLSHandshaker{TLSHandshaker: tlsdialer.EOFTLSHandshaker{}} + conn, _, err := h.Handshake(ctx, tlsdialer.EOFConn{}, &tls.Config{ ServerName: "www.kernel.org", }) if !errors.Is(err, io.EOF) { @@ -164,7 +164,7 @@ func TestEmitterTLSHandshakerFailure(t *testing.T) { } func TestTLSDialerFailureSplitHostPort(t *testing.T) { - dialer := dialer.TLSDialer{} + dialer := tlsdialer.TLSDialer{} conn, err := dialer.DialTLSContext( context.Background(), "tcp", "www.google.com") // missing port if err == nil { @@ -176,7 +176,7 @@ func TestTLSDialerFailureSplitHostPort(t *testing.T) { } func TestTLSDialerFailureDialing(t *testing.T) { - dialer := dialer.TLSDialer{Dialer: dialer.EOFDialer{}} + dialer := tlsdialer.TLSDialer{Dialer: tlsdialer.EOFDialer{}} conn, err := dialer.DialTLSContext( context.Background(), "tcp", "www.google.com:443") if !errors.Is(err, io.EOF) { @@ -188,9 +188,9 @@ func TestTLSDialerFailureDialing(t *testing.T) { } func TestTLSDialerFailureHandshaking(t *testing.T) { - rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}} - dialer := dialer.TLSDialer{ - Dialer: dialer.EOFConnDialer{}, + rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}} + dialer := tlsdialer.TLSDialer{ + Dialer: tlsdialer.EOFConnDialer{}, TLSHandshaker: rec, } conn, err := dialer.DialTLSContext( @@ -207,12 +207,12 @@ func TestTLSDialerFailureHandshaking(t *testing.T) { } func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) { - rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}} - dialer := dialer.TLSDialer{ + rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}} + dialer := tlsdialer.TLSDialer{ Config: &tls.Config{ ServerName: "x.org", }, - Dialer: dialer.EOFConnDialer{}, + Dialer: tlsdialer.EOFConnDialer{}, TLSHandshaker: rec, } conn, err := dialer.DialTLSContext( @@ -229,7 +229,7 @@ func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) { } type RecorderTLSHandshaker struct { - dialer.TLSHandshaker + tlsdialer.TLSHandshaker SNI string } @@ -241,10 +241,10 @@ func (h *RecorderTLSHandshaker) Handshake( } func TestDialTLSContextGood(t *testing.T) { - dialer := dialer.TLSDialer{ + dialer := tlsdialer.TLSDialer{ Config: &tls.Config{ServerName: "google.com"}, Dialer: new(net.Dialer), - TLSHandshaker: dialer.SystemTLSHandshaker{}, + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, } conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443") if err != nil { @@ -257,12 +257,12 @@ func TestDialTLSContextGood(t *testing.T) { } func TestDialTLSContextTimeout(t *testing.T) { - dialer := dialer.TLSDialer{ + dialer := tlsdialer.TLSDialer{ Config: &tls.Config{ServerName: "google.com"}, Dialer: new(net.Dialer), - TLSHandshaker: dialer.ErrorWrapperTLSHandshaker{ - TLSHandshaker: dialer.TimeoutTLSHandshaker{ - TLSHandshaker: dialer.SystemTLSHandshaker{}, + TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{ + TLSHandshaker: tlsdialer.TimeoutTLSHandshaker{ + TLSHandshaker: tlsdialer.SystemTLSHandshaker{}, HandshakeTimeout: 10 * time.Microsecond, }, },