diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 20df40a..43b3a8f 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -264,20 +264,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, case "https": config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} httpClient := &http.Client{Transport: NewHTTPTransport(config)} - var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverHTTPSTransportWithHostOverride( httpClient, URL, hostOverride) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil case "udp": dialer := NewDialer(config) endpoint, err := makeValidEndpoint(resolverURL) if err != nil { return nil, err } - var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverUDPTransport( dialer, endpoint) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil case "dot": config.TLSConfig.NextProtos = []string{"dot"} tlsDialer := NewTLSDialer(config) @@ -285,20 +285,20 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, if err != nil { return nil, err } - var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTLSTransport( tlsDialer.DialTLSContext, endpoint) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil case "tcp": dialer := NewDialer(config) endpoint, err := makeValidEndpoint(resolverURL) if err != nil { return nil, err } - var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport( + var txp model.DNSTransport = netxlite.NewUnwrappedDNSOverTCPTransport( dialer.DialContext, endpoint) txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil - return netxlite.NewSerialResolver(txp), nil + return netxlite.NewUnwrappedSerialResolver(txp), nil default: return nil, errors.New("unsupported resolver scheme") } diff --git a/internal/engine/netx/resolver/integration_test.go b/internal/engine/netx/resolver/integration_test.go index 1bb9606..822ce35 100644 --- a/internal/engine/netx/resolver/integration_test.go +++ b/internal/engine/netx/resolver/integration_test.go @@ -70,50 +70,50 @@ func TestNewResolverSystem(t *testing.T) { } func TestNewResolverUDPAddress(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverUDPDomain(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverTCPAddress(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverTCPDomain(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverDoTAddress(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverDoTDomain(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverDoH(t *testing.T) { - reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query")) + reso := netxlite.NewUnwrappedSerialResolver( + netxlite.NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query")) testresolverquick(t, reso) testresolverquickidna(t, reso) } diff --git a/internal/measurex/resolver.go b/internal/measurex/resolver.go index 944e169..9cfe3f9 100644 --- a/internal/measurex/resolver.go +++ b/internal/measurex/resolver.go @@ -43,8 +43,8 @@ func (mx *Measurer) NewResolverSystem(db WritableDB, logger model.Logger) model. // - address is the resolver address (e.g., "1.1.1.1:53"). func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver { return mx.WrapResolver(db, netxlite.WrapResolver( - logger, netxlite.NewSerialResolver( - mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDPTransport( + logger, netxlite.NewUnwrappedSerialResolver( + mx.WrapDNSXRoundTripper(db, netxlite.NewUnwrappedDNSOverUDPTransport( mx.NewDialerWithSystemResolver(db, logger), address, )))), diff --git a/internal/model/netx.go b/internal/model/netx.go index 326acdc..f33240f 100644 --- a/internal/model/netx.go +++ b/internal/model/netx.go @@ -101,6 +101,12 @@ type DNSEncoder interface { Encode(domain string, qtype uint16, padding bool) DNSQuery } +// DNSTransportWrapper is a type that takes in input a DNSTransport +// and returns in output a wrapped DNSTransport. +type DNSTransportWrapper interface { + WrapDNSTransport(txp DNSTransport) DNSTransport +} + // DNSTransport represents an abstract DNS transport. type DNSTransport interface { // RoundTrip sends a DNS query and receives the reply. diff --git a/internal/netxlite/dnsoverhttps.go b/internal/netxlite/dnsoverhttps.go index 196bc28..3a7cec9 100644 --- a/internal/netxlite/dnsoverhttps.go +++ b/internal/netxlite/dnsoverhttps.go @@ -31,20 +31,21 @@ type DNSOverHTTPSTransport struct { HostOverride string } -// NewDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport instance. +// NewUnwrappedDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport +// instance that has not been wrapped yet. // // Arguments: // // - client is a model.HTTPClient type; // // - URL is the DoH resolver URL (e.g., https://dns.google/dns-query). -func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport { - return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "") +func NewUnwrappedDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport { + return NewUnwrappedDNSOverHTTPSTransportWithHostOverride(client, URL, "") } -// NewDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport -// with the given Host header override. -func NewDNSOverHTTPSTransportWithHostOverride( +// NewUnwrappedDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport +// with the given Host header override. This instance has not been wrapped yet. +func NewUnwrappedDNSOverHTTPSTransportWithHostOverride( client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport { return &DNSOverHTTPSTransport{ Client: client, diff --git a/internal/netxlite/dnsoverhttps_test.go b/internal/netxlite/dnsoverhttps_test.go index 57480f0..c6bbf23 100644 --- a/internal/netxlite/dnsoverhttps_test.go +++ b/internal/netxlite/dnsoverhttps_test.go @@ -16,7 +16,7 @@ import ( func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { t.Run("query serialization failure", func(t *testing.T) { - txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query") + txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query") expected := errors.New("mocked error") query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { @@ -34,7 +34,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("NewRequestFailure", func(t *testing.T) { const invalidURL = "\t" - txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL) + txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, invalidURL) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return make([]byte, 17), nil @@ -293,7 +293,7 @@ func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("other functions behave correctly", func(t *testing.T) { const queryURL = "https://cloudflare-dns.com/dns-query" - txp := NewDNSOverHTTPSTransport(http.DefaultClient, queryURL) + txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, queryURL) if txp.Network() != "doh" { t.Fatal("invalid network") } diff --git a/internal/netxlite/dnsovertcp.go b/internal/netxlite/dnsovertcp.go index 1f174c7..b3cde88 100644 --- a/internal/netxlite/dnsovertcp.go +++ b/internal/netxlite/dnsovertcp.go @@ -31,25 +31,27 @@ type DNSOverTCPTransport struct { requiresPadding bool } -// NewDNSOverTCPTransport creates a new DNSOverTCPTransport. +// NewUnwrappedDNSOverTCPTransport creates a new DNSOverTCPTransport +// that has not been wrapped yet. // // Arguments: // // - dial is a function with the net.Dialer.DialContext's signature; // // - address is the endpoint address (e.g., 8.8.8.8:53). -func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { +func NewUnwrappedDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false) } -// NewDNSOverTLSTransport creates a new DNSOverTLS transport. +// NewUnwrappedDNSOverTLSTransport creates a new DNSOverTLS transport +// that has not been wrapped yet. // // Arguments: // // - dial is a function with the net.Dialer.DialContext's signature; // // - address is the endpoint address (e.g., 8.8.8.8:853). -func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { +func NewUnwrappedDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { return newDNSOverTCPOrTLSTransport(dial, "dot", address, true) } diff --git a/internal/netxlite/dnsovertcp_test.go b/internal/netxlite/dnsovertcp_test.go index 3be1391..83f826d 100644 --- a/internal/netxlite/dnsovertcp_test.go +++ b/internal/netxlite/dnsovertcp_test.go @@ -20,7 +20,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("cannot encode query", func(t *testing.T) { expected := errors.New("mocked error") const address = "9.9.9.9:53" - txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return nil, expected @@ -37,7 +37,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("query too large", func(t *testing.T) { const address = "9.9.9.9:53" - txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return make([]byte, math.MaxUint16+1), nil @@ -65,7 +65,7 @@ func TestDNSOverTCPTransport(t *testing.T) { return nil, mocked }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -98,7 +98,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -134,7 +134,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -176,7 +176,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -211,7 +211,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) txp.decoder = &mocks.DNSDecoder{ MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { return nil, mocked @@ -250,7 +250,7 @@ func TestDNSOverTCPTransport(t *testing.T) { }, nil }, } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address) expectedResp := &mocks.DNSResponse{} txp.decoder = &mocks.DNSDecoder{ MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { @@ -269,7 +269,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("other functions okay with TCP", func(t *testing.T) { const address = "9.9.9.9:53" - txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) + txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address) if txp.RequiresPadding() != false { t.Fatal("invalid RequiresPadding") } @@ -284,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("other functions okay with TLS", func(t *testing.T) { const address = "9.9.9.9:853" - txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address) + txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, address) if txp.RequiresPadding() != true { t.Fatal("invalid RequiresPadding") } diff --git a/internal/netxlite/dnsoverudp.go b/internal/netxlite/dnsoverudp.go index fa9d0e4..4f453c9 100644 --- a/internal/netxlite/dnsoverudp.go +++ b/internal/netxlite/dnsoverudp.go @@ -51,7 +51,8 @@ type DNSOverUDPTransport struct { IOTimeout time.Duration } -// NewDNSOverUDPTransport creates a DNSOverUDPTransport instance. +// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance +// that has not been wrapped yet. // // Arguments: // @@ -64,7 +65,7 @@ type DNSOverUDPTransport struct { // IP addresses returned by the underlying DNS lookup performed using // the dialer. This usage pattern is NOT RECOMMENDED because we'll // have less control over which IP address is being used. -func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { +func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { return &DNSOverUDPTransport{ Decoder: &DNSDecoderMiekg{}, Dialer: dialer, diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index f79b2f7..2e14a01 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -21,7 +21,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("cannot encode query", func(t *testing.T) { expected := errors.New("mocked error") const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(nil, address) + txp := NewUnwrappedDNSOverUDPTransport(nil, address) query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { return nil, expected @@ -39,7 +39,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("dial failure", func(t *testing.T) { mocked := errors.New("mocked error") const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(&mocks.Dialer{ + txp := NewUnwrappedDNSOverUDPTransport(&mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, mocked }, @@ -60,7 +60,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("Write failure", func(t *testing.T) { mocked := errors.New("mocked error") - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -103,7 +103,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("Read failure", func(t *testing.T) { mocked := errors.New("mocked error") - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -150,7 +150,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("decode failure", func(t *testing.T) { const expected = 17 input := bytes.NewReader(make([]byte, expected)) - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -201,7 +201,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("decode success", func(t *testing.T) { const expected = 17 input := bytes.NewReader(make([]byte, expected)) - txp := NewDNSOverUDPTransport( + txp := NewUnwrappedDNSOverUDPTransport( &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -264,7 +264,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) resp, err := txp.RoundTrip(context.Background(), query) @@ -297,7 +297,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) ctx := context.Background() @@ -332,7 +332,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) ctx := context.Background() @@ -359,7 +359,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) @@ -413,7 +413,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) @@ -440,7 +440,7 @@ func TestDNSOverUDPTransport(t *testing.T) { }, } const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(dialer, address) + txp := NewUnwrappedDNSOverUDPTransport(dialer, address) txp.CloseIdleConnections() if !called { t.Fatal("not called") @@ -449,7 +449,7 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("other functions okay", func(t *testing.T) { const address = "9.9.9.9:53" - txp := NewDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address) + txp := NewUnwrappedDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address) if txp.RequiresPadding() != false { t.Fatal("invalid RequiresPadding") } diff --git a/internal/netxlite/dnstransport.go b/internal/netxlite/dnstransport.go new file mode 100644 index 0000000..8a779ef --- /dev/null +++ b/internal/netxlite/dnstransport.go @@ -0,0 +1,60 @@ +package netxlite + +// +// Generic DNS transport code. +// + +import ( + "context" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// WrapDNSTransport wraps a DNSTransport to provide error wrapping. This function will +// apply all the provided wrappers around the default transport wrapping. If any of the +// wrappers is nil, we just silently and gracefully ignore it. +func WrapDNSTransport(txp model.DNSTransport, + wrappers ...model.DNSTransportWrapper) (out model.DNSTransport) { + out = &dnsTransportErrWrapper{ + DNSTransport: txp, + } + for _, wrapper := range wrappers { + if wrapper == nil { + continue // skip as documented + } + out = wrapper.WrapDNSTransport(out) // compose with user-provided wrappers + } + return +} + +// dnsTransportErrWrapper wraps DNSTransport to provide error wrapping. +type dnsTransportErrWrapper struct { + DNSTransport model.DNSTransport +} + +var _ model.DNSTransport = &dnsTransportErrWrapper{} + +func (t *dnsTransportErrWrapper) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + resp, err := t.DNSTransport.RoundTrip(ctx, query) + if err != nil { + return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err) + } + return resp, nil +} + +func (t *dnsTransportErrWrapper) RequiresPadding() bool { + return t.DNSTransport.RequiresPadding() +} + +func (t *dnsTransportErrWrapper) Network() string { + return t.DNSTransport.Network() +} + +func (t *dnsTransportErrWrapper) Address() string { + return t.DNSTransport.Address() +} + +func (t *dnsTransportErrWrapper) CloseIdleConnections() { + t.DNSTransport.CloseIdleConnections() +} diff --git a/internal/netxlite/dnstransport_test.go b/internal/netxlite/dnstransport_test.go new file mode 100644 index 0000000..7be833c --- /dev/null +++ b/internal/netxlite/dnstransport_test.go @@ -0,0 +1,156 @@ +package netxlite + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +type dnsTransportExtensionFirst struct { + model.DNSTransport +} + +type dnsTransportWrapperFirst struct{} + +func (*dnsTransportWrapperFirst) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport { + return &dnsTransportExtensionFirst{txp} +} + +type dnsTransportExtensionSecond struct { + model.DNSTransport +} + +type dnsTransportWrapperSecond struct{} + +func (*dnsTransportWrapperSecond) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport { + return &dnsTransportExtensionSecond{txp} +} + +func TestWrapDNSTransport(t *testing.T) { + orig := &mocks.DNSTransport{} + extensions := []model.DNSTransportWrapper{ + &dnsTransportWrapperFirst{}, + nil, // explicitly test for documented use case + &dnsTransportWrapperSecond{}, + } + txp := WrapDNSTransport(orig, extensions...) + ext2 := txp.(*dnsTransportExtensionSecond) + ext1 := ext2.DNSTransport.(*dnsTransportExtensionFirst) + errWrapper := ext1.DNSTransport.(*dnsTransportErrWrapper) + underlying := errWrapper.DNSTransport + if orig != underlying { + t.Fatal("unexpected underlying transport") + } +} + +func TestDNSTransportErrWrapper(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expectedResp := &mocks.DNSResponse{} + child := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return expectedResp, nil + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + query := &mocks.DNSQuery{} + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if err != nil { + t.Fatal(err) + } + if resp != expectedResp { + t.Fatal("unexpected resp") + } + }) + + t.Run("on failure", func(t *testing.T) { + expectedErr := io.EOF + child := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expectedErr + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + query := &mocks.DNSQuery{} + ctx := context.Background() + resp, err := txp.RoundTrip(ctx, query) + if !errors.Is(err, expectedErr) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("unexpected resp") + } + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("error has not been wrapped") + } + }) + }) + + t.Run("RequiresPadding", func(t *testing.T) { + child := &mocks.DNSTransport{ + MockRequiresPadding: func() bool { + return true + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + if !txp.RequiresPadding() { + t.Fatal("expected true") + } + }) + + t.Run("Network", func(t *testing.T) { + child := &mocks.DNSTransport{ + MockNetwork: func() string { + return "x" + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + if txp.Network() != "x" { + t.Fatal("unexpected Network") + } + }) + + t.Run("Address", func(t *testing.T) { + child := &mocks.DNSTransport{ + MockAddress: func() string { + return "x" + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + if txp.Address() != "x" { + t.Fatal("unexpected Address") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.DNSTransport{ + MockCloseIdleConnections: func() { + called = true + }, + } + txp := &dnsTransportErrWrapper{ + DNSTransport: child, + } + txp.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) +} diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 57052ae..da172a8 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -104,7 +104,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { t.Run("on success", func(t *testing.T) { dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, "8.8.4.4:53") + r := netxlite.NewParallelResolverUDP(log.Log, dlr, "8.8.4.4:53") defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "dns.google.com") @@ -128,7 +128,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { } defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String()) defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "ooni.org") @@ -152,7 +152,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { } defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String()) defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "ooni.org") @@ -176,7 +176,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { } defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + r := netxlite.NewParallelResolverUDP(log.Log, dlr, listener.LocalAddr().String()) defer r.CloseIdleConnections() ctx := context.Background() addrs, err := r.LookupHost(ctx, "ooni.org") diff --git a/internal/netxlite/operations.go b/internal/netxlite/operations.go index 773659b..2ff4e09 100644 --- a/internal/netxlite/operations.go +++ b/internal/netxlite/operations.go @@ -13,6 +13,9 @@ const ( // ConnectOperation is the operation where we do a TCP connect. ConnectOperation = "connect" + // DNSRoundTripOperation is the DNS round trip. + DNSRoundTripOperation = "dns_round_trip" + // TLSHandshakeOperation is the TLS handshake. TLSHandshakeOperation = "tls_handshake" diff --git a/internal/netxlite/resolvercore.go b/internal/netxlite/resolvercore.go index c4d0009..7640d99 100644 --- a/internal/netxlite/resolvercore.go +++ b/internal/netxlite/resolvercore.go @@ -23,18 +23,23 @@ import ( var ErrNoDNSTransport = errors.New("operation requires a DNS transport") // NewResolverStdlib creates a new Resolver by combining WrapResolver -// with an internal "system" resolver type. -func NewResolverStdlib(logger model.DebugLogger) model.Resolver { - return WrapResolver(logger, newResolverSystem()) +// with an internal "system" resolver type. The list of optional wrappers +// allow to wrap the underlying getaddrinfo transport. Any nil wrapper +// will be silently ignored by the code that performs the wrapping. +func NewResolverStdlib(logger model.DebugLogger, wrappers ...model.DNSTransportWrapper) model.Resolver { + return WrapResolver(logger, newResolverSystem(wrappers...)) } -func newResolverSystem() *resolverSystem { +func newResolverSystem(wrappers ...model.DNSTransportWrapper) *resolverSystem { return &resolverSystem{ - t: &dnsOverGetaddrinfoTransport{}, + t: WrapDNSTransport(&dnsOverGetaddrinfoTransport{}, wrappers...), } } -// NewResolverUDP creates a new Resolver using DNS-over-UDP. +// NewSerialResolverUDP creates a new Resolver using DNS-over-UDP +// that performs serial A/AAAA lookups during LookupHost. +// +// Deprecated: use NewParallelResolverUDP. // // Arguments: // @@ -43,9 +48,33 @@ func newResolverSystem() *resolverSystem { // - dialer is the dialer to create and connect UDP conns // // - address is the server address (e.g., 1.1.1.1:53) -func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { - return WrapResolver(logger, NewSerialResolver( - NewDNSOverUDPTransport(dialer, address), +// +// - wrappers is the optional list of wrappers to wrap the underlying +// transport. Any nil wrapper will be silently ignored. +func NewSerialResolverUDP(logger model.DebugLogger, dialer model.Dialer, + address string, wrappers ...model.DNSTransportWrapper) model.Resolver { + return WrapResolver(logger, NewUnwrappedSerialResolver( + WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...), + )) +} + +// NewParallelResolverUDP creates a new Resolver using DNS-over-UDP +// that performs parallel A/AAAA lookups during LookupHost. +// +// Arguments: +// +// - logger is the logger to use +// +// - dialer is the dialer to create and connect UDP conns +// +// - address is the server address (e.g., 1.1.1.1:53) +// +// - wrappers is the optional list of wrappers to wrap the underlying +// transport. Any nil wrapper will be silently ignored. +func NewParallelResolverUDP(logger model.DebugLogger, dialer model.Dialer, + address string, wrappers ...model.DNSTransportWrapper) model.Resolver { + return WrapResolver(logger, NewUnwrappedParallelResolver( + WrapDNSTransport(NewUnwrappedDNSOverUDPTransport(dialer, address), wrappers...), )) } diff --git a/internal/netxlite/resolvercore_test.go b/internal/netxlite/resolvercore_test.go index a4ee4db..b4b0a8a 100644 --- a/internal/netxlite/resolvercore_test.go +++ b/internal/netxlite/resolvercore_test.go @@ -25,12 +25,13 @@ func TestNewResolverSystem(t *testing.T) { shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) reso := errWrapper.Resolver.(*resolverSystem) - _ = reso.t.(*dnsOverGetaddrinfoTransport) + txpErrWrapper := reso.t.(*dnsTransportErrWrapper) + _ = txpErrWrapper.DNSTransport.(*dnsOverGetaddrinfoTransport) } -func TestNewResolverUDP(t *testing.T) { +func TestNewSerialResolverUDP(t *testing.T) { d := NewDialerWithoutResolver(log.Log) - resolver := NewResolverUDP(log.Log, d, "1.1.1.1:53") + resolver := NewSerialResolverUDP(log.Log, d, "1.1.1.1:53") idna := resolver.(*resolverIDNA) logger := idna.Resolver.(*resolverLogger) if logger.Logger != log.Log { @@ -39,8 +40,27 @@ func TestNewResolverUDP(t *testing.T) { shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) serio := errWrapper.Resolver.(*SerialResolver) - txp := serio.Transport().(*DNSOverUDPTransport) - if txp.Address() != "1.1.1.1:53" { + txp := serio.Transport().(*dnsTransportErrWrapper) + dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport) + if dnsTxp.Address() != "1.1.1.1:53" { + t.Fatal("invalid address") + } +} + +func TestNewParallelResolverUDP(t *testing.T) { + d := NewDialerWithoutResolver(log.Log) + resolver := NewParallelResolverUDP(log.Log, d, "1.1.1.1:53") + idna := resolver.(*resolverIDNA) + logger := idna.Resolver.(*resolverLogger) + if logger.Logger != log.Log { + t.Fatal("invalid logger") + } + shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) + errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) + para := errWrapper.Resolver.(*ParallelResolver) + txp := para.Transport().(*dnsTransportErrWrapper) + dnsTxp := txp.DNSTransport.(*DNSOverUDPTransport) + if dnsTxp.Address() != "1.1.1.1:53" { t.Fatal("invalid address") } } diff --git a/internal/netxlite/resolverparallel_test.go b/internal/netxlite/resolverparallel_test.go index f208623..73504b4 100644 --- a/internal/netxlite/resolverparallel_test.go +++ b/internal/netxlite/resolverparallel_test.go @@ -14,7 +14,7 @@ import ( func TestParallelResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { - txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") + txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") r := NewUnwrappedParallelResolver(txp) rtx := r.Transport() if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { diff --git a/internal/netxlite/resolverserial.go b/internal/netxlite/resolverserial.go index 316ca88..9ce54ec 100644 --- a/internal/netxlite/resolverserial.go +++ b/internal/netxlite/resolverserial.go @@ -33,8 +33,8 @@ type SerialResolver struct { Txp model.DNSTransport } -// NewSerialResolver creates a new SerialResolver instance. -func NewSerialResolver(t model.DNSTransport) *SerialResolver { +// NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance. +func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver { return &SerialResolver{ NumTimeouts: &atomicx.Int64{}, Txp: t, diff --git a/internal/netxlite/resolverserial_test.go b/internal/netxlite/resolverserial_test.go index ed9df5f..360faae 100644 --- a/internal/netxlite/resolverserial_test.go +++ b/internal/netxlite/resolverserial_test.go @@ -31,8 +31,8 @@ func (err *errorWithTimeout) Unwrap() error { func TestSerialResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { - txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") - r := NewSerialResolver(txp) + txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") + r := NewUnwrappedSerialResolver(txp) rtx := r.Transport() if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { t.Fatal("not the transport we expected") @@ -56,7 +56,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -80,7 +80,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if !errors.Is(err, ErrOODNSNoAnswer) { t.Fatal("not the error we expected", err) @@ -107,7 +107,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if err != nil { t.Fatal(err) @@ -134,7 +134,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if err != nil { t.Fatal(err) @@ -157,7 +157,7 @@ func TestSerialResolver(t *testing.T) { return true }, } - r := NewSerialResolver(txp) + r := NewUnwrappedSerialResolver(txp) addrs, err := r.LookupHost(context.Background(), "www.gogle.com") if !errors.Is(err, ETIMEDOUT) { t.Fatal("not the error we expected") diff --git a/internal/tutorial/netxlite/chapter06/main.go b/internal/tutorial/netxlite/chapter06/main.go index 65745df..0e5a9a4 100644 --- a/internal/tutorial/netxlite/chapter06/main.go +++ b/internal/tutorial/netxlite/chapter06/main.go @@ -54,7 +54,7 @@ func main() { // UDP endpoint address at which the server is listening. // // ```Go - reso := netxlite.NewResolverUDP(log.Log, dialer, *serverAddr) + reso := netxlite.NewParallelResolverUDP(log.Log, dialer, *serverAddr) // ``` // // The API we invoke is the same as in the previous chapter, though,