chore: merge probe-engine into probe-cli (#201)
This is how I did it: 1. `git clone https://github.com/ooni/probe-engine internal/engine` 2. ``` (cd internal/engine && git describe --tags) v0.23.0 ``` 3. `nvim go.mod` (merging `go.mod` with `internal/engine/go.mod` 4. `rm -rf internal/.git internal/engine/go.{mod,sum}` 5. `git add internal/engine` 6. `find . -type f -name \*.go -exec sed -i 's@/ooni/probe-engine@/ooni/probe-cli/v3/internal/engine@g' {} \;` 7. `go build ./...` (passes) 8. `go test -race ./...` (temporary failure on RiseupVPN) 9. `go mod tidy` 10. this commit message Once this piece of work is done, we can build a new version of `ooniprobe` that is using `internal/engine` directly. We need to do more work to ensure all the other functionality in `probe-engine` (e.g. making mobile packages) are still WAI. Part of https://github.com/ooni/probe/issues/1335
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// AddressResolver is a resolver that knows how to correctly
|
||||
// resolve IP addresses to themselves.
|
||||
type AddressResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r AddressResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
return r.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
var _ Resolver = AddressResolver{}
|
||||
@@ -0,0 +1,36 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestAddressSuccess(t *testing.T) {
|
||||
r := resolver.AddressResolver{}
|
||||
addrs, err := r.LookupHost(context.Background(), "8.8.8.8")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddressFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := resolver.AddressResolver{
|
||||
Resolver: resolver.FakeResolver{
|
||||
Err: expected,
|
||||
},
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/runtimex"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
var privateIPBlocks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"0.0.0.0/8", // "This" network (however, Linux...)
|
||||
"10.0.0.0/8", // RFC1918
|
||||
"100.64.0.0/10", // Carrier grade NAT
|
||||
"127.0.0.0/8", // IPv4 loopback
|
||||
"169.254.0.0/16", // RFC3927 link-local
|
||||
"172.16.0.0/12", // RFC1918
|
||||
"192.168.0.0/16", // RFC1918
|
||||
"224.0.0.0/4", // Multicast
|
||||
"::1/128", // IPv6 loopback
|
||||
"fe80::/10", // IPv6 link-local
|
||||
"fc00::/7", // IPv6 unique local addr
|
||||
} {
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
runtimex.PanicOnError(err, "net.ParseCIDR failed")
|
||||
privateIPBlocks = append(privateIPBlocks, block)
|
||||
}
|
||||
}
|
||||
|
||||
func isPrivate(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
for _, block := range privateIPBlocks {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsBogon returns whether if an IP address is bogon. Passing to this
|
||||
// function a non-IP address causes it to return bogon.
|
||||
func IsBogon(address string) bool {
|
||||
ip := net.ParseIP(address)
|
||||
return ip == nil || isPrivate(ip)
|
||||
}
|
||||
|
||||
// BogonResolver is a bogon aware resolver. When a bogon is encountered in
|
||||
// a reply, this resolver will return an error.
|
||||
type BogonResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r BogonResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
for _, addr := range addrs {
|
||||
if IsBogon(addr) == true {
|
||||
// We need to return the addrs otherwise the caller cannot see/log/save
|
||||
// the specific addresses that triggered our bogon filter
|
||||
return addrs, errorx.ErrDNSBogon
|
||||
}
|
||||
}
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ Resolver = BogonResolver{}
|
||||
@@ -0,0 +1,52 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestResolverIsBogon(t *testing.T) {
|
||||
if resolver.IsBogon("antani") != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if resolver.IsBogon("127.0.0.1") != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if resolver.IsBogon("1.1.1.1") != false {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if resolver.IsBogon("10.0.1.1") != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBogonAwareResolverWithBogon(t *testing.T) {
|
||||
r := resolver.BogonResolver{
|
||||
Resolver: resolver.NewFakeResolverWithResult([]string{"127.0.0.1"}),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if !errors.Is(err, errorx.ErrDNSBogon) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "127.0.0.1" {
|
||||
t.Fatal("expected to see address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBogonAwareResolverWithoutBogon(t *testing.T) {
|
||||
orig := []string{"8.8.8.8"}
|
||||
r := resolver.BogonResolver{
|
||||
Resolver: resolver.NewFakeResolverWithResult(orig),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != len(orig) || addrs[0] != orig[0] {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CacheResolver is a resolver that caches successful replies.
|
||||
type CacheResolver struct {
|
||||
ReadOnly bool
|
||||
Resolver
|
||||
mu sync.Mutex
|
||||
cache map[string][]string
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r *CacheResolver) LookupHost(
|
||||
ctx context.Context, hostname string) ([]string, error) {
|
||||
if entry := r.Get(hostname); entry != nil {
|
||||
return entry, nil
|
||||
}
|
||||
entry, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.ReadOnly == false {
|
||||
r.Set(hostname, entry)
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// Get gets the currently configured entry for domain, or nil
|
||||
func (r *CacheResolver) Get(domain string) []string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.cache[domain]
|
||||
}
|
||||
|
||||
// Set allows to pre-populate the cache
|
||||
func (r *CacheResolver) Set(domain string, addresses []string) {
|
||||
r.mu.Lock()
|
||||
if r.cache == nil {
|
||||
r.cache = make(map[string][]string)
|
||||
}
|
||||
r.cache[domain] = addresses
|
||||
r.mu.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestCacheFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Err: expected,
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs here")
|
||||
}
|
||||
if cache.Get("www.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheHitSuccess(t *testing.T) {
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Err: errors.New("mocked error"),
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r}
|
||||
cache.Set("dns.google.com", []string{"8.8.8.8"})
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMissSuccess(t *testing.T) {
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Result: []string{"8.8.8.8"},
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.Get("dns.google.com")[0] != "8.8.8.8" {
|
||||
t.Fatal("expected full cache here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheReadonlySuccess(t *testing.T) {
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Result: []string{"8.8.8.8"},
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r, ReadOnly: true}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.Get("dns.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// ChainResolver is a chain resolver. The primary resolver is used first and, if that
|
||||
// fails, we then attempt with the secondary resolver.
|
||||
type ChainResolver struct {
|
||||
Primary Resolver
|
||||
Secondary Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (c ChainResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := c.Primary.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
addrs, err = c.Secondary.LookupHost(ctx, hostname)
|
||||
}
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (c ChainResolver) Network() string {
|
||||
return "chain"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (c ChainResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ Resolver = ChainResolver{}
|
||||
@@ -0,0 +1,28 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestChainLookupHost(t *testing.T) {
|
||||
r := resolver.ChainResolver{
|
||||
Primary: resolver.NewFakeResolverThatFails(),
|
||||
Secondary: resolver.SystemResolver{},
|
||||
}
|
||||
if r.Address() != "" {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
if r.Network() != "chain" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expect non nil return value here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The Decoder decodes a DNS reply into A or AAAA entries. It will use the
|
||||
// provided qtype and only look for mathing entries. It will return error if
|
||||
// there are no entries for the requested qtype inside the reply.
|
||||
type Decoder interface {
|
||||
Decode(qtype uint16, data []byte) ([]string, error)
|
||||
}
|
||||
|
||||
// MiekgDecoder uses github.com/miekg/dns to implement the Decoder.
|
||||
type MiekgDecoder struct{}
|
||||
|
||||
// Decode implements Decoder.Decode.
|
||||
func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
|
||||
reply := new(dns.Msg)
|
||||
if err := reply.Unpack(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bassosimone): map more errors to net.DNSError names
|
||||
switch reply.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
case dns.RcodeNameError:
|
||||
return nil, errors.New("ooniresolver: no such host")
|
||||
default:
|
||||
return nil, errors.New("ooniresolver: query failed")
|
||||
}
|
||||
var addrs []string
|
||||
for _, answer := range reply.Answer {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
if rra, ok := answer.(*dns.A); ok {
|
||||
ip := rra.A
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if rra, ok := answer.(*dns.AAAA); ok {
|
||||
ip := rra.AAAA
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(addrs) <= 0 {
|
||||
return nil, errors.New("ooniresolver: no response returned")
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
var _ Decoder = MiekgDecoder{}
|
||||
@@ -0,0 +1,113 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDecoderUnpackError(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNXDOMAIN(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeNameError))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderOtherError(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeRefused))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNoAddress(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeA(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "1.1.1.1" {
|
||||
t.Fatal("invalid first IPv4 entry")
|
||||
}
|
||||
if data[1] != "8.8.8.8" {
|
||||
t.Fatal("invalid second IPv4 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeAAAA(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "::1" {
|
||||
t.Fatal("invalid first IPv6 entry")
|
||||
}
|
||||
if data[1] != "fe80::1" {
|
||||
t.Fatal("invalid second IPv6 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAReply(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
|
||||
)
|
||||
|
||||
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
||||
// an HTTP/HTTPS channel provided by URL using the Do function.
|
||||
type DNSOverHTTPS struct {
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
URL string
|
||||
HostOverride string
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
||||
// specified http.Client and URL, as a convenience.
|
||||
func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS {
|
||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
||||
// it's creating a resolver where we use the specified host.
|
||||
func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS {
|
||||
return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = t.HostOverride
|
||||
req.Header.Set("user-agent", httpheader.UserAgent())
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
var resp *http.Response
|
||||
resp, err = t.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
// TODO(bassosimone): we should map the status code to a
|
||||
// proper Error in the DNS context.
|
||||
return nil, errors.New("doh: server returned error")
|
||||
}
|
||||
if resp.Header.Get("content-type") != "application/dns-message" {
|
||||
return nil, errors.New("doh: invalid content-type")
|
||||
}
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoH according to RFC8467
|
||||
func (t DNSOverHTTPS) RequiresPadding() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverHTTPS) Network() string {
|
||||
return "doh"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverHTTPS) Address() string {
|
||||
return t.URL
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverHTTPS{}
|
||||
@@ -0,0 +1,165 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
||||
const invalidURL = "\t"
|
||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 500,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: server returned error" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSSuccess(t *testing.T) {
|
||||
body := []byte("AAA")
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/dns-message"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, body) {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPTransportOK(t *testing.T) {
|
||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||
if txp.Network() != "doh" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("should require padding")
|
||||
}
|
||||
if txp.Address() != queryURL {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var correct bool
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||
return nil, expected
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct user agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHostOverride(t *testing.T) {
|
||||
var correct bool
|
||||
expected := errors.New("mocked error")
|
||||
|
||||
hostOverride := "test.com"
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Host == hostOverride
|
||||
return nil, expected
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct host override")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DialContextFunc is a generic function for dialing a connection.
|
||||
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
|
||||
// and NewDNSOverTLS to create specific instances that use plaintext
|
||||
// queries or encrypted queries over TLS.
|
||||
//
|
||||
// As a known bug, this implementation always creates a new connection
|
||||
// for each incoming query, thus increasing the response delay.
|
||||
type DNSOverTCP struct {
|
||||
dial DialContextFunc
|
||||
address string
|
||||
network string
|
||||
requiresPadding bool
|
||||
}
|
||||
|
||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
||||
func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
|
||||
return DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "tcp",
|
||||
requiresPadding: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
|
||||
return DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "dot",
|
||||
requiresPadding: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
if len(query) > math.MaxUint16 {
|
||||
return nil, errors.New("query too long")
|
||||
}
|
||||
conn, err := t.dial(ctx, "tcp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write request
|
||||
buf := []byte{byte(len(query) >> 8)}
|
||||
buf = append(buf, byte(len(query)))
|
||||
buf = append(buf, query...)
|
||||
if _, err = conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
header := make([]byte, 2)
|
||||
if _, err = io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := int(header[0])<<8 | int(header[1])
|
||||
reply := make([]byte, length)
|
||||
if _, err = io.ReadFull(conn, reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoT and false for TCP
|
||||
// according to RFC8467.
|
||||
func (t DNSOverTCP) RequiresPadding() bool {
|
||||
return t.requiresPadding
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverTCP) Network() string {
|
||||
return t.network
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverTCP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverTCP{}
|
||||
@@ -0,0 +1,146 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Err: mocked}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
WriteError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(2)},
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(1), byte(1)},
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 || reply[0] != 1 {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "tcp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTLSTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:853"
|
||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address)
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "dot" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialer is the network dialer interface assumed by this package.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// DNSOverUDP is a DNS over UDP RoundTripper.
|
||||
type DNSOverUDP struct {
|
||||
dialer Dialer
|
||||
address string
|
||||
}
|
||||
|
||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
||||
func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP {
|
||||
return DNSOverUDP{dialer: dialer, address: address}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
// Use five seconds timeout like Bionic does. See
|
||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err = conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply := make([]byte, 1<<17)
|
||||
var n int
|
||||
n, err = conn.Read(reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply[:n], nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467
|
||||
func (t DNSOverUDP) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverUDP) Network() string {
|
||||
return "udp"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverUDP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverUDP{}
|
||||
@@ -0,0 +1,107 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverUDPDialFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPWriteFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
WriteError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadSuccess(t *testing.T) {
|
||||
const expected = 17
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{ReadData: make([]byte, 17)},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != expected {
|
||||
t.Fatal("expected non nil data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverUDP(&net.Dialer{}, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "udp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
)
|
||||
|
||||
// EmitterTransport is a RoundTripper that emits events when they occur.
|
||||
type EmitterTransport struct {
|
||||
RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp EmitterTransport) RoundTrip(ctx context.Context, querydata []byte) ([]byte, error) {
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
DNSQuery: &modelx.DNSQueryEvent{
|
||||
Data: querydata,
|
||||
DialID: dialid.ContextDialID(ctx),
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
},
|
||||
})
|
||||
replydata, err := txp.RoundTripper.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
DNSReply: &modelx.DNSReplyEvent{
|
||||
Data: replydata,
|
||||
DialID: dialid.ContextDialID(ctx),
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
},
|
||||
})
|
||||
return replydata, nil
|
||||
}
|
||||
|
||||
// EmitterResolver is a resolver that emits events
|
||||
type EmitterResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost returns the IP addresses of a host
|
||||
func (r EmitterResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var (
|
||||
network string
|
||||
address string
|
||||
)
|
||||
type queryableResolver interface {
|
||||
Transport() RoundTripper
|
||||
}
|
||||
if qr, ok := r.Resolver.(queryableResolver); ok {
|
||||
txp := qr.Transport()
|
||||
network, address = txp.Network(), txp.Address()
|
||||
}
|
||||
dialID := dialid.ContextDialID(ctx)
|
||||
txID := transactionid.ContextTransactionID(ctx)
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
ResolveStart: &modelx.ResolveStartEvent{
|
||||
DialID: dialID,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Hostname: hostname,
|
||||
TransactionID: txID,
|
||||
TransportAddress: address,
|
||||
TransportNetwork: network,
|
||||
},
|
||||
})
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
ResolveDone: &modelx.ResolveDoneEvent{
|
||||
Addresses: addrs,
|
||||
DialID: dialID,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Error: err,
|
||||
Hostname: hostname,
|
||||
TransactionID: txID,
|
||||
TransportAddress: address,
|
||||
TransportNetwork: network,
|
||||
},
|
||||
})
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ RoundTripper = EmitterTransport{}
|
||||
var _ Resolver = EmitterResolver{}
|
||||
@@ -0,0 +1,220 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestEmitterTransportSuccess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
|
||||
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
|
||||
}}
|
||||
e := resolver.MiekgEncoder{}
|
||||
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
replydata, err := txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[0].DNSQuery == nil {
|
||||
t.Fatal("missing DNSQuery field")
|
||||
}
|
||||
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
|
||||
t.Fatal("invalid query data")
|
||||
}
|
||||
if events[0].DNSQuery.DialID == 0 {
|
||||
t.Fatal("invalid query DialID")
|
||||
}
|
||||
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
if events[1].DNSReply == nil {
|
||||
t.Fatal("missing DNSReply field")
|
||||
}
|
||||
if !bytes.Equal(events[1].DNSReply.Data, replydata) {
|
||||
t.Fatal("missing reply data")
|
||||
}
|
||||
if events[1].DNSReply.DialID != 1 {
|
||||
t.Fatal("invalid query DialID")
|
||||
}
|
||||
if events[1].DNSReply.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterTransportFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
|
||||
Err: mocked,
|
||||
}}
|
||||
e := resolver.MiekgEncoder{}
|
||||
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
replydata, err := txp.RoundTrip(ctx, querydata)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if replydata != nil {
|
||||
t.Fatal("expected nil replydata")
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[0].DNSQuery == nil {
|
||||
t.Fatal("missing DNSQuery field")
|
||||
}
|
||||
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
|
||||
t.Fatal("invalid query data")
|
||||
}
|
||||
if events[0].DNSQuery.DialID == 0 {
|
||||
t.Fatal("invalid query DialID")
|
||||
}
|
||||
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterResolverFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
|
||||
resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, io.EOF
|
||||
},
|
||||
URL: "https://dns.google.com/",
|
||||
},
|
||||
)}
|
||||
replies, err := r.LookupHost(ctx, "www.google.com")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if replies != nil {
|
||||
t.Fatal("expected nil replies")
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[0].ResolveStart == nil {
|
||||
t.Fatal("missing ResolveStart field")
|
||||
}
|
||||
if events[0].ResolveStart.DialID == 0 {
|
||||
t.Fatal("invalid DialID")
|
||||
}
|
||||
if events[0].ResolveStart.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
if events[0].ResolveStart.Hostname != "www.google.com" {
|
||||
t.Fatal("invalid Hostname")
|
||||
}
|
||||
if events[0].ResolveStart.TransactionID == 0 {
|
||||
t.Fatal("invalid TransactionID")
|
||||
}
|
||||
if events[0].ResolveStart.TransportAddress != "https://dns.google.com/" {
|
||||
t.Fatal("invalid TransportAddress")
|
||||
}
|
||||
if events[0].ResolveStart.TransportNetwork != "doh" {
|
||||
t.Fatal("invalid TransportNetwork")
|
||||
}
|
||||
if events[1].ResolveDone == nil {
|
||||
t.Fatal("missing ResolveDone field")
|
||||
}
|
||||
if events[1].ResolveDone.DialID == 0 {
|
||||
t.Fatal("invalid DialID")
|
||||
}
|
||||
if events[1].ResolveDone.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
if events[1].ResolveDone.Error != io.EOF {
|
||||
t.Fatal("invalid Error")
|
||||
}
|
||||
if events[1].ResolveDone.Hostname != "www.google.com" {
|
||||
t.Fatal("invalid Hostname")
|
||||
}
|
||||
if events[1].ResolveDone.TransactionID == 0 {
|
||||
t.Fatal("invalid TransactionID")
|
||||
}
|
||||
if events[1].ResolveDone.TransportAddress != "https://dns.google.com/" {
|
||||
t.Fatal("invalid TransportAddress")
|
||||
}
|
||||
if events[1].ResolveDone.TransportNetwork != "doh" {
|
||||
t.Fatal("invalid TransportNetwork")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterResolverSuccess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
r := resolver.EmitterResolver{Resolver: resolver.NewFakeResolverWithResult(
|
||||
[]string{"8.8.8.8"},
|
||||
)}
|
||||
replies, err := r.LookupHost(ctx, "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(replies) != 1 {
|
||||
t.Fatal("expected a single replies")
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[1].ResolveDone == nil {
|
||||
t.Fatal("missing ResolveDone field")
|
||||
}
|
||||
if events[1].ResolveDone.Addresses[0] != "8.8.8.8" {
|
||||
t.Fatal("invalid Addresses")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package resolver
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// The Encoder encodes DNS queries to bytes
|
||||
type Encoder interface {
|
||||
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
|
||||
}
|
||||
|
||||
// MiekgEncoder uses github.com/miekg/dns to implement the Encoder.
|
||||
type MiekgEncoder struct{}
|
||||
|
||||
const (
|
||||
// PaddingDesiredBlockSize is the size that the padded query should be multiple of
|
||||
PaddingDesiredBlockSize = 128
|
||||
|
||||
// EDNS0MaxResponseSize is the maximum response size for EDNS0
|
||||
EDNS0MaxResponseSize = 4096
|
||||
|
||||
// DNSSECEnabled turns on support for DNSSEC when using EDNS0
|
||||
DNSSECEnabled = true
|
||||
)
|
||||
|
||||
// Encode implements Encoder.Encode
|
||||
func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
if padding {
|
||||
query.SetEdns0(EDNS0MaxResponseSize, DNSSECEnabled)
|
||||
// Clients SHOULD pad queries to the closest multiple of
|
||||
// 128 octets RFC8467#section-4.1. We inflate the query
|
||||
// length by the size of the option (i.e. 4 octets). The
|
||||
// cast to uint is necessary to make the modulus operation
|
||||
// work as intended when the desiredBlockSize is smaller
|
||||
// than (query.Len()+4) ¯\_(ツ)_/¯.
|
||||
remainder := (PaddingDesiredBlockSize - uint(query.Len()+4)) % PaddingDesiredBlockSize
|
||||
opt := new(dns.EDNS0_PADDING)
|
||||
opt.Padding = make([]byte, remainder)
|
||||
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
||||
}
|
||||
return query.Pack()
|
||||
}
|
||||
|
||||
var _ Encoder = MiekgEncoder{}
|
||||
@@ -0,0 +1,99 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestEncoderEncodeA(t *testing.T) {
|
||||
e := resolver.MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func TestEncoderEncodeAAAA(t *testing.T) {
|
||||
e := resolver.MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func validate(t *testing.T, data []byte, qtype byte) {
|
||||
// skipping over the query ID
|
||||
if data[2] != 1 {
|
||||
t.Fatal("FLAGS should only have RD set")
|
||||
}
|
||||
if data[3] != 0 {
|
||||
t.Fatal("RA|Z|Rcode should be zero")
|
||||
}
|
||||
if data[4] != 0 || data[5] != 1 {
|
||||
t.Fatal("QCOUNT high should be one")
|
||||
}
|
||||
if data[6] != 0 || data[7] != 0 {
|
||||
t.Fatal("ANCOUNT should be zero")
|
||||
}
|
||||
if data[8] != 0 || data[9] != 0 {
|
||||
t.Fatal("NSCOUNT should be zero")
|
||||
}
|
||||
if data[10] != 0 || data[11] != 0 {
|
||||
t.Fatal("ARCOUNT should be zero")
|
||||
}
|
||||
t.Log(data[12])
|
||||
if data[12] != 1 || data[13] != byte('x') {
|
||||
t.Fatal("The name does not contain 1:x")
|
||||
}
|
||||
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
|
||||
t.Fatal("The name does not containg 3:org")
|
||||
}
|
||||
if data[18] != 0 {
|
||||
t.Fatal("The name does not terminate where expected")
|
||||
}
|
||||
if data[19] != 0 && data[20] != qtype {
|
||||
t.Fatal("The query is not for the expected type")
|
||||
}
|
||||
if data[21] != 0 && data[22] != 1 {
|
||||
t.Fatal("The query is not IN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderPadding(t *testing.T) {
|
||||
// The purpose of this unit test is to make sure that for a wide
|
||||
// array of values we obtain the right query size.
|
||||
getquerylen := func(domainlen int, padding bool) int {
|
||||
e := resolver.MiekgEncoder{}
|
||||
data, err := e.Encode(
|
||||
// This is not a valid name because it ends up being way
|
||||
// longer than 255 octets. However, the library is allowing
|
||||
// us to generate such name and we are not going to send
|
||||
// it on the wire. Also, we check below that the query that
|
||||
// we generate is long enough, so we should be good.
|
||||
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
||||
dns.TypeA, padding,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
for domainlen := 1; domainlen <= 4000; domainlen++ {
|
||||
vanillalen := getquerylen(domainlen, false)
|
||||
paddedlen := getquerylen(domainlen, true)
|
||||
if vanillalen < domainlen {
|
||||
t.Fatal("vanillalen is smaller than domainlen")
|
||||
}
|
||||
if (paddedlen % resolver.PaddingDesiredBlockSize) != 0 {
|
||||
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
|
||||
}
|
||||
if paddedlen < vanillalen {
|
||||
t.Fatal("paddedlen is smaller than vanillalen")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// ErrorWrapperResolver is a Resolver that knows about wrapping errors.
|
||||
type ErrorWrapperResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r ErrorWrapperResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
dialID := dialid.ContextDialID(ctx)
|
||||
txID := transactionid.ContextTransactionID(ctx)
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
DialID: dialID,
|
||||
Error: err,
|
||||
Operation: errorx.ResolveOperation,
|
||||
TransactionID: txID,
|
||||
}.MaybeBuild()
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ Resolver = ErrorWrapperResolver{}
|
||||
@@ -0,0 +1,58 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestErrorWrapperSuccess(t *testing.T) {
|
||||
orig := []string{"8.8.8.8"}
|
||||
r := resolver.ErrorWrapperResolver{
|
||||
Resolver: resolver.NewFakeResolverWithResult(orig),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != len(orig) || addrs[0] != orig[0] {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorWrapperFailure(t *testing.T) {
|
||||
r := resolver.ErrorWrapperResolver{
|
||||
Resolver: resolver.NewFakeResolverThatFails(),
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
addrs, err := r.LookupHost(ctx, "dns.google.com")
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addr here")
|
||||
}
|
||||
var errWrapper *errorx.ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("cannot properly cast the returned error")
|
||||
}
|
||||
if errWrapper.Failure != errorx.FailureDNSNXDOMAINError {
|
||||
t.Fatal("unexpected failure")
|
||||
}
|
||||
if errWrapper.ConnID != 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if errWrapper.DialID == 0 {
|
||||
t.Fatal("unexpected DialID")
|
||||
}
|
||||
if errWrapper.TransactionID == 0 {
|
||||
t.Fatal("unexpected TransactionID")
|
||||
}
|
||||
if errWrapper.Operation != errorx.ResolveOperation {
|
||||
t.Fatal("unexpected Operation")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
)
|
||||
|
||||
type FakeDialer struct {
|
||||
Conn net.Conn
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return d.Conn, d.Err
|
||||
}
|
||||
|
||||
type FakeConn struct {
|
||||
ReadError error
|
||||
ReadData []byte
|
||||
SetDeadlineError error
|
||||
SetReadDeadlineError error
|
||||
SetWriteDeadlineError error
|
||||
WriteError error
|
||||
}
|
||||
|
||||
func (c *FakeConn) Read(b []byte) (int, error) {
|
||||
if len(c.ReadData) > 0 {
|
||||
n := copy(b, c.ReadData)
|
||||
c.ReadData = c.ReadData[n:]
|
||||
return n, nil
|
||||
}
|
||||
if c.ReadError != nil {
|
||||
return 0, c.ReadError
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *FakeConn) Write(b []byte) (n int, err error) {
|
||||
if c.WriteError != nil {
|
||||
return 0, c.WriteError
|
||||
}
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) Close() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (*FakeConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
|
||||
return c.SetDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
|
||||
return c.SetReadDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
|
||||
return c.SetWriteDeadlineError
|
||||
}
|
||||
|
||||
type FakeTransport struct {
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
return ft.Data, ft.Err
|
||||
}
|
||||
|
||||
func (ft FakeTransport) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (ft FakeTransport) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ft FakeTransport) Network() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
type FakeEncoder struct {
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return fe.Data, fe.Err
|
||||
}
|
||||
|
||||
type FakeResolver struct {
|
||||
NumFailures *atomicx.Int64
|
||||
Err error
|
||||
Result []string
|
||||
}
|
||||
|
||||
func NewFakeResolverThatFails() FakeResolver {
|
||||
return FakeResolver{NumFailures: atomicx.NewInt64(), Err: errNotFound}
|
||||
}
|
||||
|
||||
func NewFakeResolverWithResult(r []string) FakeResolver {
|
||||
return FakeResolver{NumFailures: atomicx.NewInt64(), Result: r}
|
||||
}
|
||||
|
||||
var errNotFound = &net.DNSError{
|
||||
Err: "no such host",
|
||||
}
|
||||
|
||||
func (c FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
if c.Err != nil {
|
||||
if c.NumFailures != nil {
|
||||
c.NumFailures.Add(1)
|
||||
}
|
||||
return nil, c.Err
|
||||
}
|
||||
return c.Result, nil
|
||||
}
|
||||
|
||||
func (c FakeResolver) Network() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (c FakeResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ Resolver = FakeResolver{}
|
||||
@@ -0,0 +1,76 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func GenReplyError(t *testing.T, code int) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetRcode(query, code)
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func GenReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetReply(query)
|
||||
for _, ip := range ips {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
reply.Answer = append(reply.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
A: net.ParseIP(ip),
|
||||
})
|
||||
case dns.TypeAAAA:
|
||||
reply.Answer = append(reply.Answer, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
AAAA: net.ParseIP(ip),
|
||||
})
|
||||
}
|
||||
}
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// IDNAResolver is to support resolving Internationalized Domain Names.
|
||||
// See RFC3492 for more information.
|
||||
type IDNAResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r IDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
host, err := idna.ToASCII(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Resolver.LookupHost(ctx, host)
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network.
|
||||
func (r IDNAResolver) Network() string {
|
||||
return "idna"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address.
|
||||
func (r IDNAResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ Resolver = IDNAResolver{}
|
||||
@@ -0,0 +1,76 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
var ErrUnexpectedPunycode = errors.New("unexpected punycode value")
|
||||
|
||||
type CheckIDNAResolver struct {
|
||||
Addresses []string
|
||||
Error error
|
||||
Expect string
|
||||
}
|
||||
|
||||
func (resolv CheckIDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if resolv.Error != nil {
|
||||
return nil, resolv.Error
|
||||
}
|
||||
if hostname != resolv.Expect {
|
||||
return nil, ErrUnexpectedPunycode
|
||||
}
|
||||
return resolv.Addresses, nil
|
||||
}
|
||||
|
||||
func (r CheckIDNAResolver) Network() string {
|
||||
return "checkidna"
|
||||
}
|
||||
|
||||
func (r CheckIDNAResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestIDNAResolverSuccess(t *testing.T) {
|
||||
expectedIPs := []string{"77.88.55.66"}
|
||||
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
|
||||
Addresses: expectedIPs,
|
||||
Expect: "xn--d1acpjx3f.xn--p1ai",
|
||||
}}
|
||||
addrs, err := resolv.LookupHost(context.Background(), "яндекс.рф")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedIPs, addrs); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDNAResolverFailure(t *testing.T) {
|
||||
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
|
||||
Error: errors.New("we should not arrive here"),
|
||||
}}
|
||||
// See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/
|
||||
addrs, err := resolv.LookupHost(context.Background(), "xn--0000h")
|
||||
if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDNAResolverTransportOK(t *testing.T) {
|
||||
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{}}
|
||||
if resolv.Network() != "idna" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
if resolv.Address() != "" {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
func testresolverquick(t *testing.T, reso resolver.Resolver) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
reso = resolver.LoggingResolver{Logger: log.Log, Resolver: reso}
|
||||
addrs, err := reso.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil addrs here")
|
||||
}
|
||||
var foundquad8 bool
|
||||
for _, addr := range addrs {
|
||||
// See https://github.com/ooni/probe-cli/v3/internal/engine/pull/954/checks?check_run_id=1182269025
|
||||
if addr == "8.8.8.8" || addr == "2001:4860:4860::8888" {
|
||||
foundquad8 = true
|
||||
}
|
||||
}
|
||||
if !foundquad8 {
|
||||
t.Fatalf("did not find 8.8.8.8 in ouput; output=%+v", addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensuring we can handle Internationalized Domain Names (IDNs) without issues
|
||||
func testresolverquickidna(t *testing.T, reso resolver.Resolver) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
reso = resolver.IDNAResolver{
|
||||
resolver.LoggingResolver{Logger: log.Log, Resolver: reso},
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "яндекс.рф")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil addrs here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolverSystem(t *testing.T) {
|
||||
reso := resolver.SystemResolver{}
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverUDPAddress(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverUDP(new(net.Dialer), "8.8.8.8:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverUDPDomain(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverUDP(new(net.Dialer), "dns.google.com:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverTCPAddress(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "8.8.8.8:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverTCPDomain(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "dns.google.com:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverDoTAddress(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverDoTDomain(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTLS(resolver.DialTLSContext, "dns.google.com:853"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverDoH(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverHTTPS(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger is the logger assumed by this package
|
||||
type Logger interface {
|
||||
Debugf(format string, v ...interface{})
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
// LoggingResolver is a resolver that emits events
|
||||
type LoggingResolver struct {
|
||||
Resolver
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// LookupHost returns the IP addresses of a host
|
||||
func (r LoggingResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
r.Logger.Debugf("resolve %s...", hostname)
|
||||
start := time.Now()
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
stop := time.Now()
|
||||
r.Logger.Debugf("resolve %s... (%+v, %+v) in %s", hostname, addrs, err, stop.Sub(start))
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ Resolver = LoggingResolver{}
|
||||
@@ -0,0 +1,23 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestLoggingResolver(t *testing.T) {
|
||||
r := resolver.LoggingResolver{
|
||||
Logger: log.Log,
|
||||
Resolver: resolver.NewFakeResolverThatFails(),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.google.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addr here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Resolver is a DNS resolver. The *net.Resolver used by Go implements
|
||||
// this interface, but other implementations are possible.
|
||||
type Resolver interface {
|
||||
// LookupHost resolves a hostname to a list of IP addresses.
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
|
||||
// Network returns the network being used by the resolver
|
||||
Network() string
|
||||
|
||||
// Address returns the address being used by the resolver
|
||||
Address() string
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// SaverResolver is a resolver that saves events
|
||||
type SaverResolver struct {
|
||||
Resolver
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
start := time.Now()
|
||||
r.Saver.Write(trace.Event{
|
||||
Address: r.Resolver.Address(),
|
||||
Hostname: hostname,
|
||||
Name: "resolve_start",
|
||||
Proto: r.Resolver.Network(),
|
||||
Time: start,
|
||||
})
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
stop := time.Now()
|
||||
r.Saver.Write(trace.Event{
|
||||
Addresses: addrs,
|
||||
Address: r.Resolver.Address(),
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Hostname: hostname,
|
||||
Name: "resolve_done",
|
||||
Proto: r.Resolver.Network(),
|
||||
Time: stop,
|
||||
})
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
// SaverDNSTransport is a DNS transport that saves events
|
||||
type SaverDNSTransport struct {
|
||||
RoundTripper
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
start := time.Now()
|
||||
txp.Saver.Write(trace.Event{
|
||||
Address: txp.Address(),
|
||||
DNSQuery: query,
|
||||
Name: "dns_round_trip_start",
|
||||
Proto: txp.Network(),
|
||||
Time: start,
|
||||
})
|
||||
reply, err := txp.RoundTripper.RoundTrip(ctx, query)
|
||||
stop := time.Now()
|
||||
txp.Saver.Write(trace.Event{
|
||||
Address: txp.Address(),
|
||||
DNSQuery: query,
|
||||
DNSReply: reply,
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: "dns_round_trip_done",
|
||||
Proto: txp.Network(),
|
||||
Time: stop,
|
||||
})
|
||||
return reply, err
|
||||
}
|
||||
|
||||
var _ Resolver = SaverResolver{}
|
||||
var _ RoundTripper = SaverDNSTransport{}
|
||||
@@ -0,0 +1,211 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
func TestSaverResolverFailure(t *testing.T) {
|
||||
expected := errors.New("no such host")
|
||||
saver := &trace.Saver{}
|
||||
reso := resolver.SaverResolver{
|
||||
Resolver: resolver.FakeResolver{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if ev[0].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[0].Name != "resolve_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if ev[1].Addresses != nil {
|
||||
t.Fatal("unexpected Addresses")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if !errors.Is(ev[1].Err, expected) {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[1].Name != "resolve_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverResolverSuccess(t *testing.T) {
|
||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||
saver := &trace.Saver{}
|
||||
reso := resolver.SaverResolver{
|
||||
Resolver: resolver.FakeResolver{
|
||||
Result: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal("expected nil error here")
|
||||
}
|
||||
if !reflect.DeepEqual(addrs, expected) {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if ev[0].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[0].Name != "resolve_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[1].Addresses, expected) {
|
||||
t.Fatal("unexpected Addresses")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[1].Name != "resolve_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverDNSTransportFailure(t *testing.T) {
|
||||
expected := errors.New("no such host")
|
||||
saver := &trace.Saver{}
|
||||
txp := resolver.SaverDNSTransport{
|
||||
RoundTripper: resolver.FakeTransport{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
query := []byte("abc")
|
||||
reply, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[0].Name != "dns_round_trip_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[1].DNSReply != nil {
|
||||
t.Fatal("unexpected DNSReply")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if !errors.Is(ev[1].Err, expected) {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Name != "dns_round_trip_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverDNSTransportSuccess(t *testing.T) {
|
||||
expected := []byte("def")
|
||||
saver := &trace.Saver{}
|
||||
txp := resolver.SaverDNSTransport{
|
||||
RoundTripper: resolver.FakeTransport{
|
||||
Data: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
query := []byte("abc")
|
||||
reply, err := txp.RoundTrip(context.Background(), query)
|
||||
if err != nil {
|
||||
t.Fatal("we expected nil error here")
|
||||
}
|
||||
if !bytes.Equal(reply, expected) {
|
||||
t.Fatal("expected another reply here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[0].Name != "dns_round_trip_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSReply, expected) {
|
||||
t.Fatal("unexpected DNSReply")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Name != "dns_round_trip_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
)
|
||||
|
||||
// RoundTripper represents an abstract DNS transport.
|
||||
type RoundTripper interface {
|
||||
// RoundTrip sends a DNS query and receives the reply.
|
||||
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
|
||||
// RequiresPadding return true for DoH and DoT according to RFC8467
|
||||
RequiresPadding() bool
|
||||
|
||||
// Network is the network of the round tripper (e.g. "dot")
|
||||
Network() string
|
||||
|
||||
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
||||
Address() string
|
||||
}
|
||||
|
||||
// SerialResolver is a resolver that first issues an A query and then
|
||||
// issues an AAAA query for the requested domain.
|
||||
type SerialResolver struct {
|
||||
Encoder Encoder
|
||||
Decoder Decoder
|
||||
NumTimeouts *atomicx.Int64
|
||||
Txp RoundTripper
|
||||
}
|
||||
|
||||
// NewSerialResolver creates a new OONI Resolver instance.
|
||||
func NewSerialResolver(t RoundTripper) SerialResolver {
|
||||
return SerialResolver{
|
||||
Encoder: MiekgEncoder{},
|
||||
Decoder: MiekgDecoder{},
|
||||
NumTimeouts: atomicx.NewInt64(),
|
||||
Txp: t,
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the transport being used.
|
||||
func (r SerialResolver) Transport() RoundTripper {
|
||||
return r.Txp
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (r SerialResolver) Network() string {
|
||||
return r.Txp.Network()
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (r SerialResolver) Address() string {
|
||||
return r.Txp.Address()
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var addrs []string
|
||||
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
|
||||
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
|
||||
if errA != nil && errAAAA != nil {
|
||||
return nil, errA
|
||||
}
|
||||
addrs = append(addrs, addrsA...)
|
||||
addrs = append(addrs, addrsAAAA...)
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (r SerialResolver) roundTripWithRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
var errorslist []error
|
||||
for i := 0; i < 3; i++ {
|
||||
replies, err := r.roundTrip(ctx, hostname, qtype)
|
||||
if err == nil {
|
||||
return replies, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
var operr *net.OpError
|
||||
if errors.As(err, &operr) == false || operr.Timeout() == false {
|
||||
// The first error is the one that is most likely to be caused
|
||||
// by the network. Subsequent errors are more likely to be caused
|
||||
// by context deadlines. So, the first error is attached to an
|
||||
// operation, while subsequent errors may possibly not be. If
|
||||
// so, the resulting failing operation is not correct.
|
||||
break
|
||||
}
|
||||
r.NumTimeouts.Add(1)
|
||||
}
|
||||
// bugfix: we MUST return one of the errors otherwise we confuse the
|
||||
// mechanism in errwrap that classifies the root cause operation, since
|
||||
// it would not be able to find a child with a major operation error
|
||||
return nil, errorslist[0]
|
||||
}
|
||||
|
||||
func (r SerialResolver) roundTrip(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.Decode(qtype, replydata)
|
||||
}
|
||||
|
||||
var _ Resolver = SerialResolver{}
|
||||
@@ -0,0 +1,111 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestOONIGettingTransport(t *testing.T) {
|
||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853")
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
rtx := r.Transport()
|
||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||
t.Fatal("not the transport we expected")
|
||||
}
|
||||
if r.Network() != rtx.Network() {
|
||||
t.Fatal("invalid network seen from the resolver")
|
||||
}
|
||||
if r.Address() != rtx.Address() {
|
||||
t.Fatal("invalid address seen from the resolver")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIEncodeError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853")
|
||||
r := resolver.SerialResolver{Encoder: resolver.FakeEncoder{Err: mocked}, Txp: txp}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIRoundTripError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.FakeTransport{Err: mocked}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithEmptyReply(t *testing.T) {
|
||||
txp := resolver.FakeTransport{Data: resolver.GenReplySuccess(t, dns.TypeA)}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithAReply(t *testing.T) {
|
||||
txp := resolver.FakeTransport{
|
||||
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
|
||||
}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithAAAAReply(t *testing.T) {
|
||||
txp := resolver.FakeTransport{
|
||||
Data: resolver.GenReplySuccess(t, dns.TypeAAAA, "::1"),
|
||||
}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "::1" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOONIWithTimeout(t *testing.T) {
|
||||
txp := resolver.FakeTransport{
|
||||
Err: &net.OpError{Err: syscall.ETIMEDOUT, Op: "dial"},
|
||||
}
|
||||
r := resolver.NewSerialResolver(txp)
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, syscall.ETIMEDOUT) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
if r.NumTimeouts.Load() <= 0 {
|
||||
t.Fatal("we didn't actually take the timeouts")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package resolver
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor"
|
||||
|
||||
// SystemResolver is the system resolver. It is implemented using
|
||||
// selfcensor.SystemResolver so that we can perform integration testing
|
||||
// by forcing the code to return specific responses.
|
||||
type SystemResolver = selfcensor.SystemResolver
|
||||
|
||||
var _ Resolver = SystemResolver{}
|
||||
@@ -0,0 +1,25 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestSystemResolverLookupHost(t *testing.T) {
|
||||
r := resolver.SystemResolver{}
|
||||
if r.Network() != "system" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if r.Address() != "" {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil result here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
func DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
connch := make(chan net.Conn)
|
||||
errch := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := tls.Dial(network, address, new(tls.Config))
|
||||
if err != nil {
|
||||
errch <- err
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
conn.Close()
|
||||
case connch <- conn:
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case conn := <-connch:
|
||||
return conn, nil
|
||||
case err := <-errch:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user