feat(netxlite): add dialer factory, simplify resolver factory (#459)

See https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-05 20:41:46 +02:00 committed by GitHub
parent b52d784f00
commit 6a1e92cace
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 107 additions and 19 deletions

View File

@ -15,6 +15,26 @@ type Dialer interface {
CloseIdleConnections() CloseIdleConnections()
} }
// NewDialerWithResolver creates a dialer using the given resolver and logger.
func NewDialerWithResolver(logger Logger, resolver Resolver) Dialer {
return &dialerLogger{
Dialer: &dialerResolver{
Dialer: &dialerLogger{
Dialer: &dialerSystem{},
Logger: logger,
},
Resolver: resolver,
},
Logger: logger,
}
}
// NewDialerWithoutResolver creates a dialer that uses the given
// logger and fails with ErrNoResolver when it is passed a domain name.
func NewDialerWithoutResolver(logger Logger) Dialer {
return NewDialerWithResolver(logger, &nullResolver{})
}
// underlyingDialer is the Dialer we use by default. // underlyingDialer is the Dialer we use by default.
var underlyingDialer = &net.Dialer{ var underlyingDialer = &net.Dialer{
Timeout: 15 * time.Second, Timeout: 15 * time.Second,

View File

@ -207,3 +207,31 @@ func TestUnderlyingDialerHasTimeout(t *testing.T) {
t.Fatal("unexpected timeout value") t.Fatal("unexpected timeout value")
} }
} }
func TestNewDialerWithoutResolverChain(t *testing.T) {
dlr := NewDialerWithoutResolver(log.Log)
dlog, okay := dlr.(*dialerLogger)
if !okay {
t.Fatal("invalid type")
}
if dlog.Logger != log.Log {
t.Fatal("invalid logger")
}
dreso, okay := dlog.Dialer.(*dialerResolver)
if !okay {
t.Fatal("invalid type")
}
if _, okay := dreso.Resolver.(*nullResolver); !okay {
t.Fatal("invalid Resolver type")
}
dlog, okay = dreso.Dialer.(*dialerLogger)
if !okay {
t.Fatal("invalid type")
}
if dlog.Logger != log.Log {
t.Fatal("invalid logger")
}
if _, okay := dlog.Dialer.(*dialerSystem); !okay {
t.Fatal("invalid type")
}
}

View File

@ -13,7 +13,7 @@ func TestHTTP3TransportWorks(t *testing.T) {
Dialer: &quicDialerQUICGo{ Dialer: &quicDialerQUICGo{
QUICListener: &quicListenerStdlib{}, QUICListener: &quicListenerStdlib{},
}, },
Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), Resolver: NewResolverSystem(log.Log),
} }
txp := NewHTTP3Transport(d, &tls.Config{}) txp := NewHTTP3Transport(d, &tls.Config{})
client := &http.Client{Transport: txp} client := &http.Client{Transport: txp}

View File

@ -112,7 +112,7 @@ func TestHTTPTransportLoggerCloseIdleConnections(t *testing.T) {
func TestHTTPTransportWorks(t *testing.T) { func TestHTTPTransportWorks(t *testing.T) {
d := &dialerResolver{ d := &dialerResolver{
Dialer: defaultDialer, Dialer: defaultDialer,
Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), Resolver: NewResolverSystem(log.Log),
} }
th := &tlsHandshakerConfigurable{} th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th) txp := NewHTTPTransport(d, &tls.Config{}, th)
@ -134,7 +134,7 @@ func TestHTTPTransportWithFailingDialer(t *testing.T) {
return nil, expected return nil, expected
}, },
}, },
Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), Resolver: NewResolverSystem(log.Log),
} }
th := &tlsHandshakerConfigurable{} th := &tlsHandshakerConfigurable{}
txp := NewHTTPTransport(d, &tls.Config{}, th) txp := NewHTTPTransport(d, &tls.Config{}, th)

View File

