refactor: merge dnsx and errorsx into netxlite (#517)

When preparing a tutorial for netxlite, I figured it is easier
to tell people "hey, this is the package you should use for all
low-level networking stuff" rather than introducing people to
a set of packages working together where some piece of functionality
is here and some other piece is there.

Part of https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso
2021-09-28 12:42:01 +02:00
committed by GitHub
parent de130d249c
commit 6d3a4f1db8
169 changed files with 575 additions and 671 deletions
-100
View File
@@ -1,100 +0,0 @@
package dnsx
import (
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/model"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
)
// HTTPSSvc is an HTTPSSvc reply.
type HTTPSSvc = model.HTTPSSvc
// The DNSDecoder decodes DNS replies.
type DNSDecoder interface {
// DecodeLookupHost decodes an A or AAAA reply.
DecodeLookupHost(qtype uint16, data []byte) ([]string, error)
// DecodeHTTPS decodes an HTTPS reply.
DecodeHTTPS(data []byte) (*HTTPSSvc, error)
}
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
type DNSDecoderMiekg struct{}
func (d *DNSDecoderMiekg) parseReply(data []byte) (*dns.Msg, error) {
reply := new(dns.Msg)
if err := reply.Unpack(data); err != nil {
return nil, err
}
// TODO(bassosimone): map more errors to net.DNSError names
// TODO(bassosimone): add support for lame referral.
switch reply.Rcode {
case dns.RcodeSuccess:
return reply, nil
case dns.RcodeNameError:
return nil, errorsx.ErrOODNSNoSuchHost
case dns.RcodeRefused:
return nil, errorsx.ErrOODNSRefused
default:
return nil, errorsx.ErrOODNSMisbehaving
}
}
func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte) (*HTTPSSvc, error) {
reply, err := d.parseReply(data)
if err != nil {
return nil, err
}
out := &HTTPSSvc{}
for _, answer := range reply.Answer {
switch avalue := answer.(type) {
case *dns.HTTPS:
for _, v := range avalue.Value {
switch extv := v.(type) {
case *dns.SVCBAlpn:
out.ALPN = extv.Alpn
case *dns.SVCBIPv4Hint:
for _, ip := range extv.Hint {
out.IPv4 = append(out.IPv4, ip.String())
}
case *dns.SVCBIPv6Hint:
for _, ip := range extv.Hint {
out.IPv6 = append(out.IPv6, ip.String())
}
}
}
}
}
if len(out.ALPN) <= 0 {
return nil, errorsx.ErrOODNSNoAnswer
}
return out, nil
}
func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte) ([]string, error) {
reply, err := d.parseReply(data)
if err != nil {
return nil, err
}
var addrs []string
for _, answer := range reply.Answer {
switch qtype {
case dns.TypeA:
if rra, ok := answer.(*dns.A); ok {
ip := rra.A
addrs = append(addrs, ip.String())
}
case dns.TypeAAAA:
if rra, ok := answer.(*dns.AAAA); ok {
ip := rra.AAAA
addrs = append(addrs, ip.String())
}
}
}
if len(addrs) <= 0 {
return nil, errorsx.ErrOODNSNoAnswer
}
return addrs, nil
}
var _ DNSDecoder = &DNSDecoderMiekg{}
-311
View File
@@ -1,311 +0,0 @@
package dnsx
import (
"errors"
"net"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
)
func TestDNSDecoder(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
t.Run("UnpackError", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(dns.TypeA, nil)
if err == nil {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("NXDOMAIN", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(
dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeNameError))
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("Refused", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(
dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeRefused))
if !errors.Is(err, errorsx.ErrOODNSRefused) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("no address", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA))
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("decode A", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(
dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "1.1.1.1" {
t.Fatal("invalid first IPv4 entry")
}
if data[1] != "8.8.8.8" {
t.Fatal("invalid second IPv4 entry")
}
})
t.Run("decode AAAA", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(
dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "::1" {
t.Fatal("invalid first IPv6 entry")
}
if data[1] != "fe80::1" {
t.Fatal("invalid second IPv6 entry")
}
})
t.Run("unexpected A reply", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(
dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("unexpected AAAA reply", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(
dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
})
t.Run("parseReply", func(t *testing.T) {
d := &DNSDecoderMiekg{}
msg := &dns.Msg{}
msg.Rcode = dns.RcodeFormatError // an rcode we don't handle
data, err := msg.Pack()
if err != nil {
t.Fatal(err)
}
reply, err := d.parseReply(data)
if !errors.Is(err, errorsx.ErrOODNSMisbehaving) { // catch all error
t.Fatal("not the error we expected", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("DecodeHTTPS", func(t *testing.T) {
t.Run("with nil data", func(t *testing.T) {
d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(nil)
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("not the error we expected", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("with empty answer", func(t *testing.T) {
data := dnsGenHTTPSReplySuccess(t, nil, nil, nil)
d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data)
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("with full answer", func(t *testing.T) {
alpn := []string{"h3"}
v4 := []string{"1.1.1.1"}
v6 := []string{"::1"}
data := dnsGenHTTPSReplySuccess(t, alpn, v4, v6)
d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
t.Fatal(diff)
}
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
t.Fatal(diff)
}
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
t.Fatal(diff)
}
})
})
}
// dnsGenReplyWithError generates a DNS reply for the given
// query type (e.g., dns.TypeA) using code as the Rcode.
func dnsGenReplyWithError(t *testing.T, qtype uint16, code int) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetRcode(query, code)
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given
// qtype (e.g., dns.TypeA) containing the given ips... in the answer.
func dnsGenLookupHostReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query)
for _, ip := range ips {
switch qtype {
case dns.TypeA:
reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
A: net.ParseIP(ip),
})
case dns.TypeAAAA:
reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: net.ParseIP(ip),
})
}
}
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
// dnsGenHTTPSReplySuccess generates a successful HTTPS response containing
// the given (possibly nil) alpns, ipv4s, and ipv6s.
func dnsGenHTTPSReplySuccess(t *testing.T, alpns, ipv4s, ipv6s []string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: dns.TypeHTTPS,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query)
answer := &dns.HTTPS{
SVCB: dns.SVCB{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: dns.TypeHTTPS,
Class: dns.ClassINET,
Ttl: 100,
},
Target: dns.Fqdn("x.org"),
Value: []dns.SVCBKeyValue{},
},
}
reply.Answer = append(reply.Answer, answer)
if len(alpns) > 0 {
answer.Value = append(answer.Value, &dns.SVCBAlpn{Alpn: alpns})
}
if len(ipv4s) > 0 {
var addrs []net.IP
for _, addr := range ipv4s {
addrs = append(addrs, net.ParseIP(addr))
}
answer.Value = append(answer.Value, &dns.SVCBIPv4Hint{Hint: addrs})
}
if len(ipv6s) > 0 {
var addrs []net.IP
for _, addr := range ipv6s {
addrs = append(addrs, net.ParseIP(addr))
}
answer.Value = append(answer.Value, &dns.SVCBIPv6Hint{Hint: addrs})
}
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
-52
View File
@@ -1,52 +0,0 @@
package dnsx
import "github.com/miekg/dns"
// The DNSEncoder encodes DNS queries to bytes
type DNSEncoder interface {
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
}
// DNSEncoderMiekg uses github.com/miekg/dns to implement the Encoder.
type DNSEncoderMiekg struct{}
const (
// dnsPaddingDesiredBlockSize is the size that the padded query should be multiple of
dnsPaddingDesiredBlockSize = 128
// dnsEDNS0MaxResponseSize is the maximum response size for EDNS0
dnsEDNS0MaxResponseSize = 4096
// dnsDNSSECEnabled turns on support for DNSSEC when using EDNS0
dnsDNSSECEnabled = true
)
// Encode implements Encoder.Encode
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
question := dns.Question{
Name: dns.Fqdn(domain),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
if padding {
query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled)
// Clients SHOULD pad queries to the closest multiple of
// 128 octets RFC8467#section-4.1. We inflate the query
// length by the size of the option (i.e. 4 octets). The
// cast to uint is necessary to make the modulus operation
// work as intended when the desiredBlockSize is smaller
// than (query.Len()+4) ¯\_(ツ)_/¯.
remainder := (dnsPaddingDesiredBlockSize - uint(query.Len()+4)) % dnsPaddingDesiredBlockSize
opt := new(dns.EDNS0_PADDING)
opt.Padding = make([]byte, remainder)
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
}
return query.Pack()
}
var _ DNSEncoder = &DNSEncoderMiekg{}
-103
View File
@@ -1,103 +0,0 @@
package dnsx
import (
"strings"
"testing"
"github.com/miekg/dns"
)
func TestDNSEncoder(t *testing.T) {
t.Run("encode A", func(t *testing.T) {
e := &DNSEncoderMiekg{}
data, err := e.Encode("x.org", dns.TypeA, false)
if err != nil {
t.Fatal(err)
}
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
})
t.Run("encode AAAA", func(t *testing.T) {
e := &DNSEncoderMiekg{}
data, err := e.Encode("x.org", dns.TypeAAAA, false)
if err != nil {
t.Fatal(err)
}
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
})
t.Run("encode padding", func(t *testing.T) {
// The purpose of this unit test is to make sure that for a wide
// array of values we obtain the right query size.
getquerylen := func(domainlen int, padding bool) int {
e := &DNSEncoderMiekg{}
data, err := e.Encode(
// This is not a valid name because it ends up being way
// longer than 255 octets. However, the library is allowing
// us to generate such name and we are not going to send
// it on the wire. Also, we check below that the query that
// we generate is long enough, so we should be good.
dns.Fqdn(strings.Repeat("x.", domainlen)),
dns.TypeA, padding,
)
if err != nil {
t.Fatal(err)
}
return len(data)
}
for domainlen := 1; domainlen <= 4000; domainlen++ {
vanillalen := getquerylen(domainlen, false)
paddedlen := getquerylen(domainlen, true)
if vanillalen < domainlen {
t.Fatal("vanillalen is smaller than domainlen")
}
if (paddedlen % dnsPaddingDesiredBlockSize) != 0 {
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
}
if paddedlen < vanillalen {
t.Fatal("paddedlen is smaller than vanillalen")
}
}
})
}
// dnsValidateEncodedQueryBytes validates the query serialized in data
// for the given query type qtype (e.g., dns.TypeAAAA).
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) {
// skipping over the query ID
if data[2] != 1 {
t.Fatal("FLAGS should only have RD set")
}
if data[3] != 0 {
t.Fatal("RA|Z|Rcode should be zero")
}
if data[4] != 0 || data[5] != 1 {
t.Fatal("QCOUNT high should be one")
}
if data[6] != 0 || data[7] != 0 {
t.Fatal("ANCOUNT should be zero")
}
if data[8] != 0 || data[9] != 0 {
t.Fatal("NSCOUNT should be zero")
}
if data[10] != 0 || data[11] != 0 {
t.Fatal("ARCOUNT should be zero")
}
t.Log(data[12])
if data[12] != 1 || data[13] != byte('x') {
t.Fatal("The name does not contain 1:x")
}
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
t.Fatal("The name does not contain 3:org")
}
if data[18] != 0 {
t.Fatal("The name does not terminate where expected")
}
if data[19] != 0 && data[20] != qtype {
t.Fatal("The query is not for the expected type")
}
if data[21] != 0 && data[22] != 1 {
t.Fatal("The query is not IN")
}
}
-89
View File
@@ -1,89 +0,0 @@
package dnsx
import (
"bytes"
"context"
"errors"
"net/http"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
)
// HTTPClient is the HTTP client expected by DNSOverHTTPS.
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
CloseIdleConnections()
}
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
// an HTTP/HTTPS channel provided by URL using the Do function.
type DNSOverHTTPS struct {
Client HTTPClient
URL string
HostOverride string
}
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
// specified http.Client and URL, as a convenience.
func NewDNSOverHTTPS(client HTTPClient, URL string) *DNSOverHTTPS {
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
}
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
// it's creating a resolver where we use the specified host.
func NewDNSOverHTTPSWithHostOverride(
client HTTPClient, URL, hostOverride string) *DNSOverHTTPS {
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
if err != nil {
return nil, err
}
req.Host = t.HostOverride
req.Header.Set("user-agent", httpheader.UserAgent())
req.Header.Set("content-type", "application/dns-message")
var resp *http.Response
resp, err = t.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
// TODO(bassosimone): we should map the status code to a
// proper Error in the DNS context.
return nil, errors.New("doh: server returned error")
}
if resp.Header.Get("content-type") != "application/dns-message" {
return nil, errors.New("doh: invalid content-type")
}
return iox.ReadAllContext(ctx, resp.Body)
}
// RequiresPadding returns true for DoH according to RFC8467
func (t *DNSOverHTTPS) RequiresPadding() bool {
return true
}
// Network returns the transport network (e.g., doh, dot)
func (t *DNSOverHTTPS) Network() string {
return "doh"
}
// Address returns the upstream server address.
func (t *DNSOverHTTPS) Address() string {
return t.URL
}
// CloseIdleConnections closes idle connections.
func (t *DNSOverHTTPS) CloseIdleConnections() {
t.Client.CloseIdleConnections()
}
var _ DNSTransport = &DNSOverHTTPS{}
-196
View File
@@ -1,196 +0,0 @@
package dnsx
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
func TestDNSOverHTTPS(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t"
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("client.Do failure", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("server returns 500", func(t *testing.T) {
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("missing content type", func(t *testing.T) {
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("success", func(t *testing.T) {
body := []byte("AAA")
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, body) {
t.Fatal("not the response we expected")
}
})
t.Run("sets the correct user-agent", func(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct user agent")
}
})
t.Run("we can override the Host header", func(t *testing.T) {
var correct bool
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
})
})
t.Run("other functions behave correctly", func(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
if txp.RequiresPadding() != true {
t.Fatal("should require padding")
}
if txp.Address() != queryURL {
t.Fatal("invalid address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
doh := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockCloseIdleConnections: func() {
called = true
},
},
}
doh.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}
-102
View File
@@ -1,102 +0,0 @@
package dnsx
import (
"context"
"errors"
"io"
"math"
"net"
"time"
)
// DialContextFunc is a generic function for dialing a connection.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
// and NewDNSOverTLS to create specific instances that use plaintext
// queries or encrypted queries over TLS.
//
// As a known bug, this implementation always creates a new connection
// for each incoming query, thus increasing the response delay.
type DNSOverTCP struct {
dial DialContextFunc
address string
network string
requiresPadding bool
}
// NewDNSOverTCP creates a new DNSOverTCP transport.
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
return &DNSOverTCP{
dial: dial,
address: address,
network: "tcp",
requiresPadding: false,
}
}
// NewDNSOverTLS creates a new DNSOverTLS transport.
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
return &DNSOverTCP{
dial: dial,
address: address,
network: "dot",
requiresPadding: true,
}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
if len(query) > math.MaxUint16 {
return nil, errors.New("query too long")
}
conn, err := t.dial(ctx, "tcp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return nil, err
}
// Write request
buf := []byte{byte(len(query) >> 8)}
buf = append(buf, byte(len(query)))
buf = append(buf, query...)
if _, err = conn.Write(buf); err != nil {
return nil, err
}
// Read response
header := make([]byte, 2)
if _, err = io.ReadFull(conn, header); err != nil {
return nil, err
}
length := int(header[0])<<8 | int(header[1])
reply := make([]byte, length)
if _, err = io.ReadFull(conn, reply); err != nil {
return nil, err
}
return reply, nil
}
// RequiresPadding returns true for DoT and false for TCP
// according to RFC8467.
func (t *DNSOverTCP) RequiresPadding() bool {
return t.requiresPadding
}
// Network returns the transport network (e.g., doh, dot)
func (t *DNSOverTCP) Network() string {
return t.network
}
// Address returns the upstream server address.
func (t *DNSOverTCP) Address() string {
return t.address
}
// CloseIdleConnections closes idle connections.
func (t *DNSOverTCP) CloseIdleConnections() {
// nothing to do
}
var _ DNSTransport = &DNSOverTCP{}
-226
View File
@@ -1,226 +0,0 @@
package dnsx
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
func TestDNSOverTCP(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("query too large", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
if err == nil {
t.Fatal("expected an error here")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("dial failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("SetDeadline failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("write failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("first read fails", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("second read fails", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
input := io.MultiReader(
bytes.NewReader([]byte{byte(0), byte(2)}),
&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
},
)
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("successful case", func(t *testing.T) {
const address = "9.9.9.9:53"
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if err != nil {
t.Fatal(err)
}
if len(reply) != 1 || reply[0] != 1 {
t.Fatal("not the response we expected")
}
})
})
t.Run("other functions okay with TCP", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "tcp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
})
t.Run("other functions okay with TLS", func(t *testing.T) {
const address = "9.9.9.9:853"
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "dot" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
})
}
-69
View File
@@ -1,69 +0,0 @@
package dnsx
import (
"context"
"net"
"time"
)
// Dialer is the network dialer interface assumed by this package.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// DNSOverUDP is a DNS over UDP RoundTripper.
type DNSOverUDP struct {
dialer Dialer
address string
}
// NewDNSOverUDP creates a DNSOverUDP instance.
func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP {
return &DNSOverUDP{dialer: dialer, address: address}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
// Use five seconds timeout like Bionic does. See
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, err
}
if _, err = conn.Write(query); err != nil {
return nil, err
}
reply := make([]byte, 1<<17)
var n int
n, err = conn.Read(reply)
if err != nil {
return nil, err
}
return reply[:n], nil
}
// RequiresPadding returns false for UDP according to RFC8467
func (t *DNSOverUDP) RequiresPadding() bool {
return false
}
// Network returns the transport network (e.g., doh, dot)
func (t *DNSOverUDP) Network() string {
return "udp"
}
// Address returns the upstream server address.
func (t *DNSOverUDP) Address() string {
return t.address
}
// CloseIdleConnections closes idle connections.
func (t *DNSOverUDP) CloseIdleConnections() {
// nothing to do
}
var _ DNSTransport = &DNSOverUDP{}
-161
View File
@@ -1,161 +0,0 @@
package dnsx
import (
"bytes"
"context"
"errors"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
func TestDNSOverUDP(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("dial failure", func(t *testing.T) {
mocked := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverUDP(&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}, address)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("SetDeadline failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverUDP(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("Write failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverUDP(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("Read failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverUDP(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("read success", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewDNSOverUDP(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(data) != expected {
t.Fatal("expected non nil data")
}
})
})
t.Run("other functions okay", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewDNSOverUDP(&net.Dialer{}, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "udp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
})
}
-21
View File
@@ -1,21 +0,0 @@
package dnsx
import "context"
// DNSTransport represents an abstract DNS transport.
type DNSTransport interface {
// RoundTrip sends a DNS query and receives the reply.
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
// RequiresPadding return true for DoH and DoT according to RFC8467
RequiresPadding() bool
// Network is the network of the round tripper (e.g. "dot")
Network() string
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
Address() string
// CloseIdleConnections closes idle connections.
CloseIdleConnections()
}
@@ -1,5 +1,5 @@
// Package model contains the dnsx model.
package model
// Package dnsx contains the dnsx model.
package dnsx
// HTTPSSvc is an HTTPSSvc reply.
type HTTPSSvc struct {
@@ -1,23 +0,0 @@
package mocks
import "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/model"
// HTTPSSvc is the result of HTTPS queries.
type HTTPSSvc = model.HTTPSSvc
// DNSDecoder allows mocking dnsx.DNSDecoder.
type DNSDecoder struct {
MockDecodeLookupHost func(qtype uint16, reply []byte) ([]string, error)
MockDecodeHTTPS func(reply []byte) (*HTTPSSvc, error)
}
// DecodeLookupHost calls MockDecodeLookupHost.
func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte) ([]string, error) {
return e.MockDecodeLookupHost(qtype, reply)
}
// DecodeHTTPS calls MockDecodeHTTPS.
func (e *DNSDecoder) DecodeHTTPS(reply []byte) (*HTTPSSvc, error) {
return e.MockDecodeHTTPS(reply)
}
@@ -1,42 +0,0 @@
package mocks
import (
"errors"
"testing"
"github.com/miekg/dns"
)
func TestDNSDecoder(t *testing.T) {
t.Run("DecodeLookupHost", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeLookupHost: func(qtype uint16, reply []byte) ([]string, error) {
return nil, expected
},
}
out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17))
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeHTTPS", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeHTTPS: func(reply []byte) (*HTTPSSvc, error) {
return nil, expected
},
}
out, err := e.DecodeHTTPS(make([]byte, 17))
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
}
@@ -1,11 +0,0 @@
package mocks
// DNSEncoder allows mocking dnsx.DNSEncoder.
type DNSEncoder struct {
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, error)
}
// Encode calls MockEncode.
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
return e.MockEncode(domain, qtype, padding)
}
@@ -1,26 +0,0 @@
package mocks
import (
"errors"
"testing"
"github.com/miekg/dns"
)
func TestDNSEncoder(t *testing.T) {
t.Run("Encode", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
return nil, expected
},
}
out, err := e.Encode("dns.google", dns.TypeA, true)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
}
@@ -1,41 +0,0 @@
package mocks
import "context"
// DNSTransport allows mocking dnsx.DNSTransport.
type DNSTransport struct {
MockRoundTrip func(ctx context.Context, query []byte) (reply []byte, err error)
MockRequiresPadding func() bool
MockNetwork func() string
MockAddress func() string
MockCloseIdleConnections func()
}
// RoundTrip calls MockRoundTrip.
func (txp *DNSTransport) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) {
return txp.MockRoundTrip(ctx, query)
}
// RequiresPadding calls MockRequiresPadding.
func (txp *DNSTransport) RequiresPadding() bool {
return txp.MockRequiresPadding()
}
// Network calls MockNetwork.
func (txp *DNSTransport) Network() string {
return txp.MockNetwork()
}
// Address calls MockAddress.
func (txp *DNSTransport) Address() string {
return txp.MockAddress()
}
// CloseIdleConnections calls MockCloseIdleConnections.
func (txp *DNSTransport) CloseIdleConnections() {
txp.MockCloseIdleConnections()
}
@@ -1,73 +0,0 @@
package mocks
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/atomicx"
)
func TestDNSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16))
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("RequiresPadding", func(t *testing.T) {
txp := &DNSTransport{
MockRequiresPadding: func() bool {
return true
},
}
if txp.RequiresPadding() != true {
t.Fatal("unexpected result")
}
})
t.Run("Network", func(t *testing.T) {
txp := &DNSTransport{
MockNetwork: func() string {
return "antani"
},
}
if txp.Network() != "antani" {
t.Fatal("unexpected result")
}
})
t.Run("Address", func(t *testing.T) {
txp := &DNSTransport{
MockAddress: func() string {
return "mascetti"
},
}
if txp.Address() != "mascetti" {
t.Fatal("unexpected result")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
called := &atomicx.Int64{}
txp := &DNSTransport{
MockCloseIdleConnections: func() {
called.Add(1)
},
}
txp.CloseIdleConnections()
if called.Load() != 1 {
t.Fatal("not called")
}
})
}
-2
View File
@@ -1,2 +0,0 @@
// Package mocks contains mocks for dnsx.
package mocks
-118
View File
@@ -1,118 +0,0 @@
package dnsx
import (
"context"
"errors"
"net"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
)
// SerialResolver is a resolver that first issues an A query and then
// issues an AAAA query for the requested domain.
type SerialResolver struct {
Encoder DNSEncoder
Decoder DNSDecoder
NumTimeouts *atomicx.Int64
Txp DNSTransport
}
// NewSerialResolver creates a new OONI Resolver instance.
func NewSerialResolver(t DNSTransport) *SerialResolver {
return &SerialResolver{
Encoder: &DNSEncoderMiekg{},
Decoder: &DNSDecoderMiekg{},
NumTimeouts: &atomicx.Int64{},
Txp: t,
}
}
// Transport returns the transport being used.
func (r *SerialResolver) Transport() DNSTransport {
return r.Txp
}
// Network implements Resolver.Network
func (r *SerialResolver) Network() string {
return r.Txp.Network()
}
// Address implements Resolver.Address
func (r *SerialResolver) Address() string {
return r.Txp.Address()
}
// CloseIdleConnections closes idle connections.
func (r *SerialResolver) CloseIdleConnections() {
r.Txp.CloseIdleConnections()
}
// LookupHost implements Resolver.LookupHost.
func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
var addrs []string
addrsA, errA := r.lookupHostWithRetry(ctx, hostname, dns.TypeA)
addrsAAAA, errAAAA := r.lookupHostWithRetry(ctx, hostname, dns.TypeAAAA)
if errA != nil && errAAAA != nil {
return nil, errA
}
addrs = append(addrs, addrsA...)
addrs = append(addrs, addrsAAAA...)
return addrs, nil
}
// LookupHTTPS implements Resolver.LookupHTTPS.
func (r *SerialResolver) LookupHTTPS(
ctx context.Context, hostname string) (*HTTPSSvc, error) {
querydata, err := r.Encoder.Encode(
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.DecodeHTTPS(replydata)
}
func (r *SerialResolver) lookupHostWithRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
var errorslist []error
for i := 0; i < 3; i++ {
replies, err := r.lookupHostWithoutRetry(ctx, hostname, qtype)
if err == nil {
return replies, nil
}
errorslist = append(errorslist, err)
var operr *net.OpError
if !errors.As(err, &operr) || !operr.Timeout() {
// The first error is the one that is most likely to be caused
// by the network. Subsequent errors are more likely to be caused
// by context deadlines. So, the first error is attached to an
// operation, while subsequent errors may possibly not be. If
// so, the resulting failing operation is not correct.
break
}
r.NumTimeouts.Add(1)
}
// bugfix: we MUST return one of the errors otherwise we confuse the
// mechanism in errwrap that classifies the root cause operation, since
// it would not be able to find a child with a major operation error
return nil, errorslist[0]
}
// lookupHostWithoutRetry issues a lookup host query for the specified
// qtype (dns.A or dns.AAAA) without retrying on failure.
func (r *SerialResolver) lookupHostWithoutRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.DecodeLookupHost(qtype, replydata)
}
@@ -1,257 +0,0 @@
package dnsx
import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
)
func TestSerialResolver(t *testing.T) {
t.Run("transport okay", func(t *testing.T) {
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewSerialResolver(txp)
rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
t.Fatal("not the transport we expected")
}
if r.Network() != rtx.Network() {
t.Fatal("invalid network seen from the resolver")
}
if r.Address() != rtx.Address() {
t.Fatal("invalid address seen from the resolver")
}
})
t.Run("LookupHost", func(t *testing.T) {
t.Run("Encode error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
return nil, mocked
},
},
Txp: txp,
}
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
})
t.Run("RoundTrip error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, mocked
},
MockRequiresPadding: func() bool {
return true
},
}
r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
})
t.Run("empty reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeA), nil
},
MockRequiresPadding: func() bool {
return true
},
}
r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if addrs != nil {
t.Fatal("expected nil address here")
}
})
t.Run("with A reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeA, "8.8.8.8"), nil
},
MockRequiresPadding: func() bool {
return true
},
}
r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
})
t.Run("with AAAA reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1"), nil
},
MockRequiresPadding: func() bool {
return true
},
}
r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "::1" {
t.Fatal("not the result we expected")
}
})
t.Run("with timeout", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, &net.OpError{Err: errorsx.ETIMEDOUT, Op: "dial"}
},
MockRequiresPadding: func() bool {
return true
},
}
r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, errorsx.ETIMEDOUT) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
if r.NumTimeouts.Load() <= 0 {
t.Fatal("we didn't actually take the timeouts")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
r := &SerialResolver{
Txp: &mocks.DNSTransport{
MockCloseIdleConnections: func() {
called = true
},
},
}
r.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("LookupHTTPS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
return nil, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
https, err := r.LookupHTTPS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
return make([]byte, 64), nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, expected
},
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
https, err := r.LookupHTTPS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("unexpected result")
}
})
t.Run("for decode error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
return make([]byte, 64), nil
},
},
Decoder: &mocks.DNSDecoder{
MockDecodeHTTPS: func(reply []byte) (*mocks.HTTPSSvc, error) {
return nil, expected
},
},
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 128), nil
},
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
https, err := r.LookupHTTPS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("unexpected result")
}
})
})
}