refactor(netx): use netxlite to build TLSDialer (#790)
This diff modifies netx to use netxlite to build the TLSDialer. Building the TLSDialer entails building a TLSHandshaker. While there, hide netxlite names we don't want to be public and change netx tests to test for functionality. To this end, refactor filtering to provide an easier to use TLS server. We don't need the complexity of proxying rather we need to provoke specific errors. Part of https://github.com/ooni/probe/issues/2121
This commit is contained in:
parent
ae24ba644c
commit
e9ed733f07
|
@ -64,8 +64,6 @@ type Config struct {
|
||||||
TLSSaver *tracex.Saver // default: not saving TLS
|
TLSSaver *tracex.Saver // default: not saving TLS
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool()
|
|
||||||
|
|
||||||
// NewResolver creates a new resolver from the specified config
|
// NewResolver creates a new resolver from the specified config
|
||||||
func NewResolver(config Config) model.Resolver {
|
func NewResolver(config Config) model.Resolver {
|
||||||
if config.BaseResolver == nil {
|
if config.BaseResolver == nil {
|
||||||
|
@ -132,25 +130,16 @@ func NewTLSDialer(config Config) model.TLSDialer {
|
||||||
if config.Dialer == nil {
|
if config.Dialer == nil {
|
||||||
config.Dialer = NewDialer(config)
|
config.Dialer = NewDialer(config)
|
||||||
}
|
}
|
||||||
var h model.TLSHandshaker = &netxlite.TLSHandshakerConfigurable{}
|
logger := model.ValidLoggerOrDefault(config.Logger)
|
||||||
h = &netxlite.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
|
thx := netxlite.NewTLSHandshakerStdlib(logger)
|
||||||
if config.Logger != nil {
|
thx = config.TLSSaver.WrapTLSHandshaker(thx) // WAI when TLSSaver is nil
|
||||||
h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h}
|
tlsConfig := netxlite.ClonedTLSConfigOrNewEmptyConfig(config.TLSConfig)
|
||||||
}
|
// TODO(bassosimone): we should not provide confusing options and
|
||||||
h = config.TLSSaver.WrapTLSHandshaker(h) // behaves with nil TLSSaver
|
// so we should drop CertPool and NoTLSVerify in favour of encouraging
|
||||||
if config.TLSConfig == nil {
|
// the users of this library to always use a TLSConfig.
|
||||||
config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
|
tlsConfig.RootCAs = config.CertPool // netxlite uses default cert pool if this is nil
|
||||||
}
|
tlsConfig.InsecureSkipVerify = config.NoTLSVerify
|
||||||
if config.CertPool == nil {
|
return netxlite.NewTLSDialerWithConfig(config.Dialer, thx, tlsConfig)
|
||||||
config.CertPool = defaultCertPool
|
|
||||||
}
|
|
||||||
config.TLSConfig.RootCAs = config.CertPool
|
|
||||||
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
|
|
||||||
return &netxlite.TLSDialerLegacy{
|
|
||||||
Config: config.TLSConfig,
|
|
||||||
Dialer: config.Dialer,
|
|
||||||
TLSHandshaker: h,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned
|
// NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||||
"github.com/ooni/probe-cli/v3/internal/tracex"
|
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -208,210 +209,103 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTLSDialerVanilla(t *testing.T) {
|
func TestNewTLSDialer(t *testing.T) {
|
||||||
td := NewTLSDialer(Config{})
|
t.Run("we always have error wrapping", func(t *testing.T) {
|
||||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||||
if !ok {
|
defer server.Close()
|
||||||
t.Fatal("not the TLSDialer we expected")
|
tdx := NewTLSDialer(Config{})
|
||||||
|
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||||
|
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if len(rtd.Config.NextProtos) != 2 {
|
if conn != nil {
|
||||||
t.Fatal("invalid len(config.NextProtos)")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
|
||||||
t.Fatal("invalid Config.NextProtos")
|
|
||||||
}
|
|
||||||
if rtd.Config.RootCAs != defaultCertPool {
|
|
||||||
t.Fatal("invalid Config.RootCAs")
|
|
||||||
}
|
|
||||||
if rtd.Dialer == nil {
|
|
||||||
t.Fatal("invalid Dialer")
|
|
||||||
}
|
|
||||||
if rtd.TLSHandshaker == nil {
|
|
||||||
t.Fatal("invalid TLSHandshaker")
|
|
||||||
}
|
|
||||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSDialerWithConfig(t *testing.T) {
|
|
||||||
td := NewTLSDialer(Config{
|
|
||||||
TLSConfig: new(tls.Config),
|
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSDialer we expected")
|
|
||||||
}
|
|
||||||
if len(rtd.Config.NextProtos) != 0 {
|
|
||||||
t.Fatal("invalid len(config.NextProtos)")
|
|
||||||
}
|
|
||||||
if rtd.Config.RootCAs != defaultCertPool {
|
|
||||||
t.Fatal("invalid Config.RootCAs")
|
|
||||||
}
|
|
||||||
if rtd.Dialer == nil {
|
|
||||||
t.Fatal("invalid Dialer")
|
|
||||||
}
|
|
||||||
if rtd.TLSHandshaker == nil {
|
|
||||||
t.Fatal("invalid TLSHandshaker")
|
|
||||||
}
|
|
||||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSDialerWithLogging(t *testing.T) {
|
t.Run("we can collect TLS measurements", func(t *testing.T) {
|
||||||
td := NewTLSDialer(Config{
|
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||||
Logger: log.Log,
|
defer server.Close()
|
||||||
})
|
saver := &tracex.Saver{}
|
||||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
tdx := NewTLSDialer(Config{
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSDialer we expected")
|
|
||||||
}
|
|
||||||
if len(rtd.Config.NextProtos) != 2 {
|
|
||||||
t.Fatal("invalid len(config.NextProtos)")
|
|
||||||
}
|
|
||||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
|
||||||
t.Fatal("invalid Config.NextProtos")
|
|
||||||
}
|
|
||||||
if rtd.Config.RootCAs != defaultCertPool {
|
|
||||||
t.Fatal("invalid Config.RootCAs")
|
|
||||||
}
|
|
||||||
if rtd.Dialer == nil {
|
|
||||||
t.Fatal("invalid Dialer")
|
|
||||||
}
|
|
||||||
if rtd.TLSHandshaker == nil {
|
|
||||||
t.Fatal("invalid TLSHandshaker")
|
|
||||||
}
|
|
||||||
lth, ok := rtd.TLSHandshaker.(*netxlite.TLSHandshakerLogger)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
if lth.DebugLogger != log.Log {
|
|
||||||
t.Fatal("not the Logger we expected")
|
|
||||||
}
|
|
||||||
ewth, ok := lth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSDialerWithSaver(t *testing.T) {
|
|
||||||
saver := new(tracex.Saver)
|
|
||||||
td := NewTLSDialer(Config{
|
|
||||||
TLSSaver: saver,
|
TLSSaver: saver,
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||||
if !ok {
|
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if len(rtd.Config.NextProtos) != 2 {
|
if conn != nil {
|
||||||
t.Fatal("invalid len(config.NextProtos)")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
if len(saver.Read()) <= 0 {
|
||||||
t.Fatal("invalid Config.NextProtos")
|
t.Fatal("did not read any event")
|
||||||
}
|
}
|
||||||
if rtd.Config.RootCAs != defaultCertPool {
|
|
||||||
t.Fatal("invalid Config.RootCAs")
|
|
||||||
}
|
|
||||||
if rtd.Dialer == nil {
|
|
||||||
t.Fatal("invalid Dialer")
|
|
||||||
}
|
|
||||||
if rtd.TLSHandshaker == nil {
|
|
||||||
t.Fatal("invalid TLSHandshaker")
|
|
||||||
}
|
|
||||||
sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
if sth.Saver != saver {
|
|
||||||
t.Fatal("not the Logger we expected")
|
|
||||||
}
|
|
||||||
ewth, ok := sth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
|
|
||||||
td := NewTLSDialer(Config{
|
|
||||||
TLSConfig: new(tls.Config),
|
|
||||||
NoTLSVerify: true,
|
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSDialer we expected")
|
|
||||||
}
|
|
||||||
if len(rtd.Config.NextProtos) != 0 {
|
|
||||||
t.Fatal("invalid len(config.NextProtos)")
|
|
||||||
}
|
|
||||||
if rtd.Config.InsecureSkipVerify != true {
|
|
||||||
t.Fatal("expected true InsecureSkipVerify")
|
|
||||||
}
|
|
||||||
if rtd.Config.RootCAs != defaultCertPool {
|
|
||||||
t.Fatal("invalid Config.RootCAs")
|
|
||||||
}
|
|
||||||
if rtd.Dialer == nil {
|
|
||||||
t.Fatal("invalid Dialer")
|
|
||||||
}
|
|
||||||
if rtd.TLSHandshaker == nil {
|
|
||||||
t.Fatal("invalid TLSHandshaker")
|
|
||||||
}
|
|
||||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
|
t.Run("we can collect dial measurements", func(t *testing.T) {
|
||||||
td := NewTLSDialer(Config{
|
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||||
NoTLSVerify: true,
|
defer server.Close()
|
||||||
|
saver := &tracex.Saver{}
|
||||||
|
tdx := NewTLSDialer(Config{
|
||||||
|
DialSaver: saver,
|
||||||
})
|
})
|
||||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||||
if !ok {
|
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||||
t.Fatal("not the TLSDialer we expected")
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if len(rtd.Config.NextProtos) != 2 {
|
if conn != nil {
|
||||||
t.Fatal("invalid len(config.NextProtos)")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
if len(saver.Read()) <= 0 {
|
||||||
t.Fatal("invalid Config.NextProtos")
|
t.Fatal("did not read any event")
|
||||||
}
|
}
|
||||||
if rtd.Config.InsecureSkipVerify != true {
|
})
|
||||||
t.Fatal("expected true InsecureSkipVerify")
|
|
||||||
|
t.Run("we can collect I/O measurements", func(t *testing.T) {
|
||||||
|
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||||
|
defer server.Close()
|
||||||
|
saver := &tracex.Saver{}
|
||||||
|
tdx := NewTLSDialer(Config{
|
||||||
|
ReadWriteSaver: saver,
|
||||||
|
})
|
||||||
|
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||||
|
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if rtd.Config.RootCAs != defaultCertPool {
|
if conn != nil {
|
||||||
t.Fatal("invalid Config.RootCAs")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
if rtd.Dialer == nil {
|
if len(saver.Read()) <= 0 {
|
||||||
t.Fatal("invalid Dialer")
|
t.Fatal("did not read any event")
|
||||||
}
|
}
|
||||||
if rtd.TLSHandshaker == nil {
|
})
|
||||||
t.Fatal("invalid TLSHandshaker")
|
|
||||||
|
t.Run("we can skip TLS verification", func(t *testing.T) {
|
||||||
|
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||||
|
defer server.Close()
|
||||||
|
tdx := NewTLSDialer(Config{NoTLSVerify: true})
|
||||||
|
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err.(*netxlite.ErrWrapper).WrappedErr)
|
||||||
}
|
}
|
||||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
conn.Close()
|
||||||
if !ok {
|
})
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
|
||||||
}
|
t.Run("we can set the cert pool", func(t *testing.T) {
|
||||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
defer server.Close()
|
||||||
|
tdx := NewTLSDialer(Config{
|
||||||
|
CertPool: server.CertPool(),
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewVanilla(t *testing.T) {
|
func TestNewVanilla(t *testing.T) {
|
||||||
|
@ -441,33 +335,6 @@ func TestNewWithDialer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewWithTLSDialer(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
tlsDialer := &netxlite.TLSDialerLegacy{
|
|
||||||
Config: new(tls.Config),
|
|
||||||
Dialer: &mocks.Dialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
MockCloseIdleConnections: func() {
|
|
||||||
// nothing
|
|
||||||
},
|
|
||||||
},
|
|
||||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
|
||||||
}
|
|
||||||
txp := NewHTTPTransport(Config{
|
|
||||||
TLSDialer: tlsDialer,
|
|
||||||
})
|
|
||||||
client := &http.Client{Transport: txp}
|
|
||||||
resp, err := client.Get("https://www.google.com")
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if resp != nil {
|
|
||||||
t.Fatal("not the response we expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewWithByteCounter(t *testing.T) {
|
func TestNewWithByteCounter(t *testing.T) {
|
||||||
counter := bytecounter.New()
|
counter := bytecounter.New()
|
||||||
txp := NewHTTPTransport(Config{
|
txp := NewHTTPTransport(Config{
|
||||||
|
|
|
@ -1,24 +1,22 @@
|
||||||
package filtering
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"time"
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO(bassosimone): remove TLSActionPass since we want integration tests
|
"github.com/google/martian/v3/mitm"
|
||||||
// to only run locally to make them much more predictable.
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
|
)
|
||||||
|
|
||||||
// TLSAction is a TLS filtering action that this proxy should take.
|
// TLSAction is a TLS filtering action that this proxy should take.
|
||||||
type TLSAction string
|
type TLSAction string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// TLSActionPass passes the traffic to the destination.
|
|
||||||
TLSActionPass = TLSAction("pass")
|
|
||||||
|
|
||||||
// TLSActionReset resets the connection.
|
// TLSActionReset resets the connection.
|
||||||
TLSActionReset = TLSAction("reset")
|
TLSActionReset = TLSAction("reset")
|
||||||
|
|
||||||
|
@ -35,48 +33,98 @@ const (
|
||||||
// TLSActionAlertUnrecognizedName tells the client that
|
// TLSActionAlertUnrecognizedName tells the client that
|
||||||
// it's handshaking with an unknown SNI.
|
// it's handshaking with an unknown SNI.
|
||||||
TLSActionAlertUnrecognizedName = TLSAction("alert-unrecognized-name")
|
TLSActionAlertUnrecognizedName = TLSAction("alert-unrecognized-name")
|
||||||
|
|
||||||
|
// TLSActionBlockText returns a static piece of text
|
||||||
|
// to the client saying this website is blocked.
|
||||||
|
TLSActionBlockText = TLSAction("block-text")
|
||||||
)
|
)
|
||||||
|
|
||||||
// TLSProxy is a TLS proxy that routes the traffic depending
|
// TLSServer is a TLS server implementing filtering policies.
|
||||||
// on the SNI value and may implement filtering policies.
|
type TLSServer struct {
|
||||||
type TLSProxy struct {
|
// action is the action to perform.
|
||||||
// OnIncomingSNI is the MANDATORY hook called whenever we have
|
action TLSAction
|
||||||
// successfully received a ClientHello message.
|
|
||||||
OnIncomingSNI func(sni string) TLSAction
|
// cancel allows to cancel background operations.
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
// cert is the fake CA certificate.
|
||||||
|
cert *x509.Certificate
|
||||||
|
|
||||||
|
// config is the config to generate certificates on the fly.
|
||||||
|
config *mitm.Config
|
||||||
|
|
||||||
|
// done is closed when the background goroutine has terminated.
|
||||||
|
done chan bool
|
||||||
|
|
||||||
|
// endpoint is the endpoint where we're listening.
|
||||||
|
endpoint string
|
||||||
|
|
||||||
|
// listener is the TCP listener.
|
||||||
|
listener net.Listener
|
||||||
|
|
||||||
|
// privkey is the private key that signed the cert.
|
||||||
|
privkey *rsa.PrivateKey
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the proxy.
|
// NewTLSServer creates and starts a new TLSServer that executes
|
||||||
func (p *TLSProxy) Start(address string) (net.Listener, error) {
|
// the given action during the TLS handshake.
|
||||||
listener, _, err := p.start(address)
|
func NewTLSServer(action TLSAction) *TLSServer {
|
||||||
return listener, err
|
done := make(chan bool)
|
||||||
|
cert, privkey, err := mitm.NewAuthority("jafar", "OONI", 24*time.Hour)
|
||||||
|
runtimex.PanicOnError(err, "mitm.NewAuthority failed")
|
||||||
|
config, err := mitm.NewConfig(cert, privkey)
|
||||||
|
runtimex.PanicOnError(err, "mitm.NewConfig failed")
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
runtimex.PanicOnError(err, "net.Listen failed")
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
endpoint := listener.Addr().String()
|
||||||
|
server := &TLSServer{
|
||||||
|
action: action,
|
||||||
|
cancel: cancel,
|
||||||
|
cert: cert,
|
||||||
|
config: config,
|
||||||
|
done: done,
|
||||||
|
endpoint: endpoint,
|
||||||
|
listener: listener,
|
||||||
|
privkey: privkey,
|
||||||
|
}
|
||||||
|
go server.mainloop(ctx)
|
||||||
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, error) {
|
// CertPool returns the internal CA as a cert pool.
|
||||||
listener, err := net.Listen("tcp", address)
|
func (p *TLSServer) CertPool() *x509.CertPool {
|
||||||
if err != nil {
|
o := x509.NewCertPool()
|
||||||
return nil, nil, err
|
o.AddCert(p.cert)
|
||||||
}
|
return o
|
||||||
done := make(chan interface{})
|
|
||||||
go p.mainloop(listener, done)
|
|
||||||
return listener, done, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) {
|
// Endpoint returns the endpoint where the server is listening.
|
||||||
defer close(done)
|
func (p *TLSServer) Endpoint() string {
|
||||||
for p.oneloop(listener) {
|
return p.endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes this server as soon as possible.
|
||||||
|
func (p *TLSServer) Close() error {
|
||||||
|
p.cancel()
|
||||||
|
err := p.listener.Close()
|
||||||
|
<-p.done
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TLSServer) mainloop(ctx context.Context) {
|
||||||
|
defer close(p.done)
|
||||||
|
for p.oneloop(ctx) {
|
||||||
// nothing
|
// nothing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLSProxy) oneloop(listener net.Listener) bool {
|
func (p *TLSServer) oneloop(ctx context.Context) bool {
|
||||||
conn, err := listener.Accept()
|
conn, err := p.listener.Accept()
|
||||||
if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") {
|
|
||||||
return false // we need to stop
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true // we can continue running
|
return !errors.Is(err, net.ErrClosed)
|
||||||
}
|
}
|
||||||
go p.handle(conn)
|
go p.handle(ctx, conn)
|
||||||
return true // we can continue running
|
return true // we can continue running
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,102 +133,55 @@ const (
|
||||||
tlsAlertUnrecognizedName = byte(112)
|
tlsAlertUnrecognizedName = byte(112)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *TLSProxy) handle(conn net.Conn) {
|
func (p *TLSServer) handle(ctx context.Context, tcpConn net.Conn) {
|
||||||
defer conn.Close()
|
defer tcpConn.Close()
|
||||||
sni, hello, err := p.readClientHello(conn)
|
tlsConn := tls.Server(tcpConn, &tls.Config{
|
||||||
if err != nil {
|
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
p.reset(conn)
|
switch p.action {
|
||||||
|
case TLSActionTimeout:
|
||||||
|
select {
|
||||||
|
case <-time.After(300 * time.Second):
|
||||||
|
return nil, errors.New("timing out the connection")
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.reset(tcpConn)
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
case TLSActionAlertInternalError:
|
||||||
|
p.alert(tcpConn, tlsAlertInternalError)
|
||||||
|
return nil, errors.New("already sent alert")
|
||||||
|
case TLSActionAlertUnrecognizedName:
|
||||||
|
p.alert(tcpConn, tlsAlertUnrecognizedName)
|
||||||
|
return nil, errors.New("already sent alert")
|
||||||
|
case TLSActionEOF:
|
||||||
|
p.eof(tcpConn)
|
||||||
|
return nil, errors.New("already closed the connection")
|
||||||
|
case TLSActionBlockText:
|
||||||
|
return p.config.TLSForHost(info.ServerName).GetCertificate(info)
|
||||||
|
default:
|
||||||
|
p.reset(tcpConn)
|
||||||
|
return nil, errors.New("already RST the connection")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch p.OnIncomingSNI(sni) {
|
p.blockText(tlsConn)
|
||||||
case TLSActionPass:
|
tlsConn.Close()
|
||||||
p.proxy(conn, sni, hello)
|
|
||||||
case TLSActionTimeout:
|
|
||||||
p.timeout(conn)
|
|
||||||
case TLSActionAlertInternalError:
|
|
||||||
p.alert(conn, tlsAlertInternalError)
|
|
||||||
case TLSActionAlertUnrecognizedName:
|
|
||||||
p.alert(conn, tlsAlertUnrecognizedName)
|
|
||||||
case TLSActionEOF:
|
|
||||||
p.eof(conn)
|
|
||||||
default:
|
|
||||||
p.reset(conn)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// readClientHello reads the incoming ClientHello message.
|
func (p *TLSServer) reset(conn net.Conn) {
|
||||||
//
|
if tc, good := conn.(*net.TCPConn); good {
|
||||||
// Arguments:
|
|
||||||
//
|
|
||||||
// - conn is the connection from which to read the ClientHello.
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
//
|
|
||||||
// - a string containing the SNI (empty on error);
|
|
||||||
//
|
|
||||||
// - bytes from the original ClientHello (nil on error);
|
|
||||||
//
|
|
||||||
// - an error (nil on success).
|
|
||||||
func (p *TLSProxy) readClientHello(conn net.Conn) (string, []byte, error) {
|
|
||||||
connWrapper := &tlsClientHelloReader{Conn: conn}
|
|
||||||
var (
|
|
||||||
expectedErr = errors.New("cannot continue handhake")
|
|
||||||
sni string
|
|
||||||
mutex sync.Mutex // just for safety
|
|
||||||
)
|
|
||||||
err := tls.Server(connWrapper, &tls.Config{
|
|
||||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
mutex.Lock()
|
|
||||||
sni = info.ServerName
|
|
||||||
mutex.Unlock()
|
|
||||||
return nil, expectedErr
|
|
||||||
},
|
|
||||||
}).Handshake()
|
|
||||||
if !errors.Is(err, expectedErr) {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
return sni, connWrapper.clientHello, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// tlsClientHelloReader wraps a net.Conn for the purpose of
|
|
||||||
// saving the bytes of the ClientHello message.
|
|
||||||
type tlsClientHelloReader struct {
|
|
||||||
net.Conn
|
|
||||||
clientHello []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
|
|
||||||
count, err := c.Conn.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
c.clientHello = append(c.clientHello, b[:count]...)
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write prevents writing on the real connection
|
|
||||||
func (c *tlsClientHelloReader) Write(b []byte) (int, error) {
|
|
||||||
return 0, errors.New("cannot write on this connection")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TLSProxy) reset(conn net.Conn) {
|
|
||||||
if tc, ok := conn.(*net.TCPConn); ok {
|
|
||||||
tc.SetLinger(0)
|
tc.SetLinger(0)
|
||||||
}
|
}
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLSProxy) timeout(conn net.Conn) {
|
func (p *TLSServer) eof(conn net.Conn) {
|
||||||
buffer := make([]byte, 1<<14)
|
|
||||||
conn.Read(buffer)
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLSProxy) eof(conn net.Conn) {
|
func (p *TLSServer) alert(conn net.Conn, code byte) {
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TLSProxy) alert(conn net.Conn, code byte) {
|
|
||||||
alertdata := []byte{
|
alertdata := []byte{
|
||||||
21, // alert
|
21, // alert
|
||||||
3, // version[0]
|
3, // version[0]
|
||||||
|
@ -194,55 +195,6 @@ func (p *TLSProxy) alert(conn net.Conn, code byte) {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TLSProxy) proxy(conn net.Conn, sni string, hello []byte) {
|
func (p *TLSServer) blockText(tlsConn net.Conn) {
|
||||||
p.proxydial(conn, sni, hello, net.Dial)
|
tlsConn.Write(HTTPBlockpage451)
|
||||||
}
|
|
||||||
|
|
||||||
func (p *TLSProxy) proxydial(conn net.Conn, sni string, hello []byte,
|
|
||||||
dial func(network, address string) (net.Conn, error)) {
|
|
||||||
if sni == "" { // don't know the destination host
|
|
||||||
p.reset(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
serverconn, err := dial("tcp", net.JoinHostPort(sni, "443"))
|
|
||||||
if err != nil {
|
|
||||||
p.reset(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if p.connectingToMyself(serverconn) {
|
|
||||||
p.reset(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := serverconn.Write(hello); err != nil {
|
|
||||||
p.reset(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer serverconn.Close() // conn is owned by the caller
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
wg.Add(2)
|
|
||||||
go p.forward(wg, conn, serverconn)
|
|
||||||
go p.forward(wg, serverconn, conn)
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// connectingToMyself returns true when the proxy has been somehow
|
|
||||||
// forced to create a connection to itself.
|
|
||||||
func (p *TLSProxy) connectingToMyself(conn net.Conn) bool {
|
|
||||||
local := conn.LocalAddr().String()
|
|
||||||
localAddr, _, localErr := net.SplitHostPort(local)
|
|
||||||
remote := conn.RemoteAddr().String()
|
|
||||||
remoteAddr, _, remoteErr := net.SplitHostPort(remote)
|
|
||||||
return localErr != nil || remoteErr != nil || localAddr == remoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward will forward the traffic.
|
|
||||||
func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) {
|
|
||||||
defer wg.Done()
|
|
||||||
// We cannot use netxlite.CopyContext here because we want netxlite to
|
|
||||||
// use filtering inside its test suite, so this package cannot depend on
|
|
||||||
// netxlite. In general, we don't want to use io.Copy or io.ReadAll
|
|
||||||
// directly because they may cause the code to block as documented in
|
|
||||||
// internal/netxlite/iox.go. However, this package is only used for
|
|
||||||
// testing, so it's completely okay to make an exception here.
|
|
||||||
io.Copy(left, right)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,297 +1,146 @@
|
||||||
package filtering
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/apex/log"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTLSProxy(t *testing.T) {
|
func TestTLSServer(t *testing.T) {
|
||||||
if testing.Short() {
|
t.Run("TLSActionReset", func(t *testing.T) {
|
||||||
t.Skip("skip test in short mode")
|
srv := NewTLSServer(TLSActionReset)
|
||||||
}
|
defer srv.Close()
|
||||||
newproxy := func(action TLSAction) (net.Listener, <-chan interface{}, error) {
|
config := &tls.Config{ServerName: "dns.google"}
|
||||||
p := &TLSProxy{
|
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||||
OnIncomingSNI: func(sni string) TLSAction {
|
if netxlite.NewTopLevelGenericErrWrapper(err).Error() != netxlite.FailureConnectionReset {
|
||||||
return action
|
t.Fatal("unexpected err", err)
|
||||||
},
|
}
|
||||||
}
|
if conn != nil {
|
||||||
return p.start("127.0.0.1:0")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
dialTLS := func(ctx context.Context, endpoint string, sni string) (net.Conn, error) {
|
|
||||||
d := netxlite.NewDialerWithoutResolver(log.Log)
|
t.Run("TLSActionTimeout", func(t *testing.T) {
|
||||||
th := netxlite.NewTLSHandshakerStdlib(log.Log)
|
srv := NewTLSServer(TLSActionTimeout)
|
||||||
tdx := netxlite.NewTLSDialerWithConfig(d, th, &tls.Config{
|
defer srv.Close()
|
||||||
ServerName: sni,
|
config := &tls.Config{ServerName: "dns.google"}
|
||||||
NextProtos: []string{"h2", "http/1.1"},
|
d := &tls.Dialer{Config: config}
|
||||||
RootCAs: netxlite.NewDefaultCertPool(),
|
ctx, cancel := context.WithTimeout(context.Background(), 70*time.Millisecond)
|
||||||
})
|
defer cancel()
|
||||||
return tdx.DialTLSContext(ctx, "tcp", endpoint)
|
conn, err := d.DialContext(ctx, "tcp", srv.Endpoint())
|
||||||
}
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
|
||||||
t.Run("TLSActionPass", func(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
listener, done, err := newproxy(TLSActionPass)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TLSActionTimeout", func(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
listener, done, err := newproxy(TLSActionTimeout)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
|
||||||
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
|
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TLSActionAlertInternalError", func(t *testing.T) {
|
t.Run("TLSActionAlertInternalError", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
srv := NewTLSServer(TLSActionAlertInternalError)
|
||||||
listener, done, err := newproxy(TLSActionAlertInternalError)
|
defer srv.Close()
|
||||||
if err != nil {
|
config := &tls.Config{ServerName: "dns.google"}
|
||||||
t.Fatal(err)
|
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||||
}
|
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") {
|
if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) {
|
t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
srv := NewTLSServer(TLSActionAlertUnrecognizedName)
|
||||||
listener, done, err := newproxy(TLSActionAlertUnrecognizedName)
|
defer srv.Close()
|
||||||
if err != nil {
|
config := &tls.Config{ServerName: "dns.google"}
|
||||||
t.Fatal(err)
|
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||||
}
|
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") {
|
if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TLSActionEOF", func(t *testing.T) {
|
t.Run("TLSActionEOF", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
srv := NewTLSServer(TLSActionEOF)
|
||||||
listener, done, err := newproxy(TLSActionEOF)
|
defer srv.Close()
|
||||||
if err != nil {
|
config := &tls.Config{ServerName: "dns.google"}
|
||||||
t.Fatal(err)
|
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||||
}
|
if !errors.Is(err, io.EOF) {
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
|
||||||
if err == nil || err.Error() != netxlite.FailureEOFError {
|
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TLSActionReset", func(t *testing.T) {
|
t.Run("TLSActionBlockText", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
t.Run("certificate error when we're validating", func(t *testing.T) {
|
||||||
listener, done, err := newproxy(TLSActionReset)
|
srv := NewTLSServer(TLSActionBlockText)
|
||||||
if err != nil {
|
defer srv.Close()
|
||||||
t.Fatal(err)
|
// Certificate.Verify now uses platform APIs to verify certificate validity
|
||||||
|
// on macOS and iOS when it is called with a nil VerifyOpts.Roots or when using
|
||||||
|
// the root pool returned from SystemCertPool. "
|
||||||
|
//
|
||||||
|
// -- https://tip.golang.org/doc/go1.18
|
||||||
|
//
|
||||||
|
// So we need to explicitly use our default cert pool otherwise we will
|
||||||
|
// see this test failing with a different error string here.
|
||||||
|
config := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
RootCAs: netxlite.NewDefaultCertPool(),
|
||||||
}
|
}
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
if err == nil || !strings.HasSuffix(err.Error(), "certificate signed by unknown authority") {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn")
|
t.Fatal("expected nil conn")
|
||||||
}
|
}
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
})
|
||||||
|
|
||||||
dial := func(ctx context.Context, endpoint string) (net.Conn, error) {
|
t.Run("blocktext when we skip validation", func(t *testing.T) {
|
||||||
d := netxlite.NewDialerWithoutResolver(log.Log)
|
srv := NewTLSServer(TLSActionBlockText)
|
||||||
return d.DialContext(ctx, "tcp", endpoint)
|
defer srv.Close()
|
||||||
}
|
config := &tls.Config{InsecureSkipVerify: true, ServerName: "dns.google"}
|
||||||
|
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||||
t.Run("handle cannot read ClientHello", func(t *testing.T) {
|
|
||||||
listener, done, err := newproxy(TLSActionPass)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn, err := dial(context.Background(), listener.Addr().String())
|
defer conn.Close()
|
||||||
|
data, err := io.ReadAll(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n"))
|
if !bytes.Equal(HTTPBlockpage451, data) {
|
||||||
buff := make([]byte, 1<<17)
|
t.Fatal("unexpected block text")
|
||||||
_, err = conn.Read(buff)
|
|
||||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
}
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TLSActionPass fails because we don't have SNI", func(t *testing.T) {
|
t.Run("blocktext when we configure the cert pool", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
srv := NewTLSServer(TLSActionBlockText)
|
||||||
listener, done, err := newproxy(TLSActionPass)
|
defer srv.Close()
|
||||||
|
config := &tls.Config{RootCAs: srv.CertPool(), ServerName: "dns.google"}
|
||||||
|
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "127.0.0.1")
|
defer conn.Close()
|
||||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
data, err := io.ReadAll(conn)
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("expected nil conn")
|
|
||||||
}
|
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("TLSActionPass fails because we can't dial", func(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
listener, done, err := newproxy(TLSActionPass)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn, err := dialTLS(ctx, listener.Addr().String(), "antani.ooni.org")
|
if !bytes.Equal(HTTPBlockpage451, data) {
|
||||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
t.Fatal("unexpected block text")
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("expected nil conn")
|
|
||||||
}
|
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("proxydial fails because it's connecting to itself", func(t *testing.T) {
|
|
||||||
p := &TLSProxy{}
|
|
||||||
conn := &mocks.Conn{
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
|
|
||||||
return &mocks.Conn{
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
MockLocalAddr: func() net.Addr {
|
|
||||||
return &net.TCPAddr{
|
|
||||||
IP: net.IPv6loopback,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
MockRemoteAddr: func() net.Addr {
|
|
||||||
return &net.TCPAddr{
|
|
||||||
IP: net.IPv6loopback,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("proxydial fails because it cannot write the hello", func(t *testing.T) {
|
|
||||||
p := &TLSProxy{}
|
|
||||||
conn := &mocks.Conn{
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
|
|
||||||
return &mocks.Conn{
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
MockLocalAddr: func() net.Addr {
|
|
||||||
return &net.TCPAddr{
|
|
||||||
IP: net.IPv6loopback,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
MockRemoteAddr: func() net.Addr {
|
|
||||||
return &net.TCPAddr{
|
|
||||||
IP: net.IPv4(10, 0, 0, 1),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
MockWrite: func(b []byte) (int, error) {
|
|
||||||
return 0, errors.New("mocked error")
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
|
||||||
p := &TLSProxy{}
|
|
||||||
listener, err := p.Start("127.0.0.1")
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if listener != nil {
|
|
||||||
t.Fatal("expected nil listener")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("oneloop correctly handles a listener error", func(t *testing.T) {
|
|
||||||
listener := &mocks.Listener{
|
|
||||||
MockAccept: func() (net.Conn, error) {
|
|
||||||
return nil, errors.New("mocked error")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
p := &TLSProxy{}
|
|
||||||
if !p.oneloop(listener) {
|
|
||||||
t.Fatal("should return true here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSClientHelloReader(t *testing.T) {
|
|
||||||
t.Run("on failure", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
chr := &tlsClientHelloReader{
|
|
||||||
Conn: &mocks.Conn{
|
|
||||||
MockRead: func(b []byte) (int, error) {
|
|
||||||
return 0, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
clientHello: []byte{},
|
|
||||||
}
|
|
||||||
buf := make([]byte, 128)
|
|
||||||
count, err := chr.Read(buf)
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if count != 0 {
|
|
||||||
t.Fatal("invalid count")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -303,18 +303,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
connectionResetFlow := func(th model.TLSHandshaker) error {
|
connectionResetFlow := func(th model.TLSHandshaker) error {
|
||||||
tlsProxy := &filtering.TLSProxy{
|
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||||
OnIncomingSNI: func(sni string) filtering.TLSAction {
|
defer server.Close()
|
||||||
return filtering.TLSActionReset
|
|
||||||
},
|
|
||||||
}
|
|
||||||
listener, err := tlsProxy.Start("127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot start proxy: %w", err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn, err := dial(ctx, listener.Addr().String())
|
conn, err := dial(ctx, server.Endpoint())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dial failed: %w", err)
|
return fmt.Errorf("dial failed: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -338,18 +330,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
timeoutFlow := func(th model.TLSHandshaker) error {
|
timeoutFlow := func(th model.TLSHandshaker) error {
|
||||||
tlsProxy := &filtering.TLSProxy{
|
server := filtering.NewTLSServer(filtering.TLSActionTimeout)
|
||||||
OnIncomingSNI: func(sni string) filtering.TLSAction {
|
defer server.Close()
|
||||||
return filtering.TLSActionTimeout
|
|
||||||
},
|
|
||||||
}
|
|
||||||
listener, err := tlsProxy.Start("127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot start proxy: %w", err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn, err := dial(ctx, listener.Addr().String())
|
conn, err := dial(ctx, server.Endpoint())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dial failed: %w", err)
|
return fmt.Errorf("dial failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,12 +20,8 @@ type (
|
||||||
HTTPTransportWrapper = httpTransportConnectionsCloser
|
HTTPTransportWrapper = httpTransportConnectionsCloser
|
||||||
HTTPTransportLogger = httpTransportLogger
|
HTTPTransportLogger = httpTransportLogger
|
||||||
ErrorWrapperResolver = resolverErrWrapper
|
ErrorWrapperResolver = resolverErrWrapper
|
||||||
ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper
|
|
||||||
ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport
|
ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport
|
||||||
ResolverLogger = resolverLogger
|
ResolverLogger = resolverLogger
|
||||||
ResolverIDNA = resolverIDNA
|
ResolverIDNA = resolverIDNA
|
||||||
TLSHandshakerConfigurable = tlsHandshakerConfigurable
|
|
||||||
TLSHandshakerLogger = tlsHandshakerLogger
|
|
||||||
TLSDialerLegacy = tlsDialer
|
|
||||||
AddressResolver = resolverShortCircuitIPAddr
|
AddressResolver = resolverShortCircuitIPAddr
|
||||||
)
|
)
|
||||||
|
|
|
@ -60,6 +60,15 @@ var (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ClonedTLSConfigOrNewEmptyConfig returns a clone of the provided config,
|
||||||
|
// if not nil, or a fresh and completely empty *tls.Config.
|
||||||
|
func ClonedTLSConfigOrNewEmptyConfig(config *tls.Config) *tls.Config {
|
||||||
|
if config != nil {
|
||||||
|
return config.Clone()
|
||||||
|
}
|
||||||
|
return &tls.Config{}
|
||||||
|
}
|
||||||
|
|
||||||
// TLSVersionString returns a TLS version string. If value is zero, we
|
// TLSVersionString returns a TLS version string. If value is zero, we
|
||||||
// return the empty string. If the value is unknown, we return
|
// return the empty string. If the value is unknown, we return
|
||||||
// `TLS_VERSION_UNKNOWN_ddd` where `ddd` is the numeric value passed
|
// `TLS_VERSION_UNKNOWN_ddd` where `ddd` is the numeric value passed
|
||||||
|
|
|
@ -591,3 +591,30 @@ func TestNewNullTLSDialer(t *testing.T) {
|
||||||
}
|
}
|
||||||
dialer.CloseIdleConnections() // does not crash
|
dialer.CloseIdleConnections() // does not crash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
|
||||||
|
t.Run("with nil config", func(t *testing.T) {
|
||||||
|
var input *tls.Config
|
||||||
|
output := ClonedTLSConfigOrNewEmptyConfig(input)
|
||||||
|
if output == nil {
|
||||||
|
t.Fatal("expected non-nil result")
|
||||||
|
}
|
||||||
|
v := reflect.ValueOf(*output)
|
||||||
|
if !v.IsZero() {
|
||||||
|
t.Fatal("expected zero config")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
input := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
output := ClonedTLSConfigOrNewEmptyConfig(input)
|
||||||
|
if output == input {
|
||||||
|
t.Fatal("expected two distinct objects")
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(input, output) {
|
||||||
|
t.Fatal("apparently the two objects have different values")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ for file in $(find . -type f -name \*.go); do
|
||||||
# implement safer wrappers for these functions.
|
# implement safer wrappers for these functions.
|
||||||
continue
|
continue
|
||||||
fi
|
fi
|
||||||
if [ "$file" = "./internal/netxlite/filtering/tls.go" ]; then
|
if [ "$file" = "./internal/netxlite/filtering/tls_test.go" ]; then
|
||||||
# We're allowed to use ReadAll and Copy in this file to
|
# We're allowed to use ReadAll and Copy in this file to
|
||||||
# avoid depending on netxlite, so we can use filtering
|
# avoid depending on netxlite, so we can use filtering
|
||||||
# inside of netxlite's own test suite.
|
# inside of netxlite's own test suite.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user