refactor(netxlite): finish grouping tests (#488)
They are now more readable. I'll do another pass and start separating integration testing from unit testing. I think we need to have some always on integration testing for netxlite that runs on macOS, linux, and windows. See https://github.com/ooni/probe/issues/1591
This commit is contained in:
parent
493b72b170
commit
f2e3e5cc08
@ -1,12 +1,15 @@
|
|||||||
package netxlite_test
|
package netxlite_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/apex/log"
|
"github.com/apex/log"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
|
utls "gitlab.com/yawning/utls.git"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHTTPTransport(t *testing.T) {
|
func TestHTTPTransport(t *testing.T) {
|
||||||
@ -49,3 +52,21 @@ func TestHTTP3Transport(t *testing.T) {
|
|||||||
txp.CloseIdleConnections()
|
txp.CloseIdleConnections()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUTLSHandshaker(t *testing.T) {
|
||||||
|
t.Run("with chrome fingerprint", func(t *testing.T) {
|
||||||
|
h := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_Auto)
|
||||||
|
cfg := &tls.Config{ServerName: "google.com"}
|
||||||
|
conn, err := net.Dial("tcp", "google.com:443")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("unexpected error", err)
|
||||||
|
}
|
||||||
|
conn, _, err = h.Handshake(context.Background(), conn, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("unexpected error", err)
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
t.Fatal("nil connection")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -7,74 +7,80 @@ import (
|
|||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResolverLegacyAdapterWithCompatibleType(t *testing.T) {
|
func TestResolverLegacyAdapter(t *testing.T) {
|
||||||
var called bool
|
t.Run("with compatible type", func(t *testing.T) {
|
||||||
r := NewResolverLegacyAdapter(&mocks.Resolver{
|
var called bool
|
||||||
MockNetwork: func() string {
|
r := NewResolverLegacyAdapter(&mocks.Resolver{
|
||||||
return "network"
|
MockNetwork: func() string {
|
||||||
},
|
return "network"
|
||||||
MockAddress: func() string {
|
},
|
||||||
return "address"
|
MockAddress: func() string {
|
||||||
},
|
return "address"
|
||||||
MockCloseIdleConnections: func() {
|
},
|
||||||
called = true
|
MockCloseIdleConnections: func() {
|
||||||
},
|
called = true
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if r.Network() != "network" {
|
||||||
|
t.Fatal("invalid Network")
|
||||||
|
}
|
||||||
|
if r.Address() != "address" {
|
||||||
|
t.Fatal("invalid Address")
|
||||||
|
}
|
||||||
|
r.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
if r.Network() != "network" {
|
|
||||||
t.Fatal("invalid Network")
|
|
||||||
}
|
|
||||||
if r.Address() != "address" {
|
|
||||||
t.Fatal("invalid Address")
|
|
||||||
}
|
|
||||||
r.CloseIdleConnections()
|
|
||||||
if !called {
|
|
||||||
t.Fatal("not called")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverLegacyAdapterDefaults(t *testing.T) {
|
t.Run("with incompatible type", func(t *testing.T) {
|
||||||
r := NewResolverLegacyAdapter(&net.Resolver{})
|
r := NewResolverLegacyAdapter(&net.Resolver{})
|
||||||
if r.Network() != "adapter" {
|
if r.Network() != "adapter" {
|
||||||
t.Fatal("invalid Network")
|
t.Fatal("invalid Network")
|
||||||
}
|
}
|
||||||
if r.Address() != "" {
|
if r.Address() != "" {
|
||||||
t.Fatal("invalid Address")
|
t.Fatal("invalid Address")
|
||||||
}
|
}
|
||||||
r.CloseIdleConnections() // does not crash
|
r.CloseIdleConnections() // does not crash
|
||||||
}
|
|
||||||
|
|
||||||
func TestDialerLegacyAdapterWithCompatibleType(t *testing.T) {
|
|
||||||
var called bool
|
|
||||||
r := NewDialerLegacyAdapter(&mocks.Dialer{
|
|
||||||
MockCloseIdleConnections: func() {
|
|
||||||
called = true
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
r.CloseIdleConnections()
|
|
||||||
if !called {
|
|
||||||
t.Fatal("not called")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialerLegacyAdapterDefaults(t *testing.T) {
|
func TestDialerLegacyAdapter(t *testing.T) {
|
||||||
r := NewDialerLegacyAdapter(&net.Dialer{})
|
t.Run("with compatible type", func(t *testing.T) {
|
||||||
r.CloseIdleConnections() // does not crash
|
var called bool
|
||||||
}
|
r := NewDialerLegacyAdapter(&mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
func TestQUICContextDialerAdapterWithCompatibleType(t *testing.T) {
|
called = true
|
||||||
var called bool
|
},
|
||||||
d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICDialer{
|
})
|
||||||
MockCloseIdleConnections: func() {
|
r.CloseIdleConnections()
|
||||||
called = true
|
if !called {
|
||||||
},
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with incompatible type", func(t *testing.T) {
|
||||||
|
r := NewDialerLegacyAdapter(&net.Dialer{})
|
||||||
|
r.CloseIdleConnections() // does not crash
|
||||||
})
|
})
|
||||||
d.CloseIdleConnections()
|
|
||||||
if !called {
|
|
||||||
t.Fatal("not called")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICContextDialerAdapterDefaults(t *testing.T) {
|
func TestQUICContextDialerAdapter(t *testing.T) {
|
||||||
d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICContextDialer{})
|
t.Run("with compatible type", func(t *testing.T) {
|
||||||
d.CloseIdleConnections() // does not crash
|
var called bool
|
||||||
|
d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICDialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
})
|
||||||
|
d.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with incompatible type", func(t *testing.T) {
|
||||||
|
d := NewQUICDialerFromContextDialerAdapter(&mocks.QUICContextDialer{})
|
||||||
|
d.CloseIdleConnections() // does not crash
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -17,457 +17,459 @@ import (
|
|||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/quicx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestQUICDialerQUICGoCannotSplitHostPort(t *testing.T) {
|
func TestQUICDialerQUICGo(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
ServerName: "www.google.com",
|
t.Run("cannot split host port", func(t *testing.T) {
|
||||||
}
|
tlsConfig := &tls.Config{
|
||||||
systemdialer := quicDialerQUICGo{
|
ServerName: "www.google.com",
|
||||||
QUICListener: &quicListenerStdlib{},
|
}
|
||||||
}
|
systemdialer := quicDialerQUICGo{
|
||||||
defer systemdialer.CloseIdleConnections() // just to see it running
|
QUICListener: &quicListenerStdlib{},
|
||||||
ctx := context.Background()
|
}
|
||||||
sess, err := systemdialer.DialContext(
|
defer systemdialer.CloseIdleConnections() // just to see it running
|
||||||
ctx, "udp", "a.b.c.d", tlsConfig, &quic.Config{})
|
ctx := context.Background()
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
sess, err := systemdialer.DialContext(
|
||||||
t.Fatal("not the error we expected", err)
|
ctx, "udp", "a.b.c.d", tlsConfig, &quic.Config{})
|
||||||
}
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
if sess != nil {
|
t.Fatal("not the error we expected", err)
|
||||||
t.Fatal("expected nil sess here")
|
}
|
||||||
}
|
if sess != nil {
|
||||||
}
|
t.Fatal("expected nil sess here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
func TestQUICDialerQUICGoInvalidPort(t *testing.T) {
|
t.Run("with invalid port", func(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
ServerName: "www.google.com",
|
ServerName: "www.google.com",
|
||||||
}
|
}
|
||||||
systemdialer := quicDialerQUICGo{
|
systemdialer := quicDialerQUICGo{
|
||||||
QUICListener: &quicListenerStdlib{},
|
QUICListener: &quicListenerStdlib{},
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
sess, err := systemdialer.DialContext(
|
sess, err := systemdialer.DialContext(
|
||||||
ctx, "udp", "8.8.4.4:xyz", tlsConfig, &quic.Config{})
|
ctx, "udp", "8.8.4.4:xyz", tlsConfig, &quic.Config{})
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid syntax") {
|
if err == nil || !strings.HasSuffix(err.Error(), "invalid syntax") {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
}
|
}
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
t.Fatal("expected nil sess here")
|
t.Fatal("expected nil sess here")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestQUICDialerQUICGoInvalidIP(t *testing.T) {
|
t.Run("with invalid IP", func(t *testing.T) {
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
ServerName: "www.google.com",
|
ServerName: "www.google.com",
|
||||||
}
|
}
|
||||||
systemdialer := quicDialerQUICGo{
|
systemdialer := quicDialerQUICGo{
|
||||||
QUICListener: &quicListenerStdlib{},
|
QUICListener: &quicListenerStdlib{},
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
sess, err := systemdialer.DialContext(
|
sess, err := systemdialer.DialContext(
|
||||||
ctx, "udp", "a.b.c.d:0", tlsConfig, &quic.Config{})
|
ctx, "udp", "a.b.c.d:0", tlsConfig, &quic.Config{})
|
||||||
if !errors.Is(err, errInvalidIP) {
|
if !errors.Is(err, errInvalidIP) {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
}
|
}
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
t.Fatal("expected nil sess here")
|
t.Fatal("expected nil sess here")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestQUICDialerQUICGoCannotListen(t *testing.T) {
|
t.Run("with listen error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
ServerName: "www.google.com",
|
ServerName: "www.google.com",
|
||||||
}
|
}
|
||||||
systemdialer := quicDialerQUICGo{
|
systemdialer := quicDialerQUICGo{
|
||||||
QUICListener: &mocks.QUICListener{
|
QUICListener: &mocks.QUICListener{
|
||||||
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
|
MockListen: func(addr *net.UDPAddr) (quicx.UDPLikeConn, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
sess, err := systemdialer.DialContext(
|
|
||||||
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
t.Fatal("expected nil sess here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerQUICGoCannotPerformHandshake(t *testing.T) {
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
ServerName: "dns.google",
|
|
||||||
}
|
|
||||||
systemdialer := quicDialerQUICGo{
|
|
||||||
QUICListener: &quicListenerStdlib{},
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel() // fail immediately
|
|
||||||
sess, err := systemdialer.DialContext(
|
|
||||||
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
|
||||||
if !errors.Is(err, context.Canceled) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
log.Fatal("expected nil session here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerQUICGoWorksAsIntended(t *testing.T) {
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
ServerName: "dns.google",
|
|
||||||
}
|
|
||||||
systemdialer := quicDialerQUICGo{
|
|
||||||
QUICListener: &quicListenerStdlib{},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
sess, err := systemdialer.DialContext(
|
|
||||||
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
<-sess.HandshakeComplete().Done()
|
|
||||||
if err := sess.CloseWithError(0, ""); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerQUICGoTLSDefaultsForWeb(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
var gotTLSConfig *tls.Config
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
ServerName: "dns.google",
|
|
||||||
}
|
|
||||||
systemdialer := quicDialerQUICGo{
|
|
||||||
QUICListener: &quicListenerStdlib{},
|
|
||||||
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
|
||||||
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
|
||||||
quicConfig *quic.Config) (quic.EarlySession, error) {
|
|
||||||
gotTLSConfig = tlsConfig
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
sess, err := systemdialer.DialContext(
|
|
||||||
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
t.Fatal("expected nil session here")
|
|
||||||
}
|
|
||||||
if tlsConfig.RootCAs != nil {
|
|
||||||
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
|
||||||
}
|
|
||||||
if gotTLSConfig.RootCAs != defaultCertPool {
|
|
||||||
t.Fatal("invalid gotTLSConfig.RootCAs")
|
|
||||||
}
|
|
||||||
if tlsConfig.NextProtos != nil {
|
|
||||||
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" {
|
|
||||||
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
|
||||||
}
|
|
||||||
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
|
||||||
t.Fatal("the ServerName field must match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
var gotTLSConfig *tls.Config
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
ServerName: "dns.google",
|
|
||||||
}
|
|
||||||
systemdialer := quicDialerQUICGo{
|
|
||||||
QUICListener: &quicListenerStdlib{},
|
|
||||||
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
|
||||||
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
|
||||||
quicConfig *quic.Config) (quic.EarlySession, error) {
|
|
||||||
gotTLSConfig = tlsConfig
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
sess, err := systemdialer.DialContext(
|
|
||||||
ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{})
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
t.Fatal("expected nil session here")
|
|
||||||
}
|
|
||||||
if tlsConfig.RootCAs != nil {
|
|
||||||
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
|
||||||
}
|
|
||||||
if gotTLSConfig.RootCAs != defaultCertPool {
|
|
||||||
t.Fatal("invalid gotTLSConfig.RootCAs")
|
|
||||||
}
|
|
||||||
if tlsConfig.NextProtos != nil {
|
|
||||||
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" {
|
|
||||||
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
|
||||||
}
|
|
||||||
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
|
||||||
t.Fatal("the ServerName field must match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerResolverCloseIdleConnections(t *testing.T) {
|
|
||||||
var (
|
|
||||||
forDialer bool
|
|
||||||
forResolver bool
|
|
||||||
)
|
|
||||||
d := &quicDialerResolver{
|
|
||||||
Dialer: &mocks.QUICDialer{
|
|
||||||
MockCloseIdleConnections: func() {
|
|
||||||
forDialer = true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Resolver: &mocks.Resolver{
|
|
||||||
MockCloseIdleConnections: func() {
|
|
||||||
forResolver = true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
d.CloseIdleConnections()
|
|
||||||
if !forDialer || !forResolver {
|
|
||||||
t.Fatal("not called")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerResolverSuccess(t *testing.T) {
|
|
||||||
tlsConfig := &tls.Config{}
|
|
||||||
dialer := &quicDialerResolver{
|
|
||||||
Resolver: NewResolverSystem(log.Log),
|
|
||||||
Dialer: &quicDialerQUICGo{
|
|
||||||
QUICListener: &quicListenerStdlib{},
|
|
||||||
}}
|
|
||||||
sess, err := dialer.DialContext(
|
|
||||||
context.Background(), "udp", "www.google.com:443",
|
|
||||||
tlsConfig, &quic.Config{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
<-sess.HandshakeComplete().Done()
|
|
||||||
if err := sess.CloseWithError(0, ""); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerResolverNoPort(t *testing.T) {
|
|
||||||
tlsConfig := &tls.Config{}
|
|
||||||
dialer := &quicDialerResolver{
|
|
||||||
Resolver: NewResolverSystem(log.Log),
|
|
||||||
Dialer: &quicDialerQUICGo{}}
|
|
||||||
sess, err := dialer.DialContext(
|
|
||||||
context.Background(), "udp", "www.google.com",
|
|
||||||
tlsConfig, &quic.Config{})
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
t.Fatal("expected a nil sess here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerResolverLookupHostAddress(t *testing.T) {
|
|
||||||
dialer := &quicDialerResolver{Resolver: &mocks.Resolver{
|
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
// We should not arrive here and call this function but if we do then
|
|
||||||
// there is going to be an error that fails this test.
|
|
||||||
return nil, errors.New("mocked error")
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
|
|
||||||
t.Fatal("not the result we expected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerResolverLookupHostFailure(t *testing.T) {
|
|
||||||
tlsConfig := &tls.Config{}
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
dialer := &quicDialerResolver{Resolver: &mocks.Resolver{
|
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
sess, err := dialer.DialContext(
|
|
||||||
context.Background(), "udp", "dns.google.com:853",
|
|
||||||
tlsConfig, &quic.Config{})
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
t.Fatal("expected nil sess")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerResolverInvalidPort(t *testing.T) {
|
|
||||||
// This test allows us to check for the case where every attempt
|
|
||||||
// to establish a connection leads to a failure
|
|
||||||
tlsConf := &tls.Config{}
|
|
||||||
dialer := &quicDialerResolver{
|
|
||||||
Resolver: NewResolverSystem(log.Log),
|
|
||||||
Dialer: &quicDialerQUICGo{
|
|
||||||
QUICListener: &quicListenerStdlib{},
|
|
||||||
}}
|
|
||||||
sess, err := dialer.DialContext(
|
|
||||||
context.Background(), "udp", "www.google.com:0",
|
|
||||||
tlsConf, &quic.Config{})
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected an error here")
|
|
||||||
}
|
|
||||||
if !strings.HasSuffix(err.Error(), "sendto: invalid argument") &&
|
|
||||||
!strings.HasSuffix(err.Error(), "sendto: can't assign requested address") {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
t.Fatal("expected nil sess")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
var gotTLSConfig *tls.Config
|
|
||||||
tlsConfig := &tls.Config{}
|
|
||||||
dialer := &quicDialerResolver{
|
|
||||||
Resolver: NewResolverSystem(log.Log),
|
|
||||||
Dialer: &mocks.QUICDialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network, address string,
|
|
||||||
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
|
||||||
gotTLSConfig = tlsConfig
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
sess, err := dialer.DialContext(
|
|
||||||
context.Background(), "udp", "www.google.com:443",
|
|
||||||
tlsConfig, &quic.Config{})
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if sess != nil {
|
|
||||||
t.Fatal("expected nil session here")
|
|
||||||
}
|
|
||||||
if tlsConfig.ServerName != "" {
|
|
||||||
t.Fatal("should not have changed tlsConfig.ServerName")
|
|
||||||
}
|
|
||||||
if gotTLSConfig.ServerName != "www.google.com" {
|
|
||||||
t.Fatal("gotTLSConfig.ServerName has not been set")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerLoggerCloseIdleConnections(t *testing.T) {
|
|
||||||
var forDialer bool
|
|
||||||
d := &quicDialerLogger{
|
|
||||||
Dialer: &mocks.QUICDialer{
|
|
||||||
MockCloseIdleConnections: func() {
|
|
||||||
forDialer = true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
d.CloseIdleConnections()
|
|
||||||
if !forDialer {
|
|
||||||
t.Fatal("not called")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQUICDialerLoggerSuccess(t *testing.T) {
|
|
||||||
d := &quicDialerLogger{
|
|
||||||
Dialer: &mocks.QUICDialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network string,
|
|
||||||
address string, tlsConfig *tls.Config,
|
|
||||||
quicConfig *quic.Config) (quic.EarlySession, error) {
|
|
||||||
return &mocks.QUICEarlySession{
|
|
||||||
MockCloseWithError: func(
|
|
||||||
code quic.ApplicationErrorCode, reason string) error {
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
}, nil
|
},
|
||||||
},
|
}
|
||||||
},
|
ctx := context.Background()
|
||||||
Logger: log.Log,
|
sess, err := systemdialer.DialContext(
|
||||||
}
|
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||||
ctx := context.Background()
|
if !errors.Is(err, expected) {
|
||||||
tlsConfig := &tls.Config{}
|
t.Fatal("not the error we expected", err)
|
||||||
quicConfig := &quic.Config{}
|
}
|
||||||
sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
|
if sess != nil {
|
||||||
if err != nil {
|
t.Fatal("expected nil sess here")
|
||||||
t.Fatal(err)
|
}
|
||||||
}
|
})
|
||||||
if err := sess.CloseWithError(0, ""); err != nil {
|
|
||||||
t.Fatal(err)
|
t.Run("with handshake failure", func(t *testing.T) {
|
||||||
}
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
systemdialer := quicDialerQUICGo{
|
||||||
|
QUICListener: &quicListenerStdlib{},
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // fail immediately
|
||||||
|
sess, err := systemdialer.DialContext(
|
||||||
|
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
log.Fatal("expected nil session here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works as intended", func(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skip test in short mode")
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
systemdialer := quicDialerQUICGo{
|
||||||
|
QUICListener: &quicListenerStdlib{},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
sess, err := systemdialer.DialContext(
|
||||||
|
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
<-sess.HandshakeComplete().Done()
|
||||||
|
if err := sess.CloseWithError(0, ""); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TLS defaults for web", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
systemdialer := quicDialerQUICGo{
|
||||||
|
QUICListener: &quicListenerStdlib{},
|
||||||
|
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
gotTLSConfig = tlsConfig
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
sess, err := systemdialer.DialContext(
|
||||||
|
ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session here")
|
||||||
|
}
|
||||||
|
if tlsConfig.RootCAs != nil {
|
||||||
|
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||||
|
t.Fatal("invalid gotTLSConfig.RootCAs")
|
||||||
|
}
|
||||||
|
if tlsConfig.NextProtos != nil {
|
||||||
|
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"h3"}); diff != "" {
|
||||||
|
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
||||||
|
}
|
||||||
|
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
||||||
|
t.Fatal("the ServerName field must match")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TLS defaults for DoQ", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
systemdialer := quicDialerQUICGo{
|
||||||
|
QUICListener: &quicListenerStdlib{},
|
||||||
|
mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn,
|
||||||
|
remoteAddr net.Addr, host string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
gotTLSConfig = tlsConfig
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
sess, err := systemdialer.DialContext(
|
||||||
|
ctx, "udp", "8.8.8.8:8853", tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session here")
|
||||||
|
}
|
||||||
|
if tlsConfig.RootCAs != nil {
|
||||||
|
t.Fatal("tlsConfig.RootCAs should not have been changed")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||||
|
t.Fatal("invalid gotTLSConfig.RootCAs")
|
||||||
|
}
|
||||||
|
if tlsConfig.NextProtos != nil {
|
||||||
|
t.Fatal("tlsConfig.NextProtos should not have been changed")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotTLSConfig.NextProtos, []string{"dq"}); diff != "" {
|
||||||
|
t.Fatal("invalid gotTLSConfig.NextProtos", diff)
|
||||||
|
}
|
||||||
|
if tlsConfig.ServerName != gotTLSConfig.ServerName {
|
||||||
|
t.Fatal("the ServerName field must match")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICDialerLoggerFailure(t *testing.T) {
|
func TestQUICDialerResolver(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
|
||||||
d := &quicDialerLogger{
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
Dialer: &mocks.QUICDialer{
|
var (
|
||||||
MockDialContext: func(ctx context.Context, network string,
|
forDialer bool
|
||||||
address string, tlsConfig *tls.Config,
|
forResolver bool
|
||||||
quicConfig *quic.Config) (quic.EarlySession, error) {
|
)
|
||||||
return nil, expected
|
d := &quicDialerResolver{
|
||||||
|
Dialer: &mocks.QUICDialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
forDialer = true
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
Resolver: &mocks.Resolver{
|
||||||
Logger: log.Log,
|
MockCloseIdleConnections: func() {
|
||||||
}
|
forResolver = true
|
||||||
ctx := context.Background()
|
},
|
||||||
tlsConfig := &tls.Config{}
|
},
|
||||||
quicConfig := &quic.Config{}
|
}
|
||||||
sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
|
d.CloseIdleConnections()
|
||||||
if !errors.Is(err, expected) {
|
if !forDialer || !forResolver {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not called")
|
||||||
}
|
}
|
||||||
if sess != nil {
|
})
|
||||||
t.Fatal("expected nil session")
|
|
||||||
}
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
dialer := &quicDialerResolver{
|
||||||
|
Resolver: NewResolverSystem(log.Log),
|
||||||
|
Dialer: &quicDialerQUICGo{
|
||||||
|
QUICListener: &quicListenerStdlib{},
|
||||||
|
}}
|
||||||
|
sess, err := dialer.DialContext(
|
||||||
|
context.Background(), "udp", "www.google.com:443",
|
||||||
|
tlsConfig, &quic.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
<-sess.HandshakeComplete().Done()
|
||||||
|
if err := sess.CloseWithError(0, ""); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with missing port", func(t *testing.T) {
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
dialer := &quicDialerResolver{
|
||||||
|
Resolver: NewResolverSystem(log.Log),
|
||||||
|
Dialer: &quicDialerQUICGo{}}
|
||||||
|
sess, err := dialer.DialContext(
|
||||||
|
context.Background(), "udp", "www.google.com",
|
||||||
|
tlsConfig, &quic.Config{})
|
||||||
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected a nil sess here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with lookup host failure", func(t *testing.T) {
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
dialer := &quicDialerResolver{Resolver: &mocks.Resolver{
|
||||||
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
sess, err := dialer.DialContext(
|
||||||
|
context.Background(), "udp", "dns.google.com:853",
|
||||||
|
tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil sess")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with invalid port", func(t *testing.T) {
|
||||||
|
// This test allows us to check for the case where every attempt
|
||||||
|
// to establish a connection leads to a failure
|
||||||
|
tlsConf := &tls.Config{}
|
||||||
|
dialer := &quicDialerResolver{
|
||||||
|
Resolver: NewResolverSystem(log.Log),
|
||||||
|
Dialer: &quicDialerQUICGo{
|
||||||
|
QUICListener: &quicListenerStdlib{},
|
||||||
|
}}
|
||||||
|
sess, err := dialer.DialContext(
|
||||||
|
context.Background(), "udp", "www.google.com:0",
|
||||||
|
tlsConf, &quic.Config{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected an error here")
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(err.Error(), "sendto: invalid argument") &&
|
||||||
|
!strings.HasSuffix(err.Error(), "sendto: can't assign requested address") {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil sess")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("we apply TLS defaults", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
dialer := &quicDialerResolver{
|
||||||
|
Resolver: NewResolverSystem(log.Log),
|
||||||
|
Dialer: &mocks.QUICDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
gotTLSConfig = tlsConfig
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
sess, err := dialer.DialContext(
|
||||||
|
context.Background(), "udp", "www.google.com:443",
|
||||||
|
tlsConfig, &quic.Config{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session here")
|
||||||
|
}
|
||||||
|
if tlsConfig.ServerName != "" {
|
||||||
|
t.Fatal("should not have changed tlsConfig.ServerName")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.ServerName != "www.google.com" {
|
||||||
|
t.Fatal("gotTLSConfig.ServerName has not been set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("lookup host with address", func(t *testing.T) {
|
||||||
|
dialer := &quicDialerResolver{Resolver: &mocks.Resolver{
|
||||||
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
// We should not arrive here and call this function but if we do then
|
||||||
|
// there is going to be an error that fails this test.
|
||||||
|
return nil, errors.New("mocked error")
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
addrs, err := dialer.lookupHost(context.Background(), "1.1.1.1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
|
||||||
|
t.Fatal("not the result we expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewQUICDialerWithoutResolverChain(t *testing.T) {
|
func TestQUICLoggerDialer(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var forDialer bool
|
||||||
|
d := &quicDialerLogger{
|
||||||
|
Dialer: &mocks.QUICDialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
forDialer = true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
d.CloseIdleConnections()
|
||||||
|
if !forDialer {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
d := &quicDialerLogger{
|
||||||
|
Dialer: &mocks.QUICDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network string,
|
||||||
|
address string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
return &mocks.QUICEarlySession{
|
||||||
|
MockCloseWithError: func(
|
||||||
|
code quic.ApplicationErrorCode, reason string) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Logger: log.Log,
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
quicConfig := &quic.Config{}
|
||||||
|
sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := sess.CloseWithError(0, ""); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
d := &quicDialerLogger{
|
||||||
|
Dialer: &mocks.QUICDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network string,
|
||||||
|
address string, tlsConfig *tls.Config,
|
||||||
|
quicConfig *quic.Config) (quic.EarlySession, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Logger: log.Log,
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
quicConfig := &quic.Config{}
|
||||||
|
sess, err := d.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if sess != nil {
|
||||||
|
t.Fatal("expected nil session")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewQUICDialer(t *testing.T) {
|
||||||
ql := NewQUICListener()
|
ql := NewQUICListener()
|
||||||
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
|
dlr := NewQUICDialerWithoutResolver(ql, log.Log)
|
||||||
dlog, okay := dlr.(*quicDialerLogger)
|
logger := dlr.(*quicDialerLogger)
|
||||||
if !okay {
|
if logger.Logger != log.Log {
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if dlog.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
dr, okay := dlog.Dialer.(*quicDialerResolver)
|
resolver := logger.Dialer.(*quicDialerResolver)
|
||||||
if !okay {
|
if _, okay := resolver.Resolver.(*nullResolver); !okay {
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if _, okay := dr.Resolver.(*nullResolver); !okay {
|
|
||||||
t.Fatal("invalid resolver type")
|
t.Fatal("invalid resolver type")
|
||||||
}
|
}
|
||||||
dlog, okay = dr.Dialer.(*quicDialerLogger)
|
logger = resolver.Dialer.(*quicDialerLogger)
|
||||||
if !okay {
|
if logger.Logger != log.Log {
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if dlog.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
ew, okay := dlog.Dialer.(*quicDialerErrWrapper)
|
errWrapper := logger.Dialer.(*quicDialerErrWrapper)
|
||||||
if !okay {
|
base := errWrapper.QUICDialer.(*quicDialerQUICGo)
|
||||||
t.Fatal("invalid type")
|
if base.QUICListener != ql {
|
||||||
}
|
|
||||||
dgo, okay := ew.QUICDialer.(*quicDialerQUICGo)
|
|
||||||
if !okay {
|
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if dgo.QUICListener != ql {
|
|
||||||
t.Fatal("invalid quic listener")
|
t.Fatal("invalid quic listener")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewSingleUseQUICDialerWorksAsIntended(t *testing.T) {
|
func TestNewSingleUseQUICDialer(t *testing.T) {
|
||||||
sess := &mocks.QUICEarlySession{}
|
sess := &mocks.QUICEarlySession{}
|
||||||
qd := NewSingleUseQUICDialer(sess)
|
qd := NewSingleUseQUICDialer(sess)
|
||||||
outsess, err := qd.DialContext(
|
outsess, err := qd.DialContext(
|
||||||
|
@ -15,6 +15,7 @@ func TestQuirkReduceErrors(t *testing.T) {
|
|||||||
t.Fatal("wrong result")
|
t.Fatal("wrong result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("single error", func(t *testing.T) {
|
t.Run("single error", func(t *testing.T) {
|
||||||
err := errors.New("mocked error")
|
err := errors.New("mocked error")
|
||||||
result := quirkReduceErrors([]error{err})
|
result := quirkReduceErrors([]error{err})
|
||||||
@ -22,6 +23,7 @@ func TestQuirkReduceErrors(t *testing.T) {
|
|||||||
t.Fatal("wrong result")
|
t.Fatal("wrong result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("multiple errors", func(t *testing.T) {
|
t.Run("multiple errors", func(t *testing.T) {
|
||||||
err1 := errors.New("mocked error #1")
|
err1 := errors.New("mocked error #1")
|
||||||
err2 := errors.New("mocked error #2")
|
err2 := errors.New("mocked error #2")
|
||||||
@ -30,6 +32,7 @@ func TestQuirkReduceErrors(t *testing.T) {
|
|||||||
t.Fatal("wrong result")
|
t.Fatal("wrong result")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
||||||
err1 := errors.New("mocked error #1")
|
err1 := errors.New("mocked error #1")
|
||||||
err2 := &errorsx.ErrWrapper{
|
err2 := &errorsx.ErrWrapper{
|
||||||
|
@ -15,235 +15,237 @@ import (
|
|||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResolverSystemNetworkAddress(t *testing.T) {
|
func TestResolverSystem(t *testing.T) {
|
||||||
r := &resolverSystem{}
|
t.Run("Network and Address", func(t *testing.T) {
|
||||||
if r.Network() != "system" {
|
r := &resolverSystem{}
|
||||||
t.Fatal("invalid Network")
|
if r.Network() != "system" {
|
||||||
}
|
t.Fatal("invalid Network")
|
||||||
if r.Address() != "" {
|
}
|
||||||
t.Fatal("invalid Address")
|
if r.Address() != "" {
|
||||||
}
|
t.Fatal("invalid Address")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("works as intended", func(t *testing.T) {
|
||||||
|
r := &resolverSystem{}
|
||||||
|
defer r.CloseIdleConnections()
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if addrs == nil {
|
||||||
|
t.Fatal("expected non-nil result here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("check default timeout", func(t *testing.T) {
|
||||||
|
r := &resolverSystem{}
|
||||||
|
if r.timeout() != 15*time.Second {
|
||||||
|
t.Fatal("unexpected default timeout")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
|
t.Run("with timeout and success", func(t *testing.T) {
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
r := &resolverSystem{
|
||||||
|
testableTimeout: 1 * time.Microsecond,
|
||||||
|
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
defer wg.Done()
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
return []string{"8.8.8.8"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
addrs, err := r.LookupHost(ctx, "example.antani")
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("invalid addrs")
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with timeout and failure", func(t *testing.T) {
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(1)
|
||||||
|
r := &resolverSystem{
|
||||||
|
testableTimeout: 1 * time.Microsecond,
|
||||||
|
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
defer wg.Done()
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
return nil, errors.New("no such host")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
addrs, err := r.LookupHost(ctx, "example.antani")
|
||||||
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("invalid addrs")
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with NXDOMAIN", func(t *testing.T) {
|
||||||
|
r := &resolverSystem{
|
||||||
|
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
return nil, errors.New("no such host")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
addrs, err := r.LookupHost(ctx, "example.antani")
|
||||||
|
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("invalid addrs")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolverSystemWorksAsIntended(t *testing.T) {
|
func TestResolverLogger(t *testing.T) {
|
||||||
r := &resolverSystem{}
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
defer r.CloseIdleConnections()
|
t.Run("with success", func(t *testing.T) {
|
||||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
expected := []string{"1.1.1.1"}
|
||||||
if err != nil {
|
r := resolverLogger{
|
||||||
t.Fatal(err)
|
Logger: log.Log,
|
||||||
}
|
Resolver: &mocks.Resolver{
|
||||||
if addrs == nil {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
t.Fatal("expected non-nil result here")
|
return expected, nil
|
||||||
}
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "dns.google")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(expected, addrs); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := resolverLogger{
|
||||||
|
Logger: log.Log,
|
||||||
|
Resolver: &mocks.Resolver{
|
||||||
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "dns.google")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("expected nil addr here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolverSystemDefaultTimeout(t *testing.T) {
|
func TestResolverIDNA(t *testing.T) {
|
||||||
r := &resolverSystem{}
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
if r.timeout() != 15*time.Second {
|
t.Run("with valid IDNA in input", func(t *testing.T) {
|
||||||
t.Fatal("unexpected default timeout")
|
expectedIPs := []string{"77.88.55.66"}
|
||||||
}
|
r := &resolverIDNA{
|
||||||
|
Resolver: &mocks.Resolver{
|
||||||
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
if domain != "xn--d1acpjx3f.xn--p1ai" {
|
||||||
|
return nil, errors.New("passed invalid domain")
|
||||||
|
}
|
||||||
|
return expectedIPs, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
addrs, err := r.LookupHost(ctx, "яндекс.рф")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(expectedIPs, addrs); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with invalid punycode", func(t *testing.T) {
|
||||||
|
r := &resolverIDNA{Resolver: &mocks.Resolver{
|
||||||
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
return nil, errors.New("should not happen")
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
// See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/
|
||||||
|
ctx := context.Background()
|
||||||
|
addrs, err := r.LookupHost(ctx, "xn--0000h")
|
||||||
|
if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("expected no response here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolverSystemWithTimeoutAndSuccess(t *testing.T) {
|
func TestNewResolverSystem(t *testing.T) {
|
||||||
wg := &sync.WaitGroup{}
|
resolver := NewResolverSystem(log.Log)
|
||||||
wg.Add(1)
|
idna := resolver.(*resolverIDNA)
|
||||||
r := &resolverSystem{
|
logger := idna.Resolver.(*resolverLogger)
|
||||||
testableTimeout: 1 * time.Microsecond,
|
if logger.Logger != log.Log {
|
||||||
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
defer wg.Done()
|
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
return []string{"8.8.8.8"}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
addrs, err := r.LookupHost(ctx, "example.antani")
|
|
||||||
if !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("invalid addrs")
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverSystemWithTimeoutAndFailure(t *testing.T) {
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
wg.Add(1)
|
|
||||||
r := &resolverSystem{
|
|
||||||
testableTimeout: 1 * time.Microsecond,
|
|
||||||
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
defer wg.Done()
|
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
return nil, errors.New("no such host")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
addrs, err := r.LookupHost(ctx, "example.antani")
|
|
||||||
if !errors.Is(err, context.DeadlineExceeded) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("invalid addrs")
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverSystemWithNXDOMAIN(t *testing.T) {
|
|
||||||
r := &resolverSystem{
|
|
||||||
testableLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
return nil, errors.New("no such host")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
addrs, err := r.LookupHost(ctx, "example.antani")
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("invalid addrs")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverLoggerWithSuccess(t *testing.T) {
|
|
||||||
expected := []string{"1.1.1.1"}
|
|
||||||
r := resolverLogger{
|
|
||||||
Logger: log.Log,
|
|
||||||
Resolver: &mocks.Resolver{
|
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
return expected, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
addrs, err := r.LookupHost(context.Background(), "dns.google")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(expected, addrs); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverLoggerWithFailure(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
r := resolverLogger{
|
|
||||||
Logger: log.Log,
|
|
||||||
Resolver: &mocks.Resolver{
|
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
addrs, err := r.LookupHost(context.Background(), "dns.google")
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("expected nil addr here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverIDNAWorksAsIntended(t *testing.T) {
|
|
||||||
expectedIPs := []string{"77.88.55.66"}
|
|
||||||
r := &resolverIDNA{
|
|
||||||
Resolver: &mocks.Resolver{
|
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
if domain != "xn--d1acpjx3f.xn--p1ai" {
|
|
||||||
return nil, errors.New("passed invalid domain")
|
|
||||||
}
|
|
||||||
return expectedIPs, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
addrs, err := r.LookupHost(ctx, "яндекс.рф")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(expectedIPs, addrs); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolverIDNAWithInvalidPunycode(t *testing.T) {
|
|
||||||
r := &resolverIDNA{Resolver: &mocks.Resolver{
|
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
return nil, errors.New("should not happen")
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
// See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/
|
|
||||||
ctx := context.Background()
|
|
||||||
addrs, err := r.LookupHost(ctx, "xn--0000h")
|
|
||||||
if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("expected no response here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewResolverTypeChain(t *testing.T) {
|
|
||||||
r := NewResolverSystem(log.Log)
|
|
||||||
ridna, ok := r.(*resolverIDNA)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("invalid resolver")
|
|
||||||
}
|
|
||||||
rl, ok := ridna.Resolver.(*resolverLogger)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("invalid resolver")
|
|
||||||
}
|
|
||||||
if rl.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
scia, ok := rl.Resolver.(*resolverShortCircuitIPAddr)
|
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||||
if !ok {
|
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||||
t.Fatal("invalid resolver")
|
_ = errWrapper.Resolver.(*resolverSystem)
|
||||||
}
|
|
||||||
ew, ok := scia.Resolver.(*resolverErrWrapper)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("invalid resolver")
|
|
||||||
}
|
|
||||||
if _, ok := ew.Resolver.(*resolverSystem); !ok {
|
|
||||||
t.Fatal("invalid resolver")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolverShortCircuitIPAddrWithIPAddr(t *testing.T) {
|
func TestResolverShortCircuitIPAddr(t *testing.T) {
|
||||||
r := &resolverShortCircuitIPAddr{
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
Resolver: &mocks.Resolver{
|
t.Run("with IP addr", func(t *testing.T) {
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
r := &resolverShortCircuitIPAddr{
|
||||||
return nil, errors.New("mocked error")
|
Resolver: &mocks.Resolver{
|
||||||
},
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
},
|
return nil, errors.New("mocked error")
|
||||||
}
|
},
|
||||||
ctx := context.Background()
|
},
|
||||||
addrs, err := r.LookupHost(ctx, "8.8.8.8")
|
}
|
||||||
if err != nil {
|
ctx := context.Background()
|
||||||
t.Fatal(err)
|
addrs, err := r.LookupHost(ctx, "8.8.8.8")
|
||||||
}
|
if err != nil {
|
||||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
t.Fatal(err)
|
||||||
t.Fatal("invalid result")
|
}
|
||||||
}
|
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||||
|
t.Fatal("invalid result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with domain", func(t *testing.T) {
|
||||||
|
r := &resolverShortCircuitIPAddr{
|
||||||
|
Resolver: &mocks.Resolver{
|
||||||
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
return nil, errors.New("mocked error")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
addrs, err := r.LookupHost(ctx, "dns.google")
|
||||||
|
if err == nil || err.Error() != "mocked error" {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("invalid result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolverShortCircuitIPAddrWithDomain(t *testing.T) {
|
func TestNullResolver(t *testing.T) {
|
||||||
r := &resolverShortCircuitIPAddr{
|
|
||||||
Resolver: &mocks.Resolver{
|
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
|
||||||
return nil, errors.New("mocked error")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
|
||||||
if err == nil || err.Error() != "mocked error" {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("invalid result")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNullResolverWorksAsIntended(t *testing.T) {
|
|
||||||
r := &nullResolver{}
|
r := &nullResolver{}
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
addrs, err := r.LookupHost(ctx, "dns.google")
|
||||||
|
@ -118,354 +118,357 @@ func TestConfigureTLSVersion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerConfigurableWithError(t *testing.T) {
|
func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||||
var times []time.Time
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
h := &tlsHandshakerConfigurable{}
|
t.Run("with error", func(t *testing.T) {
|
||||||
tcpConn := &mocks.Conn{
|
|
||||||
MockWrite: func(b []byte) (int, error) {
|
|
||||||
return 0, io.EOF
|
|
||||||
},
|
|
||||||
MockSetDeadline: func(t time.Time) error {
|
|
||||||
times = append(times, t)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
|
||||||
ServerName: "x.org",
|
|
||||||
})
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Fatal("not the error that we expected")
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("expected nil con here")
|
|
||||||
}
|
|
||||||
if len(times) != 2 {
|
|
||||||
t.Fatal("expected two time entries")
|
|
||||||
}
|
|
||||||
if !times[0].After(time.Now()) {
|
|
||||||
t.Fatal("timeout not in the future")
|
|
||||||
}
|
|
||||||
if !times[1].IsZero() {
|
|
||||||
t.Fatal("did not clear timeout on exit")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSHandshakerConfigurableSuccess(t *testing.T) {
|
var times []time.Time
|
||||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
h := &tlsHandshakerConfigurable{}
|
||||||
rw.WriteHeader(200)
|
tcpConn := &mocks.Conn{
|
||||||
})
|
MockWrite: func(b []byte) (int, error) {
|
||||||
srvr := httptest.NewTLSServer(handler)
|
return 0, io.EOF
|
||||||
defer srvr.Close()
|
},
|
||||||
URL, err := url.Parse(srvr.URL)
|
MockSetDeadline: func(t time.Time) error {
|
||||||
if err != nil {
|
times = append(times, t)
|
||||||
t.Fatal(err)
|
return nil
|
||||||
}
|
|
||||||
conn, err := net.Dial("tcp", URL.Host)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
handshaker := &tlsHandshakerConfigurable{}
|
|
||||||
ctx := context.Background()
|
|
||||||
config := &tls.Config{
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
MinVersion: tls.VersionTLS13,
|
|
||||||
MaxVersion: tls.VersionTLS13,
|
|
||||||
ServerName: URL.Hostname(),
|
|
||||||
}
|
|
||||||
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tlsConn.Close()
|
|
||||||
if connState.Version != tls.VersionTLS13 {
|
|
||||||
t.Fatal("unexpected TLS version")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSHandshakerConfigurableSetsDefaultRootCAs(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
var gotTLSConfig *tls.Config
|
|
||||||
handshaker := &tlsHandshakerConfigurable{
|
|
||||||
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
|
|
||||||
gotTLSConfig = config
|
|
||||||
return &mocks.TLSConn{
|
|
||||||
MockHandshakeContext: func(ctx context.Context) error {
|
|
||||||
return expected
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
ctx := context.Background()
|
||||||
}
|
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
||||||
ctx := context.Background()
|
ServerName: "x.org",
|
||||||
config := &tls.Config{}
|
})
|
||||||
conn := &mocks.Conn{
|
if err != io.EOF {
|
||||||
MockSetDeadline: func(t time.Time) error {
|
t.Fatal("not the error that we expected")
|
||||||
return nil
|
}
|
||||||
},
|
if conn != nil {
|
||||||
}
|
t.Fatal("expected nil con here")
|
||||||
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
}
|
||||||
if !errors.Is(err, expected) {
|
if len(times) != 2 {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("expected two time entries")
|
||||||
}
|
}
|
||||||
if !reflect.ValueOf(connState).IsZero() {
|
if !times[0].After(time.Now()) {
|
||||||
t.Fatal("expected zero connState here")
|
t.Fatal("timeout not in the future")
|
||||||
}
|
}
|
||||||
if tlsConn != nil {
|
if !times[1].IsZero() {
|
||||||
t.Fatal("expected nil tlsConn here")
|
t.Fatal("did not clear timeout on exit")
|
||||||
}
|
}
|
||||||
if config.RootCAs != nil {
|
})
|
||||||
t.Fatal("config.RootCAs should still be nil")
|
|
||||||
}
|
t.Run("with success", func(t *testing.T) {
|
||||||
if gotTLSConfig.RootCAs != defaultCertPool {
|
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
|
rw.WriteHeader(200)
|
||||||
}
|
})
|
||||||
|
srvr := httptest.NewTLSServer(handler)
|
||||||
|
defer srvr.Close()
|
||||||
|
URL, err := url.Parse(srvr.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
conn, err := net.Dial("tcp", URL.Host)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
handshaker := &tlsHandshakerConfigurable{}
|
||||||
|
ctx := context.Background()
|
||||||
|
config := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
MaxVersion: tls.VersionTLS13,
|
||||||
|
ServerName: URL.Hostname(),
|
||||||
|
}
|
||||||
|
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tlsConn.Close()
|
||||||
|
if connState.Version != tls.VersionTLS13 {
|
||||||
|
t.Fatal("unexpected TLS version")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets default root CA", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
var gotTLSConfig *tls.Config
|
||||||
|
handshaker := &tlsHandshakerConfigurable{
|
||||||
|
NewConn: func(conn net.Conn, config *tls.Config) TLSConn {
|
||||||
|
gotTLSConfig = config
|
||||||
|
return &mocks.TLSConn{
|
||||||
|
MockHandshakeContext: func(ctx context.Context) error {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
config := &tls.Config{}
|
||||||
|
conn := &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(connState).IsZero() {
|
||||||
|
t.Fatal("expected zero connState here")
|
||||||
|
}
|
||||||
|
if tlsConn != nil {
|
||||||
|
t.Fatal("expected nil tlsConn here")
|
||||||
|
}
|
||||||
|
if config.RootCAs != nil {
|
||||||
|
t.Fatal("config.RootCAs should still be nil")
|
||||||
|
}
|
||||||
|
if gotTLSConfig.RootCAs != defaultCertPool {
|
||||||
|
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerLoggerSuccess(t *testing.T) {
|
func TestTLSHandshakerLogger(t *testing.T) {
|
||||||
th := &tlsHandshakerLogger{
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
TLSHandshaker: &mocks.TLSHandshaker{
|
t.Run("on success", func(t *testing.T) {
|
||||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
th := &tlsHandshakerLogger{
|
||||||
return tls.Client(conn, config), tls.ConnectionState{}, nil
|
TLSHandshaker: &mocks.TLSHandshaker{
|
||||||
|
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
return tls.Client(conn, config), tls.ConnectionState{}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Logger: log.Log,
|
||||||
|
}
|
||||||
|
conn := &mocks.Conn{
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config := &tls.Config{}
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConn, connState, err := th.Handshake(ctx, conn, config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := tlsConn.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(connState).IsZero() {
|
||||||
|
t.Fatal("expected zero ConnectionState here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
th := &tlsHandshakerLogger{
|
||||||
|
TLSHandshaker: &mocks.TLSHandshaker{
|
||||||
|
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
return nil, tls.ConnectionState{}, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Logger: log.Log,
|
||||||
|
}
|
||||||
|
conn := &mocks.Conn{
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config := &tls.Config{}
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConn, connState, err := th.Handshake(ctx, conn, config)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if tlsConn != nil {
|
||||||
|
t.Fatal("expected nil conn here")
|
||||||
|
}
|
||||||
|
if !reflect.ValueOf(connState).IsZero() {
|
||||||
|
t.Fatal("expected zero ConnectionState here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSDialer(t *testing.T) {
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
dialer := &tlsDialer{
|
||||||
|
Dialer: &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
}
|
||||||
Logger: log.Log,
|
dialer.CloseIdleConnections()
|
||||||
}
|
if !called {
|
||||||
conn := &mocks.Conn{
|
t.Fatal("not called")
|
||||||
MockClose: func() error {
|
}
|
||||||
return nil
|
})
|
||||||
},
|
|
||||||
}
|
t.Run("DialTLSContext", func(t *testing.T) {
|
||||||
config := &tls.Config{}
|
t.Run("failure to split host and port", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
dialer := &tlsDialer{}
|
||||||
tlsConn, connState, err := th.Handshake(ctx, conn, config)
|
ctx := context.Background()
|
||||||
if err != nil {
|
const address = "www.google.com" // missing port
|
||||||
t.Fatal(err)
|
conn, err := dialer.DialTLSContext(ctx, "tcp", address)
|
||||||
}
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
if err := tlsConn.Close(); err != nil {
|
t.Fatal("not the error we expected", err)
|
||||||
t.Fatal(err)
|
}
|
||||||
}
|
if conn != nil {
|
||||||
if !reflect.ValueOf(connState).IsZero() {
|
t.Fatal("connection is not nil")
|
||||||
t.Fatal("expected zero ConnectionState here")
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
|
t.Run("failure dialing", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // immediately fail
|
||||||
|
dialer := tlsDialer{Dialer: &dialerSystem{}}
|
||||||
|
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||||
|
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("connection is not nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("failure handshaking", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
dialer := tlsDialer{
|
||||||
|
Config: &tls.Config{},
|
||||||
|
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}, MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
}, MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}}, nil
|
||||||
|
}},
|
||||||
|
TLSHandshaker: &tlsHandshakerConfigurable{},
|
||||||
|
}
|
||||||
|
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("connection is not nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("success handshaking", func(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
dialer := tlsDialer{
|
||||||
|
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
}, MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
}, MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}}, nil
|
||||||
|
}},
|
||||||
|
TLSHandshaker: &mocks.TLSHandshaker{
|
||||||
|
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
return tls.Client(conn, config), tls.ConnectionState{}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
t.Fatal("connection is nil")
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("config", func(t *testing.T) {
|
||||||
|
t.Run("from empty config for web", func(t *testing.T) {
|
||||||
|
d := &tlsDialer{}
|
||||||
|
config := d.config("www.google.com", "443")
|
||||||
|
if config.ServerName != "www.google.com" {
|
||||||
|
t.Fatal("invalid server name")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("from empty config for dot", func(t *testing.T) {
|
||||||
|
d := &tlsDialer{}
|
||||||
|
config := d.config("dns.google", "853")
|
||||||
|
if config.ServerName != "dns.google" {
|
||||||
|
t.Fatal("invalid server name")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with server name", func(t *testing.T) {
|
||||||
|
d := &tlsDialer{
|
||||||
|
Config: &tls.Config{
|
||||||
|
ServerName: "example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config := d.config("dns.google", "853")
|
||||||
|
if config.ServerName != "example.com" {
|
||||||
|
t.Fatal("invalid server name")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with alpn", func(t *testing.T) {
|
||||||
|
d := &tlsDialer{
|
||||||
|
Config: &tls.Config{
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config := d.config("dns.google", "853")
|
||||||
|
if config.ServerName != "dns.google" {
|
||||||
|
t.Fatal("invalid server name")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLSHandshakerLoggerFailure(t *testing.T) {
|
func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
|
||||||
th := &tlsHandshakerLogger{
|
|
||||||
TLSHandshaker: &mocks.TLSHandshaker{
|
|
||||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
|
||||||
return nil, tls.ConnectionState{}, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Logger: log.Log,
|
|
||||||
}
|
|
||||||
conn := &mocks.Conn{
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config := &tls.Config{}
|
|
||||||
ctx := context.Background()
|
|
||||||
tlsConn, connState, err := th.Handshake(ctx, conn, config)
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if tlsConn != nil {
|
|
||||||
t.Fatal("expected nil conn here")
|
|
||||||
}
|
|
||||||
if !reflect.ValueOf(connState).IsZero() {
|
|
||||||
t.Fatal("expected zero ConnectionState here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerCloseIdleConnections(t *testing.T) {
|
|
||||||
var called bool
|
|
||||||
dialer := &tlsDialer{
|
|
||||||
Dialer: &mocks.Dialer{
|
|
||||||
MockCloseIdleConnections: func() {
|
|
||||||
called = true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
dialer.CloseIdleConnections()
|
|
||||||
if !called {
|
|
||||||
t.Fatal("not called")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerDialTLSContextFailureSplitHostPort(t *testing.T) {
|
|
||||||
dialer := &tlsDialer{}
|
|
||||||
ctx := context.Background()
|
|
||||||
const address = "www.google.com" // missing port
|
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", address)
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("connection is not nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerDialTLSContextFailureDialing(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel() // immediately fail
|
|
||||||
dialer := tlsDialer{Dialer: &dialerSystem{}}
|
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "operation was canceled") {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("connection is not nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerDialTLSContextFailureHandshaking(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
dialer := tlsDialer{
|
|
||||||
Config: &tls.Config{},
|
|
||||||
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
|
||||||
return 0, io.EOF
|
|
||||||
}, MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
}, MockSetDeadline: func(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}}, nil
|
|
||||||
}},
|
|
||||||
TLSHandshaker: &tlsHandshakerConfigurable{},
|
|
||||||
}
|
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
|
||||||
if !errors.Is(err, io.EOF) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("connection is not nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerDialTLSContextSuccessHandshaking(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
dialer := tlsDialer{
|
|
||||||
Dialer: &mocks.Dialer{MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return &mocks.Conn{MockWrite: func(b []byte) (int, error) {
|
|
||||||
return 0, io.EOF
|
|
||||||
}, MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
}, MockSetDeadline: func(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
}}, nil
|
|
||||||
}},
|
|
||||||
TLSHandshaker: &mocks.TLSHandshaker{
|
|
||||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
|
||||||
return tls.Client(conn, config), tls.ConnectionState{}, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conn, err := dialer.DialTLSContext(ctx, "tcp", "www.google.com:443")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if conn == nil {
|
|
||||||
t.Fatal("connection is nil")
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerConfigFromEmptyConfigForWeb(t *testing.T) {
|
|
||||||
d := &tlsDialer{}
|
|
||||||
config := d.config("www.google.com", "443")
|
|
||||||
if config.ServerName != "www.google.com" {
|
|
||||||
t.Fatal("invalid server name")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"h2", "http/1.1"}); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerConfigFromEmptyConfigForDoT(t *testing.T) {
|
|
||||||
d := &tlsDialer{}
|
|
||||||
config := d.config("dns.google", "853")
|
|
||||||
if config.ServerName != "dns.google" {
|
|
||||||
t.Fatal("invalid server name")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerConfigWithServerName(t *testing.T) {
|
|
||||||
d := &tlsDialer{
|
|
||||||
Config: &tls.Config{
|
|
||||||
ServerName: "example.com",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config := d.config("dns.google", "853")
|
|
||||||
if config.ServerName != "example.com" {
|
|
||||||
t.Fatal("invalid server name")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"dot"}); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTLSDialerConfigWithALPN(t *testing.T) {
|
|
||||||
d := &tlsDialer{
|
|
||||||
Config: &tls.Config{
|
|
||||||
NextProtos: []string{"h2"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
config := d.config("dns.google", "853")
|
|
||||||
if config.ServerName != "dns.google" {
|
|
||||||
t.Fatal("invalid server name")
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(config.NextProtos, []string{"h2"}); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSHandshakerStdlibTypes(t *testing.T) {
|
|
||||||
th := NewTLSHandshakerStdlib(log.Log)
|
th := NewTLSHandshakerStdlib(log.Log)
|
||||||
thl, okay := th.(*tlsHandshakerLogger)
|
logger := th.(*tlsHandshakerLogger)
|
||||||
if !okay {
|
if logger.Logger != log.Log {
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if thl.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||||
if !okay {
|
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||||
t.Fatal("invalid type")
|
if configurable.NewConn != nil {
|
||||||
}
|
|
||||||
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
|
|
||||||
if !okay {
|
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if thc.NewConn != nil {
|
|
||||||
t.Fatal("expected nil NewConn")
|
t.Fatal("expected nil NewConn")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTLSDialerWorksAsIntended(t *testing.T) {
|
func TestNewTLSDialer(t *testing.T) {
|
||||||
d := &mocks.Dialer{}
|
d := &mocks.Dialer{}
|
||||||
tlsh := &mocks.TLSHandshaker{}
|
th := &mocks.TLSHandshaker{}
|
||||||
td := NewTLSDialer(d, tlsh)
|
dialer := NewTLSDialer(d, th)
|
||||||
tdut, okay := td.(*tlsDialer)
|
tlsd := dialer.(*tlsDialer)
|
||||||
if !okay {
|
if tlsd.Config == nil {
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if tdut.Config == nil {
|
|
||||||
t.Fatal("unexpected config")
|
t.Fatal("unexpected config")
|
||||||
}
|
}
|
||||||
if tdut.Dialer != d {
|
if tlsd.Dialer != d {
|
||||||
t.Fatal("unexpected dialer")
|
t.Fatal("unexpected dialer")
|
||||||
}
|
}
|
||||||
if tdut.TLSHandshaker != tlsh {
|
if tlsd.TLSHandshaker != th {
|
||||||
t.Fatal("invalid handshaker")
|
t.Fatal("invalid handshaker")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewSingleUseTLSDialerWorksAsIntended(t *testing.T) {
|
func TestNewSingleUseTLSDialer(t *testing.T) {
|
||||||
conn := &mocks.TLSConn{}
|
conn := &mocks.TLSConn{}
|
||||||
d := NewSingleUseTLSDialer(conn)
|
d := NewSingleUseTLSDialer(conn)
|
||||||
outconn, err := d.DialTLSContext(context.Background(), "", "")
|
outconn, err := d.DialTLSContext(context.Background(), "", "")
|
||||||
|
@ -2,9 +2,7 @@ package netxlite
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -13,107 +11,84 @@ import (
|
|||||||
utls "gitlab.com/yawning/utls.git"
|
utls "gitlab.com/yawning/utls.git"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUTLSHandshakerChrome(t *testing.T) {
|
func TestNewTLSHandshakerUTLS(t *testing.T) {
|
||||||
h := &tlsHandshakerConfigurable{
|
|
||||||
NewConn: newConnUTLS(&utls.HelloChrome_Auto),
|
|
||||||
}
|
|
||||||
cfg := &tls.Config{ServerName: "google.com"}
|
|
||||||
conn, err := net.Dial("tcp", "google.com:443")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("unexpected error", err)
|
|
||||||
}
|
|
||||||
conn, _, err = h.Handshake(context.Background(), conn, cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("unexpected error", err)
|
|
||||||
}
|
|
||||||
if conn == nil {
|
|
||||||
t.Fatal("nil connection")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTLSHandshakerUTLSTypes(t *testing.T) {
|
|
||||||
th := NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_83)
|
th := NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_83)
|
||||||
thl, okay := th.(*tlsHandshakerLogger)
|
logger := th.(*tlsHandshakerLogger)
|
||||||
if !okay {
|
if logger.Logger != log.Log {
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if thl.Logger != log.Log {
|
|
||||||
t.Fatal("invalid logger")
|
t.Fatal("invalid logger")
|
||||||
}
|
}
|
||||||
ew, okay := thl.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||||
if !okay {
|
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||||
t.Fatal("invalid type")
|
if configurable.NewConn == nil {
|
||||||
}
|
|
||||||
thc, okay := ew.TLSHandshaker.(*tlsHandshakerConfigurable)
|
|
||||||
if !okay {
|
|
||||||
t.Fatal("invalid type")
|
|
||||||
}
|
|
||||||
if thc.NewConn == nil {
|
|
||||||
t.Fatal("expected non-nil NewConn")
|
t.Fatal("expected non-nil NewConn")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUTLSConnHandshakeNotInterruptedSuccess(t *testing.T) {
|
func TestUTLSConn(t *testing.T) {
|
||||||
ctx := context.Background()
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
conn := &utlsConn{
|
t.Run("not interrupted with success", func(t *testing.T) {
|
||||||
testableHandshake: func() error {
|
ctx := context.Background()
|
||||||
return nil
|
conn := &utlsConn{
|
||||||
},
|
testableHandshake: func() error {
|
||||||
}
|
return nil
|
||||||
err := conn.HandshakeContext(ctx)
|
},
|
||||||
if err != nil {
|
}
|
||||||
t.Fatal(err)
|
err := conn.HandshakeContext(ctx)
|
||||||
}
|
if err != nil {
|
||||||
}
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
func TestUTLSConnHandshakeNotInterruptedFailure(t *testing.T) {
|
t.Run("not interrupted with failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn := &utlsConn{
|
conn := &utlsConn{
|
||||||
testableHandshake: func() error {
|
testableHandshake: func() error {
|
||||||
return expected
|
return expected
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := conn.HandshakeContext(ctx)
|
err := conn.HandshakeContext(ctx)
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestUTLSConnHandshakeInterrupted(t *testing.T) {
|
t.Run("interrupted", func(t *testing.T) {
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
sigch := make(chan interface{})
|
sigch := make(chan interface{})
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
conn := &utlsConn{
|
conn := &utlsConn{
|
||||||
testableHandshake: func() error {
|
testableHandshake: func() error {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
<-sigch
|
<-sigch
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := conn.HandshakeContext(ctx)
|
err := conn.HandshakeContext(ctx)
|
||||||
if !errors.Is(err, context.DeadlineExceeded) {
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
}
|
}
|
||||||
close(sigch)
|
close(sigch)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestUTLSConnHandshakePanic(t *testing.T) {
|
t.Run("with panic", func(t *testing.T) {
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn := &utlsConn{
|
conn := &utlsConn{
|
||||||
testableHandshake: func() error {
|
testableHandshake: func() error {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
panic("mascetti")
|
panic("mascetti")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := conn.HandshakeContext(ctx)
|
err := conn.HandshakeContext(ctx)
|
||||||
if !errors.Is(err, ErrUTLSHandshakePanic) {
|
if !errors.Is(err, ErrUTLSHandshakePanic) {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user