refactor(netx): use netxlite to build TLSDialer (#790)

This diff modifies netx to use netxlite to build the TLSDialer.

Building the TLSDialer entails building a TLSHandshaker.

While there, hide netxlite names we don't want to be public
and change netx tests to test for functionality.

To this end, refactor filtering to provide an easier to
use TLS server. We don't need the complexity of proxying
rather we need to provoke specific errors.

Part of https://github.com/ooni/probe/issues/2121
This commit is contained in:
Simone Basso 2022-06-02 17:39:48 +02:00 committed by GitHub
parent ae24ba644c
commit e9ed733f07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 363 additions and 690 deletions

View File

@ -64,8 +64,6 @@ type Config struct {
TLSSaver *tracex.Saver // default: not saving TLS TLSSaver *tracex.Saver // default: not saving TLS
} }
var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool()
// NewResolver creates a new resolver from the specified config // NewResolver creates a new resolver from the specified config
func NewResolver(config Config) model.Resolver { func NewResolver(config Config) model.Resolver {
if config.BaseResolver == nil { if config.BaseResolver == nil {
@ -132,25 +130,16 @@ func NewTLSDialer(config Config) model.TLSDialer {
if config.Dialer == nil { if config.Dialer == nil {
config.Dialer = NewDialer(config) config.Dialer = NewDialer(config)
} }
var h model.TLSHandshaker = &netxlite.TLSHandshakerConfigurable{} logger := model.ValidLoggerOrDefault(config.Logger)
h = &netxlite.ErrorWrapperTLSHandshaker{TLSHandshaker: h} thx := netxlite.NewTLSHandshakerStdlib(logger)
if config.Logger != nil { thx = config.TLSSaver.WrapTLSHandshaker(thx) // WAI when TLSSaver is nil
h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h} tlsConfig := netxlite.ClonedTLSConfigOrNewEmptyConfig(config.TLSConfig)
} // TODO(bassosimone): we should not provide confusing options and
h = config.TLSSaver.WrapTLSHandshaker(h) // behaves with nil TLSSaver // so we should drop CertPool and NoTLSVerify in favour of encouraging
if config.TLSConfig == nil { // the users of this library to always use a TLSConfig.
config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}} tlsConfig.RootCAs = config.CertPool // netxlite uses default cert pool if this is nil
} tlsConfig.InsecureSkipVerify = config.NoTLSVerify
if config.CertPool == nil { return netxlite.NewTLSDialerWithConfig(config.Dialer, thx, tlsConfig)
config.CertPool = defaultCertPool
}
config.TLSConfig.RootCAs = config.CertPool
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
return &netxlite.TLSDialerLegacy{
Config: config.TLSConfig,
Dialer: config.Dialer,
TLSHandshaker: h,
}
} }
// NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned // NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned

View File

