cleanup(netx): remove the DNSClient type (#660)

The DNSClient type existed because the Resolver type did not
include CloseIdleConnections in its signature.

Now that Resolver includes CloseIdleConnections, the DNSClient
type has become unnecessary and can be safely removed.

See https://github.com/ooni/probe/issues/1956.
This commit is contained in:
Simone Basso 2022-01-10 11:53:06 +01:00 committed by GitHub
parent 730373cc75
commit d3c6c11e48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 58 deletions

View File

@ -17,7 +17,7 @@ import (
// Client is DNS, HTTP, and TCP client. // Client is DNS, HTTP, and TCP client.
type Client struct { type Client struct {
dnsClient *netx.DNSClient dnsClient model.Resolver
httpTransport model.HTTPTransport httpTransport model.HTTPTransport
dialer model.Dialer dialer model.Dialer
} }
@ -34,7 +34,7 @@ func NewClient(resolverURL string) (*Client, error) {
return nil, err return nil, err
} }
return &Client{ return &Client{
dnsClient: &configuration.DNSClient, dnsClient: configuration.DNSClient,
httpTransport: netx.NewHTTPTransport(configuration.HTTPConfig), httpTransport: netx.NewHTTPTransport(configuration.HTTPConfig),
dialer: netx.NewDialer(configuration.HTTPConfig), dialer: netx.NewDialer(configuration.HTTPConfig),
}, nil }, nil

View File

@ -26,7 +26,7 @@ type Configurer struct {
// The Configuration is the configuration for running a measurement. // The Configuration is the configuration for running a measurement.
type Configuration struct { type Configuration struct {
HTTPConfig netx.Config HTTPConfig netx.Config
DNSClient netx.DNSClient DNSClient model.Resolver
} }
// CloseIdleConnections will close idle connections, if needed. // CloseIdleConnections will close idle connections, if needed.
@ -82,7 +82,7 @@ func (c Configurer) NewConfiguration() (Configuration, error) {
return configuration, err return configuration, err
} }
configuration.DNSClient = dnsclient configuration.DNSClient = dnsclient
configuration.HTTPConfig.BaseResolver = dnsclient.Resolver configuration.HTTPConfig.BaseResolver = dnsclient
// configure TLS // configure TLS
configuration.HTTPConfig.TLSConfig = &tls.Config{ configuration.HTTPConfig.TLSConfig = &tls.Config{
NextProtos: []string{"h2", "http/1.1"}, NextProtos: []string{"h2", "http/1.1"},

View File

@ -239,20 +239,6 @@ var allTransportsInfo = map[bool]httpTransportInfo{
}, },
} }
// DNSClient is a DNS client. It wraps a Resolver and it possibly
// also wraps an HTTP client, but only when we're using DoH.
type DNSClient struct {
model.Resolver
httpClient *http.Client
}
// CloseIdleConnections closes idle connections, if any.
func (c DNSClient) CloseIdleConnections() {
if c.httpClient != nil {
c.httpClient.CloseIdleConnections()
}
}
// NewDNSClient creates a new DNS client. The config argument is used to // NewDNSClient creates a new DNS client. The config argument is used to
// create the underlying Dialer and/or HTTP transport, if needed. The URL // create the underlying Dialer and/or HTTP transport, if needed. The URL
// argument describes the kind of client that we want to make: // argument describes the kind of client that we want to make:
@ -271,15 +257,14 @@ func (c DNSClient) CloseIdleConnections() {
// //
// If config.ResolveSaver is not nil and we're creating an underlying // If config.ResolveSaver is not nil and we're creating an underlying
// resolver where this is possible, we will also save events. // resolver where this is possible, we will also save events.
func NewDNSClient(config Config, URL string) (DNSClient, error) { func NewDNSClient(config Config, URL string) (model.Resolver, error) {
return NewDNSClientWithOverrides(config, URL, "", "", "") return NewDNSClientWithOverrides(config, URL, "", "", "")
} }
// NewDNSClientWithOverrides creates a new DNS client, similar to NewDNSClient, // NewDNSClientWithOverrides creates a new DNS client, similar to NewDNSClient,
// with the option to override the default Hostname and SNI. // with the option to override the default Hostname and SNI.
func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
TLSVersion string) (DNSClient, error) { TLSVersion string) (model.Resolver, error) {
var c DNSClient
switch URL { switch URL {
case "doh://powerdns": case "doh://powerdns":
URL = "https://doh.powerdns.org/" URL = "https://doh.powerdns.org/"
@ -292,34 +277,32 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
} }
resolverURL, err := url.Parse(URL) resolverURL, err := url.Parse(URL)
if err != nil { if err != nil {
return c, err return nil, err
} }
config.TLSConfig = &tls.Config{ServerName: SNIOverride} config.TLSConfig = &tls.Config{ServerName: SNIOverride}
if err := netxlite.ConfigureTLSVersion(config.TLSConfig, TLSVersion); err != nil { if err := netxlite.ConfigureTLSVersion(config.TLSConfig, TLSVersion); err != nil {
return c, err return nil, err
} }
switch resolverURL.Scheme { switch resolverURL.Scheme {
case "system": case "system":
c.Resolver = &netxlite.ResolverSystem{} return &netxlite.ResolverSystem{}, nil
return c, nil
case "https": case "https":
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
c.httpClient = &http.Client{Transport: NewHTTPTransport(config)} httpClient := &http.Client{Transport: NewHTTPTransport(config)}
var txp model.DNSTransport = netxlite.NewDNSOverHTTPSWithHostOverride( var txp model.DNSTransport = netxlite.NewDNSOverHTTPSWithHostOverride(
c.httpClient, URL, hostOverride) httpClient, URL, hostOverride)
if config.ResolveSaver != nil { if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{ txp = resolver.SaverDNSTransport{
DNSTransport: txp, DNSTransport: txp,
Saver: config.ResolveSaver, Saver: config.ResolveSaver,
} }
} }
c.Resolver = netxlite.NewSerialResolver(txp) return netxlite.NewSerialResolver(txp), nil
return c, nil
case "udp": case "udp":
dialer := NewDialer(config) dialer := NewDialer(config)
endpoint, err := makeValidEndpoint(resolverURL) endpoint, err := makeValidEndpoint(resolverURL)
if err != nil { if err != nil {
return c, err return nil, err
} }
var txp model.DNSTransport = netxlite.NewDNSOverUDP( var txp model.DNSTransport = netxlite.NewDNSOverUDP(
dialer, endpoint) dialer, endpoint)
@ -329,14 +312,13 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
Saver: config.ResolveSaver, Saver: config.ResolveSaver,
} }
} }
c.Resolver = netxlite.NewSerialResolver(txp) return netxlite.NewSerialResolver(txp), nil
return c, nil
case "dot": case "dot":
config.TLSConfig.NextProtos = []string{"dot"} config.TLSConfig.NextProtos = []string{"dot"}
tlsDialer := NewTLSDialer(config) tlsDialer := NewTLSDialer(config)
endpoint, err := makeValidEndpoint(resolverURL) endpoint, err := makeValidEndpoint(resolverURL)
if err != nil { if err != nil {
return c, err return nil, err
} }
var txp model.DNSTransport = netxlite.NewDNSOverTLS( var txp model.DNSTransport = netxlite.NewDNSOverTLS(
tlsDialer.DialTLSContext, endpoint) tlsDialer.DialTLSContext, endpoint)
@ -346,13 +328,12 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
Saver: config.ResolveSaver, Saver: config.ResolveSaver,
} }
} }
c.Resolver = netxlite.NewSerialResolver(txp) return netxlite.NewSerialResolver(txp), nil
return c, nil
case "tcp": case "tcp":
dialer := NewDialer(config) dialer := NewDialer(config)
endpoint, err := makeValidEndpoint(resolverURL) endpoint, err := makeValidEndpoint(resolverURL)
if err != nil { if err != nil {
return c, err return nil, err
} }
var txp model.DNSTransport = netxlite.NewDNSOverTCP( var txp model.DNSTransport = netxlite.NewDNSOverTCP(
dialer.DialContext, endpoint) dialer.DialContext, endpoint)
@ -362,10 +343,9 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
Saver: config.ResolveSaver, Saver: config.ResolveSaver,
} }
} }
c.Resolver = netxlite.NewSerialResolver(txp) return netxlite.NewSerialResolver(txp), nil
return c, nil
default: default:
return c, errors.New("unsupported resolver scheme") return nil, errors.New("unsupported resolver scheme")
} }
} }

