feat: upgrade oohttp and propagate changes (#461)

Part of https://github.com/ooni/probe/issues/1506
This commit is contained in:
Simone Basso
2021-09-05 21:23:47 +02:00
committed by GitHub
parent 5b8df394b1
commit b834af83ac
8 changed files with 95 additions and 35 deletions
+9 -6
View File
@@ -1,6 +1,9 @@
package mocks
import "crypto/tls"
import (
"context"
"crypto/tls"
)
// TLSConn allows to mock netxlite.TLSConn.
type TLSConn struct {
@@ -10,8 +13,8 @@ type TLSConn struct {
// MockConnectionState allows to mock the ConnectionState method.
MockConnectionState func() tls.ConnectionState
// MockHandshake allows to mock the Handshake method.
MockHandshake func() error
// MockHandshakeContext allows to mock the HandshakeContext method.
MockHandshakeContext func(ctx context.Context) error
}
// ConnectionState calls MockConnectionState.
@@ -19,7 +22,7 @@ func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}
// Handshake calls MockHandshake.
func (c *TLSConn) Handshake() error {
return c.MockHandshake()
// HandshakeContext calls MockHandshakeContext.
func (c *TLSConn) HandshakeContext(ctx context.Context) error {
return c.MockHandshakeContext(ctx)
}
+4 -3
View File
@@ -1,6 +1,7 @@
package mocks
import (
"context"
"crypto/tls"
"errors"
"reflect"
@@ -20,14 +21,14 @@ func TestTLSConnConnectionState(t *testing.T) {
}
}
func TestTLSConnHandshake(t *testing.T) {
func TestTLSConnHandshakeContext(t *testing.T) {
expected := errors.New("mocked error")
c := &TLSConn{
MockHandshake: func() error {
MockHandshakeContext: func(ctx context.Context) error {
return expected
},
}
err := c.Handshake()
err := c.HandshakeContext(context.Background())
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
+8 -16
View File
@@ -8,6 +8,8 @@ import (
"fmt"
"net"
"time"
oohttp "github.com/ooni/oohttp"
)
var (
@@ -103,17 +105,12 @@ func ConfigureTLSVersion(config *tls.Config, version string) error {
return nil
}
// TLSConn is any tls.Conn-like structure.
type TLSConn interface {
// net.Conn is the embedded conn.
net.Conn
// TLSConn is the type of connection that oohttp expects from
// any library that implements TLS functionality.
type TLSConn = oohttp.TLSConn
// ConnectionState returns the TLS connection state.
ConnectionState() tls.ConnectionState
// Handshake performs the handshake.
Handshake() error
}
// Ensures that a tls.Conn implements the TLSConn interface.
var _ TLSConn = &tls.Conn{}
// TLSHandshaker is the generic TLS handshaker.
type TLSHandshaker interface {
@@ -154,11 +151,6 @@ var defaultCertPool = NewDefaultCertPool()
// Handshake implements Handshaker.Handshake. This function will
// configure the code to use the built-in Mozilla CA if the config
// field contains a nil RootCAs field.
//
// Bug
//
// Until Go 1.17 is released, this function will not honour
// the context. We'll however always enforce an overall timeout.
func (h *tlsHandshakerConfigurable) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
@@ -173,7 +165,7 @@ func (h *tlsHandshakerConfigurable) Handshake(
config.RootCAs = defaultCertPool
}
tlsconn := h.newConn(conn, config)
if err := tlsconn.Handshake(); err != nil {
if err := tlsconn.HandshakeContext(ctx); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
+1 -1
View File
@@ -190,7 +190,7 @@ func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
gotTLSConfig = config
return &mocks.TLSConn{
MockHandshake: func() error {
MockHandshakeContext: func(ctx context.Context) error {
return expected
},
}
+27 -2
View File
@@ -1,6 +1,7 @@
package netxlite
import (
"context"
"crypto/tls"
"net"
@@ -21,9 +22,13 @@ func NewTLSHandshakerUTLS(logger Logger, id *utls.ClientHelloID) TLSHandshaker {
// utlsConn implements TLSConn and uses a utls UConn as its underlying connection
type utlsConn struct {
*utls.UConn
testableHandshake func() error
}
// newConnUTLS creates a NewConn function creating a utls connection with a specified ClientHelloID
// Ensures that a utlsConn implements the TLSConn interface.
var _ TLSConn = &utlsConn{}
// newConnUTLS returns a NewConn function for creating utlsConn instances.
func newConnUTLS(clientHello *utls.ClientHelloID) func(conn net.Conn, config *tls.Config) TLSConn {
return func(conn net.Conn, config *tls.Config) TLSConn {
uConfig := &utls.Config{
@@ -34,10 +39,30 @@ func newConnUTLS(clientHello *utls.ClientHelloID) func(conn net.Conn, config *tl
DynamicRecordSizingDisabled: config.DynamicRecordSizingDisabled,
}
tlsConn := utls.UClient(conn, uConfig, *clientHello)
return &utlsConn{tlsConn}
return &utlsConn{UConn: tlsConn}
}
}
func (c *utlsConn) HandshakeContext(ctx context.Context) error {
errch := make(chan error, 1)
go func() {
errch <- c.handshakefn()()
}()
select {
case err := <-errch:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (c *utlsConn) handshakefn() func() error {
if c.testableHandshake != nil {
return c.testableHandshake
}
return c.UConn.Handshake
}
func (c *utlsConn) ConnectionState() tls.ConnectionState {
uState := c.Conn.ConnectionState()
return tls.ConnectionState{
+37
View File
@@ -3,8 +3,11 @@ package netxlite
import (
"context"
"crypto/tls"
"errors"
"net"
"sync"
"testing"
"time"
"github.com/apex/log"
utls "gitlab.com/yawning/utls.git"
@@ -45,3 +48,37 @@ func TestNewTLSHandshakerUTLSTypes(t *testing.T) {
t.Fatal("expected non-nil NewConn")
}
}
func TestUTLSConnHandshakeNotInterrupted(t *testing.T) {
ctx := context.Background()
conn := &utlsConn{
testableHandshake: func() error {
return nil
},
}
err := conn.HandshakeContext(ctx)
if err != nil {
t.Fatal(err)
}
}
func TestUTLSConnHandshakeInterrupted(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(1)
sigch := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
conn := &utlsConn{
testableHandshake: func() error {
defer wg.Done()
<-sigch
return nil
},
}
err := conn.HandshakeContext(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
}
close(sigch)
wg.Wait()
}