feat(netxlite): implement DNSTransport wrapping (#776)
Acknowledge that transports MAY be used in isolation (i.e., outside of a Resolver) and add support for wrapping. Ensure that every factory that creates an unwrapped type is named accordingly to hopefully ensure there are no surprises. Implement DNSTransport wrapping and use a technique similar to the one used by Dialer to customize the DNSTransport while constructing more complex data types (e.g., a specific resolver). Ensure that the stdlib resolver's own "getaddrinfo" transport (1) is wrapped and (2) could be extended during construction. This work is part of my ongoing effort to bring to this repository websteps-illustrated changes relative to netxlite. Ref issue: https://github.com/ooni/probe/issues/2096
This commit is contained in:
parent
923d81cdee
commit
8f7e3803eb
|
@ -264,20 +264,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
||||||
case "https":
|
case "https":
|
||||||
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
|
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||||
httpClient := &http.Client{Transport: NewHTTPTransport(config)}
|
httpClient := &http.Client{Transport: NewHTTPTransport(config)}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride(
|
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverHTTPSTransportWithHostOverride(
|
||||||
httpClient, URL, hostOverride)
|
httpClient, URL, hostOverride)
|
||||||
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
||||||
return netxlite.NewSerialResolver(txp), nil
|
return netxlite.NewUnwrappedSerialResolver(txp), nil
|
||||||
case "udp":
|
case "udp":
|
||||||
dialer := NewDialer(config)
|
dialer := NewDialer(config)
|
||||||
endpoint, err := makeValidEndpoint(resolverURL)
|
endpoint, err := makeValidEndpoint(resolverURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport(
|
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverUDPTransport(
|
||||||
dialer, endpoint)
|
dialer, endpoint)
|
||||||
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
||||||
return netxlite.NewSerialResolver(txp), nil
|
return netxlite.NewUnwrappedSerialResolver(txp), nil
|
||||||
case "dot":
|
case "dot":
|
||||||
config.TLSConfig.NextProtos = []string{"dot"}
|
config.TLSConfig.NextProtos = []string{"dot"}
|
||||||
tlsDialer := NewTLSDialer(config)
|
tlsDialer := NewTLSDialer(config)
|
||||||
|
@ -285,20 +285,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport(
|
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTLSTransport(
|
||||||
tlsDialer.DialTLSContext, endpoint)
|
tlsDialer.DialTLSContext, endpoint)
|
||||||
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
||||||
return netxlite.NewSerialResolver(txp), nil
|
return netxlite.NewUnwrappedSerialResolver(txp), nil
|
||||||
case "tcp":
|
case "tcp":
|
||||||
dialer := NewDialer(config)
|
dialer := NewDialer(config)
|
||||||
endpoint, err := makeValidEndpoint(resolverURL)
|
endpoint, err := makeValidEndpoint(resolverURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport(
|
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTCPTransport(
|
||||||
dialer.DialContext, endpoint)
|
dialer.DialContext, endpoint)
|
||||||
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
|
||||||
return netxlite.NewSerialResolver(txp), nil
|
return netxlite.NewUnwrappedSerialResolver(txp), nil
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported resolver scheme")
|
return nil, errors.New("unsupported resolver scheme")
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,50 +70,50 @@ func TestNewResolverSystem(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverUDPAddress(t *testing.T) {
|
func TestNewResolverUDPAddress(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewUnwrappedSerialResolver(
|
||||||
netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53"))
|
netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverUDPDomain(t *testing.T) {
|
func TestNewResolverUDPDomain(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewUnwrappedSerialResolver(
|
||||||
netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53"))
|
netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverTCPAddress(t *testing.T) {
|
func TestNewResolverTCPAddress(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewUnwrappedSerialResolver(
|
||||||
netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53"))
|
netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverTCPDomain(t *testing.T) {
|
func TestNewResolverTCPDomain(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewUnwrappedSerialResolver(
|
||||||
netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53"))
|
netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverDoTAddress(t *testing.T) {
|
func TestNewResolverDoTAddress(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewUnwrappedSerialResolver(
|
||||||
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
|
netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverDoTDomain(t *testing.T) {
|
func TestNewResolverDoTDomain(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewUnwrappedSerialResolver(
|
||||||
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
|
netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverDoH(t *testing.T) {
|
func TestNewResolverDoH(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewUnwrappedSerialResolver(
|
||||||
netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
|
netxlite.NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,8 +43,8 @@ func (mx *Measurer) NewResolverSystem(db WritableDB, logger model.Logger) model.
|
||||||
// - address is the resolver address (e.g., "1.1.1.1:53").
|
// - address is the resolver address (e.g., "1.1.1.1:53").
|
||||||
func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver {
|
func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver {
|
||||||
return mx.WrapResolver(db, netxlite.WrapResolver(
|
return mx.WrapResolver(db, netxlite.WrapResolver(
|
||||||
logger, netxlite.NewSerialResolver(
|
logger, netxlite.NewUnwrappedSerialResolver(
|
||||||
mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDPTransport(
|
mx.WrapDNSXRoundTripper(db, netxlite.NewUnwrappedDNSOverUDPTransport(
|
||||||
mx.NewDialerWithSystemResolver(db, logger),
|
mx.NewDialerWithSystemResolver(db, logger),
|
||||||
address,
|
address,
|
||||||
)))),
|
)))),
|
||||||
|
|
|
@ -101,6 +101,12 @@ type DNSEncoder interface {
|
||||||
Encode(domain string, qtype uint16, padding bool) DNSQuery
|
Encode(domain string, qtype uint16, padding bool) DNSQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DNSTransportWrapper is a type that takes in input a DNSTransport
|
||||||
|
// and returns in output a wrapped DNSTransport.
|
||||||
|
type DNSTransportWrapper interface {
|
||||||
|
WrapDNSTransport(txp DNSTransport) DNSTransport
|
||||||
|
}
|
||||||
|
|
||||||
// DNSTransport represents an abstract DNS transport.
|
// DNSTransport represents an abstract DNS transport.
|
||||||
type DNSTransport interface {
|
type DNSTransport interface {
|
||||||
// RoundTrip sends a DNS query and receives the reply.
|
// RoundTrip sends a DNS query and receives the reply.
|
||||||
|
|
|
@ -31,20 +31,21 @@ type DNSOverHTTPSTransport struct {
|
||||||
HostOverride string
|
HostOverride string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport instance.
|
// NewUnwrappedDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport
|
||||||
|
// instance that has not been wrapped yet.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - client is a model.HTTPClient type;
|
// - client is a model.HTTPClient type;
|
||||||
//
|
//
|
||||||
// - URL is the DoH resolver URL (e.g., https://dns.google/dns-query).
|
// - URL is the DoH resolver URL (e.g., https://dns.google/dns-query).
|
||||||
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
|
func NewUnwrappedDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
|
||||||
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
|
return NewUnwrappedDNSOverHTTPSTransportWithHostOverride(client, URL, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport
|
// NewUnwrappedDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport
|
||||||
// with the given Host header override.
|
// with the given Host header override. This instance has not been wrapped yet.
|
||||||
func NewDNSOverHTTPSTransportWithHostOverride(
|
func NewUnwrappedDNSOverHTTPSTransportWithHostOverride(
|
||||||
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
|
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
|
||||||
return &DNSOverHTTPSTransport{
|
return &DNSOverHTTPSTransport{
|
||||||
Client: client,
|
Client: client,
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
func TestDNSOverHTTPSTransport(t *testing.T) {
|
func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
t.Run("query serialization failure", func(t *testing.T) {
|
t.Run("query serialization failure", func(t *testing.T) {
|
||||||
txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
|
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
query := &mocks.DNSQuery{
|
query := &mocks.DNSQuery{
|
||||||
MockBytes: func() ([]byte, error) {
|
MockBytes: func() ([]byte, error) {
|
||||||
|
@ -34,7 +34,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("NewRequestFailure", func(t *testing.T) {
|
t.Run("NewRequestFailure", func(t *testing.T) {
|
||||||
const invalidURL = "\t"
|
const invalidURL = "\t"
|
||||||
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
|
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
|
||||||
query := &mocks.DNSQuery{
|
query := &mocks.DNSQuery{
|
||||||
MockBytes: func() ([]byte, error) {
|
MockBytes: func() ([]byte, error) {
|
||||||
return make([]byte, 17), nil
|
return make([]byte, 17), nil
|
||||||
|
@ -293,7 +293,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions behave correctly", func(t *testing.T) {
|
t.Run("other functions behave correctly", func(t *testing.T) {
|
||||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||||
txp := NewDNSOverHTTPSTransport(http.DefaultClient, queryURL)
|
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, queryURL)
|
||||||
if txp.Network() != "doh" {
|
if txp.Network() != "doh" {
|
||||||
t.Fatal("invalid network")
|
t.Fatal("invalid network")
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,25 +31,27 @@ type DNSOverTCPTransport struct {
|
||||||
requiresPadding bool
|
requiresPadding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverTCPTransport creates a new DNSOverTCPTransport.
|
// NewUnwrappedDNSOverTCPTransport creates a new DNSOverTCPTransport
|
||||||
|
// that has not been wrapped yet.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - dial is a function with the net.Dialer.DialContext's signature;
|
// - dial is a function with the net.Dialer.DialContext's signature;
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||||
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
func NewUnwrappedDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||||
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
|
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverTLSTransport creates a new DNSOverTLS transport.
|
// NewUnwrappedDNSOverTLSTransport creates a new DNSOverTLS transport
|
||||||
|
// that has not been wrapped yet.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - dial is a function with the net.Dialer.DialContext's signature;
|
// - dial is a function with the net.Dialer.DialContext's signature;
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
||||||
func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
func NewUnwrappedDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||||
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
|
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
t.Run("cannot encode query", func(t *testing.T) {
|
t.Run("cannot encode query", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||||
query := &mocks.DNSQuery{
|
query := &mocks.DNSQuery{
|
||||||
MockBytes: func() ([]byte, error) {
|
MockBytes: func() ([]byte, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
|
@ -37,7 +37,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("query too large", func(t *testing.T) {
|
t.Run("query too large", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||||
query := &mocks.DNSQuery{
|
query := &mocks.DNSQuery{
|
||||||
MockBytes: func() ([]byte, error) {
|
MockBytes: func() ([]byte, error) {
|
||||||
return make([]byte, math.MaxUint16+1), nil
|
return make([]byte, math.MaxUint16+1), nil
|
||||||
|
@ -65,7 +65,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
resp, err := txp.RoundTrip(context.Background(), query)
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -98,7 +98,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
resp, err := txp.RoundTrip(context.Background(), query)
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -134,7 +134,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
resp, err := txp.RoundTrip(context.Background(), query)
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -176,7 +176,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
resp, err := txp.RoundTrip(context.Background(), query)
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -211,7 +211,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
txp.decoder = &mocks.DNSDecoder{
|
txp.decoder = &mocks.DNSDecoder{
|
||||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
|
@ -250,7 +250,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
expectedResp := &mocks.DNSResponse{}
|
expectedResp := &mocks.DNSResponse{}
|
||||||
txp.decoder = &mocks.DNSDecoder{
|
txp.decoder = &mocks.DNSDecoder{
|
||||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
@ -269,7 +269,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions okay with TCP", func(t *testing.T) {
|
t.Run("other functions okay with TCP", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||||
if txp.RequiresPadding() != false {
|
if txp.RequiresPadding() != false {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
@ -284,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions okay with TLS", func(t *testing.T) {
|
t.Run("other functions okay with TLS", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:853"
|
const address = "9.9.9.9:853"
|
||||||
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
|
txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
|
||||||
if txp.RequiresPadding() != true {
|
if txp.RequiresPadding() != true {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,8 @@ type DNSOverUDPTransport struct {
|
||||||
IOTimeout time.Duration
|
IOTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverUDPTransport creates a DNSOverUDPTransport instance.
|
// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance
|
||||||
|
// that has not been wrapped yet.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
|
@ -64,7 +65,7 @@ type DNSOverUDPTransport struct {
|
||||||
// IP addresses returned by the underlying DNS lookup performed using
|
// IP addresses returned by the underlying DNS lookup performed using
|
||||||
// the dialer. This usage pattern is NOT RECOMMENDED because we'll
|
// the dialer. This usage pattern is NOT RECOMMENDED because we'll
|
||||||
// have less control over which IP address is being used.
|
// have less control over which IP address is being used.
|
||||||
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
||||||
return &DNSOverUDPTransport{
|
return &DNSOverUDPTransport{
|
||||||
Decoder: &DNSDecoderMiekg{},
|
Decoder: &DNSDecoderMiekg{},
|
||||||
Dialer: dialer,
|
Dialer: dialer,
|
||||||
|
|
|
@ -21,7 +21,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
t.Run("cannot encode query", func(t *testing.T) {
|
t.Run("cannot encode query", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverUDPTransport(nil, address)
|
txp := NewUnwrappedDNSOverUDPTransport(nil, address)
|
||||||
query := &mocks.DNSQuery{
|
query := &mocks.DNSQuery{
|
||||||
MockBytes: func() ([]byte, error) {
|
MockBytes: func() ([]byte, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
|
@ -39,7 +39,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
t.Run("dial failure", func(t *testing.T) {
|
t.Run("dial failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverUDPTransport(&mocks.Dialer{
|
txp := NewUnwrappedDNSOverUDPTransport(&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
|
@ -60,7 +60,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("Write failure", func(t *testing.T) {
|
t.Run("Write failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := NewDNSOverUDPTransport(
|
txp := NewUnwrappedDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -103,7 +103,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("Read failure", func(t *testing.T) {
|
t.Run("Read failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := NewDNSOverUDPTransport(
|
txp := NewUnwrappedDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -150,7 +150,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
t.Run("decode failure", func(t *testing.T) {
|
t.Run("decode failure", func(t *testing.T) {
|
||||||
const expected = 17
|
const expected = 17
|
||||||
input := bytes.NewReader(make([]byte, expected))
|
input := bytes.NewReader(make([]byte, expected))
|
||||||
txp := NewDNSOverUDPTransport(
|
txp := NewUnwrappedDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -201,7 +201,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
t.Run("decode success", func(t *testing.T) {
|
t.Run("decode success", func(t *testing.T) {
|
||||||
const expected = 17
|
const expected = 17
|
||||||
input := bytes.NewReader(make([]byte, expected))
|
input := bytes.NewReader(make([]byte, expected))
|
||||||
txp := NewDNSOverUDPTransport(
|
txp := NewUnwrappedDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -264,7 +264,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
resp, err := txp.RoundTrip(context.Background(), query)
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
@ -297,7 +297,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -332,7 +332,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -359,7 +359,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
||||||
|
@ -413,7 +413,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
|
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
|
@ -440,7 +440,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverUDPTransport(dialer, address)
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, address)
|
||||||
txp.CloseIdleConnections()
|
txp.CloseIdleConnections()
|
||||||
if !called {
|
if !called {
|
||||||
t.Fatal("not called")
|
t.Fatal("not called")
|
||||||
|
@ -449,7 +449,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions okay", func(t *testing.T) {
|
t.Run("other functions okay", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
|
txp := NewUnwrappedDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
|
||||||
if txp.RequiresPadding() != false {
|
if txp.RequiresPadding() != false {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
|
60
internal/netxlite/dnstransport.go
Normal file
60
internal/netxlite/dnstransport.go
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
package netxlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Generic DNS transport code.
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WrapDNSTransport wraps a DNSTransport to provide error wrapping. This function will
|
||||||
|
// apply all the provided wrappers around the default transport wrapping. If any of the
|
||||||
|
// wrappers is nil, we just silently and gracefully ignore it.
|
||||||
|
func WrapDNSTransport(txp model.DNSTransport,
|
||||||
|
wrappers ...model.DNSTransportWrapper) (out model.DNSTransport) {
|
||||||
|
out = &dnsTransportErrWrapper{
|
||||||
|
DNSTransport: txp,
|
||||||
|
}
|
||||||
|
for _, wrapper := range wrappers {
|
||||||
|
if wrapper == nil {
|
||||||
|
continue // skip as documented
|
||||||
|
}
|
||||||
|
out = wrapper.WrapDNSTransport(out) // compose with user-provided wrappers
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsTransportErrWrapper wraps DNSTransport to provide error wrapping.
|
||||||
|
type dnsTransportErrWrapper struct {
|
||||||
|
DNSTransport model.DNSTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.DNSTransport = &dnsTransportErrWrapper{}
|
||||||
|
|
||||||
|
func (t *dnsTransportErrWrapper) RoundTrip(
|
||||||
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
resp, err := t.DNSTransport.RoundTrip(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dnsTransportErrWrapper) RequiresPadding() bool {
|
||||||
|
return t.DNSTransport.RequiresPadding()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dnsTransportErrWrapper) Network() string {
|
||||||
|
return t.DNSTransport.Network()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dnsTransportErrWrapper) Address() string {
|
||||||
|
return t.DNSTransport.Address()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *dnsTransportErrWrapper) CloseIdleConnections() {
|
||||||
|
t.DNSTransport.CloseIdleConnections()
|
||||||
|
}
|
156
internal/netxlite/dnstransport_test.go
Normal file
156
internal/netxlite/dnstransport_test.go
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
package netxlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
type dnsTransportExtensionFirst struct {
|
||||||
|
model.DNSTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
type dnsTransportWrapperFirst struct{}
|
||||||
|
|
||||||
|
func (*dnsTransportWrapperFirst) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
|
||||||
|
return &dnsTransportExtensionFirst{txp}
|
||||||
|
}
|
||||||
|
|
||||||
|
type dnsTransportExtensionSecond struct {
|
||||||
|
model.DNSTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
type dnsTransportWrapperSecond struct{}
|
||||||
|
|
||||||
|
func (*dnsTransportWrapperSecond) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
|
||||||
|
return &dnsTransportExtensionSecond{txp}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapDNSTransport(t *testing.T) {
|
||||||
|
orig := &mocks.DNSTransport{}
|
||||||
|
extensions := []model.DNSTransportWrapper{
|
||||||
|
&dnsTransportWrapperFirst{},
|
||||||
|
nil, // explicitly test for documented use case
|
||||||
|
&dnsTransportWrapperSecond{},
|
||||||
|
}
|
||||||
|
txp := WrapDNSTransport(orig, extensions...)
|
||||||
|
ext2 := txp.(*dnsTransportExtensionSecond)
|
||||||
|
ext1 := ext2.DNSTransport.(*dnsTransportExtensionFirst)
|
||||||
|
errWrapper := ext1.DNSTransport.(*dnsTransportErrWrapper)
|
||||||
|
underlying := errWrapper.DNSTransport
|
||||||
|
if orig != underlying {
|
||||||
|
t.Fatal("unexpected underlying transport")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSTransportErrWrapper(t *testing.T) {
|
||||||
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
expectedResp := &mocks.DNSResponse{}
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return expectedResp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := &dnsTransportErrWrapper{
|
||||||
|
DNSTransport: child,
|
||||||
|
}
|
||||||
|
query := &mocks.DNSQuery{}
|
||||||
|
ctx := context.Background()
|
||||||
|
resp, err := txp.RoundTrip(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if resp != expectedResp {
|
||||||
|
t.Fatal("unexpected resp")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expectedErr := io.EOF
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return nil, expectedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := &dnsTransportErrWrapper{
|
||||||
|
DNSTransport: child,
|
||||||
|
}
|
||||||
|
query := &mocks.DNSQuery{}
|
||||||
|
ctx := context.Background()
|
||||||
|
resp, err := txp.RoundTrip(ctx, query)
|
||||||
|
if !errors.Is(err, expectedErr) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("unexpected resp")
|
||||||
|
}
|
||||||
|
var errWrapper *ErrWrapper
|
||||||
|
if !errors.As(err, &errWrapper) {
|
||||||
|
t.Fatal("error has not been wrapped")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RequiresPadding", func(t *testing.T) {
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := &dnsTransportErrWrapper{
|
||||||
|
DNSTransport: child,
|
||||||
|
}
|
||||||
|
if !txp.RequiresPadding() {
|
||||||
|
t.Fatal("expected true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network", func(t *testing.T) {
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "x"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := &dnsTransportErrWrapper{
|
||||||
|
DNSTransport: child,
|
||||||
|
}
|
||||||
|
if txp.Network() != "x" {
|
||||||
|
t.Fatal("unexpected Network")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Address", func(t *testing.T) {
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockAddress: func() string {
|
||||||
|
return "x"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := &dnsTransportErrWrapper{
|
||||||
|
DNSTransport: child,
|
||||||
|
}
|
||||||
|
if txp.Address() != "x" {
|
||||||
|
t.Fatal("unexpected Address")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
child := &mocks.DNSTransport{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := &dnsTransportErrWrapper{
|
||||||
|
DNSTransport: child,
|
||||||
|
}
|
||||||
|
txp.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -104,7 +104,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("on success", func(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
||||||
r := netxlite.NewResolverUDP(log.Log, dlr, "8.8.4.4:53")
|
r := netxlite.NewParallelResolverUDP(log.Log, dlr, "8.8.4.4:53")
|
||||||
defer r.CloseIdleConnections()
|
defer r.CloseIdleConnections()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google.com")
|
addrs, err := r.LookupHost(ctx, "dns.google.com")
|
||||||
|
@ -128,7 +128,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
||||||
r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String())
|
r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String())
|
||||||
defer r.CloseIdleConnections()
|
defer r.CloseIdleConnections()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
addrs, err := r.LookupHost(ctx, "ooni.org")
|
addrs, err := r.LookupHost(ctx, "ooni.org")
|
||||||
|
@ -152,7 +152,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
||||||
r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String())
|
r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String())
|
||||||
defer r.CloseIdleConnections()
|
defer r.CloseIdleConnections()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
addrs, err := r.LookupHost(ctx, "ooni.org")
|
addrs, err := r.LookupHost(ctx, "ooni.org")
|
||||||
|
@ -176,7 +176,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
||||||
r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String())
|
r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String())
|
||||||
defer r.CloseIdleConnections()
|
defer r.CloseIdleConnections()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
addrs, err := r.LookupHost(ctx, "ooni.org")
|
addrs, err := r.LookupHost(ctx, "ooni.org")
|
||||||
|
|
|
@ -13,6 +13,9 @@ const (
|
||||||
// ConnectOperation is the operation where we do a TCP connect.
|
// ConnectOperation is the operation where we do a TCP connect.
|
||||||
ConnectOperation = "connect"
|
ConnectOperation = "connect"
|
||||||
|
|
||||||
|
// DNSRoundTripOperation is the DNS round trip.
|
||||||
|
DNSRoundTripOperation = "dns_round_trip"
|
||||||
|
|
||||||
// TLSHandshakeOperation is the TLS handshake.
|
// TLSHandshakeOperation is the TLS handshake.
|
||||||
TLSHandshakeOperation = "tls_handshake"
|
TLSHandshakeOperation = "tls_handshake"
|
||||||
|
|
||||||
|
|
|
@ -23,18 +23,23 @@ import (
|
||||||
var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
|
var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
|
||||||
|
|
||||||
// NewResolverStdlib creates a new Resolver by combining WrapResolver
|
// NewResolverStdlib creates a new Resolver by combining WrapResolver
|
||||||
// with an internal "system" resolver type.
|
// with an internal "system" resolver type. The list of optional wrappers
|
||||||
func NewResolverStdlib(logger model.DebugLogger) model.Resolver {
|
// allow to wrap the underlying getaddrinfo transport. Any nil wrapper
|
||||||
return WrapResolver(logger, newResolverSystem())
|
// will be silently ignored by the code that performs the wrapping.
|
||||||
|
func NewResolverStdlib(logger model.DebugLogger, wrappers ...model.DNSTransportWrapper) model.Resolver {
|
||||||
|
return WrapResolver(logger, newResolverSystem(wrappers...))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newResolverSystem() *resolverSystem {
|
func newResolverSystem(wrappers ...model.DNSTransportWrapper) *resolverSystem {
|
||||||
return &resolverSystem{
|
return &resolverSystem{
|
||||||
t: &dnsOverGetaddrinfoTransport{},
|
t: WrapDNSTransport(&dnsOverGetaddrinfoTransport{}, wrappers...),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolverUDP creates a new Resolver using DNS-over-UDP.
|
// NewSerialResolverUDP creates a new Resolver using DNS-over-UDP
|
||||||
|
// that performs serial A/AAAA lookups during LookupHost.
|
||||||
|
//
|
||||||
|
// Deprecated: use NewParallelResolverUDP.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
|
@ -43,9 +48,33 @@ func newResolverSystem() *resolverSystem {
|
||||||
// - dialer is the dialer to create and connect UDP conns
|
// - dialer is the dialer to create and connect UDP conns
|
||||||
//
|
//
|
||||||
// - address is the server address (e.g., 1.1.1.1:53)
|
// - address is the server address (e.g., 1.1.1.1:53)
|
||||||
func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver {
|
//
|
||||||
return WrapResolver(logger, NewSerialResolver(
|
// - wrappers is the optional list of wrappers to wrap the underlying
|
||||||
NewDNSOverUDPTransport(dialer, address),
|
// transport. Any nil wrapper will be silently ignored.
|
||||||
|
func NewSerialResolverUDP(logger model.DebugLogger, dialer model.Dialer,
|
||||||
|
address string, wrappers ...model.DNSTransportWrapper) model.Resolver {
|
||||||
|
return WrapResolver(logger, NewUnwrappedSerialResolver(
|
||||||
|
WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewParallelResolverUDP creates a new Resolver using DNS-over-UDP
|
||||||
|
// that performs parallel A/AAAA lookups during LookupHost.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// - logger is the logger to use
|
||||||
|
//
|
||||||
|
// - dialer is the dialer to create and connect UDP conns
|
||||||
|
//
|
||||||
|
// - address is the server address (e.g., 1.1.1.1:53)
|
||||||
|
//
|
||||||
|
// - wrappers is the optional list of wrappers to wrap the underlying
|
||||||
|
// transport. Any nil wrapper will be silently ignored.
|
||||||
|
func NewParallelResolverUDP(logger model.DebugLogger, dialer model.Dialer,
|
||||||
|
address string, wrappers ...model.DNSTransportWrapper) model.Resolver {
|
||||||
|
return WrapResolver(logger, NewUnwrappedParallelResolver(
|
||||||
|
WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,12 +25,13 @@ func TestNewResolverSystem(t *testing.T) {
|
||||||
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||||
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||||
reso := errWrapper.Resolver.(*resolverSystem)
|
reso := errWrapper.Resolver.(*resolverSystem)
|
||||||
_ = reso.t.(*dnsOverGetaddrinfoTransport)
|
txpErrWrapper := reso.t.(*dnsTransportErrWrapper)
|
||||||
|
_ = txpErrWrapper.DNSTransport.(*dnsOverGetaddrinfoTransport)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverUDP(t *testing.T) {
|
func TestNewSerialResolverUDP(t *testing.T) {
|
||||||
d := NewDialerWithoutResolver(log.Log)
|
d := NewDialerWithoutResolver(log.Log)
|
||||||
resolver := NewResolverUDP(log.Log, d, "1.1.1.1:53")
|
resolver := NewSerialResolverUDP(log.Log, d, "1.1.1.1:53")
|
||||||
idna := resolver.(*resolverIDNA)
|
idna := resolver.(*resolverIDNA)
|
||||||
logger := idna.Resolver.(*resolverLogger)
|
logger := idna.Resolver.(*resolverLogger)
|
||||||
if logger.Logger != log.Log {
|
if logger.Logger != log.Log {
|
||||||
|
@ -39,8 +40,27 @@ func TestNewResolverUDP(t *testing.T) {
|
||||||
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||||
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||||
serio := errWrapper.Resolver.(*SerialResolver)
|
serio := errWrapper.Resolver.(*SerialResolver)
|
||||||
txp := serio.Transport().(*DNSOverUDPTransport)
|
txp := serio.Transport().(*dnsTransportErrWrapper)
|
||||||
if txp.Address() != "1.1.1.1:53" {
|
dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport)
|
||||||
|
if dnsTxp.Address() != "1.1.1.1:53" {
|
||||||
|
t.Fatal("invalid address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewParallelResolverUDP(t *testing.T) {
|
||||||
|
d := NewDialerWithoutResolver(log.Log)
|
||||||
|
resolver := NewParallelResolverUDP(log.Log, d, "1.1.1.1:53")
|
||||||
|
idna := resolver.(*resolverIDNA)
|
||||||
|
logger := idna.Resolver.(*resolverLogger)
|
||||||
|
if logger.Logger != log.Log {
|
||||||
|
t.Fatal("invalid logger")
|
||||||
|
}
|
||||||
|
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||||
|
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||||
|
para := errWrapper.Resolver.(*ParallelResolver)
|
||||||
|
txp := para.Transport().(*dnsTransportErrWrapper)
|
||||||
|
dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport)
|
||||||
|
if dnsTxp.Address() != "1.1.1.1:53" {
|
||||||
t.Fatal("invalid address")
|
t.Fatal("invalid address")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
|
|
||||||
func TestParallelResolver(t *testing.T) {
|
func TestParallelResolver(t *testing.T) {
|
||||||
t.Run("transport okay", func(t *testing.T) {
|
t.Run("transport okay", func(t *testing.T) {
|
||||||
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||||
r := NewUnwrappedParallelResolver(txp)
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
rtx := r.Transport()
|
rtx := r.Transport()
|
||||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||||
|
|
|
@ -33,8 +33,8 @@ type SerialResolver struct {
|
||||||
Txp model.DNSTransport
|
Txp model.DNSTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSerialResolver creates a new SerialResolver instance.
|
// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance.
|
||||||
func NewSerialResolver(t model.DNSTransport) *SerialResolver {
|
func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver {
|
||||||
return &SerialResolver{
|
return &SerialResolver{
|
||||||
NumTimeouts: &atomicx.Int64{},
|
NumTimeouts: &atomicx.Int64{},
|
||||||
Txp: t,
|
Txp: t,
|
||||||
|
|
|
@ -31,8 +31,8 @@ func (err *errorWithTimeout) Unwrap() error {
|
||||||
|
|
||||||
func TestSerialResolver(t *testing.T) {
|
func TestSerialResolver(t *testing.T) {
|
||||||
t.Run("transport okay", func(t *testing.T) {
|
t.Run("transport okay", func(t *testing.T) {
|
||||||
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||||
r := NewSerialResolver(txp)
|
r := NewUnwrappedSerialResolver(txp)
|
||||||
rtx := r.Transport()
|
rtx := r.Transport()
|
||||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
|
@ -56,7 +56,7 @@ func TestSerialResolver(t *testing.T) {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
r := NewSerialResolver(txp)
|
r := NewUnwrappedSerialResolver(txp)
|
||||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -80,7 +80,7 @@ func TestSerialResolver(t *testing.T) {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
r := NewSerialResolver(txp)
|
r := NewUnwrappedSerialResolver(txp)
|
||||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
|
@ -107,7 +107,7 @@ func TestSerialResolver(t *testing.T) {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
r := NewSerialResolver(txp)
|
r := NewUnwrappedSerialResolver(txp)
|
||||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -134,7 +134,7 @@ func TestSerialResolver(t *testing.T) {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
r := NewSerialResolver(txp)
|
r := NewUnwrappedSerialResolver(txp)
|
||||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -157,7 +157,7 @@ func TestSerialResolver(t *testing.T) {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
r := NewSerialResolver(txp)
|
r := NewUnwrappedSerialResolver(txp)
|
||||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
if !errors.Is(err, ETIMEDOUT) {
|
if !errors.Is(err, ETIMEDOUT) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
|
|
@ -54,7 +54,7 @@ func main() {
|
||||||
// UDP endpoint address at which the server is listening.
|
// UDP endpoint address at which the server is listening.
|
||||||
//
|
//
|
||||||
// ```Go
|
// ```Go
|
||||||
reso := netxlite.NewResolverUDP(log.Log, dialer, *serverAddr)
|
reso := netxlite.NewParallelResolverUDP(log.Log, dialer, *serverAddr)
|
||||||
// ```
|
// ```
|
||||||
//
|
//
|
||||||
// The API we invoke is the same as in the previous chapter, though,
|
// The API we invoke is the same as in the previous chapter, though,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user