diff --git a/internal/netxlite/dialer.go b/internal/netxlite/dialer.go index 0489a58..dd734ab 100644 --- a/internal/netxlite/dialer.go +++ b/internal/netxlite/dialer.go @@ -164,7 +164,7 @@ func NewSingleUseDialer(conn net.Conn) Dialer { return &dialerSingleUse{conn: conn} } -// dialerSingleUse is the type of Dialer returned by NewSingleDialer. +// dialerSingleUse is the Dialer returned by NewSingleDialer. type dialerSingleUse struct { sync.Mutex conn net.Conn diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index abfec10..9802863 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -6,6 +6,7 @@ import ( "errors" "net" "strconv" + "sync" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/netxlite/quicx" @@ -288,3 +289,36 @@ func (d *quicDialerLogger) DialContext( func (d *quicDialerLogger) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } + +// NewSingleUseQUICDialer returns a dialer that returns the given connection +// once and after that always fails with the ErrNoConnReuse error. +func NewSingleUseQUICDialer(sess quic.EarlySession) QUICDialer { + return &quicDialerSingleUse{sess: sess} +} + +// quicDialerSingleUse is the QUICDialer returned by NewSingleQUICDialer. +type quicDialerSingleUse struct { + sync.Mutex + sess quic.EarlySession +} + +var _ QUICDialer = &quicDialerSingleUse{} + +// DialContext implements QUICDialer.DialContext. +func (s *quicDialerSingleUse) DialContext( + ctx context.Context, network, addr string, tlsCfg *tls.Config, + cfg *quic.Config) (quic.EarlySession, error) { + var sess quic.EarlySession + defer s.Unlock() + s.Lock() + if s.sess == nil { + return nil, ErrNoConnReuse + } + sess, s.sess = s.sess, nil + return sess, nil +} + +// CloseIdleConnections closes idle connections. +func (s *quicDialerSingleUse) CloseIdleConnections() { + // nothing to do +} diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index 470ac2f..8290a30 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -460,3 +460,26 @@ func TestNewQUICDialerWithoutResolverChain(t *testing.T) { t.Fatal("invalid quic listener") } } + +func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) { + sess := &mocks.QUICEarlySession{} + qd := NewSingleUseQUICDialer(sess) + outsess, err := qd.DialContext( + context.Background(), "", "", &tls.Config{}, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + if sess != outsess { + t.Fatal("invalid outsess") + } + for i := 0; i < 4; i++ { + outsess, err = qd.DialContext( + context.Background(), "", "", &tls.Config{}, &quic.Config{}) + if !errors.Is(err, ErrNoConnReuse) { + t.Fatal("not the error we expected", err) + } + if outsess != nil { + t.Fatal("expected nil outconn here") + } + } +}