feat(netxlite): implement DNSTransport wrapping (#776)

Acknowledge that transports MAY be used in isolation (i.e., outside
of a Resolver) and add support for wrapping.

Ensure that every factory that creates an unwrapped type is named
accordingly to hopefully ensure there are no surprises.

Implement DNSTransport wrapping and use a technique similar to the
one used by Dialer to customize the DNSTransport while constructing
more complex data types (e.g., a specific resolver).

Ensure that the stdlib resolver's own "getaddrinfo" transport (1)
is wrapped and (2) could be extended during construction.

This work is part of my ongoing effort to bring to this repository
websteps-illustrated changes relative to netxlite.

Ref issue: https://github.com/ooni/probe/issues/2096
This commit is contained in:
Simone Basso 2022-06-01 11:10:08 +02:00 committed by GitHub
parent 923d81cdee
commit 8f7e3803eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 369 additions and 91 deletions

View File

@ -264,20 +264,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
case "https":
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
httpClient := &http.Client{Transport: NewHTTPTransport(config)}
var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride(
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverHTTPSTransportWithHostOverride(
httpClient, URL, hostOverride)
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
return netxlite.NewSerialResolver(txp), nil
return netxlite.NewUnwrappedSerialResolver(txp), nil
case "udp":
dialer := NewDialer(config)
endpoint, err := makeValidEndpoint(resolverURL)
if err != nil {
return nil, err
}
var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport(
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverUDPTransport(
dialer, endpoint)
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
return netxlite.NewSerialResolver(txp), nil
return netxlite.NewUnwrappedSerialResolver(txp), nil
case "dot":
config.TLSConfig.NextProtos = []string{"dot"}
tlsDialer := NewTLSDialer(config)
@ -285,20 +285,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
if err != nil {
return nil, err
}
var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport(
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTLSTransport(
tlsDialer.DialTLSContext, endpoint)
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
return netxlite.NewSerialResolver(txp), nil
return netxlite.NewUnwrappedSerialResolver(txp), nil
case "tcp":
dialer := NewDialer(config)
endpoint, err := makeValidEndpoint(resolverURL)
if err != nil {
return nil, err
}
var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport(
var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTCPTransport(
dialer.DialContext, endpoint)
txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
return netxlite.NewSerialResolver(txp), nil
return netxlite.NewUnwrappedSerialResolver(txp), nil
default:
return nil, errors.New("unsupported resolver scheme")
}

View File

@ -70,50 +70,50 @@ func TestNewResolverSystem(t *testing.T) {
}
func TestNewResolverUDPAddress(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53"))
reso := netxlite.NewUnwrappedSerialResolver(
netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverUDPDomain(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53"))
reso := netxlite.NewUnwrappedSerialResolver(
netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverTCPAddress(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53"))
reso := netxlite.NewUnwrappedSerialResolver(
netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverTCPDomain(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53"))
reso := netxlite.NewUnwrappedSerialResolver(
netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoTAddress(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
reso := netxlite.NewUnwrappedSerialResolver(
netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoTDomain(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
reso := netxlite.NewUnwrappedSerialResolver(
netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoH(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
reso := netxlite.NewUnwrappedSerialResolver(
netxlite.NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}

View File

@ -43,8 +43,8 @@ func (mx *Measurer) NewResolverSystem(db WritableDB, logger model.Logger) model.
// - address is the resolver address (e.g., "1.1.1.1:53").
func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver {
return mx.WrapResolver(db, netxlite.WrapResolver(
logger, netxlite.NewSerialResolver(
mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDPTransport(
logger, netxlite.NewUnwrappedSerialResolver(
mx.WrapDNSXRoundTripper(db, netxlite.NewUnwrappedDNSOverUDPTransport(
mx.NewDialerWithSystemResolver(db, logger),
address,
)))),

View File

@ -101,6 +101,12 @@ type DNSEncoder interface {
Encode(domain string, qtype uint16, padding bool) DNSQuery
}
// DNSTransportWrapper is a type that takes in input a DNSTransport
// and returns in output a wrapped DNSTransport.
type DNSTransportWrapper interface {
WrapDNSTransport(txp DNSTransport) DNSTransport
}
// DNSTransport represents an abstract DNS transport.
type DNSTransport interface {
// RoundTrip sends a DNS query and receives the reply.

View File

@ -31,20 +31,21 @@ type DNSOverHTTPSTransport struct {
HostOverride string
}
// NewDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport instance.
// NewUnwrappedDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport
// instance that has not been wrapped yet.
//
// Arguments:
//
// - client is a model.HTTPClient type;
//
// - URL is the DoH resolver URL (e.g., https://dns.google/dns-query).
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
func NewUnwrappedDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
return NewUnwrappedDNSOverHTTPSTransportWithHostOverride(client, URL, "")
}
// NewDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport
// with the given Host header override.
func NewDNSOverHTTPSTransportWithHostOverride(
// NewUnwrappedDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport
// with the given Host header override. This instance has not been wrapped yet.
func NewUnwrappedDNSOverHTTPSTransportWithHostOverride(
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
return &DNSOverHTTPSTransport{
Client: client,

View File

@ -16,7 +16,7 @@ import (
func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("query serialization failure", func(t *testing.T) {
txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
expected := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
@ -34,7 +34,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t"
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
@ -293,7 +293,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("other functions behave correctly", func(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := NewDNSOverHTTPSTransport(http.DefaultClient, queryURL)
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}

View File

@ -31,25 +31,27 @@ type DNSOverTCPTransport struct {
requiresPadding bool
}
// NewDNSOverTCPTransport creates a new DNSOverTCPTransport.
// NewUnwrappedDNSOverTCPTransport creates a new DNSOverTCPTransport
// that has not been wrapped yet.
//
// Arguments:
//
// - dial is a function with the net.Dialer.DialContext's signature;
//
// - address is the endpoint address (e.g., 8.8.8.8:53).
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
func NewUnwrappedDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
}
// NewDNSOverTLSTransport creates a new DNSOverTLS transport.
// NewUnwrappedDNSOverTLSTransport creates a new DNSOverTLS transport
// that has not been wrapped yet.
//
// Arguments:
//
// - dial is a function with the net.Dialer.DialContext's signature;
//
// - address is the endpoint address (e.g., 8.8.8.8:853).
func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
func NewUnwrappedDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
}

View File

@ -20,7 +20,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
@ -37,7 +37,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
t.Run("query too large", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, math.MaxUint16+1), nil
@ -65,7 +65,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
return nil, mocked
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -98,7 +98,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -134,7 +134,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -176,7 +176,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -211,7 +211,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked
@ -250,7 +250,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
expectedResp := &mocks.DNSResponse{}
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
@ -269,7 +269,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
t.Run("other functions okay with TCP", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
@ -284,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
t.Run("other functions okay with TLS", func(t *testing.T) {
const address = "9.9.9.9:853"
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}

View File

@ -51,7 +51,8 @@ type DNSOverUDPTransport struct {
IOTimeout time.Duration
}
// NewDNSOverUDPTransport creates a DNSOverUDPTransport instance.
// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance
// that has not been wrapped yet.
//
// Arguments:
//
@ -64,7 +65,7 @@ type DNSOverUDPTransport struct {
// IP addresses returned by the underlying DNS lookup performed using
// the dialer. This usage pattern is NOT RECOMMENDED because we'll
// have less control over which IP address is being used.
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
return &DNSOverUDPTransport{
Decoder: &DNSDecoderMiekg{},
Dialer: dialer,

View File

@ -21,7 +21,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverUDPTransport(nil, address)
txp := NewUnwrappedDNSOverUDPTransport(nil, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
@ -39,7 +39,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("dial failure", func(t *testing.T) {
mocked := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverUDPTransport(&mocks.Dialer{
txp := NewUnwrappedDNSOverUDPTransport(&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
@ -60,7 +60,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("Write failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverUDPTransport(
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
@ -103,7 +103,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("Read failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverUDPTransport(
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
@ -150,7 +150,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("decode failure", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewDNSOverUDPTransport(
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
@ -201,7 +201,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("decode success", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewDNSOverUDPTransport(
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
@ -264,7 +264,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
resp, err := txp.RoundTrip(context.Background(), query)
@ -297,7 +297,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
ctx := context.Background()
@ -332,7 +332,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
ctx := context.Background()
@ -359,7 +359,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
@ -413,7 +413,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
@ -440,7 +440,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
},
}
const address = "9.9.9.9:53"
txp := NewDNSOverUDPTransport(dialer, address)
txp := NewUnwrappedDNSOverUDPTransport(dialer, address)
txp.CloseIdleConnections()
if !called {
t.Fatal("not called")
@ -449,7 +449,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("other functions okay", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
txp := NewUnwrappedDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}

View File

@ -0,0 +1,60 @@
package netxlite
//
// Generic DNS transport code.
//
import (
"context"
"github.com/ooni/probe-cli/v3/internal/model"
)
// WrapDNSTransport wraps a DNSTransport to provide error wrapping. This function will
// apply all the provided wrappers around the default transport wrapping. If any of the
// wrappers is nil, we just silently and gracefully ignore it.
func WrapDNSTransport(txp model.DNSTransport,
wrappers ...model.DNSTransportWrapper) (out model.DNSTransport) {
out = &dnsTransportErrWrapper{
DNSTransport: txp,
}
for _, wrapper := range wrappers {
if wrapper == nil {
continue // skip as documented
}
out = wrapper.WrapDNSTransport(out) // compose with user-provided wrappers
}
return
}
// dnsTransportErrWrapper wraps DNSTransport to provide error wrapping.
type dnsTransportErrWrapper struct {
DNSTransport model.DNSTransport
}
var _ model.DNSTransport = &dnsTransportErrWrapper{}
func (t *dnsTransportErrWrapper) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
resp, err := t.DNSTransport.RoundTrip(ctx, query)
if err != nil {
return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err)
}
return resp, nil
}
func (t *dnsTransportErrWrapper) RequiresPadding() bool {
return t.DNSTransport.RequiresPadding()
}
func (t *dnsTransportErrWrapper) Network() string {
return t.DNSTransport.Network()
}
func (t *dnsTransportErrWrapper) Address() string {
return t.DNSTransport.Address()
}
func (t *dnsTransportErrWrapper) CloseIdleConnections() {
t.DNSTransport.CloseIdleConnections()
}

View File

@ -0,0 +1,156 @@
package netxlite
import (
"context"
"errors"
"io"
"testing"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
type dnsTransportExtensionFirst struct {
model.DNSTransport
}
type dnsTransportWrapperFirst struct{}
func (*dnsTransportWrapperFirst) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
return &dnsTransportExtensionFirst{txp}
}
type dnsTransportExtensionSecond struct {
model.DNSTransport
}
type dnsTransportWrapperSecond struct{}
func (*dnsTransportWrapperSecond) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
return &dnsTransportExtensionSecond{txp}
}
func TestWrapDNSTransport(t *testing.T) {
orig := &mocks.DNSTransport{}
extensions := []model.DNSTransportWrapper{
&dnsTransportWrapperFirst{},
nil, // explicitly test for documented use case
&dnsTransportWrapperSecond{},
}
txp := WrapDNSTransport(orig, extensions...)
ext2 := txp.(*dnsTransportExtensionSecond)
ext1 := ext2.DNSTransport.(*dnsTransportExtensionFirst)
errWrapper := ext1.DNSTransport.(*dnsTransportErrWrapper)
underlying := errWrapper.DNSTransport
if orig != underlying {
t.Fatal("unexpected underlying transport")
}
}
func TestDNSTransportErrWrapper(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedResp := &mocks.DNSResponse{}
child := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
},
}
txp := &dnsTransportErrWrapper{
DNSTransport: child,
}
query := &mocks.DNSQuery{}
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if err != nil {
t.Fatal(err)
}
if resp != expectedResp {
t.Fatal("unexpected resp")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
child := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expectedErr
},
}
txp := &dnsTransportErrWrapper{
DNSTransport: child,
}
query := &mocks.DNSQuery{}
ctx := context.Background()
resp, err := txp.RoundTrip(ctx, query)
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("unexpected resp")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("error has not been wrapped")
}
})
})
t.Run("RequiresPadding", func(t *testing.T) {
child := &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return true
},
}
txp := &dnsTransportErrWrapper{
DNSTransport: child,
}
if !txp.RequiresPadding() {
t.Fatal("expected true")
}
})
t.Run("Network", func(t *testing.T) {
child := &mocks.DNSTransport{
MockNetwork: func() string {
return "x"
},
}
txp := &dnsTransportErrWrapper{
DNSTransport: child,
}
if txp.Network() != "x" {
t.Fatal("unexpected Network")
}
})
t.Run("Address", func(t *testing.T) {
child := &mocks.DNSTransport{
MockAddress: func() string {
return "x"
},
}
txp := &dnsTransportErrWrapper{
DNSTransport: child,
}
if txp.Address() != "x" {
t.Fatal("unexpected Address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.DNSTransport{
MockCloseIdleConnections: func() {
called = true
},
}
txp := &dnsTransportErrWrapper{
DNSTransport: child,
}
txp.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}

View File

@ -104,7 +104,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
t.Run("on success", func(t *testing.T) {
dlr := netxlite.NewDialerWithoutResolver(log.Log)
r := netxlite.NewResolverUDP(log.Log, dlr, "8.8.4.4:53")
r := netxlite.NewParallelResolverUDP(log.Log, dlr, "8.8.4.4:53")
defer r.CloseIdleConnections()
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "dns.google.com")
@ -128,7 +128,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
}
defer listener.Close()
dlr := netxlite.NewDialerWithoutResolver(log.Log)
r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String())
r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String())
defer r.CloseIdleConnections()
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "ooni.org")
@ -152,7 +152,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
}
defer listener.Close()
dlr := netxlite.NewDialerWithoutResolver(log.Log)
r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String())
r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String())
defer r.CloseIdleConnections()
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "ooni.org")
@ -176,7 +176,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
}
defer listener.Close()
dlr := netxlite.NewDialerWithoutResolver(log.Log)
r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String())
r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String())
defer r.CloseIdleConnections()
ctx := context.Background()
addrs, err := r.LookupHost(ctx, "ooni.org")

View File

@ -13,6 +13,9 @@ const (
// ConnectOperation is the operation where we do a TCP connect.
ConnectOperation = "connect"
// DNSRoundTripOperation is the DNS round trip.
DNSRoundTripOperation = "dns_round_trip"
// TLSHandshakeOperation is the TLS handshake.
TLSHandshakeOperation = "tls_handshake"

View File

@ -23,18 +23,23 @@ import (
var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
// NewResolverStdlib creates a new Resolver by combining WrapResolver
// with an internal "system" resolver type.
func NewResolverStdlib(logger model.DebugLogger) model.Resolver {
return WrapResolver(logger, newResolverSystem())
// with an internal "system" resolver type. The list of optional wrappers
// allow to wrap the underlying getaddrinfo transport. Any nil wrapper
// will be silently ignored by the code that performs the wrapping.
func NewResolverStdlib(logger model.DebugLogger, wrappers ...model.DNSTransportWrapper) model.Resolver {
return WrapResolver(logger, newResolverSystem(wrappers...))
}
func newResolverSystem() *resolverSystem {
func newResolverSystem(wrappers ...model.DNSTransportWrapper) *resolverSystem {
return &resolverSystem{
t: &dnsOverGetaddrinfoTransport{},
t: WrapDNSTransport(&dnsOverGetaddrinfoTransport{}, wrappers...),
}
}
// NewResolverUDP creates a new Resolver using DNS-over-UDP.
// NewSerialResolverUDP creates a new Resolver using DNS-over-UDP
// that performs serial A/AAAA lookups during LookupHost.
//
// Deprecated: use NewParallelResolverUDP.
//
// Arguments:
//
@ -43,9 +48,33 @@ func newResolverSystem() *resolverSystem {
// - dialer is the dialer to create and connect UDP conns
//
// - address is the server address (e.g., 1.1.1.1:53)
func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver {
return WrapResolver(logger, NewSerialResolver(
NewDNSOverUDPTransport(dialer, address),
//
// - wrappers is the optional list of wrappers to wrap the underlying
// transport. Any nil wrapper will be silently ignored.
func NewSerialResolverUDP(logger model.DebugLogger, dialer model.Dialer,
address string, wrappers ...model.DNSTransportWrapper) model.Resolver {
return WrapResolver(logger, NewUnwrappedSerialResolver(
WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...),
))
}
// NewParallelResolverUDP creates a new Resolver using DNS-over-UDP
// that performs parallel A/AAAA lookups during LookupHost.
//
// Arguments:
//
// - logger is the logger to use
//
// - dialer is the dialer to create and connect UDP conns
//
// - address is the server address (e.g., 1.1.1.1:53)
//
// - wrappers is the optional list of wrappers to wrap the underlying
// transport. Any nil wrapper will be silently ignored.
func NewParallelResolverUDP(logger model.DebugLogger, dialer model.Dialer,
address string, wrappers ...model.DNSTransportWrapper) model.Resolver {
return WrapResolver(logger, NewUnwrappedParallelResolver(
WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...),
))
}

View File

@ -25,12 +25,13 @@ func TestNewResolverSystem(t *testing.T) {
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
reso := errWrapper.Resolver.(*resolverSystem)
_ = reso.t.(*dnsOverGetaddrinfoTransport)
txpErrWrapper := reso.t.(*dnsTransportErrWrapper)
_ = txpErrWrapper.DNSTransport.(*dnsOverGetaddrinfoTransport)
}
func TestNewResolverUDP(t *testing.T) {
func TestNewSerialResolverUDP(t *testing.T) {
d := NewDialerWithoutResolver(log.Log)
resolver := NewResolverUDP(log.Log, d, "1.1.1.1:53")
resolver := NewSerialResolverUDP(log.Log, d, "1.1.1.1:53")
idna := resolver.(*resolverIDNA)
logger := idna.Resolver.(*resolverLogger)
if logger.Logger != log.Log {
@ -39,8 +40,27 @@ func TestNewResolverUDP(t *testing.T) {
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
serio := errWrapper.Resolver.(*SerialResolver)
txp := serio.Transport().(*DNSOverUDPTransport)
if txp.Address() != "1.1.1.1:53" {
txp := serio.Transport().(*dnsTransportErrWrapper)
dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport)
if dnsTxp.Address() != "1.1.1.1:53" {
t.Fatal("invalid address")
}
}
func TestNewParallelResolverUDP(t *testing.T) {
d := NewDialerWithoutResolver(log.Log)
resolver := NewParallelResolverUDP(log.Log, d, "1.1.1.1:53")
idna := resolver.(*resolverIDNA)
logger := idna.Resolver.(*resolverLogger)
if logger.Logger != log.Log {
t.Fatal("invalid logger")
}
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
para := errWrapper.Resolver.(*ParallelResolver)
txp := para.Transport().(*dnsTransportErrWrapper)
dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport)
if dnsTxp.Address() != "1.1.1.1:53" {
t.Fatal("invalid address")
}
}

View File

@ -14,7 +14,7 @@ import (
func TestParallelResolver(t *testing.T) {
t.Run("transport okay", func(t *testing.T) {
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewUnwrappedParallelResolver(txp)
rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {

View File

@ -33,8 +33,8 @@ type SerialResolver struct {
Txp model.DNSTransport
}
// NewSerialResolver creates a new SerialResolver instance.
func NewSerialResolver(t model.DNSTransport) *SerialResolver {
// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance.
func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver {
return &SerialResolver{
NumTimeouts: &atomicx.Int64{},
Txp: t,

View File

@ -31,8 +31,8 @@ func (err *errorWithTimeout) Unwrap() error {
func TestSerialResolver(t *testing.T) {
t.Run("transport okay", func(t *testing.T) {
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewSerialResolver(txp)
txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewUnwrappedSerialResolver(txp)
rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
t.Fatal("not the transport we expected")
@ -56,7 +56,7 @@ func TestSerialResolver(t *testing.T) {
return true
},
}
r := NewSerialResolver(txp)
r := NewUnwrappedSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -80,7 +80,7 @@ func TestSerialResolver(t *testing.T) {
return true
},
}
r := NewSerialResolver(txp)
r := NewUnwrappedSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
@ -107,7 +107,7 @@ func TestSerialResolver(t *testing.T) {
return true
},
}
r := NewSerialResolver(txp)
r := NewUnwrappedSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil {
t.Fatal(err)
@ -134,7 +134,7 @@ func TestSerialResolver(t *testing.T) {
return true
},
}
r := NewSerialResolver(txp)
r := NewUnwrappedSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil {
t.Fatal(err)
@ -157,7 +157,7 @@ func TestSerialResolver(t *testing.T) {
return true
},
}
r := NewSerialResolver(txp)
r := NewUnwrappedSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, ETIMEDOUT) {
t.Fatal("not the error we expected")

View File

@ -54,7 +54,7 @@ func main() {
// UDP endpoint address at which the server is listening.
//
// ```Go
reso := netxlite.NewResolverUDP(log.Log, dialer, *serverAddr)
reso := netxlite.NewParallelResolverUDP(log.Log, dialer, *serverAddr)
// ```
//
// The API we invoke is the same as in the previous chapter, though,