refactor(netx/resolver): add CloseIdleConnections to RoundTripper (#501)
While there, also change to pointer receiver and use internal testing for what are clearly unit tests. Part of https://github.com/ooni/probe/issues/1591.
This commit is contained in:
parent
5ab3c3b689
commit
1eb9e8c9b0
|
@ -124,7 +124,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
|
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -200,7 +200,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
|
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -276,7 +276,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
|
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -352,7 +352,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
udptxp, ok := stxp.RoundTripper.(resolver.DNSOverUDP)
|
udptxp, ok := stxp.RoundTripper.(*resolver.DNSOverUDP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
|
|
@ -680,7 +680,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
|
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -696,7 +696,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
|
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -712,7 +712,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
|
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -733,7 +733,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
if _, ok := txp.RoundTripper.(resolver.DNSOverHTTPS); !ok {
|
if _, ok := txp.RoundTripper.(*resolver.DNSOverHTTPS); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -749,7 +749,7 @@ func TestNewDNSClientUDP(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
if _, ok := r.Transport().(resolver.DNSOverUDP); !ok {
|
if _, ok := r.Transport().(*resolver.DNSOverUDP); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -770,7 +770,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
if _, ok := txp.RoundTripper.(resolver.DNSOverUDP); !ok {
|
if _, ok := txp.RoundTripper.(*resolver.DNSOverUDP); !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dnsclient.CloseIdleConnections()
|
dnsclient.CloseIdleConnections()
|
||||||
|
@ -786,7 +786,7 @@ func TestNewDNSClientTCP(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(resolver.DNSOverTCP)
|
txp, ok := r.Transport().(*resolver.DNSOverTCP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -811,7 +811,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dotcp, ok := txp.RoundTripper.(resolver.DNSOverTCP)
|
dotcp, ok := txp.RoundTripper.(*resolver.DNSOverTCP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -831,7 +831,7 @@ func TestNewDNSClientDoT(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(resolver.DNSOverTCP)
|
txp, ok := r.Transport().(*resolver.DNSOverTCP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -856,7 +856,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
dotls, ok := txp.RoundTripper.(resolver.DNSOverTCP)
|
dotls, ok := txp.RoundTripper.(*resolver.DNSOverTCP)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,28 +11,35 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// HTTPClient is the HTTP client expected by DNSOverHTTPS.
|
||||||
|
type HTTPClient interface {
|
||||||
|
Do(req *http.Request) (*http.Response, error)
|
||||||
|
CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
||||||
// an HTTP/HTTPS channel provided by URL using the Do function.
|
// an HTTP/HTTPS channel provided by URL using the Do function.
|
||||||
type DNSOverHTTPS struct {
|
type DNSOverHTTPS struct {
|
||||||
Do func(req *http.Request) (*http.Response, error)
|
Client HTTPClient
|
||||||
URL string
|
URL string
|
||||||
HostOverride string
|
HostOverride string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
||||||
// specified http.Client and URL, as a convenience.
|
// specified http.Client and URL, as a convenience.
|
||||||
func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS {
|
func NewDNSOverHTTPS(client *http.Client, URL string) *DNSOverHTTPS {
|
||||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
||||||
// it's creating a resolver where we use the specified host.
|
// it's creating a resolver where we use the specified host.
|
||||||
func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS {
|
func NewDNSOverHTTPSWithHostOverride(
|
||||||
return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride}
|
client *http.Client, URL, hostOverride string) *DNSOverHTTPS {
|
||||||
|
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements RoundTripper.RoundTrip.
|
// RoundTrip implements RoundTripper.RoundTrip.
|
||||||
func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||||
|
@ -43,7 +50,7 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro
|
||||||
req.Header.Set("user-agent", httpheader.UserAgent())
|
req.Header.Set("user-agent", httpheader.UserAgent())
|
||||||
req.Header.Set("content-type", "application/dns-message")
|
req.Header.Set("content-type", "application/dns-message")
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
resp, err = t.Do(req.WithContext(ctx))
|
resp, err = t.Client.Do(req.WithContext(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -60,18 +67,23 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns true for DoH according to RFC8467
|
// RequiresPadding returns true for DoH according to RFC8467
|
||||||
func (t DNSOverHTTPS) RequiresPadding() bool {
|
func (t *DNSOverHTTPS) RequiresPadding() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network returns the transport network (e.g., doh, dot)
|
// Network returns the transport network (e.g., doh, dot)
|
||||||
func (t DNSOverHTTPS) Network() string {
|
func (t *DNSOverHTTPS) Network() string {
|
||||||
return "doh"
|
return "doh"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the upstream server address.
|
// Address returns the upstream server address.
|
||||||
func (t DNSOverHTTPS) Address() string {
|
func (t *DNSOverHTTPS) Address() string {
|
||||||
return t.URL
|
return t.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ RoundTripper = DNSOverHTTPS{}
|
// CloseIdleConnections closes idle connections.
|
||||||
|
func (t *DNSOverHTTPS) CloseIdleConnections() {
|
||||||
|
t.Client.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ RoundTripper = &DNSOverHTTPS{}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package resolver_test
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -10,12 +10,12 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
||||||
const invalidURL = "\t"
|
const invalidURL = "\t"
|
||||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
data, err := txp.RoundTrip(context.Background(), nil)
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||||
t.Fatal("expected an error here")
|
t.Fatal("expected an error here")
|
||||||
|
@ -27,9 +27,11 @@ func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
txp := resolver.DNSOverHTTPS{
|
txp := &DNSOverHTTPS{
|
||||||
Do: func(*http.Request) (*http.Response, error) {
|
Client: &mocks.HTTPClient{
|
||||||
return nil, expected
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
|
@ -43,12 +45,14 @@ func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
||||||
txp := resolver.DNSOverHTTPS{
|
txp := &DNSOverHTTPS{
|
||||||
Do: func(*http.Request) (*http.Response, error) {
|
Client: &mocks.HTTPClient{
|
||||||
return &http.Response{
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
StatusCode: 500,
|
return &http.Response{
|
||||||
Body: io.NopCloser(strings.NewReader("")),
|
StatusCode: 500,
|
||||||
}, nil
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
|
@ -62,12 +66,14 @@ func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
||||||
txp := resolver.DNSOverHTTPS{
|
txp := &DNSOverHTTPS{
|
||||||
Do: func(*http.Request) (*http.Response, error) {
|
Client: &mocks.HTTPClient{
|
||||||
return &http.Response{
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
StatusCode: 200,
|
return &http.Response{
|
||||||
Body: io.NopCloser(strings.NewReader("")),
|
StatusCode: 200,
|
||||||
}, nil
|
Body: io.NopCloser(strings.NewReader("")),
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
|
@ -82,15 +88,17 @@ func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverHTTPSSuccess(t *testing.T) {
|
func TestDNSOverHTTPSSuccess(t *testing.T) {
|
||||||
body := []byte("AAA")
|
body := []byte("AAA")
|
||||||
txp := resolver.DNSOverHTTPS{
|
txp := &DNSOverHTTPS{
|
||||||
Do: func(*http.Request) (*http.Response, error) {
|
Client: &mocks.HTTPClient{
|
||||||
return &http.Response{
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
StatusCode: 200,
|
return &http.Response{
|
||||||
Body: io.NopCloser(bytes.NewReader(body)),
|
StatusCode: 200,
|
||||||
Header: http.Header{
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
"Content-Type": []string{"application/dns-message"},
|
Header: http.Header{
|
||||||
},
|
"Content-Type": []string{"application/dns-message"},
|
||||||
}, nil
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
|
@ -105,7 +113,7 @@ func TestDNSOverHTTPSSuccess(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverHTTPTransportOK(t *testing.T) {
|
func TestDNSOverHTTPTransportOK(t *testing.T) {
|
||||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||||
if txp.Network() != "doh" {
|
if txp.Network() != "doh" {
|
||||||
t.Fatal("invalid network")
|
t.Fatal("invalid network")
|
||||||
}
|
}
|
||||||
|
@ -120,10 +128,12 @@ func TestDNSOverHTTPTransportOK(t *testing.T) {
|
||||||
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
|
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
var correct bool
|
var correct bool
|
||||||
txp := resolver.DNSOverHTTPS{
|
txp := &DNSOverHTTPS{
|
||||||
Do: func(req *http.Request) (*http.Response, error) {
|
Client: &mocks.HTTPClient{
|
||||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||||
return nil, expected
|
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
|
@ -144,10 +154,12 @@ func TestDNSOverHTTPSHostOverride(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
|
|
||||||
hostOverride := "test.com"
|
hostOverride := "test.com"
|
||||||
txp := resolver.DNSOverHTTPS{
|
txp := &DNSOverHTTPS{
|
||||||
Do: func(req *http.Request) (*http.Response, error) {
|
Client: &mocks.HTTPClient{
|
||||||
correct = req.Host == hostOverride
|
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||||
return nil, expected
|
correct = req.Host == hostOverride
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
HostOverride: hostOverride,
|
HostOverride: hostOverride,
|
||||||
|
|
|
@ -26,8 +26,8 @@ type DNSOverTCP struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
||||||
func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
|
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
||||||
return DNSOverTCP{
|
return &DNSOverTCP{
|
||||||
dial: dial,
|
dial: dial,
|
||||||
address: address,
|
address: address,
|
||||||
network: "tcp",
|
network: "tcp",
|
||||||
|
@ -36,8 +36,8 @@ func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||||
func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
|
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
||||||
return DNSOverTCP{
|
return &DNSOverTCP{
|
||||||
dial: dial,
|
dial: dial,
|
||||||
address: address,
|
address: address,
|
||||||
network: "dot",
|
network: "dot",
|
||||||
|
@ -46,7 +46,7 @@ func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements RoundTripper.RoundTrip.
|
// RoundTrip implements RoundTripper.RoundTrip.
|
||||||
func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||||
if len(query) > math.MaxUint16 {
|
if len(query) > math.MaxUint16 {
|
||||||
return nil, errors.New("query too long")
|
return nil, errors.New("query too long")
|
||||||
}
|
}
|
||||||
|
@ -80,18 +80,23 @@ func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error)
|
||||||
|
|
||||||
// RequiresPadding returns true for DoT and false for TCP
|
// RequiresPadding returns true for DoT and false for TCP
|
||||||
// according to RFC8467.
|
// according to RFC8467.
|
||||||
func (t DNSOverTCP) RequiresPadding() bool {
|
func (t *DNSOverTCP) RequiresPadding() bool {
|
||||||
return t.requiresPadding
|
return t.requiresPadding
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network returns the transport network (e.g., doh, dot)
|
// Network returns the transport network (e.g., doh, dot)
|
||||||
func (t DNSOverTCP) Network() string {
|
func (t *DNSOverTCP) Network() string {
|
||||||
return t.network
|
return t.network
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the upstream server address.
|
// Address returns the upstream server address.
|
||||||
func (t DNSOverTCP) Address() string {
|
func (t *DNSOverTCP) Address() string {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ RoundTripper = DNSOverTCP{}
|
// CloseIdleConnections closes idle connections.
|
||||||
|
func (t *DNSOverTCP) CloseIdleConnections() {
|
||||||
|
// nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ RoundTripper = &DNSOverTCP{}
|
||||||
|
|
|
@ -1,17 +1,15 @@
|
||||||
package resolver_test
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected an error here")
|
t.Fatal("expected an error here")
|
||||||
|
@ -24,8 +22,8 @@ func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
||||||
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
fakedialer := resolver.FakeDialer{Err: mocked}
|
fakedialer := FakeDialer{Err: mocked}
|
||||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -38,10 +36,10 @@ func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
||||||
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||||
SetDeadlineError: mocked,
|
SetDeadlineError: mocked,
|
||||||
}}
|
}}
|
||||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -54,10 +52,10 @@ func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
||||||
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||||
WriteError: mocked,
|
WriteError: mocked,
|
||||||
}}
|
}}
|
||||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -70,10 +68,10 @@ func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
||||||
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||||
ReadError: mocked,
|
ReadError: mocked,
|
||||||
}}
|
}}
|
||||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -86,11 +84,11 @@ func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
||||||
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||||
ReadError: mocked,
|
ReadError: mocked,
|
||||||
ReadData: []byte{byte(0), byte(2)},
|
ReadData: []byte{byte(0), byte(2)},
|
||||||
}}
|
}}
|
||||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -103,11 +101,11 @@ func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
||||||
func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||||
ReadError: mocked,
|
ReadError: mocked,
|
||||||
ReadData: []byte{byte(0), byte(1), byte(1)},
|
ReadData: []byte{byte(0), byte(1), byte(1)},
|
||||||
}}
|
}}
|
||||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -119,7 +117,7 @@ func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverTCPTransportOK(t *testing.T) {
|
func TestDNSOverTCPTransportOK(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||||
if txp.RequiresPadding() != false {
|
if txp.RequiresPadding() != false {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
@ -133,7 +131,7 @@ func TestDNSOverTCPTransportOK(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverTLSTransportOK(t *testing.T) {
|
func TestDNSOverTLSTransportOK(t *testing.T) {
|
||||||
const address = "9.9.9.9:853"
|
const address = "9.9.9.9:853"
|
||||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address)
|
txp := NewDNSOverTLS(DialTLSContext, address)
|
||||||
if txp.RequiresPadding() != true {
|
if txp.RequiresPadding() != true {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,12 +18,12 @@ type DNSOverUDP struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
// NewDNSOverUDP creates a DNSOverUDP instance.
|
||||||
func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP {
|
func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP {
|
||||||
return DNSOverUDP{dialer: dialer, address: address}
|
return &DNSOverUDP{dialer: dialer, address: address}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements RoundTripper.RoundTrip.
|
// RoundTrip implements RoundTripper.RoundTrip.
|
||||||
func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -47,18 +47,23 @@ func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns false for UDP according to RFC8467
|
// RequiresPadding returns false for UDP according to RFC8467
|
||||||
func (t DNSOverUDP) RequiresPadding() bool {
|
func (t *DNSOverUDP) RequiresPadding() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Network returns the transport network (e.g., doh, dot)
|
// Network returns the transport network (e.g., doh, dot)
|
||||||
func (t DNSOverUDP) Network() string {
|
func (t *DNSOverUDP) Network() string {
|
||||||
return "udp"
|
return "udp"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the upstream server address.
|
// Address returns the upstream server address.
|
||||||
func (t DNSOverUDP) Address() string {
|
func (t *DNSOverUDP) Address() string {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ RoundTripper = DNSOverUDP{}
|
// CloseIdleConnections closes idle connections.
|
||||||
|
func (t *DNSOverUDP) CloseIdleConnections() {
|
||||||
|
// nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ RoundTripper = &DNSOverUDP{}
|
||||||
|
|
|
@ -1,18 +1,16 @@
|
||||||
package resolver_test
|
package resolver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverUDPDialFailure(t *testing.T) {
|
func TestDNSOverUDPDialFailure(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address)
|
txp := NewDNSOverUDP(FakeDialer{Err: mocked}, address)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
data, err := txp.RoundTrip(context.Background(), nil)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -24,9 +22,9 @@ func TestDNSOverUDPDialFailure(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := resolver.NewDNSOverUDP(
|
txp := NewDNSOverUDP(
|
||||||
resolver.FakeDialer{
|
FakeDialer{
|
||||||
Conn: &resolver.FakeConn{
|
Conn: &FakeConn{
|
||||||
SetDeadlineError: mocked,
|
SetDeadlineError: mocked,
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
|
@ -42,9 +40,9 @@ func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverUDPWriteFailure(t *testing.T) {
|
func TestDNSOverUDPWriteFailure(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := resolver.NewDNSOverUDP(
|
txp := NewDNSOverUDP(
|
||||||
resolver.FakeDialer{
|
FakeDialer{
|
||||||
Conn: &resolver.FakeConn{
|
Conn: &FakeConn{
|
||||||
WriteError: mocked,
|
WriteError: mocked,
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
|
@ -60,9 +58,9 @@ func TestDNSOverUDPWriteFailure(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverUDPReadFailure(t *testing.T) {
|
func TestDNSOverUDPReadFailure(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := resolver.NewDNSOverUDP(
|
txp := NewDNSOverUDP(
|
||||||
resolver.FakeDialer{
|
FakeDialer{
|
||||||
Conn: &resolver.FakeConn{
|
Conn: &FakeConn{
|
||||||
ReadError: mocked,
|
ReadError: mocked,
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
|
@ -78,9 +76,9 @@ func TestDNSOverUDPReadFailure(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverUDPReadSuccess(t *testing.T) {
|
func TestDNSOverUDPReadSuccess(t *testing.T) {
|
||||||
const expected = 17
|
const expected = 17
|
||||||
txp := resolver.NewDNSOverUDP(
|
txp := NewDNSOverUDP(
|
||||||
resolver.FakeDialer{
|
FakeDialer{
|
||||||
Conn: &resolver.FakeConn{ReadData: make([]byte, 17)},
|
Conn: &FakeConn{ReadData: make([]byte, 17)},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
)
|
)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
data, err := txp.RoundTrip(context.Background(), nil)
|
||||||
|
@ -94,7 +92,7 @@ func TestDNSOverUDPReadSuccess(t *testing.T) {
|
||||||
|
|
||||||
func TestDNSOverUDPTransportOK(t *testing.T) {
|
func TestDNSOverUDPTransportOK(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := resolver.NewDNSOverUDP(&net.Dialer{}, address)
|
txp := NewDNSOverUDP(&net.Dialer{}, address)
|
||||||
if txp.RequiresPadding() != false {
|
if txp.RequiresPadding() != false {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEmitterTransportSuccess(t *testing.T) {
|
func TestEmitterTransportSuccess(t *testing.T) {
|
||||||
|
@ -107,9 +108,11 @@ func TestEmitterResolverFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||||
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
|
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
|
||||||
resolver.DNSOverHTTPS{
|
&resolver.DNSOverHTTPS{
|
||||||
Do: func(req *http.Request) (*http.Response, error) {
|
Client: &mocks.HTTPClient{
|
||||||
return nil, io.EOF
|
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, io.EOF
|
||||||
|
},
|
||||||
},
|
},
|
||||||
URL: "https://dns.google.com/",
|
URL: "https://dns.google.com/",
|
||||||
},
|
},
|
||||||
|
|
|
@ -93,6 +93,10 @@ func (ft FakeTransport) Network() string {
|
||||||
return "fake"
|
return "fake"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (fk FakeTransport) CloseIdleConnections() {
|
||||||
|
// nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
type FakeEncoder struct {
|
type FakeEncoder struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
Err error
|
Err error
|
||||||
|
|
|
@ -22,6 +22,9 @@ type RoundTripper interface {
|
||||||
|
|
||||||
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
||||||
Address() string
|
Address() string
|
||||||
|
|
||||||
|
// CloseIdleConnections closes idle connections.
|
||||||
|
CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SerialResolver is a resolver that first issues an A query and then
|
// SerialResolver is a resolver that first issues an A query and then
|
||||||
|
@ -81,7 +84,7 @@ func (r SerialResolver) roundTripWithRetry(
|
||||||
}
|
}
|
||||||
errorslist = append(errorslist, err)
|
errorslist = append(errorslist, err)
|
||||||
var operr *net.OpError
|
var operr *net.OpError
|
||||||
if errors.As(err, &operr) == false || operr.Timeout() == false {
|
if !errors.As(err, &operr) || !operr.Timeout() {
|
||||||
// The first error is the one that is most likely to be caused
|
// The first error is the one that is most likely to be caused
|
||||||
// by the network. Subsequent errors are more likely to be caused
|
// by the network. Subsequent errors are more likely to be caused
|
||||||
// by context deadlines. So, the first error is attached to an
|
// by context deadlines. So, the first error is attached to an
|
||||||
|
|
|
@ -17,3 +17,20 @@ func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
func (txp *HTTPTransport) CloseIdleConnections() {
|
func (txp *HTTPTransport) CloseIdleConnections() {
|
||||||
txp.MockCloseIdleConnections()
|
txp.MockCloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HTTPClient allows mocking an http.Client.
|
||||||
|
type HTTPClient struct {
|
||||||
|
MockDo func(req *http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
MockCloseIdleConnections func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do calls MockDo.
|
||||||
|
func (txp *HTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
|
return txp.MockDo(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections calls MockCloseIdleConnections.
|
||||||
|
func (txp *HTTPClient) CloseIdleConnections() {
|
||||||
|
txp.MockCloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
|
@ -38,3 +38,34 @@ func TestHTTPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHTTPClient(t *testing.T) {
|
||||||
|
t.Run("Do", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
clnt := &HTTPClient{
|
||||||
|
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := clnt.Do(&http.Request{})
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected nil response here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
called := &atomicx.Int64{}
|
||||||
|
clnt := &HTTPClient{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called.Add(1)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
clnt.CloseIdleConnections()
|
||||||
|
if called.Load() != 1 {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user