ooni-probe-cli/internal/netxlite/dnsoverhttps_test.go
Simone Basso 8f7e3803eb
feat(netxlite): implement DNSTransport wrapping (#776)
Acknowledge that transports MAY be used in isolation (i.e., outside
of a Resolver) and add support for wrapping.

Ensure that every factory that creates an unwrapped type is named
accordingly to hopefully ensure there are no surprises.

Implement DNSTransport wrapping and use a technique similar to the
one used by Dialer to customize the DNSTransport while constructing
more complex data types (e.g., a specific resolver).

Ensure that the stdlib resolver's own "getaddrinfo" transport (1)
is wrapped and (2) could be extended during construction.

This work is part of my ongoing effort to bring to this repository
websteps-illustrated changes relative to netxlite.

Ref issue: https://github.com/ooni/probe/issues/2096
2022-06-01 11:10:08 +02:00

323 lines
8.4 KiB
Go

package netxlite
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("query serialization failure", func(t *testing.T) {
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
expected := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t"
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("client.Do failure", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("server returns 500", func(t *testing.T) {
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("missing content type", func(t *testing.T) {
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("")),
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("ReadAllContext fails", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("decode response failure", func(t *testing.T) {
expected := errors.New("mocked error")
body := []byte("AAA")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("success", func(t *testing.T) {
body := []byte("AAA")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return &mocks.DNSResponse{}, nil
},
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected non-nil resp here")
}
})
t.Run("sets the correct user-agent", func(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == model.HTTPHeaderUserAgent
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
data, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct user agent")
}
})
t.Run("we can override the Host header", func(t *testing.T) {
var correct bool
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if resp != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
})
})
t.Run("other functions behave correctly", func(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := NewUnwrappedDNSOverHTTPSTransport(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
if txp.RequiresPadding() != true {
t.Fatal("should require padding")
}
if txp.Address() != queryURL {
t.Fatal("invalid address")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
doh := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockCloseIdleConnections: func() {
called = true
},
},
}
doh.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}