ooni-probe-cli/internal/netxlite/dnsoverhttps_test.go
2022-06-09 00:30:18 +02:00

353 lines
9.3 KiB
Go

package netxlite
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestNewDNSOverHTTPSTransport(t *testing.T) {
const URL = "https://1.1.1.1/dns-query"
clnt := NewHTTPClientStdlib(model.DiscardLogger)
txp := NewDNSOverHTTPSTransport(clnt, URL)
ew := txp.(*dnsTransportErrWrapper)
https := ew.DNSTransport.(*DNSOverHTTPSTransport)
if https.Client != clnt {
t.Fatal("invalid client")
}
if https.URL != URL {
t.Fatal("invalid URL")
}
}
func TestNewDNSOverHTTPSTransportWithHTTPTransport(t *testing.T) {
const URL = "https://1.1.1.1/dns-query"
httpTxp := NewHTTPTransportStdlib(model.DiscardLogger)
txp := NewDNSOverHTTPSTransportWithHTTPTransport(httpTxp, URL)
ew := txp.(*dnsTransportErrWrapper)
https := ew.DNSTransport.(*DNSOverHTTPSTransport)
ewClient := https.Client.(*httpClientErrWrapper)
clnt := ewClient.HTTPClient.(*http.Client)
if clnt.Transport != httpTxp {
t.Fatal("invalid transport")
}
if https.URL != URL {
t.Fatal("invalid URL")
}
}
func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("query serialization failure", func(t *testing.T) {
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
expected := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t"
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("client.Do failure", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("server returns 500", func(t *testing.T) {
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("missing content type", func(t *testing.T) {
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("ReadAllContext fails", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("decode response failure", func(t *testing.T) {
expected := errors.New("mocked error")
body := []byte("AAA")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("success", func(t *testing.T) {
body := []byte("AAA")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return &mocks.DNSResponse{}, nil
},
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil resp here")
}
})
t.Run("sets the correct user-agent", func(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == model.HTTPHeaderUserAgent
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
data, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct user agent")
}
})
t.Run("we can override the Host header", func(t *testing.T) {
var correct bool
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if resp != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
})
})
t.Run("other functions behave correctly", func(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
if txp.RequiresPadding() != true {
t.Fatal("should require padding")
}
if txp.Address() != queryURL {
t.Fatal("invalid address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
doh := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockCloseIdleConnections: func() {
called = true
},
},
}
doh.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}