refactor(netxlite): add Transport suffix to DNS transports (#731)
This diff has been extracted from c2f7ccab0e
See https://github.com/ooni/probe/issues/2096
This commit is contained in:
parent
6c388d2c61
commit
f5b801ae95
|
@ -124,7 +124,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPS)
|
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPSTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPS)
|
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPSTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -276,7 +276,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPS)
|
dohtxp, ok := stxp.DNSTransport.(*netxlite.DNSOverHTTPSTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -352,7 +352,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
udptxp, ok := stxp.DNSTransport.(*netxlite.DNSOverUDP)
|
udptxp, ok := stxp.DNSTransport.(*netxlite.DNSOverUDPTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
|
|
@ -368,5 +368,5 @@ const thResolverURL = "https://dns.google/dns-query"
|
||||||
// Here we're using github.com/apex/log as the logger, which
|
// Here we're using github.com/apex/log as the logger, which
|
||||||
// is fine because this is backend only code.
|
// is fine because this is backend only code.
|
||||||
var thResolver = netxlite.WrapResolver(log.Log, netxlite.NewSerialResolver(
|
var thResolver = netxlite.WrapResolver(log.Log, netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverHTTPS(http.DefaultClient, thResolverURL),
|
netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, thResolverURL),
|
||||||
))
|
))
|
||||||
|
|
|
@ -286,7 +286,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
||||||
case "https":
|
case "https":
|
||||||
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
|
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||||
httpClient := &http.Client{Transport: NewHTTPTransport(config)}
|
httpClient := &http.Client{Transport: NewHTTPTransport(config)}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverHTTPSWithHostOverride(
|
var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride(
|
||||||
httpClient, URL, hostOverride)
|
httpClient, URL, hostOverride)
|
||||||
if config.ResolveSaver != nil {
|
if config.ResolveSaver != nil {
|
||||||
txp = resolver.SaverDNSTransport{
|
txp = resolver.SaverDNSTransport{
|
||||||
|
@ -301,7 +301,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverUDP(
|
var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport(
|
||||||
dialer, endpoint)
|
dialer, endpoint)
|
||||||
if config.ResolveSaver != nil {
|
if config.ResolveSaver != nil {
|
||||||
txp = resolver.SaverDNSTransport{
|
txp = resolver.SaverDNSTransport{
|
||||||
|
@ -332,7 +332,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverTCP(
|
var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport(
|
||||||
dialer.DialContext, endpoint)
|
dialer.DialContext, endpoint)
|
||||||
if config.ResolveSaver != nil {
|
if config.ResolveSaver != nil {
|
||||||
txp = resolver.SaverDNSTransport{
|
txp = resolver.SaverDNSTransport{
|
||||||
|
|
|
@ -586,7 +586,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(*netxlite.DNSOverHTTPS); !ok {
|
if _, ok := r.Transport().(*netxlite.DNSOverHTTPSTransport); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -602,7 +602,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(*netxlite.DNSOverHTTPS); !ok {
|
if _, ok := r.Transport().(*netxlite.DNSOverHTTPSTransport); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -618,7 +618,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(*netxlite.DNSOverHTTPS); !ok {
|
if _, ok := r.Transport().(*netxlite.DNSOverHTTPSTransport); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -639,7 +639,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
if _, ok := txp.DNSTransport.(*netxlite.DNSOverHTTPS); !ok {
|
if _, ok := txp.DNSTransport.(*netxlite.DNSOverHTTPSTransport); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -655,7 +655,7 @@ func TestNewDNSClientUDP(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(*netxlite.DNSOverUDP); !ok {
|
if _, ok := r.Transport().(*netxlite.DNSOverUDPTransport); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -676,7 +676,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
if _, ok := txp.DNSTransport.(*netxlite.DNSOverUDP); !ok {
|
if _, ok := txp.DNSTransport.(*netxlite.DNSOverUDPTransport); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -692,7 +692,7 @@ func TestNewDNSClientTCP(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(*netxlite.DNSOverTCP)
|
txp, ok := r.Transport().(*netxlite.DNSOverTCPTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -717,7 +717,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dotcp, ok := txp.DNSTransport.(*netxlite.DNSOverTCP)
|
dotcp, ok := txp.DNSTransport.(*netxlite.DNSOverTCPTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -737,7 +737,7 @@ func TestNewDNSClientDoT(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(*netxlite.DNSOverTCP)
|
txp, ok := r.Transport().(*netxlite.DNSOverTCPTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -762,7 +762,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dotls, ok := txp.DNSTransport.(*netxlite.DNSOverTCP)
|
dotls, ok := txp.DNSTransport.(*netxlite.DNSOverTCPTransport)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,28 +71,28 @@ func TestNewResolverSystem(t *testing.T) {
|
||||||
|
|
||||||
func TestNewResolverUDPAddress(t *testing.T) {
|
func TestNewResolverUDPAddress(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverUDP(netxlite.DefaultDialer, "8.8.8.8:53"))
|
netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "8.8.8.8:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverUDPDomain(t *testing.T) {
|
func TestNewResolverUDPDomain(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverUDP(netxlite.DefaultDialer, "dns.google.com:53"))
|
netxlite.NewDNSOverUDPTransport(netxlite.DefaultDialer, "dns.google.com:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverTCPAddress(t *testing.T) {
|
func TestNewResolverTCPAddress(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverTCP(new(net.Dialer).DialContext, "8.8.8.8:53"))
|
netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "8.8.8.8:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverTCPDomain(t *testing.T) {
|
func TestNewResolverTCPDomain(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverTCP(new(net.Dialer).DialContext, "dns.google.com:53"))
|
netxlite.NewDNSOverTCPTransport(new(net.Dialer).DialContext, "dns.google.com:53"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ func TestNewResolverDoTDomain(t *testing.T) {
|
||||||
|
|
||||||
func TestNewResolverDoH(t *testing.T) {
|
func TestNewResolverDoH(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverHTTPS(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
|
netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ func (mx *Measurer) NewResolverSystem(db WritableDB, logger model.Logger) model.
|
||||||
func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver {
|
func (mx *Measurer) NewResolverUDP(db WritableDB, logger model.Logger, address string) model.Resolver {
|
||||||
return mx.WrapResolver(db, netxlite.WrapResolver(
|
return mx.WrapResolver(db, netxlite.WrapResolver(
|
||||||
logger, netxlite.NewSerialResolver(
|
logger, netxlite.NewSerialResolver(
|
||||||
mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDP(
|
mx.WrapDNSXRoundTripper(db, netxlite.NewDNSOverUDPTransport(
|
||||||
mx.NewDialerWithSystemResolver(db, logger),
|
mx.NewDialerWithSystemResolver(db, logger),
|
||||||
address,
|
address,
|
||||||
)))),
|
)))),
|
||||||
|
|
|
@ -244,7 +244,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
// These errors are returned by custom DNSTransport instances (e.g.,
|
// These errors are returned by custom DNSTransport instances (e.g.,
|
||||||
// DNSOverHTTPS and DNSOverUDP). Their suffix matches the equivalent
|
// DNSOverHTTPSTransport and DNSOverUDPTransport). Their suffix matches the equivalent
|
||||||
// unexported errors used by the Go standard library.
|
// unexported errors used by the Go standard library.
|
||||||
var (
|
var (
|
||||||
ErrOODNSNoSuchHost = fmt.Errorf("ooniresolver: %s", DNSNoSuchHostSuffix)
|
ErrOODNSNoSuchHost = fmt.Errorf("ooniresolver: %s", DNSNoSuchHostSuffix)
|
||||||
|
|
|
@ -11,8 +11,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSOverHTTPS is a DNS-over-HTTPS DNSTransport.
|
// DNSOverHTTPSTransport is a DNS-over-HTTPS DNSTransport.
|
||||||
type DNSOverHTTPS struct {
|
type DNSOverHTTPSTransport struct {
|
||||||
// Client is the MANDATORY http client to use.
|
// Client is the MANDATORY http client to use.
|
||||||
Client model.HTTPClient
|
Client model.HTTPClient
|
||||||
|
|
||||||
|
@ -24,26 +24,26 @@ type DNSOverHTTPS struct {
|
||||||
HostOverride string
|
HostOverride string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverHTTPS creates a new DNSOverHTTPS instance.
|
// NewDNSOverHTTPSTransport creates a new DNSOverHTTPSTransport instance.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - client in http.Client-like type (e.g., http.DefaultClient);
|
// - client in http.Client-like type (e.g., http.DefaultClient);
|
||||||
//
|
//
|
||||||
// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query).
|
// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query).
|
||||||
func NewDNSOverHTTPS(client model.HTTPClient, URL string) *DNSOverHTTPS {
|
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
|
||||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverHTTPSWithHostOverride creates a new DNSOverHTTPS
|
// NewDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport
|
||||||
// with the given Host header override.
|
// with the given Host header override.
|
||||||
func NewDNSOverHTTPSWithHostOverride(
|
func NewDNSOverHTTPSTransportWithHostOverride(
|
||||||
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPS {
|
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
|
||||||
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
|
return &DNSOverHTTPSTransport{Client: client, URL: URL, HostOverride: hostOverride}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip sends a query and receives a reply.
|
// RoundTrip sends a query and receives a reply.
|
||||||
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||||
|
@ -71,23 +71,23 @@ func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns true for DoH according to RFC8467.
|
// RequiresPadding returns true for DoH according to RFC8467.
|
||||||
func (t *DNSOverHTTPS) RequiresPadding() bool {
|
func (t *DNSOverHTTPSTransport) RequiresPadding() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network returns the transport network, i.e., "doh".
|
// Network returns the transport network, i.e., "doh".
|
||||||
func (t *DNSOverHTTPS) Network() string {
|
func (t *DNSOverHTTPSTransport) Network() string {
|
||||||
return "doh"
|
return "doh"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the URL we're using for the DoH server.
|
// Address returns the URL we're using for the DoH server.
|
||||||
func (t *DNSOverHTTPS) Address() string {
|
func (t *DNSOverHTTPSTransport) Address() string {
|
||||||
return t.URL
|
return t.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections closes idle connections, if any.
|
// CloseIdleConnections closes idle connections, if any.
|
||||||
func (t *DNSOverHTTPS) CloseIdleConnections() {
|
func (t *DNSOverHTTPSTransport) CloseIdleConnections() {
|
||||||
t.Client.CloseIdleConnections()
|
t.Client.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DNSTransport = &DNSOverHTTPS{}
|
var _ model.DNSTransport = &DNSOverHTTPSTransport{}
|
||||||
|
|
|
@ -13,11 +13,11 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverHTTPS(t *testing.T) {
|
func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
t.Run("NewRequestFailure", func(t *testing.T) {
|
t.Run("NewRequestFailure", func(t *testing.T) {
|
||||||
const invalidURL = "\t"
|
const invalidURL = "\t"
|
||||||
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
data, err := txp.RoundTrip(context.Background(), nil)
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||||
t.Fatal("expected an error here")
|
t.Fatal("expected an error here")
|
||||||
|
@ -29,7 +29,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
|
|
||||||
t.Run("client.Do failure", func(t *testing.T) {
|
t.Run("client.Do failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
txp := &DNSOverHTTPS{
|
txp := &DNSOverHTTPSTransport{
|
||||||
Client: &mocks.HTTPClient{
|
Client: &mocks.HTTPClient{
|
||||||
MockDo: func(*http.Request) (*http.Response, error) {
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
|
@ -47,7 +47,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("server returns 500", func(t *testing.T) {
|
t.Run("server returns 500", func(t *testing.T) {
|
||||||
txp := &DNSOverHTTPS{
|
txp := &DNSOverHTTPSTransport{
|
||||||
Client: &mocks.HTTPClient{
|
Client: &mocks.HTTPClient{
|
||||||
MockDo: func(*http.Request) (*http.Response, error) {
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
|
@ -68,7 +68,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("missing content type", func(t *testing.T) {
|
t.Run("missing content type", func(t *testing.T) {
|
||||||
txp := &DNSOverHTTPS{
|
txp := &DNSOverHTTPSTransport{
|
||||||
Client: &mocks.HTTPClient{
|
Client: &mocks.HTTPClient{
|
||||||
MockDo: func(*http.Request) (*http.Response, error) {
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
|
@ -90,7 +90,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
|
|
||||||
t.Run("success", func(t *testing.T) {
|
t.Run("success", func(t *testing.T) {
|
||||||
body := []byte("AAA")
|
body := []byte("AAA")
|
||||||
txp := &DNSOverHTTPS{
|
txp := &DNSOverHTTPSTransport{
|
||||||
Client: &mocks.HTTPClient{
|
Client: &mocks.HTTPClient{
|
||||||
MockDo: func(*http.Request) (*http.Response, error) {
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
|
@ -116,7 +116,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
t.Run("sets the correct user-agent", func(t *testing.T) {
|
t.Run("sets the correct user-agent", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
var correct bool
|
var correct bool
|
||||||
txp := &DNSOverHTTPS{
|
txp := &DNSOverHTTPSTransport{
|
||||||
Client: &mocks.HTTPClient{
|
Client: &mocks.HTTPClient{
|
||||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||||
|
@ -141,7 +141,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
var correct bool
|
var correct bool
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
hostOverride := "test.com"
|
hostOverride := "test.com"
|
||||||
txp := &DNSOverHTTPS{
|
txp := &DNSOverHTTPSTransport{
|
||||||
Client: &mocks.HTTPClient{
|
Client: &mocks.HTTPClient{
|
||||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||||
correct = req.Host == hostOverride
|
correct = req.Host == hostOverride
|
||||||
|
@ -167,7 +167,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions behave correctly", func(t *testing.T) {
|
t.Run("other functions behave correctly", func(t *testing.T) {
|
||||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||||
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
txp := NewDNSOverHTTPSTransport(http.DefaultClient, queryURL)
|
||||||
if txp.Network() != "doh" {
|
if txp.Network() != "doh" {
|
||||||
t.Fatal("invalid network")
|
t.Fatal("invalid network")
|
||||||
}
|
}
|
||||||
|
@ -181,7 +181,7 @@ func TestDNSOverHTTPS(t *testing.T) {
|
||||||
|
|
||||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
var called bool
|
var called bool
|
||||||
doh := &DNSOverHTTPS{
|
doh := &DNSOverHTTPSTransport{
|
||||||
Client: &mocks.HTTPClient{
|
Client: &mocks.HTTPClient{
|
||||||
MockCloseIdleConnections: func() {
|
MockCloseIdleConnections: func() {
|
||||||
called = true
|
called = true
|
||||||
|
|
|
@ -14,25 +14,25 @@ import (
|
||||||
// DialContextFunc is the type of net.Dialer.DialContext.
|
// DialContextFunc is the type of net.Dialer.DialContext.
|
||||||
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
||||||
|
|
||||||
// DNSOverTCP is a DNS-over-{TCP,TLS} DNSTransport.
|
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
|
||||||
//
|
//
|
||||||
// Bug: this implementation always creates a new connection for each query.
|
// Bug: this implementation always creates a new connection for each query.
|
||||||
type DNSOverTCP struct {
|
type DNSOverTCPTransport struct {
|
||||||
dial DialContextFunc
|
dial DialContextFunc
|
||||||
address string
|
address string
|
||||||
network string
|
network string
|
||||||
requiresPadding bool
|
requiresPadding bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
// NewDNSOverTCPTransport creates a new DNSOverTCPTransport.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - dial is a function with the net.Dialer.DialContext's signature;
|
// - dial is a function with the net.Dialer.DialContext's signature;
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||||
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||||
return &DNSOverTCP{
|
return &DNSOverTCPTransport{
|
||||||
dial: dial,
|
dial: dial,
|
||||||
address: address,
|
address: address,
|
||||||
network: "tcp",
|
network: "tcp",
|
||||||
|
@ -47,8 +47,8 @@ func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
||||||
// - dial is a function with the net.Dialer.DialContext's signature;
|
// - dial is a function with the net.Dialer.DialContext's signature;
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
||||||
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||||
return &DNSOverTCP{
|
return &DNSOverTCPTransport{
|
||||||
dial: dial,
|
dial: dial,
|
||||||
address: address,
|
address: address,
|
||||||
network: "dot",
|
network: "dot",
|
||||||
|
@ -57,7 +57,7 @@ func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip sends a query and receives a reply.
|
// RoundTrip sends a query and receives a reply.
|
||||||
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||||
if len(query) > math.MaxUint16 {
|
if len(query) > math.MaxUint16 {
|
||||||
return nil, errors.New("query too long")
|
return nil, errors.New("query too long")
|
||||||
}
|
}
|
||||||
|
@ -91,23 +91,23 @@ func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error
|
||||||
|
|
||||||
// RequiresPadding returns true for DoT and false for TCP
|
// RequiresPadding returns true for DoT and false for TCP
|
||||||
// according to RFC8467.
|
// according to RFC8467.
|
||||||
func (t *DNSOverTCP) RequiresPadding() bool {
|
func (t *DNSOverTCPTransport) RequiresPadding() bool {
|
||||||
return t.requiresPadding
|
return t.requiresPadding
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network returns the transport network, i.e., "dot" or "tcp".
|
// Network returns the transport network, i.e., "dot" or "tcp".
|
||||||
func (t *DNSOverTCP) Network() string {
|
func (t *DNSOverTCPTransport) Network() string {
|
||||||
return t.network
|
return t.network
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the upstream server endpoint (e.g., "1.1.1.1:853").
|
// Address returns the upstream server endpoint (e.g., "1.1.1.1:853").
|
||||||
func (t *DNSOverTCP) Address() string {
|
func (t *DNSOverTCPTransport) Address() string {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections closes idle connections, if any.
|
// CloseIdleConnections closes idle connections, if any.
|
||||||
func (t *DNSOverTCP) CloseIdleConnections() {
|
func (t *DNSOverTCPTransport) CloseIdleConnections() {
|
||||||
// nothing to do
|
// nothing to do
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DNSTransport = &DNSOverTCP{}
|
var _ model.DNSTransport = &DNSOverTCPTransport{}
|
||||||
|
|
|
@ -13,11 +13,11 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverTCP(t *testing.T) {
|
func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
t.Run("query too large", func(t *testing.T) {
|
t.Run("query too large", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected an error here")
|
t.Fatal("expected an error here")
|
||||||
|
@ -35,7 +35,7 @@ func TestDNSOverTCP(t *testing.T) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -60,7 +60,7 @@ func TestDNSOverTCP(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -88,7 +88,7 @@ func TestDNSOverTCP(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -119,7 +119,7 @@ func TestDNSOverTCP(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -156,7 +156,7 @@ func TestDNSOverTCP(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -185,7 +185,7 @@ func TestDNSOverTCP(t *testing.T) {
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -198,7 +198,7 @@ func TestDNSOverTCP(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions okay with TCP", func(t *testing.T) {
|
t.Run("other functions okay with TCP", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||||
if txp.RequiresPadding() != false {
|
if txp.RequiresPadding() != false {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,25 +7,25 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSOverUDP is a DNS-over-UDP DNSTransport.
|
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
|
||||||
type DNSOverUDP struct {
|
type DNSOverUDPTransport struct {
|
||||||
dialer model.Dialer
|
dialer model.Dialer
|
||||||
address string
|
address string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
// NewDNSOverUDPTransport creates a DNSOverUDPTransport instance.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - dialer is any type that implements the Dialer interface;
|
// - dialer is any type that implements the Dialer interface;
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||||
func NewDNSOverUDP(dialer model.Dialer, address string) *DNSOverUDP {
|
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
||||||
return &DNSOverUDP{dialer: dialer, address: address}
|
return &DNSOverUDPTransport{dialer: dialer, address: address}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip sends a query and receives a reply.
|
// RoundTrip sends a query and receives a reply.
|
||||||
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -49,23 +49,23 @@ func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns false for UDP according to RFC8467.
|
// RequiresPadding returns false for UDP according to RFC8467.
|
||||||
func (t *DNSOverUDP) RequiresPadding() bool {
|
func (t *DNSOverUDPTransport) RequiresPadding() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network returns the transport network, i.e., "udp".
|
// Network returns the transport network, i.e., "udp".
|
||||||
func (t *DNSOverUDP) Network() string {
|
func (t *DNSOverUDPTransport) Network() string {
|
||||||
return "udp"
|
return "udp"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the upstream server address.
|
// Address returns the upstream server address.
|
||||||
func (t *DNSOverUDP) Address() string {
|
func (t *DNSOverUDPTransport) Address() string {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections closes idle connections, if any.
|
// CloseIdleConnections closes idle connections, if any.
|
||||||
func (t *DNSOverUDP) CloseIdleConnections() {
|
func (t *DNSOverUDPTransport) CloseIdleConnections() {
|
||||||
// nothing to do
|
// nothing to do
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DNSTransport = &DNSOverUDP{}
|
var _ model.DNSTransport = &DNSOverUDPTransport{}
|
||||||
|
|
|
@ -12,12 +12,12 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverUDP(t *testing.T) {
|
func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
t.Run("dial failure", func(t *testing.T) {
|
t.Run("dial failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverUDP(&mocks.Dialer{
|
txp := NewDNSOverUDPTransport(&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
|
@ -33,7 +33,7 @@ func TestDNSOverUDP(t *testing.T) {
|
||||||
|
|
||||||
t.Run("SetDeadline failure", func(t *testing.T) {
|
t.Run("SetDeadline failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := NewDNSOverUDP(
|
txp := NewDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -58,7 +58,7 @@ func TestDNSOverUDP(t *testing.T) {
|
||||||
|
|
||||||
t.Run("Write failure", func(t *testing.T) {
|
t.Run("Write failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := NewDNSOverUDP(
|
txp := NewDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -86,7 +86,7 @@ func TestDNSOverUDP(t *testing.T) {
|
||||||
|
|
||||||
t.Run("Read failure", func(t *testing.T) {
|
t.Run("Read failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := NewDNSOverUDP(
|
txp := NewDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -118,7 +118,7 @@ func TestDNSOverUDP(t *testing.T) {
|
||||||
t.Run("read success", func(t *testing.T) {
|
t.Run("read success", func(t *testing.T) {
|
||||||
const expected = 17
|
const expected = 17
|
||||||
input := bytes.NewReader(make([]byte, expected))
|
input := bytes.NewReader(make([]byte, expected))
|
||||||
txp := NewDNSOverUDP(
|
txp := NewDNSOverUDPTransport(
|
||||||
&mocks.Dialer{
|
&mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -148,7 +148,7 @@ func TestDNSOverUDP(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions okay", func(t *testing.T) {
|
t.Run("other functions okay", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverUDP(NewDialerWithoutResolver(log.Log), address)
|
txp := NewDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
|
||||||
if txp.RequiresPadding() != false {
|
if txp.RequiresPadding() != false {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
|
|
@ -242,5 +242,5 @@ func (p *DNSProxy) dnstransport() DNSTransport {
|
||||||
return p.Upstream
|
return p.Upstream
|
||||||
}
|
}
|
||||||
const URL = "https://1.1.1.1/dns-query"
|
const URL = "https://1.1.1.1/dns-query"
|
||||||
return netxlite.NewDNSOverHTTPS(http.DefaultClient, URL)
|
return netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, URL)
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNoDNSTransport is the error returned when you attempt to perform
|
// ErrNoDNSTransport is the error returned when you attempt to perform
|
||||||
// a DNS operation that requires a custom DNSTransport (e.g., DNSOverHTTPS)
|
// a DNS operation that requires a custom DNSTransport (e.g., DNSOverHTTPSTransport)
|
||||||
// but you are using the "system" resolver instead.
|
// but you are using the "system" resolver instead.
|
||||||
var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
|
var ErrNoDNSTransport = errors.New("operation requires a DNS transport")
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ func NewResolverStdlib(logger model.DebugLogger) model.Resolver {
|
||||||
// - address is the server address (e.g., 1.1.1.1:53)
|
// - address is the server address (e.g., 1.1.1.1:53)
|
||||||
func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver {
|
func NewResolverUDP(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver {
|
||||||
return WrapResolver(logger, NewSerialResolver(
|
return WrapResolver(logger, NewSerialResolver(
|
||||||
NewDNSOverUDP(dialer, address),
|
NewDNSOverUDPTransport(dialer, address),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ func TestNewResolverUDP(t *testing.T) {
|
||||||
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
shortCircuit := logger.Resolver.(*resolverShortCircuitIPAddr)
|
||||||
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
errWrapper := shortCircuit.Resolver.(*resolverErrWrapper)
|
||||||
serio := errWrapper.Resolver.(*SerialResolver)
|
serio := errWrapper.Resolver.(*SerialResolver)
|
||||||
txp := serio.Transport().(*DNSOverUDP)
|
txp := serio.Transport().(*DNSOverUDPTransport)
|
||||||
if txp.Address() != "1.1.1.1:53" {
|
if txp.Address() != "1.1.1.1:53" {
|
||||||
t.Fatal("invalid address")
|
t.Fatal("invalid address")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user