refactor(netx/resolver): add CloseIdleConnections to RoundTripper (#501)
While there, also change to pointer receiver and use internal testing for what are clearly unit tests. Part of https://github.com/ooni/probe/issues/1591.
This commit is contained in:
parent
5ab3c3b689
commit
1eb9e8c9b0
|
@ -124,7 +124,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
|
||||
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
|
||||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
|
||||
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
|
||||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
|
@ -276,7 +276,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
|
|||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
|
||||
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
|
||||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
|
@ -352,7 +352,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
udptxp, ok := stxp.RoundTripper.(resolver.DNSOverUDP)
|
||||
udptxp, ok := stxp.RoundTripper.(*resolver.DNSOverUDP)
|
||||
if !ok {
|
||||
t.Fatal("not the DNS transport we expected")
|
||||
}
|
||||
|
|
|
@ -680,7 +680,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
|
||||
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dnsclient.CloseIdleConnections()
|
||||
|
@ -696,7 +696,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
|
||||
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dnsclient.CloseIdleConnections()
|
||||
|
@ -712,7 +712,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
|
||||
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dnsclient.CloseIdleConnections()
|
||||
|
@ -733,7 +733,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
if _, ok := txp.RoundTripper.(resolver.DNSOverHTTPS); !ok {
|
||||
if _, ok := txp.RoundTripper.(*resolver.DNSOverHTTPS); !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dnsclient.CloseIdleConnections()
|
||||
|
@ -749,7 +749,7 @@ func TestNewDNSClientUDP(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
if _, ok := r.Transport().(resolver.DNSOverUDP); !ok {
|
||||
if _, ok := r.Transport().(*resolver.DNSOverUDP); !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dnsclient.CloseIdleConnections()
|
||||
|
@ -770,7 +770,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
if _, ok := txp.RoundTripper.(resolver.DNSOverUDP); !ok {
|
||||
if _, ok := txp.RoundTripper.(*resolver.DNSOverUDP); !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dnsclient.CloseIdleConnections()
|
||||
|
@ -786,7 +786,7 @@ func TestNewDNSClientTCP(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
txp, ok := r.Transport().(resolver.DNSOverTCP)
|
||||
txp, ok := r.Transport().(*resolver.DNSOverTCP)
|
||||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
|
@ -811,7 +811,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dotcp, ok := txp.RoundTripper.(resolver.DNSOverTCP)
|
||||
dotcp, ok := txp.RoundTripper.(*resolver.DNSOverTCP)
|
||||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
|
@ -831,7 +831,7 @@ func TestNewDNSClientDoT(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
txp, ok := r.Transport().(resolver.DNSOverTCP)
|
||||
txp, ok := r.Transport().(*resolver.DNSOverTCP)
|
||||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
|
@ -856,7 +856,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
|
|||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
dotls, ok := txp.RoundTripper.(resolver.DNSOverTCP)
|
||||
dotls, ok := txp.RoundTripper.(*resolver.DNSOverTCP)
|
||||
if !ok {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
|
|
|
@ -11,28 +11,35 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
|
||||
)
|
||||
|
||||
// HTTPClient is the HTTP client expected by DNSOverHTTPS.
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
||||
// an HTTP/HTTPS channel provided by URL using the Do function.
|
||||
type DNSOverHTTPS struct {
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
Client HTTPClient
|
||||
URL string
|
||||
HostOverride string
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
||||
// specified http.Client and URL, as a convenience.
|
||||
func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS {
|
||||
func NewDNSOverHTTPS(client *http.Client, URL string) *DNSOverHTTPS {
|
||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
||||
// it's creating a resolver where we use the specified host.
|
||||
func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS {
|
||||
return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride}
|
||||
func NewDNSOverHTTPSWithHostOverride(
|
||||
client *http.Client, URL, hostOverride string) *DNSOverHTTPS {
|
||||
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||
|
@ -43,7 +50,7 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro
|
|||
req.Header.Set("user-agent", httpheader.UserAgent())
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
var resp *http.Response
|
||||
resp, err = t.Do(req.WithContext(ctx))
|
||||
resp, err = t.Client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -60,18 +67,23 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro
|
|||
}
|
||||
|
||||
// RequiresPadding returns true for DoH according to RFC8467
|
||||
func (t DNSOverHTTPS) RequiresPadding() bool {
|
||||
func (t *DNSOverHTTPS) RequiresPadding() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverHTTPS) Network() string {
|
||||
func (t *DNSOverHTTPS) Network() string {
|
||||
return "doh"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverHTTPS) Address() string {
|
||||
func (t *DNSOverHTTPS) Address() string {
|
||||
return t.URL
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverHTTPS{}
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverHTTPS) CloseIdleConnections() {
|
||||
t.Client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverHTTPS{}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package resolver_test
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
@ -10,12 +10,12 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
||||
const invalidURL = "\t"
|
||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("expected an error here")
|
||||
|
@ -27,10 +27,12 @@ func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
|||
|
||||
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
|
@ -43,13 +45,15 @@ func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 500,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
|
@ -62,13 +66,15 @@ func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
|
@ -82,8 +88,9 @@ func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
|||
|
||||
func TestDNSOverHTTPSSuccess(t *testing.T) {
|
||||
body := []byte("AAA")
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
|
@ -92,6 +99,7 @@ func TestDNSOverHTTPSSuccess(t *testing.T) {
|
|||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
|
@ -105,7 +113,7 @@ func TestDNSOverHTTPSSuccess(t *testing.T) {
|
|||
|
||||
func TestDNSOverHTTPTransportOK(t *testing.T) {
|
||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||
if txp.Network() != "doh" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
|
@ -120,11 +128,13 @@ func TestDNSOverHTTPTransportOK(t *testing.T) {
|
|||
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var correct bool
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
|
@ -144,11 +154,13 @@ func TestDNSOverHTTPSHostOverride(t *testing.T) {
|
|||
expected := errors.New("mocked error")
|
||||
|
||||
hostOverride := "test.com"
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Host == hostOverride
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
|
|
|
@ -26,8 +26,8 @@ type DNSOverTCP struct {
|
|||
}
|
||||
|
||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
||||
func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
|
||||
return DNSOverTCP{
|
||||
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "tcp",
|
||||
|
@ -36,8 +36,8 @@ func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
|
|||
}
|
||||
|
||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
|
||||
return DNSOverTCP{
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "dot",
|
||||
|
@ -46,7 +46,7 @@ func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
|
|||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
if len(query) > math.MaxUint16 {
|
||||
return nil, errors.New("query too long")
|
||||
}
|
||||
|
@ -80,18 +80,23 @@ func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error)
|
|||
|
||||
// RequiresPadding returns true for DoT and false for TCP
|
||||
// according to RFC8467.
|
||||
func (t DNSOverTCP) RequiresPadding() bool {
|
||||
func (t *DNSOverTCP) RequiresPadding() bool {
|
||||
return t.requiresPadding
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverTCP) Network() string {
|
||||
func (t *DNSOverTCP) Network() string {
|
||||
return t.network
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverTCP) Address() string {
|
||||
func (t *DNSOverTCP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverTCP{}
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverTCP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverTCP{}
|
||||
|
|
|
@ -1,17 +1,15 @@
|
|||
package resolver_test
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
|
@ -24,8 +22,8 @@ func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
|||
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Err: mocked}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
fakedialer := FakeDialer{Err: mocked}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
|
@ -38,10 +36,10 @@ func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
|||
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
|
@ -54,10 +52,10 @@ func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
|||
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
WriteError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
|
@ -70,10 +68,10 @@ func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
|||
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
|
@ -86,11 +84,11 @@ func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
|||
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(2)},
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
|
@ -103,11 +101,11 @@ func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
|||
func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(1), byte(1)},
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -119,7 +117,7 @@ func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
|||
|
||||
func TestDNSOverTCPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
|
@ -133,7 +131,7 @@ func TestDNSOverTCPTransportOK(t *testing.T) {
|
|||
|
||||
func TestDNSOverTLSTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:853"
|
||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address)
|
||||
txp := NewDNSOverTLS(DialTLSContext, address)
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
|
|
|
@ -18,12 +18,12 @@ type DNSOverUDP struct {
|
|||
}
|
||||
|
||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
||||
func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP {
|
||||
return DNSOverUDP{dialer: dialer, address: address}
|
||||
func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP {
|
||||
return &DNSOverUDP{dialer: dialer, address: address}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -47,18 +47,23 @@ func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error)
|
|||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467
|
||||
func (t DNSOverUDP) RequiresPadding() bool {
|
||||
func (t *DNSOverUDP) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverUDP) Network() string {
|
||||
func (t *DNSOverUDP) Network() string {
|
||||
return "udp"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverUDP) Address() string {
|
||||
func (t *DNSOverUDP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverUDP{}
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverUDP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverUDP{}
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
package resolver_test
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverUDPDialFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address)
|
||||
txp := NewDNSOverUDP(FakeDialer{Err: mocked}, address)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
|
@ -24,9 +22,9 @@ func TestDNSOverUDPDialFailure(t *testing.T) {
|
|||
|
||||
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
|
@ -42,9 +40,9 @@ func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
|||
|
||||
func TestDNSOverUDPWriteFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{
|
||||
WriteError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
|
@ -60,9 +58,9 @@ func TestDNSOverUDPWriteFailure(t *testing.T) {
|
|||
|
||||
func TestDNSOverUDPReadFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
|
@ -78,9 +76,9 @@ func TestDNSOverUDPReadFailure(t *testing.T) {
|
|||
|
||||
func TestDNSOverUDPReadSuccess(t *testing.T) {
|
||||
const expected = 17
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{ReadData: make([]byte, 17)},
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{ReadData: make([]byte, 17)},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
|
@ -94,7 +92,7 @@ func TestDNSOverUDPReadSuccess(t *testing.T) {
|
|||
|
||||
func TestDNSOverUDPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverUDP(&net.Dialer{}, address)
|
||||
txp := NewDNSOverUDP(&net.Dialer{}, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestEmitterTransportSuccess(t *testing.T) {
|
||||
|
@ -107,10 +108,12 @@ func TestEmitterResolverFailure(t *testing.T) {
|
|||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
|
||||
resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
&resolver.DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, io.EOF
|
||||
},
|
||||
},
|
||||
URL: "https://dns.google.com/",
|
||||
},
|
||||
)}
|
||||
|
|
|
@ -93,6 +93,10 @@ func (ft FakeTransport) Network() string {
|
|||
return "fake"
|
||||
}
|
||||
|
||||
func (fk FakeTransport) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
type FakeEncoder struct {
|
||||
Data []byte
|
||||
Err error
|
||||
|
|
|
@ -22,6 +22,9 @@ type RoundTripper interface {
|
|||
|
||||
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
||||
Address() string
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// SerialResolver is a resolver that first issues an A query and then
|
||||
|
@ -81,7 +84,7 @@ func (r SerialResolver) roundTripWithRetry(
|
|||
}
|
||||
errorslist = append(errorslist, err)
|
||||
var operr *net.OpError
|
||||
if errors.As(err, &operr) == false || operr.Timeout() == false {
|
||||
if !errors.As(err, &operr) || !operr.Timeout() {
|
||||
// The first error is the one that is most likely to be caused
|
||||
// by the network. Subsequent errors are more likely to be caused
|
||||
// by context deadlines. So, the first error is attached to an
|
||||
|
|
|
@ -17,3 +17,20 @@ func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
func (txp *HTTPTransport) CloseIdleConnections() {
|
||||
txp.MockCloseIdleConnections()
|
||||
}
|
||||
|
||||
// HTTPClient allows mocking an http.Client.
|
||||
type HTTPClient struct {
|
||||
MockDo func(req *http.Request) (*http.Response, error)
|
||||
|
||||
MockCloseIdleConnections func()
|
||||
}
|
||||
|
||||
// Do calls MockDo.
|
||||
func (txp *HTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
return txp.MockDo(req)
|
||||
}
|
||||
|
||||
// CloseIdleConnections calls MockCloseIdleConnections.
|
||||
func (txp *HTTPClient) CloseIdleConnections() {
|
||||
txp.MockCloseIdleConnections()
|
||||
}
|
||||
|
|
|
@ -38,3 +38,34 @@ func TestHTTPTransport(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPClient(t *testing.T) {
|
||||
t.Run("Do", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
clnt := &HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
resp, err := clnt.Do(&http.Request{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
called := &atomicx.Int64{}
|
||||
clnt := &HTTPClient{
|
||||
MockCloseIdleConnections: func() {
|
||||
called.Add(1)
|
||||
},
|
||||
}
|
||||
clnt.CloseIdleConnections()
|
||||
if called.Load() != 1 {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user