@ -215,7 +215,7 @@ func TestQUICDialerQUICGoTLSDefaultsForDoQ(t *testing.T) {
func TestQUICDialerResolverSuccess(t *testing.T) { func TestQUICDialerResolverSuccess(t *testing.T) {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), Resolver: NewResolverSystem(log.Log),
Dialer: &quicDialerQUICGo{ Dialer: &quicDialerQUICGo{
QUICListener: &quicListenerStdlib{}, QUICListener: &quicListenerStdlib{},
}} }}
@ -234,7 +234,7 @@ func TestQUICDialerResolverSuccess(t *testing.T) {
func TestQUICDialerResolverNoPort(t *testing.T) { func TestQUICDialerResolverNoPort(t *testing.T) {
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), Resolver: NewResolverSystem(log.Log),
Dialer: &quicDialerQUICGo{}} Dialer: &quicDialerQUICGo{}}
sess, err := dialer.DialContext( sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com", context.Background(), "udp", "www.google.com",
@ -288,7 +288,7 @@ func TestQUICDialerResolverInvalidPort(t *testing.T) {
// to establish a connection leads to a failure // to establish a connection leads to a failure
tlsConf := &tls.Config{} tlsConf := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), Resolver: NewResolverSystem(log.Log),
Dialer: &quicDialerQUICGo{ Dialer: &quicDialerQUICGo{
QUICListener: &quicListenerStdlib{}, QUICListener: &quicListenerStdlib{},
}} }}
@ -312,7 +312,7 @@ func TestQUICDialerResolverApplyTLSDefaults(t *testing.T) {
var gotTLSConfig *tls.Config var gotTLSConfig *tls.Config
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
dialer := &quicDialerResolver{ dialer := &quicDialerResolver{
Resolver: NewResolver(&ResolverConfig{Logger: log.Log}), Resolver: NewResolverSystem(log.Log),
Dialer: &mocks.QUICContextDialer{ Dialer: &mocks.QUICContextDialer{
MockDialContext: func(ctx context.Context, network, address string, MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) { tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {

View File

@ -2,6 +2,7 @@ package netxlite
import ( import (
"context" "context"
"errors"
"net" "net"
"time" "time"
@ -23,20 +24,15 @@ type Resolver interface {
CloseIdleConnections() CloseIdleConnections()
} }
// ResolverConfig contains config for creating a resolver. // NewResolverSystem creates a new resolver using system
type ResolverConfig struct { // facilities for resolving domain names (e.g., getaddrinfo).
// Logger is the MANDATORY logger to use. func NewResolverSystem(logger Logger) Resolver {
Logger Logger
}
// NewResolver creates a new resolver.
func NewResolver(config *ResolverConfig) Resolver {
return &resolverIDNA{ return &resolverIDNA{
Resolver: &resolverLogger{ Resolver: &resolverLogger{
Resolver: &resolverShortCircuitIPAddr{ Resolver: &resolverShortCircuitIPAddr{
Resolver: &resolverSystem{}, Resolver: &resolverSystem{},
}, },
Logger: config.Logger, Logger: logger,
}, },
} }
} }
@ -159,3 +155,30 @@ func (r *resolverShortCircuitIPAddr) LookupHost(ctx context.Context, hostname st
} }
return r.Resolver.LookupHost(ctx, hostname) return r.Resolver.LookupHost(ctx, hostname)
} }
// ErrNoResolver indicates you are using a dialer without a resolver.
var ErrNoResolver = errors.New("no configured resolver")
// nullResolver is a resolver that is not capable of resolving
// domain names to IP addresses and always returns ErrNoResolver.
type nullResolver struct{}
// LookupHost implements Resolver.LookupHost.
func (r *nullResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) {
return nil, ErrNoResolver
}
// Network implements Resolver.Network.
func (r *nullResolver) Network() string {
return "null"
}
// Address implements Resolver.Address.
func (r *nullResolver) Address() string {
return ""
}
// CloseIdleConnections implements Resolver.CloseIdleConnections.
func (r *nullResolver) CloseIdleConnections() {
// nothing
}

View File

@ -180,9 +180,7 @@ func TestResolverIDNAWithInvalidPunycode(t *testing.T) {
} }
func TestNewResolverTypeChain(t *testing.T) { func TestNewResolverTypeChain(t *testing.T) {
r := NewResolver(&ResolverConfig{ r := NewResolverSystem(log.Log)
Logger: log.Log,
})
ridna, ok := r.(*resolverIDNA) ridna, ok := r.(*resolverIDNA)
if !ok { if !ok {
t.Fatal("invalid resolver") t.Fatal("invalid resolver")
@ -238,3 +236,22 @@ func TestResolverShortCircuitIPAddrWithDomain(t *testing.T) {
t.Fatal("invalid result") t.Fatal("invalid result")
} }
} }
func TestNullResolverWorksAsIntended(t *testing.T) {
r := &nullResolver{}
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "dns.google")
if !errors.Is(err, ErrNoResolver) {
t.Fatal("not the error we expected", err)
}
if addrs != nil {
t.Fatal("expected nil addr")
}
if r.Network() != "null" {
t.Fatal("invalid network")
}
if r.Address() != "" {
t.Fatal("invalid address")
}
r.CloseIdleConnections() // should not crash
}