feat: upgrade oohttp and propagate changes (#461)
Part of https://github.com/ooni/probe/issues/1506
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
package mocks
|
||||
|
||||
import "crypto/tls"
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
)
|
||||
|
||||
// TLSConn allows to mock netxlite.TLSConn.
|
||||
type TLSConn struct {
|
||||
@@ -10,8 +13,8 @@ type TLSConn struct {
|
||||
// MockConnectionState allows to mock the ConnectionState method.
|
||||
MockConnectionState func() tls.ConnectionState
|
||||
|
||||
// MockHandshake allows to mock the Handshake method.
|
||||
MockHandshake func() error
|
||||
// MockHandshakeContext allows to mock the HandshakeContext method.
|
||||
MockHandshakeContext func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// ConnectionState calls MockConnectionState.
|
||||
@@ -19,7 +22,7 @@ func (c *TLSConn) ConnectionState() tls.ConnectionState {
|
||||
return c.MockConnectionState()
|
||||
}
|
||||
|
||||
// Handshake calls MockHandshake.
|
||||
func (c *TLSConn) Handshake() error {
|
||||
return c.MockHandshake()
|
||||
// HandshakeContext calls MockHandshakeContext.
|
||||
func (c *TLSConn) HandshakeContext(ctx context.Context) error {
|
||||
return c.MockHandshakeContext(ctx)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"reflect"
|
||||
@@ -20,14 +21,14 @@ func TestTLSConnConnectionState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConnHandshake(t *testing.T) {
|
||||
func TestTLSConnHandshakeContext(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
c := &TLSConn{
|
||||
MockHandshake: func() error {
|
||||
MockHandshakeContext: func(ctx context.Context) error {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
err := c.Handshake()
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
oohttp "github.com/ooni/oohttp"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -103,17 +105,12 @@ func ConfigureTLSVersion(config *tls.Config, version string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TLSConn is any tls.Conn-like structure.
|
||||
type TLSConn interface {
|
||||
// net.Conn is the embedded conn.
|
||||
net.Conn
|
||||
// TLSConn is the type of connection that oohttp expects from
|
||||
// any library that implements TLS functionality.
|
||||
type TLSConn = oohttp.TLSConn
|
||||
|
||||
// ConnectionState returns the TLS connection state.
|
||||
ConnectionState() tls.ConnectionState
|
||||
|
||||
// Handshake performs the handshake.
|
||||
Handshake() error
|
||||
}
|
||||
// Ensures that a tls.Conn implements the TLSConn interface.
|
||||
var _ TLSConn = &tls.Conn{}
|
||||
|
||||
// TLSHandshaker is the generic TLS handshaker.
|
||||
type TLSHandshaker interface {
|
||||
@@ -154,11 +151,6 @@ var defaultCertPool = NewDefaultCertPool()
|
||||
// Handshake implements Handshaker.Handshake. This function will
|
||||
// configure the code to use the built-in Mozilla CA if the config
|
||||
// field contains a nil RootCAs field.
|
||||
//
|
||||
// Bug
|
||||
//
|
||||
// Until Go 1.17 is released, this function will not honour
|
||||
// the context. We'll however always enforce an overall timeout.
|
||||
func (h *tlsHandshakerConfigurable) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
@@ -173,7 +165,7 @@ func (h *tlsHandshakerConfigurable) Handshake(
|
||||
config.RootCAs = defaultCertPool
|
||||
}
|
||||
tlsconn := h.newConn(conn, config)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
if err := tlsconn.HandshakeContext(ctx); err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
return tlsconn, tlsconn.ConnectionState(), nil
|
||||
|
||||
@@ -190,7 +190,7 @@ func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
|
||||
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
|
||||
gotTLSConfig = config
|
||||
return &mocks.TLSConn{
|
||||
MockHandshake: func() error {
|
||||
MockHandshakeContext: func(ctx context.Context) error {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package netxlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
@@ -21,9 +22,13 @@ func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
|
||||
// utlsConn implements TLSConn and uses a utls UConn as its underlying connection
|
||||
type utlsConn struct {
|
||||
*utls.UConn
|
||||
testableHandshake func() error
|
||||
}
|
||||
|
||||
// newConnUTLS creates a NewConn function creating a utls connection with a specified ClientHelloID
|
||||
// Ensures that a utlsConn implements the TLSConn interface.
|
||||
var _ TLSConn = &utlsConn{}
|
||||
|
||||
// newConnUTLS returns a NewConn function for creating utlsConn instances.
|
||||
func newConnUTLS(clientHello *utls.ClientHelloID) func(conn net.Conn, config *tls.Config) TLSConn {
|
||||
return func(conn net.Conn, config *tls.Config) TLSConn {
|
||||
uConfig := &utls.Config{
|
||||
@@ -34,10 +39,30 @@ func newConnUTLS(clientHello *utls.ClientHelloID) func(conn net.Conn, config *tl
|
||||
DynamicRecordSizingDisabled: config.DynamicRecordSizingDisabled,
|
||||
}
|
||||
tlsConn := utls.UClient(conn, uConfig, *clientHello)
|
||||
return &utlsConn{tlsConn}
|
||||
return &utlsConn{UConn: tlsConn}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *utlsConn) HandshakeContext(ctx context.Context) error {
|
||||
errch := make(chan error, 1)
|
||||
go func() {
|
||||
errch <- c.handshakefn()()
|
||||
}()
|
||||
select {
|
||||
case err := <-errch:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *utlsConn) handshakefn() func() error {
|
||||
if c.testableHandshake != nil {
|
||||
return c.testableHandshake
|
||||
}
|
||||
return c.UConn.Handshake
|
||||
}
|
||||
|
||||
func (c *utlsConn) ConnectionState() tls.ConnectionState {
|
||||
uState := c.Conn.ConnectionState()
|
||||
return tls.ConnectionState{
|
||||
|
||||
@@ -3,8 +3,11 @@ package netxlite
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apex/log"
|
||||
utls "gitlab.com/yawning/utls.git"
|
||||
@@ -45,3 +48,37 @@ func TestNewTLSHandshakerUTLSTypes(t *testing.T) {
|
||||
t.Fatal("expected non-nil NewConn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUTLSConnHandshakeNotInterrupted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conn := &utlsConn{
|
||||
testableHandshake: func() error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
err := conn.HandshakeContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUTLSConnHandshakeInterrupted(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
sigch := make(chan interface{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
conn := &utlsConn{
|
||||
testableHandshake: func() error {
|
||||
defer wg.Done()
|
||||
<-sigch
|
||||
return nil
|
||||
},
|
||||
}
|
||||
err := conn.HandshakeContext(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
close(sigch)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user