refactor(netxlite): add Transport suffix to DNS transports (#731)

This diff has been extracted from c2f7ccab0e

See https://github.com/ooni/probe/issues/2096
This commit is contained in:
Simone Basso 2022-05-14 17:38:31 +02:00 committed by GitHub
parent 6c388d2c61
commit f5b801ae95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 94 additions and 94 deletions

View File

@ -124,7 +124,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPS) dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPSTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -200,7 +200,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPS) dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPSTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -276,7 +276,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPS) dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPSTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -352,7 +352,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
udptxp, ok := stxp.DNSTransport.(*netxlite.DNSOverUDP) udptxp, ok := stxp.DNSTransport.(*netxlite.DNSOverUDPTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }

View File

@ -368,5 +368,5 @@ const thResolverURL = "https://dns.google/dns-query"
// Here we're using github.com/apex/log as the logger, which // Here we're using github.com/apex/log as the logger, which
// is fine because this is backend only code. // is fine because this is backend only code.
var thResolver = netxlite.WrapResolver(log.Log, netxlite.NewSerialResolver( var thResolver = netxlite.WrapResolver(log.Log, netxlite.NewSerialResolver(
netxlite.NewDNSOverHTTPS(http.DefaultClient, thResolverURL), netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, thResolverURL),
)) ))

View File

