From a3654f60b72bc4f361d07591ea6e9be5d369774b Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Sun, 5 Sep 2021 18:03:50 +0200 Subject: [PATCH] refactor(netxlite): add more functions to resolver (#455) We would like to refactor the code so that a DoH resolver owns the connections of its underlying HTTP client. To do that, we need first to incorporate CloseIdleConnections into the Resolver model. Then, we need to add the same function to all netxlite types that wrap a Resolver type. At the same time, we want the rest of the code for now to continue with the simpler definition of a Resolver, now called ResolverLegacy. We will eventually propagate this change to the rest of the tree and simplify the way in which we manage Resolvers. To make this possible, we introduce a new factory function that adapts a ResolverLegacy to become a Resolver. See https://github.com/ooni/probe/issues/1591. --- .../oohelperd/internal/websteps/explore.go | 2 +- .../oohelperd/internal/websteps/generate.go | 2 +- .../internal/websteps/initialchecks.go | 2 +- .../oohelperd/internal/websteps/measure.go | 10 +-- internal/engine/experiment/ndt7/dial.go | 5 +- internal/engine/experiment/websteps/dns.go | 6 +- .../engine/experiment/websteps/factory.go | 14 ++-- internal/engine/experiment/websteps/quic.go | 2 +- internal/engine/experiment/websteps/tcp.go | 2 +- internal/engine/netx/dialer/dialer.go | 5 +- internal/engine/netx/netx.go | 12 +++- internal/engine/netx/netx_test.go | 56 ++++++++++++--- .../engine/netx/resolver/integration_test.go | 10 ++- internal/netxlite/http3_test.go | 5 +- internal/netxlite/http_test.go | 4 +- internal/netxlite/legacy.go | 63 +++++++++++++++++ internal/netxlite/legacy_test.go | 38 ++++++++++ internal/netxlite/mocks/resolver.go | 12 +++- internal/netxlite/mocks/resolver_test.go | 13 ++++ internal/netxlite/quic_test.go | 12 ++-- internal/netxlite/resolver.go | 70 ++++++++----------- internal/netxlite/resolver_test.go | 53 ++++---------- 22 files changed, 279 insertions(+), 119 deletions(-) diff --git a/internal/cmd/oohelperd/internal/websteps/explore.go b/internal/cmd/oohelperd/internal/websteps/explore.go index 2ec9e5d..5f52b76 100644 --- a/internal/cmd/oohelperd/internal/websteps/explore.go +++ b/internal/cmd/oohelperd/internal/websteps/explore.go @@ -27,7 +27,7 @@ type Explorer interface { // DefaultExplorer is the default Explorer. type DefaultExplorer struct { - resolver netxlite.Resolver + resolver netxlite.ResolverLegacy } // Explore returns a list of round trips sorted so that the first diff --git a/internal/cmd/oohelperd/internal/websteps/generate.go b/internal/cmd/oohelperd/internal/websteps/generate.go index 430c239..3a7fa51 100644 --- a/internal/cmd/oohelperd/internal/websteps/generate.go +++ b/internal/cmd/oohelperd/internal/websteps/generate.go @@ -24,7 +24,7 @@ type Generator interface { type DefaultGenerator struct { dialer netxlite.Dialer quicDialer netxlite.QUICContextDialer - resolver netxlite.Resolver + resolver netxlite.ResolverLegacy transport http.RoundTripper } diff --git a/internal/cmd/oohelperd/internal/websteps/initialchecks.go b/internal/cmd/oohelperd/internal/websteps/initialchecks.go index f0c7e95..183a243 100644 --- a/internal/cmd/oohelperd/internal/websteps/initialchecks.go +++ b/internal/cmd/oohelperd/internal/websteps/initialchecks.go @@ -31,7 +31,7 @@ type InitChecker interface { // DefaultInitChecker is the default InitChecker. type DefaultInitChecker struct { - resolver netxlite.Resolver + resolver netxlite.ResolverLegacy } // InitialChecks checks whether the URL is valid and whether the diff --git a/internal/cmd/oohelperd/internal/websteps/measure.go b/internal/cmd/oohelperd/internal/websteps/measure.go index 5482bae..eb27c14 100644 --- a/internal/cmd/oohelperd/internal/websteps/measure.go +++ b/internal/cmd/oohelperd/internal/websteps/measure.go @@ -24,7 +24,7 @@ type Config struct { checker InitChecker explorer Explorer generator Generator - resolver netxlite.Resolver + resolver netxlite.ResolverLegacy } // Measure performs the three consecutive steps of the testhelper algorithm: @@ -87,10 +87,12 @@ func newDNSFailedResponse(err error, URL string) *ControlResponse { } // newResolver creates a new DNS resolver instance -func newResolver() netxlite.Resolver { +func newResolver() netxlite.ResolverLegacy { childResolver, err := netx.NewDNSClient(netx.Config{Logger: log.Log}, "doh://google") runtimex.PanicOnError(err, "NewDNSClient failed") - var r netxlite.Resolver = childResolver - r = &netxlite.ResolverIDNA{Resolver: r} + var r netxlite.ResolverLegacy = childResolver + r = &netxlite.ResolverIDNA{ + Resolver: netxlite.NewResolverLegacyAdapter(r), + } return r } diff --git a/internal/engine/experiment/ndt7/dial.go b/internal/engine/experiment/ndt7/dial.go index c627513..95e41c5 100644 --- a/internal/engine/experiment/ndt7/dial.go +++ b/internal/engine/experiment/ndt7/dial.go @@ -35,7 +35,10 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) { var reso resolver.Resolver = &netxlite.ResolverSystem{} - reso = &netxlite.ResolverLogger{Resolver: reso, Logger: mgr.logger} + reso = &netxlite.ResolverLogger{ + Resolver: netxlite.NewResolverLegacyAdapter(reso), + Logger: mgr.logger, + } dlr := dialer.New(&dialer.Config{ ContextByteCounting: true, Logger: mgr.logger, diff --git a/internal/engine/experiment/websteps/dns.go b/internal/engine/experiment/websteps/dns.go index 80615e9..c07531f 100644 --- a/internal/engine/experiment/websteps/dns.go +++ b/internal/engine/experiment/websteps/dns.go @@ -11,7 +11,7 @@ import ( type DNSConfig struct { Domain string - Resolver netxlite.Resolver + Resolver netxlite.ResolverLegacy } // DNSDo performs the DNS check. @@ -21,7 +21,9 @@ func DNSDo(ctx context.Context, config DNSConfig) ([]string, error) { childResolver, err := netx.NewDNSClient(netx.Config{Logger: log.Log}, "doh://google") runtimex.PanicOnError(err, "NewDNSClient failed") resolver = childResolver - resolver = &netxlite.ResolverIDNA{Resolver: resolver} + resolver = &netxlite.ResolverIDNA{ + Resolver: netxlite.NewResolverLegacyAdapter(resolver), + } } return resolver.LookupHost(ctx, config.Domain) } diff --git a/internal/engine/experiment/websteps/factory.go b/internal/engine/experiment/websteps/factory.go index 70a6581..0b7f317 100644 --- a/internal/engine/experiment/websteps/factory.go +++ b/internal/engine/experiment/websteps/factory.go @@ -33,23 +33,29 @@ func NewRequest(ctx context.Context, URL *url.URL, headers http.Header) *http.Re // NewDialerResolver contructs a new dialer for TCP connections, // with default, errorwrapping and resolve functionalities -func NewDialerResolver(resolver netxlite.Resolver) netxlite.Dialer { +func NewDialerResolver(resolver netxlite.ResolverLegacy) netxlite.Dialer { var d netxlite.Dialer = netxlite.DefaultDialer d = &errorsx.ErrorWrapperDialer{Dialer: d} - d = &netxlite.DialerResolver{Resolver: resolver, Dialer: d} + d = &netxlite.DialerResolver{ + Resolver: netxlite.NewResolverLegacyAdapter(resolver), + Dialer: d, + } return d } // NewQUICDialerResolver creates a new QUICDialerResolver // with default, errorwrapping and resolve functionalities -func NewQUICDialerResolver(resolver netxlite.Resolver) netxlite.QUICContextDialer { +func NewQUICDialerResolver(resolver netxlite.ResolverLegacy) netxlite.QUICContextDialer { var ql quicdialer.QUICListener = &netxlite.QUICListenerStdlib{} ql = &errorsx.ErrorWrapperQUICListener{QUICListener: ql} var dialer netxlite.QUICContextDialer = &netxlite.QUICDialerQUICGo{ QUICListener: ql, } dialer = &errorsx.ErrorWrapperQUICDialer{Dialer: dialer} - dialer = &netxlite.QUICDialerResolver{Resolver: resolver, Dialer: dialer} + dialer = &netxlite.QUICDialerResolver{ + Resolver: netxlite.NewResolverLegacyAdapter(resolver), + Dialer: dialer, + } return dialer } diff --git a/internal/engine/experiment/websteps/quic.go b/internal/engine/experiment/websteps/quic.go index 8667252..70a4b04 100644 --- a/internal/engine/experiment/websteps/quic.go +++ b/internal/engine/experiment/websteps/quic.go @@ -11,7 +11,7 @@ import ( type QUICConfig struct { Endpoint string QUICDialer netxlite.QUICContextDialer - Resolver netxlite.Resolver + Resolver netxlite.ResolverLegacy TLSConf *tls.Config } diff --git a/internal/engine/experiment/websteps/tcp.go b/internal/engine/experiment/websteps/tcp.go index 3e7481d..67bea47 100644 --- a/internal/engine/experiment/websteps/tcp.go +++ b/internal/engine/experiment/websteps/tcp.go @@ -10,7 +10,7 @@ import ( type TCPConfig struct { Dialer netxlite.Dialer Endpoint string - Resolver netxlite.Resolver + Resolver netxlite.ResolverLegacy } // TCPDo performs the TCP check. diff --git a/internal/engine/netx/dialer/dialer.go b/internal/engine/netx/dialer/dialer.go index d3ab02d..31a196a 100644 --- a/internal/engine/netx/dialer/dialer.go +++ b/internal/engine/netx/dialer/dialer.go @@ -80,7 +80,10 @@ func New(config *Config, resolver Resolver) Dialer { if config.ReadWriteSaver != nil { d = &saverConnDialer{Dialer: d, Saver: config.ReadWriteSaver} } - d = &netxlite.DialerResolver{Resolver: resolver, Dialer: d} + d = &netxlite.DialerResolver{ + Resolver: netxlite.NewResolverLegacyAdapter(resolver), + Dialer: d, + } d = &proxyDialer{ProxyURL: config.ProxyURL, Dialer: d} if config.ContextByteCounting { d = &byteCounterDialer{Dialer: d} diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 73ce898..354b087 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -134,12 +134,15 @@ func NewResolver(config Config) Resolver { } r = &errorsx.ErrorWrapperResolver{Resolver: r} if config.Logger != nil { - r = &netxlite.ResolverLogger{Logger: config.Logger, Resolver: r} + r = &netxlite.ResolverLogger{ + Logger: config.Logger, + Resolver: netxlite.NewResolverLegacyAdapter(r), + } } if config.ResolveSaver != nil { r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver} } - return &resolver.IDNAResolver{Resolver: r} + return &resolver.IDNAResolver{Resolver: netxlite.NewResolverLegacyAdapter(r)} } // NewDialer creates a new Dialer from the specified config @@ -176,7 +179,10 @@ func NewQUICDialer(config Config) QUICDialer { if config.TLSSaver != nil { d = quicdialer.HandshakeSaver{Saver: config.TLSSaver, Dialer: d} } - d = &netxlite.QUICDialerResolver{Resolver: config.FullResolver, Dialer: d} + d = &netxlite.QUICDialerResolver{ + Resolver: netxlite.NewResolverLegacyAdapter(config.FullResolver), + Dialer: d, + } return d } diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 3d29a85..357be66 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -24,7 +24,11 @@ func TestNewResolverVanilla(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) + rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -48,7 +52,11 @@ func TestNewResolverSpecificResolver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) + rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -70,7 +78,11 @@ func TestNewResolverWithBogonFilter(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) + rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -96,17 +108,33 @@ func TestNewResolverWithLogging(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - lr, ok := ir.Resolver.(*netxlite.ResolverLogger) + rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + lr, ok := rla.ResolverLegacy.(*netxlite.ResolverLogger) if !ok { t.Fatal("not the resolver we expected") } if lr.Logger != log.Log { t.Fatal("not the logger we expected") } - ewr, ok := lr.Resolver.(*errorsx.ErrorWrapperResolver) + rla, ok = ir.Resolver.(*netxlite.ResolverLegacyAdapter) if !ok { t.Fatal("not the resolver we expected") } + lr, ok = rla.ResolverLegacy.(*netxlite.ResolverLogger) + if !ok { + t.Fatal("not the resolver we expected") + } + rla, ok = lr.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver) + if !ok { + t.Fatalf("not the resolver we expected %T", rla.ResolverLegacy) + } ar, ok := ewr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") @@ -126,7 +154,11 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - sr, ok := ir.Resolver.(resolver.SaverResolver) + rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + sr, ok := rla.ResolverLegacy.(resolver.SaverResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -155,7 +187,11 @@ func TestNewResolverWithReadWriteCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) + rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -186,7 +222,11 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) + rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter) + if !ok { + t.Fatal("not the resolver we expected") + } + ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } diff --git a/internal/engine/netx/resolver/integration_test.go b/internal/engine/netx/resolver/integration_test.go index 7de5a7e..b8d7a75 100644 --- a/internal/engine/netx/resolver/integration_test.go +++ b/internal/engine/netx/resolver/integration_test.go @@ -19,7 +19,10 @@ func testresolverquick(t *testing.T, reso resolver.Resolver) { if testing.Short() { t.Skip("skip test in short mode") } - reso = &netxlite.ResolverLogger{Logger: log.Log, Resolver: reso} + reso = &netxlite.ResolverLogger{ + Logger: log.Log, + Resolver: netxlite.NewResolverLegacyAdapter(reso), + } addrs, err := reso.LookupHost(context.Background(), "dns.google.com") if err != nil { t.Fatal(err) @@ -45,7 +48,10 @@ func testresolverquickidna(t *testing.T, reso resolver.Resolver) { t.Skip("skip test in short mode") } reso = &resolver.IDNAResolver{ - Resolver: &netxlite.ResolverLogger{Logger: log.Log, Resolver: reso}, + Resolver: &netxlite.ResolverLogger{ + Logger: log.Log, + Resolver: netxlite.NewResolverLegacyAdapter(reso), + }, } addrs, err := reso.LookupHost(context.Background(), "яндекс.рф") if err != nil { diff --git a/internal/netxlite/http3_test.go b/internal/netxlite/http3_test.go index 2234bdf..e3e9434 100644 --- a/internal/netxlite/http3_test.go +++ b/internal/netxlite/http3_test.go @@ -2,9 +2,10 @@ package netxlite import ( "crypto/tls" - "net" "net/http" "testing" + + "github.com/apex/log" ) func TestHTTP3TransportWorks(t *testing.T) { @@ -12,7 +13,7 @@ func TestHTTP3TransportWorks(t *testing.T) { Dialer: &quicDialerQUICGo{ QUICListener: &quicListenerStdlib{}, }, - Resolver: &net.Resolver{}, + Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), } txp := NewHTTP3Transport(d, &tls.Config{}) client := &http.Client{Transport: txp} diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index a094ff8..dae6373 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -112,7 +112,7 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) { func TestHTTPTransportWorks(t *testing.T) { d := &dialerResolver{ Dialer: defaultDialer, - Resolver: &net.Resolver{}, + Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), } th := &tlsHandshakerConfigurable{} txp := NewHTTPTransport(d, &tls.Config{}, th) @@ -134,7 +134,7 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) { return nil, expected }, }, - Resolver: &net.Resolver{}, + Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), } th := &tlsHandshakerConfigurable{} txp := NewHTTPTransport(d, &tls.Config{}, th) diff --git a/internal/netxlite/legacy.go b/internal/netxlite/legacy.go index d5bbb89..6e6aad1 100644 --- a/internal/netxlite/legacy.go +++ b/internal/netxlite/legacy.go @@ -1,6 +1,7 @@ package netxlite import ( + "context" "errors" "strings" @@ -59,3 +60,65 @@ type ( TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerLogger = tlsHandshakerLogger ) + +// ResolverLegacy performs domain name resolutions. +// +// This definition of Resolver is DEPRECATED. New code should use +// the more complete definition in the new Resolver interface. +// +// Existing code in ooni/probe-cli is still using this definition. +type ResolverLegacy interface { + // LookupHost behaves like net.Resolver.LookupHost. + LookupHost(ctx context.Context, hostname string) (addrs []string, err error) +} + +// NewResolverLegacyAdapter adapts a ResolverLegacy to +// become compatible with the Resolver definition. +func NewResolverLegacyAdapter(reso ResolverLegacy) Resolver { + return &ResolverLegacyAdapter{reso} +} + +// ResolverLegacyAdapter makes a ResolverLegacy behave like +// it was a Resolver type. If ResolverLegacy is actually also +// a Resolver, this adapter will just forward missing calls, +// otherwise it will implement a sensible default action. +type ResolverLegacyAdapter struct { + ResolverLegacy +} + +var _ Resolver = &ResolverLegacyAdapter{} + +type resolverLegacyNetworker interface { + Network() string +} + +// Network implements Resolver.Network. +func (r *ResolverLegacyAdapter) Network() string { + if rn, ok := r.ResolverLegacy.(resolverLegacyNetworker); ok { + return rn.Network() + } + return "adapter" +} + +type resolverLegacyAddresser interface { + Address() string +} + +// Address implements Resolver.Address. +func (r *ResolverLegacyAdapter) Address() string { + if ra, ok := r.ResolverLegacy.(resolverLegacyAddresser); ok { + return ra.Address() + } + return "" +} + +type resolverLegacyIdleConnectionsCloser interface { + CloseIdleConnections() +} + +// CloseIdleConnections implements Resolver.CloseIdleConnections. +func (r *ResolverLegacyAdapter) CloseIdleConnections() { + if ra, ok := r.ResolverLegacy.(resolverLegacyIdleConnectionsCloser); ok { + ra.CloseIdleConnections() + } +} diff --git a/internal/netxlite/legacy_test.go b/internal/netxlite/legacy_test.go index b21d356..8d44cce 100644 --- a/internal/netxlite/legacy_test.go +++ b/internal/netxlite/legacy_test.go @@ -2,9 +2,11 @@ package netxlite import ( "errors" + "net" "testing" "github.com/ooni/probe-cli/v3/internal/errorsx" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) func TestReduceErrors(t *testing.T) { @@ -44,3 +46,39 @@ func TestReduceErrors(t *testing.T) { } }) } + +func TestResolverLegacyAdapterWithCompatibleType(t *testing.T) { + var called bool + r := NewResolverLegacyAdapter(&mocks.Resolver{ + MockNetwork: func() string { + return "network" + }, + MockAddress: func() string { + return "address" + }, + MockCloseIdleConnections: func() { + called = true + }, + }) + if r.Network() != "network" { + t.Fatal("invalid Network") + } + if r.Address() != "address" { + t.Fatal("invalid Address") + } + r.CloseIdleConnections() + if !called { + t.Fatal("not called") + } +} + +func TestResolverLegacyAdapterDefaults(t *testing.T) { + r := NewResolverLegacyAdapter(&net.Resolver{}) + if r.Network() != "adapter" { + t.Fatal("invalid Network") + } + if r.Address() != "" { + t.Fatal("invalid Address") + } + r.CloseIdleConnections() // does not crash +} diff --git a/internal/netxlite/mocks/resolver.go b/internal/netxlite/mocks/resolver.go index 6ab5a40..3abf749 100644 --- a/internal/netxlite/mocks/resolver.go +++ b/internal/netxlite/mocks/resolver.go @@ -4,9 +4,10 @@ import "context" // Resolver is a mockable Resolver. type Resolver struct { - MockLookupHost func(ctx context.Context, domain string) ([]string, error) - MockNetwork func() string - MockAddress func() string + MockLookupHost func(ctx context.Context, domain string) ([]string, error) + MockNetwork func() string + MockAddress func() string + MockCloseIdleConnections func() } // LookupHost calls MockLookupHost. @@ -23,3 +24,8 @@ func (r *Resolver) Address() string { func (r *Resolver) Network() string { return r.MockNetwork() } + +// CloseIdleConnections calls MockCloseIdleConnections. +func (r *Resolver) CloseIdleConnections() { + r.MockCloseIdleConnections() +} diff --git a/internal/netxlite/mocks/resolver_test.go b/internal/netxlite/mocks/resolver_test.go index 79686c6..8d1ad85 100644 --- a/internal/netxlite/mocks/resolver_test.go +++ b/internal/netxlite/mocks/resolver_test.go @@ -44,3 +44,16 @@ func TestResolverAddress(t *testing.T) { t.Fatal("unexpected address", v) } } + +func TestResolverCloseIdleConnections(t *testing.T) { + var called bool + r := &Resolver{ + MockCloseIdleConnections: func() { + called = true + }, + } + r.CloseIdleConnections() + if !called { + t.Fatal("not called") + } +} diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index d3c2b88..1e2278c 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -215,7 +215,8 @@ func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) { func TestQUICDialerResolverSuccess(t *testing.T) { tlsConfig := &tls.Config{} dialer := &quicDialerResolver{ - Resolver: &net.Resolver{}, Dialer: &quicDialerQUICGo{ + Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), + Dialer: &quicDialerQUICGo{ QUICListener: &quicListenerStdlib{}, }} sess, err := dialer.DialContext( @@ -233,7 +234,8 @@ func TestQUICDialerResolverSuccess(t *testing.T) { func TestQUICDialerResolverNoPort(t *testing.T) { tlsConfig := &tls.Config{} dialer := &quicDialerResolver{ - Resolver: new(net.Resolver), Dialer: &quicDialerQUICGo{}} + Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), + Dialer: &quicDialerQUICGo{}} sess, err := dialer.DialContext( context.Background(), "udp", "www.google.com", tlsConfig, &quic.Config{}) @@ -286,7 +288,8 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) { // to establish a connection leads to a failure tlsConf := &tls.Config{} dialer := &quicDialerResolver{ - Resolver: new(net.Resolver), Dialer: &quicDialerQUICGo{ + Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), + Dialer: &quicDialerQUICGo{ QUICListener: &quicListenerStdlib{}, }} sess, err := dialer.DialContext( @@ -309,7 +312,8 @@ func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) { var gotTLSConfig *tls.Config tlsConfig := &tls.Config{} dialer := &quicDialerResolver{ - Resolver: new(net.Resolver), Dialer: &mocks.QUICContextDialer{ + Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), + Dialer: &mocks.QUICContextDialer{ MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { gotTLSConfig = tlsConfig diff --git a/internal/netxlite/resolver.go b/internal/netxlite/resolver.go index 5fa7374..cd79a92 100644 --- a/internal/netxlite/resolver.go +++ b/internal/netxlite/resolver.go @@ -12,6 +12,31 @@ import ( type Resolver interface { // LookupHost behaves like net.Resolver.LookupHost. LookupHost(ctx context.Context, hostname string) (addrs []string, err error) + + // Network returns the resolver type (e.g., system, dot, doh). + Network() string + + // Address returns the resolver address (e.g., 8.8.8.8:53). + Address() string + + // CloseIdleConnections closes idle connections, if any. + CloseIdleConnections() +} + +// ResolverConfig contains config for creating a resolver. +type ResolverConfig struct { + // Logger is the MANDATORY logger to use. + Logger Logger +} + +// NewResolver creates a new resolver. +func NewResolver(config *ResolverConfig) Resolver { + return &resolverIDNA{ + Resolver: &resolverLogger{ + Resolver: &resolverSystem{}, + Logger: config.Logger, + }, + } } // resolverSystem is the system resolver. @@ -34,6 +59,11 @@ func (r *resolverSystem) Address() string { return "" } +// CloseIdleConnections implements Resolver.CloseIdleConnections. +func (r *resolverSystem) CloseIdleConnections() { + // nothing +} + // DefaultResolver is the resolver we use by default. var DefaultResolver = &resolverSystem{} @@ -59,30 +89,6 @@ func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]str return addrs, nil } -type resolverNetworker interface { - Network() string -} - -// Network implements Resolver.Network. -func (r *resolverLogger) Network() string { - if rn, ok := r.Resolver.(resolverNetworker); ok { - return rn.Network() - } - return "logger" -} - -type resolverAddresser interface { - Address() string -} - -// Address implements Resolver.Address. -func (r *resolverLogger) Address() string { - if ra, ok := r.Resolver.(resolverAddresser); ok { - return ra.Address() - } - return "" -} - // resolverIDNA supports resolving Internationalized Domain Names. // // See RFC3492 for more information. @@ -98,19 +104,3 @@ func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]strin } return r.Resolver.LookupHost(ctx, host) } - -// Network implements Resolver.Network. -func (r *resolverIDNA) Network() string { - if rn, ok := r.Resolver.(resolverNetworker); ok { - return rn.Network() - } - return "idna" -} - -// Address implements Resolver.Address. -func (r *resolverIDNA) Address() string { - if ra, ok := r.Resolver.(resolverAddresser); ok { - return ra.Address() - } - return "" -} diff --git a/internal/netxlite/resolver_test.go b/internal/netxlite/resolver_test.go index 80314f1..601350a 100644 --- a/internal/netxlite/resolver_test.go +++ b/internal/netxlite/resolver_test.go @@ -3,7 +3,6 @@ package netxlite import ( "context" "errors" - "net" "strings" "testing" @@ -71,26 +70,6 @@ func TestResolverLoggerWithFailure(t *testing.T) { } } -func TestResolverLoggerChildNetworkAddress(t *testing.T) { - r := &resolverLogger{Logger: log.Log, Resolver: DefaultResolver} - if r.Network() != "system" { - t.Fatal("invalid Network") - } - if r.Address() != "" { - t.Fatal("invalid Address") - } -} - -func TestResolverLoggerNoChildNetworkAddress(t *testing.T) { - r := &resolverLogger{Logger: log.Log, Resolver: &net.Resolver{}} - if r.Network() != "logger" { - t.Fatal("invalid Network") - } - if r.Address() != "" { - t.Fatal("invalid Address") - } -} - func TestResolverIDNAWorksAsIntended(t *testing.T) { expectedIPs := []string{"77.88.55.66"} r := &resolverIDNA{ @@ -130,24 +109,22 @@ func TestResolverIDNAWithInvalidPunycode(t *testing.T) { } } -func TestResolverIDNAChildNetworkAddress(t *testing.T) { - r := &resolverIDNA{ - Resolver: DefaultResolver, +func TestNewResolverTypeChain(t *testing.T) { + r := NewResolver(&ResolverConfig{ + Logger: log.Log, + }) + ridna, ok := r.(*resolverIDNA) + if !ok { + t.Fatal("invalid resolver") } - if v := r.Network(); v != "system" { - t.Fatal("invalid network", v) + rl, ok := ridna.Resolver.(*resolverLogger) + if !ok { + t.Fatal("invalid resolver") } - if v := r.Address(); v != "" { - t.Fatal("invalid address", v) - } -} - -func TestResolverIDNANoChildNetworkAddress(t *testing.T) { - r := &resolverIDNA{} - if v := r.Network(); v != "idna" { - t.Fatal("invalid network", v) - } - if v := r.Address(); v != "" { - t.Fatal("invalid address", v) + if rl.Logger != log.Log { + t.Fatal("invalid logger") + } + if _, ok := rl.Resolver.(*resolverSystem); !ok { + t.Fatal("invalid resolver") } }