refactor(netx): use netxlite to build TLSDialer (#790)
This diff modifies netx to use netxlite to build the TLSDialer. Building the TLSDialer entails building a TLSHandshaker. While there, hide netxlite names we don't want to be public and change netx tests to test for functionality. To this end, refactor filtering to provide an easier to use TLS server. We don't need the complexity of proxying rather we need to provoke specific errors. Part of https://github.com/ooni/probe/issues/2121
This commit is contained in:
parent
ae24ba644c
commit
e9ed733f07
|
@ -64,8 +64,6 @@ type Config struct {
|
|||
TLSSaver *tracex.Saver // default: not saving TLS
|
||||
}
|
||||
|
||||
var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool()
|
||||
|
||||
// NewResolver creates a new resolver from the specified config
|
||||
func NewResolver(config Config) model.Resolver {
|
||||
if config.BaseResolver == nil {
|
||||
|
@ -132,25 +130,16 @@ func NewTLSDialer(config Config) model.TLSDialer {
|
|||
if config.Dialer == nil {
|
||||
config.Dialer = NewDialer(config)
|
||||
}
|
||||
var h model.TLSHandshaker = &netxlite.TLSHandshakerConfigurable{}
|
||||
h = &netxlite.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
|
||||
if config.Logger != nil {
|
||||
h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h}
|
||||
}
|
||||
h = config.TLSSaver.WrapTLSHandshaker(h) // behaves with nil TLSSaver
|
||||
if config.TLSConfig == nil {
|
||||
config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
|
||||
}
|
||||
if config.CertPool == nil {
|
||||
config.CertPool = defaultCertPool
|
||||
}
|
||||
config.TLSConfig.RootCAs = config.CertPool
|
||||
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
|
||||
return &netxlite.TLSDialerLegacy{
|
||||
Config: config.TLSConfig,
|
||||
Dialer: config.Dialer,
|
||||
TLSHandshaker: h,
|
||||
}
|
||||
logger := model.ValidLoggerOrDefault(config.Logger)
|
||||
thx := netxlite.NewTLSHandshakerStdlib(logger)
|
||||
thx = config.TLSSaver.WrapTLSHandshaker(thx) // WAI when TLSSaver is nil
|
||||
tlsConfig := netxlite.ClonedTLSConfigOrNewEmptyConfig(config.TLSConfig)
|
||||
// TODO(bassosimone): we should not provide confusing options and
|
||||
// so we should drop CertPool and NoTLSVerify in favour of encouraging
|
||||
// the users of this library to always use a TLSConfig.
|
||||
tlsConfig.RootCAs = config.CertPool // netxlite uses default cert pool if this is nil
|
||||
tlsConfig.InsecureSkipVerify = config.NoTLSVerify
|
||||
return netxlite.NewTLSDialerWithConfig(config.Dialer, thx, tlsConfig)
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||
)
|
||||
|
||||
|
@ -208,210 +209,103 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewTLSDialerVanilla(t *testing.T) {
|
||||
td := NewTLSDialer(Config{})
|
||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSDialer we expected")
|
||||
}
|
||||
if len(rtd.Config.NextProtos) != 2 {
|
||||
t.Fatal("invalid len(config.NextProtos)")
|
||||
}
|
||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
||||
t.Fatal("invalid Config.NextProtos")
|
||||
}
|
||||
if rtd.Config.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid Config.RootCAs")
|
||||
}
|
||||
if rtd.Dialer == nil {
|
||||
t.Fatal("invalid Dialer")
|
||||
}
|
||||
if rtd.TLSHandshaker == nil {
|
||||
t.Fatal("invalid TLSHandshaker")
|
||||
}
|
||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSDialerWithConfig(t *testing.T) {
|
||||
td := NewTLSDialer(Config{
|
||||
TLSConfig: new(tls.Config),
|
||||
func TestNewTLSDialer(t *testing.T) {
|
||||
t.Run("we always have error wrapping", func(t *testing.T) {
|
||||
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||
defer server.Close()
|
||||
tdx := NewTLSDialer(Config{})
|
||||
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSDialer we expected")
|
||||
}
|
||||
if len(rtd.Config.NextProtos) != 0 {
|
||||
t.Fatal("invalid len(config.NextProtos)")
|
||||
}
|
||||
if rtd.Config.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid Config.RootCAs")
|
||||
}
|
||||
if rtd.Dialer == nil {
|
||||
t.Fatal("invalid Dialer")
|
||||
}
|
||||
if rtd.TLSHandshaker == nil {
|
||||
t.Fatal("invalid TLSHandshaker")
|
||||
}
|
||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSDialerWithLogging(t *testing.T) {
|
||||
td := NewTLSDialer(Config{
|
||||
Logger: log.Log,
|
||||
t.Run("we can collect TLS measurements", func(t *testing.T) {
|
||||
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||
defer server.Close()
|
||||
saver := &tracex.Saver{}
|
||||
tdx := NewTLSDialer(Config{
|
||||
TLSSaver: saver,
|
||||
})
|
||||
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
if len(saver.Read()) <= 0 {
|
||||
t.Fatal("did not read any event")
|
||||
}
|
||||
})
|
||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSDialer we expected")
|
||||
}
|
||||
if len(rtd.Config.NextProtos) != 2 {
|
||||
t.Fatal("invalid len(config.NextProtos)")
|
||||
}
|
||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
||||
t.Fatal("invalid Config.NextProtos")
|
||||
}
|
||||
if rtd.Config.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid Config.RootCAs")
|
||||
}
|
||||
if rtd.Dialer == nil {
|
||||
t.Fatal("invalid Dialer")
|
||||
}
|
||||
if rtd.TLSHandshaker == nil {
|
||||
t.Fatal("invalid TLSHandshaker")
|
||||
}
|
||||
lth, ok := rtd.TLSHandshaker.(*netxlite.TLSHandshakerLogger)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if lth.DebugLogger != log.Log {
|
||||
t.Fatal("not the Logger we expected")
|
||||
}
|
||||
ewth, ok := lth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSDialerWithSaver(t *testing.T) {
|
||||
saver := new(tracex.Saver)
|
||||
td := NewTLSDialer(Config{
|
||||
TLSSaver: saver,
|
||||
t.Run("we can collect dial measurements", func(t *testing.T) {
|
||||
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||
defer server.Close()
|
||||
saver := &tracex.Saver{}
|
||||
tdx := NewTLSDialer(Config{
|
||||
DialSaver: saver,
|
||||
})
|
||||
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
if len(saver.Read()) <= 0 {
|
||||
t.Fatal("did not read any event")
|
||||
}
|
||||
})
|
||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSDialer we expected")
|
||||
}
|
||||
if len(rtd.Config.NextProtos) != 2 {
|
||||
t.Fatal("invalid len(config.NextProtos)")
|
||||
}
|
||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
||||
t.Fatal("invalid Config.NextProtos")
|
||||
}
|
||||
if rtd.Config.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid Config.RootCAs")
|
||||
}
|
||||
if rtd.Dialer == nil {
|
||||
t.Fatal("invalid Dialer")
|
||||
}
|
||||
if rtd.TLSHandshaker == nil {
|
||||
t.Fatal("invalid TLSHandshaker")
|
||||
}
|
||||
sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if sth.Saver != saver {
|
||||
t.Fatal("not the Logger we expected")
|
||||
}
|
||||
ewth, ok := sth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
|
||||
td := NewTLSDialer(Config{
|
||||
TLSConfig: new(tls.Config),
|
||||
NoTLSVerify: true,
|
||||
t.Run("we can collect I/O measurements", func(t *testing.T) {
|
||||
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||
defer server.Close()
|
||||
saver := &tracex.Saver{}
|
||||
tdx := NewTLSDialer(Config{
|
||||
ReadWriteSaver: saver,
|
||||
})
|
||||
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
if len(saver.Read()) <= 0 {
|
||||
t.Fatal("did not read any event")
|
||||
}
|
||||
})
|
||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSDialer we expected")
|
||||
}
|
||||
if len(rtd.Config.NextProtos) != 0 {
|
||||
t.Fatal("invalid len(config.NextProtos)")
|
||||
}
|
||||
if rtd.Config.InsecureSkipVerify != true {
|
||||
t.Fatal("expected true InsecureSkipVerify")
|
||||
}
|
||||
if rtd.Config.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid Config.RootCAs")
|
||||
}
|
||||
if rtd.Dialer == nil {
|
||||
t.Fatal("invalid Dialer")
|
||||
}
|
||||
if rtd.TLSHandshaker == nil {
|
||||
t.Fatal("invalid TLSHandshaker")
|
||||
}
|
||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
|
||||
td := NewTLSDialer(Config{
|
||||
NoTLSVerify: true,
|
||||
t.Run("we can skip TLS verification", func(t *testing.T) {
|
||||
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||
defer server.Close()
|
||||
tdx := NewTLSDialer(Config{NoTLSVerify: true})
|
||||
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||
if err != nil {
|
||||
t.Fatal(err.(*netxlite.ErrWrapper).WrappedErr)
|
||||
}
|
||||
conn.Close()
|
||||
})
|
||||
|
||||
t.Run("we can set the cert pool", func(t *testing.T) {
|
||||
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||
defer server.Close()
|
||||
tdx := NewTLSDialer(Config{
|
||||
CertPool: server.CertPool(),
|
||||
TLSConfig: &tls.Config{
|
||||
ServerName: "dns.google",
|
||||
},
|
||||
})
|
||||
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
})
|
||||
rtd, ok := td.(*netxlite.TLSDialerLegacy)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSDialer we expected")
|
||||
}
|
||||
if len(rtd.Config.NextProtos) != 2 {
|
||||
t.Fatal("invalid len(config.NextProtos)")
|
||||
}
|
||||
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
|
||||
t.Fatal("invalid Config.NextProtos")
|
||||
}
|
||||
if rtd.Config.InsecureSkipVerify != true {
|
||||
t.Fatal("expected true InsecureSkipVerify")
|
||||
}
|
||||
if rtd.Config.RootCAs != defaultCertPool {
|
||||
t.Fatal("invalid Config.RootCAs")
|
||||
}
|
||||
if rtd.Dialer == nil {
|
||||
t.Fatal("invalid Dialer")
|
||||
}
|
||||
if rtd.TLSHandshaker == nil {
|
||||
t.Fatal("invalid TLSHandshaker")
|
||||
}
|
||||
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
|
||||
if !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
|
||||
t.Fatal("not the TLSHandshaker we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVanilla(t *testing.T) {
|
||||
|
@ -441,33 +335,6 @@ func TestNewWithDialer(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNewWithTLSDialer(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
tlsDialer := &netxlite.TLSDialerLegacy{
|
||||
Config: new(tls.Config),
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockCloseIdleConnections: func() {
|
||||
// nothing
|
||||
},
|
||||
},
|
||||
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
|
||||
}
|
||||
txp := NewHTTPTransport(Config{
|
||||
TLSDialer: tlsDialer,
|
||||
})
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewWithByteCounter(t *testing.T) {
|
||||
counter := bytecounter.New()
|
||||
txp := NewHTTPTransport(Config{
|
||||
|
|
|
@ -1,24 +1,22 @@
|
|||
package filtering
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
"time"
|
||||
|
||||
// TODO(bassosimone): remove TLSActionPass since we want integration tests
|
||||
// to only run locally to make them much more predictable.
|
||||
"github.com/google/martian/v3/mitm"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
// TLSAction is a TLS filtering action that this proxy should take.
|
||||
type TLSAction string
|
||||
|
||||
const (
|
||||
// TLSActionPass passes the traffic to the destination.
|
||||
TLSActionPass = TLSAction("pass")
|
||||
|
||||
// TLSActionReset resets the connection.
|
||||
TLSActionReset = TLSAction("reset")
|
||||
|
||||
|
@ -35,48 +33,98 @@ const (
|
|||
// TLSActionAlertUnrecognizedName tells the client that
|
||||
// it's handshaking with an unknown SNI.
|
||||
TLSActionAlertUnrecognizedName = TLSAction("alert-unrecognized-name")
|
||||
|
||||
// TLSActionBlockText returns a static piece of text
|
||||
// to the client saying this website is blocked.
|
||||
TLSActionBlockText = TLSAction("block-text")
|
||||
)
|
||||
|
||||
// TLSProxy is a TLS proxy that routes the traffic depending
|
||||
// on the SNI value and may implement filtering policies.
|
||||
type TLSProxy struct {
|
||||
// OnIncomingSNI is the MANDATORY hook called whenever we have
|
||||
// successfully received a ClientHello message.
|
||||
OnIncomingSNI func(sni string) TLSAction
|
||||
// TLSServer is a TLS server implementing filtering policies.
|
||||
type TLSServer struct {
|
||||
// action is the action to perform.
|
||||
action TLSAction
|
||||
|
||||
// cancel allows to cancel background operations.
|
||||
cancel context.CancelFunc
|
||||
|
||||
// cert is the fake CA certificate.
|
||||
cert *x509.Certificate
|
||||
|
||||
// config is the config to generate certificates on the fly.
|
||||
config *mitm.Config
|
||||
|
||||
// done is closed when the background goroutine has terminated.
|
||||
done chan bool
|
||||
|
||||
// endpoint is the endpoint where we're listening.
|
||||
endpoint string
|
||||
|
||||
// listener is the TCP listener.
|
||||
listener net.Listener
|
||||
|
||||
// privkey is the private key that signed the cert.
|
||||
privkey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
// Start starts the proxy.
|
||||
func (p *TLSProxy) Start(address string) (net.Listener, error) {
|
||||
listener, _, err := p.start(address)
|
||||
return listener, err
|
||||
}
|
||||
|
||||
func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, error) {
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
// NewTLSServer creates and starts a new TLSServer that executes
|
||||
// the given action during the TLS handshake.
|
||||
func NewTLSServer(action TLSAction) *TLSServer {
|
||||
done := make(chan bool)
|
||||
cert, privkey, err := mitm.NewAuthority("jafar", "OONI", 24*time.Hour)
|
||||
runtimex.PanicOnError(err, "mitm.NewAuthority failed")
|
||||
config, err := mitm.NewConfig(cert, privkey)
|
||||
runtimex.PanicOnError(err, "mitm.NewConfig failed")
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
runtimex.PanicOnError(err, "net.Listen failed")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
endpoint := listener.Addr().String()
|
||||
server := &TLSServer{
|
||||
action: action,
|
||||
cancel: cancel,
|
||||
cert: cert,
|
||||
config: config,
|
||||
done: done,
|
||||
endpoint: endpoint,
|
||||
listener: listener,
|
||||
privkey: privkey,
|
||||
}
|
||||
done := make(chan interface{})
|
||||
go p.mainloop(listener, done)
|
||||
return listener, done, nil
|
||||
go server.mainloop(ctx)
|
||||
return server
|
||||
}
|
||||
|
||||
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) {
|
||||
defer close(done)
|
||||
for p.oneloop(listener) {
|
||||
// CertPool returns the internal CA as a cert pool.
|
||||
func (p *TLSServer) CertPool() *x509.CertPool {
|
||||
o := x509.NewCertPool()
|
||||
o.AddCert(p.cert)
|
||||
return o
|
||||
}
|
||||
|
||||
// Endpoint returns the endpoint where the server is listening.
|
||||
func (p *TLSServer) Endpoint() string {
|
||||
return p.endpoint
|
||||
}
|
||||
|
||||
// Close closes this server as soon as possible.
|
||||
func (p *TLSServer) Close() error {
|
||||
p.cancel()
|
||||
err := p.listener.Close()
|
||||
<-p.done
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *TLSServer) mainloop(ctx context.Context) {
|
||||
defer close(p.done)
|
||||
for p.oneloop(ctx) {
|
||||
// nothing
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TLSProxy) oneloop(listener net.Listener) bool {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
return false // we need to stop
|
||||
}
|
||||
func (p *TLSServer) oneloop(ctx context.Context) bool {
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
return true // we can continue running
|
||||
return !errors.Is(err, net.ErrClosed)
|
||||
}
|
||||
go p.handle(conn)
|
||||
go p.handle(ctx, conn)
|
||||
return true // we can continue running
|
||||
}
|
||||
|
||||
|
@ -85,102 +133,55 @@ const (
|
|||
tlsAlertUnrecognizedName = byte(112)
|
||||
)
|
||||
|
||||
func (p *TLSProxy) handle(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
sni, hello, err := p.readClientHello(conn)
|
||||
if err != nil {
|
||||
p.reset(conn)
|
||||
func (p *TLSServer) handle(ctx context.Context, tcpConn net.Conn) {
|
||||
defer tcpConn.Close()
|
||||
tlsConn := tls.Server(tcpConn, &tls.Config{
|
||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
switch p.action {
|
||||
case TLSActionTimeout:
|
||||
select {
|
||||
case <-time.After(300 * time.Second):
|
||||
return nil, errors.New("timing out the connection")
|
||||
case <-ctx.Done():
|
||||
p.reset(tcpConn)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
case TLSActionAlertInternalError:
|
||||
p.alert(tcpConn, tlsAlertInternalError)
|
||||
return nil, errors.New("already sent alert")
|
||||
case TLSActionAlertUnrecognizedName:
|
||||
p.alert(tcpConn, tlsAlertUnrecognizedName)
|
||||
return nil, errors.New("already sent alert")
|
||||
case TLSActionEOF:
|
||||
p.eof(tcpConn)
|
||||
return nil, errors.New("already closed the connection")
|
||||
case TLSActionBlockText:
|
||||
return p.config.TLSForHost(info.ServerName).GetCertificate(info)
|
||||
default:
|
||||
p.reset(tcpConn)
|
||||
return nil, errors.New("already RST the connection")
|
||||
}
|
||||
},
|
||||
})
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return
|
||||
}
|
||||
switch p.OnIncomingSNI(sni) {
|
||||
case TLSActionPass:
|
||||
p.proxy(conn, sni, hello)
|
||||
case TLSActionTimeout:
|
||||
p.timeout(conn)
|
||||
case TLSActionAlertInternalError:
|
||||
p.alert(conn, tlsAlertInternalError)
|
||||
case TLSActionAlertUnrecognizedName:
|
||||
p.alert(conn, tlsAlertUnrecognizedName)
|
||||
case TLSActionEOF:
|
||||
p.eof(conn)
|
||||
default:
|
||||
p.reset(conn)
|
||||
}
|
||||
p.blockText(tlsConn)
|
||||
tlsConn.Close()
|
||||
}
|
||||
|
||||
// readClientHello reads the incoming ClientHello message.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - conn is the connection from which to read the ClientHello.
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// - a string containing the SNI (empty on error);
|
||||
//
|
||||
// - bytes from the original ClientHello (nil on error);
|
||||
//
|
||||
// - an error (nil on success).
|
||||
func (p *TLSProxy) readClientHello(conn net.Conn) (string, []byte, error) {
|
||||
connWrapper := &tlsClientHelloReader{Conn: conn}
|
||||
var (
|
||||
expectedErr = errors.New("cannot continue handhake")
|
||||
sni string
|
||||
mutex sync.Mutex // just for safety
|
||||
)
|
||||
err := tls.Server(connWrapper, &tls.Config{
|
||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
mutex.Lock()
|
||||
sni = info.ServerName
|
||||
mutex.Unlock()
|
||||
return nil, expectedErr
|
||||
},
|
||||
}).Handshake()
|
||||
if !errors.Is(err, expectedErr) {
|
||||
return "", nil, err
|
||||
}
|
||||
return sni, connWrapper.clientHello, nil
|
||||
}
|
||||
|
||||
// tlsClientHelloReader wraps a net.Conn for the purpose of
|
||||
// saving the bytes of the ClientHello message.
|
||||
type tlsClientHelloReader struct {
|
||||
net.Conn
|
||||
clientHello []byte
|
||||
}
|
||||
|
||||
func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
|
||||
count, err := c.Conn.Read(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.clientHello = append(c.clientHello, b[:count]...)
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Write prevents writing on the real connection
|
||||
func (c *tlsClientHelloReader) Write(b []byte) (int, error) {
|
||||
return 0, errors.New("cannot write on this connection")
|
||||
}
|
||||
|
||||
func (p *TLSProxy) reset(conn net.Conn) {
|
||||
if tc, ok := conn.(*net.TCPConn); ok {
|
||||
func (p *TLSServer) reset(conn net.Conn) {
|
||||
if tc, good := conn.(*net.TCPConn); good {
|
||||
tc.SetLinger(0)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func (p *TLSProxy) timeout(conn net.Conn) {
|
||||
buffer := make([]byte, 1<<14)
|
||||
conn.Read(buffer)
|
||||
func (p *TLSServer) eof(conn net.Conn) {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func (p *TLSProxy) eof(conn net.Conn) {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func (p *TLSProxy) alert(conn net.Conn, code byte) {
|
||||
func (p *TLSServer) alert(conn net.Conn, code byte) {
|
||||
alertdata := []byte{
|
||||
21, // alert
|
||||
3, // version[0]
|
||||
|
@ -194,55 +195,6 @@ func (p *TLSProxy) alert(conn net.Conn, code byte) {
|
|||
conn.Close()
|
||||
}
|
||||
|
||||
func (p *TLSProxy) proxy(conn net.Conn, sni string, hello []byte) {
|
||||
p.proxydial(conn, sni, hello, net.Dial)
|
||||
}
|
||||
|
||||
func (p *TLSProxy) proxydial(conn net.Conn, sni string, hello []byte,
|
||||
dial func(network, address string) (net.Conn, error)) {
|
||||
if sni == "" { // don't know the destination host
|
||||
p.reset(conn)
|
||||
return
|
||||
}
|
||||
serverconn, err := dial("tcp", net.JoinHostPort(sni, "443"))
|
||||
if err != nil {
|
||||
p.reset(conn)
|
||||
return
|
||||
}
|
||||
if p.connectingToMyself(serverconn) {
|
||||
p.reset(conn)
|
||||
return
|
||||
}
|
||||
if _, err := serverconn.Write(hello); err != nil {
|
||||
p.reset(conn)
|
||||
return
|
||||
}
|
||||
defer serverconn.Close() // conn is owned by the caller
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go p.forward(wg, conn, serverconn)
|
||||
go p.forward(wg, serverconn, conn)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// connectingToMyself returns true when the proxy has been somehow
|
||||
// forced to create a connection to itself.
|
||||
func (p *TLSProxy) connectingToMyself(conn net.Conn) bool {
|
||||
local := conn.LocalAddr().String()
|
||||
localAddr, _, localErr := net.SplitHostPort(local)
|
||||
remote := conn.RemoteAddr().String()
|
||||
remoteAddr, _, remoteErr := net.SplitHostPort(remote)
|
||||
return localErr != nil || remoteErr != nil || localAddr == remoteAddr
|
||||
}
|
||||
|
||||
// forward will forward the traffic.
|
||||
func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) {
|
||||
defer wg.Done()
|
||||
// We cannot use netxlite.CopyContext here because we want netxlite to
|
||||
// use filtering inside its test suite, so this package cannot depend on
|
||||
// netxlite. In general, we don't want to use io.Copy or io.ReadAll
|
||||
// directly because they may cause the code to block as documented in
|
||||
// internal/netxlite/iox.go. However, this package is only used for
|
||||
// testing, so it's completely okay to make an exception here.
|
||||
io.Copy(left, right)
|
||||
func (p *TLSServer) blockText(tlsConn net.Conn) {
|
||||
tlsConn.Write(HTTPBlockpage451)
|
||||
}
|
||||
|
|
|
@ -1,297 +1,146 @@
|
|||
package filtering
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
func TestTLSProxy(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
newproxy := func(action TLSAction) (net.Listener, <-chan interface{}, error) {
|
||||
p := &TLSProxy{
|
||||
OnIncomingSNI: func(sni string) TLSAction {
|
||||
return action
|
||||
},
|
||||
}
|
||||
return p.start("127.0.0.1:0")
|
||||
}
|
||||
|
||||
dialTLS := func(ctx context.Context, endpoint string, sni string) (net.Conn, error) {
|
||||
d := netxlite.NewDialerWithoutResolver(log.Log)
|
||||
th := netxlite.NewTLSHandshakerStdlib(log.Log)
|
||||
tdx := netxlite.NewTLSDialerWithConfig(d, th, &tls.Config{
|
||||
ServerName: sni,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
RootCAs: netxlite.NewDefaultCertPool(),
|
||||
})
|
||||
return tdx.DialTLSContext(ctx, "tcp", endpoint)
|
||||
}
|
||||
|
||||
t.Run("TLSActionPass", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionPass)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("TLSActionTimeout", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionTimeout)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
||||
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
|
||||
func TestTLSServer(t *testing.T) {
|
||||
t.Run("TLSActionReset", func(t *testing.T) {
|
||||
srv := NewTLSServer(TLSActionReset)
|
||||
defer srv.Close()
|
||||
config := &tls.Config{ServerName: "dns.google"}
|
||||
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||
if netxlite.NewTopLevelGenericErrWrapper(err).Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLSActionTimeout", func(t *testing.T) {
|
||||
srv := NewTLSServer(TLSActionTimeout)
|
||||
defer srv.Close()
|
||||
config := &tls.Config{ServerName: "dns.google"}
|
||||
d := &tls.Dialer{Config: config}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 70*time.Millisecond)
|
||||
defer cancel()
|
||||
conn, err := d.DialContext(ctx, "tcp", srv.Endpoint())
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("TLSActionAlertInternalError", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionAlertInternalError)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
||||
srv := NewTLSServer(TLSActionAlertInternalError)
|
||||
defer srv.Close()
|
||||
config := &tls.Config{ServerName: "dns.google"}
|
||||
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionAlertUnrecognizedName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
||||
srv := NewTLSServer(TLSActionAlertUnrecognizedName)
|
||||
defer srv.Close()
|
||||
config := &tls.Config{ServerName: "dns.google"}
|
||||
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("TLSActionEOF", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionEOF)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
||||
if err == nil || err.Error() != netxlite.FailureEOFError {
|
||||
srv := NewTLSServer(TLSActionEOF)
|
||||
defer srv.Close()
|
||||
config := &tls.Config{ServerName: "dns.google"}
|
||||
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("TLSActionReset", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionReset)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
t.Run("TLSActionBlockText", func(t *testing.T) {
|
||||
t.Run("certificate error when we're validating", func(t *testing.T) {
|
||||
srv := NewTLSServer(TLSActionBlockText)
|
||||
defer srv.Close()
|
||||
// Certificate.Verify now uses platform APIs to verify certificate validity
|
||||
// on macOS and iOS when it is called with a nil VerifyOpts.Roots or when using
|
||||
// the root pool returned from SystemCertPool. "
|
||||
//
|
||||
// -- https://tip.golang.org/doc/go1.18
|
||||
//
|
||||
// So we need to explicitly use our default cert pool otherwise we will
|
||||
// see this test failing with a different error string here.
|
||||
config := &tls.Config{
|
||||
ServerName: "dns.google",
|
||||
RootCAs: netxlite.NewDefaultCertPool(),
|
||||
}
|
||||
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "certificate signed by unknown authority") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
dial := func(ctx context.Context, endpoint string) (net.Conn, error) {
|
||||
d := netxlite.NewDialerWithoutResolver(log.Log)
|
||||
return d.DialContext(ctx, "tcp", endpoint)
|
||||
}
|
||||
t.Run("blocktext when we skip validation", func(t *testing.T) {
|
||||
srv := NewTLSServer(TLSActionBlockText)
|
||||
defer srv.Close()
|
||||
config := &tls.Config{InsecureSkipVerify: true, ServerName: "dns.google"}
|
||||
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
data, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(HTTPBlockpage451, data) {
|
||||
t.Fatal("unexpected block text")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handle cannot read ClientHello", func(t *testing.T) {
|
||||
listener, done, err := newproxy(TLSActionPass)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dial(context.Background(), listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n"))
|
||||
buff := make([]byte, 1<<17)
|
||||
_, err = conn.Read(buff)
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("TLSActionPass fails because we don't have SNI", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionPass)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "127.0.0.1")
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("TLSActionPass fails because we can't dial", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
listener, done, err := newproxy(TLSActionPass)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialTLS(ctx, listener.Addr().String(), "antani.ooni.org")
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
listener.Close()
|
||||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
t.Run("proxydial fails because it's connecting to itself", func(t *testing.T) {
|
||||
p := &TLSProxy{}
|
||||
conn := &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
MockLocalAddr: func() net.Addr {
|
||||
return &net.TCPAddr{
|
||||
IP: net.IPv6loopback,
|
||||
}
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &net.TCPAddr{
|
||||
IP: net.IPv6loopback,
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
t.Run("blocktext when we configure the cert pool", func(t *testing.T) {
|
||||
srv := NewTLSServer(TLSActionBlockText)
|
||||
defer srv.Close()
|
||||
config := &tls.Config{RootCAs: srv.CertPool(), ServerName: "dns.google"}
|
||||
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
data, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(HTTPBlockpage451, data) {
|
||||
t.Fatal("unexpected block text")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("proxydial fails because it cannot write the hello", func(t *testing.T) {
|
||||
p := &TLSProxy{}
|
||||
conn := &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
MockLocalAddr: func() net.Addr {
|
||||
return &net.TCPAddr{
|
||||
IP: net.IPv6loopback,
|
||||
}
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &net.TCPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
}
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, errors.New("mocked error")
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
||||
p := &TLSProxy{}
|
||||
listener, err := p.Start("127.0.0.1")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if listener != nil {
|
||||
t.Fatal("expected nil listener")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("oneloop correctly handles a listener error", func(t *testing.T) {
|
||||
listener := &mocks.Listener{
|
||||
MockAccept: func() (net.Conn, error) {
|
||||
return nil, errors.New("mocked error")
|
||||
},
|
||||
}
|
||||
p := &TLSProxy{}
|
||||
if !p.oneloop(listener) {
|
||||
t.Fatal("should return true here")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTLSClientHelloReader(t *testing.T) {
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
chr := &tlsClientHelloReader{
|
||||
Conn: &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, expected
|
||||
},
|
||||
},
|
||||
clientHello: []byte{},
|
||||
}
|
||||
buf := make([]byte, 128)
|
||||
count, err := chr.Read(buf)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -303,18 +303,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
|
|||
}
|
||||
|
||||
connectionResetFlow := func(th model.TLSHandshaker) error {
|
||||
tlsProxy := &filtering.TLSProxy{
|
||||
OnIncomingSNI: func(sni string) filtering.TLSAction {
|
||||
return filtering.TLSActionReset
|
||||
},
|
||||
}
|
||||
listener, err := tlsProxy.Start("127.0.0.1:0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot start proxy: %w", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
server := filtering.NewTLSServer(filtering.TLSActionReset)
|
||||
defer server.Close()
|
||||
ctx := context.Background()
|
||||
conn, err := dial(ctx, listener.Addr().String())
|
||||
conn, err := dial(ctx, server.Endpoint())
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial failed: %w", err)
|
||||
}
|
||||
|
@ -338,18 +330,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
|
|||
}
|
||||
|
||||
timeoutFlow := func(th model.TLSHandshaker) error {
|
||||
tlsProxy := &filtering.TLSProxy{
|
||||
OnIncomingSNI: func(sni string) filtering.TLSAction {
|
||||
return filtering.TLSActionTimeout
|
||||
},
|
||||
}
|
||||
listener, err := tlsProxy.Start("127.0.0.1:0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot start proxy: %w", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
server := filtering.NewTLSServer(filtering.TLSActionTimeout)
|
||||
defer server.Close()
|
||||
ctx := context.Background()
|
||||
conn, err := dial(ctx, listener.Addr().String())
|
||||
conn, err := dial(ctx, server.Endpoint())
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial failed: %w", err)
|
||||
}
|
||||
|
|
|
@ -20,12 +20,8 @@ type (
|
|||
HTTPTransportWrapper = httpTransportConnectionsCloser
|
||||
HTTPTransportLogger = httpTransportLogger
|
||||
ErrorWrapperResolver = resolverErrWrapper
|
||||
ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper
|
||||
ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport
|
||||
ResolverLogger = resolverLogger
|
||||
ResolverIDNA = resolverIDNA
|
||||
TLSHandshakerConfigurable = tlsHandshakerConfigurable
|
||||
TLSHandshakerLogger = tlsHandshakerLogger
|
||||
TLSDialerLegacy = tlsDialer
|
||||
AddressResolver = resolverShortCircuitIPAddr
|
||||
)
|
||||
|
|
|
@ -60,6 +60,15 @@ var (
|
|||
}
|
||||
)
|
||||
|
||||
// ClonedTLSConfigOrNewEmptyConfig returns a clone of the provided config,
|
||||
// if not nil, or a fresh and completely empty *tls.Config.
|
||||
func ClonedTLSConfigOrNewEmptyConfig(config *tls.Config) *tls.Config {
|
||||
if config != nil {
|
||||
return config.Clone()
|
||||
}
|
||||
return &tls.Config{}
|
||||
}
|
||||
|
||||
// TLSVersionString returns a TLS version string. If value is zero, we
|
||||
// return the empty string. If the value is unknown, we return
|
||||
// `TLS_VERSION_UNKNOWN_ddd` where `ddd` is the numeric value passed
|
||||
|
|
|
@ -591,3 +591,30 @@ func TestNewNullTLSDialer(t *testing.T) {
|
|||
}
|
||||
dialer.CloseIdleConnections() // does not crash
|
||||
}
|
||||
|
||||
func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
|
||||
t.Run("with nil config", func(t *testing.T) {
|
||||
var input *tls.Config
|
||||
output := ClonedTLSConfigOrNewEmptyConfig(input)
|
||||
if output == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
v := reflect.ValueOf(*output)
|
||||
if !v.IsZero() {
|
||||
t.Fatal("expected zero config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("", func(t *testing.T) {
|
||||
input := &tls.Config{
|
||||
ServerName: "dns.google",
|
||||
}
|
||||
output := ClonedTLSConfigOrNewEmptyConfig(input)
|
||||
if output == input {
|
||||
t.Fatal("expected two distinct objects")
|
||||
}
|
||||
if !reflect.DeepEqual(input, output) {
|
||||
t.Fatal("apparently the two objects have different values")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ for file in $(find . -type f -name \*.go); do
|
|||
# implement safer wrappers for these functions.
|
||||
continue
|
||||
fi
|
||||
if [ "$file" = "./internal/netxlite/filtering/tls.go" ]; then
|
||||
if [ "$file" = "./internal/netxlite/filtering/tls_test.go" ]; then
|
||||
# We're allowed to use ReadAll and Copy in this file to
|
||||
# avoid depending on netxlite, so we can use filtering
|
||||
# inside of netxlite's own test suite.
|
||||
|
|
Loading…
Reference in New Issue
Block a user