From c00cad1382e70a7a5fba7b056e15df7d37c63c71 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Fri, 25 Jun 2021 16:20:08 +0200 Subject: [PATCH] refactor(quicdialer): separate saving from listening (#405) With this change, we will soon be able to move the creation of a QUIC session inside of the netxlite package. Part of https://github.com/ooni/probe/issues/1505. --- internal/engine/netx/netx.go | 11 +++- internal/engine/netx/quicdialer/dns_test.go | 12 +++-- .../netx/quicdialer/errorwrapper_test.go | 8 ++- internal/engine/netx/quicdialer/saver_test.go | 12 +++-- internal/engine/netx/quicdialer/system.go | 52 ++++++++++++++----- .../engine/netx/quicdialer/system_test.go | 12 ++++- 6 files changed, 83 insertions(+), 24 deletions(-) diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index ced5f12..dace1e9 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -160,7 +160,16 @@ func NewQUICDialer(config Config) QUICDialer { if config.FullResolver == nil { config.FullResolver = NewResolver(config) } - var d quicdialer.ContextDialer = &quicdialer.SystemDialer{Saver: config.ReadWriteSaver} + var ql quicdialer.QUICListener = &quicdialer.QUICListenerStdlib{} + if config.ReadWriteSaver != nil { + ql = &quicdialer.QUICListenerSaver{ + QUICListener: ql, + Saver: config.ReadWriteSaver, + } + } + var d quicdialer.ContextDialer = &quicdialer.SystemDialer{ + QUICListener: ql, + } d = quicdialer.ErrorWrapperDialer{Dialer: d} if config.TLSSaver != nil { d = quicdialer.HandshakeSaver{Saver: config.TLSSaver, Dialer: d} diff --git a/internal/engine/netx/quicdialer/dns_test.go b/internal/engine/netx/quicdialer/dns_test.go index a35c078..c62cf2a 100644 --- a/internal/engine/netx/quicdialer/dns_test.go +++ b/internal/engine/netx/quicdialer/dns_test.go @@ -25,7 +25,9 @@ func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string func TestDNSDialerSuccess(t *testing.T) { tlsConf := &tls.Config{NextProtos: []string{"h3"}} dialer := quicdialer.DNSDialer{ - Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}} + Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + }} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com:443", tlsConf, &quic.Config{}) @@ -88,7 +90,9 @@ func TestDNSDialerLookupHostFailure(t *testing.T) { func TestDNSDialerInvalidPort(t *testing.T) { tlsConf := &tls.Config{NextProtos: []string{"h3"}} dialer := quicdialer.DNSDialer{ - Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}} + Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + }} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com:0", tlsConf, &quic.Config{}) @@ -107,7 +111,9 @@ func TestDNSDialerInvalidPort(t *testing.T) { func TestDNSDialerInvalidPortSyntax(t *testing.T) { tlsConf := &tls.Config{NextProtos: []string{"h3"}} dialer := quicdialer.DNSDialer{ - Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}} + Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + }} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com:port", tlsConf, &quic.Config{}) diff --git a/internal/engine/netx/quicdialer/errorwrapper_test.go b/internal/engine/netx/quicdialer/errorwrapper_test.go index fe360b7..10c763f 100644 --- a/internal/engine/netx/quicdialer/errorwrapper_test.go +++ b/internal/engine/netx/quicdialer/errorwrapper_test.go @@ -48,7 +48,9 @@ func TestErrorWrapperInvalidCertificate(t *testing.T) { ServerName: servername, } - dlr := quicdialer.ErrorWrapperDialer{Dialer: &quicdialer.SystemDialer{}} + dlr := quicdialer.ErrorWrapperDialer{Dialer: &quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + }} // use Google IP sess, err := dlr.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{}) @@ -69,7 +71,9 @@ func TestErrorWrapperSuccess(t *testing.T) { NextProtos: []string{"h3"}, ServerName: "www.google.com", } - d := quicdialer.ErrorWrapperDialer{Dialer: quicdialer.SystemDialer{}} + d := quicdialer.ErrorWrapperDialer{Dialer: quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + }} sess, err := d.DialContext(ctx, "udp", "216.58.212.164:443", tlsConf, &quic.Config{}) if err != nil { t.Fatal(err) diff --git a/internal/engine/netx/quicdialer/saver_test.go b/internal/engine/netx/quicdialer/saver_test.go index 1e93657..4f8b775 100644 --- a/internal/engine/netx/quicdialer/saver_test.go +++ b/internal/engine/netx/quicdialer/saver_test.go @@ -36,8 +36,10 @@ func TestHandshakeSaverSuccess(t *testing.T) { } saver := &trace.Saver{} dlr := quicdialer.HandshakeSaver{ - Dialer: quicdialer.SystemDialer{}, - Saver: saver, + Dialer: quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + }, + Saver: saver, } sess, err := dlr.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{}) @@ -92,8 +94,10 @@ func TestHandshakeSaverHostNameError(t *testing.T) { } saver := &trace.Saver{} dlr := quicdialer.HandshakeSaver{ - Dialer: quicdialer.SystemDialer{}, - Saver: saver, + Dialer: quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + }, + Saver: saver, } sess, err := dlr.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{}) diff --git a/internal/engine/netx/quicdialer/system.go b/internal/engine/netx/quicdialer/system.go index 702f7dd..cabfde1 100644 --- a/internal/engine/netx/quicdialer/system.go +++ b/internal/engine/netx/quicdialer/system.go @@ -13,13 +13,47 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) +// QUICListener listens for QUIC connections. +type QUICListener interface { + // Listen creates a new listening net.PacketConn. + Listen(addr *net.UDPAddr) (net.PacketConn, error) +} + +// QUICListenerStdlib is a QUICListener using the standard library. +type QUICListenerStdlib struct{} + +// Listen implements QUICListener.Listen. +func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error) { + return net.ListenUDP("udp", addr) +} + +// QUICListenerSaver is a QUICListener that also implements saving events. +type QUICListenerSaver struct { + // QUICListener is the underlying QUICListener. + QUICListener QUICListener + + // Saver is the underlying Saver. + Saver *trace.Saver +} + +// Listen implements QUICListener.Listen. +func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (net.PacketConn, error) { + pconn, err := qls.QUICListener.Listen(addr) + if err != nil { + return nil, err + } + // TODO(bassosimone): refactor to remove this restriction. + udpConn, ok := pconn.(*net.UDPConn) + if !ok { + return nil, errors.New("quicdialer: cannot convert to udpConn") + } + return saverUDPConn{UDPConn: udpConn, saver: qls.Saver}, nil +} + // SystemDialer is the basic dialer for QUIC type SystemDialer struct { - // Saver saves read/write events on the underlying UDP - // connection. (Implementation note: we need it here since - // this is the only part in the codebase that is able to - // observe the underlying UDP connection.) - Saver *trace.Saver + // QUICListener is the underlying QUICListener to use. + QUICListener QUICListener } // DialContext implements ContextDialer.DialContext @@ -35,20 +69,14 @@ func (d SystemDialer) DialContext(ctx context.Context, network string, } ip := net.ParseIP(onlyhost) if ip == nil { - // TODO(kelmenhorst): write test for this error condition. return nil, errors.New("quicdialer: invalid IP representation") } - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + pconn, err := d.QUICListener.Listen(&net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { return nil, err } - var pconn net.PacketConn = udpConn - if d.Saver != nil { - pconn = saverUDPConn{UDPConn: udpConn, saver: d.Saver} - } udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""} return quic.DialEarlyContext(ctx, pconn, udpAddr, host, tlsCfg, cfg) - } type saverUDPConn struct { diff --git a/internal/engine/netx/quicdialer/system_test.go b/internal/engine/netx/quicdialer/system_test.go index 309e8cf..cc4d664 100644 --- a/internal/engine/netx/quicdialer/system_test.go +++ b/internal/engine/netx/quicdialer/system_test.go @@ -18,7 +18,10 @@ func TestSystemDialerInvalidIPFailure(t *testing.T) { } saver := &trace.Saver{} systemdialer := quicdialer.SystemDialer{ - Saver: saver, + QUICListener: &quicdialer.QUICListenerSaver{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + Saver: saver, + }, } sess, err := systemdialer.DialContext(context.Background(), "udp", "a.b.c.d:0", tlsConf, &quic.Config{}) if err == nil { @@ -39,7 +42,12 @@ func TestSystemDialerSuccessWithReadWrite(t *testing.T) { ServerName: "www.google.com", } saver := &trace.Saver{} - systemdialer := quicdialer.SystemDialer{Saver: saver} + systemdialer := quicdialer.SystemDialer{ + QUICListener: &quicdialer.QUICListenerSaver{ + QUICListener: &quicdialer.QUICListenerStdlib{}, + Saver: saver, + }, + } _, err := systemdialer.DialContext(context.Background(), "udp", "216.58.212.164:443", tlsConf, &quic.Config{}) if err != nil {