chore: merge probe-engine into probe-cli (#201)

This is how I did it:

1. `git clone https://github.com/ooni/probe-engine internal/engine`

2. ```
(cd internal/engine && git describe --tags)
v0.23.0
```

3. `nvim go.mod` (merging `go.mod` with `internal/engine/go.mod`

4. `rm -rf internal/.git internal/engine/go.{mod,sum}`

5. `git add internal/engine`

6. `find . -type f -name \*.go -exec sed -i 's@/ooni/probe-engine@/ooni/probe-cli/v3/internal/engine@g' {} \;`

7. `go build ./...` (passes)

8. `go test -race ./...` (temporary failure on RiseupVPN)

9. `go mod tidy`

10. this commit message

Once this piece of work is done, we can build a new version of `ooniprobe` that
is using `internal/engine` directly. We need to do more work to ensure all the
other functionality in `probe-engine` (e.g. making mobile packages) are still WAI.

Part of https://github.com/ooni/probe/issues/1335
This commit is contained in:
Simone Basso
2021-02-02 12:05:47 +01:00
committed by GitHub
parent b1ce300c8d
commit d57c78bc71
535 changed files with 66182 additions and 23 deletions
+22
View File
@@ -0,0 +1,22 @@
package resolver
import (
"context"
"net"
)
// AddressResolver is a resolver that knows how to correctly
// resolve IP addresses to themselves.
type AddressResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r AddressResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {
return []string{hostname}, nil
}
return r.Resolver.LookupHost(ctx, hostname)
}
var _ Resolver = AddressResolver{}
@@ -0,0 +1,36 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestAddressSuccess(t *testing.T) {
r := resolver.AddressResolver{}
addrs, err := r.LookupHost(context.Background(), "8.8.8.8")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
}
func TestAddressFailure(t *testing.T) {
expected := errors.New("mocked error")
r := resolver.AddressResolver{
Resolver: resolver.FakeResolver{
Err: expected,
},
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil addrs")
}
}
+71
View File
@@ -0,0 +1,71 @@
package resolver
import (
"context"
"net"
"github.com/ooni/probe-cli/v3/internal/engine/internal/runtimex"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
var privateIPBlocks []*net.IPNet
func init() {
for _, cidr := range []string{
"0.0.0.0/8", // "This" network (however, Linux...)
"10.0.0.0/8", // RFC1918
"100.64.0.0/10", // Carrier grade NAT
"127.0.0.0/8", // IPv4 loopback
"169.254.0.0/16", // RFC3927 link-local
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"224.0.0.0/4", // Multicast
"::1/128", // IPv6 loopback
"fe80::/10", // IPv6 link-local
"fc00::/7", // IPv6 unique local addr
} {
_, block, err := net.ParseCIDR(cidr)
runtimex.PanicOnError(err, "net.ParseCIDR failed")
privateIPBlocks = append(privateIPBlocks, block)
}
}
func isPrivate(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
for _, block := range privateIPBlocks {
if block.Contains(ip) {
return true
}
}
return false
}
// IsBogon returns whether if an IP address is bogon. Passing to this
// function a non-IP address causes it to return bogon.
func IsBogon(address string) bool {
ip := net.ParseIP(address)
return ip == nil || isPrivate(ip)
}
// BogonResolver is a bogon aware resolver. When a bogon is encountered in
// a reply, this resolver will return an error.
type BogonResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r BogonResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname)
for _, addr := range addrs {
if IsBogon(addr) == true {
// We need to return the addrs otherwise the caller cannot see/log/save
// the specific addresses that triggered our bogon filter
return addrs, errorx.ErrDNSBogon
}
}
return addrs, err
}
var _ Resolver = BogonResolver{}
@@ -0,0 +1,52 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestResolverIsBogon(t *testing.T) {
if resolver.IsBogon("antani") != true {
t.Fatal("unexpected result")
}
if resolver.IsBogon("127.0.0.1") != true {
t.Fatal("unexpected result")
}
if resolver.IsBogon("1.1.1.1") != false {
t.Fatal("unexpected result")
}
if resolver.IsBogon("10.0.1.1") != true {
t.Fatal("unexpected result")
}
}
func TestBogonAwareResolverWithBogon(t *testing.T) {
r := resolver.BogonResolver{
Resolver: resolver.NewFakeResolverWithResult([]string{"127.0.0.1"}),
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if !errors.Is(err, errorx.ErrDNSBogon) {
t.Fatal("not the error we expected")
}
if len(addrs) != 1 || addrs[0] != "127.0.0.1" {
t.Fatal("expected to see address here")
}
}
func TestBogonAwareResolverWithoutBogon(t *testing.T) {
orig := []string{"8.8.8.8"}
r := resolver.BogonResolver{
Resolver: resolver.NewFakeResolverWithResult(orig),
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != len(orig) || addrs[0] != orig[0] {
t.Fatal("not the error we expected")
}
}
+47
View File
@@ -0,0 +1,47 @@
package resolver
import (
"context"
"sync"
)
// CacheResolver is a resolver that caches successful replies.
type CacheResolver struct {
ReadOnly bool
Resolver
mu sync.Mutex
cache map[string][]string
}
// LookupHost implements Resolver.LookupHost
func (r *CacheResolver) LookupHost(
ctx context.Context, hostname string) ([]string, error) {
if entry := r.Get(hostname); entry != nil {
return entry, nil
}
entry, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil {
return nil, err
}
if r.ReadOnly == false {
r.Set(hostname, entry)
}
return entry, nil
}
// Get gets the currently configured entry for domain, or nil
func (r *CacheResolver) Get(domain string) []string {
r.mu.Lock()
defer r.mu.Unlock()
return r.cache[domain]
}
// Set allows to pre-populate the cache
func (r *CacheResolver) Set(domain string, addresses []string) {
r.mu.Lock()
if r.cache == nil {
r.cache = make(map[string][]string)
}
r.cache[domain] = addresses
r.mu.Unlock()
}
@@ -0,0 +1,76 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestCacheFailure(t *testing.T) {
expected := errors.New("mocked error")
var r resolver.Resolver = resolver.FakeResolver{
Err: expected,
}
cache := &resolver.CacheResolver{Resolver: r}
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil addrs here")
}
if cache.Get("www.google.com") != nil {
t.Fatal("expected empty cache here")
}
}
func TestCacheHitSuccess(t *testing.T) {
var r resolver.Resolver = resolver.FakeResolver{
Err: errors.New("mocked error"),
}
cache := &resolver.CacheResolver{Resolver: r}
cache.Set("dns.google.com", []string{"8.8.8.8"})
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
}
func TestCacheMissSuccess(t *testing.T) {
var r resolver.Resolver = resolver.FakeResolver{
Result: []string{"8.8.8.8"},
}
cache := &resolver.CacheResolver{Resolver: r}
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
if cache.Get("dns.google.com")[0] != "8.8.8.8" {
t.Fatal("expected full cache here")
}
}
func TestCacheReadonlySuccess(t *testing.T) {
var r resolver.Resolver = resolver.FakeResolver{
Result: []string{"8.8.8.8"},
}
cache := &resolver.CacheResolver{Resolver: r, ReadOnly: true}
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
if cache.Get("dns.google.com") != nil {
t.Fatal("expected empty cache here")
}
}
+33
View File
@@ -0,0 +1,33 @@
package resolver
import (
"context"
)
// ChainResolver is a chain resolver. The primary resolver is used first and, if that
// fails, we then attempt with the secondary resolver.
type ChainResolver struct {
Primary Resolver
Secondary Resolver
}
// LookupHost implements Resolver.LookupHost
func (c ChainResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := c.Primary.LookupHost(ctx, hostname)
if err != nil {
addrs, err = c.Secondary.LookupHost(ctx, hostname)
}
return addrs, err
}
// Network implements Resolver.Network
func (c ChainResolver) Network() string {
return "chain"
}
// Address implements Resolver.Address
func (c ChainResolver) Address() string {
return ""
}
var _ Resolver = ChainResolver{}
@@ -0,0 +1,28 @@
package resolver_test
import (
"context"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestChainLookupHost(t *testing.T) {
r := resolver.ChainResolver{
Primary: resolver.NewFakeResolverThatFails(),
Secondary: resolver.SystemResolver{},
}
if r.Address() != "" {
t.Fatal("invalid address")
}
if r.Network() != "chain" {
t.Fatal("invalid network")
}
addrs, err := r.LookupHost(context.Background(), "www.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expect non nil return value here")
}
}
+54
View File
@@ -0,0 +1,54 @@
package resolver
import (
"errors"
"github.com/miekg/dns"
)
// The Decoder decodes a DNS reply into A or AAAA entries. It will use the
// provided qtype and only look for mathing entries. It will return error if
// there are no entries for the requested qtype inside the reply.
type Decoder interface {
Decode(qtype uint16, data []byte) ([]string, error)
}
// MiekgDecoder uses github.com/miekg/dns to implement the Decoder.
type MiekgDecoder struct{}
// Decode implements Decoder.Decode.
func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
reply := new(dns.Msg)
if err := reply.Unpack(data); err != nil {
return nil, err
}
// TODO(bassosimone): map more errors to net.DNSError names
switch reply.Rcode {
case dns.RcodeSuccess:
case dns.RcodeNameError:
return nil, errors.New("ooniresolver: no such host")
default:
return nil, errors.New("ooniresolver: query failed")
}
var addrs []string
for _, answer := range reply.Answer {
switch qtype {
case dns.TypeA:
if rra, ok := answer.(*dns.A); ok {
ip := rra.A
addrs = append(addrs, ip.String())
}
case dns.TypeAAAA:
if rra, ok := answer.(*dns.AAAA); ok {
ip := rra.AAAA
addrs = append(addrs, ip.String())
}
}
}
if len(addrs) <= 0 {
return nil, errors.New("ooniresolver: no response returned")
}
return addrs, nil
}
var _ Decoder = MiekgDecoder{}
@@ -0,0 +1,113 @@
package resolver_test
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDecoderUnpackError(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, nil)
if err == nil {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderNXDOMAIN(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeNameError))
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderOtherError(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeRefused))
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderNoAddress(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderDecodeA(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "1.1.1.1" {
t.Fatal("invalid first IPv4 entry")
}
if data[1] != "8.8.8.8" {
t.Fatal("invalid second IPv4 entry")
}
}
func TestDecoderDecodeAAAA(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "::1" {
t.Fatal("invalid first IPv6 entry")
}
if data[1] != "fe80::1" {
t.Fatal("invalid second IPv6 entry")
}
}
func TestDecoderUnexpectedAReply(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
@@ -0,0 +1,77 @@
package resolver
import (
"bytes"
"context"
"errors"
"io/ioutil"
"net/http"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
)
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
// an HTTP/HTTPS channel provided by URL using the Do function.
type DNSOverHTTPS struct {
Do func(req *http.Request) (*http.Response, error)
URL string
HostOverride string
}
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
// specified http.Client and URL, as a convenience.
func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS {
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
}
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
// it's creating a resolver where we use the specified host.
func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS {
return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
if err != nil {
return nil, err
}
req.Host = t.HostOverride
req.Header.Set("user-agent", httpheader.UserAgent())
req.Header.Set("content-type", "application/dns-message")
var resp *http.Response
resp, err = t.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
// TODO(bassosimone): we should map the status code to a
// proper Error in the DNS context.
return nil, errors.New("doh: server returned error")
}
if resp.Header.Get("content-type") != "application/dns-message" {
return nil, errors.New("doh: invalid content-type")
}
return ioutil.ReadAll(resp.Body)
}
// RequiresPadding returns true for DoH according to RFC8467
func (t DNSOverHTTPS) RequiresPadding() bool {
return true
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverHTTPS) Network() string {
return "doh"
}
// Address returns the upstream server address.
func (t DNSOverHTTPS) Address() string {
return t.URL
}
var _ RoundTripper = DNSOverHTTPS{}
@@ -0,0 +1,165 @@
package resolver_test
import (
"bytes"
"context"
"errors"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
const invalidURL = "\t"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
expected := errors.New("mocked error")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: ioutil.NopCloser(strings.NewReader("")),
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSSuccess(t *testing.T) {
body := []byte("AAA")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, body) {
t.Fatal("not the response we expected")
}
}
func TestDNSOverHTTPTransportOK(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
if txp.RequiresPadding() != true {
t.Fatal("should require padding")
}
if txp.Address() != queryURL {
t.Fatal("invalid address")
}
}
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct user agent")
}
}
func TestDNSOverHTTPSHostOverride(t *testing.T) {
var correct bool
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
}
@@ -0,0 +1,97 @@
package resolver
import (
"context"
"errors"
"io"
"math"
"net"
"time"
)
// DialContextFunc is a generic function for dialing a connection.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
// and NewDNSOverTLS to create specific instances that use plaintext
// queries or encrypted queries over TLS.
//
// As a known bug, this implementation always creates a new connection
// for each incoming query, thus increasing the response delay.
type DNSOverTCP struct {
dial DialContextFunc
address string
network string
requiresPadding bool
}
// NewDNSOverTCP creates a new DNSOverTCP transport.
func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
return DNSOverTCP{
dial: dial,
address: address,
network: "tcp",
requiresPadding: false,
}
}
// NewDNSOverTLS creates a new DNSOverTLS transport.
func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
return DNSOverTCP{
dial: dial,
address: address,
network: "dot",
requiresPadding: true,
}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
if len(query) > math.MaxUint16 {
return nil, errors.New("query too long")
}
conn, err := t.dial(ctx, "tcp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return nil, err
}
// Write request
buf := []byte{byte(len(query) >> 8)}
buf = append(buf, byte(len(query)))
buf = append(buf, query...)
if _, err = conn.Write(buf); err != nil {
return nil, err
}
// Read response
header := make([]byte, 2)
if _, err = io.ReadFull(conn, header); err != nil {
return nil, err
}
length := int(header[0])<<8 | int(header[1])
reply := make([]byte, length)
if _, err = io.ReadFull(conn, reply); err != nil {
return nil, err
}
return reply, nil
}
// RequiresPadding returns true for DoT and false for TCP
// according to RFC8467.
func (t DNSOverTCP) RequiresPadding() bool {
return t.requiresPadding
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverTCP) Network() string {
return t.network
}
// Address returns the upstream server address.
func (t DNSOverTCP) Address() string {
return t.address
}
var _ RoundTripper = DNSOverTCP{}
@@ -0,0 +1,146 @@
package resolver_test
import (
"context"
"errors"
"net"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
if err == nil {
t.Fatal("expected an error here")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Err: mocked}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
SetDeadlineError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
WriteError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
ReadError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(2)},
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportAllGood(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(1), byte(1)},
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if err != nil {
t.Fatal(err)
}
if len(reply) != 1 || reply[0] != 1 {
t.Fatal("not the response we expected")
}
}
func TestDNSOverTCPTransportOK(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "tcp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
}
func TestDNSOverTLSTransportOK(t *testing.T) {
const address = "9.9.9.9:853"
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "dot" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
}
@@ -0,0 +1,64 @@
package resolver
import (
"context"
"net"
"time"
)
// Dialer is the network dialer interface assumed by this package.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// DNSOverUDP is a DNS over UDP RoundTripper.
type DNSOverUDP struct {
dialer Dialer
address string
}
// NewDNSOverUDP creates a DNSOverUDP instance.
func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP {
return DNSOverUDP{dialer: dialer, address: address}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
// Use five seconds timeout like Bionic does. See
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, err
}
if _, err = conn.Write(query); err != nil {
return nil, err
}
reply := make([]byte, 1<<17)
var n int
n, err = conn.Read(reply)
if err != nil {
return nil, err
}
return reply[:n], nil
}
// RequiresPadding returns false for UDP according to RFC8467
func (t DNSOverUDP) RequiresPadding() bool {
return false
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverUDP) Network() string {
return "udp"
}
// Address returns the upstream server address.
func (t DNSOverUDP) Address() string {
return t.address
}
var _ RoundTripper = DNSOverUDP{}
@@ -0,0 +1,107 @@
package resolver_test
import (
"context"
"errors"
"net"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverUDPDialFailure(t *testing.T) {
mocked := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
SetDeadlineError: mocked,
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPWriteFailure(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
WriteError: mocked,
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPReadFailure(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
ReadError: mocked,
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPReadSuccess(t *testing.T) {
const expected = 17
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{ReadData: make([]byte, 17)},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(data) != expected {
t.Fatal("expected non nil data")
}
}
func TestDNSOverUDPTransportOK(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverUDP(&net.Dialer{}, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "udp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
}
+89
View File
@@ -0,0 +1,89 @@
package resolver
import (
"context"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
)
// EmitterTransport is a RoundTripper that emits events when they occur.
type EmitterTransport struct {
RoundTripper
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp EmitterTransport) RoundTrip(ctx context.Context, querydata []byte) ([]byte, error) {
root := modelx.ContextMeasurementRootOrDefault(ctx)
root.Handler.OnMeasurement(modelx.Measurement{
DNSQuery: &modelx.DNSQueryEvent{
Data: querydata,
DialID: dialid.ContextDialID(ctx),
DurationSinceBeginning: time.Now().Sub(root.Beginning),
},
})
replydata, err := txp.RoundTripper.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
root.Handler.OnMeasurement(modelx.Measurement{
DNSReply: &modelx.DNSReplyEvent{
Data: replydata,
DialID: dialid.ContextDialID(ctx),
DurationSinceBeginning: time.Now().Sub(root.Beginning),
},
})
return replydata, nil
}
// EmitterResolver is a resolver that emits events
type EmitterResolver struct {
Resolver
}
// LookupHost returns the IP addresses of a host
func (r EmitterResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
var (
network string
address string
)
type queryableResolver interface {
Transport() RoundTripper
}
if qr, ok := r.Resolver.(queryableResolver); ok {
txp := qr.Transport()
network, address = txp.Network(), txp.Address()
}
dialID := dialid.ContextDialID(ctx)
txID := transactionid.ContextTransactionID(ctx)
root := modelx.ContextMeasurementRootOrDefault(ctx)
root.Handler.OnMeasurement(modelx.Measurement{
ResolveStart: &modelx.ResolveStartEvent{
DialID: dialID,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
Hostname: hostname,
TransactionID: txID,
TransportAddress: address,
TransportNetwork: network,
},
})
addrs, err := r.Resolver.LookupHost(ctx, hostname)
root.Handler.OnMeasurement(modelx.Measurement{
ResolveDone: &modelx.ResolveDoneEvent{
Addresses: addrs,
DialID: dialID,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
Error: err,
Hostname: hostname,
TransactionID: txID,
TransportAddress: address,
TransportNetwork: network,
},
})
return addrs, err
}
var _ RoundTripper = EmitterTransport{}
var _ Resolver = EmitterResolver{}
@@ -0,0 +1,220 @@
package resolver_test
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"testing"
"time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestEmitterTransportSuccess(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
}}
e := resolver.MiekgEncoder{}
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
if err != nil {
t.Fatal(err)
}
replydata, err := txp.RoundTrip(ctx, querydata)
if err != nil {
t.Fatal(err)
}
events := handler.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
if events[0].DNSQuery == nil {
t.Fatal("missing DNSQuery field")
}
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
t.Fatal("invalid query data")
}
if events[0].DNSQuery.DialID == 0 {
t.Fatal("invalid query DialID")
}
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
if events[1].DNSReply == nil {
t.Fatal("missing DNSReply field")
}
if !bytes.Equal(events[1].DNSReply.Data, replydata) {
t.Fatal("missing reply data")
}
if events[1].DNSReply.DialID != 1 {
t.Fatal("invalid query DialID")
}
if events[1].DNSReply.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
}
func TestEmitterTransportFailure(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
mocked := errors.New("mocked error")
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
Err: mocked,
}}
e := resolver.MiekgEncoder{}
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
if err != nil {
t.Fatal(err)
}
replydata, err := txp.RoundTrip(ctx, querydata)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if replydata != nil {
t.Fatal("expected nil replydata")
}
events := handler.Read()
if len(events) != 1 {
t.Fatal("unexpected number of events")
}
if events[0].DNSQuery == nil {
t.Fatal("missing DNSQuery field")
}
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
t.Fatal("invalid query data")
}
if events[0].DNSQuery.DialID == 0 {
t.Fatal("invalid query DialID")
}
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
}
func TestEmitterResolverFailure(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
ctx = transactionid.WithTransactionID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
return nil, io.EOF
},
URL: "https://dns.google.com/",
},
)}
replies, err := r.LookupHost(ctx, "www.google.com")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if replies != nil {
t.Fatal("expected nil replies")
}
events := handler.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
if events[0].ResolveStart == nil {
t.Fatal("missing ResolveStart field")
}
if events[0].ResolveStart.DialID == 0 {
t.Fatal("invalid DialID")
}
if events[0].ResolveStart.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
if events[0].ResolveStart.Hostname != "www.google.com" {
t.Fatal("invalid Hostname")
}
if events[0].ResolveStart.TransactionID == 0 {
t.Fatal("invalid TransactionID")
}
if events[0].ResolveStart.TransportAddress != "https://dns.google.com/" {
t.Fatal("invalid TransportAddress")
}
if events[0].ResolveStart.TransportNetwork != "doh" {
t.Fatal("invalid TransportNetwork")
}
if events[1].ResolveDone == nil {
t.Fatal("missing ResolveDone field")
}
if events[1].ResolveDone.DialID == 0 {
t.Fatal("invalid DialID")
}
if events[1].ResolveDone.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
if events[1].ResolveDone.Error != io.EOF {
t.Fatal("invalid Error")
}
if events[1].ResolveDone.Hostname != "www.google.com" {
t.Fatal("invalid Hostname")
}
if events[1].ResolveDone.TransactionID == 0 {
t.Fatal("invalid TransactionID")
}
if events[1].ResolveDone.TransportAddress != "https://dns.google.com/" {
t.Fatal("invalid TransportAddress")
}
if events[1].ResolveDone.TransportNetwork != "doh" {
t.Fatal("invalid TransportNetwork")
}
}
func TestEmitterResolverSuccess(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
ctx = transactionid.WithTransactionID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
r := resolver.EmitterResolver{Resolver: resolver.NewFakeResolverWithResult(
[]string{"8.8.8.8"},
)}
replies, err := r.LookupHost(ctx, "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(replies) != 1 {
t.Fatal("expected a single replies")
}
events := handler.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
if events[1].ResolveDone == nil {
t.Fatal("missing ResolveDone field")
}
if events[1].ResolveDone.Addresses[0] != "8.8.8.8" {
t.Fatal("invalid Addresses")
}
}
+52
View File
@@ -0,0 +1,52 @@
package resolver
import "github.com/miekg/dns"
// The Encoder encodes DNS queries to bytes
type Encoder interface {
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
}
// MiekgEncoder uses github.com/miekg/dns to implement the Encoder.
type MiekgEncoder struct{}
const (
// PaddingDesiredBlockSize is the size that the padded query should be multiple of
PaddingDesiredBlockSize = 128
// EDNS0MaxResponseSize is the maximum response size for EDNS0
EDNS0MaxResponseSize = 4096
// DNSSECEnabled turns on support for DNSSEC when using EDNS0
DNSSECEnabled = true
)
// Encode implements Encoder.Encode
func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
question := dns.Question{
Name: dns.Fqdn(domain),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
if padding {
query.SetEdns0(EDNS0MaxResponseSize, DNSSECEnabled)
// Clients SHOULD pad queries to the closest multiple of
// 128 octets RFC8467#section-4.1. We inflate the query
// length by the size of the option (i.e. 4 octets). The
// cast to uint is necessary to make the modulus operation
// work as intended when the desiredBlockSize is smaller
// than (query.Len()+4) ¯\_(ツ)_/¯.
remainder := (PaddingDesiredBlockSize - uint(query.Len()+4)) % PaddingDesiredBlockSize
opt := new(dns.EDNS0_PADDING)
opt.Padding = make([]byte, remainder)
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
}
return query.Pack()
}
var _ Encoder = MiekgEncoder{}
@@ -0,0 +1,99 @@
package resolver_test
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestEncoderEncodeA(t *testing.T) {
e := resolver.MiekgEncoder{}
data, err := e.Encode("x.org", dns.TypeA, false)
if err != nil {
t.Fatal(err)
}
validate(t, data, byte(dns.TypeA))
}
func TestEncoderEncodeAAAA(t *testing.T) {
e := resolver.MiekgEncoder{}
data, err := e.Encode("x.org", dns.TypeAAAA, false)
if err != nil {
t.Fatal(err)
}
validate(t, data, byte(dns.TypeA))
}
func validate(t *testing.T, data []byte, qtype byte) {
// skipping over the query ID
if data[2] != 1 {
t.Fatal("FLAGS should only have RD set")
}
if data[3] != 0 {
t.Fatal("RA|Z|Rcode should be zero")
}
if data[4] != 0 || data[5] != 1 {
t.Fatal("QCOUNT high should be one")
}
if data[6] != 0 || data[7] != 0 {
t.Fatal("ANCOUNT should be zero")
}
if data[8] != 0 || data[9] != 0 {
t.Fatal("NSCOUNT should be zero")
}
if data[10] != 0 || data[11] != 0 {
t.Fatal("ARCOUNT should be zero")
}
t.Log(data[12])
if data[12] != 1 || data[13] != byte('x') {
t.Fatal("The name does not contain 1:x")
}
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
t.Fatal("The name does not containg 3:org")
}
if data[18] != 0 {
t.Fatal("The name does not terminate where expected")
}
if data[19] != 0 && data[20] != qtype {
t.Fatal("The query is not for the expected type")
}
if data[21] != 0 && data[22] != 1 {
t.Fatal("The query is not IN")
}
}
func TestEncoderPadding(t *testing.T) {
// The purpose of this unit test is to make sure that for a wide
// array of values we obtain the right query size.
getquerylen := func(domainlen int, padding bool) int {
e := resolver.MiekgEncoder{}
data, err := e.Encode(
// This is not a valid name because it ends up being way
// longer than 255 octets. However, the library is allowing
// us to generate such name and we are not going to send
// it on the wire. Also, we check below that the query that
// we generate is long enough, so we should be good.
dns.Fqdn(strings.Repeat("x.", domainlen)),
dns.TypeA, padding,
)
if err != nil {
t.Fatal(err)
}
return len(data)
}
for domainlen := 1; domainlen <= 4000; domainlen++ {
vanillalen := getquerylen(domainlen, false)
paddedlen := getquerylen(domainlen, true)
if vanillalen < domainlen {
t.Fatal("vanillalen is smaller than domainlen")
}
if (paddedlen % resolver.PaddingDesiredBlockSize) != 0 {
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
}
if paddedlen < vanillalen {
t.Fatal("paddedlen is smaller than vanillalen")
}
}
}
@@ -0,0 +1,30 @@
package resolver
import (
"context"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
// ErrorWrapperResolver is a Resolver that knows about wrapping errors.
type ErrorWrapperResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r ErrorWrapperResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
dialID := dialid.ContextDialID(ctx)
txID := transactionid.ContextTransactionID(ctx)
addrs, err := r.Resolver.LookupHost(ctx, hostname)
err = errorx.SafeErrWrapperBuilder{
DialID: dialID,
Error: err,
Operation: errorx.ResolveOperation,
TransactionID: txID,
}.MaybeBuild()
return addrs, err
}
var _ Resolver = ErrorWrapperResolver{}
@@ -0,0 +1,58 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestErrorWrapperSuccess(t *testing.T) {
orig := []string{"8.8.8.8"}
r := resolver.ErrorWrapperResolver{
Resolver: resolver.NewFakeResolverWithResult(orig),
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != len(orig) || addrs[0] != orig[0] {
t.Fatal("not the result we expected")
}
}
func TestErrorWrapperFailure(t *testing.T) {
r := resolver.ErrorWrapperResolver{
Resolver: resolver.NewFakeResolverThatFails(),
}
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
ctx = transactionid.WithTransactionID(ctx)
addrs, err := r.LookupHost(ctx, "dns.google.com")
if addrs != nil {
t.Fatal("expected nil addr here")
}
var errWrapper *errorx.ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("cannot properly cast the returned error")
}
if errWrapper.Failure != errorx.FailureDNSNXDOMAINError {
t.Fatal("unexpected failure")
}
if errWrapper.ConnID != 0 {
t.Fatal("unexpected ConnID")
}
if errWrapper.DialID == 0 {
t.Fatal("unexpected DialID")
}
if errWrapper.TransactionID == 0 {
t.Fatal("unexpected TransactionID")
}
if errWrapper.Operation != errorx.ResolveOperation {
t.Fatal("unexpected Operation")
}
}
+142
View File
@@ -0,0 +1,142 @@
package resolver
import (
"context"
"io"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
)
type FakeDialer struct {
Conn net.Conn
Err error
}
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return d.Conn, d.Err
}
type FakeConn struct {
ReadError error
ReadData []byte
SetDeadlineError error
SetReadDeadlineError error
SetWriteDeadlineError error
WriteError error
}
func (c *FakeConn) Read(b []byte) (int, error) {
if len(c.ReadData) > 0 {
n := copy(b, c.ReadData)
c.ReadData = c.ReadData[n:]
return n, nil
}
if c.ReadError != nil {
return 0, c.ReadError
}
return 0, io.EOF
}
func (c *FakeConn) Write(b []byte) (n int, err error) {
if c.WriteError != nil {
return 0, c.WriteError
}
n = len(b)
return
}
func (*FakeConn) Close() (err error) {
return
}
func (*FakeConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (*FakeConn) RemoteAddr() net.Addr {
return &net.TCPAddr{}
}
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
return c.SetDeadlineError
}
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
return c.SetReadDeadlineError
}
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
return c.SetWriteDeadlineError
}
type FakeTransport struct {
Data []byte
Err error
}
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
return ft.Data, ft.Err
}
func (ft FakeTransport) RequiresPadding() bool {
return false
}
func (ft FakeTransport) Address() string {
return ""
}
func (ft FakeTransport) Network() string {
return "fake"
}
type FakeEncoder struct {
Data []byte
Err error
}
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
return fe.Data, fe.Err
}
type FakeResolver struct {
NumFailures *atomicx.Int64
Err error
Result []string
}
func NewFakeResolverThatFails() FakeResolver {
return FakeResolver{NumFailures: atomicx.NewInt64(), Err: errNotFound}
}
func NewFakeResolverWithResult(r []string) FakeResolver {
return FakeResolver{NumFailures: atomicx.NewInt64(), Result: r}
}
var errNotFound = &net.DNSError{
Err: "no such host",
}
func (c FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
time.Sleep(10 * time.Microsecond)
if c.Err != nil {
if c.NumFailures != nil {
c.NumFailures.Add(1)
}
return nil, c.Err
}
return c.Result, nil
}
func (c FakeResolver) Network() string {
return "fake"
}
func (c FakeResolver) Address() string {
return ""
}
var _ Resolver = FakeResolver{}
@@ -0,0 +1,76 @@
package resolver
import (
"net"
"testing"
"github.com/miekg/dns"
)
func GenReplyError(t *testing.T, code int) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetRcode(query, code)
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
func GenReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query)
for _, ip := range ips {
switch qtype {
case dns.TypeA:
reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
A: net.ParseIP(ip),
})
case dns.TypeAAAA:
reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: net.ParseIP(ip),
})
}
}
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
+34
View File
@@ -0,0 +1,34 @@
package resolver
import (
"context"
"golang.org/x/net/idna"
)
// IDNAResolver is to support resolving Internationalized Domain Names.
// See RFC3492 for more information.
type IDNAResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r IDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
host, err := idna.ToASCII(hostname)
if err != nil {
return nil, err
}
return r.Resolver.LookupHost(ctx, host)
}
// Network implements Resolver.Network.
func (r IDNAResolver) Network() string {
return "idna"
}
// Address implements Resolver.Address.
func (r IDNAResolver) Address() string {
return ""
}
var _ Resolver = IDNAResolver{}
@@ -0,0 +1,76 @@
package resolver_test
import (
"context"
"errors"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
var ErrUnexpectedPunycode = errors.New("unexpected punycode value")
type CheckIDNAResolver struct {
Addresses []string
Error error
Expect string
}
func (resolv CheckIDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if resolv.Error != nil {
return nil, resolv.Error
}
if hostname != resolv.Expect {
return nil, ErrUnexpectedPunycode
}
return resolv.Addresses, nil
}
func (r CheckIDNAResolver) Network() string {
return "checkidna"
}
func (r CheckIDNAResolver) Address() string {
return ""
}
func TestIDNAResolverSuccess(t *testing.T) {
expectedIPs := []string{"77.88.55.66"}
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
Addresses: expectedIPs,
Expect: "xn--d1acpjx3f.xn--p1ai",
}}
addrs, err := resolv.LookupHost(context.Background(), "яндекс.рф")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(expectedIPs, addrs); diff != "" {
t.Fatal(diff)
}
}
func TestIDNAResolverFailure(t *testing.T) {
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
Error: errors.New("we should not arrive here"),
}}
// See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/
addrs, err := resolv.LookupHost(context.Background(), "xn--0000h")
if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected no response here")
}
}
func TestIDNAResolverTransportOK(t *testing.T) {
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{}}
if resolv.Network() != "idna" {
t.Fatal("invalid network")
}
if resolv.Address() != "" {
t.Fatal("invalid address")
}
}
@@ -0,0 +1,111 @@
package resolver_test
import (
"context"
"net"
"net/http"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func init() {
log.SetLevel(log.DebugLevel)
}
func testresolverquick(t *testing.T, reso resolver.Resolver) {
if testing.Short() {
t.Skip("skip test in short mode")
}
reso = resolver.LoggingResolver{Logger: log.Log, Resolver: reso}
addrs, err := reso.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil addrs here")
}
var foundquad8 bool
for _, addr := range addrs {
// See https://github.com/ooni/probe-cli/v3/internal/engine/pull/954/checks?check_run_id=1182269025
if addr == "8.8.8.8" || addr == "2001:4860:4860::8888" {
foundquad8 = true
}
}
if !foundquad8 {
t.Fatalf("did not find 8.8.8.8 in ouput; output=%+v", addrs)
}
}
// Ensuring we can handle Internationalized Domain Names (IDNs) without issues
func testresolverquickidna(t *testing.T, reso resolver.Resolver) {
if testing.Short() {
t.Skip("skip test in short mode")
}
reso = resolver.IDNAResolver{
resolver.LoggingResolver{Logger: log.Log, Resolver: reso},
}
addrs, err := reso.LookupHost(context.Background(), "яндекс.рф")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil addrs here")
}
}
func TestNewResolverSystem(t *testing.T) {
reso := resolver.SystemResolver{}
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverUDPAddress(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverUDP(new(net.Dialer), "8.8.8.8:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverUDPDomain(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverUDP(new(net.Dialer), "dns.google.com:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverTCPAddress(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "8.8.8.8:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverTCPDomain(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "dns.google.com:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoTAddress(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoTDomain(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTLS(resolver.DialTLSContext, "dns.google.com:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoH(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverHTTPS(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
+30
View File
@@ -0,0 +1,30 @@
package resolver
import (
"context"
"time"
)
// Logger is the logger assumed by this package
type Logger interface {
Debugf(format string, v ...interface{})
Debug(message string)
}
// LoggingResolver is a resolver that emits events
type LoggingResolver struct {
Resolver
Logger Logger
}
// LookupHost returns the IP addresses of a host
func (r LoggingResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
r.Logger.Debugf("resolve %s...", hostname)
start := time.Now()
addrs, err := r.Resolver.LookupHost(ctx, hostname)
stop := time.Now()
r.Logger.Debugf("resolve %s... (%+v, %+v) in %s", hostname, addrs, err, stop.Sub(start))
return addrs, err
}
var _ Resolver = LoggingResolver{}
@@ -0,0 +1,23 @@
package resolver_test
import (
"context"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestLoggingResolver(t *testing.T) {
r := resolver.LoggingResolver{
Logger: log.Log,
Resolver: resolver.NewFakeResolverThatFails(),
}
addrs, err := r.LookupHost(context.Background(), "www.google.com")
if err == nil {
t.Fatal("expected an error here")
}
if addrs != nil {
t.Fatal("expected nil addr here")
}
}
+18
View File
@@ -0,0 +1,18 @@
package resolver
import (
"context"
)
// Resolver is a DNS resolver. The *net.Resolver used by Go implements
// this interface, but other implementations are possible.
type Resolver interface {
// LookupHost resolves a hostname to a list of IP addresses.
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
// Network returns the network being used by the resolver
Network() string
// Address returns the address being used by the resolver
Address() string
}
+73
View File
@@ -0,0 +1,73 @@
package resolver
import (
"context"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// SaverResolver is a resolver that saves events
type SaverResolver struct {
Resolver
Saver *trace.Saver
}
// LookupHost implements Resolver.LookupHost
func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
start := time.Now()
r.Saver.Write(trace.Event{
Address: r.Resolver.Address(),
Hostname: hostname,
Name: "resolve_start",
Proto: r.Resolver.Network(),
Time: start,
})
addrs, err := r.Resolver.LookupHost(ctx, hostname)
stop := time.Now()
r.Saver.Write(trace.Event{
Addresses: addrs,
Address: r.Resolver.Address(),
Duration: stop.Sub(start),
Err: err,
Hostname: hostname,
Name: "resolve_done",
Proto: r.Resolver.Network(),
Time: stop,
})
return addrs, err
}
// SaverDNSTransport is a DNS transport that saves events
type SaverDNSTransport struct {
RoundTripper
Saver *trace.Saver
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
start := time.Now()
txp.Saver.Write(trace.Event{
Address: txp.Address(),
DNSQuery: query,
Name: "dns_round_trip_start",
Proto: txp.Network(),
Time: start,
})
reply, err := txp.RoundTripper.RoundTrip(ctx, query)
stop := time.Now()
txp.Saver.Write(trace.Event{
Address: txp.Address(),
DNSQuery: query,
DNSReply: reply,
Duration: stop.Sub(start),
Err: err,
Name: "dns_round_trip_done",
Proto: txp.Network(),
Time: stop,
})
return reply, err
}
var _ Resolver = SaverResolver{}
var _ RoundTripper = SaverDNSTransport{}
+211
View File
@@ -0,0 +1,211 @@
package resolver_test
import (
"bytes"
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
func TestSaverResolverFailure(t *testing.T) {
expected := errors.New("no such host")
saver := &trace.Saver{}
reso := resolver.SaverResolver{
Resolver: resolver.FakeResolver{
Err: expected,
},
Saver: saver,
}
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if ev[1].Addresses != nil {
t.Fatal("unexpected Addresses")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverResolverSuccess(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &trace.Saver{}
reso := resolver.SaverResolver{
Resolver: resolver.FakeResolver{
Result: expected,
},
Saver: saver,
}
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if err != nil {
t.Fatal("expected nil error here")
}
if !reflect.DeepEqual(addrs, expected) {
t.Fatal("not the result we expected")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !reflect.DeepEqual(ev[1].Addresses, expected) {
t.Fatal("unexpected Addresses")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverDNSTransportFailure(t *testing.T) {
expected := errors.New("no such host")
saver := &trace.Saver{}
txp := resolver.SaverDNSTransport{
RoundTripper: resolver.FakeTransport{
Err: expected,
},
Saver: saver,
}
query := []byte("abc")
reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].DNSReply != nil {
t.Fatal("unexpected DNSReply")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Name != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverDNSTransportSuccess(t *testing.T) {
expected := []byte("def")
saver := &trace.Saver{}
txp := resolver.SaverDNSTransport{
RoundTripper: resolver.FakeTransport{
Data: expected,
},
Saver: saver,
}
query := []byte("abc")
reply, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal("we expected nil error here")
}
if !bytes.Equal(reply, expected) {
t.Fatal("expected another reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].DNSReply, expected) {
t.Fatal("unexpected DNSReply")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
+113
View File
@@ -0,0 +1,113 @@
package resolver
import (
"context"
"errors"
"net"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
)
// RoundTripper represents an abstract DNS transport.
type RoundTripper interface {
// RoundTrip sends a DNS query and receives the reply.
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
// RequiresPadding return true for DoH and DoT according to RFC8467
RequiresPadding() bool
// Network is the network of the round tripper (e.g. "dot")
Network() string
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
Address() string
}
// SerialResolver is a resolver that first issues an A query and then
// issues an AAAA query for the requested domain.
type SerialResolver struct {
Encoder Encoder
Decoder Decoder
NumTimeouts *atomicx.Int64
Txp RoundTripper
}
// NewSerialResolver creates a new OONI Resolver instance.
func NewSerialResolver(t RoundTripper) SerialResolver {
return SerialResolver{
Encoder: MiekgEncoder{},
Decoder: MiekgDecoder{},
NumTimeouts: atomicx.NewInt64(),
Txp: t,
}
}
// Transport returns the transport being used.
func (r SerialResolver) Transport() RoundTripper {
return r.Txp
}
// Network implements Resolver.Network
func (r SerialResolver) Network() string {
return r.Txp.Network()
}
// Address implements Resolver.Address
func (r SerialResolver) Address() string {
return r.Txp.Address()
}
// LookupHost implements Resolver.LookupHost.
func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
var addrs []string
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
if errA != nil && errAAAA != nil {
return nil, errA
}
addrs = append(addrs, addrsA...)
addrs = append(addrs, addrsAAAA...)
return addrs, nil
}
func (r SerialResolver) roundTripWithRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
var errorslist []error
for i := 0; i < 3; i++ {
replies, err := r.roundTrip(ctx, hostname, qtype)
if err == nil {
return replies, nil
}
errorslist = append(errorslist, err)
var operr *net.OpError
if errors.As(err, &operr) == false || operr.Timeout() == false {
// The first error is the one that is most likely to be caused
// by the network. Subsequent errors are more likely to be caused
// by context deadlines. So, the first error is attached to an
// operation, while subsequent errors may possibly not be. If
// so, the resulting failing operation is not correct.
break
}
r.NumTimeouts.Add(1)
}
// bugfix: we MUST return one of the errors otherwise we confuse the
// mechanism in errwrap that classifies the root cause operation, since
// it would not be able to find a child with a major operation error
return nil, errorslist[0]
}
func (r SerialResolver) roundTrip(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.Decode(qtype, replydata)
}
var _ Resolver = SerialResolver{}
@@ -0,0 +1,111 @@
package resolver_test
import (
"context"
"errors"
"net"
"strings"
"syscall"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestOONIGettingTransport(t *testing.T) {
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853")
r := resolver.NewSerialResolver(txp)
rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
t.Fatal("not the transport we expected")
}
if r.Network() != rtx.Network() {
t.Fatal("invalid network seen from the resolver")
}
if r.Address() != rtx.Address() {
t.Fatal("invalid address seen from the resolver")
}
}
func TestOONIEncodeError(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853")
r := resolver.SerialResolver{Encoder: resolver.FakeEncoder{Err: mocked}, Txp: txp}
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
}
func TestOONIRoundTripError(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.FakeTransport{Err: mocked}
r := resolver.NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
}
func TestOONIWithEmptyReply(t *testing.T) {
txp := resolver.FakeTransport{Data: resolver.GenReplySuccess(t, dns.TypeA)}
r := resolver.NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
}
func TestOONIWithAReply(t *testing.T) {
txp := resolver.FakeTransport{
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
}
r := resolver.NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
}
func TestOONIWithAAAAReply(t *testing.T) {
txp := resolver.FakeTransport{
Data: resolver.GenReplySuccess(t, dns.TypeAAAA, "::1"),
}
r := resolver.NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "::1" {
t.Fatal("not the result we expected")
}
}
func TestOONIWithTimeout(t *testing.T) {
txp := resolver.FakeTransport{
Err: &net.OpError{Err: syscall.ETIMEDOUT, Op: "dial"},
}
r := resolver.NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, syscall.ETIMEDOUT) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
if r.NumTimeouts.Load() <= 0 {
t.Fatal("we didn't actually take the timeouts")
}
}
+10
View File
@@ -0,0 +1,10 @@
package resolver
import "github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor"
// SystemResolver is the system resolver. It is implemented using
// selfcensor.SystemResolver so that we can perform integration testing
// by forcing the code to return specific responses.
type SystemResolver = selfcensor.SystemResolver
var _ Resolver = SystemResolver{}
@@ -0,0 +1,25 @@
package resolver_test
import (
"context"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestSystemResolverLookupHost(t *testing.T) {
r := resolver.SystemResolver{}
if r.Network() != "system" {
t.Fatal("invalid Network")
}
if r.Address() != "" {
t.Fatal("invalid Address")
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil result here")
}
}
+32
View File
@@ -0,0 +1,32 @@
package resolver
import (
"context"
"crypto/tls"
"net"
)
func DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
connch := make(chan net.Conn)
errch := make(chan error, 1)
go func() {
conn, err := tls.Dial(network, address, new(tls.Config))
if err != nil {
errch <- err
return
}
select {
case <-ctx.Done():
conn.Close()
case connch <- conn:
}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case conn := <-connch:
return conn, nil
case err := <-errch:
return nil, err
}
}