package netx_test import ( "crypto/tls" "errors" "net/http" "strings" "testing" "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/engine/netx" "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport" "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" "github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) func TestNewResolverVanilla(t *testing.T) { r := netx.NewResolver(netx.Config{}) ir, ok := r.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } ar, ok := ewr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } } func TestNewResolverSpecificResolver(t *testing.T) { r := netx.NewResolver(netx.Config{ BaseResolver: resolver.BogonResolver{ // not initialized because it doesn't matter in this context }, }) ir, ok := r.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } ar, ok := ewr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } _, ok = ar.Resolver.(resolver.BogonResolver) if !ok { t.Fatal("not the resolver we expected") } } func TestNewResolverWithBogonFilter(t *testing.T) { r := netx.NewResolver(netx.Config{ BogonIsError: true, }) ir, ok := r.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } br, ok := ewr.Resolver.(resolver.BogonResolver) if !ok { t.Fatal("not the resolver we expected") } ar, ok := br.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } } func TestNewResolverWithLogging(t *testing.T) { r := netx.NewResolver(netx.Config{ Logger: log.Log, }) ir, ok := r.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } lr, ok := ir.Resolver.(resolver.LoggingResolver) if !ok { t.Fatal("not the resolver we expected") } if lr.Logger != log.Log { t.Fatal("not the logger we expected") } ewr, ok := lr.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } ar, ok := ewr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } } func TestNewResolverWithSaver(t *testing.T) { saver := new(trace.Saver) r := netx.NewResolver(netx.Config{ ResolveSaver: saver, }) ir, ok := r.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } sr, ok := ir.Resolver.(resolver.SaverResolver) if !ok { t.Fatal("not the resolver we expected") } if sr.Saver != saver { t.Fatal("not the saver we expected") } ewr, ok := sr.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } ar, ok := ewr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } } func TestNewResolverWithReadWriteCache(t *testing.T) { r := netx.NewResolver(netx.Config{ CacheResolutions: true, }) ir, ok := r.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } cr, ok := ewr.Resolver.(*resolver.CacheResolver) if !ok { t.Fatal("not the resolver we expected") } if cr.ReadOnly != false { t.Fatal("expected readwrite cache here") } ar, ok := cr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } } func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { r := netx.NewResolver(netx.Config{ DNSCache: map[string][]string{ "dns.google.com": {"8.8.8.8"}, }, }) ir, ok := r.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } cr, ok := ewr.Resolver.(*resolver.CacheResolver) if !ok { t.Fatal("not the resolver we expected") } if cr.ReadOnly != true { t.Fatal("expected readonly cache here") } if cr.Get("dns.google.com")[0] != "8.8.8.8" { t.Fatal("cache not correctly prefilled") } ar, ok := cr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } } func TestNewDialerVanilla(t *testing.T) { d := netx.NewDialer(netx.Config{}) sd, ok := d.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } pd, ok := sd.Dialer.(dialer.ProxyDialer) if !ok { t.Fatal("not the dialer we expected") } if pd.ProxyURL != nil { t.Fatal("not the proxy URL we expected") } dnsd, ok := pd.Dialer.(dialer.DNSDialer) if !ok { t.Fatal("not the dialer we expected") } if dnsd.Resolver == nil { t.Fatal("not the resolver we expected") } ir, ok := dnsd.Resolver.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) if !ok { t.Fatal("not the dialer we expected") } td, ok := ewd.Dialer.(dialer.TimeoutDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := td.Dialer.(selfcensor.SystemDialer); !ok { t.Fatal("not the dialer we expected") } } func TestNewDialerWithResolver(t *testing.T) { d := netx.NewDialer(netx.Config{ FullResolver: resolver.BogonResolver{ // not initialized because it doesn't matter in this context }, }) sd, ok := d.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } pd, ok := sd.Dialer.(dialer.ProxyDialer) if !ok { t.Fatal("not the dialer we expected") } if pd.ProxyURL != nil { t.Fatal("not the proxy URL we expected") } dnsd, ok := pd.Dialer.(dialer.DNSDialer) if !ok { t.Fatal("not the dialer we expected") } if dnsd.Resolver == nil { t.Fatal("not the resolver we expected") } if _, ok := dnsd.Resolver.(resolver.BogonResolver); !ok { t.Fatal("not the resolver we expected") } ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) if !ok { t.Fatal("not the dialer we expected") } td, ok := ewd.Dialer.(dialer.TimeoutDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := td.Dialer.(selfcensor.SystemDialer); !ok { t.Fatal("not the dialer we expected") } } func TestNewDialerWithLogger(t *testing.T) { d := netx.NewDialer(netx.Config{ Logger: log.Log, }) sd, ok := d.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } pd, ok := sd.Dialer.(dialer.ProxyDialer) if !ok { t.Fatal("not the dialer we expected") } if pd.ProxyURL != nil { t.Fatal("not the proxy URL we expected") } dnsd, ok := pd.Dialer.(dialer.DNSDialer) if !ok { t.Fatal("not the dialer we expected") } if dnsd.Resolver == nil { t.Fatal("not the resolver we expected") } ir, ok := dnsd.Resolver.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := ir.Resolver.(resolver.LoggingResolver); !ok { t.Fatal("not the resolver we expected") } ld, ok := dnsd.Dialer.(dialer.LoggingDialer) if !ok { t.Fatal("not the dialer we expected") } if ld.Logger != log.Log { t.Fatal("not the logger we expected") } ewd, ok := ld.Dialer.(dialer.ErrorWrapperDialer) if !ok { t.Fatal("not the dialer we expected") } td, ok := ewd.Dialer.(dialer.TimeoutDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := td.Dialer.(selfcensor.SystemDialer); !ok { t.Fatal("not the dialer we expected") } } func TestNewDialerWithDialSaver(t *testing.T) { saver := new(trace.Saver) d := netx.NewDialer(netx.Config{ DialSaver: saver, }) sd, ok := d.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } pd, ok := sd.Dialer.(dialer.ProxyDialer) if !ok { t.Fatal("not the dialer we expected") } if pd.ProxyURL != nil { t.Fatal("not the proxy URL we expected") } dnsd, ok := pd.Dialer.(dialer.DNSDialer) if !ok { t.Fatal("not the dialer we expected") } if dnsd.Resolver == nil { t.Fatal("not the resolver we expected") } ir, ok := dnsd.Resolver.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } sad, ok := dnsd.Dialer.(dialer.SaverDialer) if !ok { t.Fatal("not the dialer we expected") } if sad.Saver != saver { t.Fatal("not the logger we expected") } ewd, ok := sad.Dialer.(dialer.ErrorWrapperDialer) if !ok { t.Fatal("not the dialer we expected") } td, ok := ewd.Dialer.(dialer.TimeoutDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := td.Dialer.(selfcensor.SystemDialer); !ok { t.Fatal("not the dialer we expected") } } func TestNewDialerWithReadWriteSaver(t *testing.T) { saver := new(trace.Saver) d := netx.NewDialer(netx.Config{ ReadWriteSaver: saver, }) sd, ok := d.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } pd, ok := sd.Dialer.(dialer.ProxyDialer) if !ok { t.Fatal("not the dialer we expected") } if pd.ProxyURL != nil { t.Fatal("not the proxy URL we expected") } dnsd, ok := pd.Dialer.(dialer.DNSDialer) if !ok { t.Fatal("not the dialer we expected") } if dnsd.Resolver == nil { t.Fatal("not the resolver we expected") } ir, ok := dnsd.Resolver.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } scd, ok := dnsd.Dialer.(dialer.SaverConnDialer) if !ok { t.Fatal("not the dialer we expected") } if scd.Saver != saver { t.Fatal("not the logger we expected") } ewd, ok := scd.Dialer.(dialer.ErrorWrapperDialer) if !ok { t.Fatal("not the dialer we expected") } td, ok := ewd.Dialer.(dialer.TimeoutDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := td.Dialer.(selfcensor.SystemDialer); !ok { t.Fatal("not the dialer we expected") } } func TestNewDialerWithContextByteCounting(t *testing.T) { d := netx.NewDialer(netx.Config{ ContextByteCounting: true, }) sd, ok := d.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } bcd, ok := sd.Dialer.(dialer.ByteCounterDialer) if !ok { t.Fatal("not the dialer we expected") } pd, ok := bcd.Dialer.(dialer.ProxyDialer) if !ok { t.Fatal("not the dialer we expected") } if pd.ProxyURL != nil { t.Fatal("not the proxy URL we expected") } dnsd, ok := pd.Dialer.(dialer.DNSDialer) if !ok { t.Fatal("not the dialer we expected") } if dnsd.Resolver == nil { t.Fatal("not the resolver we expected") } ir, ok := dnsd.Resolver.(resolver.IDNAResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) if !ok { t.Fatal("not the dialer we expected") } td, ok := ewd.Dialer.(dialer.TimeoutDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := td.Dialer.(selfcensor.SystemDialer); !ok { t.Fatal("not the dialer we expected") } } func TestNewTLSDialerVanilla(t *testing.T) { td := netx.NewTLSDialer(netx.Config{}) rtd, ok := td.(dialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } if len(rtd.Config.NextProtos) != 2 { t.Fatal("invalid len(config.NextProtos)") } if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { t.Fatal("invalid Config.NextProtos") } if rtd.Config.RootCAs != netx.DefaultCertPool() { t.Fatal("invalid Config.RootCAs") } if rtd.Dialer == nil { t.Fatal("invalid Dialer") } sd, ok := rtd.Dialer.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { t.Fatal("not the Dialer we expected") } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } func TestNewTLSDialerWithConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSConfig: new(tls.Config), }) rtd, ok := td.(dialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } if len(rtd.Config.NextProtos) != 0 { t.Fatal("invalid len(config.NextProtos)") } if rtd.Config.RootCAs != netx.DefaultCertPool() { t.Fatal("invalid Config.RootCAs") } if rtd.Dialer == nil { t.Fatal("invalid Dialer") } sd, ok := rtd.Dialer.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { t.Fatal("not the Dialer we expected") } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } func TestNewTLSDialerWithLogging(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ Logger: log.Log, }) rtd, ok := td.(dialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } if len(rtd.Config.NextProtos) != 2 { t.Fatal("invalid len(config.NextProtos)") } if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { t.Fatal("invalid Config.NextProtos") } if rtd.Config.RootCAs != netx.DefaultCertPool() { t.Fatal("invalid Config.RootCAs") } if rtd.Dialer == nil { t.Fatal("invalid Dialer") } sd, ok := rtd.Dialer.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { t.Fatal("not the Dialer we expected") } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } lth, ok := rtd.TLSHandshaker.(dialer.LoggingTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if lth.Logger != log.Log { t.Fatal("not the Logger we expected") } ewth, ok := lth.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } func TestNewTLSDialerWithSaver(t *testing.T) { saver := new(trace.Saver) td := netx.NewTLSDialer(netx.Config{ TLSSaver: saver, }) rtd, ok := td.(dialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } if len(rtd.Config.NextProtos) != 2 { t.Fatal("invalid len(config.NextProtos)") } if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { t.Fatal("invalid Config.NextProtos") } if rtd.Config.RootCAs != netx.DefaultCertPool() { t.Fatal("invalid Config.RootCAs") } if rtd.Dialer == nil { t.Fatal("invalid Dialer") } sd, ok := rtd.Dialer.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { t.Fatal("not the Dialer we expected") } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } sth, ok := rtd.TLSHandshaker.(dialer.SaverTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if sth.Saver != saver { t.Fatal("not the Logger we expected") } ewth, ok := sth.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ TLSConfig: new(tls.Config), NoTLSVerify: true, }) rtd, ok := td.(dialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } if len(rtd.Config.NextProtos) != 0 { t.Fatal("invalid len(config.NextProtos)") } if rtd.Config.InsecureSkipVerify != true { t.Fatal("expected true InsecureSkipVerify") } if rtd.Config.RootCAs != netx.DefaultCertPool() { t.Fatal("invalid Config.RootCAs") } if rtd.Dialer == nil { t.Fatal("invalid Dialer") } sd, ok := rtd.Dialer.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { t.Fatal("not the Dialer we expected") } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) { td := netx.NewTLSDialer(netx.Config{ NoTLSVerify: true, }) rtd, ok := td.(dialer.TLSDialer) if !ok { t.Fatal("not the TLSDialer we expected") } if len(rtd.Config.NextProtos) != 2 { t.Fatal("invalid len(config.NextProtos)") } if rtd.Config.NextProtos[0] != "h2" || rtd.Config.NextProtos[1] != "http/1.1" { t.Fatal("invalid Config.NextProtos") } if rtd.Config.InsecureSkipVerify != true { t.Fatal("expected true InsecureSkipVerify") } if rtd.Config.RootCAs != netx.DefaultCertPool() { t.Fatal("invalid Config.RootCAs") } if rtd.Dialer == nil { t.Fatal("invalid Dialer") } sd, ok := rtd.Dialer.(dialer.ShapingDialer) if !ok { t.Fatal("not the dialer we expected") } if _, ok := sd.Dialer.(dialer.ProxyDialer); !ok { t.Fatal("not the Dialer we expected") } if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } ewth, ok := rtd.TLSHandshaker.(dialer.ErrorWrapperTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } tth, ok := ewth.TLSHandshaker.(dialer.TimeoutTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } if _, ok := tth.TLSHandshaker.(dialer.SystemTLSHandshaker); !ok { t.Fatal("not the TLSHandshaker we expected") } } func TestNewVanilla(t *testing.T) { txp := netx.NewHTTPTransport(netx.Config{}) uatxp, ok := txp.(httptransport.UserAgentTransport) if !ok { t.Fatal("not the transport we expected") } if _, ok := uatxp.RoundTripper.(*http.Transport); !ok { t.Fatal("not the transport we expected") } } func TestNewWithDialer(t *testing.T) { expected := errors.New("mocked error") dialer := netx.FakeDialer{Err: expected} txp := netx.NewHTTPTransport(netx.Config{ Dialer: dialer, }) client := &http.Client{Transport: txp} resp, err := client.Get("http://www.google.com") if !errors.Is(err, expected) { t.Fatal("not the error we expected") } if resp != nil { t.Fatal("not the response we expected") } } func TestNewWithTLSDialer(t *testing.T) { expected := errors.New("mocked error") tlsDialer := dialer.TLSDialer{ Config: new(tls.Config), Dialer: netx.FakeDialer{Err: expected}, TLSHandshaker: dialer.SystemTLSHandshaker{}, } txp := netx.NewHTTPTransport(netx.Config{ TLSDialer: tlsDialer, }) client := &http.Client{Transport: txp} resp, err := client.Get("https://www.google.com") if !errors.Is(err, expected) { t.Fatal("not the error we expected") } if resp != nil { t.Fatal("not the response we expected") } } func TestNewWithByteCounter(t *testing.T) { counter := bytecounter.New() txp := netx.NewHTTPTransport(netx.Config{ ByteCounter: counter, }) uatxp, ok := txp.(httptransport.UserAgentTransport) if !ok { t.Fatal("not the transport we expected") } bctxp, ok := uatxp.RoundTripper.(httptransport.ByteCountingTransport) if !ok { t.Fatal("not the transport we expected") } if bctxp.Counter != counter { t.Fatal("not the byte counter we expected") } if _, ok := bctxp.RoundTripper.(*http.Transport); !ok { t.Fatal("not the transport we expected") } } func TestNewWithLogger(t *testing.T) { txp := netx.NewHTTPTransport(netx.Config{ Logger: log.Log, }) uatxp, ok := txp.(httptransport.UserAgentTransport) if !ok { t.Fatal("not the transport we expected") } ltxp, ok := uatxp.RoundTripper.(httptransport.LoggingTransport) if !ok { t.Fatal("not the transport we expected") } if ltxp.Logger != log.Log { t.Fatal("not the logger we expected") } if _, ok := ltxp.RoundTripper.(*http.Transport); !ok { t.Fatal("not the transport we expected") } } func TestNewWithSaver(t *testing.T) { saver := new(trace.Saver) txp := netx.NewHTTPTransport(netx.Config{ HTTPSaver: saver, }) uatxp, ok := txp.(httptransport.UserAgentTransport) if !ok { t.Fatal("not the transport we expected") } stxptxp, ok := uatxp.RoundTripper.(httptransport.SaverTransactionHTTPTransport) if !ok { t.Fatal("not the transport we expected") } if stxptxp.Saver != saver { t.Fatal("not the logger we expected") } sptxp, ok := stxptxp.RoundTripper.(httptransport.SaverPerformanceHTTPTransport) if !ok { t.Fatal("not the transport we expected") } if sptxp.Saver != saver { t.Fatal("not the logger we expected") } sbtxp, ok := sptxp.RoundTripper.(httptransport.SaverBodyHTTPTransport) if !ok { t.Fatal("not the transport we expected") } if sbtxp.Saver != saver { t.Fatal("not the logger we expected") } smtxp, ok := sbtxp.RoundTripper.(httptransport.SaverMetadataHTTPTransport) if !ok { t.Fatal("not the transport we expected") } if smtxp.Saver != saver { t.Fatal("not the logger we expected") } if _, ok := smtxp.RoundTripper.(*http.Transport); !ok { t.Fatal("not the transport we expected") } } func TestNewDNSClientInvalidURL(t *testing.T) { dnsclient, err := netx.NewDNSClient(netx.Config{}, "\t\t\t") if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { t.Fatal("not the error we expected") } if dnsclient.Resolver != nil { t.Fatal("expected nil resolver here") } dnsclient.CloseIdleConnections() } func TestNewDNSClientUnsupportedScheme(t *testing.T) { dnsclient, err := netx.NewDNSClient(netx.Config{}, "antani:///") if err == nil || err.Error() != "unsupported resolver scheme" { t.Fatal("not the error we expected") } if dnsclient.Resolver != nil { t.Fatal("expected nil resolver here") } dnsclient.CloseIdleConnections() } func TestNewDNSClientSystemResolver(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "system:///") if err != nil { t.Fatal(err) } if _, ok := dnsclient.Resolver.(resolver.SystemResolver); !ok { t.Fatal("not the resolver we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientEmpty(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "") if err != nil { t.Fatal(err) } if _, ok := dnsclient.Resolver.(resolver.SystemResolver); !ok { t.Fatal("not the resolver we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientPowerdnsDoH(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "doh://powerdns") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientGoogleDoH(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "doh://google") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientCloudflareDoH(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "doh://cloudflare") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientCloudflareDoHSaver(t *testing.T) { saver := new(trace.Saver) dnsclient, err := netx.NewDNSClient( netx.Config{ResolveSaver: saver}, "doh://cloudflare") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } txp, ok := r.Transport().(resolver.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } if _, ok := txp.RoundTripper.(resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientUDP(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "udp://8.8.8.8:53") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } if _, ok := r.Transport().(resolver.DNSOverUDP); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientUDPDNSSaver(t *testing.T) { saver := new(trace.Saver) dnsclient, err := netx.NewDNSClient( netx.Config{ResolveSaver: saver}, "udp://8.8.8.8:53") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } txp, ok := r.Transport().(resolver.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } if _, ok := txp.RoundTripper.(resolver.DNSOverUDP); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientTCP(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "tcp://8.8.8.8:53") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } txp, ok := r.Transport().(resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } if txp.Network() != "tcp" { t.Fatal("not the Network we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientTCPDNSSaver(t *testing.T) { saver := new(trace.Saver) dnsclient, err := netx.NewDNSClient( netx.Config{ResolveSaver: saver}, "tcp://8.8.8.8:53") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } txp, ok := r.Transport().(resolver.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } dotcp, ok := txp.RoundTripper.(resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } if dotcp.Network() != "tcp" { t.Fatal("not the Network we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientDoT(t *testing.T) { dnsclient, err := netx.NewDNSClient( netx.Config{}, "dot://8.8.8.8:53") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } txp, ok := r.Transport().(resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } if txp.Network() != "dot" { t.Fatal("not the Network we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSClientDoTDNSSaver(t *testing.T) { saver := new(trace.Saver) dnsclient, err := netx.NewDNSClient( netx.Config{ResolveSaver: saver}, "dot://8.8.8.8:53") if err != nil { t.Fatal(err) } r, ok := dnsclient.Resolver.(resolver.SerialResolver) if !ok { t.Fatal("not the resolver we expected") } txp, ok := r.Transport().(resolver.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } dotls, ok := txp.RoundTripper.(resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } if dotls.Network() != "dot" { t.Fatal("not the Network we expected") } dnsclient.CloseIdleConnections() } func TestNewDNSCLientDoTWithoutPort(t *testing.T) { c, err := netx.NewDNSClientWithOverrides( netx.Config{}, "dot://8.8.8.8", "", "8.8.8.8", "") if err != nil { t.Fatal(err) } if c.Resolver.Address() != "8.8.8.8:853" { t.Fatal("expected default port to be added") } } func TestNewDNSCLientTCPWithoutPort(t *testing.T) { c, err := netx.NewDNSClientWithOverrides( netx.Config{}, "tcp://8.8.8.8", "", "8.8.8.8", "") if err != nil { t.Fatal(err) } if c.Resolver.Address() != "8.8.8.8:53" { t.Fatal("expected default port to be added") } } func TestNewDNSCLientUDPWithoutPort(t *testing.T) { c, err := netx.NewDNSClientWithOverrides( netx.Config{}, "udp://8.8.8.8", "", "8.8.8.8", "") if err != nil { t.Fatal(err) } if c.Resolver.Address() != "8.8.8.8:53" { t.Fatal("expected default port to be added") } } func TestNewDNSClientBadDoTEndpoint(t *testing.T) { _, err := netx.NewDNSClient( netx.Config{}, "dot://bad:endpoint:53") if err == nil || !strings.Contains(err.Error(), "too many colons in address") { t.Fatal("expected error with bad endpoint") } } func TestNewDNSClientBadTCPEndpoint(t *testing.T) { _, err := netx.NewDNSClient( netx.Config{}, "tcp://bad:endpoint:853") if err == nil || !strings.Contains(err.Error(), "too many colons in address") { t.Fatal("expected error with bad endpoint") } } func TestNewDNSClientBadUDPEndpoint(t *testing.T) { _, err := netx.NewDNSClient( netx.Config{}, "udp://bad:endpoint:853") if err == nil || !strings.Contains(err.Error(), "too many colons in address") { t.Fatal("expected error with bad endpoint") } } func TestNewDNSCLientWithInvalidTLSVersion(t *testing.T) { _, err := netx.NewDNSClientWithOverrides( netx.Config{}, "dot://8.8.8.8", "", "", "TLSv999") if !errors.Is(err, netx.ErrInvalidTLSVersion) { t.Fatalf("not the error we expected: %+v", err) } } func TestConfigureTLSVersion(t *testing.T) { tests := []struct { name string version string wantErr error versionMin int versionMax int }{{ name: "with TLSv1.3", version: "TLSv1.3", wantErr: nil, versionMin: tls.VersionTLS13, versionMax: tls.VersionTLS13, }, { name: "with TLSv1.2", version: "TLSv1.2", wantErr: nil, versionMin: tls.VersionTLS12, versionMax: tls.VersionTLS12, }, { name: "with TLSv1.1", version: "TLSv1.1", wantErr: nil, versionMin: tls.VersionTLS11, versionMax: tls.VersionTLS11, }, { name: "with TLSv1.0", version: "TLSv1.0", wantErr: nil, versionMin: tls.VersionTLS10, versionMax: tls.VersionTLS10, }, { name: "with TLSv1", version: "TLSv1", wantErr: nil, versionMin: tls.VersionTLS10, versionMax: tls.VersionTLS10, }, { name: "with default", version: "", wantErr: nil, versionMin: 0, versionMax: 0, }, { name: "with invalid version", version: "TLSv999", wantErr: netx.ErrInvalidTLSVersion, versionMin: 0, versionMax: 0, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { conf := new(tls.Config) err := netx.ConfigureTLSVersion(conf, tt.version) if !errors.Is(err, tt.wantErr) { t.Fatalf("not the error we expected: %+v", err) } if conf.MinVersion != uint16(tt.versionMin) { t.Fatalf("not the min version we expected: %+v", conf.MinVersion) } if conf.MaxVersion != uint16(tt.versionMax) { t.Fatalf("not the max version we expected: %+v", conf.MaxVersion) } }) } }