View File

@ -544,10 +544,9 @@ func TestNewDNSClientInvalidURL(t *testing.T) {
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if dnsclient.Resolver != nil { if dnsclient != nil {
t.Fatal("expected nil resolver here") t.Fatal("expected nil resolver here")
} }
dnsclient.CloseIdleConnections()
} }
func TestNewDNSClientUnsupportedScheme(t *testing.T) { func TestNewDNSClientUnsupportedScheme(t *testing.T) {
@ -555,10 +554,9 @@ func TestNewDNSClientUnsupportedScheme(t *testing.T) {
if err == nil || err.Error() != "unsupported resolver scheme" { if err == nil || err.Error() != "unsupported resolver scheme" {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if dnsclient.Resolver != nil { if dnsclient != nil {
t.Fatal("expected nil resolver here") t.Fatal("expected nil resolver here")
} }
dnsclient.CloseIdleConnections()
} }
func TestNewDNSClientSystemResolver(t *testing.T) { func TestNewDNSClientSystemResolver(t *testing.T) {
@ -567,7 +565,7 @@ func TestNewDNSClientSystemResolver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, ok := dnsclient.Resolver.(*netxlite.ResolverSystem); !ok { if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -579,7 +577,7 @@ func TestNewDNSClientEmpty(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if _, ok := dnsclient.Resolver.(*netxlite.ResolverSystem); !ok { if _, ok := dnsclient.(*netxlite.ResolverSystem); !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -591,7 +589,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -607,7 +605,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -623,7 +621,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -640,7 +638,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -660,7 +658,7 @@ func TestNewDNSClientUDP(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -677,7 +675,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -697,7 +695,7 @@ func TestNewDNSClientTCP(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -718,7 +716,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -742,7 +740,7 @@ func TestNewDNSClientDoT(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -763,7 +761,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r, ok := dnsclient.Resolver.(*netxlite.SerialResolver) r, ok := dnsclient.(*netxlite.SerialResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -787,7 +785,7 @@ func TestNewDNSCLientDoTWithoutPort(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if c.Resolver.Address() != "8.8.8.8:853" { if c.Address() != "8.8.8.8:853" {
t.Fatal("expected default port to be added") t.Fatal("expected default port to be added")
} }
} }
@ -798,7 +796,7 @@ func TestNewDNSCLientTCPWithoutPort(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if c.Resolver.Address() != "8.8.8.8:53" { if c.Address() != "8.8.8.8:53" {
t.Fatal("expected default port to be added") t.Fatal("expected default port to be added")
} }
} }
@ -809,7 +807,7 @@ func TestNewDNSCLientUDPWithoutPort(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if c.Resolver.Address() != "8.8.8.8:53" { if c.Address() != "8.8.8.8:53" {
t.Fatal("expected default port to be added") t.Fatal("expected default port to be added")
} }
} }