cleanup(quic): wait for handshake completion in netxlite (#729)

See https://github.com/ooni/probe/issues/2097
This commit is contained in:
Simone Basso 2022-05-14 16:32:32 +02:00 committed by GitHub
parent 5904e6988d
commit 2238908afe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 140 additions and 102 deletions

View File

@ -64,12 +64,8 @@ func (s *Saver) QUICDialContext(ctx context.Context, dialer model.QUICDialer,
var state tls.ConnectionState
sess, err := dialer.DialContext(ctx, network, address, tlsConfig, quicConfig)
if err == nil {
select {
case <-sess.HandshakeComplete().Done():
state = sess.ConnectionState().TLS.ConnectionState
case <-ctx.Done():
sess, err = nil, ctx.Err()
}
<-sess.HandshakeComplete().Done() // robustness (the dialer already does that)
state = sess.ConnectionState().TLS.ConnectionState
}
s.appendQUICHandshake(&QUICTLSHandshakeEvent{
ALPN: tlsConfig.NextProtos,

View File

@ -264,49 +264,6 @@ func TestSaverQUICDialContext(t *testing.T) {
}
})
t.Run("on handshake timeout", func(t *testing.T) {
handshakeCtx := context.Background()
handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx)
defer handshakeCancel()
const expectedNetwork = "udp"
const mockedEndpoint = "8.8.4.4:443"
saver := NewSaver()
v := &SingleQUICTLSHandshakeValidator{
ExpectedALPN: []string{"h3"},
ExpectedSNI: "dns.google",
ExpectedSkipVerify: true,
//
ExpectedCipherSuite: 0,
ExpectedNegotiatedProtocol: "",
ExpectedPeerCerts: nil,
ExpectedVersion: 0,
//
ExpectedNetwork: "quic",
ExpectedRemoteAddr: mockedEndpoint,
//
QUICConfig: &quic.Config{},
//
ExpectedFailure: context.DeadlineExceeded,
Saver: saver,
}
qconn := newQUICConnection(handshakeCtx, tls.ConnectionState{})
dialer := newQUICDialer(qconn, nil)
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Microsecond)
defer cancel()
qconn, err := saver.QUICDialContext(ctx, dialer, expectedNetwork,
mockedEndpoint, v.NewTLSConfig(), v.QUICConfig)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("unexpected error")
}
if qconn != nil {
t.Fatal("expected nil connection")
}
if err := v.Validate(); err != nil {
t.Fatal(err)
}
})
t.Run("on other error", func(t *testing.T) {
mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF)
const expectedNetwork = "udp"

View File

@ -264,49 +264,6 @@ func TestSaverQUICDialContext(t *testing.T) {
}
})
t.Run("on handshake timeout", func(t *testing.T) {
handshakeCtx := context.Background()
handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx)
defer handshakeCancel()
const expectedNetwork = "udp"
const mockedEndpoint = "8.8.4.4:443"
saver := NewSaver()
v := &SingleQUICTLSHandshakeValidator{
ExpectedALPN: []string{"h3"},
ExpectedSNI: "dns.google",
ExpectedSkipVerify: true,
//
ExpectedCipherSuite: 0,
ExpectedNegotiatedProtocol: "",
ExpectedPeerCerts: nil,
ExpectedVersion: 0,
//
ExpectedNetwork: "quic",
ExpectedRemoteAddr: mockedEndpoint,
//
QUICConfig: &quic.Config{},
//
ExpectedFailure: context.DeadlineExceeded,
Saver: saver,
}
qconn := newQUICConnection(handshakeCtx, tls.ConnectionState{})
dialer := newQUICDialer(qconn, nil)
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Microsecond)
defer cancel()
qconn, err := saver.QUICDialContext(ctx, dialer, expectedNetwork,
mockedEndpoint, v.NewTLSConfig(), v.QUICConfig)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("unexpected error")
}
if qconn != nil {
t.Fatal("expected nil connection")
}
if err := v.Validate(); err != nil {
t.Fatal(err)
}
})
t.Run("on other error", func(t *testing.T) {
mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF)
const expectedNetwork = "udp"

View File

@ -118,12 +118,8 @@ func (qh *quicDialerDB) DialContext(ctx context.Context, network, address string
defer dialer.CloseIdleConnections()
sess, err := dialer.DialContext(ctx, network, address, tlsConfig, quicConfig)
if err == nil {
select {
case <-sess.HandshakeComplete().Done():
state = sess.ConnectionState().TLS.ConnectionState
case <-ctx.Done():
sess, err = nil, ctx.Err()
}
<-sess.HandshakeComplete().Done() // robustness (the dialer already does that)
state = sess.ConnectionState().TLS.ConnectionState
}
finished := time.Since(qh.begin).Seconds()
qh.db.InsertIntoQUICHandshake(&QUICTLSHandshakeEvent{

View File

@ -55,9 +55,12 @@ func NewQUICDialerWithResolver(listener model.QUICListener,
Dialer: &quicDialerResolver{
Dialer: &quicDialerLogger{
Dialer: &quicDialerErrWrapper{
QUICDialer: &quicDialerQUICGo{
QUICListener: listener,
}},
QUICDialer: &quicDialerHandshakeCompleter{
Dialer: &quicDialerQUICGo{
QUICListener: listener,
},
},
},
Logger: logger,
operationSuffix: "_address",
},
@ -164,6 +167,33 @@ func (d *quicDialerQUICGo) CloseIdleConnections() {
// nothing to do
}
// quicDialerHandshakeCompleter ensures we complete the handshake.
type quicDialerHandshakeCompleter struct {
Dialer model.QUICDialer
}
// DialContext implements model.QUICDialer.DialContext.
func (d *quicDialerHandshakeCompleter) DialContext(
ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
conn, err := d.Dialer.DialContext(ctx, network, address, tlsConfig, quicConfig)
if err != nil {
return nil, err
}
select {
case <-conn.HandshakeComplete().Done():
return conn, nil
case <-ctx.Done():
conn.CloseWithError(0, "") // we own the conn
return nil, ctx.Err()
}
}
// CloseIdleConnections implements model.QUICDialer.CloseIdleConnections.
func (d *quicDialerHandshakeCompleter) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
// quicConnectionOwnsConn ensures that we close the UDPLikeConn.
type quicConnectionOwnsConn struct {
// EarlyConnection is the embedded early connection

View File

@ -38,7 +38,8 @@ func TestNewQUICDialer(t *testing.T) {
t.Fatal("invalid logger")
}
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
handshakeCompleter := errWrapper.QUICDialer.(*quicDialerHandshakeCompleter)
base := handshakeCompleter.Dialer.(*quicDialerQUICGo)
if base.QUICListener != ql {
t.Fatal("invalid quic listener")
}
@ -227,6 +228,107 @@ func TestQUICDialerQUICGo(t *testing.T) {
})
}
func TestQUICDialerHandshakeCompleter(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("in case of failure", func(t *testing.T) {
expected := errors.New("mocked error")
d := &quicDialerHandshakeCompleter{
Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return nil, expected
},
},
}
ctx := context.Background()
conn, err := d.DialContext(ctx, "udp", "8.8.8.8:443", &tls.Config{}, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("in case of context cancellation", func(t *testing.T) {
handshakeCtx, handshakeCancel := context.WithCancel(context.Background())
defer handshakeCancel()
ctx, cancel := context.WithCancel(context.Background())
var called bool
expected := &mocks.QUICEarlyConnection{
MockHandshakeComplete: func() context.Context {
cancel()
return handshakeCtx
},
MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error {
called = true
return nil
},
}
d := &quicDialerHandshakeCompleter{
Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return expected, nil
},
},
}
conn, err := d.DialContext(ctx, "udp", "8.8.8.8:443", &tls.Config{}, &quic.Config{})
if !errors.Is(err, context.Canceled) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
if !called {
t.Fatal("not called")
}
})
t.Run("in case of success", func(t *testing.T) {
handshakeCtx, handshakeCancel := context.WithCancel(context.Background())
defer handshakeCancel()
expected := &mocks.QUICEarlyConnection{
MockHandshakeComplete: func() context.Context {
handshakeCancel()
return handshakeCtx
},
}
d := &quicDialerHandshakeCompleter{
Dialer: &mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return expected, nil
},
},
}
conn, err := d.DialContext(
context.Background(), "udp", "8.8.8.8:443", &tls.Config{}, &quic.Config{})
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var forDialer bool
d := &quicDialerHandshakeCompleter{
Dialer: &mocks.QUICDialer{
MockCloseIdleConnections: func() {
forDialer = true
},
},
}
d.CloseIdleConnections()
if !forDialer {
t.Fatal("not called")
}
})
}
func TestQUICDialerResolver(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var (