ooni-probe-cli/internal/engine/legacy/netx/oldhttptransport/tracetripper_test.go
Simone Basso 6d3a4f1db8
refactor: merge dnsx and errorsx into netxlite (#517)
When preparing a tutorial for netxlite, I figured it is easier
to tell people "hey, this is the package you should use for all
low-level networking stuff" rather than introducing people to
a set of packages working together where some piece of functionality
is here and some other piece is there.

Part of https://github.com/ooni/probe/issues/1591
2021-09-28 12:42:01 +02:00

273 lines
7.0 KiB
Go

package oldhttptransport
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptrace"
"sync"
"testing"
"time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestTraceTripperSuccess(t *testing.T) {
client := &http.Client{
Transport: NewTraceTripper(http.DefaultTransport),
}
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
_, err = netxlite.ReadAllContext(context.Background(), resp.Body)
if err != nil {
t.Fatal(err)
}
client.CloseIdleConnections()
}
type roundTripHandler struct {
roundTrips []*modelx.HTTPRoundTripDoneEvent
mu sync.Mutex
}
func (h *roundTripHandler) OnMeasurement(m modelx.Measurement) {
if m.HTTPRoundTripDone != nil {
h.mu.Lock()
defer h.mu.Unlock()
h.roundTrips = append(h.roundTrips, m.HTTPRoundTripDone)
}
}
func TestTraceTripperReadAllFailure(t *testing.T) {
transport := NewTraceTripper(http.DefaultTransport)
transport.readAllContext = func(ctx context.Context, r io.Reader) ([]byte, error) {
return nil, io.EOF
}
client := &http.Client{Transport: transport}
resp, err := client.Get("https://google.com")
if err == nil {
t.Fatal("expected an error here")
}
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
if transport.readAllErrs.Load() <= 0 {
t.Fatal("not the error we expected")
}
client.CloseIdleConnections()
}
func TestTraceTripperFailure(t *testing.T) {
client := &http.Client{
Transport: NewTraceTripper(http.DefaultTransport),
}
// This fails the request because we attempt to speak cleartext HTTP with
// a server that instead is expecting TLS.
resp, err := client.Get("http://www.google.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if resp != nil {
t.Fatal("expected a nil response here")
}
client.CloseIdleConnections()
}
func TestTraceTripperWithClientTrace(t *testing.T) {
client := &http.Client{
Transport: NewTraceTripper(http.DefaultTransport),
}
req, err := http.NewRequest("GET", "https://www.kernel.org/", nil)
if err != nil {
t.Fatal(err)
}
req = req.WithContext(
httptrace.WithClientTrace(req.Context(), new(httptrace.ClientTrace)),
)
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected a good response here")
}
resp.Body.Close()
client.CloseIdleConnections()
}
func TestTraceTripperWithCorrectSnaps(t *testing.T) {
// Prepare a DNS query for dns.google.com A, for which we
// know the answer in terms of well know IP addresses
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = dns.Question{
Name: dns.Fqdn("dns.google.com"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
queryData, err := query.Pack()
if err != nil {
t.Fatal(err)
}
// Prepare a new transport with limited snapshot size and
// use such transport to configure an ordinary client
transport := NewTraceTripper(http.DefaultTransport)
const snapSize = 15
client := &http.Client{Transport: transport}
// Prepare a new request for Cloudflare DNS, register
// a handler, issue the request, fetch the response.
req, err := http.NewRequest(
"POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData),
)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/dns-message")
handler := &roundTripHandler{}
ctx := modelx.WithMeasurementRoot(
context.Background(), &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
MaxBodySnapSize: snapSize,
},
)
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 200 {
t.Fatal("HTTP request failed")
}
// Read the whole response body, parse it as valid DNS
// reply and verify we obtained what we expected
replyData, err := netxlite.ReadAllContext(context.Background(), resp.Body)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
reply := new(dns.Msg)
err = reply.Unpack(replyData)
if err != nil {
t.Fatal(err)
}
if reply.Rcode != 0 {
t.Fatal("unexpected Rcode")
}
if len(reply.Answer) < 1 {
t.Fatal("no answers?!")
}
found8888, found8844, foundother := false, false, false
for _, answer := range reply.Answer {
if rra, ok := answer.(*dns.A); ok {
ip := rra.A.String()
if ip == "8.8.8.8" {
found8888 = true
} else if ip == "8.8.4.4" {
found8844 = true
} else {
foundother = true
}
}
}
if !found8888 || !found8844 || foundother {
t.Fatal("unexpected reply")
}
// Finally, make sure we have captured the correct
// snapshots for the request and response bodies
if len(handler.roundTrips) != 1 {
t.Fatal("more round trips than expected")
}
roundTrip := handler.roundTrips[0]
if len(roundTrip.RequestBodySnap) != snapSize {
t.Fatal("unexpected request body snap length")
}
if len(roundTrip.ResponseBodySnap) != snapSize {
t.Fatal("unexpected response body snap length")
}
if !bytes.Equal(roundTrip.RequestBodySnap, queryData[:snapSize]) {
t.Fatal("the request body snap is wrong")
}
if !bytes.Equal(roundTrip.ResponseBodySnap, replyData[:snapSize]) {
t.Fatal("the response body snap is wrong")
}
}
func TestTraceTripperWithReadAllFailingForBody(t *testing.T) {
// Prepare a DNS query for dns.google.com A, for which we
// know the answer in terms of well know IP addresses
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = dns.Question{
Name: dns.Fqdn("dns.google.com"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
queryData, err := query.Pack()
if err != nil {
t.Fatal(err)
}
// Prepare a new transport with limited snapshot size and
// use such transport to configure an ordinary client
transport := NewTraceTripper(http.DefaultTransport)
errorMocked := errors.New("mocked error")
transport.readAllContext = func(ctx context.Context, r io.Reader) ([]byte, error) {
return nil, errorMocked
}
const snapSize = 15
client := &http.Client{Transport: transport}
// Prepare a new request for Cloudflare DNS, register
// a handler, issue the request, fetch the response.
req, err := http.NewRequest(
"POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData),
)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/dns-message")
handler := &roundTripHandler{}
ctx := modelx.WithMeasurementRoot(
context.Background(), &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
MaxBodySnapSize: snapSize,
},
)
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err == nil {
t.Fatal("expected an error here")
}
if !errors.Is(err, errorMocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
// Finally, make sure we got something that makes sense
if len(handler.roundTrips) != 0 {
t.Fatal("more round trips than expected")
}
}