refactor(netx): move dns transports in netxlite/dnsx (#503)
While there, modernize the way in which we run tests to avoid depending on the fake files scattered around the tree and to use some well defined mock structures instead. Part of https://github.com/ooni/probe/issues/1591
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The Decoder decodes a DNS reply into A or AAAA entries. It will use the
|
||||
// provided qtype and only look for mathing entries. It will return error if
|
||||
// there are no entries for the requested qtype inside the reply.
|
||||
type Decoder interface {
|
||||
Decode(qtype uint16, data []byte) ([]string, error)
|
||||
}
|
||||
|
||||
// MiekgDecoder uses github.com/miekg/dns to implement the Decoder.
|
||||
type MiekgDecoder struct{}
|
||||
|
||||
// Decode implements Decoder.Decode.
|
||||
func (d *MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
|
||||
reply := new(dns.Msg)
|
||||
if err := reply.Unpack(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bassosimone): map more errors to net.DNSError names
|
||||
switch reply.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
case dns.RcodeNameError:
|
||||
return nil, errors.New("ooniresolver: no such host")
|
||||
default:
|
||||
return nil, errors.New("ooniresolver: query failed")
|
||||
}
|
||||
var addrs []string
|
||||
for _, answer := range reply.Answer {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
if rra, ok := answer.(*dns.A); ok {
|
||||
ip := rra.A
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if rra, ok := answer.(*dns.AAAA); ok {
|
||||
ip := rra.AAAA
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(addrs) <= 0 {
|
||||
return nil, errors.New("ooniresolver: no response returned")
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
var _ Decoder = &MiekgDecoder{}
|
||||
@@ -0,0 +1,181 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDecoderUnpackError(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNXDOMAIN(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, genReplyError(t, dns.RcodeNameError))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderOtherError(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, genReplyError(t, dns.RcodeRefused))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNoAddress(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, genReplySuccess(t, dns.TypeA))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeA(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "1.1.1.1" {
|
||||
t.Fatal("invalid first IPv4 entry")
|
||||
}
|
||||
if data[1] != "8.8.8.8" {
|
||||
t.Fatal("invalid second IPv4 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeAAAA(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "::1" {
|
||||
t.Fatal("invalid first IPv6 entry")
|
||||
}
|
||||
if data[1] != "fe80::1" {
|
||||
t.Fatal("invalid second IPv6 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAReply(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func genReplyError(t *testing.T, code int) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetRcode(query, code)
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func genReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetReply(query)
|
||||
for _, ip := range ips {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
reply.Answer = append(reply.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
A: net.ParseIP(ip),
|
||||
})
|
||||
case dns.TypeAAAA:
|
||||
reply.Answer = append(reply.Answer, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
AAAA: net.ParseIP(ip),
|
||||
})
|
||||
}
|
||||
}
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
|
||||
)
|
||||
|
||||
// HTTPClient is the HTTP client expected by DNSOverHTTPS.
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
||||
// an HTTP/HTTPS channel provided by URL using the Do function.
|
||||
type DNSOverHTTPS struct {
|
||||
Client HTTPClient
|
||||
URL string
|
||||
HostOverride string
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
||||
// specified http.Client and URL, as a convenience.
|
||||
func NewDNSOverHTTPS(client *http.Client, URL string) *DNSOverHTTPS {
|
||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
||||
// it's creating a resolver where we use the specified host.
|
||||
func NewDNSOverHTTPSWithHostOverride(
|
||||
client *http.Client, URL, hostOverride string) *DNSOverHTTPS {
|
||||
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = t.HostOverride
|
||||
req.Header.Set("user-agent", httpheader.UserAgent())
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
var resp *http.Response
|
||||
resp, err = t.Client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
// TODO(bassosimone): we should map the status code to a
|
||||
// proper Error in the DNS context.
|
||||
return nil, errors.New("doh: server returned error")
|
||||
}
|
||||
if resp.Header.Get("content-type") != "application/dns-message" {
|
||||
return nil, errors.New("doh: invalid content-type")
|
||||
}
|
||||
return iox.ReadAllContext(ctx, resp.Body)
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoH according to RFC8467
|
||||
func (t *DNSOverHTTPS) RequiresPadding() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverHTTPS) Network() string {
|
||||
return "doh"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverHTTPS) Address() string {
|
||||
return t.URL
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverHTTPS) CloseIdleConnections() {
|
||||
t.Client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverHTTPS{}
|
||||
@@ -0,0 +1,177 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
||||
const invalidURL = "\t"
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 500,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: server returned error" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSSuccess(t *testing.T) {
|
||||
body := []byte("AAA")
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/dns-message"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, body) {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPTransportOK(t *testing.T) {
|
||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||
if txp.Network() != "doh" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("should require padding")
|
||||
}
|
||||
if txp.Address() != queryURL {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var correct bool
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct user agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHostOverride(t *testing.T) {
|
||||
var correct bool
|
||||
expected := errors.New("mocked error")
|
||||
|
||||
hostOverride := "test.com"
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Host == hostOverride
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct host override")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DialContextFunc is a generic function for dialing a connection.
|
||||
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
|
||||
// and NewDNSOverTLS to create specific instances that use plaintext
|
||||
// queries or encrypted queries over TLS.
|
||||
//
|
||||
// As a known bug, this implementation always creates a new connection
|
||||
// for each incoming query, thus increasing the response delay.
|
||||
type DNSOverTCP struct {
|
||||
dial DialContextFunc
|
||||
address string
|
||||
network string
|
||||
requiresPadding bool
|
||||
}
|
||||
|
||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
||||
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "tcp",
|
||||
requiresPadding: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "dot",
|
||||
requiresPadding: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
if len(query) > math.MaxUint16 {
|
||||
return nil, errors.New("query too long")
|
||||
}
|
||||
conn, err := t.dial(ctx, "tcp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write request
|
||||
buf := []byte{byte(len(query) >> 8)}
|
||||
buf = append(buf, byte(len(query)))
|
||||
buf = append(buf, query...)
|
||||
if _, err = conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
header := make([]byte, 2)
|
||||
if _, err = io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := int(header[0])<<8 | int(header[1])
|
||||
reply := make([]byte, length)
|
||||
if _, err = io.ReadFull(conn, reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoT and false for TCP
|
||||
// according to RFC8467.
|
||||
func (t *DNSOverTCP) RequiresPadding() bool {
|
||||
return t.requiresPadding
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverTCP) Network() string {
|
||||
return t.network
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverTCP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverTCP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverTCP{}
|
||||
@@ -0,0 +1,222 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
input := io.MultiReader(
|
||||
bytes.NewReader([]byte{byte(0), byte(2)}),
|
||||
&mocks.Reader{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
},
|
||||
)
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 || reply[0] != 1 {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "tcp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTLSTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:853"
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "dot" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialer is the network dialer interface assumed by this package.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// DNSOverUDP is a DNS over UDP RoundTripper.
|
||||
type DNSOverUDP struct {
|
||||
dialer Dialer
|
||||
address string
|
||||
}
|
||||
|
||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
||||
func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP {
|
||||
return &DNSOverUDP{dialer: dialer, address: address}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
// Use five seconds timeout like Bionic does. See
|
||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err = conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply := make([]byte, 1<<17)
|
||||
var n int
|
||||
n, err = conn.Read(reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply[:n], nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467
|
||||
func (t *DNSOverUDP) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverUDP) Network() string {
|
||||
return "udp"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverUDP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverUDP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ RoundTripper = &DNSOverUDP{}
|
||||
@@ -0,0 +1,157 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverUDPDialFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverUDP(&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
}, address)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPWriteFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadSuccess(t *testing.T) {
|
||||
const expected = 17
|
||||
input := bytes.NewReader(make([]byte, expected))
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != expected {
|
||||
t.Fatal("expected non nil data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverUDP(&net.Dialer{}, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "udp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package dnsx
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// The Encoder encodes DNS queries to bytes
|
||||
type Encoder interface {
|
||||
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
|
||||
}
|
||||
|
||||
// MiekgEncoder uses github.com/miekg/dns to implement the Encoder.
|
||||
type MiekgEncoder struct{}
|
||||
|
||||
const (
|
||||
// PaddingDesiredBlockSize is the size that the padded query should be multiple of
|
||||
PaddingDesiredBlockSize = 128
|
||||
|
||||
// EDNS0MaxResponseSize is the maximum response size for EDNS0
|
||||
EDNS0MaxResponseSize = 4096
|
||||
|
||||
// DNSSECEnabled turns on support for DNSSEC when using EDNS0
|
||||
DNSSECEnabled = true
|
||||
)
|
||||
|
||||
// Encode implements Encoder.Encode
|
||||
func (e *MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
if padding {
|
||||
query.SetEdns0(EDNS0MaxResponseSize, DNSSECEnabled)
|
||||
// Clients SHOULD pad queries to the closest multiple of
|
||||
// 128 octets RFC8467#section-4.1. We inflate the query
|
||||
// length by the size of the option (i.e. 4 octets). The
|
||||
// cast to uint is necessary to make the modulus operation
|
||||
// work as intended when the desiredBlockSize is smaller
|
||||
// than (query.Len()+4) ¯\_(ツ)_/¯.
|
||||
remainder := (PaddingDesiredBlockSize - uint(query.Len()+4)) % PaddingDesiredBlockSize
|
||||
opt := new(dns.EDNS0_PADDING)
|
||||
opt.Padding = make([]byte, remainder)
|
||||
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
||||
}
|
||||
return query.Pack()
|
||||
}
|
||||
|
||||
var _ Encoder = &MiekgEncoder{}
|
||||
@@ -0,0 +1,98 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestEncoderEncodeA(t *testing.T) {
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func TestEncoderEncodeAAAA(t *testing.T) {
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func validate(t *testing.T, data []byte, qtype byte) {
|
||||
// skipping over the query ID
|
||||
if data[2] != 1 {
|
||||
t.Fatal("FLAGS should only have RD set")
|
||||
}
|
||||
if data[3] != 0 {
|
||||
t.Fatal("RA|Z|Rcode should be zero")
|
||||
}
|
||||
if data[4] != 0 || data[5] != 1 {
|
||||
t.Fatal("QCOUNT high should be one")
|
||||
}
|
||||
if data[6] != 0 || data[7] != 0 {
|
||||
t.Fatal("ANCOUNT should be zero")
|
||||
}
|
||||
if data[8] != 0 || data[9] != 0 {
|
||||
t.Fatal("NSCOUNT should be zero")
|
||||
}
|
||||
if data[10] != 0 || data[11] != 0 {
|
||||
t.Fatal("ARCOUNT should be zero")
|
||||
}
|
||||
t.Log(data[12])
|
||||
if data[12] != 1 || data[13] != byte('x') {
|
||||
t.Fatal("The name does not contain 1:x")
|
||||
}
|
||||
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
|
||||
t.Fatal("The name does not contain 3:org")
|
||||
}
|
||||
if data[18] != 0 {
|
||||
t.Fatal("The name does not terminate where expected")
|
||||
}
|
||||
if data[19] != 0 && data[20] != qtype {
|
||||
t.Fatal("The query is not for the expected type")
|
||||
}
|
||||
if data[21] != 0 && data[22] != 1 {
|
||||
t.Fatal("The query is not IN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderPadding(t *testing.T) {
|
||||
// The purpose of this unit test is to make sure that for a wide
|
||||
// array of values we obtain the right query size.
|
||||
getquerylen := func(domainlen int, padding bool) int {
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode(
|
||||
// This is not a valid name because it ends up being way
|
||||
// longer than 255 octets. However, the library is allowing
|
||||
// us to generate such name and we are not going to send
|
||||
// it on the wire. Also, we check below that the query that
|
||||
// we generate is long enough, so we should be good.
|
||||
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
||||
dns.TypeA, padding,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
for domainlen := 1; domainlen <= 4000; domainlen++ {
|
||||
vanillalen := getquerylen(domainlen, false)
|
||||
paddedlen := getquerylen(domainlen, true)
|
||||
if vanillalen < domainlen {
|
||||
t.Fatal("vanillalen is smaller than domainlen")
|
||||
}
|
||||
if (paddedlen % PaddingDesiredBlockSize) != 0 {
|
||||
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
|
||||
}
|
||||
if paddedlen < vanillalen {
|
||||
t.Fatal("paddedlen is smaller than vanillalen")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package mocks
|
||||
|
||||
// Decoder allows mocking dnsx.Decoder.
|
||||
type Decoder struct {
|
||||
MockDecode func(qtype uint16, reply []byte) ([]string, error)
|
||||
}
|
||||
|
||||
// Decode calls MockDecode.
|
||||
func (e *Decoder) Decode(qtype uint16, reply []byte) ([]string, error) {
|
||||
return e.MockDecode(qtype, reply)
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
t.Run("Decode", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &Decoder{
|
||||
MockDecode: func(qtype uint16, reply []byte) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.Decode(dns.TypeA, make([]byte, 17))
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
// Package mocks contains mocks for dnsx.
|
||||
package mocks
|
||||
@@ -0,0 +1,11 @@
|
||||
package mocks
|
||||
|
||||
// Encoder allows mocking dnsx.Encoder.
|
||||
type Encoder struct {
|
||||
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, error)
|
||||
}
|
||||
|
||||
// Encode calls MockEncode.
|
||||
func (e *Encoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return e.MockEncode(domain, qtype, padding)
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestEncoder(t *testing.T) {
|
||||
t.Run("Encode", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &Encoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.Encode("dns.google", dns.TypeA, true)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package mocks
|
||||
|
||||
import "context"
|
||||
|
||||
// RoundTripper allows mocking dnsx.RoundTripper.
|
||||
type RoundTripper struct {
|
||||
MockRoundTrip func(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
|
||||
MockRequiresPadding func() bool
|
||||
|
||||
MockNetwork func() string
|
||||
|
||||
MockAddress func() string
|
||||
|
||||
MockCloseIdleConnections func()
|
||||
}
|
||||
|
||||
// RoundTrip calls MockRoundTrip.
|
||||
func (txp *RoundTripper) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return txp.MockRoundTrip(ctx, query)
|
||||
}
|
||||
|
||||
// RequiresPadding calls MockRequiresPadding.
|
||||
func (txp *RoundTripper) RequiresPadding() bool {
|
||||
return txp.MockRequiresPadding()
|
||||
}
|
||||
|
||||
// Network calls MockNetwork.
|
||||
func (txp *RoundTripper) Network() string {
|
||||
return txp.MockNetwork()
|
||||
}
|
||||
|
||||
// Address calls MockAddress.
|
||||
func (txp *RoundTripper) Address() string {
|
||||
return txp.MockAddress()
|
||||
}
|
||||
|
||||
// CloseIdleConnections calls MockCloseIdleConnections.
|
||||
func (txp *RoundTripper) CloseIdleConnections() {
|
||||
txp.MockCloseIdleConnections()
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
)
|
||||
|
||||
func TestRoundTripper(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := &RoundTripper{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16))
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RequiresPadding", func(t *testing.T) {
|
||||
txp := &RoundTripper{
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network", func(t *testing.T) {
|
||||
txp := &RoundTripper{
|
||||
MockNetwork: func() string {
|
||||
return "antani"
|
||||
},
|
||||
}
|
||||
if txp.Network() != "antani" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Address", func(t *testing.T) {
|
||||
txp := &RoundTripper{
|
||||
MockAddress: func() string {
|
||||
return "mascetti"
|
||||
},
|
||||
}
|
||||
if txp.Address() != "mascetti" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
called := &atomicx.Int64{}
|
||||
txp := &RoundTripper{
|
||||
MockCloseIdleConnections: func() {
|
||||
called.Add(1)
|
||||
},
|
||||
}
|
||||
txp.CloseIdleConnections()
|
||||
if called.Load() != 1 {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package dnsx
|
||||
|
||||
import "context"
|
||||
|
||||
// RoundTripper represents an abstract DNS transport.
|
||||
type RoundTripper interface {
|
||||
// RoundTrip sends a DNS query and receives the reply.
|
||||
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
|
||||
// RequiresPadding return true for DoH and DoT according to RFC8467
|
||||
RequiresPadding() bool
|
||||
|
||||
// Network is the network of the round tripper (e.g. "dot")
|
||||
Network() string
|
||||
|
||||
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
||||
Address() string
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
CloseIdleConnections()
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
)
|
||||
|
||||
// SerialResolver is a resolver that first issues an A query and then
|
||||
// issues an AAAA query for the requested domain.
|
||||
type SerialResolver struct {
|
||||
Encoder Encoder
|
||||
Decoder Decoder
|
||||
NumTimeouts *atomicx.Int64
|
||||
Txp RoundTripper
|
||||
}
|
||||
|
||||
// NewSerialResolver creates a new OONI Resolver instance.
|
||||
func NewSerialResolver(t RoundTripper) *SerialResolver {
|
||||
return &SerialResolver{
|
||||
Encoder: &MiekgEncoder{},
|
||||
Decoder: &MiekgDecoder{},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: t,
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the transport being used.
|
||||
func (r *SerialResolver) Transport() RoundTripper {
|
||||
return r.Txp
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (r *SerialResolver) Network() string {
|
||||
return r.Txp.Network()
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (r *SerialResolver) Address() string {
|
||||
return r.Txp.Address()
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (r *SerialResolver) CloseIdleConnections() {
|
||||
r.Txp.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var addrs []string
|
||||
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
|
||||
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
|
||||
if errA != nil && errAAAA != nil {
|
||||
return nil, errA
|
||||
}
|
||||
addrs = append(addrs, addrsA...)
|
||||
addrs = append(addrs, addrsAAAA...)
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (r *SerialResolver) roundTripWithRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
var errorslist []error
|
||||
for i := 0; i < 3; i++ {
|
||||
replies, err := r.roundTrip(ctx, hostname, qtype)
|
||||
if err == nil {
|
||||
return replies, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
var operr *net.OpError
|
||||
if !errors.As(err, &operr) || !operr.Timeout() {
|
||||
// The first error is the one that is most likely to be caused
|
||||
// by the network. Subsequent errors are more likely to be caused
|
||||
// by context deadlines. So, the first error is attached to an
|
||||
// operation, while subsequent errors may possibly not be. If
|
||||
// so, the resulting failing operation is not correct.
|
||||
break
|
||||
}
|
||||
r.NumTimeouts.Add(1)
|
||||
}
|
||||
// bugfix: we MUST return one of the errors otherwise we confuse the
|
||||
// mechanism in errwrap that classifies the root cause operation, since
|
||||
// it would not be able to find a child with a major operation error
|
||||
return nil, errorslist[0]
|
||||
}
|
||||
|
||||
func (r *SerialResolver) roundTrip(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.Decode(qtype, replydata)
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
)
|
||||
|
||||
func TestOONIGettingTransport(t *testing.T) {
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := NewSerialResolver(txp)
|
||||
rtx := r.Transport()
|
||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
if r.Network() != rtx.Network() {
|
||||
t.Fatal("invalid network seen from the resolver")
|
||||
}
|
||||
if r.Address() != rtx.Address() {
|
||||
t.Fatal("invalid address seen from the resolver")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIEncodeError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := SerialResolver{
|
||||
Encoder: &mocks.Encoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
},
|
||||
Txp: txp,
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIRoundTripError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := &mocks.RoundTripper{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, mocked
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithEmptyReply(t *testing.T) {
|
||||
txp := &mocks.RoundTripper{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return genReplySuccess(t, dns.TypeA), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithAReply(t *testing.T) {
|
||||
txp := &mocks.RoundTripper{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return genReplySuccess(t, dns.TypeA, "8.8.8.8"), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithAAAAReply(t *testing.T) {
|
||||
txp := &mocks.RoundTripper{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return genReplySuccess(t, dns.TypeAAAA, "::1"), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "::1" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithTimeout(t *testing.T) {
|
||||
txp := &mocks.RoundTripper{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, &net.OpError{Err: errorsx.ETIMEDOUT, Op: "dial"}
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, errorsx.ETIMEDOUT) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
if r.NumTimeouts.Load() <= 0 {
|
||||
t.Fatal("we didn't actually take the timeouts")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user