ooni-probe-cli/internal/netxlite/dnsx/dnsoverhttps_test.go

197 lines
5.0 KiB
Go

package dnsx
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)
func TestDNSOverHTTPS(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t"
txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("client.Do failure", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("server returns 500", func(t *testing.T) {
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("missing content type", func(t *testing.T) {
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("success", func(t *testing.T) {
body := []byte("AAA")
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, body) {
t.Fatal("not the response we expected")
}
})
t.Run("sets the correct user-agent", func(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct user agent")
}
})
t.Run("we can override the Host header", func(t *testing.T) {
var correct bool
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
})
})
t.Run("other functions behave correctly", func(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := NewDNSOverHTTPS(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
if txp.RequiresPadding() != true {
t.Fatal("should require padding")
}
if txp.Address() != queryURL {
t.Fatal("invalid address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
doh := &DNSOverHTTPS{
Client: &mocks.HTTPClient{
MockCloseIdleConnections: func() {
called = true
},
},
}
doh.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}