@ -13,6 +13,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/tracex" "github.com/ooni/probe-cli/v3/internal/tracex"
) )
@ -208,210 +209,103 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
} }
} }
func TestNewTLSDialerVanilla(t *testing.T) { func TestNewTLSDialer(t *testing.T) {
td := NewTLSDialer(Config{}) t.Run("we always have error wrapping", func(t *testing.T) {
rtd, ok := td.(*netxlite.TLSDialerLegacy) server := filtering.NewTLSServer(filtering.TLSActionReset)
if !ok { defer server.Close()
t.Fatal("not the TLSDialer we expected") tdx := NewTLSDialer(Config{})
} conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if len(rtd.Config.NextProtos) != 2 { if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("invalid len(config.NextProtos)") t.Fatal("unexpected err", err)
} }
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { if conn != nil {
t.Fatal("invalid Config.NextProtos") t.Fatal("expected nil conn")
} }
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithConfig(t *testing.T) {
td := NewTLSDialer(Config{
TLSConfig: new(tls.Config),
}) })
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 0 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithLogging(t *testing.T) { t.Run("we can collect TLS measurements", func(t *testing.T) {
td := NewTLSDialer(Config{ server := filtering.NewTLSServer(filtering.TLSActionReset)
Logger: log.Log, defer server.Close()
saver := &tracex.Saver{}
tdx := NewTLSDialer(Config{
TLSSaver: saver,
})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
if len(saver.Read()) <= 0 {
t.Fatal("did not read any event")
}
}) })
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 2 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
t.Fatal("invalid Config.NextProtos")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
lth, ok := rtd.TLSHandshaker.(*netxlite.TLSHandshakerLogger)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if lth.DebugLogger != log.Log {
t.Fatal("not the Logger we expected")
}
ewth, ok := lth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithSaver(t *testing.T) { t.Run("we can collect dial measurements", func(t *testing.T) {
saver := new(tracex.Saver) server := filtering.NewTLSServer(filtering.TLSActionReset)
td := NewTLSDialer(Config{ defer server.Close()
TLSSaver: saver, saver := &tracex.Saver{}
tdx := NewTLSDialer(Config{
DialSaver: saver,
})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
if len(saver.Read()) <= 0 {
t.Fatal("did not read any event")
}
}) })
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 2 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
t.Fatal("invalid Config.NextProtos")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if sth.Saver != saver {
t.Fatal("not the Logger we expected")
}
ewth, ok := sth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { t.Run("we can collect I/O measurements", func(t *testing.T) {
td := NewTLSDialer(Config{ server := filtering.NewTLSServer(filtering.TLSActionReset)
TLSConfig: new(tls.Config), defer server.Close()
NoTLSVerify: true, saver := &tracex.Saver{}
tdx := NewTLSDialer(Config{
ReadWriteSaver: saver,
})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
if len(saver.Read()) <= 0 {
t.Fatal("did not read any event")
}
}) })
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 0 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.InsecureSkipVerify != true {
t.Fatal("expected true InsecureSkipVerify")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { t.Run("we can skip TLS verification", func(t *testing.T) {
td := NewTLSDialer(Config{ server := filtering.NewTLSServer(filtering.TLSActionBlockText)
NoTLSVerify: true, defer server.Close()
tdx := NewTLSDialer(Config{NoTLSVerify: true})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err != nil {
t.Fatal(err.(*netxlite.ErrWrapper).WrappedErr)
}
conn.Close()
})
t.Run("we can set the cert pool", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
defer server.Close()
tdx := NewTLSDialer(Config{
CertPool: server.CertPool(),
TLSConfig: &tls.Config{
ServerName: "dns.google",
},
})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err != nil {
t.Fatal(err)
}
conn.Close()
}) })
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 2 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
t.Fatal("invalid Config.NextProtos")
}
if rtd.Config.InsecureSkipVerify != true {
t.Fatal("expected true InsecureSkipVerify")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
} }
func TestNewVanilla(t *testing.T) { func TestNewVanilla(t *testing.T) {
@ -441,33 +335,6 @@ func TestNewWithDialer(t *testing.T) {
} }
} }
func TestNewWithTLSDialer(t *testing.T) {
expected := errors.New("mocked error")
tlsDialer := &netxlite.TLSDialerLegacy{
Config: new(tls.Config),
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, expected
},
MockCloseIdleConnections: func() {
// nothing
},
},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
}
txp := NewHTTPTransport(Config{
TLSDialer: tlsDialer,
})
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("not the response we expected")
}
}
func TestNewWithByteCounter(t *testing.T) { func TestNewWithByteCounter(t *testing.T) {
counter := bytecounter.New() counter := bytecounter.New()
txp := NewHTTPTransport(Config{ txp := NewHTTPTransport(Config{

View File

@ -1,24 +1,22 @@
package filtering package filtering
import ( import (
"context"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors" "errors"
"io"
"net" "net"
"strings" "time"
"sync"
)
// TODO(bassosimone): remove TLSActionPass since we want integration tests "github.com/google/martian/v3/mitm"
// to only run locally to make them much more predictable. "github.com/ooni/probe-cli/v3/internal/runtimex"
)
// TLSAction is a TLS filtering action that this proxy should take. // TLSAction is a TLS filtering action that this proxy should take.
type TLSAction string type TLSAction string
const ( const (
// TLSActionPass passes the traffic to the destination.
TLSActionPass = TLSAction("pass")
// TLSActionReset resets the connection. // TLSActionReset resets the connection.
TLSActionReset = TLSAction("reset") TLSActionReset = TLSAction("reset")
@ -35,48 +33,98 @@ const (
// TLSActionAlertUnrecognizedName tells the client that // TLSActionAlertUnrecognizedName tells the client that
// it's handshaking with an unknown SNI. // it's handshaking with an unknown SNI.
TLSActionAlertUnrecognizedName = TLSAction("alert-unrecognized-name") TLSActionAlertUnrecognizedName = TLSAction("alert-unrecognized-name")
// TLSActionBlockText returns a static piece of text
// to the client saying this website is blocked.
TLSActionBlockText = TLSAction("block-text")
) )
// TLSProxy is a TLS proxy that routes the traffic depending // TLSServer is a TLS server implementing filtering policies.
// on the SNI value and may implement filtering policies. type TLSServer struct {
type TLSProxy struct { // action is the action to perform.
// OnIncomingSNI is the MANDATORY hook called whenever we have action TLSAction
// successfully received a ClientHello message.
OnIncomingSNI func(sni string) TLSAction // cancel allows to cancel background operations.
cancel context.CancelFunc
// cert is the fake CA certificate.
cert *x509.Certificate
// config is the config to generate certificates on the fly.
config *mitm.Config
// done is closed when the background goroutine has terminated.
done chan bool
// endpoint is the endpoint where we're listening.
endpoint string
// listener is the TCP listener.
listener net.Listener
// privkey is the private key that signed the cert.
privkey *rsa.PrivateKey
} }
// Start starts the proxy. // NewTLSServer creates and starts a new TLSServer that executes
func (p *TLSProxy) Start(address string) (net.Listener, error) { // the given action during the TLS handshake.
listener, _, err := p.start(address) func NewTLSServer(action TLSAction) *TLSServer {
return listener, err done := make(chan bool)
} cert, privkey, err := mitm.NewAuthority("jafar", "OONI", 24*time.Hour)
runtimex.PanicOnError(err, "mitm.NewAuthority failed")
func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, error) { config, err := mitm.NewConfig(cert, privkey)
listener, err := net.Listen("tcp", address) runtimex.PanicOnError(err, "mitm.NewConfig failed")
if err != nil { listener, err := net.Listen("tcp", "127.0.0.1:0")
return nil, nil, err runtimex.PanicOnError(err, "net.Listen failed")
ctx, cancel := context.WithCancel(context.Background())
endpoint := listener.Addr().String()
server := &TLSServer{
action: action,
cancel: cancel,
cert: cert,
config: config,
done: done,
endpoint: endpoint,
listener: listener,
privkey: privkey,
} }
done := make(chan interface{}) go server.mainloop(ctx)
go p.mainloop(listener, done) return server
return listener, done, nil
} }
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) { // CertPool returns the internal CA as a cert pool.
defer close(done) func (p *TLSServer) CertPool() *x509.CertPool {
for p.oneloop(listener) { o := x509.NewCertPool()
o.AddCert(p.cert)
return o
}
// Endpoint returns the endpoint where the server is listening.
func (p *TLSServer) Endpoint() string {
return p.endpoint
}
// Close closes this server as soon as possible.
func (p *TLSServer) Close() error {
p.cancel()
err := p.listener.Close()
<-p.done
return err
}
func (p *TLSServer) mainloop(ctx context.Context) {
defer close(p.done)
for p.oneloop(ctx) {
// nothing // nothing
} }
} }
func (p *TLSProxy) oneloop(listener net.Listener) bool { func (p *TLSServer) oneloop(ctx context.Context) bool {
conn, err := listener.Accept() conn, err := p.listener.Accept()
if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") {
return false // we need to stop
}
if err != nil { if err != nil {
return true // we can continue running return !errors.Is(err, net.ErrClosed)
} }
go p.handle(conn) go p.handle(ctx, conn)
return true // we can continue running return true // we can continue running
} }
@ -85,102 +133,55 @@ const (
tlsAlertUnrecognizedName = byte(112) tlsAlertUnrecognizedName = byte(112)
) )
func (p *TLSProxy) handle(conn net.Conn) { func (p *TLSServer) handle(ctx context.Context, tcpConn net.Conn) {
defer conn.Close() defer tcpConn.Close()
sni, hello, err := p.readClientHello(conn) tlsConn := tls.Server(tcpConn, &tls.Config{
if err != nil { GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
p.reset(conn) switch p.action {
case TLSActionTimeout:
select {
case <-time.After(300 * time.Second):
return nil, errors.New("timing out the connection")
case <-ctx.Done():
p.reset(tcpConn)
return nil, ctx.Err()
}
case TLSActionAlertInternalError:
p.alert(tcpConn, tlsAlertInternalError)
return nil, errors.New("already sent alert")
case TLSActionAlertUnrecognizedName:
p.alert(tcpConn, tlsAlertUnrecognizedName)
return nil, errors.New("already sent alert")
case TLSActionEOF:
p.eof(tcpConn)
return nil, errors.New("already closed the connection")
case TLSActionBlockText:
return p.config.TLSForHost(info.ServerName).GetCertificate(info)
default:
p.reset(tcpConn)
return nil, errors.New("already RST the connection")
}
},
})
if err := tlsConn.Handshake(); err != nil {
return return
} }
switch p.OnIncomingSNI(sni) { p.blockText(tlsConn)
case TLSActionPass: tlsConn.Close()
p.proxy(conn, sni, hello)
case TLSActionTimeout:
p.timeout(conn)
case TLSActionAlertInternalError:
p.alert(conn, tlsAlertInternalError)
case TLSActionAlertUnrecognizedName:
p.alert(conn, tlsAlertUnrecognizedName)
case TLSActionEOF:
p.eof(conn)
default:
p.reset(conn)
}
} }
// readClientHello reads the incoming ClientHello message. func (p *TLSServer) reset(conn net.Conn) {
// if tc, good := conn.(*net.TCPConn); good {
// Arguments:
//
// - conn is the connection from which to read the ClientHello.
//
// Returns:
//
// - a string containing the SNI (empty on error);
//
// - bytes from the original ClientHello (nil on error);
//
// - an error (nil on success).
func (p *TLSProxy) readClientHello(conn net.Conn) (string, []byte, error) {
connWrapper := &tlsClientHelloReader{Conn: conn}
var (
expectedErr = errors.New("cannot continue handhake")
sni string
mutex sync.Mutex // just for safety
)
err := tls.Server(connWrapper, &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
mutex.Lock()
sni = info.ServerName
mutex.Unlock()
return nil, expectedErr
},
}).Handshake()
if !errors.Is(err, expectedErr) {
return "", nil, err
}
return sni, connWrapper.clientHello, nil
}
// tlsClientHelloReader wraps a net.Conn for the purpose of
// saving the bytes of the ClientHello message.
type tlsClientHelloReader struct {
net.Conn
clientHello []byte
}
func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b)
if err != nil {
return 0, err
}
c.clientHello = append(c.clientHello, b[:count]...)
return count, nil
}
// Write prevents writing on the real connection
func (c *tlsClientHelloReader) Write(b []byte) (int, error) {
return 0, errors.New("cannot write on this connection")
}
func (p *TLSProxy) reset(conn net.Conn) {
if tc, ok := conn.(*net.TCPConn); ok {
tc.SetLinger(0) tc.SetLinger(0)
} }
conn.Close() conn.Close()
} }
func (p *TLSProxy) timeout(conn net.Conn) { func (p *TLSServer) eof(conn net.Conn) {
buffer := make([]byte, 1<<14)
conn.Read(buffer)
conn.Close() conn.Close()
} }
func (p *TLSProxy) eof(conn net.Conn) { func (p *TLSServer) alert(conn net.Conn, code byte) {
conn.Close()
}
func (p *TLSProxy) alert(conn net.Conn, code byte) {
alertdata := []byte{ alertdata := []byte{
21, // alert 21, // alert
3, // version[0] 3, // version[0]
@ -194,55 +195,6 @@ func (p *TLSProxy) alert(conn net.Conn, code byte) {
conn.Close() conn.Close()
} }
func (p *TLSProxy) proxy(conn net.Conn, sni string, hello []byte) { func (p *TLSServer) blockText(tlsConn net.Conn) {
p.proxydial(conn, sni, hello, net.Dial) tlsConn.Write(HTTPBlockpage451)
}
func (p *TLSProxy) proxydial(conn net.Conn, sni string, hello []byte,
dial func(network, address string) (net.Conn, error)) {
if sni == "" { // don't know the destination host
p.reset(conn)
return
}
serverconn, err := dial("tcp", net.JoinHostPort(sni, "443"))
if err != nil {
p.reset(conn)
return
}
if p.connectingToMyself(serverconn) {
p.reset(conn)
return
}
if _, err := serverconn.Write(hello); err != nil {
p.reset(conn)
return
}
defer serverconn.Close() // conn is owned by the caller
wg := &sync.WaitGroup{}
wg.Add(2)
go p.forward(wg, conn, serverconn)
go p.forward(wg, serverconn, conn)
wg.Wait()
}
// connectingToMyself returns true when the proxy has been somehow
// forced to create a connection to itself.
func (p *TLSProxy) connectingToMyself(conn net.Conn) bool {
local := conn.LocalAddr().String()
localAddr, _, localErr := net.SplitHostPort(local)
remote := conn.RemoteAddr().String()
remoteAddr, _, remoteErr := net.SplitHostPort(remote)
return localErr != nil || remoteErr != nil || localAddr == remoteAddr
}
// forward will forward the traffic.
func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) {
defer wg.Done()
// We cannot use netxlite.CopyContext here because we want netxlite to
// use filtering inside its test suite, so this package cannot depend on
// netxlite. In general, we don't want to use io.Copy or io.ReadAll
// directly because they may cause the code to block as documented in
// internal/netxlite/iox.go. However, this package is only used for
// testing, so it's completely okay to make an exception here.
io.Copy(left, right)
} }

View File

@ -1,297 +1,146 @@
package filtering package filtering
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "io"
"strings" "strings"
"testing" "testing"
"time"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
func TestTLSProxy(t *testing.T) { func TestTLSServer(t *testing.T) {
if testing.Short() { t.Run("TLSActionReset", func(t *testing.T) {
t.Skip("skip test in short mode") srv := NewTLSServer(TLSActionReset)
} defer srv.Close()
newproxy := func(action TLSAction) (net.Listener, <-chan interface{}, error) { config := &tls.Config{ServerName: "dns.google"}
p := &TLSProxy{ conn, err := tls.Dial("tcp", srv.Endpoint(), config)
OnIncomingSNI: func(sni string) TLSAction { if netxlite.NewTopLevelGenericErrWrapper(err).Error() != netxlite.FailureConnectionReset {
return action t.Fatal("unexpected err", err)
}, }
} if conn != nil {
return p.start("127.0.0.1:0") t.Fatal("expected nil conn")
} }
})
dialTLS := func(ctx context.Context, endpoint string, sni string) (net.Conn, error) {
d := netxlite.NewDialerWithoutResolver(log.Log) t.Run("TLSActionTimeout", func(t *testing.T) {
th := netxlite.NewTLSHandshakerStdlib(log.Log) srv := NewTLSServer(TLSActionTimeout)
tdx := netxlite.NewTLSDialerWithConfig(d, th, &tls.Config{ defer srv.Close()
ServerName: sni, config := &tls.Config{ServerName: "dns.google"}
NextProtos: []string{"h2", "http/1.1"}, d := &tls.Dialer{Config: config}
RootCAs: netxlite.NewDefaultCertPool(), ctx, cancel := context.WithTimeout(context.Background(), 70*time.Millisecond)
}) defer cancel()
return tdx.DialTLSContext(ctx, "tcp", endpoint) conn, err := d.DialContext(ctx, "tcp", srv.Endpoint())
} if !errors.Is(err, context.DeadlineExceeded) {
t.Run("TLSActionPass", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionPass)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err != nil {
t.Fatal(err)
}
conn.Close()
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionTimeout", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionTimeout)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if conn != nil { if conn != nil {
t.Fatal("expected nil conn") t.Fatal("expected nil conn")
} }
listener.Close()
<-done // wait for background goroutine to exit
}) })
t.Run("TLSActionAlertInternalError", func(t *testing.T) { t.Run("TLSActionAlertInternalError", func(t *testing.T) {
ctx := context.Background() srv := NewTLSServer(TLSActionAlertInternalError)
listener, done, err := newproxy(TLSActionAlertInternalError) defer srv.Close()
if err != nil { config := &tls.Config{ServerName: "dns.google"}
t.Fatal(err) conn, err := tls.Dial("tcp", srv.Endpoint(), config)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") { if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if conn != nil { if conn != nil {
t.Fatal("expected nil conn") t.Fatal("expected nil conn")
} }
listener.Close()
<-done // wait for background goroutine to exit
}) })
t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) { t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) {
ctx := context.Background() srv := NewTLSServer(TLSActionAlertUnrecognizedName)
listener, done, err := newproxy(TLSActionAlertUnrecognizedName) defer srv.Close()
if err != nil { config := &tls.Config{ServerName: "dns.google"}
t.Fatal(err) conn, err := tls.Dial("tcp", srv.Endpoint(), config)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") { if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if conn != nil { if conn != nil {
t.Fatal("expected nil conn") t.Fatal("expected nil conn")
} }
listener.Close()
<-done // wait for background goroutine to exit
}) })
t.Run("TLSActionEOF", func(t *testing.T) { t.Run("TLSActionEOF", func(t *testing.T) {
ctx := context.Background() srv := NewTLSServer(TLSActionEOF)
listener, done, err := newproxy(TLSActionEOF) defer srv.Close()
if err != nil { config := &tls.Config{ServerName: "dns.google"}
t.Fatal(err) conn, err := tls.Dial("tcp", srv.Endpoint(), config)
} if !errors.Is(err, io.EOF) {
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err == nil || err.Error() != netxlite.FailureEOFError {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if conn != nil { if conn != nil {
t.Fatal("expected nil conn") t.Fatal("expected nil conn")
} }
listener.Close()
<-done // wait for background goroutine to exit
}) })
t.Run("TLSActionReset", func(t *testing.T) { t.Run("TLSActionBlockText", func(t *testing.T) {
ctx := context.Background() t.Run("certificate error when we're validating", func(t *testing.T) {
listener, done, err := newproxy(TLSActionReset) srv := NewTLSServer(TLSActionBlockText)
if err != nil { defer srv.Close()
t.Fatal(err) // Certificate.Verify now uses platform APIs to verify certificate validity
} // on macOS and iOS when it is called with a nil VerifyOpts.Roots or when using
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") // the root pool returned from SystemCertPool. "
if err == nil || err.Error() != netxlite.FailureConnectionReset { //
t.Fatal("unexpected err", err) // -- https://tip.golang.org/doc/go1.18
} //
if conn != nil { // So we need to explicitly use our default cert pool otherwise we will
t.Fatal("expected nil conn") // see this test failing with a different error string here.
} config := &tls.Config{
listener.Close() ServerName: "dns.google",
<-done // wait for background goroutine to exit RootCAs: netxlite.NewDefaultCertPool(),
}) }
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if err == nil || !strings.HasSuffix(err.Error(), "certificate signed by unknown authority") {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
dial := func(ctx context.Context, endpoint string) (net.Conn, error) { t.Run("blocktext when we skip validation", func(t *testing.T) {
d := netxlite.NewDialerWithoutResolver(log.Log) srv := NewTLSServer(TLSActionBlockText)
return d.DialContext(ctx, "tcp", endpoint) defer srv.Close()
} config := &tls.Config{InsecureSkipVerify: true, ServerName: "dns.google"}
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
data, err := io.ReadAll(conn)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(HTTPBlockpage451, data) {
t.Fatal("unexpected block text")
}
})
t.Run("handle cannot read ClientHello", func(t *testing.T) { t.Run("blocktext when we configure the cert pool", func(t *testing.T) {
listener, done, err := newproxy(TLSActionPass) srv := NewTLSServer(TLSActionBlockText)
if err != nil { defer srv.Close()
t.Fatal(err) config := &tls.Config{RootCAs: srv.CertPool(), ServerName: "dns.google"}
} conn, err := tls.Dial("tcp", srv.Endpoint(), config)
conn, err := dial(context.Background(), listener.Addr().String()) if err != nil {
if err != nil { t.Fatal(err)
t.Fatal(err) }
} defer conn.Close()
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n")) data, err := io.ReadAll(conn)
buff := make([]byte, 1<<17) if err != nil {
_, err = conn.Read(buff) t.Fatal(err)
if err == nil || err.Error() != netxlite.FailureConnectionReset { }
t.Fatal("unexpected err", err) if !bytes.Equal(HTTPBlockpage451, data) {
} t.Fatal("unexpected block text")
listener.Close() }
<-done // wait for background goroutine to exit
})
t.Run("TLSActionPass fails because we don't have SNI", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionPass)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "127.0.0.1")
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionPass fails because we can't dial", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionPass)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "antani.ooni.org")
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("proxydial fails because it's connecting to itself", func(t *testing.T) {
p := &TLSProxy{}
conn := &mocks.Conn{
MockClose: func() error {
return nil
},
}
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv6loopback,
}
},
MockRemoteAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv6loopback,
}
},
}, nil
}) })
}) })
t.Run("proxydial fails because it cannot write the hello", func(t *testing.T) {
p := &TLSProxy{}
conn := &mocks.Conn{
MockClose: func() error {
return nil
},
}
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv6loopback,
}
},
MockRemoteAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv4(10, 0, 0, 1),
}
},
MockWrite: func(b []byte) (int, error) {
return 0, errors.New("mocked error")
},
}, nil
})
})
t.Run("Start fails on an invalid address", func(t *testing.T) {
p := &TLSProxy{}
listener, err := p.Start("127.0.0.1")
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
t.Fatal("unexpected err", err)
}
if listener != nil {
t.Fatal("expected nil listener")
}
})
t.Run("oneloop correctly handles a listener error", func(t *testing.T) {
listener := &mocks.Listener{
MockAccept: func() (net.Conn, error) {
return nil, errors.New("mocked error")
},
}
p := &TLSProxy{}
if !p.oneloop(listener) {
t.Fatal("should return true here")
}
})
}
func TestTLSClientHelloReader(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
chr := &tlsClientHelloReader{
Conn: &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
},
clientHello: []byte{},
}
buf := make([]byte, 128)
count, err := chr.Read(buf)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("invalid count")
}
})
} }

