ooni-probe-cli/internal/engine/netx/resolver/dnsoverhttps_test.go

166 lines
4.2 KiB
Go

package resolver_test
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
const invalidURL = "\t"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
expected := errors.New("mocked error")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSSuccess(t *testing.T) {
body := []byte("AAA")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, body) {
t.Fatal("not the response we expected")
}
}
func TestDNSOverHTTPTransportOK(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
if txp.RequiresPadding() != true {
t.Fatal("should require padding")
}
if txp.Address() != queryURL {
t.Fatal("invalid address")
}
}
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct user agent")
}
}
func TestDNSOverHTTPSHostOverride(t *testing.T) {
var correct bool
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
}