refactor(netxlite): add more functions to resolver (#455)

We would like to refactor the code so that a DoH resolver owns the
connections of its underlying HTTP client.

To do that, we need first to incorporate CloseIdleConnections
into the Resolver model. Then, we need to add the same function
to all netxlite types that wrap a Resolver type.

At the same time, we want the rest of the code for now to continue
with the simpler definition of a Resolver, now called ResolverLegacy.

We will eventually propagate this change to the rest of the tree
and simplify the way in which we manage Resolvers.

To make this possible, we introduce a new factory function that
adapts a ResolverLegacy to become a Resolver.

See https://github.com/ooni/probe/issues/1591.
This commit is contained in:
Simone Basso 2021-09-05 18:03:50 +02:00 committed by GitHub
parent 2e0118d1a6
commit a3654f60b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 279 additions and 119 deletions

View File

@ -27,7 +27,7 @@ type Explorer interface {
// DefaultExplorer is the default Explorer. // DefaultExplorer is the default Explorer.
type DefaultExplorer struct { type DefaultExplorer struct {
resolver netxlite.Resolver resolver netxlite.ResolverLegacy
} }
// Explore returns a list of round trips sorted so that the first // Explore returns a list of round trips sorted so that the first

View File

@ -24,7 +24,7 @@ type Generator interface {
type DefaultGenerator struct { type DefaultGenerator struct {
dialer netxlite.Dialer dialer netxlite.Dialer
quicDialer netxlite.QUICContextDialer quicDialer netxlite.QUICContextDialer
resolver netxlite.Resolver resolver netxlite.ResolverLegacy
transport http.RoundTripper transport http.RoundTripper
} }

View File

@ -31,7 +31,7 @@ type InitChecker interface {
// DefaultInitChecker is the default InitChecker. // DefaultInitChecker is the default InitChecker.
type DefaultInitChecker struct { type DefaultInitChecker struct {
resolver netxlite.Resolver resolver netxlite.ResolverLegacy
} }
// InitialChecks checks whether the URL is valid and whether the // InitialChecks checks whether the URL is valid and whether the

View File

@ -24,7 +24,7 @@ type Config struct {
checker InitChecker checker InitChecker
explorer Explorer explorer Explorer
generator Generator generator Generator
resolver netxlite.Resolver resolver netxlite.ResolverLegacy
} }
// Measure performs the three consecutive steps of the testhelper algorithm: // Measure performs the three consecutive steps of the testhelper algorithm:
@ -87,10 +87,12 @@ func newDNSFailedResponse(err error, URL string) *ControlResponse {
} }
// newResolver creates a new DNS resolver instance // newResolver creates a new DNS resolver instance
func newResolver() netxlite.Resolver { func newResolver() netxlite.ResolverLegacy {
childResolver, err := netx.NewDNSClient(netx.Config{Logger: log.Log}, "doh://google") childResolver, err := netx.NewDNSClient(netx.Config{Logger: log.Log}, "doh://google")
runtimex.PanicOnError(err, "NewDNSClient failed") runtimex.PanicOnError(err, "NewDNSClient failed")
var r netxlite.Resolver = childResolver var r netxlite.ResolverLegacy = childResolver
r = &netxlite.ResolverIDNA{Resolver: r} r = &netxlite.ResolverIDNA{
Resolver: netxlite.NewResolverLegacyAdapter(r),
}
return r return r
} }

View File

@ -35,7 +35,10 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) { func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
var reso resolver.Resolver = &netxlite.ResolverSystem{} var reso resolver.Resolver = &netxlite.ResolverSystem{}
reso = &netxlite.ResolverLogger{Resolver: reso, Logger: mgr.logger} reso = &netxlite.ResolverLogger{
Resolver: netxlite.NewResolverLegacyAdapter(reso),
Logger: mgr.logger,
}
dlr := dialer.New(&dialer.Config{ dlr := dialer.New(&dialer.Config{
ContextByteCounting: true, ContextByteCounting: true,
Logger: mgr.logger, Logger: mgr.logger,

View File

@ -11,7 +11,7 @@ import (
type DNSConfig struct { type DNSConfig struct {
Domain string Domain string
Resolver netxlite.Resolver Resolver netxlite.ResolverLegacy
} }
// DNSDo performs the DNS check. // DNSDo performs the DNS check.
@ -21,7 +21,9 @@ func DNSDo(ctx context.Context, config DNSConfig) ([]string, error) {
childResolver, err := netx.NewDNSClient(netx.Config{Logger: log.Log}, "doh://google") childResolver, err := netx.NewDNSClient(netx.Config{Logger: log.Log}, "doh://google")
runtimex.PanicOnError(err, "NewDNSClient failed") runtimex.PanicOnError(err, "NewDNSClient failed")
resolver = childResolver resolver = childResolver
resolver = &netxlite.ResolverIDNA{Resolver: resolver} resolver = &netxlite.ResolverIDNA{
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
}
} }
return resolver.LookupHost(ctx, config.Domain) return resolver.LookupHost(ctx, config.Domain)
} }

View File

@ -33,23 +33,29 @@ func NewRequest(ctx context.Context, URL *url.URL, headers http.Header) *http.Re
// NewDialerResolver contructs a new dialer for TCP connections, // NewDialerResolver contructs a new dialer for TCP connections,
// with default, errorwrapping and resolve functionalities // with default, errorwrapping and resolve functionalities
func NewDialerResolver(resolver netxlite.Resolver) netxlite.Dialer { func NewDialerResolver(resolver netxlite.ResolverLegacy) netxlite.Dialer {
var d netxlite.Dialer = netxlite.DefaultDialer var d netxlite.Dialer = netxlite.DefaultDialer
d = &errorsx.ErrorWrapperDialer{Dialer: d} d = &errorsx.ErrorWrapperDialer{Dialer: d}
d = &netxlite.DialerResolver{Resolver: resolver, Dialer: d} d = &netxlite.DialerResolver{
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
Dialer: d,
}
return d return d
} }
// NewQUICDialerResolver creates a new QUICDialerResolver // NewQUICDialerResolver creates a new QUICDialerResolver
// with default, errorwrapping and resolve functionalities // with default, errorwrapping and resolve functionalities
func NewQUICDialerResolver(resolver netxlite.Resolver) netxlite.QUICContextDialer { func NewQUICDialerResolver(resolver netxlite.ResolverLegacy) netxlite.QUICContextDialer {
var ql quicdialer.QUICListener = &netxlite.QUICListenerStdlib{} var ql quicdialer.QUICListener = &netxlite.QUICListenerStdlib{}
ql = &errorsx.ErrorWrapperQUICListener{QUICListener: ql} ql = &errorsx.ErrorWrapperQUICListener{QUICListener: ql}
var dialer netxlite.QUICContextDialer = &netxlite.QUICDialerQUICGo{ var dialer netxlite.QUICContextDialer = &netxlite.QUICDialerQUICGo{
QUICListener: ql, QUICListener: ql,
} }
dialer = &errorsx.ErrorWrapperQUICDialer{Dialer: dialer} dialer = &errorsx.ErrorWrapperQUICDialer{Dialer: dialer}
dialer = &netxlite.QUICDialerResolver{Resolver: resolver, Dialer: dialer} dialer = &netxlite.QUICDialerResolver{
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
Dialer: dialer,
}
return dialer return dialer
} }

View File

@ -11,7 +11,7 @@ import (
type QUICConfig struct { type QUICConfig struct {
Endpoint string Endpoint string
QUICDialer netxlite.QUICContextDialer QUICDialer netxlite.QUICContextDialer
Resolver netxlite.Resolver Resolver netxlite.ResolverLegacy
TLSConf *tls.Config TLSConf *tls.Config
} }

View File

@ -10,7 +10,7 @@ import (
type TCPConfig struct { type TCPConfig struct {
Dialer netxlite.Dialer Dialer netxlite.Dialer
Endpoint string Endpoint string
Resolver netxlite.Resolver Resolver netxlite.ResolverLegacy
} }
// TCPDo performs the TCP check. // TCPDo performs the TCP check.

View File

@ -80,7 +80,10 @@ func New(config *Config, resolver Resolver) Dialer {
if config.ReadWriteSaver != nil { if config.ReadWriteSaver != nil {
d = &saverConnDialer{Dialer: d, Saver: config.ReadWriteSaver} d = &saverConnDialer{Dialer: d, Saver: config.ReadWriteSaver}
} }
d = &netxlite.DialerResolver{Resolver: resolver, Dialer: d} d = &netxlite.DialerResolver{
Resolver: netxlite.NewResolverLegacyAdapter(resolver),
Dialer: d,
}
d = &proxyDialer{ProxyURL: config.ProxyURL, Dialer: d} d = &proxyDialer{ProxyURL: config.ProxyURL, Dialer: d}
if config.ContextByteCounting { if config.ContextByteCounting {
d = &byteCounterDialer{Dialer: d} d = &byteCounterDialer{Dialer: d}

View File

@ -134,12 +134,15 @@ func NewResolver(config Config) Resolver {
} }
r = &errorsx.ErrorWrapperResolver{Resolver: r} r = &errorsx.ErrorWrapperResolver{Resolver: r}
if config.Logger != nil { if config.Logger != nil {
r = &netxlite.ResolverLogger{Logger: config.Logger, Resolver: r} r = &netxlite.ResolverLogger{
Logger: config.Logger,
Resolver: netxlite.NewResolverLegacyAdapter(r),
}
} }
if config.ResolveSaver != nil { if config.ResolveSaver != nil {
r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver} r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver}
} }
return &resolver.IDNAResolver{Resolver: r} return &resolver.IDNAResolver{Resolver: netxlite.NewResolverLegacyAdapter(r)}
} }
// NewDialer creates a new Dialer from the specified config // NewDialer creates a new Dialer from the specified config
@ -176,7 +179,10 @@ func NewQUICDialer(config Config) QUICDialer {
if config.TLSSaver != nil { if config.TLSSaver != nil {
d = quicdialer.HandshakeSaver{Saver: config.TLSSaver, Dialer: d} d = quicdialer.HandshakeSaver{Saver: config.TLSSaver, Dialer: d}
} }
d = &netxlite.QUICDialerResolver{Resolver: config.FullResolver, Dialer: d} d = &netxlite.QUICDialerResolver{
Resolver: netxlite.NewResolverLegacyAdapter(config.FullResolver),
Dialer: d,
}
return d return d
} }

View File

@ -24,7 +24,11 @@ func TestNewResolverVanilla(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -48,7 +52,11 @@ func TestNewResolverSpecificResolver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -70,7 +78,11 @@ func TestNewResolverWithBogonFilter(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -96,17 +108,33 @@ func TestNewResolverWithLogging(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
lr, ok := ir.Resolver.(*netxlite.ResolverLogger) rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
lr, ok := rla.ResolverLegacy.(*netxlite.ResolverLogger)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
if lr.Logger != log.Log { if lr.Logger != log.Log {
t.Fatal("not the logger we expected") t.Fatal("not the logger we expected")
} }
ewr, ok := lr.Resolver.(*errorsx.ErrorWrapperResolver) rla, ok = ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
lr, ok = rla.ResolverLegacy.(*netxlite.ResolverLogger)
if !ok {
t.Fatal("not the resolver we expected")
}
rla, ok = lr.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver)
if !ok {
t.Fatalf("not the resolver we expected %T", rla.ResolverLegacy)
}
ar, ok := ewr.Resolver.(resolver.AddressResolver) ar, ok := ewr.Resolver.(resolver.AddressResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
@ -126,7 +154,11 @@ func TestNewResolverWithSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
sr, ok := ir.Resolver.(resolver.SaverResolver) rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
sr, ok := rla.ResolverLegacy.(resolver.SaverResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -155,7 +187,11 @@ func TestNewResolverWithReadWriteCache(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -186,7 +222,11 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
ewr, ok := ir.Resolver.(*errorsx.ErrorWrapperResolver) rla, ok := ir.Resolver.(*netxlite.ResolverLegacyAdapter)
if !ok {
t.Fatal("not the resolver we expected")
}
ewr, ok := rla.ResolverLegacy.(*errorsx.ErrorWrapperResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }

View File

@ -19,7 +19,10 @@ func testresolverquick(t *testing.T, reso resolver.Resolver) {
if testing.Short() { if testing.Short() {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
reso = &netxlite.ResolverLogger{Logger: log.Log, Resolver: reso} reso = &netxlite.ResolverLogger{
Logger: log.Log,
Resolver: netxlite.NewResolverLegacyAdapter(reso),
}
addrs, err := reso.LookupHost(context.Background(), "dns.google.com") addrs, err := reso.LookupHost(context.Background(), "dns.google.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -45,7 +48,10 @@ func testresolverquickidna(t *testing.T, reso resolver.Resolver) {
t.Skip("skip test in short mode") t.Skip("skip test in short mode")
} }
reso = &resolver.IDNAResolver{ reso = &resolver.IDNAResolver{
Resolver: &netxlite.ResolverLogger{Logger: log.Log, Resolver: reso}, Resolver: &netxlite.ResolverLogger{
Logger: log.Log,
Resolver: netxlite.NewResolverLegacyAdapter(reso),
},
} }
addrs, err := reso.LookupHost(context.Background(), "яндекс.рф") addrs, err := reso.LookupHost(context.Background(), "яндекс.рф")
if err != nil { if err != nil {

View File

@ -2,9 +2,10 @@ package netxlite
import ( import (
"crypto/tls" "crypto/tls"
"net"
"net/http" "net/http"
"testing" "testing"
"github.com/apex/log"
) )
func TestHTTP3TransportWorks(t *testing.T) { func TestHTTP3TransportWorks(t *testing.T) {
@ -12,7 +13,7 @@ func TestHTTP3TransportWorks(t *testing.T) {
Dialer: &quicDialerQUICGo{ Dialer: &quicDialerQUICGo{
QUICListener: &quicListenerStdlib{}, QUICListener: &quicListenerStdlib{},
}, },
Resolver: &net.Resolver{}, Resolver: NewResolver(&ResolverConfig{Logger: log.Log}),
} }
txp := NewHTTP3Transport(d, &tls.Config{}) txp := NewHTTP3Transport(d, &tls.Config{})
client := &http.Client{Transport: txp} client := &http.Client{Transport: txp}

View File

@ -112,7 +112,7 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
func TestHTTPTransportWorks(t *testing.T) { func TestHTTPTransportWorks(t *testing.T) {
d := &dialerResolver{ d := &dialerResolver{
Dialer: defaultDialer, Dialer: defaultDialer,
Resolver: &net.Resolver{}, Resolver: NewResolver(&ResolverConfig{Logger: log.Log}),
} }
th := &tlsHandshakerConfigurable{} th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th) txp := NewHTTPTransport(d, &tls.Config{}, th)
@ -134,7 +134,7 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
return nil, expected return nil, expected
}, },
}, },
Resolver: &net.Resolver{}, Resolver: NewResolver(&ResolverConfig{Logger: log.Log}),
} }
th := &tlsHandshakerConfigurable{} th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th) txp := NewHTTPTransport(d, &tls.Config{}, th)

View File

@ -1,6 +1,7 @@
package netxlite package netxlite
import ( import (
"context"
"errors" "errors"
"strings" "strings"
@ -59,3 +60,65 @@ type (
TLSHandshakerConfigurable = tlsHandshakerConfigurable TLSHandshakerConfigurable = tlsHandshakerConfigurable
TLSHandshakerLogger = tlsHandshakerLogger TLSHandshakerLogger = tlsHandshakerLogger
) )
// ResolverLegacy performs domain name resolutions.
//
// This definition of Resolver is DEPRECATED. New code should use
// the more complete definition in the new Resolver interface.
//
// Existing code in ooni/probe-cli is still using this definition.
type ResolverLegacy interface {
// LookupHost behaves like net.Resolver.LookupHost.
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
}
// NewResolverLegacyAdapter adapts a ResolverLegacy to
// become compatible with the Resolver definition.
func NewResolverLegacyAdapter(reso ResolverLegacy) Resolver {
return &ResolverLegacyAdapter{reso}
}
// ResolverLegacyAdapter makes a ResolverLegacy behave like
// it was a Resolver type. If ResolverLegacy is actually also
// a Resolver, this adapter will just forward missing calls,
// otherwise it will implement a sensible default action.
type ResolverLegacyAdapter struct {
ResolverLegacy
}
var _ Resolver = &ResolverLegacyAdapter{}
type resolverLegacyNetworker interface {
Network() string
}
// Network implements Resolver.Network.
func (r *ResolverLegacyAdapter) Network() string {
if rn, ok := r.ResolverLegacy.(resolverLegacyNetworker); ok {
return rn.Network()
}
return "adapter"
}
type resolverLegacyAddresser interface {
Address() string
}
// Address implements Resolver.Address.
func (r *ResolverLegacyAdapter) Address() string {
if ra, ok := r.ResolverLegacy.(resolverLegacyAddresser); ok {
return ra.Address()
}
return ""
}
type resolverLegacyIdleConnectionsCloser interface {
CloseIdleConnections()
}
// CloseIdleConnections implements Resolver.CloseIdleConnections.
func (r *ResolverLegacyAdapter) CloseIdleConnections() {
if ra, ok := r.ResolverLegacy.(resolverLegacyIdleConnectionsCloser); ok {
ra.CloseIdleConnections()
}
}

View File

@ -2,9 +2,11 @@ package netxlite
import ( import (
"errors" "errors"
"net"
"testing" "testing"
"github.com/ooni/probe-cli/v3/internal/errorsx" "github.com/ooni/probe-cli/v3/internal/errorsx"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestReduceErrors(t *testing.T) { func TestReduceErrors(t *testing.T) {
@ -44,3 +46,39 @@ func TestReduceErrors(t *testing.T) {
} }
}) })
} }
func TestResolverLegacyAdapterWithCompatibleType(t *testing.T) {
var called bool
r := NewResolverLegacyAdapter(&mocks.Resolver{
MockNetwork: func() string {
return "network"
},
MockAddress: func() string {
return "address"
},
MockCloseIdleConnections: func() {
called = true
},
})
if r.Network() != "network" {
t.Fatal("invalid Network")
}
if r.Address() != "address" {
t.Fatal("invalid Address")
}
r.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
}
func TestResolverLegacyAdapterDefaults(t *testing.T) {
r := NewResolverLegacyAdapter(&net.Resolver{})
if r.Network() != "adapter" {
t.Fatal("invalid Network")
}
if r.Address() != "" {
t.Fatal("invalid Address")
}
r.CloseIdleConnections() // does not crash
}

View File

@ -4,9 +4,10 @@ import "context"
// Resolver is a mockable Resolver. // Resolver is a mockable Resolver.
type Resolver struct { type Resolver struct {
MockLookupHost func(ctx context.Context, domain string) ([]string, error) MockLookupHost func(ctx context.Context, domain string) ([]string, error)
MockNetwork func() string MockNetwork func() string
MockAddress func() string MockAddress func() string
MockCloseIdleConnections func()
} }
// LookupHost calls MockLookupHost. // LookupHost calls MockLookupHost.
@ -23,3 +24,8 @@ func (r *Resolver) Address() string {
func (r *Resolver) Network() string { func (r *Resolver) Network() string {
return r.MockNetwork() return r.MockNetwork()
} }
// CloseIdleConnections calls MockCloseIdleConnections.
func (r *Resolver) CloseIdleConnections() {
r.MockCloseIdleConnections()
}

View File

@ -44,3 +44,16 @@ func TestResolverAddress(t *testing.T) {
t.Fatal("unexpected address", v) t.Fatal("unexpected address", v)
} }
} }
func TestResolverCloseIdleConnections(t *testing.T) {
var called bool
r := &Resolver{
MockCloseIdleConnections: func() {
called = true
},
}
r.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
}

View File

@ -215,7 +215,8 @@ func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
func TestQUICDialerResolverSuccess(t *testing.T) { func TestQUICDialerResolverSuccess(t *testing.T) {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: &net.Resolver{}, Dialer: &quicDialerQUICGo{ Resolver: NewResolver(&ResolverConfig{Logger: log.Log}),
Dialer: &quicDialerQUICGo{
QUICListener: &quicListenerStdlib{}, QUICListener: &quicListenerStdlib{},
}} }}
sess, err := dialer.DialContext( sess, err := dialer.DialContext(
@ -233,7 +234,8 @@ func TestQUICDialerResolverSuccess(t *testing.T) {
func TestQUICDialerResolverNoPort(t *testing.T) { func TestQUICDialerResolverNoPort(t *testing.T) {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: new(net.Resolver), Dialer: &quicDialerQUICGo{}} Resolver: NewResolver(&ResolverConfig{Logger: log.Log}),
Dialer: &quicDialerQUICGo{}}
sess, err := dialer.DialContext( sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com", context.Background(), "udp", "www.google.com",
tlsConfig, &quic.Config{}) tlsConfig, &quic.Config{})
@ -286,7 +288,8 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) {
// to establish a connection leads to a failure // to establish a connection leads to a failure
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: new(net.Resolver), Dialer: &quicDialerQUICGo{ Resolver: NewResolver(&ResolverConfig{Logger: log.Log}),
Dialer: &quicDialerQUICGo{
QUICListener: &quicListenerStdlib{}, QUICListener: &quicListenerStdlib{},
}} }}
sess, err := dialer.DialContext( sess, err := dialer.DialContext(
@ -309,7 +312,8 @@ func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
var gotTLSConfig *tls.Config var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: new(net.Resolver), Dialer: &mocks.QUICContextDialer{ Resolver: NewResolver(&ResolverConfig{Logger: log.Log}),
Dialer: &mocks.QUICContextDialer{
MockDialContext: func(ctx context.Context, network, address string, MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
gotTLSConfig = tlsConfig gotTLSConfig = tlsConfig

View File

@ -12,6 +12,31 @@ import (
type Resolver interface { type Resolver interface {
// LookupHost behaves like net.Resolver.LookupHost. // LookupHost behaves like net.Resolver.LookupHost.
LookupHost(ctx context.Context, hostname string) (addrs []string, err error) LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
// Network returns the resolver type (e.g., system, dot, doh).
Network() string
// Address returns the resolver address (e.g., 8.8.8.8:53).
Address() string
// CloseIdleConnections closes idle connections, if any.
CloseIdleConnections()
}
// ResolverConfig contains config for creating a resolver.
type ResolverConfig struct {
// Logger is the MANDATORY logger to use.
Logger Logger
}
// NewResolver creates a new resolver.
func NewResolver(config *ResolverConfig) Resolver {
return &resolverIDNA{
Resolver: &resolverLogger{
Resolver: &resolverSystem{},
Logger: config.Logger,
},
}
} }
// resolverSystem is the system resolver. // resolverSystem is the system resolver.
@ -34,6 +59,11 @@ func (r *resolverSystem) Address() string {
return "" return ""
} }
// CloseIdleConnections implements Resolver.CloseIdleConnections.
func (r *resolverSystem) CloseIdleConnections() {
// nothing
}
// DefaultResolver is the resolver we use by default. // DefaultResolver is the resolver we use by default.
var DefaultResolver = &resolverSystem{} var DefaultResolver = &resolverSystem{}
@ -59,30 +89,6 @@ func (r *resolverLogger) LookupHost(ctx context.Context, hostname string) ([]str
return addrs, nil return addrs, nil
} }
type resolverNetworker interface {
Network() string
}
// Network implements Resolver.Network.
func (r *resolverLogger) Network() string {
if rn, ok := r.Resolver.(resolverNetworker); ok {
return rn.Network()
}
return "logger"
}
type resolverAddresser interface {
Address() string
}
// Address implements Resolver.Address.
func (r *resolverLogger) Address() string {
if ra, ok := r.Resolver.(resolverAddresser); ok {
return ra.Address()
}
return ""
}
// resolverIDNA supports resolving Internationalized Domain Names. // resolverIDNA supports resolving Internationalized Domain Names.
// //
// See RFC3492 for more information. // See RFC3492 for more information.
@ -98,19 +104,3 @@ func (r *resolverIDNA) LookupHost(ctx context.Context, hostname string) ([]strin
} }
return r.Resolver.LookupHost(ctx, host) return r.Resolver.LookupHost(ctx, host)
} }
// Network implements Resolver.Network.
func (r *resolverIDNA) Network() string {
if rn, ok := r.Resolver.(resolverNetworker); ok {
return rn.Network()
}
return "idna"
}
// Address implements Resolver.Address.
func (r *resolverIDNA) Address() string {
if ra, ok := r.Resolver.(resolverAddresser); ok {
return ra.Address()
}
return ""
}

View File

@ -3,7 +3,6 @@ package netxlite
import ( import (
"context" "context"
"errors" "errors"
"net"
"strings" "strings"
"testing" "testing"
@ -71,26 +70,6 @@ func TestResolverLoggerWithFailure(t *testing.T) {
} }
} }
func TestResolverLoggerChildNetworkAddress(t *testing.T) {
r := &resolverLogger{Logger: log.Log, Resolver: DefaultResolver}
if r.Network() != "system" {
t.Fatal("invalid Network")
}
if r.Address() != "" {
t.Fatal("invalid Address")
}
}
func TestResolverLoggerNoChildNetworkAddress(t *testing.T) {
r := &resolverLogger{Logger: log.Log, Resolver: &net.Resolver{}}
if r.Network() != "logger" {
t.Fatal("invalid Network")
}
if r.Address() != "" {
t.Fatal("invalid Address")
}
}
func TestResolverIDNAWorksAsIntended(t *testing.T) { func TestResolverIDNAWorksAsIntended(t *testing.T) {
expectedIPs := []string{"77.88.55.66"} expectedIPs := []string{"77.88.55.66"}
r := &resolverIDNA{ r := &resolverIDNA{
@ -130,24 +109,22 @@ func TestResolverIDNAWithInvalidPunycode(t *testing.T) {
} }
} }
func TestResolverIDNAChildNetworkAddress(t *testing.T) { func TestNewResolverTypeChain(t *testing.T) {
r := &resolverIDNA{ r := NewResolver(&ResolverConfig{
Resolver: DefaultResolver, Logger: log.Log,
})
ridna, ok := r.(*resolverIDNA)
if !ok {
t.Fatal("invalid resolver")
} }
if v := r.Network(); v != "system" { rl, ok := ridna.Resolver.(*resolverLogger)
t.Fatal("invalid network", v) if !ok {
t.Fatal("invalid resolver")
} }
if v := r.Address(); v != "" { if rl.Logger != log.Log {
t.Fatal("invalid address", v) t.Fatal("invalid logger")
} }
} if _, ok := rl.Resolver.(*resolverSystem); !ok {
t.Fatal("invalid resolver")
func TestResolverIDNANoChildNetworkAddress(t *testing.T) {
r := &resolverIDNA{}
if v := r.Network(); v != "idna" {
t.Fatal("invalid network", v)
}
if v := r.Address(); v != "" {
t.Fatal("invalid address", v)
} }
} }