diff --git a/internal/netxlite/quic.go b/internal/netxlite/quic.go index a6d1ec1..86a44e6 100644 --- a/internal/netxlite/quic.go +++ b/internal/netxlite/quic.go @@ -94,6 +94,29 @@ var _ model.QUICDialer = &quicDialerQUICGo{} // ErrInvalidIP indicates that a string is not a valid IP. var ErrInvalidIP = errors.New("netxlite: invalid IP") +// ParseUDPAddr maps the string representation of an UDP endpoint to the +// corresponding *net.UDPAddr representation. +func ParseUDPAddr(address string) (*net.UDPAddr, error) { + addr, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + ipAddr := net.ParseIP(addr) + if ipAddr == nil { + return nil, ErrInvalidIP + } + dport, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + udpAddr := &net.UDPAddr{ + IP: ipAddr, + Port: dport, + Zone: "", + } + return udpAddr, nil +} + // DialContext implements QUICDialer.DialContext. This function will // apply the following TLS defaults: // @@ -105,24 +128,15 @@ var ErrInvalidIP = errors.New("netxlite: invalid IP") func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string, address string, tlsConfig *tls.Config, quicConfig *quic.Config) ( quic.EarlyConnection, error) { - onlyhost, onlyport, err := net.SplitHostPort(address) + udpAddr, err := ParseUDPAddr(address) if err != nil { return nil, err } - port, err := strconv.Atoi(onlyport) + pconn, err := d.QUICListener.Listen(&net.UDPAddr{IP: net.IPv4zero, Port: 0, Zone: ""}) if err != nil { return nil, err } - ip := net.ParseIP(onlyhost) - if ip == nil { - return nil, ErrInvalidIP - } - pconn, err := d.QUICListener.Listen(&net.UDPAddr{IP: net.IPv4zero, Port: 0}) - if err != nil { - return nil, err - } - udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""} - tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, port) + tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, udpAddr.Port) qconn, err := d.dialEarlyContext( ctx, pconn, udpAddr, address, tlsConfig, quicConfig) if err != nil { diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index b512e92..f974911 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -45,6 +45,54 @@ func TestNewQUICDialer(t *testing.T) { } } +func TestParseUDPAddr(t *testing.T) { + t.Run("cannot split host and port", func(t *testing.T) { + addr, err := ParseUDPAddr("1.2.3.4") + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("unexpected error", err) + } + if addr != nil { + t.Fatal("expected nil addr") + } + }) + + t.Run("with invalid IP addr", func(t *testing.T) { + addr, err := ParseUDPAddr("www.google.com:80") + if !errors.Is(err, ErrInvalidIP) { + t.Fatal("unexpected error", err) + } + if addr != nil { + t.Fatal("expected nil addr") + } + }) + + t.Run("with invalid port", func(t *testing.T) { + addr, err := ParseUDPAddr("8.8.8.8:www") + if err == nil || !strings.HasSuffix(err.Error(), "invalid syntax") { + t.Fatal("unexpected error", err) + } + if addr != nil { + t.Fatal("expected nil addr") + } + }) + + t.Run("with valid input", func(t *testing.T) { + addr, err := ParseUDPAddr("8.8.8.8:80") + if err != nil { + t.Fatal(err) + } + if addr.IP.String() != "8.8.8.8" { + t.Fatal("invalid IP") + } + if addr.Port != 80 { + t.Fatal("invalid port") + } + if addr.Zone != "" { + t.Fatal("invalid zone") + } + }) +} + func TestQUICDialerQUICGo(t *testing.T) { t.Run("DialContext", func(t *testing.T) { t.Run("cannot split host port", func(t *testing.T) {