refactor(netx): use netxlite to build TLSDialer (#790)

This diff modifies netx to use netxlite to build the TLSDialer.

Building the TLSDialer entails building a TLSHandshaker.

While there, hide netxlite names we don't want to be public
and change netx tests to test for functionality.

To this end, refactor filtering to provide an easier to
use TLS server. We don't need the complexity of proxying
rather we need to provoke specific errors.

Part of https://github.com/ooni/probe/issues/2121
This commit is contained in:
Simone Basso 2022-06-02 17:39:48 +02:00 committed by GitHub
parent ae24ba644c
commit e9ed733f07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 363 additions and 690 deletions

View File

@ -64,8 +64,6 @@ type Config struct {
TLSSaver *tracex.Saver // default: not saving TLS
}
var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool()
// NewResolver creates a new resolver from the specified config
func NewResolver(config Config) model.Resolver {
if config.BaseResolver == nil {
@ -132,25 +130,16 @@ func NewTLSDialer(config Config) model.TLSDialer {
if config.Dialer == nil {
config.Dialer = NewDialer(config)
}
var h model.TLSHandshaker = &netxlite.TLSHandshakerConfigurable{}
h = &netxlite.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
if config.Logger != nil {
h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h}
}
h = config.TLSSaver.WrapTLSHandshaker(h) // behaves with nil TLSSaver
if config.TLSConfig == nil {
config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
}
if config.CertPool == nil {
config.CertPool = defaultCertPool
}
config.TLSConfig.RootCAs = config.CertPool
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
return &netxlite.TLSDialerLegacy{
Config: config.TLSConfig,
Dialer: config.Dialer,
TLSHandshaker: h,
}
logger := model.ValidLoggerOrDefault(config.Logger)
thx := netxlite.NewTLSHandshakerStdlib(logger)
thx = config.TLSSaver.WrapTLSHandshaker(thx) // WAI when TLSSaver is nil
tlsConfig := netxlite.ClonedTLSConfigOrNewEmptyConfig(config.TLSConfig)
// TODO(bassosimone): we should not provide confusing options and
// so we should drop CertPool and NoTLSVerify in favour of encouraging
// the users of this library to always use a TLSConfig.
tlsConfig.RootCAs = config.CertPool // netxlite uses default cert pool if this is nil
tlsConfig.InsecureSkipVerify = config.NoTLSVerify
return netxlite.NewTLSDialerWithConfig(config.Dialer, thx, tlsConfig)
}
// NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned

View File

@ -13,6 +13,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/tracex"
)
@ -208,210 +209,103 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
}
}
func TestNewTLSDialerVanilla(t *testing.T) {
td := NewTLSDialer(Config{})
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
func TestNewTLSDialer(t *testing.T) {
t.Run("we always have error wrapping", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionReset)
defer server.Close()
tdx := NewTLSDialer(Config{})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if len(rtd.Config.NextProtos) != 2 {
t.Fatal("invalid len(config.NextProtos)")
if conn != nil {
t.Fatal("expected nil conn")
}
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
t.Fatal("invalid Config.NextProtos")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithConfig(t *testing.T) {
td := NewTLSDialer(Config{
TLSConfig: new(tls.Config),
})
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 0 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithLogging(t *testing.T) {
td := NewTLSDialer(Config{
Logger: log.Log,
})
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 2 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
t.Fatal("invalid Config.NextProtos")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
lth, ok := rtd.TLSHandshaker.(*netxlite.TLSHandshakerLogger)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if lth.DebugLogger != log.Log {
t.Fatal("not the Logger we expected")
}
ewth, ok := lth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithSaver(t *testing.T) {
saver := new(tracex.Saver)
td := NewTLSDialer(Config{
t.Run("we can collect TLS measurements", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionReset)
defer server.Close()
saver := &tracex.Saver{}
tdx := NewTLSDialer(Config{
TLSSaver: saver,
})
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if len(rtd.Config.NextProtos) != 2 {
t.Fatal("invalid len(config.NextProtos)")
if conn != nil {
t.Fatal("expected nil conn")
}
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
t.Fatal("invalid Config.NextProtos")
if len(saver.Read()) <= 0 {
t.Fatal("did not read any event")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if sth.Saver != saver {
t.Fatal("not the Logger we expected")
}
ewth, ok := sth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
td := NewTLSDialer(Config{
TLSConfig: new(tls.Config),
NoTLSVerify: true,
})
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
}
if len(rtd.Config.NextProtos) != 0 {
t.Fatal("invalid len(config.NextProtos)")
}
if rtd.Config.InsecureSkipVerify != true {
t.Fatal("expected true InsecureSkipVerify")
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
td := NewTLSDialer(Config{
NoTLSVerify: true,
t.Run("we can collect dial measurements", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionReset)
defer server.Close()
saver := &tracex.Saver{}
tdx := NewTLSDialer(Config{
DialSaver: saver,
})
rtd, ok := td.(*netxlite.TLSDialerLegacy)
if !ok {
t.Fatal("not the TLSDialer we expected")
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if len(rtd.Config.NextProtos) != 2 {
t.Fatal("invalid len(config.NextProtos)")
if conn != nil {
t.Fatal("expected nil conn")
}
if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" {
t.Fatal("invalid Config.NextProtos")
if len(saver.Read()) <= 0 {
t.Fatal("did not read any event")
}
if rtd.Config.InsecureSkipVerify != true {
t.Fatal("expected true InsecureSkipVerify")
})
t.Run("we can collect I/O measurements", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionReset)
defer server.Close()
saver := &tracex.Saver{}
tdx := NewTLSDialer(Config{
ReadWriteSaver: saver,
})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if rtd.Config.RootCAs != defaultCertPool {
t.Fatal("invalid Config.RootCAs")
if conn != nil {
t.Fatal("expected nil conn")
}
if rtd.Dialer == nil {
t.Fatal("invalid Dialer")
if len(saver.Read()) <= 0 {
t.Fatal("did not read any event")
}
if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker")
})
t.Run("we can skip TLS verification", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
defer server.Close()
tdx := NewTLSDialer(Config{NoTLSVerify: true})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err != nil {
t.Fatal(err.(*netxlite.ErrWrapper).WrappedErr)
}
ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok {
t.Fatal("not the TLSHandshaker we expected")
conn.Close()
})
t.Run("we can set the cert pool", func(t *testing.T) {
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
defer server.Close()
tdx := NewTLSDialer(Config{
CertPool: server.CertPool(),
TLSConfig: &tls.Config{
ServerName: "dns.google",
},
})
conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint())
if err != nil {
t.Fatal(err)
}
conn.Close()
})
}
func TestNewVanilla(t *testing.T) {
@ -441,33 +335,6 @@ func TestNewWithDialer(t *testing.T) {
}
}
func TestNewWithTLSDialer(t *testing.T) {
expected := errors.New("mocked error")
tlsDialer := &netxlite.TLSDialerLegacy{
Config: new(tls.Config),
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, expected
},
MockCloseIdleConnections: func() {
// nothing
},
},
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
}
txp := NewHTTPTransport(Config{
TLSDialer: tlsDialer,
})
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("not the response we expected")
}
}
func TestNewWithByteCounter(t *testing.T) {
counter := bytecounter.New()
txp := NewHTTPTransport(Config{

View File

@ -1,24 +1,22 @@
package filtering
import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"strings"
"sync"
)
"time"
// TODO(bassosimone): remove TLSActionPass since we want integration tests
// to only run locally to make them much more predictable.
"github.com/google/martian/v3/mitm"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
// TLSAction is a TLS filtering action that this proxy should take.
type TLSAction string
const (
// TLSActionPass passes the traffic to the destination.
TLSActionPass = TLSAction("pass")
// TLSActionReset resets the connection.
TLSActionReset = TLSAction("reset")
@ -35,48 +33,98 @@ const (
// TLSActionAlertUnrecognizedName tells the client that
// it's handshaking with an unknown SNI.
TLSActionAlertUnrecognizedName = TLSAction("alert-unrecognized-name")
// TLSActionBlockText returns a static piece of text
// to the client saying this website is blocked.
TLSActionBlockText = TLSAction("block-text")
)
// TLSProxy is a TLS proxy that routes the traffic depending
// on the SNI value and may implement filtering policies.
type TLSProxy struct {
// OnIncomingSNI is the MANDATORY hook called whenever we have
// successfully received a ClientHello message.
OnIncomingSNI func(sni string) TLSAction
// TLSServer is a TLS server implementing filtering policies.
type TLSServer struct {
// action is the action to perform.
action TLSAction
// cancel allows to cancel background operations.
cancel context.CancelFunc
// cert is the fake CA certificate.
cert *x509.Certificate
// config is the config to generate certificates on the fly.
config *mitm.Config
// done is closed when the background goroutine has terminated.
done chan bool
// endpoint is the endpoint where we're listening.
endpoint string
// listener is the TCP listener.
listener net.Listener
// privkey is the private key that signed the cert.
privkey *rsa.PrivateKey
}
// Start starts the proxy.
func (p *TLSProxy) Start(address string) (net.Listener, error) {
listener, _, err := p.start(address)
return listener, err
// NewTLSServer creates and starts a new TLSServer that executes
// the given action during the TLS handshake.
func NewTLSServer(action TLSAction) *TLSServer {
done := make(chan bool)
cert, privkey, err := mitm.NewAuthority("jafar", "OONI", 24*time.Hour)
runtimex.PanicOnError(err, "mitm.NewAuthority failed")
config, err := mitm.NewConfig(cert, privkey)
runtimex.PanicOnError(err, "mitm.NewConfig failed")
listener, err := net.Listen("tcp", "127.0.0.1:0")
runtimex.PanicOnError(err, "net.Listen failed")
ctx, cancel := context.WithCancel(context.Background())
endpoint := listener.Addr().String()
server := &TLSServer{
action: action,
cancel: cancel,
cert: cert,
config: config,
done: done,
endpoint: endpoint,
listener: listener,
privkey: privkey,
}
go server.mainloop(ctx)
return server
}
func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, error) {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, nil, err
}
done := make(chan interface{})
go p.mainloop(listener, done)
return listener, done, nil
// CertPool returns the internal CA as a cert pool.
func (p *TLSServer) CertPool() *x509.CertPool {
o := x509.NewCertPool()
o.AddCert(p.cert)
return o
}
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) {
defer close(done)
for p.oneloop(listener) {
// Endpoint returns the endpoint where the server is listening.
func (p *TLSServer) Endpoint() string {
return p.endpoint
}
// Close closes this server as soon as possible.
func (p *TLSServer) Close() error {
p.cancel()
err := p.listener.Close()
<-p.done
return err
}
func (p *TLSServer) mainloop(ctx context.Context) {
defer close(p.done)
for p.oneloop(ctx) {
// nothing
}
}
func (p *TLSProxy) oneloop(listener net.Listener) bool {
conn, err := listener.Accept()
if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") {
return false // we need to stop
}
func (p *TLSServer) oneloop(ctx context.Context) bool {
conn, err := p.listener.Accept()
if err != nil {
return true // we can continue running
return !errors.Is(err, net.ErrClosed)
}
go p.handle(conn)
go p.handle(ctx, conn)
return true // we can continue running
}
@ -85,102 +133,55 @@ const (
tlsAlertUnrecognizedName = byte(112)
)
func (p *TLSProxy) handle(conn net.Conn) {
defer conn.Close()
sni, hello, err := p.readClientHello(conn)
if err != nil {
p.reset(conn)
func (p *TLSServer) handle(ctx context.Context, tcpConn net.Conn) {
defer tcpConn.Close()
tlsConn := tls.Server(tcpConn, &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
switch p.action {
case TLSActionTimeout:
select {
case <-time.After(300 * time.Second):
return nil, errors.New("timing out the connection")
case <-ctx.Done():
p.reset(tcpConn)
return nil, ctx.Err()
}
case TLSActionAlertInternalError:
p.alert(tcpConn, tlsAlertInternalError)
return nil, errors.New("already sent alert")
case TLSActionAlertUnrecognizedName:
p.alert(tcpConn, tlsAlertUnrecognizedName)
return nil, errors.New("already sent alert")
case TLSActionEOF:
p.eof(tcpConn)
return nil, errors.New("already closed the connection")
case TLSActionBlockText:
return p.config.TLSForHost(info.ServerName).GetCertificate(info)
default:
p.reset(tcpConn)
return nil, errors.New("already RST the connection")
}
},
})
if err := tlsConn.Handshake(); err != nil {
return
}
switch p.OnIncomingSNI(sni) {
case TLSActionPass:
p.proxy(conn, sni, hello)
case TLSActionTimeout:
p.timeout(conn)
case TLSActionAlertInternalError:
p.alert(conn, tlsAlertInternalError)
case TLSActionAlertUnrecognizedName:
p.alert(conn, tlsAlertUnrecognizedName)
case TLSActionEOF:
p.eof(conn)
default:
p.reset(conn)
}
p.blockText(tlsConn)
tlsConn.Close()
}
// readClientHello reads the incoming ClientHello message.
//
// Arguments:
//
// - conn is the connection from which to read the ClientHello.
//
// Returns:
//
// - a string containing the SNI (empty on error);
//
// - bytes from the original ClientHello (nil on error);
//
// - an error (nil on success).
func (p *TLSProxy) readClientHello(conn net.Conn) (string, []byte, error) {
connWrapper := &tlsClientHelloReader{Conn: conn}
var (
expectedErr = errors.New("cannot continue handhake")
sni string
mutex sync.Mutex // just for safety
)
err := tls.Server(connWrapper, &tls.Config{
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
mutex.Lock()
sni = info.ServerName
mutex.Unlock()
return nil, expectedErr
},
}).Handshake()
if !errors.Is(err, expectedErr) {
return "", nil, err
}
return sni, connWrapper.clientHello, nil
}
// tlsClientHelloReader wraps a net.Conn for the purpose of
// saving the bytes of the ClientHello message.
type tlsClientHelloReader struct {
net.Conn
clientHello []byte
}
func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b)
if err != nil {
return 0, err
}
c.clientHello = append(c.clientHello, b[:count]...)
return count, nil
}
// Write prevents writing on the real connection
func (c *tlsClientHelloReader) Write(b []byte) (int, error) {
return 0, errors.New("cannot write on this connection")
}
func (p *TLSProxy) reset(conn net.Conn) {
if tc, ok := conn.(*net.TCPConn); ok {
func (p *TLSServer) reset(conn net.Conn) {
if tc, good := conn.(*net.TCPConn); good {
tc.SetLinger(0)
}
conn.Close()
}
func (p *TLSProxy) timeout(conn net.Conn) {
buffer := make([]byte, 1<<14)
conn.Read(buffer)
func (p *TLSServer) eof(conn net.Conn) {
conn.Close()
}
func (p *TLSProxy) eof(conn net.Conn) {
conn.Close()
}
func (p *TLSProxy) alert(conn net.Conn, code byte) {
func (p *TLSServer) alert(conn net.Conn, code byte) {
alertdata := []byte{
21, // alert
3, // version[0]
@ -194,55 +195,6 @@ func (p *TLSProxy) alert(conn net.Conn, code byte) {
conn.Close()
}
func (p *TLSProxy) proxy(conn net.Conn, sni string, hello []byte) {
p.proxydial(conn, sni, hello, net.Dial)
}
func (p *TLSProxy) proxydial(conn net.Conn, sni string, hello []byte,
dial func(network, address string) (net.Conn, error)) {
if sni == "" { // don't know the destination host
p.reset(conn)
return
}
serverconn, err := dial("tcp", net.JoinHostPort(sni, "443"))
if err != nil {
p.reset(conn)
return
}
if p.connectingToMyself(serverconn) {
p.reset(conn)
return
}
if _, err := serverconn.Write(hello); err != nil {
p.reset(conn)
return
}
defer serverconn.Close() // conn is owned by the caller
wg := &sync.WaitGroup{}
wg.Add(2)
go p.forward(wg, conn, serverconn)
go p.forward(wg, serverconn, conn)
wg.Wait()
}
// connectingToMyself returns true when the proxy has been somehow
// forced to create a connection to itself.
func (p *TLSProxy) connectingToMyself(conn net.Conn) bool {
local := conn.LocalAddr().String()
localAddr, _, localErr := net.SplitHostPort(local)
remote := conn.RemoteAddr().String()
remoteAddr, _, remoteErr := net.SplitHostPort(remote)
return localErr != nil || remoteErr != nil || localAddr == remoteAddr
}
// forward will forward the traffic.
func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) {
defer wg.Done()
// We cannot use netxlite.CopyContext here because we want netxlite to
// use filtering inside its test suite, so this package cannot depend on
// netxlite. In general, we don't want to use io.Copy or io.ReadAll
// directly because they may cause the code to block as documented in
// internal/netxlite/iox.go. However, this package is only used for
// testing, so it's completely okay to make an exception here.
io.Copy(left, right)
func (p *TLSServer) blockText(tlsConn net.Conn) {
tlsConn.Write(HTTPBlockpage451)
}

View File

@ -1,297 +1,146 @@
package filtering
import (
"bytes"
"context"
"crypto/tls"
"errors"
"net"
"io"
"strings"
"testing"
"time"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestTLSProxy(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
newproxy := func(action TLSAction) (net.Listener, <-chan interface{}, error) {
p := &TLSProxy{
OnIncomingSNI: func(sni string) TLSAction {
return action
},
}
return p.start("127.0.0.1:0")
}
dialTLS := func(ctx context.Context, endpoint string, sni string) (net.Conn, error) {
d := netxlite.NewDialerWithoutResolver(log.Log)
th := netxlite.NewTLSHandshakerStdlib(log.Log)
tdx := netxlite.NewTLSDialerWithConfig(d, th, &tls.Config{
ServerName: sni,
NextProtos: []string{"h2", "http/1.1"},
RootCAs: netxlite.NewDefaultCertPool(),
})
return tdx.DialTLSContext(ctx, "tcp", endpoint)
}
t.Run("TLSActionPass", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionPass)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err != nil {
t.Fatal(err)
}
conn.Close()
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionTimeout", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionTimeout)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
func TestTLSServer(t *testing.T) {
t.Run("TLSActionReset", func(t *testing.T) {
srv := NewTLSServer(TLSActionReset)
defer srv.Close()
config := &tls.Config{ServerName: "dns.google"}
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if netxlite.NewTopLevelGenericErrWrapper(err).Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("TLSActionTimeout", func(t *testing.T) {
srv := NewTLSServer(TLSActionTimeout)
defer srv.Close()
config := &tls.Config{ServerName: "dns.google"}
d := &tls.Dialer{Config: config}
ctx, cancel := context.WithTimeout(context.Background(), 70*time.Millisecond)
defer cancel()
conn, err := d.DialContext(ctx, "tcp", srv.Endpoint())
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionAlertInternalError", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionAlertInternalError)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
srv := NewTLSServer(TLSActionAlertInternalError)
defer srv.Close()
config := &tls.Config{ServerName: "dns.google"}
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionAlertUnrecognizedName)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
srv := NewTLSServer(TLSActionAlertUnrecognizedName)
defer srv.Close()
config := &tls.Config{ServerName: "dns.google"}
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionEOF", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionEOF)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err == nil || err.Error() != netxlite.FailureEOFError {
srv := NewTLSServer(TLSActionEOF)
defer srv.Close()
config := &tls.Config{ServerName: "dns.google"}
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if !errors.Is(err, io.EOF) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionReset", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionReset)
if err != nil {
t.Fatal(err)
t.Run("TLSActionBlockText", func(t *testing.T) {
t.Run("certificate error when we're validating", func(t *testing.T) {
srv := NewTLSServer(TLSActionBlockText)
defer srv.Close()
// Certificate.Verify now uses platform APIs to verify certificate validity
// on macOS and iOS when it is called with a nil VerifyOpts.Roots or when using
// the root pool returned from SystemCertPool. "
//
// -- https://tip.golang.org/doc/go1.18
//
// So we need to explicitly use our default cert pool otherwise we will
// see this test failing with a different error string here.
config := &tls.Config{
ServerName: "dns.google",
RootCAs: netxlite.NewDefaultCertPool(),
}
conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google")
if err == nil || err.Error() != netxlite.FailureConnectionReset {
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if err == nil || !strings.HasSuffix(err.Error(), "certificate signed by unknown authority") {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
dial := func(ctx context.Context, endpoint string) (net.Conn, error) {
d := netxlite.NewDialerWithoutResolver(log.Log)
return d.DialContext(ctx, "tcp", endpoint)
}
t.Run("handle cannot read ClientHello", func(t *testing.T) {
listener, done, err := newproxy(TLSActionPass)
t.Run("blocktext when we skip validation", func(t *testing.T) {
srv := NewTLSServer(TLSActionBlockText)
defer srv.Close()
config := &tls.Config{InsecureSkipVerify: true, ServerName: "dns.google"}
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if err != nil {
t.Fatal(err)
}
conn, err := dial(context.Background(), listener.Addr().String())
defer conn.Close()
data, err := io.ReadAll(conn)
if err != nil {
t.Fatal(err)
}
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n"))
buff := make([]byte, 1<<17)
_, err = conn.Read(buff)
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
if !bytes.Equal(HTTPBlockpage451, data) {
t.Fatal("unexpected block text")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionPass fails because we don't have SNI", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionPass)
t.Run("blocktext when we configure the cert pool", func(t *testing.T) {
srv := NewTLSServer(TLSActionBlockText)
defer srv.Close()
config := &tls.Config{RootCAs: srv.CertPool(), ServerName: "dns.google"}
conn, err := tls.Dial("tcp", srv.Endpoint(), config)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "127.0.0.1")
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("TLSActionPass fails because we can't dial", func(t *testing.T) {
ctx := context.Background()
listener, done, err := newproxy(TLSActionPass)
defer conn.Close()
data, err := io.ReadAll(conn)
if err != nil {
t.Fatal(err)
}
conn, err := dialTLS(ctx, listener.Addr().String(), "antani.ooni.org")
if err == nil || err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
listener.Close()
<-done // wait for background goroutine to exit
})
t.Run("proxydial fails because it's connecting to itself", func(t *testing.T) {
p := &TLSProxy{}
conn := &mocks.Conn{
MockClose: func() error {
return nil
},
}
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv6loopback,
}
},
MockRemoteAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv6loopback,
}
},
}, nil
})
})
t.Run("proxydial fails because it cannot write the hello", func(t *testing.T) {
p := &TLSProxy{}
conn := &mocks.Conn{
MockClose: func() error {
return nil
},
}
p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) {
return &mocks.Conn{
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv6loopback,
}
},
MockRemoteAddr: func() net.Addr {
return &net.TCPAddr{
IP: net.IPv4(10, 0, 0, 1),
}
},
MockWrite: func(b []byte) (int, error) {
return 0, errors.New("mocked error")
},
}, nil
})
})
t.Run("Start fails on an invalid address", func(t *testing.T) {
p := &TLSProxy{}
listener, err := p.Start("127.0.0.1")
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
t.Fatal("unexpected err", err)
}
if listener != nil {
t.Fatal("expected nil listener")
if !bytes.Equal(HTTPBlockpage451, data) {
t.Fatal("unexpected block text")
}
})
t.Run("oneloop correctly handles a listener error", func(t *testing.T) {
listener := &mocks.Listener{
MockAccept: func() (net.Conn, error) {
return nil, errors.New("mocked error")
},
}
p := &TLSProxy{}
if !p.oneloop(listener) {
t.Fatal("should return true here")
}
})
}
func TestTLSClientHelloReader(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
chr := &tlsClientHelloReader{
Conn: &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
},
clientHello: []byte{},
}
buf := make([]byte, 128)
count, err := chr.Read(buf)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("invalid count")
}
})
}

View File

@ -303,18 +303,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
}
connectionResetFlow := func(th model.TLSHandshaker) error {
tlsProxy := &filtering.TLSProxy{
OnIncomingSNI: func(sni string) filtering.TLSAction {
return filtering.TLSActionReset
},
}
listener, err := tlsProxy.Start("127.0.0.1:0")
if err != nil {
return fmt.Errorf("cannot start proxy: %w", err)
}
defer listener.Close()
server := filtering.NewTLSServer(filtering.TLSActionReset)
defer server.Close()
ctx := context.Background()
conn, err := dial(ctx, listener.Addr().String())
conn, err := dial(ctx, server.Endpoint())
if err != nil {
return fmt.Errorf("dial failed: %w", err)
}
@ -338,18 +330,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) {
}
timeoutFlow := func(th model.TLSHandshaker) error {
tlsProxy := &filtering.TLSProxy{
OnIncomingSNI: func(sni string) filtering.TLSAction {
return filtering.TLSActionTimeout
},
}
listener, err := tlsProxy.Start("127.0.0.1:0")
if err != nil {
return fmt.Errorf("cannot start proxy: %w", err)
}
defer listener.Close()
server := filtering.NewTLSServer(filtering.TLSActionTimeout)
defer server.Close()
ctx := context.Background()
conn, err := dial(ctx, listener.Addr().String())
conn, err := dial(ctx, server.Endpoint())
if err != nil {
return fmt.Errorf("dial failed: %w", err)
}

View File

@ -20,12 +20,8 @@ type (
HTTPTransportWrapper = httpTransportConnectionsCloser
HTTPTransportLogger = httpTransportLogger
ErrorWrapperResolver = resolverErrWrapper
ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper
ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport
ResolverLogger = resolverLogger
ResolverIDNA = resolverIDNA
TLSHandshakerConfigurable = tlsHandshakerConfigurable
TLSHandshakerLogger = tlsHandshakerLogger
TLSDialerLegacy = tlsDialer
AddressResolver = resolverShortCircuitIPAddr
)

View File

@ -60,6 +60,15 @@ var (
}
)
// ClonedTLSConfigOrNewEmptyConfig returns a clone of the provided config,
// if not nil, or a fresh and completely empty *tls.Config.
func ClonedTLSConfigOrNewEmptyConfig(config *tls.Config) *tls.Config {
if config != nil {
return config.Clone()
}
return &tls.Config{}
}
// TLSVersionString returns a TLS version string. If value is zero, we
// return the empty string. If the value is unknown, we return
// `TLS_VERSION_UNKNOWN_ddd` where `ddd` is the numeric value passed

View File

@ -591,3 +591,30 @@ func TestNewNullTLSDialer(t *testing.T) {
}
dialer.CloseIdleConnections() // does not crash
}
func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
t.Run("with nil config", func(t *testing.T) {
var input *tls.Config
output := ClonedTLSConfigOrNewEmptyConfig(input)
if output == nil {
t.Fatal("expected non-nil result")
}
v := reflect.ValueOf(*output)
if !v.IsZero() {
t.Fatal("expected zero config")
}
})
t.Run("", func(t *testing.T) {
input := &tls.Config{
ServerName: "dns.google",
}
output := ClonedTLSConfigOrNewEmptyConfig(input)
if output == input {
t.Fatal("expected two distinct objects")
}
if !reflect.DeepEqual(input, output) {
t.Fatal("apparently the two objects have different values")
}
})
}

View File

@ -7,7 +7,7 @@ for file in $(find . -type f -name \*.go); do
# implement safer wrappers for these functions.
continue
fi
if [ "$file" = "./internal/netxlite/filtering/tls.go" ]; then
if [ "$file" = "./internal/netxlite/filtering/tls_test.go" ]; then
# We're allowed to use ReadAll and Copy in this file to
# avoid depending on netxlite, so we can use filtering
# inside of netxlite's own test suite.