View File

@ -303,18 +303,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
} }
connectionResetFlow := func(th model.TLSHandshaker) error { connectionResetFlow := func(th model.TLSHandshaker) error {
tlsProxy := &filtering.TLSProxy{ server := filtering.NewTLSServer(filtering.TLSActionReset)
OnIncomingSNI: func(sni string) filtering.TLSAction { defer server.Close()
return filtering.TLSActionReset
},
}
listener, err := tlsProxy.Start("127.0.0.1:0")
if err != nil {
return fmt.Errorf("cannot start proxy: %w", err)
}
defer listener.Close()
ctx := context.Background() ctx := context.Background()
conn, err := dial(ctx, listener.Addr().String()) conn, err := dial(ctx, server.Endpoint())
if err != nil { if err != nil {
return fmt.Errorf("dial failed: %w", err) return fmt.Errorf("dial failed: %w", err)
} }
@ -338,18 +330,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
} }
timeoutFlow := func(th model.TLSHandshaker) error { timeoutFlow := func(th model.TLSHandshaker) error {
tlsProxy := &filtering.TLSProxy{ server := filtering.NewTLSServer(filtering.TLSActionTimeout)
OnIncomingSNI: func(sni string) filtering.TLSAction { defer server.Close()
return filtering.TLSActionTimeout
},
}
listener, err := tlsProxy.Start("127.0.0.1:0")
if err != nil {
return fmt.Errorf("cannot start proxy: %w", err)
}
defer listener.Close()
ctx := context.Background() ctx := context.Background()
conn, err := dial(ctx, listener.Addr().String()) conn, err := dial(ctx, server.Endpoint())
if err != nil { if err != nil {
return fmt.Errorf("dial failed: %w", err) return fmt.Errorf("dial failed: %w", err)
} }

View File

@ -20,12 +20,8 @@ type (
HTTPTransportWrapper = httpTransportConnectionsCloser HTTPTransportWrapper = httpTransportConnectionsCloser
HTTPTransportLogger = httpTransportLogger HTTPTransportLogger = httpTransportLogger
ErrorWrapperResolver = resolverErrWrapper ErrorWrapperResolver = resolverErrWrapper
ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper
ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport
ResolverLogger = resolverLogger ResolverLogger = resolverLogger
ResolverIDNA = resolverIDNA ResolverIDNA = resolverIDNA
TLSHandshakerConfigurable = tlsHandshakerConfigurable
TLSHandshakerLogger = tlsHandshakerLogger
TLSDialerLegacy = tlsDialer
AddressResolver = resolverShortCircuitIPAddr AddressResolver = resolverShortCircuitIPAddr
) )

View File

@ -60,6 +60,15 @@ var (
} }
) )
// ClonedTLSConfigOrNewEmptyConfig returns a clone of the provided config,
// if not nil, or a fresh and completely empty *tls.Config.
func ClonedTLSConfigOrNewEmptyConfig(config *tls.Config) *tls.Config {
if config != nil {
return config.Clone()
}
return &tls.Config{}
}
// TLSVersionString returns a TLS version string. If value is zero, we // TLSVersionString returns a TLS version string. If value is zero, we
// return the empty string. If the value is unknown, we return // return the empty string. If the value is unknown, we return
// `TLS_VERSION_UNKNOWN_ddd` where `ddd` is the numeric value passed // `TLS_VERSION_UNKNOWN_ddd` where `ddd` is the numeric value passed

