refactor(netx/resolver): add CloseIdleConnections to RoundTripper (#501)

While there, also change to pointer receiver and use internal
testing for what are clearly unit tests.

Part of https://github.com/ooni/probe/issues/1591.
This commit is contained in:
Simone Basso 2021-09-09 20:49:12 +02:00 committed by GitHub
parent 5ab3c3b689
commit 1eb9e8c9b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 203 additions and 115 deletions

View File

@ -124,7 +124,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
if !ok {
t.Fatal("not the DNS transport we expected")
}
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
if !ok {
t.Fatal("not the DNS transport we expected")
}
@ -200,7 +200,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
if !ok {
t.Fatal("not the DNS transport we expected")
}
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
if !ok {
t.Fatal("not the DNS transport we expected")
}
@ -276,7 +276,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
if !ok {
t.Fatal("not the DNS transport we expected")
}
dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS)
dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS)
if !ok {
t.Fatal("not the DNS transport we expected")
}
@ -352,7 +352,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
if !ok {
t.Fatal("not the DNS transport we expected")
}
udptxp, ok := stxp.RoundTripper.(resolver.DNSOverUDP)
udptxp, ok := stxp.RoundTripper.(*resolver.DNSOverUDP)
if !ok {
t.Fatal("not the DNS transport we expected")
}

View File

