From 8f7e3803eb52ae7a198144be43fa37ac3830a257 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 1 Jun 2022 11:10:08 +0200 Subject: [PATCH] feat(netxlite): implement DNSTransport wrapping (#776) Acknowledge that transports MAY be used in isolation (i.e., outside of a Resolver) and add support for wrapping. Ensure that every factory that creates an unwrapped type is named accordingly to hopefully ensure there are no surprises. Implement DNSTransport wrapping and use a technique similar to the one used by Dialer to customize the DNSTransport while constructing more complex data types (e.g., a specific resolver). Ensure that the stdlib resolver's own "getaddrinfo" transport (1) is wrapped and (2) could be extended during construction. This work is part of my ongoing effort to bring to this repository websteps-illustrated changes relative to netxlite. Ref issue: https://github.com/ooni/probe/issues/2096 --- internal/engine/netx/netx.go | 16 +- .../engine/netx/resolver/integration_test.go | 28 ++-- internal/measurex/resolver.go | 4 +- internal/model/netx.go | 6 + internal/netxlite/dnsoverhttps.go | 13 +- internal/netxlite/dnsoverhttps_test.go | 6 +- internal/netxlite/dnsovertcp.go | 10 +- internal/netxlite/dnsovertcp_test.go | 20 +-- internal/netxlite/dnsoverudp.go | 5 +- internal/netxlite/dnsoverudp_test.go | 26 +-- internal/netxlite/dnstransport.go | 60 +++++++ internal/netxlite/dnstransport_test.go | 156 ++++++++++++++++++ internal/netxlite/integration_test.go | 8 +- internal/netxlite/operations.go | 3 + internal/netxlite/resolvercore.go | 47 +++++- internal/netxlite/resolvercore_test.go | 30 +++- internal/netxlite/resolverparallel_test.go | 2 +- internal/netxlite/resolverserial.go | 4 +- internal/netxlite/resolverserial_test.go | 14 +- internal/tutorial/netxlite/chapter06/main.go | 2 +- 20 files changed, 369 insertions(+), 91 deletions(-) create mode 100644 internal/netxlite/dnstransport.go create mode 100644 internal/netxlite/dnstransport_test.go diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 20df40a..43b3a8f 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -264,20 +264,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, case "https": config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} httpClient := &http.Client{Transport: NewHTTPTransport(config)} - var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverHTTPSTransportWithHostOverride( httpClient, URL, hostOverride) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil case "udp": dialer := NewDialer(config) endpoint, err := makeValidEndpoint(resolverURL) if err != nil { return nil, err } - var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverUDPTransport( dialer, endpoint) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil case "dot": config.TLSConfig.NextProtos = []string{"dot"} tlsDialer := NewTLSDialer(config) @@ -285,20 +285,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, if err != nil { return nil, err } - var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTLSTransport( tlsDialer.DialTLSContext, endpoint) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil case "tcp": dialer := NewDialer(config) endpoint, err := makeValidEndpoint(resolverURL) if err != nil { return nil, err } - var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTCPTransport( dialer.DialContext, endpoint) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil default: return nil, errors.New("unsupported resolver scheme") } diff --git a/internal/engine/netx/resolver/integration_test.go b/internal/engine/netx/resolver/integration_test.go index 1bb9606..822ce35 100644 --- a/internal/engine/netx/resolver/integration_test.go +++ b/internal/engine/netx/resolver/integration_test.go @@ -70,50 +70,50 @@ func TestNewResolverSystem(t *testing.T) { } func TestNewResolverUDPAddress(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverUDPDomain(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverTCPAddress(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverTCPDomain(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverDoTAddress(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverDoTDomain(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverDoH(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query")) testresolverquick(t, reso) testresolverquickidna(t, reso) } diff --git a/internal/measurex/resolver.go b/internal/measurex/resolver.go index 944e169..9cfe3f9 100644 --- a/internal/measurex/resolver.go +++ b/internal/measurex/resolver.go @@ -43,8 +43,8 @@ func (mx *Measurer) NewResolverSystem(db WritableDB, logger model.Logger) model. // - address is the resolver address (e.g., "1.1.1.1:53"). func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver { return mx.WrapResolver(db, netxlite.WrapResolver( - logger, netxlite.NewSerialResolver( - mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDPTransport( + logger, netxlite.NewUnwrappedSerialResolver( + mx.WrapDNSXRoundTripper(db, netxlite.NewUnwrappedDNSOverUDPTransport( mx.NewDialerWithSystemResolver(db, logger), address, )))), diff --git a/internal/model/netx.go b/internal/model/netx.go index 326acdc..f33240f 100644 --- a/internal/model/netx.go +++ b/internal/model/netx.go @@ -101,6 +101,12 @@ type DNSEncoder interface { Encode(domain string, qtype uint16, padding bool) DNSQuery } +// DNSTransportWrapper is a type that takes in input a DNSTransport +// and returns in output a wrapped DNSTransport. +type DNSTransportWrapper interface { + WrapDNSTransport(txp DNSTransport) DNSTransport +} + // DNSTransport represents an abstract DNS transport. type DNSTransport interface { // RoundTrip sends a DNS query and receives the reply. diff --git a/internal/netxlite/dnsoverhttps.go b/internal/netxlite/dnsoverhttps.go index 196bc28..3a7cec9 100644 --- a/internal/netxlite/dnsoverhttps.go +++ b/internal/netxlite/dnsoverhttps.go @@ -31,20 +31,21 @@ type DNSOverHTTPSTransport struct { HostOverride string } -// NewDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport instance. +// NewUnwrappedDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport +// instance that has not been wrapped yet. // // Arguments: // // - client is a model.HTTPClient type; // // - URL is the DoH resolver URL (e.g., https://dns.google/dns-query). -func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport { - return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "") +func NewUnwrappedDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport { + return NewUnwrappedDNSOverHTTPSTransportWithHostOverride(client, URL, "") } -// NewDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport -// with the given Host header override. -func NewDNSOverHTTPSTransportWithHostOverride( +// NewUnwrappedDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport +// with the given Host header override. This instance has not been wrapped yet. +func NewUnwrappedDNSOverHTTPSTransportWithHostOverride( client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport { return &DNSOverHTTPSTransport{ Client: client, diff --git a/internal/netxlite/dnsoverhttps_test.go b/internal/netxlite/dnsoverhttps_test.go index 57480f0..c6bbf23 100644 --- a/internal/netxlite/dnsoverhttps_test.go +++ b/internal/netxlite/dnsoverhttps_test.go @@ -16,7 +16,7 @@ import ( func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { t.Run("query serialization failure", func(t *testing.T) { - txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query") + txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query") expected := errors.New("mocked error") query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { @@ -34,7 +34,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("NewRequestFailure", func(t *testing.T) { const invalidURL = "\t" - txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL) + txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, invalidURL) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return make([]byte, 17), nil @@ -293,7 +293,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("other functions behave correctly", func(t *testing.T) { const queryURL = "https://cloudflare-dns.com/dns-query" - txp := NewDNSOverHTTPSTransport(http.DefaultClient, queryURL) + txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, queryURL) if txp.Network() != "doh" { t.Fatal("invalid network") } diff --git a/internal/netxlite/dnsovertcp.go b/internal/netxlite/dnsovertcp.go index 1f174c7..b3cde88 100644 --- a/internal/netxlite/dnsovertcp.go +++ b/internal/netxlite/dnsovertcp.go @@ -31,25 +31,27 @@ type DNSOverTCPTransport struct { requiresPadding bool } -// NewDNSOverTCPTransport creates a new DNSOverTCPTransport. +// NewUnwrappedDNSOverTCPTransport creates a new DNSOverTCPTransport +// that has not been wrapped yet. // // Arguments: // // - dial is a function with the net.Dialer.DialContext's signature; // // - address is the endpoint address (e.g., 8.8.8.8:53). -func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { +func NewUnwrappedDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false) } -// NewDNSOverTLSTransport creates a new DNSOverTLS transport. +// NewUnwrappedDNSOverTLSTransport creates a new DNSOverTLS transport +// that has not been wrapped yet. // // Arguments: // // - dial is a function with the net.Dialer.DialContext's signature; // // - address is the endpoint address (e.g., 8.8.8.8:853). -func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { +func NewUnwrappedDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { return newDNSOverTCPOrTLSTransport(dial, "dot", address, true) } diff --git a/internal/netxlite/dnsovertcp_test.go b/internal/netxlite/dnsovertcp_test.go index 3be1391..83f826d 100644 --- a/internal/netxlite/dnsovertcp_test.go +++ b/internal/netxlite/dnsovertcp_test.go @@ -20,7 +20,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("cannot encode query", func(t *testing.T) { expected := errors.New("mocked error") const address = "9.9.9.9:53" - txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return nil, expected @@ -37,7 +37,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("query too large", func(t *testing.T) { const address = "9.9.9.9:53" - txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return make([]byte, math.MaxUint16+1), nil @@ -65,7 +65,7 @@ func TestDNSOverTCPTransport(t *testing.T) { return nil, mocked }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -98,7 +98,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -134,7 +134,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -176,7 +176,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -211,7 +211,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) txp.decoder = &mocks.DNSDecoder{ MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { return nil, mocked @@ -250,7 +250,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) expectedResp := &mocks.DNSResponse{} txp.decoder = &mocks.DNSDecoder{ MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { @@ -269,7 +269,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("other functions okay with TCP", func(t *testing.T) { const address = "9.9.9.9:53" - txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address) if txp.RequiresPadding() != false { t.Fatal("invalid RequiresPadding") } @@ -284,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("other functions okay with TLS", func(t *testing.T) { const address = "9.9.9.9:853" - txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address) + txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, address) if txp.RequiresPadding() != true { t.Fatal("invalid RequiresPadding") } diff --git a/internal/netxlite/dnsoverudp.go b/internal/netxlite/dnsoverudp.go index fa9d0e4..4f453c9 100644 --- a/internal/netxlite/dnsoverudp.go +++ b/internal/netxlite/dnsoverudp.go @@ -51,7 +51,8 @@ type DNSOverUDPTransport struct { IOTimeout time.Duration } -// NewDNSOverUDPTransport creates a DNSOverUDPTransport instance. +// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance +// that has not been wrapped yet. // // Arguments: // @@ -64,7 +65,7 @@ type DNSOverUDPTransport struct { // IP addresses returned by the underlying DNS lookup performed using // the dialer. This usage pattern is NOT RECOMMENDED because we'll // have less control over which IP address is being used. -func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { +func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { return &DNSOverUDPTransport{ Decoder: &DNSDecoderMiekg{}, Dialer: dialer, diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index f79b2f7..2e14a01 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -21,7 +21,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("cannot encode query", func(t *testing.T) { expected := errors.New("mocked error") const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(nil, address) + txp := NewUnwrappedDNSOverUDPTransport(nil, address) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return nil, expected @@ -39,7 +39,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("dial failure", func(t *testing.T) { mocked := errors.New("mocked error") const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(&mocks.Dialer{ + txp := NewUnwrappedDNSOverUDPTransport(&mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, mocked }, @@ -60,7 +60,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("Write failure", func(t *testing.T) { mocked := errors.New("mocked error") - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -103,7 +103,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("Read failure", func(t *testing.T) { mocked := errors.New("mocked error") - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -150,7 +150,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("decode failure", func(t *testing.T) { const expected = 17 input := bytes.NewReader(make([]byte, expected)) - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -201,7 +201,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("decode success", func(t *testing.T) { const expected = 17 input := bytes.NewReader(make([]byte, expected)) - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -264,7 +264,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) resp, err := txp.RoundTrip(context.Background(), query) @@ -297,7 +297,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) ctx := context.Background() @@ -332,7 +332,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) ctx := context.Background() @@ -359,7 +359,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) @@ -413,7 +413,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) @@ -440,7 +440,7 @@ func TestDNSOverUDPTransport(t *testing.T) { }, } const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(dialer, address) + txp := NewUnwrappedDNSOverUDPTransport(dialer, address) txp.CloseIdleConnections() if !called { t.Fatal("not called") @@ -449,7 +449,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("other functions okay", func(t *testing.T) { const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address) + txp := NewUnwrappedDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address) if txp.RequiresPadding() != false { t.Fatal("invalid RequiresPadding") } diff --git a/internal/netxlite/dnstransport.go b/internal/netxlite/dnstransport.go new file mode 100644 index 0000000..8a779ef --- /dev/null +++ b/internal/netxlite/dnstransport.go @@ -0,0 +1,60 @@ +package netxlite + +// +// Generic DNS transport code. +// + +import ( + "context" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// WrapDNSTransport wraps a DNSTransport to provide error wrapping. This function will +// apply all the provided wrappers around the default transport wrapping. If any of the +// wrappers is nil, we just silently and gracefully ignore it. +func WrapDNSTransport(txp model.DNSTransport, + wrappers ...model.DNSTransportWrapper) (out model.DNSTransport) { + out = &dnsTransportErrWrapper{ + DNSTransport: txp, + } + for _, wrapper := range wrappers { + if wrapper == nil { + continue // skip as documented + } + out = wrapper.WrapDNSTransport(out) // compose with user-provided wrappers + } + return +} + +// dnsTransportErrWrapper wraps DNSTransport to provide error wrapping. +type dnsTransportErrWrapper struct { + DNSTransport model.DNSTransport +} + +var _ model.DNSTransport = &dnsTransportErrWrapper{} + +func (t *dnsTransportErrWrapper) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + resp, err := t.DNSTransport.RoundTrip(ctx, query) + if err != nil { + return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err) + } + return resp, nil +} + +func (t *dnsTransportErrWrapper) RequiresPadding() bool { + return t.DNSTransport.RequiresPadding() +} + +func (t *dnsTransportErrWrapper) Network() string { + return t.DNSTransport.Network() +} + +func (t *dnsTransportErrWrapper) Address() string { + return t.DNSTransport.Address() +} + +func (t *dnsTransportErrWrapper) CloseIdleConnections() { + t.DNSTransport.CloseIdleConnections() +} diff --git a/internal/netxlite/dnstransport_test.go b/internal/netxlite/dnstransport_test.go new file mode 100644 index 0000000..7be833c --- /dev/null +++ b/internal/netxlite/dnstransport_test.go @@ -0,0 +1,156 @@ +package netxlite + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +type dnsTransportExtensionFirst struct { + model.DNSTransport +} + +type dnsTransportWrapperFirst struct{} + +func (*dnsTransportWrapperFirst) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport { + return &dnsTransportExtensionFirst{txp} +} + +type dnsTransportExtensionSecond struct { + model.DNSTransport +} + +type dnsTransportWrapperSecond struct{} + +func (*dnsTransportWrapperSecond) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport { + return &dnsTransportExtensionSecond{txp} +} + +func TestWrapDNSTransport(t *testing.T) { + orig := &mocks.DNSTransport{} + extensions := []model.DNSTransportWrapper{ + &dnsTransportWrapperFirst{}, + nil, // explicitly test for documented use case + &dnsTransportWrapperSecond{}, + } + txp := WrapDNSTransport(orig, extensions...) + ext2 := txp.(*dnsTransportExtensionSecond) + ext1 := ext2.DNSTransport.(*dnsTransportExtensionFirst) + errWrapper := ext1.DNSTransport.(*dnsTransportErrWrapper) + underlying := errWrapper.DNSTransport + if orig != underlying { + t.Fatal("unexpected underlying transport") + } +} + +func TestDNSTransportErrWrapper(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expectedResp := &mocks.DNSResponse{} + child := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return expectedResp, nil + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + query := &mocks.DNSQuery{} + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if err != nil { + t.Fatal(err) + } + if resp != expectedResp { + t.Fatal("unexpected resp") + } + }) + + t.Run("on failure", func(t *testing.T) { + expectedErr := io.EOF + child := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expectedErr + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + query := &mocks.DNSQuery{} + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if !errors.Is(err, expectedErr) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("unexpected resp") + } + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("error has not been wrapped") + } + }) + }) + + t.Run("RequiresPadding", func(t *testing.T) { + child := &mocks.DNSTransport{ + MockRequiresPadding: func() bool { + return true + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + if !txp.RequiresPadding() { + t.Fatal("expected true") + } + }) + + t.Run("Network", func(t *testing.T) { + child := &mocks.DNSTransport{ + MockNetwork: func() string { + return "x" + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + if txp.Network() != "x" { + t.Fatal("unexpected Network") + } + }) + + t.Run("Address", func(t *testing.T) { + child := &mocks.DNSTransport{ + MockAddress: func() string { + return "x" + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + if txp.Address() != "x" { + t.Fatal("unexpected Address") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.DNSTransport{ + MockCloseIdleConnections: func() { + called = true + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + txp.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) +} diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 57052ae..da172a8 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -104,7 +104,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { t.Run("on success", func(t *testing.T) { dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, "8.8.4.4:53") + r := netxlite.NewParallelResolverUDP(log.Log, dlr, "8.8.4.4:53") defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "dns.google.com") @@ -128,7 +128,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { } defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String()) defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "ooni.org") @@ -152,7 +152,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { } defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String()) defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "ooni.org") @@ -176,7 +176,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { } defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String()) defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "ooni.org") diff --git a/internal/netxlite/operations.go b/internal/netxlite/operations.go index 773659b..2ff4e09 100644 --- a/internal/netxlite/operations.go +++ b/internal/netxlite/operations.go @@ -13,6 +13,9 @@ const ( // ConnectOperation is the operation where we do a TCP connect. ConnectOperation = "connect" + // DNSRoundTripOperation is the DNS round trip. + DNSRoundTripOperation = "dns_round_trip" + // TLSHandshakeOperation is the TLS handshake. TLSHandshakeOperation = "tls_handshake" diff --git a/internal/netxlite/resolvercore.go b/internal/netxlite/resolvercore.go index c4d0009..7640d99 100644 --- a/internal/netxlite/resolvercore.go +++ b/internal/netxlite/resolvercore.go @@ -23,18 +23,23 @@ import ( var ErrNoDNSTransport = errors.New("operation requires a DNS transport") // NewResolverStdlib creates a new Resolver by combining WrapResolver -// with an internal "system" resolver type. -func NewResolverStdlib(logger model.DebugLogger) model.Resolver { - return WrapResolver(logger, newResolverSystem()) +// with an internal "system" resolver type. The list of optional wrappers +// allow to wrap the underlying getaddrinfo transport. Any nil wrapper +// will be silently ignored by the code that performs the wrapping. +func NewResolverStdlib(logger model.DebugLogger, wrappers ...model.DNSTransportWrapper) model.Resolver { + return WrapResolver(logger, newResolverSystem(wrappers...)) } -func newResolverSystem() *resolverSystem { +func newResolverSystem(wrappers ...model.DNSTransportWrapper) *resolverSystem { return &resolverSystem{ - t: &dnsOverGetaddrinfoTransport{}, + t: WrapDNSTransport(&dnsOverGetaddrinfoTransport{}, wrappers...), } } -// NewResolverUDP creates a new Resolver using DNS-over-UDP. +// NewSerialResolverUDP creates a new Resolver using DNS-over-UDP +// that performs serial A/AAAA lookups during LookupHost. +// +// Deprecated: use NewParallelResolverUDP. // // Arguments: // @@ -43,9 +48,33 @@ func newResolverSystem() *resolverSystem { // - dialer is the dialer to create and connect UDP conns // // - address is the server address (e.g., 1.1.1.1:53) -func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { - return WrapResolver(logger, NewSerialResolver( - NewDNSOverUDPTransport(dialer, address), +// +// - wrappers is the optional list of wrappers to wrap the underlying +// transport. Any nil wrapper will be silently ignored. +func NewSerialResolverUDP(logger model.DebugLogger, dialer model.Dialer, + address string, wrappers ...model.DNSTransportWrapper) model.Resolver { + return WrapResolver(logger, NewUnwrappedSerialResolver( + WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...), + )) +} + +// NewParallelResolverUDP creates a new Resolver using DNS-over-UDP +// that performs parallel A/AAAA lookups during LookupHost. +// +// Arguments: +// +// - logger is the logger to use +// +// - dialer is the dialer to create and connect UDP conns +// +// - address is the server address (e.g., 1.1.1.1:53) +// +// - wrappers is the optional list of wrappers to wrap the underlying +// transport. Any nil wrapper will be silently ignored. +func NewParallelResolverUDP(logger model.DebugLogger, dialer model.Dialer, + address string, wrappers ...model.DNSTransportWrapper) model.Resolver { + return WrapResolver(logger, NewUnwrappedParallelResolver( + WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...), )) } diff --git a/internal/netxlite/resolvercore_test.go b/internal/netxlite/resolvercore_test.go index a4ee4db..b4b0a8a 100644 --- a/internal/netxlite/resolvercore_test.go +++ b/internal/netxlite/resolvercore_test.go @@ -25,12 +25,13 @@ func TestNewResolverSystem(t *testing.T) { shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) reso := errWrapper.Resolver.(*resolverSystem) - _ = reso.t.(*dnsOverGetaddrinfoTransport) + txpErrWrapper := reso.t.(*dnsTransportErrWrapper) + _ = txpErrWrapper.DNSTransport.(*dnsOverGetaddrinfoTransport) } -func TestNewResolverUDP(t *testing.T) { +func TestNewSerialResolverUDP(t *testing.T) { d := NewDialerWithoutResolver(log.Log) - resolver := NewResolverUDP(log.Log, d, "1.1.1.1:53") + resolver := NewSerialResolverUDP(log.Log, d, "1.1.1.1:53") idna := resolver.(*resolverIDNA) logger := idna.Resolver.(*resolverLogger) if logger.Logger != log.Log { @@ -39,8 +40,27 @@ func TestNewResolverUDP(t *testing.T) { shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) serio := errWrapper.Resolver.(*SerialResolver) - txp := serio.Transport().(*DNSOverUDPTransport) - if txp.Address() != "1.1.1.1:53" { + txp := serio.Transport().(*dnsTransportErrWrapper) + dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport) + if dnsTxp.Address() != "1.1.1.1:53" { + t.Fatal("invalid address") + } +} + +func TestNewParallelResolverUDP(t *testing.T) { + d := NewDialerWithoutResolver(log.Log) + resolver := NewParallelResolverUDP(log.Log, d, "1.1.1.1:53") + idna := resolver.(*resolverIDNA) + logger := idna.Resolver.(*resolverLogger) + if logger.Logger != log.Log { + t.Fatal("invalid logger") + } + shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) + errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) + para := errWrapper.Resolver.(*ParallelResolver) + txp := para.Transport().(*dnsTransportErrWrapper) + dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport) + if dnsTxp.Address() != "1.1.1.1:53" { t.Fatal("invalid address") } } diff --git a/internal/netxlite/resolverparallel_test.go b/internal/netxlite/resolverparallel_test.go index f208623..73504b4 100644 --- a/internal/netxlite/resolverparallel_test.go +++ b/internal/netxlite/resolverparallel_test.go @@ -14,7 +14,7 @@ import ( func TestParallelResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { - txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") + txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") r := NewUnwrappedParallelResolver(txp) rtx := r.Transport() if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { diff --git a/internal/netxlite/resolverserial.go b/internal/netxlite/resolverserial.go index 316ca88..9ce54ec 100644 --- a/internal/netxlite/resolverserial.go +++ b/internal/netxlite/resolverserial.go @@ -33,8 +33,8 @@ type SerialResolver struct { Txp model.DNSTransport } -// NewSerialResolver creates a new SerialResolver instance. -func NewSerialResolver(t model.DNSTransport) *SerialResolver { +// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance. +func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver { return &SerialResolver{ NumTimeouts: &atomicx.Int64{}, Txp: t, diff --git a/internal/netxlite/resolverserial_test.go b/internal/netxlite/resolverserial_test.go index ed9df5f..360faae 100644 --- a/internal/netxlite/resolverserial_test.go +++ b/internal/netxlite/resolverserial_test.go @@ -31,8 +31,8 @@ func (err *errorWithTimeout) Unwrap() error { func TestSerialResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { - txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") - r := NewSerialResolver(txp) + txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") + r := NewUnwrappedSerialResolver(txp) rtx := r.Transport() if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { t.Fatal("not the transport we expected") @@ -56,7 +56,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -80,7 +80,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if !errors.Is(err, ErrOODNSNoAnswer) { t.Fatal("not the error we expected", err) @@ -107,7 +107,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if err != nil { t.Fatal(err) @@ -134,7 +134,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if err != nil { t.Fatal(err) @@ -157,7 +157,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if !errors.Is(err, ETIMEDOUT) { t.Fatal("not the error we expected") diff --git a/internal/tutorial/netxlite/chapter06/main.go b/internal/tutorial/netxlite/chapter06/main.go index 65745df..0e5a9a4 100644 --- a/internal/tutorial/netxlite/chapter06/main.go +++ b/internal/tutorial/netxlite/chapter06/main.go @@ -54,7 +54,7 @@ func main() { // UDP endpoint address at which the server is listening. // // ```Go - reso := netxlite.NewResolverUDP(log.Log, dialer, *serverAddr) + reso := netxlite.NewParallelResolverUDP(log.Log, dialer, *serverAddr) // ``` // // The API we invoke is the same as in the previous chapter, though,