View File

@ -591,3 +591,30 @@ func TestNewNullTLSDialer(t *testing.T) {
} }
dialer.CloseIdleConnections() // does not crash dialer.CloseIdleConnections() // does not crash
} }
func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
t.Run("with nil config", func(t *testing.T) {
var input *tls.Config
output := ClonedTLSConfigOrNewEmptyConfig(input)
if output == nil {
t.Fatal("expected non-nil result")
}
v := reflect.ValueOf(*output)
if !v.IsZero() {
t.Fatal("expected zero config")
}
})
t.Run("", func(t *testing.T) {
input := &tls.Config{
ServerName: "dns.google",
}
output := ClonedTLSConfigOrNewEmptyConfig(input)
if output == input {
t.Fatal("expected two distinct objects")
}
if !reflect.DeepEqual(input, output) {
t.Fatal("apparently the two objects have different values")
}
})
}

View File

@ -7,7 +7,7 @@ for file in $(find . -type f -name \*.go); do
# implement safer wrappers for these functions. # implement safer wrappers for these functions.
continue continue
fi fi
if [ "$file" = "./internal/netxlite/filtering/tls.go" ]; then if [ "$file" = "./internal/netxlite/filtering/tls_test.go" ]; then
# We're allowed to use ReadAll and Copy in this file to # We're allowed to use ReadAll and Copy in this file to
# avoid depending on netxlite, so we can use filtering # avoid depending on netxlite, so we can use filtering
# inside of netxlite's own test suite. # inside of netxlite's own test suite.