diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index dace1e9..111359d 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -160,14 +160,14 @@ func NewQUICDialer(config Config) QUICDialer { if config.FullResolver == nil { config.FullResolver = NewResolver(config) } - var ql quicdialer.QUICListener = &quicdialer.QUICListenerStdlib{} + var ql quicdialer.QUICListener = &netxlite.QUICListenerStdlib{} if config.ReadWriteSaver != nil { ql = &quicdialer.QUICListenerSaver{ QUICListener: ql, Saver: config.ReadWriteSaver, } } - var d quicdialer.ContextDialer = &quicdialer.SystemDialer{ + var d quicdialer.ContextDialer = &netxlite.QUICDialerQUICGo{ QUICListener: ql, } d = quicdialer.ErrorWrapperDialer{Dialer: d} diff --git a/internal/engine/netx/quicdialer/dns_test.go b/internal/engine/netx/quicdialer/dns_test.go index c62cf2a..9d485fb 100644 --- a/internal/engine/netx/quicdialer/dns_test.go +++ b/internal/engine/netx/quicdialer/dns_test.go @@ -11,6 +11,7 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) type MockableResolver struct { @@ -25,8 +26,8 @@ func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string func TestDNSDialerSuccess(t *testing.T) { tlsConf := &tls.Config{NextProtos: []string{"h3"}} dialer := quicdialer.DNSDialer{ - Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + Resolver: new(net.Resolver), Dialer: &netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, }} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com:443", @@ -42,7 +43,7 @@ func TestDNSDialerSuccess(t *testing.T) { func TestDNSDialerNoPort(t *testing.T) { tlsConf := &tls.Config{NextProtos: []string{"h3"}} dialer := quicdialer.DNSDialer{ - Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}} + Resolver: new(net.Resolver), Dialer: &netxlite.QUICDialerQUICGo{}} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com", tlsConf, &quic.Config{}) @@ -90,8 +91,8 @@ func TestDNSDialerLookupHostFailure(t *testing.T) { func TestDNSDialerInvalidPort(t *testing.T) { tlsConf := &tls.Config{NextProtos: []string{"h3"}} dialer := quicdialer.DNSDialer{ - Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + Resolver: new(net.Resolver), Dialer: &netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, }} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com:0", @@ -111,8 +112,8 @@ func TestDNSDialerInvalidPort(t *testing.T) { func TestDNSDialerInvalidPortSyntax(t *testing.T) { tlsConf := &tls.Config{NextProtos: []string{"h3"}} dialer := quicdialer.DNSDialer{ - Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + Resolver: new(net.Resolver), Dialer: &netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, }} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com:port", diff --git a/internal/engine/netx/quicdialer/errorwrapper_test.go b/internal/engine/netx/quicdialer/errorwrapper_test.go index 10c763f..21896b8 100644 --- a/internal/engine/netx/quicdialer/errorwrapper_test.go +++ b/internal/engine/netx/quicdialer/errorwrapper_test.go @@ -10,6 +10,7 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) func TestErrorWrapperFailure(t *testing.T) { @@ -48,8 +49,8 @@ func TestErrorWrapperInvalidCertificate(t *testing.T) { ServerName: servername, } - dlr := quicdialer.ErrorWrapperDialer{Dialer: &quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + dlr := quicdialer.ErrorWrapperDialer{Dialer: &netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, }} // use Google IP sess, err := dlr.DialContext(context.Background(), "udp", @@ -71,8 +72,8 @@ func TestErrorWrapperSuccess(t *testing.T) { NextProtos: []string{"h3"}, ServerName: "www.google.com", } - d := quicdialer.ErrorWrapperDialer{Dialer: quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + d := quicdialer.ErrorWrapperDialer{Dialer: &netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, }} sess, err := d.DialContext(ctx, "udp", "216.58.212.164:443", tlsConf, &quic.Config{}) if err != nil { diff --git a/internal/engine/netx/quicdialer/saver_test.go b/internal/engine/netx/quicdialer/saver_test.go index 4f8b775..89912e0 100644 --- a/internal/engine/netx/quicdialer/saver_test.go +++ b/internal/engine/netx/quicdialer/saver_test.go @@ -11,6 +11,7 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) type MockDialer struct { @@ -36,8 +37,8 @@ func TestHandshakeSaverSuccess(t *testing.T) { } saver := &trace.Saver{} dlr := quicdialer.HandshakeSaver{ - Dialer: quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + Dialer: &netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, }, Saver: saver, } @@ -94,8 +95,8 @@ func TestHandshakeSaverHostNameError(t *testing.T) { } saver := &trace.Saver{} dlr := quicdialer.HandshakeSaver{ - Dialer: quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + Dialer: &netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, }, Saver: saver, } diff --git a/internal/engine/netx/quicdialer/system.go b/internal/engine/netx/quicdialer/system.go index cabfde1..392378f 100644 --- a/internal/engine/netx/quicdialer/system.go +++ b/internal/engine/netx/quicdialer/system.go @@ -1,14 +1,10 @@ package quicdialer import ( - "context" - "crypto/tls" "errors" "net" - "strconv" "time" - "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) @@ -19,14 +15,6 @@ type QUICListener interface { Listen(addr *net.UDPAddr) (net.PacketConn, error) } -// QUICListenerStdlib is a QUICListener using the standard library. -type QUICListenerStdlib struct{} - -// Listen implements QUICListener.Listen. -func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error) { - return net.ListenUDP("udp", addr) -} - // QUICListenerSaver is a QUICListener that also implements saving events. type QUICListenerSaver struct { // QUICListener is the underlying QUICListener. @@ -50,35 +38,6 @@ func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (net.PacketConn, error) return saverUDPConn{UDPConn: udpConn, saver: qls.Saver}, nil } -// SystemDialer is the basic dialer for QUIC -type SystemDialer struct { - // QUICListener is the underlying QUICListener to use. - QUICListener QUICListener -} - -// DialContext implements ContextDialer.DialContext -func (d SystemDialer) DialContext(ctx context.Context, network string, - host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { - onlyhost, onlyport, err := net.SplitHostPort(host) - if err != nil { - return nil, err - } - port, err := strconv.Atoi(onlyport) - if err != nil { - return nil, err - } - ip := net.ParseIP(onlyhost) - if ip == nil { - return nil, errors.New("quicdialer: invalid IP representation") - } - pconn, err := d.QUICListener.Listen(&net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""} - return quic.DialEarlyContext(ctx, pconn, udpAddr, host, tlsCfg, cfg) -} - type saverUDPConn struct { *net.UDPConn saver *trace.Saver diff --git a/internal/engine/netx/quicdialer/system_test.go b/internal/engine/netx/quicdialer/system_test.go index cc4d664..1ea2440 100644 --- a/internal/engine/netx/quicdialer/system_test.go +++ b/internal/engine/netx/quicdialer/system_test.go @@ -9,32 +9,9 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestSystemDialerInvalidIPFailure(t *testing.T) { - tlsConf := &tls.Config{ - NextProtos: []string{"h3"}, - ServerName: "www.google.com", - } - saver := &trace.Saver{} - systemdialer := quicdialer.SystemDialer{ - QUICListener: &quicdialer.QUICListenerSaver{ - QUICListener: &quicdialer.QUICListenerStdlib{}, - Saver: saver, - }, - } - sess, err := systemdialer.DialContext(context.Background(), "udp", "a.b.c.d:0", tlsConf, &quic.Config{}) - if err == nil { - t.Fatal("expected an error here") - } - if sess != nil { - t.Fatal("expected nil sess here") - } - if err.Error() != "quicdialer: invalid IP representation" { - t.Fatal("expected another error here") - } -} - func TestSystemDialerSuccessWithReadWrite(t *testing.T) { // This is the most common use case for collecting reads, writes tlsConf := &tls.Config{ @@ -42,9 +19,9 @@ func TestSystemDialerSuccessWithReadWrite(t *testing.T) { ServerName: "www.google.com", } saver := &trace.Saver{} - systemdialer := quicdialer.SystemDialer{ + systemdialer := &netxlite.QUICDialerQUICGo{ QUICListener: &quicdialer.QUICListenerSaver{ - QUICListener: &quicdialer.QUICListenerStdlib{}, + QUICListener: &netxlite.QUICListenerStdlib{}, Saver: saver, }, } diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go new file mode 100644 index 0000000..5ae1749 --- /dev/null +++ b/internal/netxlite/quic.go @@ -0,0 +1,81 @@ +package netxlite + +import ( + "context" + "crypto/tls" + "errors" + "net" + "strconv" + + "github.com/lucas-clemente/quic-go" +) + +// QUICDialerContext is a dialer for QUIC using Context. +type QUICContextDialer interface { + // DialContext establishes a new QUIC session using the given + // network and address. The tlsConfig and the quicConfig arguments + // MUST NOT be nil. Returns either the session or an error. + DialContext(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) +} + +// QUICDialer dials QUIC connections. +type QUICDialer interface { + // DialContext establishes a new QUIC session using the given + // network and address. The tlsConfig and the quicConfig arguments + // MUST NOT be nil. Returns either the session or an error. + Dial(network, address string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlySession, error) +} + +// QUICListener listens for QUIC connections. +type QUICListener interface { + // Listen creates a new listening net.PacketConn. + Listen(addr *net.UDPAddr) (net.PacketConn, error) +} + +// QUICListenerStdlib is a QUICListener using the standard library. +type QUICListenerStdlib struct{} + +var _ QUICListener = &QUICListenerStdlib{} + +// Listen implements QUICListener.Listen. +func (qls *QUICListenerStdlib) Listen(addr *net.UDPAddr) (net.PacketConn, error) { + return net.ListenUDP("udp", addr) +} + +// QUICDialerQUICGo dials using the lucas-clemente/quic-go library. +type QUICDialerQUICGo struct { + // QUICListener is the underlying QUICListener to use. + QUICListener QUICListener +} + +var _ QUICContextDialer = &QUICDialerQUICGo{} + +// errInvalidIP indicates that a string is not a valid IP. +var errInvalidIP = errors.New("netxlite: invalid IP") + +// DialContext implements ContextDialer.DialContext +func (d *QUICDialerQUICGo) DialContext(ctx context.Context, network string, + address string, tlsConfig *tls.Config, quicConfig *quic.Config) ( + quic.EarlySession, error) { + onlyhost, onlyport, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(onlyport) + if err != nil { + return nil, err + } + ip := net.ParseIP(onlyhost) + if ip == nil { + return nil, errInvalidIP + } + pconn, err := d.QUICListener.Listen(&net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return nil, err + } + udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""} + return quic.DialEarlyContext( + ctx, pconn, udpAddr, address, tlsConfig, quicConfig) +} diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go new file mode 100644 index 0000000..4d6e2b2 --- /dev/null +++ b/internal/netxlite/quic_test.go @@ -0,0 +1,114 @@ +package netxlite + +import ( + "context" + "crypto/tls" + "errors" + "log" + "net" + "strings" + "testing" + + "github.com/lucas-clemente/quic-go" + "github.com/ooni/probe-cli/v3/internal/netxmocks" +) + +func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) { + tlsConfig := &tls.Config{ + NextProtos: []string{"h3"}, + ServerName: "www.google.com", + } + systemdialer := QUICDialerQUICGo{ + QUICListener: &QUICListenerStdlib{}, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "a.b.c.d", tlsConfig, &quic.Config{}) + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } +} + +func TestQUICDialerQUICGoInvalidPort(t *testing.T) { + tlsConfig := &tls.Config{ + NextProtos: []string{"h3"}, + ServerName: "www.google.com", + } + systemdialer := QUICDialerQUICGo{ + QUICListener: &QUICListenerStdlib{}, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.4.4:xyz", tlsConfig, &quic.Config{}) + if err == nil || !strings.HasSuffix(err.Error(), "invalid syntax") { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } +} + +func TestQUICDialerQUICGoInvalidIP(t *testing.T) { + tlsConfig := &tls.Config{ + NextProtos: []string{"h3"}, + ServerName: "www.google.com", + } + systemdialer := QUICDialerQUICGo{ + QUICListener: &QUICListenerStdlib{}, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "a.b.c.d:0", tlsConfig, &quic.Config{}) + if !errors.Is(err, errInvalidIP) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } +} + +func TestQUICDialerQUICGoCannotListen(t *testing.T) { + expected := errors.New("mocked error") + tlsConfig := &tls.Config{ + NextProtos: []string{"h3"}, + ServerName: "www.google.com", + } + systemdialer := QUICDialerQUICGo{ + QUICListener: &netxmocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (net.PacketConn, error) { + return nil, expected + }, + }, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } +} + +func TestQUICDialerWorksAsIntended(t *testing.T) { + tlsConfig := &tls.Config{ + NextProtos: []string{"h3"}, + ServerName: "dns.google", + } + systemdialer := QUICDialerQUICGo{ + QUICListener: &QUICListenerStdlib{}, + } + ctx := context.Background() + sess, err := systemdialer.DialContext( + ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) + if err != nil { + t.Fatal("not the error we expected", err) + } + if err := sess.CloseWithError(0, ""); err != nil { + log.Fatal(err) + } +} diff --git a/internal/netxmocks/quic.go b/internal/netxmocks/quic.go new file mode 100644 index 0000000..a0238ec --- /dev/null +++ b/internal/netxmocks/quic.go @@ -0,0 +1,13 @@ +package netxmocks + +import "net" + +// QUICListener is a mockable netxlite.QUICListener. +type QUICListener struct { + MockListen func(addr *net.UDPAddr) (net.PacketConn, error) +} + +// Listen calls MockListen. +func (ql *QUICListener) Listen(addr *net.UDPAddr) (net.PacketConn, error) { + return ql.MockListen(addr) +} diff --git a/internal/netxmocks/quic_test.go b/internal/netxmocks/quic_test.go new file mode 100644 index 0000000..fbddfb7 --- /dev/null +++ b/internal/netxmocks/quic_test.go @@ -0,0 +1,23 @@ +package netxmocks + +import ( + "errors" + "net" + "testing" +) + +func TestQUICListenerListen(t *testing.T) { + expected := errors.New("mocked error") + ql := &QUICListener{ + MockListen: func(addr *net.UDPAddr) (net.PacketConn, error) { + return nil, expected + }, + } + pconn, err := ql.Listen(&net.UDPAddr{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", expected) + } + if pconn != nil { + t.Fatal("expected nil conn here") + } +}