@ -680,7 +680,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
t.Fatal("not the transport we expected")
}
dnsclient.CloseIdleConnections()
@ -696,7 +696,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
t.Fatal("not the transport we expected")
}
dnsclient.CloseIdleConnections()
@ -712,7 +712,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok {
if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok {
t.Fatal("not the transport we expected")
}
dnsclient.CloseIdleConnections()
@ -733,7 +733,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
if !ok {
t.Fatal("not the transport we expected")
}
if _, ok := txp.RoundTripper.(resolver.DNSOverHTTPS); !ok {
if _, ok := txp.RoundTripper.(*resolver.DNSOverHTTPS); !ok {
t.Fatal("not the transport we expected")
}
dnsclient.CloseIdleConnections()
@ -749,7 +749,7 @@ func TestNewDNSClientUDP(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
if _, ok := r.Transport().(resolver.DNSOverUDP); !ok {
if _, ok := r.Transport().(*resolver.DNSOverUDP); !ok {
t.Fatal("not the transport we expected")
}
dnsclient.CloseIdleConnections()
@ -770,7 +770,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
if !ok {
t.Fatal("not the transport we expected")
}
if _, ok := txp.RoundTripper.(resolver.DNSOverUDP); !ok {
if _, ok := txp.RoundTripper.(*resolver.DNSOverUDP); !ok {
t.Fatal("not the transport we expected")
}
dnsclient.CloseIdleConnections()
@ -786,7 +786,7 @@ func TestNewDNSClientTCP(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
txp, ok := r.Transport().(resolver.DNSOverTCP)
txp, ok := r.Transport().(*resolver.DNSOverTCP)
if !ok {
t.Fatal("not the transport we expected")
}
@ -811,7 +811,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
if !ok {
t.Fatal("not the transport we expected")
}
dotcp, ok := txp.RoundTripper.(resolver.DNSOverTCP)
dotcp, ok := txp.RoundTripper.(*resolver.DNSOverTCP)
if !ok {
t.Fatal("not the transport we expected")
}
@ -831,7 +831,7 @@ func TestNewDNSClientDoT(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
txp, ok := r.Transport().(resolver.DNSOverTCP)
txp, ok := r.Transport().(*resolver.DNSOverTCP)
if !ok {
t.Fatal("not the transport we expected")
}
@ -856,7 +856,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
if !ok {
t.Fatal("not the transport we expected")
}
dotls, ok := txp.RoundTripper.(resolver.DNSOverTCP)
dotls, ok := txp.RoundTripper.(*resolver.DNSOverTCP)
if !ok {
t.Fatal("not the transport we expected")
}

View File

@ -11,28 +11,35 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
)
// HTTPClient is the HTTP client expected by DNSOverHTTPS.
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
CloseIdleConnections()
}
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
// an HTTP/HTTPS channel provided by URL using the Do function.
type DNSOverHTTPS struct {
Do func(req *http.Request) (*http.Response, error)
Client HTTPClient
URL string
HostOverride string
}
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
// specified http.Client and URL, as a convenience.
func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS {
func NewDNSOverHTTPS(client *http.Client, URL string) *DNSOverHTTPS {
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
}
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
// it's creating a resolver where we use the specified host.
func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS {
return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride}
func NewDNSOverHTTPSWithHostOverride(
client *http.Client, URL, hostOverride string) *DNSOverHTTPS {
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
@ -43,7 +50,7 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro
req.Header.Set("user-agent", httpheader.UserAgent())
req.Header.Set("content-type", "application/dns-message")
var resp *http.Response
resp, err = t.Do(req.WithContext(ctx))
resp, err = t.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -60,18 +67,23 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro
}
// RequiresPadding returns true for DoH according to RFC8467
func (t DNSOverHTTPS) RequiresPadding() bool {
func (t *DNSOverHTTPS) RequiresPadding() bool {
return true
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverHTTPS) Network() string {
func (t *DNSOverHTTPS) Network() string {
return "doh"
}
// Address returns the upstream server address.
func (t DNSOverHTTPS) Address() string {
func (t *DNSOverHTTPS) Address() string {
return t.URL
}
var _ RoundTripper = DNSOverHTTPS{}
// CloseIdleConnections closes idle connections.
func (t *DNSOverHTTPS) CloseIdleConnections() {
t.Client.CloseIdleConnections()
}
var _ RoundTripper = &DNSOverHTTPS{}

View File

@ -1,4 +1,4 @@
package resolver_test
package resolver
import (
"bytes"
@ -10,12 +10,12 @@ import (
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
const invalidURL = "\t"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here")
@ -27,9 +27,11 @@ func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
expected := errors.New("mocked error")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return nil, expected
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
@ -43,12 +45,14 @@ func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
}
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
}, nil
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
@ -62,12 +66,14 @@ func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
}
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("")),
}, nil
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
@ -82,15 +88,17 @@ func TestDNSOverHTTPSMissingContentType(t *testing.T) {
func TestDNSOverHTTPSSuccess(t *testing.T) {
body := []byte("AAA")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
@ -105,7 +113,7 @@ func TestDNSOverHTTPSSuccess(t *testing.T) {
func TestDNSOverHTTPTransportOK(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
@ -120,10 +128,12 @@ func TestDNSOverHTTPTransportOK(t *testing.T) {
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
return nil, expected
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
@ -144,10 +154,12 @@ func TestDNSOverHTTPSHostOverride(t *testing.T) {
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,

View File

@ -26,8 +26,8 @@ type DNSOverTCP struct {
}
// NewDNSOverTCP creates a new DNSOverTCP transport.
func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
return DNSOverTCP{
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
return &DNSOverTCP{
dial: dial,
address: address,
network: "tcp",
@ -36,8 +36,8 @@ func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
}
// NewDNSOverTLS creates a new DNSOverTLS transport.
func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
return DNSOverTCP{
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
return &DNSOverTCP{
dial: dial,
address: address,
network: "dot",
@ -46,7 +46,7 @@ func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
if len(query) > math.MaxUint16 {
return nil, errors.New("query too long")
}
@ -80,18 +80,23 @@ func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error)
// RequiresPadding returns true for DoT and false for TCP
// according to RFC8467.
func (t DNSOverTCP) RequiresPadding() bool {
func (t *DNSOverTCP) RequiresPadding() bool {
return t.requiresPadding
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverTCP) Network() string {
func (t *DNSOverTCP) Network() string {
return t.network
}
// Address returns the upstream server address.
func (t DNSOverTCP) Address() string {
func (t *DNSOverTCP) Address() string {
return t.address
}
var _ RoundTripper = DNSOverTCP{}
// CloseIdleConnections closes idle connections.
func (t *DNSOverTCP) CloseIdleConnections() {
// nothing to do
}
var _ RoundTripper = &DNSOverTCP{}

View File

@ -1,17 +1,15 @@
package resolver_test
package resolver
import (
"context"
"errors"
"net"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
if err == nil {
t.Fatal("expected an error here")
@ -24,8 +22,8 @@ func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Err: mocked}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
fakedialer := FakeDialer{Err: mocked}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -38,10 +36,10 @@ func TestDNSOverTCPTransportDialFailure(t *testing.T) {
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
fakedialer := FakeDialer{Conn: &FakeConn{
SetDeadlineError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -54,10 +52,10 @@ func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
fakedialer := FakeDialer{Conn: &FakeConn{
WriteError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -70,10 +68,10 @@ func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
fakedialer := FakeDialer{Conn: &FakeConn{
ReadError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -86,11 +84,11 @@ func TestDNSOverTCPTransportReadFailure(t *testing.T) {
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
fakedialer := FakeDialer{Conn: &FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(2)},
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -103,11 +101,11 @@ func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
func TestDNSOverTCPTransportAllGood(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
fakedialer := FakeDialer{Conn: &FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(1), byte(1)},
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if err != nil {
t.Fatal(err)
@ -119,7 +117,7 @@ func TestDNSOverTCPTransportAllGood(t *testing.T) {
func TestDNSOverTCPTransportOK(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
@ -133,7 +131,7 @@ func TestDNSOverTCPTransportOK(t *testing.T) {
func TestDNSOverTLSTransportOK(t *testing.T) {
const address = "9.9.9.9:853"
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address)
txp := NewDNSOverTLS(DialTLSContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}

View File

@ -18,12 +18,12 @@ type DNSOverUDP struct {
}
// NewDNSOverUDP creates a DNSOverUDP instance.
func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP {
return DNSOverUDP{dialer: dialer, address: address}
func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP {
return &DNSOverUDP{dialer: dialer, address: address}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil {
return nil, err
@ -47,18 +47,23 @@ func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error)
}
// RequiresPadding returns false for UDP according to RFC8467
func (t DNSOverUDP) RequiresPadding() bool {
func (t *DNSOverUDP) RequiresPadding() bool {
return false
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverUDP) Network() string {
func (t *DNSOverUDP) Network() string {
return "udp"
}
// Address returns the upstream server address.
func (t DNSOverUDP) Address() string {
func (t *DNSOverUDP) Address() string {
return t.address
}
var _ RoundTripper = DNSOverUDP{}
// CloseIdleConnections closes idle connections.
func (t *DNSOverUDP) CloseIdleConnections() {
// nothing to do
}
var _ RoundTripper = &DNSOverUDP{}

View File

@ -1,18 +1,16 @@
package resolver_test
package resolver
import (
"context"
"errors"
"net"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverUDPDialFailure(t *testing.T) {
mocked := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address)
txp := NewDNSOverUDP(FakeDialer{Err: mocked}, address)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
@ -24,9 +22,9 @@ func TestDNSOverUDPDialFailure(t *testing.T) {
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
txp := NewDNSOverUDP(
FakeDialer{
Conn: &FakeConn{
SetDeadlineError: mocked,
},
}, "9.9.9.9:53",
@ -42,9 +40,9 @@ func TestDNSOverUDPSetDeadlineError(t *testing.T) {
func TestDNSOverUDPWriteFailure(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
txp := NewDNSOverUDP(
FakeDialer{
Conn: &FakeConn{
WriteError: mocked,
},
}, "9.9.9.9:53",
@ -60,9 +58,9 @@ func TestDNSOverUDPWriteFailure(t *testing.T) {
func TestDNSOverUDPReadFailure(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
txp := NewDNSOverUDP(
FakeDialer{
Conn: &FakeConn{
ReadError: mocked,
},
}, "9.9.9.9:53",
@ -78,9 +76,9 @@ func TestDNSOverUDPReadFailure(t *testing.T) {
func TestDNSOverUDPReadSuccess(t *testing.T) {
const expected = 17
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{ReadData: make([]byte, 17)},
txp := NewDNSOverUDP(
FakeDialer{
Conn: &FakeConn{ReadData: make([]byte, 17)},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
@ -94,7 +92,7 @@ func TestDNSOverUDPReadSuccess(t *testing.T) {
func TestDNSOverUDPTransportOK(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverUDP(&net.Dialer{}, address)
txp := NewDNSOverUDP(&net.Dialer{}, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}

View File

@ -13,6 +13,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
func TestEmitterTransportSuccess(t *testing.T) {
@ -107,9 +108,11 @@ func TestEmitterResolverFailure(t *testing.T) {
}
ctx = modelx.WithMeasurementRoot(ctx, root)
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
return nil, io.EOF
&resolver.DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
return nil, io.EOF
},
},
URL: "https://dns.google.com/",
},

View File

@ -93,6 +93,10 @@ func (ft FakeTransport) Network() string {
return "fake"
}
func (fk FakeTransport) CloseIdleConnections() {
// nothing to do
}
type FakeEncoder struct {
Data []byte
Err error

View File

@ -22,6 +22,9 @@ type RoundTripper interface {
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
Address() string
// CloseIdleConnections closes idle connections.
CloseIdleConnections()
}
// SerialResolver is a resolver that first issues an A query and then
@ -81,7 +84,7 @@ func (r SerialResolver) roundTripWithRetry(
}
errorslist = append(errorslist, err)
var operr *net.OpError
if errors.As(err, &operr) == false || operr.Timeout() == false {
if !errors.As(err, &operr) || !operr.Timeout() {
// The first error is the one that is most likely to be caused
// by the network. Subsequent errors are more likely to be caused
// by context deadlines. So, the first error is attached to an

View File

@ -17,3 +17,20 @@ func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
func (txp *HTTPTransport) CloseIdleConnections() {
txp.MockCloseIdleConnections()
}
// HTTPClient allows mocking an http.Client.
type HTTPClient struct {
MockDo func(req *http.Request) (*http.Response, error)
MockCloseIdleConnections func()
}
// Do calls MockDo.
func (txp *HTTPClient) Do(req *http.Request) (*http.Response, error) {
return txp.MockDo(req)
}
// CloseIdleConnections calls MockCloseIdleConnections.
func (txp *HTTPClient) CloseIdleConnections() {
txp.MockCloseIdleConnections()
}

View File

@ -38,3 +38,34 @@ func TestHTTPTransport(t *testing.T) {
}
})
}
func TestHTTPClient(t *testing.T) {
t.Run("Do", func(t *testing.T) {
expected := errors.New("mocked error")
clnt := &HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
return nil, expected
},
}
resp, err := clnt.Do(&http.Request{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
called := &atomicx.Int64{}
clnt := &HTTPClient{
MockCloseIdleConnections: func() {
called.Add(1)
},
}
clnt.CloseIdleConnections()
if called.Load() != 1 {
t.Fatal("not called")
}
})
}