refactor(netx): extract tlsdialer from dialer

This commit is contained in:
Simone Basso
2021-06-08 11:24:13 +02:00
parent e0311e8fed
commit 704e5bd870
18 changed files with 699 additions and 494 deletions
@@ -0,0 +1,80 @@
package tlsdialer
import (
"context"
"crypto/tls"
"io"
"net"
"time"
)
type EOFDialer struct{}
func (EOFDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return nil, io.EOF
}
type EOFConnDialer struct{}
func (EOFConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return EOFConn{}, nil
}
type EOFConn struct {
net.Conn
}
func (EOFConn) Read(p []byte) (int, error) {
time.Sleep(10 * time.Microsecond)
return 0, io.EOF
}
func (EOFConn) Write(p []byte) (int, error) {
time.Sleep(10 * time.Microsecond)
return 0, io.EOF
}
func (EOFConn) Close() error {
time.Sleep(10 * time.Microsecond)
return io.EOF
}
func (EOFConn) LocalAddr() net.Addr {
return EOFAddr{}
}
func (EOFConn) RemoteAddr() net.Addr {
return EOFAddr{}
}
func (EOFConn) SetDeadline(t time.Time) error {
return nil
}
func (EOFConn) SetReadDeadline(t time.Time) error {
return nil
}
func (EOFConn) SetWriteDeadline(t time.Time) error {
return nil
}
type EOFAddr struct{}
func (EOFAddr) Network() string {
return "tcp"
}
func (EOFAddr) String() string {
return "127.0.0.1:1234"
}
type EOFTLSHandshaker struct{}
func (EOFTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
time.Sleep(10 * time.Microsecond)
return nil, tls.ConnectionState{}, io.EOF
}
@@ -0,0 +1,71 @@
package tlsdialer
import (
"context"
"io"
"net"
"time"
)
type FakeDialer struct {
Conn net.Conn
Err error
}
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return d.Conn, d.Err
}
type FakeConn struct {
ReadError error
ReadData []byte
SetDeadlineError error
SetReadDeadlineError error
SetWriteDeadlineError error
WriteError error
}
func (c *FakeConn) Read(b []byte) (int, error) {
if len(c.ReadData) > 0 {
n := copy(b, c.ReadData)
c.ReadData = c.ReadData[n:]
return n, nil
}
if c.ReadError != nil {
return 0, c.ReadError
}
return 0, io.EOF
}
func (c *FakeConn) Write(b []byte) (n int, err error) {
if c.WriteError != nil {
return 0, c.WriteError
}
n = len(b)
return
}
func (*FakeConn) Close() (err error) {
return
}
func (*FakeConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (*FakeConn) RemoteAddr() net.Addr {
return &net.TCPAddr{}
}
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
return c.SetDeadlineError
}
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
return c.SetReadDeadlineError
}
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
return c.SetWriteDeadlineError
}
@@ -0,0 +1,36 @@
package tlsdialer_test
import (
"context"
"net"
"net/http"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
)
func TestTLSDialerSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
log.SetLevel(log.DebugLevel)
dialer := tlsdialer.TLSDialer{Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.LoggingTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
Logger: log.Log,
},
}
txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) {
// AlpineLinux edge is still using Go 1.13. We cannot switch to
// using DialTLSContext here as we'd like to until either Alpine
// switches to Go 1.14 or we drop the MK dependency.
return dialer.DialTLSContext(context.Background(), network, address)
}}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
}
+39
View File
@@ -0,0 +1,39 @@
package tlsdialer
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
)
// Logger is the logger assumed by this package
type Logger interface {
Debugf(format string, v ...interface{})
Debug(message string)
}
// LoggingTLSHandshaker is a TLSHandshaker with logging
type LoggingTLSHandshaker struct {
TLSHandshaker
Logger Logger
}
// Handshake implements Handshaker.Handshake
func (h LoggingTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
h.Logger.Debugf("tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos)
start := time.Now()
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
stop := time.Now()
h.Logger.Debugf(
"tls {sni=%s next=%+v}... %+v in %s {next=%s cipher=%s v=%s}", config.ServerName,
config.NextProtos, err, stop.Sub(start), state.NegotiatedProtocol,
tlsx.CipherSuiteString(state.CipherSuite), tlsx.VersionString(state.Version))
return tlsconn, state, err
}
var _ TLSHandshaker = LoggingTLSHandshaker{}
@@ -0,0 +1,28 @@
package tlsdialer_test
import (
"context"
"crypto/tls"
"errors"
"io"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
)
func TestLoggingTLSHandshakerFailure(t *testing.T) {
h := tlsdialer.LoggingTLSHandshaker{
TLSHandshaker: tlsdialer.EOFTLSHandshaker{},
Logger: log.Log,
}
tlsconn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
ServerName: "www.google.com",
})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if tlsconn != nil {
t.Fatal("expected nil tlsconn here")
}
}
+49
View File
@@ -0,0 +1,49 @@
package tlsdialer
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// SaverTLSHandshaker saves events occurring during the handshake
type SaverTLSHandshaker struct {
TLSHandshaker
Saver *trace.Saver
}
// Handshake implements TLSHandshaker.Handshake
func (h SaverTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
start := time.Now()
h.Saver.Write(trace.Event{
Name: "tls_handshake_start",
NoTLSVerify: config.InsecureSkipVerify,
TLSNextProtos: config.NextProtos,
TLSServerName: config.ServerName,
Time: start,
})
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
stop := time.Now()
h.Saver.Write(trace.Event{
Duration: stop.Sub(start),
Err: err,
Name: "tls_handshake_done",
NoTLSVerify: config.InsecureSkipVerify,
TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: config.NextProtos,
TLSPeerCerts: trace.PeerCerts(state, err),
TLSServerName: config.ServerName,
TLSVersion: tlsx.VersionString(state.Version),
Time: stop,
})
return tlsconn, state, err
}
var _ TLSHandshaker = SaverTLSHandshaker{}
@@ -0,0 +1,313 @@
package tlsdialer_test
import (
"context"
"crypto/tls"
"net"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
// This is the most common use case for collecting reads, writes
if testing.Short() {
t.Skip("skip test in short mode")
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: dialer.SaverConnDialer{
Dialer: new(net.Dialer),
Saver: saver,
},
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
Saver: saver,
},
}
// Implementation note: we don't close the connection here because it is
// very handy to have the last event being the end of the handshake
_, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
if err != nil {
t.Fatal(err)
}
ev := saver.Read()
if len(ev) < 4 {
// it's a bit tricky to be sure about the right number of
// events because network conditions may influence that
t.Fatal("unexpected number of events")
}
if ev[0].Name != "tls_handshake_start" {
t.Fatal("unexpected Name")
}
if ev[0].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("unexpected Time")
}
last := len(ev) - 1
for idx := 1; idx < last; idx++ {
if ev[idx].Data == nil {
t.Fatal("unexpected Data")
}
if ev[idx].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[idx].Err != nil {
t.Fatal("unexpected Err")
}
if ev[idx].NumBytes <= 0 {
t.Fatal("unexpected NumBytes")
}
switch ev[idx].Name {
case errorx.ReadOperation, errorx.WriteOperation:
default:
t.Fatal("unexpected Name")
}
if ev[idx].Time.Before(ev[idx-1].Time) {
t.Fatal("unexpected Time")
}
}
if ev[last].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[last].Err != nil {
t.Fatal("unexpected Err")
}
if ev[last].Name != "tls_handshake_done" {
t.Fatal("unexpected Name")
}
if ev[last].TLSCipherSuite == "" {
t.Fatal("unexpected TLSCipherSuite")
}
if ev[last].TLSNegotiatedProto != "h2" {
t.Fatal("unexpected TLSNegotiatedProto")
}
if !reflect.DeepEqual(ev[last].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[last].TLSPeerCerts == nil {
t.Fatal("unexpected TLSPeerCerts")
}
if ev[last].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if ev[last].TLSVersion == "" {
t.Fatal("unexpected TLSVersion")
}
if ev[last].Time.Before(ev[last-1].Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverTLSHandshakerSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
if err != nil {
t.Fatal(err)
}
conn.Close()
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("unexpected number of events")
}
if ev[0].Name != "tls_handshake_start" {
t.Fatal("unexpected Name")
}
if ev[0].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("unexpected Time")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name != "tls_handshake_done" {
t.Fatal("unexpected Name")
}
if ev[1].TLSCipherSuite == "" {
t.Fatal("unexpected TLSCipherSuite")
}
if ev[1].TLSNegotiatedProto != "h2" {
t.Fatal("unexpected TLSNegotiatedProto")
}
if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[1].TLSPeerCerts == nil {
t.Fatal("unexpected TLSPeerCerts")
}
if ev[1].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if ev[1].TLSVersion == "" {
t.Fatal("unexpected TLSVersion")
}
if ev[1].Time.Before(ev[0].Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverTLSHandshakerHostnameError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "wrong.host.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "expired.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := tlsdialer.TLSDialer{
Config: &tls.Config{InsecureSkipVerify: true},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn here")
}
conn.Close()
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify != true {
t.Fatal("expected NoTLSVerify to be true")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
+147
View File
@@ -0,0 +1,147 @@
// Package tlsdialer contains code to establish TLS connections.
package tlsdialer
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
// UnderlyingDialer is the underlying dialer type.
type UnderlyingDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// TLSHandshaker is the generic TLS handshaker
type TLSHandshaker interface {
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}
// SystemTLSHandshaker is the system TLS handshaker.
type SystemTLSHandshaker struct{}
// Handshake implements Handshaker.Handshake
func (h SystemTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn := tls.Client(conn, config)
if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
}
// TimeoutTLSHandshaker is a TLSHandshaker with timeout
type TimeoutTLSHandshaker struct {
TLSHandshaker
HandshakeTimeout time.Duration // default: 10 second
}
// Handshake implements Handshaker.Handshake
func (h TimeoutTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
timeout := 10 * time.Second
if h.HandshakeTimeout != 0 {
timeout = h.HandshakeTimeout
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, tls.ConnectionState{}, err
}
tlsconn, connstate, err := h.TLSHandshaker.Handshake(ctx, conn, config)
conn.SetDeadline(time.Time{})
return tlsconn, connstate, err
}
// ErrorWrapperTLSHandshaker wraps the returned error to be an OONI error
type ErrorWrapperTLSHandshaker struct {
TLSHandshaker
}
// Handshake implements Handshaker.Handshake
func (h ErrorWrapperTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
err = errorx.SafeErrWrapperBuilder{
ConnID: connID,
Error: err,
Operation: errorx.TLSHandshakeOperation,
}.MaybeBuild()
return tlsconn, state, err
}
// EmitterTLSHandshaker emits events using the MeasurementRoot
type EmitterTLSHandshaker struct {
TLSHandshaker
}
// Handshake implements Handshaker.Handshake
func (h EmitterTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
root := modelx.ContextMeasurementRootOrDefault(ctx)
root.Handler.OnMeasurement(modelx.Measurement{
TLSHandshakeStart: &modelx.TLSHandshakeStartEvent{
ConnID: connID,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
SNI: config.ServerName,
},
})
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
root.Handler.OnMeasurement(modelx.Measurement{
TLSHandshakeDone: &modelx.TLSHandshakeDoneEvent{
ConnID: connID,
ConnectionState: modelx.NewTLSConnectionState(state),
Error: err,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
},
})
return tlsconn, state, err
}
// TLSDialer is the TLS dialer
type TLSDialer struct {
Config *tls.Config
Dialer UnderlyingDialer
TLSHandshaker TLSHandshaker
}
// DialTLSContext is like tls.DialTLS but with the signature of net.Dialer.DialContext
func (d TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
// Implementation note: when DialTLS is not set, the code in
// net/http will perform the handshake. Otherwise, if DialTLS
// is set, we will end up here. This code is still used when
// performing non-HTTP TLS-enabled dial operations.
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
config := d.Config
if config == nil {
config = new(tls.Config)
} else {
config = config.Clone()
}
if config.ServerName == "" {
config.ServerName = host
}
tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil {
conn.Close()
return nil, err
}
return tlsconn, nil
}
+277
View File
@@ -0,0 +1,277 @@
package tlsdialer_test
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
)
func TestSystemTLSHandshakerEOFError(t *testing.T) {
h := tlsdialer.SystemTLSHandshaker{}
conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
ServerName: "x.org",
})
if err != io.EOF {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) {
h := tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
expected := errors.New("mocked error")
conn, _, err := h.Handshake(
context.Background(), &tlsdialer.FakeConn{SetDeadlineError: expected},
new(tls.Config))
if !errors.Is(err, expected) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerEOFError(t *testing.T) {
h := tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
conn, _, err := h.Handshake(
context.Background(), tlsdialer.EOFConn{}, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) {
h := tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
underlying := &SetDeadlineConn{}
conn, _, err := h.Handshake(
context.Background(), underlying, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
if len(underlying.deadlines) != 2 {
t.Fatal("SetDeadline not called twice")
}
if underlying.deadlines[0].Before(time.Now()) {
t.Fatal("the first SetDeadline call was incorrect")
}
if !underlying.deadlines[1].IsZero() {
t.Fatal("the second SetDeadline call was incorrect")
}
}
type SetDeadlineConn struct {
tlsdialer.EOFConn
deadlines []time.Time
}
func (c *SetDeadlineConn) SetDeadline(t time.Time) error {
c.deadlines = append(c.deadlines, t)
return nil
}
func TestErrorWrapperTLSHandshakerFailure(t *testing.T) {
h := tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: tlsdialer.EOFTLSHandshaker{}}
conn, _, err := h.Handshake(
context.Background(), tlsdialer.EOFConn{}, new(tls.Config))
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
var errWrapper *errorx.ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("cannot cast to ErrWrapper")
}
if errWrapper.ConnID == 0 {
t.Fatal("unexpected ConnID")
}
if errWrapper.Failure != errorx.FailureEOFError {
t.Fatal("unexpected Failure")
}
if errWrapper.Operation != errorx.TLSHandshakeOperation {
t.Fatal("unexpected Operation")
}
}
func TestEmitterTLSHandshakerFailure(t *testing.T) {
saver := &handlers.SavingHandler{}
ctx := modelx.WithMeasurementRoot(context.Background(), &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: saver,
})
h := tlsdialer.EmitterTLSHandshaker{TLSHandshaker: tlsdialer.EOFTLSHandshaker{}}
conn, _, err := h.Handshake(ctx, tlsdialer.EOFConn{}, &tls.Config{
ServerName: "www.kernel.org",
})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("Wrong number of events")
}
if events[0].TLSHandshakeStart == nil {
t.Fatal("missing TLSHandshakeStart event")
}
if events[0].TLSHandshakeStart.ConnID == 0 {
t.Fatal("expected nonzero ConnID")
}
if events[0].TLSHandshakeStart.DurationSinceBeginning == 0 {
t.Fatal("expected nonzero DurationSinceBeginning")
}
if events[0].TLSHandshakeStart.SNI != "www.kernel.org" {
t.Fatal("expected nonzero SNI")
}
if events[1].TLSHandshakeDone == nil {
t.Fatal("missing TLSHandshakeDone event")
}
if events[1].TLSHandshakeDone.ConnID == 0 {
t.Fatal("expected nonzero ConnID")
}
if events[1].TLSHandshakeDone.DurationSinceBeginning == 0 {
t.Fatal("expected nonzero DurationSinceBeginning")
}
}
func TestTLSDialerFailureSplitHostPort(t *testing.T) {
dialer := tlsdialer.TLSDialer{}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com") // missing port
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}
func TestTLSDialerFailureDialing(t *testing.T) {
dialer := tlsdialer.TLSDialer{Dialer: tlsdialer.EOFDialer{}}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}
func TestTLSDialerFailureHandshaking(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}}
dialer := tlsdialer.TLSDialer{
Dialer: tlsdialer.EOFConnDialer{},
TLSHandshaker: rec,
}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
if rec.SNI != "www.google.com" {
t.Fatal("unexpected SNI value")
}
}
func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}}
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{
ServerName: "x.org",
},
Dialer: tlsdialer.EOFConnDialer{},
TLSHandshaker: rec,
}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
if rec.SNI != "x.org" {
t.Fatal("unexpected SNI value")
}
}
type RecorderTLSHandshaker struct {
tlsdialer.TLSHandshaker
SNI string
}
func (h *RecorderTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
h.SNI = config.ServerName
return h.TLSHandshaker.Handshake(ctx, conn, config)
}
func TestDialTLSContextGood(t *testing.T) {
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("connection is nil")
}
conn.Close()
}
func TestDialTLSContextTimeout(t *testing.T) {
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 10 * time.Microsecond,
},
},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err.Error() != errorx.FailureGenericTimeoutError {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}