refactor(netx): move dns transports in netxlite/dnsx (#503)
While there, modernize the way in which we run tests to avoid depending on the fake files scattered around the tree and to use some well defined mock structures instead. Part of https://github.com/ooni/probe/issues/1591
This commit is contained in:
@@ -1,54 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The Decoder decodes a DNS reply into A or AAAA entries. It will use the
|
||||
// provided qtype and only look for mathing entries. It will return error if
|
||||
// there are no entries for the requested qtype inside the reply.
|
||||
type Decoder interface {
|
||||
Decode(qtype uint16, data []byte) ([]string, error)
|
||||
}
|
||||
|
||||
// MiekgDecoder uses github.com/miekg/dns to implement the Decoder.
|
||||
type MiekgDecoder struct{}
|
||||
|
||||
// Decode implements Decoder.Decode.
|
||||
func (d *MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
|
||||
reply := new(dns.Msg)
|
||||
if err := reply.Unpack(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bassosimone): map more errors to net.DNSError names
|
||||
switch reply.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
case dns.RcodeNameError:
|
||||
return nil, errors.New("ooniresolver: no such host")
|
||||
default:
|
||||
return nil, errors.New("ooniresolver: query failed")
|
||||
}
|
||||
var addrs []string
|
||||
for _, answer := range reply.Answer {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
if rra, ok := answer.(*dns.A); ok {
|
||||
ip := rra.A
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if rra, ok := answer.(*dns.AAAA); ok {
|
||||
ip := rra.AAAA
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(addrs) <= 0 {
|
||||
return nil, errors.New("ooniresolver: no response returned")
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
var _ Decoder = &MiekgDecoder{}
|
||||
@@ -1,112 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDecoderUnpackError(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNXDOMAIN(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeNameError))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderOtherError(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeRefused))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNoAddress(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, GenReplySuccess(t, dns.TypeA))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeA(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "1.1.1.1" {
|
||||
t.Fatal("invalid first IPv4 entry")
|
||||
}
|
||||
if data[1] != "8.8.8.8" {
|
||||
t.Fatal("invalid second IPv4 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeAAAA(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "::1" {
|
||||
t.Fatal("invalid first IPv6 entry")
|
||||
}
|
||||
if data[1] != "fe80::1" {
|
||||
t.Fatal("invalid second IPv6 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAReply(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
|
||||
)
|
||||
|
||||
// HTTPClient is the HTTP client expected by DNSOverHTTPS.
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
||||
// an HTTP/HTTPS channel provided by URL using the Do function.
|
||||
type DNSOverHTTPS struct {
|
||||
Client HTTPClient
|
||||
URL string
|
||||
HostOverride string
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
||||
// specified http.Client and URL, as a convenience.
|
||||
func NewDNSOverHTTPS(client *http.Client, URL string) *DNSOverHTTPS {
|
||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
||||
// it's creating a resolver where we use the specified host.
|
||||
func NewDNSOverHTTPSWithHostOverride(
|
||||
client *http.Client, URL, hostOverride string) *DNSOverHTTPS {
|
||||
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = t.HostOverride
|
||||
req.Header.Set("user-agent", httpheader.UserAgent())
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
var resp *http.Response
|
||||
resp, err = t.Client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
// TODO(bassosimone): we should map the status code to a
|
||||
// proper Error in the DNS context.
|
||||
return nil, errors.New("doh: server returned error")
|
||||
}
|
||||
if resp.Header.Get("content-type") != "application/dns-message" {
|
||||
return nil, errors.New("doh: invalid content-type")
|
||||
}
|
||||
return iox.ReadAllContext(ctx, resp.Body)
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoH according to RFC8467
|
||||
func (t *DNSOverHTTPS) RequiresPadding() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverHTTPS) Network() string {
|
||||
return "doh"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverHTTPS) Address() string {
|
||||
return t.URL
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverHTTPS) CloseIdleConnections() {
|
||||
t.Client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverHTTPS{}
|
||||
@@ -1,177 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
||||
const invalidURL = "\t"
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 500,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: server returned error" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSSuccess(t *testing.T) {
|
||||
body := []byte("AAA")
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/dns-message"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, body) {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPTransportOK(t *testing.T) {
|
||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||
if txp.Network() != "doh" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("should require padding")
|
||||
}
|
||||
if txp.Address() != queryURL {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var correct bool
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct user agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHostOverride(t *testing.T) {
|
||||
var correct bool
|
||||
expected := errors.New("mocked error")
|
||||
|
||||
hostOverride := "test.com"
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Host == hostOverride
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct host override")
|
||||
}
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DialContextFunc is a generic function for dialing a connection.
|
||||
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
|
||||
// and NewDNSOverTLS to create specific instances that use plaintext
|
||||
// queries or encrypted queries over TLS.
|
||||
//
|
||||
// As a known bug, this implementation always creates a new connection
|
||||
// for each incoming query, thus increasing the response delay.
|
||||
type DNSOverTCP struct {
|
||||
dial DialContextFunc
|
||||
address string
|
||||
network string
|
||||
requiresPadding bool
|
||||
}
|
||||
|
||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
||||
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "tcp",
|
||||
requiresPadding: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "dot",
|
||||
requiresPadding: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
if len(query) > math.MaxUint16 {
|
||||
return nil, errors.New("query too long")
|
||||
}
|
||||
conn, err := t.dial(ctx, "tcp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write request
|
||||
buf := []byte{byte(len(query) >> 8)}
|
||||
buf = append(buf, byte(len(query)))
|
||||
buf = append(buf, query...)
|
||||
if _, err = conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
header := make([]byte, 2)
|
||||
if _, err = io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := int(header[0])<<8 | int(header[1])
|
||||
reply := make([]byte, length)
|
||||
if _, err = io.ReadFull(conn, reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoT and false for TCP
|
||||
// according to RFC8467.
|
||||
func (t *DNSOverTCP) RequiresPadding() bool {
|
||||
return t.requiresPadding
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverTCP) Network() string {
|
||||
return t.network
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverTCP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverTCP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverTCP{}
|
||||
@@ -1,144 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := FakeDialer{Err: mocked}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
}}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
WriteError: mocked,
|
||||
}}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
}}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(2)},
|
||||
}}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := FakeDialer{Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(1), byte(1)},
|
||||
}}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 || reply[0] != 1 {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "tcp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTLSTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:853"
|
||||
txp := NewDNSOverTLS(DialTLSContext, address)
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "dot" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialer is the network dialer interface assumed by this package.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// DNSOverUDP is a DNS over UDP RoundTripper.
|
||||
type DNSOverUDP struct {
|
||||
dialer Dialer
|
||||
address string
|
||||
}
|
||||
|
||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
||||
func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP {
|
||||
return &DNSOverUDP{dialer: dialer, address: address}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
// Use five seconds timeout like Bionic does. See
|
||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err = conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply := make([]byte, 1<<17)
|
||||
var n int
|
||||
n, err = conn.Read(reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply[:n], nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467
|
||||
func (t *DNSOverUDP) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverUDP) Network() string {
|
||||
return "udp"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverUDP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverUDP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverUDP{}
|
||||
@@ -1,105 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDNSOverUDPDialFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverUDP(FakeDialer{Err: mocked}, address)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPWriteFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{
|
||||
WriteError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{
|
||||
ReadError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadSuccess(t *testing.T) {
|
||||
const expected = 17
|
||||
txp := NewDNSOverUDP(
|
||||
FakeDialer{
|
||||
Conn: &FakeConn{ReadData: make([]byte, 17)},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != expected {
|
||||
t.Fatal("expected non nil data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverUDP(&net.Dialer{}, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "udp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// The Encoder encodes DNS queries to bytes
|
||||
type Encoder interface {
|
||||
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
|
||||
}
|
||||
|
||||
// MiekgEncoder uses github.com/miekg/dns to implement the Encoder.
|
||||
type MiekgEncoder struct{}
|
||||
|
||||
const (
|
||||
// PaddingDesiredBlockSize is the size that the padded query should be multiple of
|
||||
PaddingDesiredBlockSize = 128
|
||||
|
||||
// EDNS0MaxResponseSize is the maximum response size for EDNS0
|
||||
EDNS0MaxResponseSize = 4096
|
||||
|
||||
// DNSSECEnabled turns on support for DNSSEC when using EDNS0
|
||||
DNSSECEnabled = true
|
||||
)
|
||||
|
||||
// Encode implements Encoder.Encode
|
||||
func (e *MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
if padding {
|
||||
query.SetEdns0(EDNS0MaxResponseSize, DNSSECEnabled)
|
||||
// Clients SHOULD pad queries to the closest multiple of
|
||||
// 128 octets RFC8467#section-4.1. We inflate the query
|
||||
// length by the size of the option (i.e. 4 octets). The
|
||||
// cast to uint is necessary to make the modulus operation
|
||||
// work as intended when the desiredBlockSize is smaller
|
||||
// than (query.Len()+4) ¯\_(ツ)_/¯.
|
||||
remainder := (PaddingDesiredBlockSize - uint(query.Len()+4)) % PaddingDesiredBlockSize
|
||||
opt := new(dns.EDNS0_PADDING)
|
||||
opt.Padding = make([]byte, remainder)
|
||||
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
||||
}
|
||||
return query.Pack()
|
||||
}
|
||||
|
||||
var _ Encoder = &MiekgEncoder{}
|
||||
@@ -1,98 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestEncoderEncodeA(t *testing.T) {
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func TestEncoderEncodeAAAA(t *testing.T) {
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func validate(t *testing.T, data []byte, qtype byte) {
|
||||
// skipping over the query ID
|
||||
if data[2] != 1 {
|
||||
t.Fatal("FLAGS should only have RD set")
|
||||
}
|
||||
if data[3] != 0 {
|
||||
t.Fatal("RA|Z|Rcode should be zero")
|
||||
}
|
||||
if data[4] != 0 || data[5] != 1 {
|
||||
t.Fatal("QCOUNT high should be one")
|
||||
}
|
||||
if data[6] != 0 || data[7] != 0 {
|
||||
t.Fatal("ANCOUNT should be zero")
|
||||
}
|
||||
if data[8] != 0 || data[9] != 0 {
|
||||
t.Fatal("NSCOUNT should be zero")
|
||||
}
|
||||
if data[10] != 0 || data[11] != 0 {
|
||||
t.Fatal("ARCOUNT should be zero")
|
||||
}
|
||||
t.Log(data[12])
|
||||
if data[12] != 1 || data[13] != byte('x') {
|
||||
t.Fatal("The name does not contain 1:x")
|
||||
}
|
||||
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
|
||||
t.Fatal("The name does not contain 3:org")
|
||||
}
|
||||
if data[18] != 0 {
|
||||
t.Fatal("The name does not terminate where expected")
|
||||
}
|
||||
if data[19] != 0 && data[20] != qtype {
|
||||
t.Fatal("The query is not for the expected type")
|
||||
}
|
||||
if data[21] != 0 && data[22] != 1 {
|
||||
t.Fatal("The query is not IN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderPadding(t *testing.T) {
|
||||
// The purpose of this unit test is to make sure that for a wide
|
||||
// array of values we obtain the right query size.
|
||||
getquerylen := func(domainlen int, padding bool) int {
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode(
|
||||
// This is not a valid name because it ends up being way
|
||||
// longer than 255 octets. However, the library is allowing
|
||||
// us to generate such name and we are not going to send
|
||||
// it on the wire. Also, we check below that the query that
|
||||
// we generate is long enough, so we should be good.
|
||||
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
||||
dns.TypeA, padding,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
for domainlen := 1; domainlen <= 4000; domainlen++ {
|
||||
vanillalen := getquerylen(domainlen, false)
|
||||
paddedlen := getquerylen(domainlen, true)
|
||||
if vanillalen < domainlen {
|
||||
t.Fatal("vanillalen is smaller than domainlen")
|
||||
}
|
||||
if (paddedlen % PaddingDesiredBlockSize) != 0 {
|
||||
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
|
||||
}
|
||||
if paddedlen < vanillalen {
|
||||
t.Fatal("paddedlen is smaller than vanillalen")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package resolver
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx"
|
||||
|
||||
// Variables that other packages expect to find here but have been
|
||||
// moved into the internal/netxlite/dnsx package.
|
||||
var (
|
||||
NewSerialResolver = dnsx.NewSerialResolver
|
||||
NewDNSOverUDP = dnsx.NewDNSOverUDP
|
||||
NewDNSOverTCP = dnsx.NewDNSOverTCP
|
||||
NewDNSOverTLS = dnsx.NewDNSOverTLS
|
||||
NewDNSOverHTTPS = dnsx.NewDNSOverHTTPS
|
||||
NewDNSOverHTTPSWithHostOverride = dnsx.NewDNSOverHTTPSWithHostOverride
|
||||
)
|
||||
|
||||
// Types that other packages expect to find here but have been
|
||||
// moved into the internal/netxlite/dnsx package.
|
||||
type (
|
||||
DNSOverHTTPS = dnsx.DNSOverHTTPS
|
||||
DNSOverTCP = dnsx.DNSOverTCP
|
||||
DNSOverUDP = dnsx.DNSOverUDP
|
||||
MiekgEncoder = dnsx.MiekgEncoder
|
||||
MiekgDecoder = dnsx.MiekgDecoder
|
||||
RoundTripper = dnsx.RoundTripper
|
||||
SerialResolver = dnsx.SerialResolver
|
||||
Dialer = dnsx.Dialer
|
||||
DialContextFunc = dnsx.DialContextFunc
|
||||
)
|
||||
@@ -1,121 +0,0 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
)
|
||||
|
||||
// RoundTripper represents an abstract DNS transport.
|
||||
type RoundTripper interface {
|
||||
// RoundTrip sends a DNS query and receives the reply.
|
||||
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
|
||||
// RequiresPadding return true for DoH and DoT according to RFC8467
|
||||
RequiresPadding() bool
|
||||
|
||||
// Network is the network of the round tripper (e.g. "dot")
|
||||
Network() string
|
||||
|
||||
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
||||
Address() string
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// SerialResolver is a resolver that first issues an A query and then
|
||||
// issues an AAAA query for the requested domain.
|
||||
type SerialResolver struct {
|
||||
Encoder Encoder
|
||||
Decoder Decoder
|
||||
NumTimeouts *atomicx.Int64
|
||||
Txp RoundTripper
|
||||
}
|
||||
|
||||
// NewSerialResolver creates a new OONI Resolver instance.
|
||||
func NewSerialResolver(t RoundTripper) *SerialResolver {
|
||||
return &SerialResolver{
|
||||
Encoder: &MiekgEncoder{},
|
||||
Decoder: &MiekgDecoder{},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: t,
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the transport being used.
|
||||
func (r *SerialResolver) Transport() RoundTripper {
|
||||
return r.Txp
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (r *SerialResolver) Network() string {
|
||||
return r.Txp.Network()
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (r *SerialResolver) Address() string {
|
||||
return r.Txp.Address()
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (r *SerialResolver) CloseIdleConnections() {
|
||||
r.Txp.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var addrs []string
|
||||
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
|
||||
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
|
||||
if errA != nil && errAAAA != nil {
|
||||
return nil, errA
|
||||
}
|
||||
addrs = append(addrs, addrsA...)
|
||||
addrs = append(addrs, addrsAAAA...)
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (r *SerialResolver) roundTripWithRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
var errorslist []error
|
||||
for i := 0; i < 3; i++ {
|
||||
replies, err := r.roundTrip(ctx, hostname, qtype)
|
||||
if err == nil {
|
||||
return replies, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
var operr *net.OpError
|
||||
if !errors.As(err, &operr) || !operr.Timeout() {
|
||||
// The first error is the one that is most likely to be caused
|
||||
// by the network. Subsequent errors are more likely to be caused
|
||||
// by context deadlines. So, the first error is attached to an
|
||||
// operation, while subsequent errors may possibly not be. If
|
||||
// so, the resulting failing operation is not correct.
|
||||
break
|
||||
}
|
||||
r.NumTimeouts.Add(1)
|
||||
}
|
||||
// bugfix: we MUST return one of the errors otherwise we confuse the
|
||||
// mechanism in errwrap that classifies the root cause operation, since
|
||||
// it would not be able to find a child with a major operation error
|
||||
return nil, errorslist[0]
|
||||
}
|
||||
|
||||
func (r *SerialResolver) roundTrip(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.Decode(qtype, replydata)
|
||||
}
|
||||
|
||||
var _ Resolver = &SerialResolver{}
|
||||
@@ -1,111 +0,0 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestOONIGettingTransport(t *testing.T) {
|
||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853")
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
rtx := r.Transport()
|
||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
if r.Network() != rtx.Network() {
|
||||
t.Fatal("invalid network seen from the resolver")
|
||||
}
|
||||
if r.Address() != rtx.Address() {
|
||||
t.Fatal("invalid address seen from the resolver")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIEncodeError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853")
|
||||
r := resolver.SerialResolver{Encoder: resolver.FakeEncoder{Err: mocked}, Txp: txp}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIRoundTripError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.FakeTransport{Err: mocked}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithEmptyReply(t *testing.T) {
|
||||
txp := resolver.FakeTransport{Data: resolver.GenReplySuccess(t, dns.TypeA)}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithAReply(t *testing.T) {
|
||||
txp := resolver.FakeTransport{
|
||||
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
|
||||
}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithAAAAReply(t *testing.T) {
|
||||
txp := resolver.FakeTransport{
|
||||
Data: resolver.GenReplySuccess(t, dns.TypeAAAA, "::1"),
|
||||
}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "::1" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithTimeout(t *testing.T) {
|
||||
txp := resolver.FakeTransport{
|
||||
Err: &net.OpError{Err: syscall.ETIMEDOUT, Op: "dial"},
|
||||
}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, syscall.ETIMEDOUT) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
if r.NumTimeouts.Load() <= 0 {
|
||||
t.Fatal("we didn't actually take the timeouts")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user