@ -286,7 +286,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
case "https": case "https":
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"} config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
httpClient := &http.Client{Transport: NewHTTPTransport(config)} httpClient := &http.Client{Transport: NewHTTPTransport(config)}
var txp model.DNSTransport = netxlite.NewDNSOverHTTPSWithHostOverride( var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride(
httpClient, URL, hostOverride) httpClient, URL, hostOverride)
if config.ResolveSaver != nil { if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{ txp = resolver.SaverDNSTransport{
@ -301,7 +301,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
if err != nil { if err != nil {
return nil, err return nil, err
} }
var txp model.DNSTransport = netxlite.NewDNSOverUDP( var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport(
dialer, endpoint) dialer, endpoint)
if config.ResolveSaver != nil { if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{ txp = resolver.SaverDNSTransport{
@ -332,7 +332,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
if err != nil { if err != nil {
return nil, err return nil, err
} }
var txp model.DNSTransport = netxlite.NewDNSOverTCP( var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport(
dialer.DialContext, endpoint) dialer.DialContext, endpoint)
if config.ResolveSaver != nil { if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{ txp = resolver.SaverDNSTransport{

View File

@ -586,7 +586,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
if _, ok := r.Transport().(*netxlite.DNSOverHTTPS); !ok { if _, ok := r.Transport().(*netxlite.DNSOverHTTPSTransport); !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -602,7 +602,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
if _, ok := r.Transport().(*netxlite.DNSOverHTTPS); !ok { if _, ok := r.Transport().(*netxlite.DNSOverHTTPSTransport); !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -618,7 +618,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
if _, ok := r.Transport().(*netxlite.DNSOverHTTPS); !ok { if _, ok := r.Transport().(*netxlite.DNSOverHTTPSTransport); !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -639,7 +639,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
if _, ok := txp.DNSTransport.(*netxlite.DNSOverHTTPS); !ok { if _, ok := txp.DNSTransport.(*netxlite.DNSOverHTTPSTransport); !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -655,7 +655,7 @@ func TestNewDNSClientUDP(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
if _, ok := r.Transport().(*netxlite.DNSOverUDP); !ok { if _, ok := r.Transport().(*netxlite.DNSOverUDPTransport); !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -676,7 +676,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
if _, ok := txp.DNSTransport.(*netxlite.DNSOverUDP); !ok { if _, ok := txp.DNSTransport.(*netxlite.DNSOverUDPTransport); !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dnsclient.CloseIdleConnections() dnsclient.CloseIdleConnections()
@ -692,7 +692,7 @@ func TestNewDNSClientTCP(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(*netxlite.DNSOverTCP) txp, ok := r.Transport().(*netxlite.DNSOverTCPTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -717,7 +717,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dotcp, ok := txp.DNSTransport.(*netxlite.DNSOverTCP) dotcp, ok := txp.DNSTransport.(*netxlite.DNSOverTCPTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -737,7 +737,7 @@ func TestNewDNSClientDoT(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(*netxlite.DNSOverTCP) txp, ok := r.Transport().(*netxlite.DNSOverTCPTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -762,7 +762,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
dotls, ok := txp.DNSTransport.(*netxlite.DNSOverTCP) dotls, ok := txp.DNSTransport.(*netxlite.DNSOverTCPTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }

View File

@ -71,28 +71,28 @@ func TestNewResolverSystem(t *testing.T) {
func TestNewResolverUDPAddress(t *testing.T) { func TestNewResolverUDPAddress(t *testing.T) {
reso := netxlite.NewSerialResolver( reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverUDP(netxlite.DefaultDialer, "8.8.8.8:53")) netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53"))
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }
func TestNewResolverUDPDomain(t *testing.T) { func TestNewResolverUDPDomain(t *testing.T) {
reso := netxlite.NewSerialResolver( reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverUDP(netxlite.DefaultDialer, "dns.google.com:53")) netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53"))
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }
func TestNewResolverTCPAddress(t *testing.T) { func TestNewResolverTCPAddress(t *testing.T) {
reso := netxlite.NewSerialResolver( reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTCP(new(net.Dialer).DialContext, "8.8.8.8:53")) netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53"))
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }
func TestNewResolverTCPDomain(t *testing.T) { func TestNewResolverTCPDomain(t *testing.T) {
reso := netxlite.NewSerialResolver( reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTCP(new(net.Dialer).DialContext, "dns.google.com:53")) netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53"))
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }
@ -113,7 +113,7 @@ func TestNewResolverDoTDomain(t *testing.T) {
func TestNewResolverDoH(t *testing.T) { func TestNewResolverDoH(t *testing.T) {
reso := netxlite.NewSerialResolver( reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverHTTPS(http.DefaultClient, "https://cloudflare-dns.com/dns-query")) netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }

View File

@ -44,7 +44,7 @@ func (mx *Measurer) NewResolverSystem(db WritableDB, logger model.Logger) model.
func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver { func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver {
return mx.WrapResolver(db, netxlite.WrapResolver( return mx.WrapResolver(db, netxlite.WrapResolver(
logger, netxlite.NewSerialResolver( logger, netxlite.NewSerialResolver(
mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDP( mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDPTransport(
mx.NewDialerWithSystemResolver(db, logger), mx.NewDialerWithSystemResolver(db, logger),
address, address,
)))), )))),

View File

@ -244,7 +244,7 @@ const (
) )
// These errors are returned by custom DNSTransport instances (e.g., // These errors are returned by custom DNSTransport instances (e.g.,
// DNSOverHTTPS and DNSOverUDP). Their suffix matches the equivalent // DNSOverHTTPSTransport and DNSOverUDPTransport). Their suffix matches the equivalent
// unexported errors used by the Go standard library. // unexported errors used by the Go standard library.
var ( var (
ErrOODNSNoSuchHost = fmt.Errorf("ooniresolver: %s", DNSNoSuchHostSuffix) ErrOODNSNoSuchHost = fmt.Errorf("ooniresolver: %s", DNSNoSuchHostSuffix)

View File

@ -11,8 +11,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
// DNSOverHTTPS is a DNS-over-HTTPS DNSTransport. // DNSOverHTTPSTransport is a DNS-over-HTTPS DNSTransport.
type DNSOverHTTPS struct { type DNSOverHTTPSTransport struct {
// Client is the MANDATORY http client to use. // Client is the MANDATORY http client to use.
Client model.HTTPClient Client model.HTTPClient
@ -24,26 +24,26 @@ type DNSOverHTTPS struct {
HostOverride string HostOverride string
} }
// NewDNSOverHTTPS creates a new DNSOverHTTPS instance. // NewDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport instance.
// //
// Arguments: // Arguments:
// //
// - client in http.Client-like type (e.g., http.DefaultClient); // - client in http.Client-like type (e.g., http.DefaultClient);
// //
// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query). // - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query).
func NewDNSOverHTTPS(client model.HTTPClient, URL string) *DNSOverHTTPS { func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
return NewDNSOverHTTPSWithHostOverride(client, URL, "") return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
} }
// NewDNSOverHTTPSWithHostOverride creates a new DNSOverHTTPS // NewDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport
// with the given Host header override. // with the given Host header override.
func NewDNSOverHTTPSWithHostOverride( func NewDNSOverHTTPSTransportWithHostOverride(
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPS { client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride} return &DNSOverHTTPSTransport{Client: client, URL: URL, HostOverride: hostOverride}
} }
// RoundTrip sends a query and receives a reply. // RoundTrip sends a query and receives a reply.
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 45*time.Second) ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel() defer cancel()
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query)) req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
@ -71,23 +71,23 @@ func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, err
} }
// RequiresPadding returns true for DoH according to RFC8467. // RequiresPadding returns true for DoH according to RFC8467.
func (t *DNSOverHTTPS) RequiresPadding() bool { func (t *DNSOverHTTPSTransport) RequiresPadding() bool {
return true return true
} }
// Network returns the transport network, i.e., "doh". // Network returns the transport network, i.e., "doh".
func (t *DNSOverHTTPS) Network() string { func (t *DNSOverHTTPSTransport) Network() string {
return "doh" return "doh"
} }
// Address returns the URL we're using for the DoH server. // Address returns the URL we're using for the DoH server.
func (t *DNSOverHTTPS) Address() string { func (t *DNSOverHTTPSTransport) Address() string {
return t.URL return t.URL
} }
// CloseIdleConnections closes idle connections, if any. // CloseIdleConnections closes idle connections, if any.
func (t *DNSOverHTTPS) CloseIdleConnections() { func (t *DNSOverHTTPSTransport) CloseIdleConnections() {
t.Client.CloseIdleConnections() t.Client.CloseIdleConnections()
} }
var _ model.DNSTransport = &DNSOverHTTPS{} var _ model.DNSTransport = &DNSOverHTTPSTransport{}

View File

@ -13,11 +13,11 @@ import (
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestDNSOverHTTPS(t *testing.T) { func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
t.Run("NewRequestFailure", func(t *testing.T) { t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t" const invalidURL = "\t"
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL) txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil) data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here") t.Fatal("expected an error here")
@ -29,7 +29,7 @@ func TestDNSOverHTTPS(t *testing.T) {
t.Run("client.Do failure", func(t *testing.T) { t.Run("client.Do failure", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
txp := &DNSOverHTTPS{ txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{ Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) { MockDo: func(*http.Request) (*http.Response, error) {
return nil, expected return nil, expected
@ -47,7 +47,7 @@ func TestDNSOverHTTPS(t *testing.T) {
}) })
t.Run("server returns 500", func(t *testing.T) { t.Run("server returns 500", func(t *testing.T) {
txp := &DNSOverHTTPS{ txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{ Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) { MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{ return &http.Response{
@ -68,7 +68,7 @@ func TestDNSOverHTTPS(t *testing.T) {
}) })
t.Run("missing content type", func(t *testing.T) { t.Run("missing content type", func(t *testing.T) {
txp := &DNSOverHTTPS{ txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{ Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) { MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{ return &http.Response{
@ -90,7 +90,7 @@ func TestDNSOverHTTPS(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
body := []byte("AAA") body := []byte("AAA")
txp := &DNSOverHTTPS{ txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{ Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) { MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{ return &http.Response{
@ -116,7 +116,7 @@ func TestDNSOverHTTPS(t *testing.T) {
t.Run("sets the correct user-agent", func(t *testing.T) { t.Run("sets the correct user-agent", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
var correct bool var correct bool
txp := &DNSOverHTTPS{ txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{ Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) { MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent() correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
@ -141,7 +141,7 @@ func TestDNSOverHTTPS(t *testing.T) {
var correct bool var correct bool
expected := errors.New("mocked error") expected := errors.New("mocked error")
hostOverride := "test.com" hostOverride := "test.com"
txp := &DNSOverHTTPS{ txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{ Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) { MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride correct = req.Host == hostOverride
@ -167,7 +167,7 @@ func TestDNSOverHTTPS(t *testing.T) {
t.Run("other functions behave correctly", func(t *testing.T) { t.Run("other functions behave correctly", func(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query" const queryURL = "https://cloudflare-dns.com/dns-query"
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL) txp := NewDNSOverHTTPSTransport(http.DefaultClient, queryURL)
if txp.Network() != "doh" { if txp.Network() != "doh" {
t.Fatal("invalid network") t.Fatal("invalid network")
} }
@ -181,7 +181,7 @@ func TestDNSOverHTTPS(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) { t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool var called bool
doh := &DNSOverHTTPS{ doh := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{ Client: &mocks.HTTPClient{
MockCloseIdleConnections: func() { MockCloseIdleConnections: func() {
called = true called = true

View File

@ -14,25 +14,25 @@ import (
// DialContextFunc is the type of net.Dialer.DialContext. // DialContextFunc is the type of net.Dialer.DialContext.
type DialContextFunc func(context.Context, string, string) (net.Conn, error) type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCP is a DNS-over-{TCP,TLS} DNSTransport. // DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
// //
// Bug: this implementation always creates a new connection for each query. // Bug: this implementation always creates a new connection for each query.
type DNSOverTCP struct { type DNSOverTCPTransport struct {
dial DialContextFunc dial DialContextFunc
address string address string
network string network string
requiresPadding bool requiresPadding bool
} }
// NewDNSOverTCP creates a new DNSOverTCP transport. // NewDNSOverTCPTransport creates a new DNSOverTCPTransport.
// //
// Arguments: // Arguments:
// //
// - dial is a function with the net.Dialer.DialContext's signature; // - dial is a function with the net.Dialer.DialContext's signature;
// //
// - address is the endpoint address (e.g., 8.8.8.8:53). // - address is the endpoint address (e.g., 8.8.8.8:53).
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP { func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return &DNSOverTCP{ return &DNSOverTCPTransport{
dial: dial, dial: dial,
address: address, address: address,
network: "tcp", network: "tcp",
@ -47,8 +47,8 @@ func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
// - dial is a function with the net.Dialer.DialContext's signature; // - dial is a function with the net.Dialer.DialContext's signature;
// //
// - address is the endpoint address (e.g., 8.8.8.8:853). // - address is the endpoint address (e.g., 8.8.8.8:853).
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP { func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCPTransport {
return &DNSOverTCP{ return &DNSOverTCPTransport{
dial: dial, dial: dial,
address: address, address: address,
network: "dot", network: "dot",
@ -57,7 +57,7 @@ func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
} }
// RoundTrip sends a query and receives a reply. // RoundTrip sends a query and receives a reply.
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
if len(query) > math.MaxUint16 { if len(query) > math.MaxUint16 {
return nil, errors.New("query too long") return nil, errors.New("query too long")
} }
@ -91,23 +91,23 @@ func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error
// RequiresPadding returns true for DoT and false for TCP // RequiresPadding returns true for DoT and false for TCP
// according to RFC8467. // according to RFC8467.
func (t *DNSOverTCP) RequiresPadding() bool { func (t *DNSOverTCPTransport) RequiresPadding() bool {
return t.requiresPadding return t.requiresPadding
} }
// Network returns the transport network, i.e., "dot" or "tcp". // Network returns the transport network, i.e., "dot" or "tcp".
func (t *DNSOverTCP) Network() string { func (t *DNSOverTCPTransport) Network() string {
return t.network return t.network
} }
// Address returns the upstream server endpoint (e.g., "1.1.1.1:853"). // Address returns the upstream server endpoint (e.g., "1.1.1.1:853").
func (t *DNSOverTCP) Address() string { func (t *DNSOverTCPTransport) Address() string {
return t.address return t.address
} }
// CloseIdleConnections closes idle connections, if any. // CloseIdleConnections closes idle connections, if any.
func (t *DNSOverTCP) CloseIdleConnections() { func (t *DNSOverTCPTransport) CloseIdleConnections() {
// nothing to do // nothing to do
} }
var _ model.DNSTransport = &DNSOverTCP{} var _ model.DNSTransport = &DNSOverTCPTransport{}

View File

@ -13,11 +13,11 @@ import (
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestDNSOverTCP(t *testing.T) { func TestDNSOverTCPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
t.Run("query too large", func(t *testing.T) { t.Run("query too large", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
if err == nil { if err == nil {
t.Fatal("expected an error here") t.Fatal("expected an error here")
@ -35,7 +35,7 @@ func TestDNSOverTCP(t *testing.T) {
return nil, mocked return nil, mocked
}, },
} }
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -60,7 +60,7 @@ func TestDNSOverTCP(t *testing.T) {
}, nil }, nil
}, },
} }
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -88,7 +88,7 @@ func TestDNSOverTCP(t *testing.T) {
}, nil }, nil
}, },
} }
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -119,7 +119,7 @@ func TestDNSOverTCP(t *testing.T) {
}, nil }, nil
}, },
} }
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -156,7 +156,7 @@ func TestDNSOverTCP(t *testing.T) {
}, nil }, nil
}, },
} }
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -185,7 +185,7 @@ func TestDNSOverTCP(t *testing.T) {
}, nil }, nil
}, },
} }
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -198,7 +198,7 @@ func TestDNSOverTCP(t *testing.T) {
t.Run("other functions okay with TCP", func(t *testing.T) { t.Run("other functions okay with TCP", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
if txp.RequiresPadding() != false { if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding") t.Fatal("invalid RequiresPadding")
} }

View File

@ -7,25 +7,25 @@ import (
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
// DNSOverUDP is a DNS-over-UDP DNSTransport. // DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
type DNSOverUDP struct { type DNSOverUDPTransport struct {
dialer model.Dialer dialer model.Dialer
address string address string
} }
// NewDNSOverUDP creates a DNSOverUDP instance. // NewDNSOverUDPTransport creates a DNSOverUDPTransport instance.
// //
// Arguments: // Arguments:
// //
// - dialer is any type that implements the Dialer interface; // - dialer is any type that implements the Dialer interface;
// //
// - address is the endpoint address (e.g., 8.8.8.8:53). // - address is the endpoint address (e.g., 8.8.8.8:53).
func NewDNSOverUDP(dialer model.Dialer, address string) *DNSOverUDP { func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
return &DNSOverUDP{dialer: dialer, address: address} return &DNSOverUDPTransport{dialer: dialer, address: address}
} }
// RoundTrip sends a query and receives a reply. // RoundTrip sends a query and receives a reply.
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
conn, err := t.dialer.DialContext(ctx, "udp", t.address) conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -49,23 +49,23 @@ func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error
} }
// RequiresPadding returns false for UDP according to RFC8467. // RequiresPadding returns false for UDP according to RFC8467.
func (t *DNSOverUDP) RequiresPadding() bool { func (t *DNSOverUDPTransport) RequiresPadding() bool {
return false return false
} }
// Network returns the transport network, i.e., "udp". // Network returns the transport network, i.e., "udp".
func (t *DNSOverUDP) Network() string { func (t *DNSOverUDPTransport) Network() string {
return "udp" return "udp"
} }
// Address returns the upstream server address. // Address returns the upstream server address.
func (t *DNSOverUDP) Address() string { func (t *DNSOverUDPTransport) Address() string {
return t.address return t.address
} }
// CloseIdleConnections closes idle connections, if any. // CloseIdleConnections closes idle connections, if any.
func (t *DNSOverUDP) CloseIdleConnections() { func (t *DNSOverUDPTransport) CloseIdleConnections() {
// nothing to do // nothing to do
} }
var _ model.DNSTransport = &DNSOverUDP{} var _ model.DNSTransport = &DNSOverUDPTransport{}

View File

@ -12,12 +12,12 @@ import (
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestDNSOverUDP(t *testing.T) { func TestDNSOverUDPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
t.Run("dial failure", func(t *testing.T) { t.Run("dial failure", func(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
txp := NewDNSOverUDP(&mocks.Dialer{ txp := NewDNSOverUDPTransport(&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked return nil, mocked
}, },
@ -33,7 +33,7 @@ func TestDNSOverUDP(t *testing.T) {
t.Run("SetDeadline failure", func(t *testing.T) { t.Run("SetDeadline failure", func(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := NewDNSOverUDP( txp := NewDNSOverUDPTransport(
&mocks.Dialer{ &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{ return &mocks.Conn{
@ -58,7 +58,7 @@ func TestDNSOverUDP(t *testing.T) {
t.Run("Write failure", func(t *testing.T) { t.Run("Write failure", func(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := NewDNSOverUDP( txp := NewDNSOverUDPTransport(
&mocks.Dialer{ &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{ return &mocks.Conn{
@ -86,7 +86,7 @@ func TestDNSOverUDP(t *testing.T) {
t.Run("Read failure", func(t *testing.T) { t.Run("Read failure", func(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := NewDNSOverUDP( txp := NewDNSOverUDPTransport(
&mocks.Dialer{ &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{ return &mocks.Conn{
@ -118,7 +118,7 @@ func TestDNSOverUDP(t *testing.T) {
t.Run("read success", func(t *testing.T) { t.Run("read success", func(t *testing.T) {
const expected = 17 const expected = 17
input := bytes.NewReader(make([]byte, expected)) input := bytes.NewReader(make([]byte, expected))
txp := NewDNSOverUDP( txp := NewDNSOverUDPTransport(
&mocks.Dialer{ &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{ return &mocks.Conn{
@ -148,7 +148,7 @@ func TestDNSOverUDP(t *testing.T) {
t.Run("other functions okay", func(t *testing.T) { t.Run("other functions okay", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
txp := NewDNSOverUDP(NewDialerWithoutResolver(log.Log), address) txp := NewDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
if txp.RequiresPadding() != false { if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding") t.Fatal("invalid RequiresPadding")
} }

View File

@ -242,5 +242,5 @@ func (p *DNSProxy) dnstransport() DNSTransport {
return p.Upstream return p.Upstream
} }
const URL = "https://1.1.1.1/dns-query" const URL = "https://1.1.1.1/dns-query"
return netxlite.NewDNSOverHTTPS(http.DefaultClient, URL) return netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, URL)
} }

View File

@ -13,7 +13,7 @@ import (
) )
// ErrNoDNSTransport is the error returned when you attempt to perform // ErrNoDNSTransport is the error returned when you attempt to perform
// a DNS operation that requires a custom DNSTransport (e.g., DNSOverHTTPS) // a DNS operation that requires a custom DNSTransport (e.g., DNSOverHTTPSTransport)
// but you are using the "system" resolver instead. // but you are using the "system" resolver instead.
var ErrNoDNSTransport = errors.New("operation requires a DNS transport") var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
@ -34,7 +34,7 @@ func NewResolverStdlib(logger model.DebugLogger) model.Resolver {
// - address is the server address (e.g., 1.1.1.1:53) // - address is the server address (e.g., 1.1.1.1:53)
func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver {
return WrapResolver(logger, NewSerialResolver( return WrapResolver(logger, NewSerialResolver(
NewDNSOverUDP(dialer, address), NewDNSOverUDPTransport(dialer, address),
)) ))
} }

View File

@ -38,7 +38,7 @@ func TestNewResolverUDP(t *testing.T) {
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr) shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper) errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
serio := errWrapper.Resolver.(*SerialResolver) serio := errWrapper.Resolver.(*SerialResolver)
txp := serio.Transport().(*DNSOverUDP) txp := serio.Transport().(*DNSOverUDPTransport)
if txp.Address() != "1.1.1.1:53" { if txp.Address() != "1.1.1.1:53" {
t.Fatal("invalid address") t.Fatal("invalid address")
} }