diff --git a/internal/archival/quic.go b/internal/archival/quic.go index 5e34c69..3a56976 100644 --- a/internal/archival/quic.go +++ b/internal/archival/quic.go @@ -64,12 +64,8 @@ func (s *Saver) QUICDialContext(ctx context.Context, dialer model.QUICDialer, var state tls.ConnectionState sess, err := dialer.DialContext(ctx, network, address, tlsConfig, quicConfig) if err == nil { - select { - case <-sess.HandshakeComplete().Done(): - state = sess.ConnectionState().TLS.ConnectionState - case <-ctx.Done(): - sess, err = nil, ctx.Err() - } + <-sess.HandshakeComplete().Done() // robustness (the dialer already does that) + state = sess.ConnectionState().TLS.ConnectionState } s.appendQUICHandshake(&QUICTLSHandshakeEvent{ ALPN: tlsConfig.NextProtos, diff --git a/internal/archival/quic_test.go b/internal/archival/quic_test.go index caabccf..4f9eae1 100644 --- a/internal/archival/quic_test.go +++ b/internal/archival/quic_test.go @@ -264,49 +264,6 @@ func TestSaverQUICDialContext(t *testing.T) { } }) - t.Run("on handshake timeout", func(t *testing.T) { - handshakeCtx := context.Background() - handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx) - defer handshakeCancel() - const expectedNetwork = "udp" - const mockedEndpoint = "8.8.4.4:443" - saver := NewSaver() - v := &SingleQUICTLSHandshakeValidator{ - ExpectedALPN: []string{"h3"}, - ExpectedSNI: "dns.google", - ExpectedSkipVerify: true, - // - ExpectedCipherSuite: 0, - ExpectedNegotiatedProtocol: "", - ExpectedPeerCerts: nil, - ExpectedVersion: 0, - // - ExpectedNetwork: "quic", - ExpectedRemoteAddr: mockedEndpoint, - // - QUICConfig: &quic.Config{}, - // - ExpectedFailure: context.DeadlineExceeded, - Saver: saver, - } - qconn := newQUICConnection(handshakeCtx, tls.ConnectionState{}) - dialer := newQUICDialer(qconn, nil) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Microsecond) - defer cancel() - qconn, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, - mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("unexpected error") - } - if qconn != nil { - t.Fatal("expected nil connection") - } - if err := v.Validate(); err != nil { - t.Fatal(err) - } - }) - t.Run("on other error", func(t *testing.T) { mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) const expectedNetwork = "udp" diff --git a/internal/archival/quic_test_go118.go b/internal/archival/quic_test_go118.go index 77a5ca3..3810913 100644 --- a/internal/archival/quic_test_go118.go +++ b/internal/archival/quic_test_go118.go @@ -264,49 +264,6 @@ func TestSaverQUICDialContext(t *testing.T) { } }) - t.Run("on handshake timeout", func(t *testing.T) { - handshakeCtx := context.Background() - handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx) - defer handshakeCancel() - const expectedNetwork = "udp" - const mockedEndpoint = "8.8.4.4:443" - saver := NewSaver() - v := &SingleQUICTLSHandshakeValidator{ - ExpectedALPN: []string{"h3"}, - ExpectedSNI: "dns.google", - ExpectedSkipVerify: true, - // - ExpectedCipherSuite: 0, - ExpectedNegotiatedProtocol: "", - ExpectedPeerCerts: nil, - ExpectedVersion: 0, - // - ExpectedNetwork: "quic", - ExpectedRemoteAddr: mockedEndpoint, - // - QUICConfig: &quic.Config{}, - // - ExpectedFailure: context.DeadlineExceeded, - Saver: saver, - } - qconn := newQUICConnection(handshakeCtx, tls.ConnectionState{}) - dialer := newQUICDialer(qconn, nil) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Microsecond) - defer cancel() - qconn, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, - mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) - if !errors.Is(err, context.DeadlineExceeded) { - t.Fatal("unexpected error") - } - if qconn != nil { - t.Fatal("expected nil connection") - } - if err := v.Validate(); err != nil { - t.Fatal(err) - } - }) - t.Run("on other error", func(t *testing.T) { mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) const expectedNetwork = "udp" diff --git a/internal/measurex/quic.go b/internal/measurex/quic.go index f03a887..64b6630 100644 --- a/internal/measurex/quic.go +++ b/internal/measurex/quic.go @@ -118,12 +118,8 @@ func (qh *quicDialerDB) DialContext(ctx context.Context, network, address string defer dialer.CloseIdleConnections() sess, err := dialer.DialContext(ctx, network, address, tlsConfig, quicConfig) if err == nil { - select { - case <-sess.HandshakeComplete().Done(): - state = sess.ConnectionState().TLS.ConnectionState - case <-ctx.Done(): - sess, err = nil, ctx.Err() - } + <-sess.HandshakeComplete().Done() // robustness (the dialer already does that) + state = sess.ConnectionState().TLS.ConnectionState } finished := time.Since(qh.begin).Seconds() qh.db.InsertIntoQUICHandshake(&QUICTLSHandshakeEvent{ diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index a8ed918..a6d1ec1 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -55,9 +55,12 @@ func NewQUICDialerWithResolver(listener model.QUICListener, Dialer: &quicDialerResolver{ Dialer: &quicDialerLogger{ Dialer: &quicDialerErrWrapper{ - QUICDialer: &quicDialerQUICGo{ - QUICListener: listener, - }}, + QUICDialer: &quicDialerHandshakeCompleter{ + Dialer: &quicDialerQUICGo{ + QUICListener: listener, + }, + }, + }, Logger: logger, operationSuffix: "_address", }, @@ -164,6 +167,33 @@ func (d *quicDialerQUICGo) CloseIdleConnections() { // nothing to do } +// quicDialerHandshakeCompleter ensures we complete the handshake. +type quicDialerHandshakeCompleter struct { + Dialer model.QUICDialer +} + +// DialContext implements model.QUICDialer.DialContext. +func (d *quicDialerHandshakeCompleter) DialContext( + ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + conn, err := d.Dialer.DialContext(ctx, network, address, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + select { + case <-conn.HandshakeComplete().Done(): + return conn, nil + case <-ctx.Done(): + conn.CloseWithError(0, "") // we own the conn + return nil, ctx.Err() + } +} + +// CloseIdleConnections implements model.QUICDialer.CloseIdleConnections. +func (d *quicDialerHandshakeCompleter) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() +} + // quicConnectionOwnsConn ensures that we close the UDPLikeConn. type quicConnectionOwnsConn struct { // EarlyConnection is the embedded early connection diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index cbd0495..b512e92 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -38,7 +38,8 @@ func TestNewQUICDialer(t *testing.T) { t.Fatal("invalid logger") } errWrapper := logger.Dialer.(*quicDialerErrWrapper) - base := errWrapper.QUICDialer.(*quicDialerQUICGo) + handshakeCompleter := errWrapper.QUICDialer.(*quicDialerHandshakeCompleter) + base := handshakeCompleter.Dialer.(*quicDialerQUICGo) if base.QUICListener != ql { t.Fatal("invalid quic listener") } @@ -227,6 +228,107 @@ func TestQUICDialerQUICGo(t *testing.T) { }) } +func TestQUICDialerHandshakeCompleter(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("in case of failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &quicDialerHandshakeCompleter{ + Dialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return nil, expected + }, + }, + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "udp", "8.8.8.8:443", &tls.Config{}, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("in case of context cancellation", func(t *testing.T) { + handshakeCtx, handshakeCancel := context.WithCancel(context.Background()) + defer handshakeCancel() + ctx, cancel := context.WithCancel(context.Background()) + var called bool + expected := &mocks.QUICEarlyConnection{ + MockHandshakeComplete: func() context.Context { + cancel() + return handshakeCtx + }, + MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error { + called = true + return nil + }, + } + d := &quicDialerHandshakeCompleter{ + Dialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return expected, nil + }, + }, + } + conn, err := d.DialContext(ctx, "udp", "8.8.8.8:443", &tls.Config{}, &quic.Config{}) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + if !called { + t.Fatal("not called") + } + }) + + t.Run("in case of success", func(t *testing.T) { + handshakeCtx, handshakeCancel := context.WithCancel(context.Background()) + defer handshakeCancel() + expected := &mocks.QUICEarlyConnection{ + MockHandshakeComplete: func() context.Context { + handshakeCancel() + return handshakeCtx + }, + } + d := &quicDialerHandshakeCompleter{ + Dialer: &mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return expected, nil + }, + }, + } + conn, err := d.DialContext( + context.Background(), "udp", "8.8.8.8:443", &tls.Config{}, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var forDialer bool + d := &quicDialerHandshakeCompleter{ + Dialer: &mocks.QUICDialer{ + MockCloseIdleConnections: func() { + forDialer = true + }, + }, + } + d.CloseIdleConnections() + if !forDialer { + t.Fatal("not called") + } + }) +} + func TestQUICDialerResolver(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) { var (