refactor: merge dnsx and errorsx into netxlite (#517)
When preparing a tutorial for netxlite, I figured it is easier to tell people "hey, this is the package you should use for all low-level networking stuff" rather than introducing people to a set of packages working together where some piece of functionality is here and some other piece is there. Part of https://github.com/ooni/probe/issues/1591
This commit is contained in:
@@ -1,100 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
)
|
||||
|
||||
// HTTPSSvc is an HTTPSSvc reply.
|
||||
type HTTPSSvc = model.HTTPSSvc
|
||||
|
||||
// The DNSDecoder decodes DNS replies.
|
||||
type DNSDecoder interface {
|
||||
// DecodeLookupHost decodes an A or AAAA reply.
|
||||
DecodeLookupHost(qtype uint16, data []byte) ([]string, error)
|
||||
|
||||
// DecodeHTTPS decodes an HTTPS reply.
|
||||
DecodeHTTPS(data []byte) (*HTTPSSvc, error)
|
||||
}
|
||||
|
||||
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
|
||||
type DNSDecoderMiekg struct{}
|
||||
|
||||
func (d *DNSDecoderMiekg) parseReply(data []byte) (*dns.Msg, error) {
|
||||
reply := new(dns.Msg)
|
||||
if err := reply.Unpack(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bassosimone): map more errors to net.DNSError names
|
||||
// TODO(bassosimone): add support for lame referral.
|
||||
switch reply.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
return reply, nil
|
||||
case dns.RcodeNameError:
|
||||
return nil, errorsx.ErrOODNSNoSuchHost
|
||||
case dns.RcodeRefused:
|
||||
return nil, errorsx.ErrOODNSRefused
|
||||
default:
|
||||
return nil, errorsx.ErrOODNSMisbehaving
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte) (*HTTPSSvc, error) {
|
||||
reply, err := d.parseReply(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &HTTPSSvc{}
|
||||
for _, answer := range reply.Answer {
|
||||
switch avalue := answer.(type) {
|
||||
case *dns.HTTPS:
|
||||
for _, v := range avalue.Value {
|
||||
switch extv := v.(type) {
|
||||
case *dns.SVCBAlpn:
|
||||
out.ALPN = extv.Alpn
|
||||
case *dns.SVCBIPv4Hint:
|
||||
for _, ip := range extv.Hint {
|
||||
out.IPv4 = append(out.IPv4, ip.String())
|
||||
}
|
||||
case *dns.SVCBIPv6Hint:
|
||||
for _, ip := range extv.Hint {
|
||||
out.IPv6 = append(out.IPv6, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(out.ALPN) <= 0 {
|
||||
return nil, errorsx.ErrOODNSNoAnswer
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte) ([]string, error) {
|
||||
reply, err := d.parseReply(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var addrs []string
|
||||
for _, answer := range reply.Answer {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
if rra, ok := answer.(*dns.A); ok {
|
||||
ip := rra.A
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if rra, ok := answer.(*dns.AAAA); ok {
|
||||
ip := rra.AAAA
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(addrs) <= 0 {
|
||||
return nil, errorsx.ErrOODNSNoAnswer
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
var _ DNSDecoder = &DNSDecoderMiekg{}
|
||||
@@ -1,311 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
)
|
||||
|
||||
func TestDNSDecoder(t *testing.T) {
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("UnpackError", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NXDOMAIN", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(
|
||||
dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeNameError))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Refused", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(
|
||||
dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeRefused))
|
||||
if !errors.Is(err, errorsx.ErrOODNSRefused) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no address", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA))
|
||||
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode A", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(
|
||||
dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "1.1.1.1" {
|
||||
t.Fatal("invalid first IPv4 entry")
|
||||
}
|
||||
if data[1] != "8.8.8.8" {
|
||||
t.Fatal("invalid second IPv4 entry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode AAAA", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(
|
||||
dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "::1" {
|
||||
t.Fatal("invalid first IPv6 entry")
|
||||
}
|
||||
if data[1] != "fe80::1" {
|
||||
t.Fatal("invalid second IPv6 entry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexpected A reply", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(
|
||||
dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexpected AAAA reply", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(
|
||||
dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
|
||||
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("parseReply", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
msg := &dns.Msg{}
|
||||
msg.Rcode = dns.RcodeFormatError // an rcode we don't handle
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
reply, err := d.parseReply(data)
|
||||
if !errors.Is(err, errorsx.ErrOODNSMisbehaving) { // catch all error
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeHTTPS", func(t *testing.T) {
|
||||
t.Run("with nil data", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeHTTPS(nil)
|
||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with empty answer", func(t *testing.T) {
|
||||
data := dnsGenHTTPSReplySuccess(t, nil, nil, nil)
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeHTTPS(data)
|
||||
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with full answer", func(t *testing.T) {
|
||||
alpn := []string{"h3"}
|
||||
v4 := []string{"1.1.1.1"}
|
||||
v6 := []string{"::1"}
|
||||
data := dnsGenHTTPSReplySuccess(t, alpn, v4, v6)
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeHTTPS(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// dnsGenReplyWithError generates a DNS reply for the given
|
||||
// query type (e.g., dns.TypeA) using code as the Rcode.
|
||||
func dnsGenReplyWithError(t *testing.T, qtype uint16, code int) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetRcode(query, code)
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given
|
||||
// qtype (e.g., dns.TypeA) containing the given ips... in the answer.
|
||||
func dnsGenLookupHostReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetReply(query)
|
||||
for _, ip := range ips {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
reply.Answer = append(reply.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
A: net.ParseIP(ip),
|
||||
})
|
||||
case dns.TypeAAAA:
|
||||
reply.Answer = append(reply.Answer, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
AAAA: net.ParseIP(ip),
|
||||
})
|
||||
}
|
||||
}
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// dnsGenHTTPSReplySuccess generates a successful HTTPS response containing
|
||||
// the given (possibly nil) alpns, ipv4s, and ipv6s.
|
||||
func dnsGenHTTPSReplySuccess(t *testing.T, alpns, ipv4s, ipv6s []string) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: dns.TypeHTTPS,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetReply(query)
|
||||
answer := &dns.HTTPS{
|
||||
SVCB: dns.SVCB{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: dns.TypeHTTPS,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 100,
|
||||
},
|
||||
Target: dns.Fqdn("x.org"),
|
||||
Value: []dns.SVCBKeyValue{},
|
||||
},
|
||||
}
|
||||
reply.Answer = append(reply.Answer, answer)
|
||||
if len(alpns) > 0 {
|
||||
answer.Value = append(answer.Value, &dns.SVCBAlpn{Alpn: alpns})
|
||||
}
|
||||
if len(ipv4s) > 0 {
|
||||
var addrs []net.IP
|
||||
for _, addr := range ipv4s {
|
||||
addrs = append(addrs, net.ParseIP(addr))
|
||||
}
|
||||
answer.Value = append(answer.Value, &dns.SVCBIPv4Hint{Hint: addrs})
|
||||
}
|
||||
if len(ipv6s) > 0 {
|
||||
var addrs []net.IP
|
||||
for _, addr := range ipv6s {
|
||||
addrs = append(addrs, net.ParseIP(addr))
|
||||
}
|
||||
answer.Value = append(answer.Value, &dns.SVCBIPv6Hint{Hint: addrs})
|
||||
}
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// The DNSEncoder encodes DNS queries to bytes
|
||||
type DNSEncoder interface {
|
||||
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
|
||||
}
|
||||
|
||||
// DNSEncoderMiekg uses github.com/miekg/dns to implement the Encoder.
|
||||
type DNSEncoderMiekg struct{}
|
||||
|
||||
const (
|
||||
// dnsPaddingDesiredBlockSize is the size that the padded query should be multiple of
|
||||
dnsPaddingDesiredBlockSize = 128
|
||||
|
||||
// dnsEDNS0MaxResponseSize is the maximum response size for EDNS0
|
||||
dnsEDNS0MaxResponseSize = 4096
|
||||
|
||||
// dnsDNSSECEnabled turns on support for DNSSEC when using EDNS0
|
||||
dnsDNSSECEnabled = true
|
||||
)
|
||||
|
||||
// Encode implements Encoder.Encode
|
||||
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
if padding {
|
||||
query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled)
|
||||
// Clients SHOULD pad queries to the closest multiple of
|
||||
// 128 octets RFC8467#section-4.1. We inflate the query
|
||||
// length by the size of the option (i.e. 4 octets). The
|
||||
// cast to uint is necessary to make the modulus operation
|
||||
// work as intended when the desiredBlockSize is smaller
|
||||
// than (query.Len()+4) ¯\_(ツ)_/¯.
|
||||
remainder := (dnsPaddingDesiredBlockSize - uint(query.Len()+4)) % dnsPaddingDesiredBlockSize
|
||||
opt := new(dns.EDNS0_PADDING)
|
||||
opt.Padding = make([]byte, remainder)
|
||||
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
||||
}
|
||||
return query.Pack()
|
||||
}
|
||||
|
||||
var _ DNSEncoder = &DNSEncoderMiekg{}
|
||||
@@ -1,103 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDNSEncoder(t *testing.T) {
|
||||
|
||||
t.Run("encode A", func(t *testing.T) {
|
||||
e := &DNSEncoderMiekg{}
|
||||
data, err := e.Encode("x.org", dns.TypeA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
|
||||
})
|
||||
|
||||
t.Run("encode AAAA", func(t *testing.T) {
|
||||
e := &DNSEncoderMiekg{}
|
||||
data, err := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
|
||||
})
|
||||
|
||||
t.Run("encode padding", func(t *testing.T) {
|
||||
// The purpose of this unit test is to make sure that for a wide
|
||||
// array of values we obtain the right query size.
|
||||
getquerylen := func(domainlen int, padding bool) int {
|
||||
e := &DNSEncoderMiekg{}
|
||||
data, err := e.Encode(
|
||||
// This is not a valid name because it ends up being way
|
||||
// longer than 255 octets. However, the library is allowing
|
||||
// us to generate such name and we are not going to send
|
||||
// it on the wire. Also, we check below that the query that
|
||||
// we generate is long enough, so we should be good.
|
||||
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
||||
dns.TypeA, padding,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
for domainlen := 1; domainlen <= 4000; domainlen++ {
|
||||
vanillalen := getquerylen(domainlen, false)
|
||||
paddedlen := getquerylen(domainlen, true)
|
||||
if vanillalen < domainlen {
|
||||
t.Fatal("vanillalen is smaller than domainlen")
|
||||
}
|
||||
if (paddedlen % dnsPaddingDesiredBlockSize) != 0 {
|
||||
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
|
||||
}
|
||||
if paddedlen < vanillalen {
|
||||
t.Fatal("paddedlen is smaller than vanillalen")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// dnsValidateEncodedQueryBytes validates the query serialized in data
|
||||
// for the given query type qtype (e.g., dns.TypeAAAA).
|
||||
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) {
|
||||
// skipping over the query ID
|
||||
if data[2] != 1 {
|
||||
t.Fatal("FLAGS should only have RD set")
|
||||
}
|
||||
if data[3] != 0 {
|
||||
t.Fatal("RA|Z|Rcode should be zero")
|
||||
}
|
||||
if data[4] != 0 || data[5] != 1 {
|
||||
t.Fatal("QCOUNT high should be one")
|
||||
}
|
||||
if data[6] != 0 || data[7] != 0 {
|
||||
t.Fatal("ANCOUNT should be zero")
|
||||
}
|
||||
if data[8] != 0 || data[9] != 0 {
|
||||
t.Fatal("NSCOUNT should be zero")
|
||||
}
|
||||
if data[10] != 0 || data[11] != 0 {
|
||||
t.Fatal("ARCOUNT should be zero")
|
||||
}
|
||||
t.Log(data[12])
|
||||
if data[12] != 1 || data[13] != byte('x') {
|
||||
t.Fatal("The name does not contain 1:x")
|
||||
}
|
||||
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
|
||||
t.Fatal("The name does not contain 3:org")
|
||||
}
|
||||
if data[18] != 0 {
|
||||
t.Fatal("The name does not terminate where expected")
|
||||
}
|
||||
if data[19] != 0 && data[20] != qtype {
|
||||
t.Fatal("The query is not for the expected type")
|
||||
}
|
||||
if data[21] != 0 && data[22] != 1 {
|
||||
t.Fatal("The query is not IN")
|
||||
}
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/iox"
|
||||
)
|
||||
|
||||
// HTTPClient is the HTTP client expected by DNSOverHTTPS.
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
||||
// an HTTP/HTTPS channel provided by URL using the Do function.
|
||||
type DNSOverHTTPS struct {
|
||||
Client HTTPClient
|
||||
URL string
|
||||
HostOverride string
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
||||
// specified http.Client and URL, as a convenience.
|
||||
func NewDNSOverHTTPS(client HTTPClient, URL string) *DNSOverHTTPS {
|
||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
||||
// it's creating a resolver where we use the specified host.
|
||||
func NewDNSOverHTTPSWithHostOverride(
|
||||
client HTTPClient, URL, hostOverride string) *DNSOverHTTPS {
|
||||
return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = t.HostOverride
|
||||
req.Header.Set("user-agent", httpheader.UserAgent())
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
var resp *http.Response
|
||||
resp, err = t.Client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
// TODO(bassosimone): we should map the status code to a
|
||||
// proper Error in the DNS context.
|
||||
return nil, errors.New("doh: server returned error")
|
||||
}
|
||||
if resp.Header.Get("content-type") != "application/dns-message" {
|
||||
return nil, errors.New("doh: invalid content-type")
|
||||
}
|
||||
return iox.ReadAllContext(ctx, resp.Body)
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoH according to RFC8467
|
||||
func (t *DNSOverHTTPS) RequiresPadding() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverHTTPS) Network() string {
|
||||
return "doh"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverHTTPS) Address() string {
|
||||
return t.URL
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverHTTPS) CloseIdleConnections() {
|
||||
t.Client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
var _ DNSTransport = &DNSOverHTTPS{}
|
||||
@@ -1,196 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverHTTPS(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("NewRequestFailure", func(t *testing.T) {
|
||||
const invalidURL = "\t"
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("client.Do failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("server returns 500", func(t *testing.T) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 500,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: server returned error" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing content type", func(t *testing.T) {
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
body := []byte("AAA")
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/dns-message"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, body) {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets the correct user-agent", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var correct bool
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct user agent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("we can override the Host header", func(t *testing.T) {
|
||||
var correct bool
|
||||
expected := errors.New("mocked error")
|
||||
hostOverride := "test.com"
|
||||
txp := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Host == hostOverride
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct host override")
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
t.Run("other functions behave correctly", func(t *testing.T) {
|
||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||
if txp.Network() != "doh" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("should require padding")
|
||||
}
|
||||
if txp.Address() != queryURL {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
doh := &DNSOverHTTPS{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
},
|
||||
}
|
||||
doh.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DialContextFunc is a generic function for dialing a connection.
|
||||
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
|
||||
// and NewDNSOverTLS to create specific instances that use plaintext
|
||||
// queries or encrypted queries over TLS.
|
||||
//
|
||||
// As a known bug, this implementation always creates a new connection
|
||||
// for each incoming query, thus increasing the response delay.
|
||||
type DNSOverTCP struct {
|
||||
dial DialContextFunc
|
||||
address string
|
||||
network string
|
||||
requiresPadding bool
|
||||
}
|
||||
|
||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
||||
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "tcp",
|
||||
requiresPadding: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
||||
return &DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "dot",
|
||||
requiresPadding: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
if len(query) > math.MaxUint16 {
|
||||
return nil, errors.New("query too long")
|
||||
}
|
||||
conn, err := t.dial(ctx, "tcp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write request
|
||||
buf := []byte{byte(len(query) >> 8)}
|
||||
buf = append(buf, byte(len(query)))
|
||||
buf = append(buf, query...)
|
||||
if _, err = conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
header := make([]byte, 2)
|
||||
if _, err = io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := int(header[0])<<8 | int(header[1])
|
||||
reply := make([]byte, length)
|
||||
if _, err = io.ReadFull(conn, reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoT and false for TCP
|
||||
// according to RFC8467.
|
||||
func (t *DNSOverTCP) RequiresPadding() bool {
|
||||
return t.requiresPadding
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverTCP) Network() string {
|
||||
return t.network
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverTCP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverTCP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ DNSTransport = &DNSOverTCP{}
|
||||
@@ -1,226 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverTCP(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("query too large", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("dial failure", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetDeadline failure", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write failure", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("first read fails", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second read fails", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
input := io.MultiReader(
|
||||
bytes.NewReader([]byte{byte(0), byte(2)}),
|
||||
&mocks.Reader{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
},
|
||||
)
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful case", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 || reply[0] != 1 {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("other functions okay with TCP", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "tcp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("other functions okay with TLS", func(t *testing.T) {
|
||||
const address = "9.9.9.9:853"
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "dot" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialer is the network dialer interface assumed by this package.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// DNSOverUDP is a DNS over UDP RoundTripper.
|
||||
type DNSOverUDP struct {
|
||||
dialer Dialer
|
||||
address string
|
||||
}
|
||||
|
||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
||||
func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP {
|
||||
return &DNSOverUDP{dialer: dialer, address: address}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
// Use five seconds timeout like Bionic does. See
|
||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err = conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply := make([]byte, 1<<17)
|
||||
var n int
|
||||
n, err = conn.Read(reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply[:n], nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467
|
||||
func (t *DNSOverUDP) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t *DNSOverUDP) Network() string {
|
||||
return "udp"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t *DNSOverUDP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (t *DNSOverUDP) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
var _ DNSTransport = &DNSOverUDP{}
|
||||
@@ -1,161 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverUDP(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("dial failure", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverUDP(&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
}, address)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetDeadline failure", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write failure", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Read failure", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("read success", func(t *testing.T) {
|
||||
const expected = 17
|
||||
input := bytes.NewReader(make([]byte, expected))
|
||||
txp := NewDNSOverUDP(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != expected {
|
||||
t.Fatal("expected non nil data")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("other functions okay", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverUDP(&net.Dialer{}, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "udp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import "context"
|
||||
|
||||
// DNSTransport represents an abstract DNS transport.
|
||||
type DNSTransport interface {
|
||||
// RoundTrip sends a DNS query and receives the reply.
|
||||
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
|
||||
// RequiresPadding return true for DoH and DoT according to RFC8467
|
||||
RequiresPadding() bool
|
||||
|
||||
// Network is the network of the round tripper (e.g. "dot")
|
||||
Network() string
|
||||
|
||||
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
||||
Address() string
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
CloseIdleConnections()
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// Package model contains the dnsx model.
|
||||
package model
|
||||
// Package dnsx contains the dnsx model.
|
||||
package dnsx
|
||||
|
||||
// HTTPSSvc is an HTTPSSvc reply.
|
||||
type HTTPSSvc struct {
|
||||
@@ -1,23 +0,0 @@
|
||||
package mocks
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/model"
|
||||
|
||||
// HTTPSSvc is the result of HTTPS queries.
|
||||
type HTTPSSvc = model.HTTPSSvc
|
||||
|
||||
// DNSDecoder allows mocking dnsx.DNSDecoder.
|
||||
type DNSDecoder struct {
|
||||
MockDecodeLookupHost func(qtype uint16, reply []byte) ([]string, error)
|
||||
|
||||
MockDecodeHTTPS func(reply []byte) (*HTTPSSvc, error)
|
||||
}
|
||||
|
||||
// DecodeLookupHost calls MockDecodeLookupHost.
|
||||
func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte) ([]string, error) {
|
||||
return e.MockDecodeLookupHost(qtype, reply)
|
||||
}
|
||||
|
||||
// DecodeHTTPS calls MockDecodeHTTPS.
|
||||
func (e *DNSDecoder) DecodeHTTPS(reply []byte) (*HTTPSSvc, error) {
|
||||
return e.MockDecodeHTTPS(reply)
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDNSDecoder(t *testing.T) {
|
||||
t.Run("DecodeLookupHost", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &DNSDecoder{
|
||||
MockDecodeLookupHost: func(qtype uint16, reply []byte) ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17))
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeHTTPS", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &DNSDecoder{
|
||||
MockDecodeHTTPS: func(reply []byte) (*HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.DecodeHTTPS(make([]byte, 17))
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package mocks
|
||||
|
||||
// DNSEncoder allows mocking dnsx.DNSEncoder.
|
||||
type DNSEncoder struct {
|
||||
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, error)
|
||||
}
|
||||
|
||||
// Encode calls MockEncode.
|
||||
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return e.MockEncode(domain, qtype, padding)
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDNSEncoder(t *testing.T) {
|
||||
t.Run("Encode", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.Encode("dns.google", dns.TypeA, true)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package mocks
|
||||
|
||||
import "context"
|
||||
|
||||
// DNSTransport allows mocking dnsx.DNSTransport.
|
||||
type DNSTransport struct {
|
||||
MockRoundTrip func(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
|
||||
MockRequiresPadding func() bool
|
||||
|
||||
MockNetwork func() string
|
||||
|
||||
MockAddress func() string
|
||||
|
||||
MockCloseIdleConnections func()
|
||||
}
|
||||
|
||||
// RoundTrip calls MockRoundTrip.
|
||||
func (txp *DNSTransport) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return txp.MockRoundTrip(ctx, query)
|
||||
}
|
||||
|
||||
// RequiresPadding calls MockRequiresPadding.
|
||||
func (txp *DNSTransport) RequiresPadding() bool {
|
||||
return txp.MockRequiresPadding()
|
||||
}
|
||||
|
||||
// Network calls MockNetwork.
|
||||
func (txp *DNSTransport) Network() string {
|
||||
return txp.MockNetwork()
|
||||
}
|
||||
|
||||
// Address calls MockAddress.
|
||||
func (txp *DNSTransport) Address() string {
|
||||
return txp.MockAddress()
|
||||
}
|
||||
|
||||
// CloseIdleConnections calls MockCloseIdleConnections.
|
||||
func (txp *DNSTransport) CloseIdleConnections() {
|
||||
txp.MockCloseIdleConnections()
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
)
|
||||
|
||||
func TestDNSTransport(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := &DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16))
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RequiresPadding", func(t *testing.T) {
|
||||
txp := &DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network", func(t *testing.T) {
|
||||
txp := &DNSTransport{
|
||||
MockNetwork: func() string {
|
||||
return "antani"
|
||||
},
|
||||
}
|
||||
if txp.Network() != "antani" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Address", func(t *testing.T) {
|
||||
txp := &DNSTransport{
|
||||
MockAddress: func() string {
|
||||
return "mascetti"
|
||||
},
|
||||
}
|
||||
if txp.Address() != "mascetti" {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
called := &atomicx.Int64{}
|
||||
txp := &DNSTransport{
|
||||
MockCloseIdleConnections: func() {
|
||||
called.Add(1)
|
||||
},
|
||||
}
|
||||
txp.CloseIdleConnections()
|
||||
if called.Load() != 1 {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
// Package mocks contains mocks for dnsx.
|
||||
package mocks
|
||||
@@ -1,118 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
)
|
||||
|
||||
// SerialResolver is a resolver that first issues an A query and then
|
||||
// issues an AAAA query for the requested domain.
|
||||
type SerialResolver struct {
|
||||
Encoder DNSEncoder
|
||||
Decoder DNSDecoder
|
||||
NumTimeouts *atomicx.Int64
|
||||
Txp DNSTransport
|
||||
}
|
||||
|
||||
// NewSerialResolver creates a new OONI Resolver instance.
|
||||
func NewSerialResolver(t DNSTransport) *SerialResolver {
|
||||
return &SerialResolver{
|
||||
Encoder: &DNSEncoderMiekg{},
|
||||
Decoder: &DNSDecoderMiekg{},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: t,
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the transport being used.
|
||||
func (r *SerialResolver) Transport() DNSTransport {
|
||||
return r.Txp
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (r *SerialResolver) Network() string {
|
||||
return r.Txp.Network()
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (r *SerialResolver) Address() string {
|
||||
return r.Txp.Address()
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (r *SerialResolver) CloseIdleConnections() {
|
||||
r.Txp.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var addrs []string
|
||||
addrsA, errA := r.lookupHostWithRetry(ctx, hostname, dns.TypeA)
|
||||
addrsAAAA, errAAAA := r.lookupHostWithRetry(ctx, hostname, dns.TypeAAAA)
|
||||
if errA != nil && errAAAA != nil {
|
||||
return nil, errA
|
||||
}
|
||||
addrs = append(addrs, addrsA...)
|
||||
addrs = append(addrs, addrsAAAA...)
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// LookupHTTPS implements Resolver.LookupHTTPS.
|
||||
func (r *SerialResolver) LookupHTTPS(
|
||||
ctx context.Context, hostname string) (*HTTPSSvc, error) {
|
||||
querydata, err := r.Encoder.Encode(
|
||||
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeHTTPS(replydata)
|
||||
}
|
||||
|
||||
func (r *SerialResolver) lookupHostWithRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
var errorslist []error
|
||||
for i := 0; i < 3; i++ {
|
||||
replies, err := r.lookupHostWithoutRetry(ctx, hostname, qtype)
|
||||
if err == nil {
|
||||
return replies, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
var operr *net.OpError
|
||||
if !errors.As(err, &operr) || !operr.Timeout() {
|
||||
// The first error is the one that is most likely to be caused
|
||||
// by the network. Subsequent errors are more likely to be caused
|
||||
// by context deadlines. So, the first error is attached to an
|
||||
// operation, while subsequent errors may possibly not be. If
|
||||
// so, the resulting failing operation is not correct.
|
||||
break
|
||||
}
|
||||
r.NumTimeouts.Add(1)
|
||||
}
|
||||
// bugfix: we MUST return one of the errors otherwise we confuse the
|
||||
// mechanism in errwrap that classifies the root cause operation, since
|
||||
// it would not be able to find a child with a major operation error
|
||||
return nil, errorslist[0]
|
||||
}
|
||||
|
||||
// lookupHostWithoutRetry issues a lookup host query for the specified
|
||||
// qtype (dns.A or dns.AAAA) without retrying on failure.
|
||||
func (r *SerialResolver) lookupHostWithoutRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeLookupHost(qtype, replydata)
|
||||
}
|
||||
@@ -1,257 +0,0 @@
|
||||
package dnsx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
|
||||
)
|
||||
|
||||
func TestSerialResolver(t *testing.T) {
|
||||
t.Run("transport okay", func(t *testing.T) {
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := NewSerialResolver(txp)
|
||||
rtx := r.Transport()
|
||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
if r.Network() != rtx.Network() {
|
||||
t.Fatal("invalid network seen from the resolver")
|
||||
}
|
||||
if r.Address() != rtx.Address() {
|
||||
t.Fatal("invalid address seen from the resolver")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("Encode error", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
},
|
||||
Txp: txp,
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RoundTrip error", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, mocked
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(t, dns.TypeA), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, errorsx.ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with A reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(t, dns.TypeA, "8.8.8.8"), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with AAAA reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1"), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "::1" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with timeout", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, &net.OpError{Err: errorsx.ETIMEDOUT, Op: "dial"}
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
r := NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, errorsx.ETIMEDOUT) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
if r.NumTimeouts.Load() <= 0 {
|
||||
t.Fatal("we didn't actually take the timeouts")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
r := &SerialResolver{
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
},
|
||||
}
|
||||
r.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||
t.Run("for encoding error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for round-trip error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return make([]byte, 64), nil
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for decode error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return make([]byte, 64), nil
|
||||
},
|
||||
},
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeHTTPS: func(reply []byte) (*mocks.HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user