6d3a4f1db8
When preparing a tutorial for netxlite, I figured it is easier to tell people "hey, this is the package you should use for all low-level networking stuff" rather than introducing people to a set of packages working together where some piece of functionality is here and some other piece is there. Part of https://github.com/ooni/probe/issues/1591
273 lines
7.0 KiB
Go
273 lines
7.0 KiB
Go
package oldhttptransport
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptrace"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
|
)
|
|
|
|
func TestTraceTripperSuccess(t *testing.T) {
|
|
client := &http.Client{
|
|
Transport: NewTraceTripper(http.DefaultTransport),
|
|
}
|
|
resp, err := client.Get("https://www.google.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
_, err = netxlite.ReadAllContext(context.Background(), resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
client.CloseIdleConnections()
|
|
}
|
|
|
|
type roundTripHandler struct {
|
|
roundTrips []*modelx.HTTPRoundTripDoneEvent
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (h *roundTripHandler) OnMeasurement(m modelx.Measurement) {
|
|
if m.HTTPRoundTripDone != nil {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
h.roundTrips = append(h.roundTrips, m.HTTPRoundTripDone)
|
|
}
|
|
}
|
|
|
|
func TestTraceTripperReadAllFailure(t *testing.T) {
|
|
transport := NewTraceTripper(http.DefaultTransport)
|
|
transport.readAllContext = func(ctx context.Context, r io.Reader) ([]byte, error) {
|
|
return nil, io.EOF
|
|
}
|
|
client := &http.Client{Transport: transport}
|
|
resp, err := client.Get("https://google.com")
|
|
if err == nil {
|
|
t.Fatal("expected an error here")
|
|
}
|
|
if !errors.Is(err, io.EOF) {
|
|
t.Fatal("not the error we expected")
|
|
}
|
|
if resp != nil {
|
|
t.Fatal("expected nil response here")
|
|
}
|
|
if transport.readAllErrs.Load() <= 0 {
|
|
t.Fatal("not the error we expected")
|
|
}
|
|
client.CloseIdleConnections()
|
|
}
|
|
|
|
func TestTraceTripperFailure(t *testing.T) {
|
|
client := &http.Client{
|
|
Transport: NewTraceTripper(http.DefaultTransport),
|
|
}
|
|
// This fails the request because we attempt to speak cleartext HTTP with
|
|
// a server that instead is expecting TLS.
|
|
resp, err := client.Get("http://www.google.com:443")
|
|
if err == nil {
|
|
t.Fatal("expected an error here")
|
|
}
|
|
if resp != nil {
|
|
t.Fatal("expected a nil response here")
|
|
}
|
|
client.CloseIdleConnections()
|
|
}
|
|
|
|
func TestTraceTripperWithClientTrace(t *testing.T) {
|
|
client := &http.Client{
|
|
Transport: NewTraceTripper(http.DefaultTransport),
|
|
}
|
|
req, err := http.NewRequest("GET", "https://www.kernel.org/", nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req = req.WithContext(
|
|
httptrace.WithClientTrace(req.Context(), new(httptrace.ClientTrace)),
|
|
)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp == nil {
|
|
t.Fatal("expected a good response here")
|
|
}
|
|
resp.Body.Close()
|
|
client.CloseIdleConnections()
|
|
}
|
|
|
|
func TestTraceTripperWithCorrectSnaps(t *testing.T) {
|
|
// Prepare a DNS query for dns.google.com A, for which we
|
|
// know the answer in terms of well know IP addresses
|
|
query := new(dns.Msg)
|
|
query.Id = dns.Id()
|
|
query.RecursionDesired = true
|
|
query.Question = make([]dns.Question, 1)
|
|
query.Question[0] = dns.Question{
|
|
Name: dns.Fqdn("dns.google.com"),
|
|
Qtype: dns.TypeA,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
queryData, err := query.Pack()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Prepare a new transport with limited snapshot size and
|
|
// use such transport to configure an ordinary client
|
|
transport := NewTraceTripper(http.DefaultTransport)
|
|
const snapSize = 15
|
|
client := &http.Client{Transport: transport}
|
|
|
|
// Prepare a new request for Cloudflare DNS, register
|
|
// a handler, issue the request, fetch the response.
|
|
req, err := http.NewRequest(
|
|
"POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData),
|
|
)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/dns-message")
|
|
handler := &roundTripHandler{}
|
|
ctx := modelx.WithMeasurementRoot(
|
|
context.Background(), &modelx.MeasurementRoot{
|
|
Beginning: time.Now(),
|
|
Handler: handler,
|
|
MaxBodySnapSize: snapSize,
|
|
},
|
|
)
|
|
req = req.WithContext(ctx)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
t.Fatal("HTTP request failed")
|
|
}
|
|
|
|
// Read the whole response body, parse it as valid DNS
|
|
// reply and verify we obtained what we expected
|
|
replyData, err := netxlite.ReadAllContext(context.Background(), resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
reply := new(dns.Msg)
|
|
err = reply.Unpack(replyData)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if reply.Rcode != 0 {
|
|
t.Fatal("unexpected Rcode")
|
|
}
|
|
if len(reply.Answer) < 1 {
|
|
t.Fatal("no answers?!")
|
|
}
|
|
found8888, found8844, foundother := false, false, false
|
|
for _, answer := range reply.Answer {
|
|
if rra, ok := answer.(*dns.A); ok {
|
|
ip := rra.A.String()
|
|
if ip == "8.8.8.8" {
|
|
found8888 = true
|
|
} else if ip == "8.8.4.4" {
|
|
found8844 = true
|
|
} else {
|
|
foundother = true
|
|
}
|
|
}
|
|
}
|
|
if !found8888 || !found8844 || foundother {
|
|
t.Fatal("unexpected reply")
|
|
}
|
|
|
|
// Finally, make sure we have captured the correct
|
|
// snapshots for the request and response bodies
|
|
if len(handler.roundTrips) != 1 {
|
|
t.Fatal("more round trips than expected")
|
|
}
|
|
roundTrip := handler.roundTrips[0]
|
|
if len(roundTrip.RequestBodySnap) != snapSize {
|
|
t.Fatal("unexpected request body snap length")
|
|
}
|
|
if len(roundTrip.ResponseBodySnap) != snapSize {
|
|
t.Fatal("unexpected response body snap length")
|
|
}
|
|
if !bytes.Equal(roundTrip.RequestBodySnap, queryData[:snapSize]) {
|
|
t.Fatal("the request body snap is wrong")
|
|
}
|
|
if !bytes.Equal(roundTrip.ResponseBodySnap, replyData[:snapSize]) {
|
|
t.Fatal("the response body snap is wrong")
|
|
}
|
|
}
|
|
|
|
func TestTraceTripperWithReadAllFailingForBody(t *testing.T) {
|
|
// Prepare a DNS query for dns.google.com A, for which we
|
|
// know the answer in terms of well know IP addresses
|
|
query := new(dns.Msg)
|
|
query.Id = dns.Id()
|
|
query.RecursionDesired = true
|
|
query.Question = make([]dns.Question, 1)
|
|
query.Question[0] = dns.Question{
|
|
Name: dns.Fqdn("dns.google.com"),
|
|
Qtype: dns.TypeA,
|
|
Qclass: dns.ClassINET,
|
|
}
|
|
queryData, err := query.Pack()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Prepare a new transport with limited snapshot size and
|
|
// use such transport to configure an ordinary client
|
|
transport := NewTraceTripper(http.DefaultTransport)
|
|
errorMocked := errors.New("mocked error")
|
|
transport.readAllContext = func(ctx context.Context, r io.Reader) ([]byte, error) {
|
|
return nil, errorMocked
|
|
}
|
|
const snapSize = 15
|
|
client := &http.Client{Transport: transport}
|
|
|
|
// Prepare a new request for Cloudflare DNS, register
|
|
// a handler, issue the request, fetch the response.
|
|
req, err := http.NewRequest(
|
|
"POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData),
|
|
)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/dns-message")
|
|
handler := &roundTripHandler{}
|
|
ctx := modelx.WithMeasurementRoot(
|
|
context.Background(), &modelx.MeasurementRoot{
|
|
Beginning: time.Now(),
|
|
Handler: handler,
|
|
MaxBodySnapSize: snapSize,
|
|
},
|
|
)
|
|
req = req.WithContext(ctx)
|
|
resp, err := client.Do(req)
|
|
if err == nil {
|
|
t.Fatal("expected an error here")
|
|
}
|
|
if !errors.Is(err, errorMocked) {
|
|
t.Fatal("not the error we expected")
|
|
}
|
|
if resp != nil {
|
|
t.Fatal("expected nil response here")
|
|
}
|
|
|
|
// Finally, make sure we got something that makes sense
|
|
if len(handler.roundTrips) != 0 {
|
|
t.Fatal("more round trips than expected")
|
|
}
|
|
}
|