diff --git a/internal/engine/legacy/netx/dialer.go b/internal/engine/legacy/netx/dialer.go index 841a1d8..63ac676 100644 --- a/internal/engine/legacy/netx/dialer.go +++ b/internal/engine/legacy/netx/dialer.go @@ -102,8 +102,8 @@ func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { // - SystemTLSHandshaker // // If you have others needs, manually build the chain you need. -func newTLSDialer(d dialer.Dialer, config *tls.Config) tlsdialer.TLSDialer { - return tlsdialer.TLSDialer{ +func newTLSDialer(d dialer.Dialer, config *tls.Config) *netxlite.TLSDialer { + return &netxlite.TLSDialer{ Config: config, Dialer: d, TLSHandshaker: tlsdialer.EmitterTLSHandshaker{ diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 2bc68f6..ced5f12 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -191,7 +191,7 @@ func NewTLSDialer(config Config) TLSDialer { } config.TLSConfig.RootCAs = config.CertPool config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify - return tlsdialer.TLSDialer{ + return &netxlite.TLSDialer{ Config: config.TLSConfig, Dialer: config.Dialer, TLSHandshaker: h, diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 4ea9e44..569d1f5 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -211,7 +211,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { func TestNewTLSDialerVanilla(t *testing.T) { td := netx.NewTLSDialer(netx.Config{}) - rtd, ok := td.(tlsdialer.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -243,7 +243,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSConfig: new(tls.Config), }) - rtd, ok := td.(tlsdialer.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -272,7 +272,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ Logger: log.Log, }) - rtd, ok := td.(tlsdialer.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -312,7 +312,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSSaver: saver, }) - rtd, ok := td.(tlsdialer.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -352,7 +352,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { TLSConfig: new(tls.Config), NoTLSVerify: true, }) - rtd, ok := td.(tlsdialer.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -384,7 +384,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ NoTLSVerify: true, }) - rtd, ok := td.(tlsdialer.TLSDialer) + rtd, ok := td.(*netxlite.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } @@ -444,7 +444,7 @@ func TestNewWithDialer(t *testing.T) { func TestNewWithTLSDialer(t *testing.T) { expected := errors.New("mocked error") - tlsDialer := tlsdialer.TLSDialer{ + tlsDialer := &netxlite.TLSDialer{ Config: new(tls.Config), Dialer: netx.FakeDialer{Err: expected}, TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, diff --git a/internal/engine/netx/tlsdialer/integration_test.go b/internal/engine/netx/tlsdialer/integration_test.go index 13c1768..24bd2d7 100644 --- a/internal/engine/netx/tlsdialer/integration_test.go +++ b/internal/engine/netx/tlsdialer/integration_test.go @@ -1,13 +1,11 @@ package tlsdialer_test import ( - "context" "net" "net/http" "testing" "github.com/apex/log" - "github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -16,18 +14,16 @@ func TestTLSDialerSuccess(t *testing.T) { t.Skip("skip test in short mode") } log.SetLevel(log.DebugLevel) - dialer := tlsdialer.TLSDialer{Dialer: new(net.Dialer), + dialer := &netxlite.TLSDialer{Dialer: new(net.Dialer), TLSHandshaker: &netxlite.TLSHandshakerLogger{ TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, Logger: log.Log, }, } - txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) { - // AlpineLinux edge is still using Go 1.13. We cannot switch to - // using DialTLSContext here as we'd like to until either Alpine - // switches to Go 1.14 or we drop the MK dependency. - return dialer.DialTLSContext(context.Background(), network, address) - }} + txp := &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + ForceAttemptHTTP2: true, + } client := &http.Client{Transport: txp} resp, err := client.Get("https://www.google.com") if err != nil { diff --git a/internal/engine/netx/tlsdialer/saver_test.go b/internal/engine/netx/tlsdialer/saver_test.go index 37b01b7..9420fe8 100644 --- a/internal/engine/netx/tlsdialer/saver_test.go +++ b/internal/engine/netx/tlsdialer/saver_test.go @@ -22,7 +22,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { } nextprotos := []string{"h2"} saver := &trace.Saver{} - tlsdlr := tlsdialer.TLSDialer{ + tlsdlr := &netxlite.TLSDialer{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ @@ -115,7 +115,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) { } nextprotos := []string{"h2"} saver := &trace.Saver{} - tlsdlr := tlsdialer.TLSDialer{ + tlsdlr := &netxlite.TLSDialer{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ @@ -181,7 +181,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := tlsdialer.TLSDialer{ + tlsdlr := &netxlite.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, @@ -214,7 +214,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := tlsdialer.TLSDialer{ + tlsdlr := &netxlite.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, @@ -247,7 +247,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := tlsdialer.TLSDialer{ + tlsdlr := &netxlite.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, @@ -280,7 +280,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { t.Skip("skip test in short mode") } saver := &trace.Saver{} - tlsdlr := tlsdialer.TLSDialer{ + tlsdlr := &netxlite.TLSDialer{ Config: &tls.Config{InsecureSkipVerify: true}, Dialer: new(net.Dialer), TLSHandshaker: tlsdialer.SaverTLSHandshaker{ diff --git a/internal/engine/netx/tlsdialer/tls.go b/internal/engine/netx/tlsdialer/tls.go index 4c58e01..d98f763 100644 --- a/internal/engine/netx/tlsdialer/tls.go +++ b/internal/engine/netx/tlsdialer/tls.go @@ -66,41 +66,3 @@ func (h EmitterTLSHandshaker) Handshake( }) return tlsconn, state, err } - -// TLSDialer is the TLS dialer -type TLSDialer struct { - Config *tls.Config - Dialer UnderlyingDialer - TLSHandshaker TLSHandshaker -} - -// DialTLSContext is like tls.DialTLS but with the signature of net.Dialer.DialContext -func (d TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { - // Implementation note: when DialTLS is not set, the code in - // net/http will perform the handshake. Otherwise, if DialTLS - // is set, we will end up here. This code is still used when - // performing non-HTTP TLS-enabled dial operations. - host, _, err := net.SplitHostPort(address) - if err != nil { - return nil, err - } - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, err - } - config := d.Config - if config == nil { - config = new(tls.Config) - } else { - config = config.Clone() - } - if config.ServerName == "" { - config.ServerName = host - } - tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config) - if err != nil { - conn.Close() - return nil, err - } - return tlsconn, nil -} diff --git a/internal/engine/netx/tlsdialer/tls_test.go b/internal/engine/netx/tlsdialer/tls_test.go index 3b3cd8b..bdbabab 100644 --- a/internal/engine/netx/tlsdialer/tls_test.go +++ b/internal/engine/netx/tlsdialer/tls_test.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "errors" "io" - "net" "testing" "time" @@ -97,115 +96,3 @@ func TestEmitterTLSHandshakerFailure(t *testing.T) { t.Fatal("expected nonzero DurationSinceBeginning") } } - -func TestTLSDialerFailureSplitHostPort(t *testing.T) { - dialer := tlsdialer.TLSDialer{} - conn, err := dialer.DialTLSContext( - context.Background(), "tcp", "www.google.com") // missing port - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("connection is not nil") - } -} - -func TestTLSDialerFailureDialing(t *testing.T) { - dialer := tlsdialer.TLSDialer{Dialer: tlsdialer.EOFDialer{}} - conn, err := dialer.DialTLSContext( - context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, io.EOF) { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("connection is not nil") - } -} - -func TestTLSDialerFailureHandshaking(t *testing.T) { - rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}} - dialer := tlsdialer.TLSDialer{ - Dialer: tlsdialer.EOFConnDialer{}, - TLSHandshaker: rec, - } - conn, err := dialer.DialTLSContext( - context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, io.EOF) { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("connection is not nil") - } - if rec.SNI != "www.google.com" { - t.Fatal("unexpected SNI value") - } -} - -func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) { - rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}} - dialer := tlsdialer.TLSDialer{ - Config: &tls.Config{ - ServerName: "x.org", - }, - Dialer: tlsdialer.EOFConnDialer{}, - TLSHandshaker: rec, - } - conn, err := dialer.DialTLSContext( - context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, io.EOF) { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("connection is not nil") - } - if rec.SNI != "x.org" { - t.Fatal("unexpected SNI value") - } -} - -type RecorderTLSHandshaker struct { - tlsdialer.TLSHandshaker - SNI string -} - -func (h *RecorderTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { - h.SNI = config.ServerName - return h.TLSHandshaker.Handshake(ctx, conn, config) -} - -func TestDialTLSContextGood(t *testing.T) { - dialer := tlsdialer.TLSDialer{ - Config: &tls.Config{ServerName: "google.com"}, - Dialer: new(net.Dialer), - TLSHandshaker: &netxlite.TLSHandshakerStdlib{}, - } - conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443") - if err != nil { - t.Fatal(err) - } - if conn == nil { - t.Fatal("connection is nil") - } - conn.Close() -} - -func TestDialTLSContextTimeout(t *testing.T) { - dialer := tlsdialer.TLSDialer{ - Config: &tls.Config{ServerName: "google.com"}, - Dialer: new(net.Dialer), - TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerStdlib{ - Timeout: 10 * time.Microsecond, - }, - }, - } - conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443") - if err.Error() != errorx.FailureGenericTimeoutError { - t.Fatal("not the error that we expected") - } - if conn != nil { - t.Fatal("connection is not nil") - } -} diff --git a/internal/netxlite/tlsdialer.go b/internal/netxlite/tlsdialer.go new file mode 100644 index 0000000..9a959d0 --- /dev/null +++ b/internal/netxlite/tlsdialer.go @@ -0,0 +1,69 @@ +package netxlite + +import ( + "context" + "crypto/tls" + "net" +) + +// TLSDialer is the TLS dialer +type TLSDialer struct { + // Config is the OPTIONAL tls config. + Config *tls.Config + + // Dialer is the MANDATORY dialer. + Dialer Dialer + + // TLSHandshaker is the MANDATORY TLS handshaker. + TLSHandshaker TLSHandshaker +} + +// DialTLSContext dials a TLS connection. +func (d *TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + config := d.config(host, port) + tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config) + if err != nil { + conn.Close() + return nil, err + } + return tlsconn, nil +} + +// config creates a new config. If d.Config is nil, then we start +// from an empty config. Otherwise, we clone d.Config. +// +// We set the ServerName field if not already set. +// +// We set the ALPN if the port is 443 or 853, if not already set. +// +// We force using our root CA, unless it's already set. +func (d *TLSDialer) config(host, port string) *tls.Config { + config := d.Config + if config == nil { + config = &tls.Config{} + } + config = config.Clone() // operate on a clone + if config.ServerName == "" { + config.ServerName = host + } + if len(config.NextProtos) <= 0 { + switch port { + case "443": + config.NextProtos = []string{"h2", "http/1.1"} + case "853": + config.NextProtos = []string{"dot"} + } + } + if config.RootCAs == nil { + config.RootCAs = NewDefaultCertPool() + } + return config +} diff --git a/internal/netxlite/tlsdialer_test.go b/internal/netxlite/tlsdialer_test.go new file mode 100644 index 0000000..6e6c6ed --- /dev/null +++ b/internal/netxlite/tlsdialer_test.go @@ -0,0 +1,177 @@ +package netxlite + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/netxmocks" +) + +func TestTLSDialerFailureSplitHostPort(t *testing.T) { + dialer := &TLSDialer{} + ctx := context.Background() + const address = "www.google.com" // missing port + conn, err := dialer.DialTLSContext(ctx, "tcp", address) + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestTLSDialerFailureDialing(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // immediately fail + dialer := TLSDialer{Dialer: &net.Dialer{}} + conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") + if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestTLSDialerFailureHandshaking(t *testing.T) { + ctx := context.Background() + dialer := TLSDialer{ + Config: &tls.Config{}, + Dialer: &netxmocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &netxmocks.Conn{MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, MockClose: func() error { + return nil + }, MockSetDeadline: func(t time.Time) error { + return nil + }}, nil + }}, + TLSHandshaker: &TLSHandshakerStdlib{}, + } + conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestTLSDialerSuccessHandshaking(t *testing.T) { + ctx := context.Background() + dialer := TLSDialer{ + Dialer: &netxmocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &netxmocks.Conn{MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, MockClose: func() error { + return nil + }, MockSetDeadline: func(t time.Time) error { + return nil + }}, nil + }}, + TLSHandshaker: &netxmocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return tls.Client(conn, config), tls.ConnectionState{}, nil + }, + }, + } + conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("connection is nil") + } + conn.Close() +} + +func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) { + d := &TLSDialer{} + config := d.config("www.google.com", "443") + if config.ServerName != "www.google.com" { + t.Fatal("invalid server name") + } + if config.RootCAs == nil { + t.Fatal("expected non-nil root CAs") + } + if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" { + t.Fatal(diff) + } +} + +func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) { + d := &TLSDialer{} + config := d.config("dns.google", "853") + if config.ServerName != "dns.google" { + t.Fatal("invalid server name") + } + if config.RootCAs == nil { + t.Fatal("expected non-nil root CAs") + } + if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { + t.Fatal(diff) + } +} + +func TestTLSDialerConfigWithServerName(t *testing.T) { + d := &TLSDialer{ + Config: &tls.Config{ + ServerName: "example.com", + }, + } + config := d.config("dns.google", "853") + if config.ServerName != "example.com" { + t.Fatal("invalid server name") + } + if config.RootCAs == nil { + t.Fatal("expected non-nil root CAs") + } + if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { + t.Fatal(diff) + } +} + +func TestTLSDialerConfigWithALPN(t *testing.T) { + d := &TLSDialer{ + Config: &tls.Config{ + NextProtos: []string{"h2"}, + }, + } + config := d.config("dns.google", "853") + if config.ServerName != "dns.google" { + t.Fatal("invalid server name") + } + if config.RootCAs == nil { + t.Fatal("expected non-nil root CAs") + } + if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" { + t.Fatal(diff) + } +} + +func TestTLSDialerConfigWithRootCA(t *testing.T) { + pool := &x509.CertPool{} + d := &TLSDialer{ + Config: &tls.Config{ + RootCAs: pool, + }, + } + config := d.config("dns.google", "853") + if config.ServerName != "dns.google" { + t.Fatal("invalid server name") + } + if config.RootCAs != pool { + t.Fatal("not the RootCAs we expected") + } + if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" { + t.Fatal(diff) + } +}