diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 476c78b..1ce0a94 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -64,8 +64,6 @@ type Config struct { TLSSaver *tracex.Saver // default: not saving TLS } -var defaultCertPool *x509.CertPool = netxlite.NewDefaultCertPool() - // NewResolver creates a new resolver from the specified config func NewResolver(config Config) model.Resolver { if config.BaseResolver == nil { @@ -132,25 +130,16 @@ func NewTLSDialer(config Config) model.TLSDialer { if config.Dialer == nil { config.Dialer = NewDialer(config) } - var h model.TLSHandshaker = &netxlite.TLSHandshakerConfigurable{} - h = &netxlite.ErrorWrapperTLSHandshaker{TLSHandshaker: h} - if config.Logger != nil { - h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h} - } - h = config.TLSSaver.WrapTLSHandshaker(h) // behaves with nil TLSSaver - if config.TLSConfig == nil { - config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}} - } - if config.CertPool == nil { - config.CertPool = defaultCertPool - } - config.TLSConfig.RootCAs = config.CertPool - config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify - return &netxlite.TLSDialerLegacy{ - Config: config.TLSConfig, - Dialer: config.Dialer, - TLSHandshaker: h, - } + logger := model.ValidLoggerOrDefault(config.Logger) + thx := netxlite.NewTLSHandshakerStdlib(logger) + thx = config.TLSSaver.WrapTLSHandshaker(thx) // WAI when TLSSaver is nil + tlsConfig := netxlite.ClonedTLSConfigOrNewEmptyConfig(config.TLSConfig) + // TODO(bassosimone): we should not provide confusing options and + // so we should drop CertPool and NoTLSVerify in favour of encouraging + // the users of this library to always use a TLSConfig. + tlsConfig.RootCAs = config.CertPool // netxlite uses default cert pool if this is nil + tlsConfig.InsecureSkipVerify = config.NoTLSVerify + return netxlite.NewTLSDialerWithConfig(config.Dialer, thx, tlsConfig) } // NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index fc372db..655b1ff 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -13,6 +13,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" "github.com/ooni/probe-cli/v3/internal/tracex" ) @@ -208,210 +209,103 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { } } -func TestNewTLSDialerVanilla(t *testing.T) { - td := NewTLSDialer(Config{}) - rtd, ok := td.(*netxlite.TLSDialerLegacy) - if !ok { - t.Fatal("not the TLSDialer we expected") - } - if len(rtd.Config.NextProtos) != 2 { - t.Fatal("invalid len(config.NextProtos)") - } - if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { - t.Fatal("invalid Config.NextProtos") - } - if rtd.Config.RootCAs != defaultCertPool { - t.Fatal("invalid Config.RootCAs") - } - if rtd.Dialer == nil { - t.Fatal("invalid Dialer") - } - if rtd.TLSHandshaker == nil { - t.Fatal("invalid TLSHandshaker") - } - ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { - t.Fatal("not the TLSHandshaker we expected") - } -} - -func TestNewTLSDialerWithConfig(t *testing.T) { - td := NewTLSDialer(Config{ - TLSConfig: new(tls.Config), +func TestNewTLSDialer(t *testing.T) { + t.Run("we always have error wrapping", func(t *testing.T) { + server := filtering.NewTLSServer(filtering.TLSActionReset) + defer server.Close() + tdx := NewTLSDialer(Config{}) + conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint()) + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } }) - rtd, ok := td.(*netxlite.TLSDialerLegacy) - if !ok { - t.Fatal("not the TLSDialer we expected") - } - if len(rtd.Config.NextProtos) != 0 { - t.Fatal("invalid len(config.NextProtos)") - } - if rtd.Config.RootCAs != defaultCertPool { - t.Fatal("invalid Config.RootCAs") - } - if rtd.Dialer == nil { - t.Fatal("invalid Dialer") - } - if rtd.TLSHandshaker == nil { - t.Fatal("invalid TLSHandshaker") - } - ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { - t.Fatal("not the TLSHandshaker we expected") - } -} -func TestNewTLSDialerWithLogging(t *testing.T) { - td := NewTLSDialer(Config{ - Logger: log.Log, + t.Run("we can collect TLS measurements", func(t *testing.T) { + server := filtering.NewTLSServer(filtering.TLSActionReset) + defer server.Close() + saver := &tracex.Saver{} + tdx := NewTLSDialer(Config{ + TLSSaver: saver, + }) + conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint()) + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + if len(saver.Read()) <= 0 { + t.Fatal("did not read any event") + } }) - rtd, ok := td.(*netxlite.TLSDialerLegacy) - if !ok { - t.Fatal("not the TLSDialer we expected") - } - if len(rtd.Config.NextProtos) != 2 { - t.Fatal("invalid len(config.NextProtos)") - } - if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { - t.Fatal("invalid Config.NextProtos") - } - if rtd.Config.RootCAs != defaultCertPool { - t.Fatal("invalid Config.RootCAs") - } - if rtd.Dialer == nil { - t.Fatal("invalid Dialer") - } - if rtd.TLSHandshaker == nil { - t.Fatal("invalid TLSHandshaker") - } - lth, ok := rtd.TLSHandshaker.(*netxlite.TLSHandshakerLogger) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if lth.DebugLogger != log.Log { - t.Fatal("not the Logger we expected") - } - ewth, ok := lth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { - t.Fatal("not the TLSHandshaker we expected") - } -} -func TestNewTLSDialerWithSaver(t *testing.T) { - saver := new(tracex.Saver) - td := NewTLSDialer(Config{ - TLSSaver: saver, + t.Run("we can collect dial measurements", func(t *testing.T) { + server := filtering.NewTLSServer(filtering.TLSActionReset) + defer server.Close() + saver := &tracex.Saver{} + tdx := NewTLSDialer(Config{ + DialSaver: saver, + }) + conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint()) + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + if len(saver.Read()) <= 0 { + t.Fatal("did not read any event") + } }) - rtd, ok := td.(*netxlite.TLSDialerLegacy) - if !ok { - t.Fatal("not the TLSDialer we expected") - } - if len(rtd.Config.NextProtos) != 2 { - t.Fatal("invalid len(config.NextProtos)") - } - if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { - t.Fatal("invalid Config.NextProtos") - } - if rtd.Config.RootCAs != defaultCertPool { - t.Fatal("invalid Config.RootCAs") - } - if rtd.Dialer == nil { - t.Fatal("invalid Dialer") - } - if rtd.TLSHandshaker == nil { - t.Fatal("invalid TLSHandshaker") - } - sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if sth.Saver != saver { - t.Fatal("not the Logger we expected") - } - ewth, ok := sth.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { - t.Fatal("not the TLSHandshaker we expected") - } -} -func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { - td := NewTLSDialer(Config{ - TLSConfig: new(tls.Config), - NoTLSVerify: true, + t.Run("we can collect I/O measurements", func(t *testing.T) { + server := filtering.NewTLSServer(filtering.TLSActionReset) + defer server.Close() + saver := &tracex.Saver{} + tdx := NewTLSDialer(Config{ + ReadWriteSaver: saver, + }) + conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint()) + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + if len(saver.Read()) <= 0 { + t.Fatal("did not read any event") + } }) - rtd, ok := td.(*netxlite.TLSDialerLegacy) - if !ok { - t.Fatal("not the TLSDialer we expected") - } - if len(rtd.Config.NextProtos) != 0 { - t.Fatal("invalid len(config.NextProtos)") - } - if rtd.Config.InsecureSkipVerify != true { - t.Fatal("expected true InsecureSkipVerify") - } - if rtd.Config.RootCAs != defaultCertPool { - t.Fatal("invalid Config.RootCAs") - } - if rtd.Dialer == nil { - t.Fatal("invalid Dialer") - } - if rtd.TLSHandshaker == nil { - t.Fatal("invalid TLSHandshaker") - } - ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { - t.Fatal("not the TLSHandshaker we expected") - } -} -func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { - td := NewTLSDialer(Config{ - NoTLSVerify: true, + t.Run("we can skip TLS verification", func(t *testing.T) { + server := filtering.NewTLSServer(filtering.TLSActionBlockText) + defer server.Close() + tdx := NewTLSDialer(Config{NoTLSVerify: true}) + conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint()) + if err != nil { + t.Fatal(err.(*netxlite.ErrWrapper).WrappedErr) + } + conn.Close() + }) + + t.Run("we can set the cert pool", func(t *testing.T) { + server := filtering.NewTLSServer(filtering.TLSActionBlockText) + defer server.Close() + tdx := NewTLSDialer(Config{ + CertPool: server.CertPool(), + TLSConfig: &tls.Config{ + ServerName: "dns.google", + }, + }) + conn, err := tdx.DialTLSContext(context.Background(), "tcp", server.Endpoint()) + if err != nil { + t.Fatal(err) + } + conn.Close() }) - rtd, ok := td.(*netxlite.TLSDialerLegacy) - if !ok { - t.Fatal("not the TLSDialer we expected") - } - if len(rtd.Config.NextProtos) != 2 { - t.Fatal("invalid len(config.NextProtos)") - } - if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { - t.Fatal("invalid Config.NextProtos") - } - if rtd.Config.InsecureSkipVerify != true { - t.Fatal("expected true InsecureSkipVerify") - } - if rtd.Config.RootCAs != defaultCertPool { - t.Fatal("invalid Config.RootCAs") - } - if rtd.Dialer == nil { - t.Fatal("invalid Dialer") - } - if rtd.TLSHandshaker == nil { - t.Fatal("invalid TLSHandshaker") - } - ewth, ok := rtd.TLSHandshaker.(*netxlite.ErrorWrapperTLSHandshaker) - if !ok { - t.Fatal("not the TLSHandshaker we expected") - } - if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerConfigurable); !ok { - t.Fatal("not the TLSHandshaker we expected") - } } func TestNewVanilla(t *testing.T) { @@ -441,33 +335,6 @@ func TestNewWithDialer(t *testing.T) { } } -func TestNewWithTLSDialer(t *testing.T) { - expected := errors.New("mocked error") - tlsDialer := &netxlite.TLSDialerLegacy{ - Config: new(tls.Config), - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, expected - }, - MockCloseIdleConnections: func() { - // nothing - }, - }, - TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, - } - txp := NewHTTPTransport(Config{ - TLSDialer: tlsDialer, - }) - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if resp != nil { - t.Fatal("not the response we expected") - } -} - func TestNewWithByteCounter(t *testing.T) { counter := bytecounter.New() txp := NewHTTPTransport(Config{ diff --git a/internal/netxlite/filtering/tls.go b/internal/netxlite/filtering/tls.go index 298a30b..abc5ddc 100644 --- a/internal/netxlite/filtering/tls.go +++ b/internal/netxlite/filtering/tls.go @@ -1,24 +1,22 @@ package filtering import ( + "context" + "crypto/rsa" "crypto/tls" + "crypto/x509" "errors" - "io" "net" - "strings" - "sync" -) + "time" -// TODO(bassosimone): remove TLSActionPass since we want integration tests -// to only run locally to make them much more predictable. + "github.com/google/martian/v3/mitm" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) // TLSAction is a TLS filtering action that this proxy should take. type TLSAction string const ( - // TLSActionPass passes the traffic to the destination. - TLSActionPass = TLSAction("pass") - // TLSActionReset resets the connection. TLSActionReset = TLSAction("reset") @@ -35,48 +33,98 @@ const ( // TLSActionAlertUnrecognizedName tells the client that // it's handshaking with an unknown SNI. TLSActionAlertUnrecognizedName = TLSAction("alert-unrecognized-name") + + // TLSActionBlockText returns a static piece of text + // to the client saying this website is blocked. + TLSActionBlockText = TLSAction("block-text") ) -// TLSProxy is a TLS proxy that routes the traffic depending -// on the SNI value and may implement filtering policies. -type TLSProxy struct { - // OnIncomingSNI is the MANDATORY hook called whenever we have - // successfully received a ClientHello message. - OnIncomingSNI func(sni string) TLSAction +// TLSServer is a TLS server implementing filtering policies. +type TLSServer struct { + // action is the action to perform. + action TLSAction + + // cancel allows to cancel background operations. + cancel context.CancelFunc + + // cert is the fake CA certificate. + cert *x509.Certificate + + // config is the config to generate certificates on the fly. + config *mitm.Config + + // done is closed when the background goroutine has terminated. + done chan bool + + // endpoint is the endpoint where we're listening. + endpoint string + + // listener is the TCP listener. + listener net.Listener + + // privkey is the private key that signed the cert. + privkey *rsa.PrivateKey } -// Start starts the proxy. -func (p *TLSProxy) Start(address string) (net.Listener, error) { - listener, _, err := p.start(address) - return listener, err -} - -func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, error) { - listener, err := net.Listen("tcp", address) - if err != nil { - return nil, nil, err +// NewTLSServer creates and starts a new TLSServer that executes +// the given action during the TLS handshake. +func NewTLSServer(action TLSAction) *TLSServer { + done := make(chan bool) + cert, privkey, err := mitm.NewAuthority("jafar", "OONI", 24*time.Hour) + runtimex.PanicOnError(err, "mitm.NewAuthority failed") + config, err := mitm.NewConfig(cert, privkey) + runtimex.PanicOnError(err, "mitm.NewConfig failed") + listener, err := net.Listen("tcp", "127.0.0.1:0") + runtimex.PanicOnError(err, "net.Listen failed") + ctx, cancel := context.WithCancel(context.Background()) + endpoint := listener.Addr().String() + server := &TLSServer{ + action: action, + cancel: cancel, + cert: cert, + config: config, + done: done, + endpoint: endpoint, + listener: listener, + privkey: privkey, } - done := make(chan interface{}) - go p.mainloop(listener, done) - return listener, done, nil + go server.mainloop(ctx) + return server } -func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) { - defer close(done) - for p.oneloop(listener) { +// CertPool returns the internal CA as a cert pool. +func (p *TLSServer) CertPool() *x509.CertPool { + o := x509.NewCertPool() + o.AddCert(p.cert) + return o +} + +// Endpoint returns the endpoint where the server is listening. +func (p *TLSServer) Endpoint() string { + return p.endpoint +} + +// Close closes this server as soon as possible. +func (p *TLSServer) Close() error { + p.cancel() + err := p.listener.Close() + <-p.done + return err +} + +func (p *TLSServer) mainloop(ctx context.Context) { + defer close(p.done) + for p.oneloop(ctx) { // nothing } } -func (p *TLSProxy) oneloop(listener net.Listener) bool { - conn, err := listener.Accept() - if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") { - return false // we need to stop - } +func (p *TLSServer) oneloop(ctx context.Context) bool { + conn, err := p.listener.Accept() if err != nil { - return true // we can continue running + return !errors.Is(err, net.ErrClosed) } - go p.handle(conn) + go p.handle(ctx, conn) return true // we can continue running } @@ -85,102 +133,55 @@ const ( tlsAlertUnrecognizedName = byte(112) ) -func (p *TLSProxy) handle(conn net.Conn) { - defer conn.Close() - sni, hello, err := p.readClientHello(conn) - if err != nil { - p.reset(conn) +func (p *TLSServer) handle(ctx context.Context, tcpConn net.Conn) { + defer tcpConn.Close() + tlsConn := tls.Server(tcpConn, &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + switch p.action { + case TLSActionTimeout: + select { + case <-time.After(300 * time.Second): + return nil, errors.New("timing out the connection") + case <-ctx.Done(): + p.reset(tcpConn) + return nil, ctx.Err() + } + case TLSActionAlertInternalError: + p.alert(tcpConn, tlsAlertInternalError) + return nil, errors.New("already sent alert") + case TLSActionAlertUnrecognizedName: + p.alert(tcpConn, tlsAlertUnrecognizedName) + return nil, errors.New("already sent alert") + case TLSActionEOF: + p.eof(tcpConn) + return nil, errors.New("already closed the connection") + case TLSActionBlockText: + return p.config.TLSForHost(info.ServerName).GetCertificate(info) + default: + p.reset(tcpConn) + return nil, errors.New("already RST the connection") + } + }, + }) + if err := tlsConn.Handshake(); err != nil { return } - switch p.OnIncomingSNI(sni) { - case TLSActionPass: - p.proxy(conn, sni, hello) - case TLSActionTimeout: - p.timeout(conn) - case TLSActionAlertInternalError: - p.alert(conn, tlsAlertInternalError) - case TLSActionAlertUnrecognizedName: - p.alert(conn, tlsAlertUnrecognizedName) - case TLSActionEOF: - p.eof(conn) - default: - p.reset(conn) - } + p.blockText(tlsConn) + tlsConn.Close() } -// readClientHello reads the incoming ClientHello message. -// -// Arguments: -// -// - conn is the connection from which to read the ClientHello. -// -// Returns: -// -// - a string containing the SNI (empty on error); -// -// - bytes from the original ClientHello (nil on error); -// -// - an error (nil on success). -func (p *TLSProxy) readClientHello(conn net.Conn) (string, []byte, error) { - connWrapper := &tlsClientHelloReader{Conn: conn} - var ( - expectedErr = errors.New("cannot continue handhake") - sni string - mutex sync.Mutex // just for safety - ) - err := tls.Server(connWrapper, &tls.Config{ - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - mutex.Lock() - sni = info.ServerName - mutex.Unlock() - return nil, expectedErr - }, - }).Handshake() - if !errors.Is(err, expectedErr) { - return "", nil, err - } - return sni, connWrapper.clientHello, nil -} - -// tlsClientHelloReader wraps a net.Conn for the purpose of -// saving the bytes of the ClientHello message. -type tlsClientHelloReader struct { - net.Conn - clientHello []byte -} - -func (c *tlsClientHelloReader) Read(b []byte) (int, error) { - count, err := c.Conn.Read(b) - if err != nil { - return 0, err - } - c.clientHello = append(c.clientHello, b[:count]...) - return count, nil -} - -// Write prevents writing on the real connection -func (c *tlsClientHelloReader) Write(b []byte) (int, error) { - return 0, errors.New("cannot write on this connection") -} - -func (p *TLSProxy) reset(conn net.Conn) { - if tc, ok := conn.(*net.TCPConn); ok { +func (p *TLSServer) reset(conn net.Conn) { + if tc, good := conn.(*net.TCPConn); good { tc.SetLinger(0) } conn.Close() } -func (p *TLSProxy) timeout(conn net.Conn) { - buffer := make([]byte, 1<<14) - conn.Read(buffer) +func (p *TLSServer) eof(conn net.Conn) { conn.Close() } -func (p *TLSProxy) eof(conn net.Conn) { - conn.Close() -} - -func (p *TLSProxy) alert(conn net.Conn, code byte) { +func (p *TLSServer) alert(conn net.Conn, code byte) { alertdata := []byte{ 21, // alert 3, // version[0] @@ -194,55 +195,6 @@ func (p *TLSProxy) alert(conn net.Conn, code byte) { conn.Close() } -func (p *TLSProxy) proxy(conn net.Conn, sni string, hello []byte) { - p.proxydial(conn, sni, hello, net.Dial) -} - -func (p *TLSProxy) proxydial(conn net.Conn, sni string, hello []byte, - dial func(network, address string) (net.Conn, error)) { - if sni == "" { // don't know the destination host - p.reset(conn) - return - } - serverconn, err := dial("tcp", net.JoinHostPort(sni, "443")) - if err != nil { - p.reset(conn) - return - } - if p.connectingToMyself(serverconn) { - p.reset(conn) - return - } - if _, err := serverconn.Write(hello); err != nil { - p.reset(conn) - return - } - defer serverconn.Close() // conn is owned by the caller - wg := &sync.WaitGroup{} - wg.Add(2) - go p.forward(wg, conn, serverconn) - go p.forward(wg, serverconn, conn) - wg.Wait() -} - -// connectingToMyself returns true when the proxy has been somehow -// forced to create a connection to itself. -func (p *TLSProxy) connectingToMyself(conn net.Conn) bool { - local := conn.LocalAddr().String() - localAddr, _, localErr := net.SplitHostPort(local) - remote := conn.RemoteAddr().String() - remoteAddr, _, remoteErr := net.SplitHostPort(remote) - return localErr != nil || remoteErr != nil || localAddr == remoteAddr -} - -// forward will forward the traffic. -func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) { - defer wg.Done() - // We cannot use netxlite.CopyContext here because we want netxlite to - // use filtering inside its test suite, so this package cannot depend on - // netxlite. In general, we don't want to use io.Copy or io.ReadAll - // directly because they may cause the code to block as documented in - // internal/netxlite/iox.go. However, this package is only used for - // testing, so it's completely okay to make an exception here. - io.Copy(left, right) +func (p *TLSServer) blockText(tlsConn net.Conn) { + tlsConn.Write(HTTPBlockpage451) } diff --git a/internal/netxlite/filtering/tls_test.go b/internal/netxlite/filtering/tls_test.go index 7ba9e51..5096978 100644 --- a/internal/netxlite/filtering/tls_test.go +++ b/internal/netxlite/filtering/tls_test.go @@ -1,297 +1,146 @@ package filtering import ( + "bytes" "context" "crypto/tls" "errors" - "net" + "io" "strings" "testing" + "time" - "github.com/apex/log" - "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestTLSProxy(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - newproxy := func(action TLSAction) (net.Listener, <-chan interface{}, error) { - p := &TLSProxy{ - OnIncomingSNI: func(sni string) TLSAction { - return action - }, - } - return p.start("127.0.0.1:0") - } - - dialTLS := func(ctx context.Context, endpoint string, sni string) (net.Conn, error) { - d := netxlite.NewDialerWithoutResolver(log.Log) - th := netxlite.NewTLSHandshakerStdlib(log.Log) - tdx := netxlite.NewTLSDialerWithConfig(d, th, &tls.Config{ - ServerName: sni, - NextProtos: []string{"h2", "http/1.1"}, - RootCAs: netxlite.NewDefaultCertPool(), - }) - return tdx.DialTLSContext(ctx, "tcp", endpoint) - } - - t.Run("TLSActionPass", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionPass) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") - if err != nil { - t.Fatal(err) - } - conn.Close() - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("TLSActionTimeout", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionTimeout) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") - if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { +func TestTLSServer(t *testing.T) { + t.Run("TLSActionReset", func(t *testing.T) { + srv := NewTLSServer(TLSActionReset) + defer srv.Close() + config := &tls.Config{ServerName: "dns.google"} + conn, err := tls.Dial("tcp", srv.Endpoint(), config) + if netxlite.NewTopLevelGenericErrWrapper(err).Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("TLSActionTimeout", func(t *testing.T) { + srv := NewTLSServer(TLSActionTimeout) + defer srv.Close() + config := &tls.Config{ServerName: "dns.google"} + d := &tls.Dialer{Config: config} + ctx, cancel := context.WithTimeout(context.Background(), 70*time.Millisecond) + defer cancel() + conn, err := d.DialContext(ctx, "tcp", srv.Endpoint()) + if !errors.Is(err, context.DeadlineExceeded) { t.Fatal("unexpected err", err) } if conn != nil { t.Fatal("expected nil conn") } - listener.Close() - <-done // wait for background goroutine to exit }) t.Run("TLSActionAlertInternalError", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionAlertInternalError) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + srv := NewTLSServer(TLSActionAlertInternalError) + defer srv.Close() + config := &tls.Config{ServerName: "dns.google"} + conn, err := tls.Dial("tcp", srv.Endpoint(), config) if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") { t.Fatal("unexpected err", err) } if conn != nil { t.Fatal("expected nil conn") } - listener.Close() - <-done // wait for background goroutine to exit }) t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionAlertUnrecognizedName) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + srv := NewTLSServer(TLSActionAlertUnrecognizedName) + defer srv.Close() + config := &tls.Config{ServerName: "dns.google"} + conn, err := tls.Dial("tcp", srv.Endpoint(), config) if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") { t.Fatal("unexpected err", err) } if conn != nil { t.Fatal("expected nil conn") } - listener.Close() - <-done // wait for background goroutine to exit }) t.Run("TLSActionEOF", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionEOF) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") - if err == nil || err.Error() != netxlite.FailureEOFError { + srv := NewTLSServer(TLSActionEOF) + defer srv.Close() + config := &tls.Config{ServerName: "dns.google"} + conn, err := tls.Dial("tcp", srv.Endpoint(), config) + if !errors.Is(err, io.EOF) { t.Fatal("unexpected err", err) } if conn != nil { t.Fatal("expected nil conn") } - listener.Close() - <-done // wait for background goroutine to exit }) - t.Run("TLSActionReset", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionReset) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") - if err == nil || err.Error() != netxlite.FailureConnectionReset { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("expected nil conn") - } - listener.Close() - <-done // wait for background goroutine to exit - }) + t.Run("TLSActionBlockText", func(t *testing.T) { + t.Run("certificate error when we're validating", func(t *testing.T) { + srv := NewTLSServer(TLSActionBlockText) + defer srv.Close() + // Certificate.Verify now uses platform APIs to verify certificate validity + // on macOS and iOS when it is called with a nil VerifyOpts.Roots or when using + // the root pool returned from SystemCertPool. " + // + // -- https://tip.golang.org/doc/go1.18 + // + // So we need to explicitly use our default cert pool otherwise we will + // see this test failing with a different error string here. + config := &tls.Config{ + ServerName: "dns.google", + RootCAs: netxlite.NewDefaultCertPool(), + } + conn, err := tls.Dial("tcp", srv.Endpoint(), config) + if err == nil || !strings.HasSuffix(err.Error(), "certificate signed by unknown authority") { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) - dial := func(ctx context.Context, endpoint string) (net.Conn, error) { - d := netxlite.NewDialerWithoutResolver(log.Log) - return d.DialContext(ctx, "tcp", endpoint) - } + t.Run("blocktext when we skip validation", func(t *testing.T) { + srv := NewTLSServer(TLSActionBlockText) + defer srv.Close() + config := &tls.Config{InsecureSkipVerify: true, ServerName: "dns.google"} + conn, err := tls.Dial("tcp", srv.Endpoint(), config) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + data, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(HTTPBlockpage451, data) { + t.Fatal("unexpected block text") + } + }) - t.Run("handle cannot read ClientHello", func(t *testing.T) { - listener, done, err := newproxy(TLSActionPass) - if err != nil { - t.Fatal(err) - } - conn, err := dial(context.Background(), listener.Addr().String()) - if err != nil { - t.Fatal(err) - } - conn.Write([]byte("GET / HTTP/1.0\r\n\r\n")) - buff := make([]byte, 1<<17) - _, err = conn.Read(buff) - if err == nil || err.Error() != netxlite.FailureConnectionReset { - t.Fatal("unexpected err", err) - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("TLSActionPass fails because we don't have SNI", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionPass) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "127.0.0.1") - if err == nil || err.Error() != netxlite.FailureConnectionReset { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("expected nil conn") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("TLSActionPass fails because we can't dial", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newproxy(TLSActionPass) - if err != nil { - t.Fatal(err) - } - conn, err := dialTLS(ctx, listener.Addr().String(), "antani.ooni.org") - if err == nil || err.Error() != netxlite.FailureConnectionReset { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("expected nil conn") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("proxydial fails because it's connecting to itself", func(t *testing.T) { - p := &TLSProxy{} - conn := &mocks.Conn{ - MockClose: func() error { - return nil - }, - } - p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockClose: func() error { - return nil - }, - MockLocalAddr: func() net.Addr { - return &net.TCPAddr{ - IP: net.IPv6loopback, - } - }, - MockRemoteAddr: func() net.Addr { - return &net.TCPAddr{ - IP: net.IPv6loopback, - } - }, - }, nil + t.Run("blocktext when we configure the cert pool", func(t *testing.T) { + srv := NewTLSServer(TLSActionBlockText) + defer srv.Close() + config := &tls.Config{RootCAs: srv.CertPool(), ServerName: "dns.google"} + conn, err := tls.Dial("tcp", srv.Endpoint(), config) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + data, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(HTTPBlockpage451, data) { + t.Fatal("unexpected block text") + } }) }) - - t.Run("proxydial fails because it cannot write the hello", func(t *testing.T) { - p := &TLSProxy{} - conn := &mocks.Conn{ - MockClose: func() error { - return nil - }, - } - p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockClose: func() error { - return nil - }, - MockLocalAddr: func() net.Addr { - return &net.TCPAddr{ - IP: net.IPv6loopback, - } - }, - MockRemoteAddr: func() net.Addr { - return &net.TCPAddr{ - IP: net.IPv4(10, 0, 0, 1), - } - }, - MockWrite: func(b []byte) (int, error) { - return 0, errors.New("mocked error") - }, - }, nil - }) - }) - - t.Run("Start fails on an invalid address", func(t *testing.T) { - p := &TLSProxy{} - listener, err := p.Start("127.0.0.1") - if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { - t.Fatal("unexpected err", err) - } - if listener != nil { - t.Fatal("expected nil listener") - } - }) - - t.Run("oneloop correctly handles a listener error", func(t *testing.T) { - listener := &mocks.Listener{ - MockAccept: func() (net.Conn, error) { - return nil, errors.New("mocked error") - }, - } - p := &TLSProxy{} - if !p.oneloop(listener) { - t.Fatal("should return true here") - } - }) -} - -func TestTLSClientHelloReader(t *testing.T) { - t.Run("on failure", func(t *testing.T) { - expected := errors.New("mocked error") - chr := &tlsClientHelloReader{ - Conn: &mocks.Conn{ - MockRead: func(b []byte) (int, error) { - return 0, expected - }, - }, - clientHello: []byte{}, - } - buf := make([]byte, 128) - count, err := chr.Read(buf) - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if count != 0 { - t.Fatal("invalid count") - } - }) } diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index da172a8..5ff1915 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -303,18 +303,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) { } connectionResetFlow := func(th model.TLSHandshaker) error { - tlsProxy := &filtering.TLSProxy{ - OnIncomingSNI: func(sni string) filtering.TLSAction { - return filtering.TLSActionReset - }, - } - listener, err := tlsProxy.Start("127.0.0.1:0") - if err != nil { - return fmt.Errorf("cannot start proxy: %w", err) - } - defer listener.Close() + server := filtering.NewTLSServer(filtering.TLSActionReset) + defer server.Close() ctx := context.Background() - conn, err := dial(ctx, listener.Addr().String()) + conn, err := dial(ctx, server.Endpoint()) if err != nil { return fmt.Errorf("dial failed: %w", err) } @@ -338,18 +330,10 @@ func TestMeasureWithTLSHandshaker(t *testing.T) { } timeoutFlow := func(th model.TLSHandshaker) error { - tlsProxy := &filtering.TLSProxy{ - OnIncomingSNI: func(sni string) filtering.TLSAction { - return filtering.TLSActionTimeout - }, - } - listener, err := tlsProxy.Start("127.0.0.1:0") - if err != nil { - return fmt.Errorf("cannot start proxy: %w", err) - } - defer listener.Close() + server := filtering.NewTLSServer(filtering.TLSActionTimeout) + defer server.Close() ctx := context.Background() - conn, err := dial(ctx, listener.Addr().String()) + conn, err := dial(ctx, server.Endpoint()) if err != nil { return fmt.Errorf("dial failed: %w", err) } diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index 0316f79..bd284e9 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -20,12 +20,8 @@ type ( HTTPTransportWrapper = httpTransportConnectionsCloser HTTPTransportLogger = httpTransportLogger ErrorWrapperResolver = resolverErrWrapper - ErrorWrapperTLSHandshaker = tlsHandshakerErrWrapper ResolverSystemDoNotInstantiate = resolverSystem // instantiate => crash w/ nil transport ResolverLogger = resolverLogger ResolverIDNA = resolverIDNA - TLSHandshakerConfigurable = tlsHandshakerConfigurable - TLSHandshakerLogger = tlsHandshakerLogger - TLSDialerLegacy = tlsDialer AddressResolver = resolverShortCircuitIPAddr ) diff --git a/internal/netxlite/tls.go b/internal/netxlite/tls.go index 3cecd4e..314626e 100644 --- a/internal/netxlite/tls.go +++ b/internal/netxlite/tls.go @@ -60,6 +60,15 @@ var ( } ) +// ClonedTLSConfigOrNewEmptyConfig returns a clone of the provided config, +// if not nil, or a fresh and completely empty *tls.Config. +func ClonedTLSConfigOrNewEmptyConfig(config *tls.Config) *tls.Config { + if config != nil { + return config.Clone() + } + return &tls.Config{} +} + // TLSVersionString returns a TLS version string. If value is zero, we // return the empty string. If the value is unknown, we return // `TLS_VERSION_UNKNOWN_ddd` where `ddd` is the numeric value passed diff --git a/internal/netxlite/tls_test.go b/internal/netxlite/tls_test.go index 1c0cd88..6d87af0 100644 --- a/internal/netxlite/tls_test.go +++ b/internal/netxlite/tls_test.go @@ -591,3 +591,30 @@ func TestNewNullTLSDialer(t *testing.T) { } dialer.CloseIdleConnections() // does not crash } + +func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) { + t.Run("with nil config", func(t *testing.T) { + var input *tls.Config + output := ClonedTLSConfigOrNewEmptyConfig(input) + if output == nil { + t.Fatal("expected non-nil result") + } + v := reflect.ValueOf(*output) + if !v.IsZero() { + t.Fatal("expected zero config") + } + }) + + t.Run("", func(t *testing.T) { + input := &tls.Config{ + ServerName: "dns.google", + } + output := ClonedTLSConfigOrNewEmptyConfig(input) + if output == input { + t.Fatal("expected two distinct objects") + } + if !reflect.DeepEqual(input, output) { + t.Fatal("apparently the two objects have different values") + } + }) +} diff --git a/script/nocopyreadall.bash b/script/nocopyreadall.bash index cd4a234..999a894 100755 --- a/script/nocopyreadall.bash +++ b/script/nocopyreadall.bash @@ -7,7 +7,7 @@ for file in $(find . -type f -name \*.go); do # implement safer wrappers for these functions. continue fi - if [ "$file" = "./internal/netxlite/filtering/tls.go" ]; then + if [ "$file" = "./internal/netxlite/filtering/tls_test.go" ]; then # We're allowed to use ReadAll and Copy in this file to # avoid depending on netxlite, so we can use filtering # inside of netxlite's own test suite.