From c2ea0b470485c2699ce888286ce444a77a7a50c3 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Mon, 21 Nov 2022 16:28:53 +0100 Subject: [PATCH] feat(webconnectivity): try all the available THs (#980) We introduce a fork of internal/httpx, named internal/httpapi, where there is a clear split between the concept of an API endpoint (such as https://0.th.ooni.org/) and of an API descriptor (such as using `GET` to access /api/v1/test-list/url). Additionally, httpapi allows to create a SequenceCaller that tries to call a given API descriptor using multiple API endpoints. The SequenceCaller will stop once an endpoint works or when all the available endpoints have been tried unsuccessfully. The definition of "success" is the following: we consider "failure" any error that occurs during the HTTP round trip or when reading the response body. We DO NOT consider "failure" errors (1) when parsing the input URL; (2) when the server returns >= 400; (3) when the server returns a string that does not parse as valid JSON. The idea of this classification of failures is that we ONLY want to retry when we see what looks like a network error that may be caused by (collateral or targeted) censorship. We take advantage of the availability of this new package and we refactor web_connectivity@v0.4 and web_connectivity@v0.5 to use a SequenceCaller for calling the web connectivity TH API. This means that we will now try all the available THs advertised by the backend rather than just selecting and using the first one provided by the backend. Because this diff is designed to be backported to the `release/3.16` branch, we have omitted additional changes to always use httpapi where we are currently using httpx. Yet, to remind ourselves about the need to do that, we have deprecated the httpx package. We will rewrite all the code currently using httpx to use httpapi as part of future work. It is also worth noting that httpapi will allow us to refactor the backend code such that (1) we remove code to select a backend URL endpoint at the beginning and (2) we try several endpoints. The design of the code is such that we can add to the mix some endpoints using as `http.Client` a special client using a tunnel. This will allow us to automatically fallback backend queries. Closes https://github.com/ooni/probe/issues/2353. Related to https://github.com/ooni/probe/issues/1519. --- .../experiment/webconnectivity/control.go | 30 +- .../webconnectivity/webconnectivity.go | 24 +- .../webconnectivity/webconnectivity_test.go | 2 +- .../webconnectivity/cleartextflow.go | 2 +- .../experiment/webconnectivity/control.go | 33 +- .../webconnectivity/dnsresolvers.go | 13 +- .../experiment/webconnectivity/measurer.go | 26 +- .../experiment/webconnectivity/secureflow.go | 2 +- .../experiment/webconnectivity/testkeys.go | 20 + internal/httpapi/call.go | 181 +++ internal/httpapi/call_test.go | 1163 +++++++++++++++++ internal/httpapi/descriptor.go | 155 +++ internal/httpapi/descriptor_test.go | 248 ++++ internal/httpapi/doc.go | 15 + internal/httpapi/endpoint.go | 76 ++ internal/httpapi/endpoint_test.go | 69 + internal/httpapi/sequence.go | 92 ++ internal/httpapi/sequence_test.go | 358 +++++ internal/httpx/httpx.go | 4 + 19 files changed, 2446 insertions(+), 67 deletions(-) create mode 100644 internal/httpapi/call.go create mode 100644 internal/httpapi/call_test.go create mode 100644 internal/httpapi/descriptor.go create mode 100644 internal/httpapi/descriptor_test.go create mode 100644 internal/httpapi/doc.go create mode 100644 internal/httpapi/endpoint.go create mode 100644 internal/httpapi/endpoint_test.go create mode 100644 internal/httpapi/sequence.go create mode 100644 internal/httpapi/sequence_test.go diff --git a/internal/engine/experiment/webconnectivity/control.go b/internal/engine/experiment/webconnectivity/control.go index ccaf870..bd953a1 100644 --- a/internal/engine/experiment/webconnectivity/control.go +++ b/internal/engine/experiment/webconnectivity/control.go @@ -4,9 +4,10 @@ import ( "context" "github.com/ooni/probe-cli/v3/internal/geoipx" - "github.com/ooni/probe-cli/v3/internal/httpx" + "github.com/ooni/probe-cli/v3/internal/httpapi" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" ) // Redirect to types defined inside the model package @@ -21,22 +22,23 @@ type ( // Control performs the control request and returns the response. func Control( ctx context.Context, sess model.ExperimentSession, - thAddr string, creq ControlRequest) (out ControlResponse, err error) { - clnt := &httpx.APIClientTemplate{ - BaseURL: thAddr, - HTTPClient: sess.DefaultHTTPClient(), - Logger: sess.Logger(), - UserAgent: sess.UserAgent(), - } + testhelpers []model.OOAPIService, creq ControlRequest) (ControlResponse, *model.OOAPIService, error) { + seqCaller := httpapi.NewSequenceCaller( + httpapi.MustNewPOSTJSONWithJSONResponseDescriptor(sess.Logger(), "/", creq).WithBodyLogging(true), + httpapi.NewEndpointList(sess.DefaultHTTPClient(), sess.UserAgent(), testhelpers...)..., + ) sess.Logger().Infof("control for %s...", creq.HTTPRequest) - // make sure error is wrapped - err = clnt.WithBodyLogging().Build().PostJSON(ctx, "/", creq, &out) - if err != nil { - err = netxlite.NewTopLevelGenericErrWrapper(err) - } + var out ControlResponse + idx, err := seqCaller.CallWithJSONResponse(ctx, &out) sess.Logger().Infof("control for %s... %+v", creq.HTTPRequest, model.ErrorToStringOrOK(err)) + if err != nil { + // make sure error is wrapped + err = netxlite.NewTopLevelGenericErrWrapper(err) + return ControlResponse{}, nil, err + } fillASNs(&out.DNS) - return + runtimex.Assert(idx >= 0 && idx < len(testhelpers), "idx out of bounds") + return out, &testhelpers[idx], nil } // fillASNs fills the ASNs array of ControlDNSResult. For each Addr inside diff --git a/internal/engine/experiment/webconnectivity/webconnectivity.go b/internal/engine/experiment/webconnectivity/webconnectivity.go index 526b9f7..bdbea35 100644 --- a/internal/engine/experiment/webconnectivity/webconnectivity.go +++ b/internal/engine/experiment/webconnectivity/webconnectivity.go @@ -15,7 +15,7 @@ import ( const ( testName = "web_connectivity" - testVersion = "0.4.1" + testVersion = "0.4.2" ) // Config contains the experiment config. @@ -145,19 +145,9 @@ func (m Measurer) Run( } // 1. find test helper testhelpers, _ := sess.GetTestHelpersByName("web-connectivity") - var testhelper *model.OOAPIService - for _, th := range testhelpers { - if th.Type == "https" { - testhelper = &th - break - } - } - if testhelper == nil { + if len(testhelpers) < 1 { return ErrNoAvailableTestHelpers } - measurement.TestHelpers = map[string]interface{}{ - "backend": testhelper, - } // 2. perform the DNS lookup step dnsBegin := time.Now() dnsResult := DNSLookup(ctx, DNSLookupConfig{ @@ -167,10 +157,11 @@ func (m Measurer) Run( tk.Queries = append(tk.Queries, dnsResult.TestKeys.Queries...) tk.DNSExperimentFailure = dnsResult.Failure epnts := NewEndpoints(URL, dnsResult.Addresses()) - sess.Logger().Infof("using control: %s", testhelper.Address) + sess.Logger().Infof("using control: %+v", testhelpers) // 3. perform the control measurement thBegin := time.Now() - tk.Control, err = Control(ctx, sess, testhelper.Address, ControlRequest{ + var usedTH *model.OOAPIService + tk.Control, usedTH, err = Control(ctx, sess, testhelpers, ControlRequest{ HTTPRequest: URL.String(), HTTPRequestHeaders: map[string][]string{ "Accept": {model.HTTPHeaderAccept}, @@ -179,6 +170,11 @@ func (m Measurer) Run( }, TCPConnect: epnts.Endpoints(), }) + if usedTH != nil { + measurement.TestHelpers = map[string]interface{}{ + "backend": usedTH, + } + } tk.THRuntime = time.Since(thBegin) tk.ControlFailure = tracex.NewFailure(err) // 4. analyze DNS results diff --git a/internal/engine/experiment/webconnectivity/webconnectivity_test.go b/internal/engine/experiment/webconnectivity/webconnectivity_test.go index c1fcb37..6db3989 100644 --- a/internal/engine/experiment/webconnectivity/webconnectivity_test.go +++ b/internal/engine/experiment/webconnectivity/webconnectivity_test.go @@ -21,7 +21,7 @@ func TestNewExperimentMeasurer(t *testing.T) { if measurer.ExperimentName() != "web_connectivity" { t.Fatal("unexpected name") } - if measurer.ExperimentVersion() != "0.4.1" { + if measurer.ExperimentVersion() != "0.4.2" { t.Fatal("unexpected version") } } diff --git a/internal/experiment/webconnectivity/cleartextflow.go b/internal/experiment/webconnectivity/cleartextflow.go index 351a27c..d515d4e 100644 --- a/internal/experiment/webconnectivity/cleartextflow.go +++ b/internal/experiment/webconnectivity/cleartextflow.go @@ -285,7 +285,7 @@ func (t *CleartextFlow) maybeFollowRedirects(ctx context.Context, resp *http.Res WaitGroup: t.WaitGroup, Referer: resp.Request.URL.String(), Session: nil, // no need to issue another control request - THAddr: "", // ditto + TestHelpers: nil, // ditto UDPAddress: t.UDPAddress, } resolvers.Start(ctx) diff --git a/internal/experiment/webconnectivity/control.go b/internal/experiment/webconnectivity/control.go index 3e58f42..bf7003a 100644 --- a/internal/experiment/webconnectivity/control.go +++ b/internal/experiment/webconnectivity/control.go @@ -8,10 +8,11 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/engine/experiment/webconnectivity" - "github.com/ooni/probe-cli/v3/internal/httpx" + "github.com/ooni/probe-cli/v3/internal/httpapi" "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" ) // EndpointMeasurementsStarter is used by Control to start extra @@ -51,8 +52,8 @@ type Control struct { // Session is the MANDATORY session to use. Session model.ExperimentSession - // THAddr is the MANDATORY TH's URL. - THAddr string + // TestHelpers is the MANDATORY list of test helpers. + TestHelpers []model.OOAPIService // URL is the MANDATORY URL we are measuring. URL *url.URL @@ -102,26 +103,20 @@ func (c *Control) Run(parentCtx context.Context) { // create logger for this operation ol := measurexlite.NewOperationLogger( c.Logger, - "control for %s using %s", + "control for %s using %+v", creq.HTTPRequest, - c.THAddr, + c.TestHelpers, ) - // create an API client - clnt := (&httpx.APIClientTemplate{ - Accept: "", - Authorization: "", - BaseURL: c.THAddr, - HTTPClient: c.Session.DefaultHTTPClient(), - Host: "", // use the one inside the URL - LogBody: true, - Logger: c.Logger, - UserAgent: c.Session.UserAgent(), - }).Build() + // create an httpapi sequence caller + seqCaller := httpapi.NewSequenceCaller( + httpapi.MustNewPOSTJSONWithJSONResponseDescriptor(c.Logger, "/", creq).WithBodyLogging(true), + httpapi.NewEndpointList(c.Session.DefaultHTTPClient(), c.Session.UserAgent(), c.TestHelpers...)..., + ) // issue the control request and wait for the response var cresp webconnectivity.ControlResponse - err := clnt.PostJSON(opCtx, "/", creq, &cresp) + idx, err := seqCaller.CallWithJSONResponse(opCtx, &cresp) if err != nil { // make sure error is wrapped err = netxlite.NewTopLevelGenericErrWrapper(err) @@ -134,6 +129,10 @@ func (c *Control) Run(parentCtx context.Context) { c.TestKeys.SetControl(&cresp) ol.Stop(nil) + // record the specific TH that worked + runtimex.Assert(idx >= 0 && idx < len(c.TestHelpers), "idx out of bounds") + c.TestKeys.setTestHelper(&c.TestHelpers[idx]) + // if the TH returned us addresses we did not previously were // aware of, make sure we also measure them c.maybeStartExtraMeasurements(parentCtx, cresp.DNS.Addrs) diff --git a/internal/experiment/webconnectivity/dnsresolvers.go b/internal/experiment/webconnectivity/dnsresolvers.go index 0c9df8d..8709635 100644 --- a/internal/experiment/webconnectivity/dnsresolvers.go +++ b/internal/experiment/webconnectivity/dnsresolvers.go @@ -67,8 +67,9 @@ type DNSResolvers struct { // always follow the redirect chain caused by the provided URL. Session model.ExperimentSession - // THAddr is the OPTIONAL test helper address. - THAddr string + // TestHelpers is the OPTIONAL list of test helpers. If the list is + // empty, we are not going to try to contact any test helper. + TestHelpers []model.OOAPIService // UDPAddress is the OPTIONAL address of the UDP resolver to use. If this // field is not set we use a default one (e.g., `8.8.8.8:53`). @@ -498,15 +499,15 @@ func (t *DNSResolvers) startSecureFlows( } } -// maybeStartControlFlow starts the control flow iff .Session and .THAddr are set. +// maybeStartControlFlow starts the control flow iff .Session and .TestHelpers are set. func (t *DNSResolvers) maybeStartControlFlow( ctx context.Context, ps *prioritySelector, addresses []DNSEntry, ) { - // note: for subsequent requests we don't set .Session and .THAddr hence + // note: for subsequent requests we don't set .Session and .TestHelpers hence // we are not going to query the test helper more than once - if t.Session != nil && t.THAddr != "" { + if t.Session != nil && len(t.TestHelpers) > 0 { var addrs []string for _, addr := range addresses { addrs = append(addrs, addr.Addr) @@ -518,7 +519,7 @@ func (t *DNSResolvers) maybeStartControlFlow( PrioSelector: ps, TestKeys: t.TestKeys, Session: t.Session, - THAddr: t.THAddr, + TestHelpers: t.TestHelpers, URL: t.URL, WaitGroup: t.WaitGroup, } diff --git a/internal/experiment/webconnectivity/measurer.go b/internal/experiment/webconnectivity/measurer.go index c6731e4..5a0b977 100644 --- a/internal/experiment/webconnectivity/measurer.go +++ b/internal/experiment/webconnectivity/measurer.go @@ -36,7 +36,7 @@ func (m *Measurer) ExperimentName() string { // ExperimentVersion implements model.ExperimentMeasurer. func (m *Measurer) ExperimentVersion() string { - return "0.5.18" + return "0.5.19" } // Run implements model.ExperimentMeasurer. @@ -89,17 +89,7 @@ func (m *Measurer) Run(ctx context.Context, sess model.ExperimentSession, // obtain the test helper's address testhelpers, _ := sess.GetTestHelpersByName("web-connectivity") - var thAddr string - for _, th := range testhelpers { - if th.Type == "https" { - thAddr = th.Address - measurement.TestHelpers = map[string]any{ - "backend": &th, - } - break - } - } - if thAddr == "" { + if len(testhelpers) < 1 { sess.Logger().Warnf("continuing without a valid TH address") tk.SetControlFailure(webconnectivity.ErrNoAvailableTestHelpers) } @@ -120,7 +110,7 @@ func (m *Measurer) Run(ctx context.Context, sess model.ExperimentSession, CookieJar: jar, Referer: "", Session: sess, - THAddr: thAddr, + TestHelpers: testhelpers, UDPAddress: "", } resos.Start(ctx) @@ -137,6 +127,16 @@ func (m *Measurer) Run(ctx context.Context, sess model.ExperimentSession, // perform any deferred computation on the test keys tk.Finalize(sess.Logger()) + // set the test helper we used + // TODO(bassosimone): it may be more informative to know about all the + // test helpers we _tried_ to use, however the data format does not have + // support for that as far as I can tell... + if th := tk.getTestHelper(); th != nil { + measurement.TestHelpers = map[string]interface{}{ + "backend": th, + } + } + // return whether there was a fundamental failure, which would prevent // the measurement from being submitted to the OONI collector. return tk.fundamentalFailure diff --git a/internal/experiment/webconnectivity/secureflow.go b/internal/experiment/webconnectivity/secureflow.go index 68f41e1..6a47b73 100644 --- a/internal/experiment/webconnectivity/secureflow.go +++ b/internal/experiment/webconnectivity/secureflow.go @@ -337,7 +337,7 @@ func (t *SecureFlow) maybeFollowRedirects(ctx context.Context, resp *http.Respon WaitGroup: t.WaitGroup, Referer: resp.Request.URL.String(), Session: nil, // no need to issue another control request - THAddr: "", // ditto + TestHelpers: nil, // ditto UDPAddress: t.UDPAddress, } resolvers.Start(ctx) diff --git a/internal/experiment/webconnectivity/testkeys.go b/internal/experiment/webconnectivity/testkeys.go index 37dbc94..5cd917f 100644 --- a/internal/experiment/webconnectivity/testkeys.go +++ b/internal/experiment/webconnectivity/testkeys.go @@ -134,6 +134,10 @@ type TestKeys struct { // mu provides mutual exclusion for accessing the test keys. mu *sync.Mutex + + // testHelper is used to communicate the TH that worked to the main + // goroutine such that we can fill measurement.TestHelpers. + testHelper *model.OOAPIService } // ConnPriorityLogEntry is an entry in the TestKeys.ConnPriorityLog slice. @@ -302,6 +306,21 @@ func (tk *TestKeys) AppendConnPriorityLogEntry(entry *ConnPriorityLogEntry) { tk.mu.Unlock() } +// setTestHelper sets .testHelper in a thread safe way +func (tk *TestKeys) setTestHelper(th *model.OOAPIService) { + tk.mu.Lock() + tk.testHelper = th + tk.mu.Unlock() +} + +// getTestHelper gets .testHelper in a thread safe way +func (tk *TestKeys) getTestHelper() (th *model.OOAPIService) { + tk.mu.Lock() + th = tk.testHelper + tk.mu.Unlock() + return +} + // NewTestKeys creates a new instance of TestKeys. func NewTestKeys() *TestKeys { return &TestKeys{ @@ -348,6 +367,7 @@ func NewTestKeys() *TestKeys { ControlRequest: nil, fundamentalFailure: nil, mu: &sync.Mutex{}, + testHelper: nil, } } diff --git a/internal/httpapi/call.go b/internal/httpapi/call.go new file mode 100644 index 0000000..1228080 --- /dev/null +++ b/internal/httpapi/call.go @@ -0,0 +1,181 @@ +package httpapi + +// +// Calling HTTP APIs. +// + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +// joinURLPath appends |resourcePath| to |urlPath|. +func joinURLPath(urlPath, resourcePath string) string { + if resourcePath == "" { + if urlPath == "" { + return "/" + } + return urlPath + } + if !strings.HasSuffix(urlPath, "/") { + urlPath += "/" + } + resourcePath = strings.TrimPrefix(resourcePath, "/") + return urlPath + resourcePath +} + +// newRequest creates a new http.Request from the given |ctx|, |endpoint|, and |desc|. +func newRequest(ctx context.Context, endpoint *Endpoint, desc *Descriptor) (*http.Request, error) { + URL, err := url.Parse(endpoint.BaseURL) + if err != nil { + return nil, err + } + // BaseURL and resource URL are joined if they have a path + URL.Path = joinURLPath(URL.Path, desc.URLPath) + if len(desc.URLQuery) > 0 { + URL.RawQuery = desc.URLQuery.Encode() + } else { + URL.RawQuery = "" // as documented we only honour desc.URLQuery + } + var reqBody io.Reader + if len(desc.RequestBody) > 0 { + reqBody = bytes.NewReader(desc.RequestBody) + desc.Logger.Debugf("httpapi: request body length: %d", len(desc.RequestBody)) + if desc.LogBody { + desc.Logger.Debugf("httpapi: request body: %s", string(desc.RequestBody)) + } + } + request, err := http.NewRequestWithContext(ctx, desc.Method, URL.String(), reqBody) + if err != nil { + return nil, err + } + request.Host = endpoint.Host // allow cloudfronting + if desc.Authorization != "" { + request.Header.Set("Authorization", desc.Authorization) + } + if desc.ContentType != "" { + request.Header.Set("Content-Type", desc.ContentType) + } + if desc.Accept != "" { + request.Header.Set("Accept", desc.Accept) + } + if endpoint.UserAgent != "" { + request.Header.Set("User-Agent", endpoint.UserAgent) + } + return request, nil +} + +// ErrHTTPRequestFailed indicates that the server returned >= 400. +type ErrHTTPRequestFailed struct { + // StatusCode is the status code that failed. + StatusCode int +} + +// Error implements error. +func (err *ErrHTTPRequestFailed) Error() string { + return fmt.Sprintf("httpapi: http request failed: %d", err.StatusCode) +} + +// errMaybeCensorship indicates that there was an error at the networking layer +// including, e.g., DNS, TCP connect, TLS. When we see this kind of error, we +// will consider retrying with another endpoint under the assumption that it +// may be that the current endpoint is censored. +type errMaybeCensorship struct { + // Err is the underlying error + Err error +} + +// Error implements error +func (err *errMaybeCensorship) Error() string { + return err.Err.Error() +} + +// Unwrap allows to get the underlying error +func (err *errMaybeCensorship) Unwrap() error { + return err.Err +} + +// docall calls the API represented by the given request |req| on the given |endpoint| +// and returns the response and its body or an error. +func docall(endpoint *Endpoint, desc *Descriptor, request *http.Request) (*http.Response, []byte, error) { + // Implementation note: remember to mark errors for which you want + // to retry with another endpoint using errMaybeCensorship. + response, err := endpoint.HTTPClient.Do(request) + if err != nil { + return nil, nil, &errMaybeCensorship{err} + } + defer response.Body.Close() + // Implementation note: always read and log the response body since + // it's quite useful to see the response JSON on API error. + r := io.LimitReader(response.Body, DefaultMaxBodySize) + data, err := netxlite.ReadAllContext(request.Context(), r) + if err != nil { + return response, nil, &errMaybeCensorship{err} + } + desc.Logger.Debugf("httpapi: response body length: %d bytes", len(data)) + if desc.LogBody { + desc.Logger.Debugf("httpapi: response body: %s", string(data)) + } + if response.StatusCode >= 400 { + return response, nil, &ErrHTTPRequestFailed{response.StatusCode} + } + return response, data, nil +} + +// call is like Call but also returns the response. +func call(ctx context.Context, desc *Descriptor, endpoint *Endpoint) (*http.Response, []byte, error) { + timeout := desc.Timeout + if timeout <= 0 { + timeout = DefaultCallTimeout // as documented + } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + request, err := newRequest(ctx, endpoint, desc) + if err != nil { + return nil, nil, err + } + return docall(endpoint, desc, request) +} + +// Call invokes the API described by |desc| on the given HTTP |endpoint| and +// returns the response body (as a slice of bytes) or an error. +// +// Note: this function returns ErrHTTPRequestFailed if the HTTP status code is +// greater or equal than 400. You could use errors.As to obtain a copy of the +// error that was returned and see for yourself the actual status code. +func Call(ctx context.Context, desc *Descriptor, endpoint *Endpoint) ([]byte, error) { + _, rawResponseBody, err := call(ctx, desc, endpoint) + return rawResponseBody, err +} + +// goodContentTypeForJSON tracks known-good content-types for JSON. If the content-type +// is not in this map, |CallWithJSONResponse| emits a warning message. +var goodContentTypeForJSON = map[string]bool{ + applicationJSON: true, +} + +// CallWithJSONResponse is like Call but also assumes that the response is a +// JSON body and attempts to parse it into the |response| field. +// +// Note: this function returns ErrHTTPRequestFailed if the HTTP status code is +// greater or equal than 400. You could use errors.As to obtain a copy of the +// error that was returned and see for yourself the actual status code. +func CallWithJSONResponse(ctx context.Context, desc *Descriptor, endpoint *Endpoint, response any) error { + httpResp, rawRespBody, err := call(ctx, desc, endpoint) + if err != nil { + return err + } + if ctype := httpResp.Header.Get("Content-Type"); !goodContentTypeForJSON[ctype] { + desc.Logger.Warnf("httpapi: unexpected content-type: %s", ctype) + // fallthrough + } + return json.Unmarshal(rawRespBody, response) +} diff --git a/internal/httpapi/call_test.go b/internal/httpapi/call_test.go new file mode 100644 index 0000000..8fd45d7 --- /dev/null +++ b/internal/httpapi/call_test.go @@ -0,0 +1,1163 @@ +package httpapi + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "syscall" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +func Test_joinURLPath(t *testing.T) { + tests := []struct { + name string + urlPath string + resourcePath string + want string + }{{ + name: "whole path inside urlPath and empty resourcePath", + urlPath: "/robots.txt", + resourcePath: "", + want: "/robots.txt", + }, { + name: "empty urlPath and slash-prefixed resourcePath", + urlPath: "", + resourcePath: "/foo", + want: "/foo", + }, { + name: "slash urlPath and slash-prefixed resourcePath", + urlPath: "/", + resourcePath: "/foo", + want: "/foo", + }, { + name: "empty urlPath and empty resourcePath", + urlPath: "", + resourcePath: "", + want: "/", + }, { + name: "non-slash-terminated urlPath and slash-prefixed resourcePath", + urlPath: "/foo", + resourcePath: "/bar", + want: "/foo/bar", + }, { + name: "slash-terminated urlPath and slash-prefixed resourcePath", + urlPath: "/foo/", + resourcePath: "/bar", + want: "/foo/bar", + }, { + name: "slash-terminated urlPath and non-slash-prefixed resourcePath", + urlPath: "/foo", + resourcePath: "bar", + want: "/foo/bar", + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := joinURLPath(tt.urlPath, tt.resourcePath) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) + } +} + +func Test_newRequest(t *testing.T) { + type args struct { + ctx context.Context + endpoint *Endpoint + desc *Descriptor + } + tests := []struct { + name string + args args + wantFn func(*testing.T, *http.Request) + wantErr error + }{{ + name: "url.Parse fails", + args: args{ + ctx: nil, + endpoint: &Endpoint{ + BaseURL: "\t\t\t", // does not parse! + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: "", + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: nil, + wantErr: errors.New(`parse "\t\t\t": net/url: invalid control character in URL`), + }, { + name: "http.NewRequestWithContext fails", + args: args{ + ctx: nil, // causes http.NewRequestWithContext to fail + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: "", + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: nil, + wantErr: errors.New("net/http: nil Context"), + }, { + name: "successful case with GET method, no body, and no extra headers", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + if req.Body != nil { + t.Fatal("invalid body", req.Body) + } + }, + wantErr: nil, + }, { + name: "successful case with POST method and body", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: 0, + Method: http.MethodPost, + RequestBody: []byte("deadbeef"), + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodPost { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + data, err := netxlite.ReadAllContext(context.Background(), req.Body) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff([]byte("deadbeef"), data); diff != "" { + t.Fatal(diff) + } + }, + wantErr: nil, + }, { + name: "with GET method and custom headers", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "antani.org", + UserAgent: "httpclient/1.0.1", + }, + desc: &Descriptor{ + Accept: "application/json", + Authorization: "deafbeef", + ContentType: "text/plain", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodPut, + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodPut { + t.Fatal("invalid method") + } + if req.Host != "antani.org" { + t.Fatal("invalid request host") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + if req.Header.Get("Authorization") != "deafbeef" { + t.Fatal("invalid authorization") + } + if req.Header.Get("Content-Type") != "text/plain" { + t.Fatal("invalid content-type") + } + if req.Header.Get("Accept") != "application/json" { + t.Fatal("invalid accept") + } + if req.Header.Get("User-Agent") != "httpclient/1.0.1" { + t.Fatal("invalid user-agent") + } + }, + wantErr: nil, + }, { + name: "we join the urlPath with the resourcePath", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/api/v1", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "/test-list/urls", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://www.example.com/api/v1/test-list/urls" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }, { + name: "we discard any query element inside the Endpoint.BaseURL", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.org/api/v1/?probe_cc=IT", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.org/api/v1/" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }, { + name: "we include query elements from Descriptor.URLQuery", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/api/v1/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "test-list/urls", + URLQuery: map[string][]string{ + "probe_cc": {"IT"}, + }, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://www.example.com/api/v1/test-list/urls?probe_cc=IT" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }, { + name: "with as many implicitly-initialized fields as possible", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + }, + desc: &Descriptor{}, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newRequest(tt.args.ctx, tt.args.endpoint, tt.args.desc) + switch { + case err == nil && tt.wantErr == nil: + // nothing + case err != nil && tt.wantErr == nil: + t.Fatalf("expected error but got %s", err.Error()) + case err == nil && tt.wantErr != nil: + t.Fatalf("expected %s but got ", tt.wantErr.Error()) + case err.Error() == tt.wantErr.Error(): + // nothing + default: + t.Fatalf("expected %s but got %s", err.Error(), tt.wantErr.Error()) + } + if tt.wantFn != nil { + tt.wantFn(t, got) + return + } + if got != nil { + t.Fatal("got response with nil tt.wantFn") + } + }) + } +} + +func TestCall(t *testing.T) { + type args struct { + ctx context.Context + desc *Descriptor + endpoint *Endpoint + } + tests := []struct { + name string + args args + want []byte + wantErr error + errfn func(t *testing.T, err error) + }{{ + name: "newRequest fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: "", + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + endpoint: &Endpoint{ + BaseURL: "\t\t\t", // causes newRequest to fail + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + }, + want: nil, + wantErr: errors.New(`parse "\t\t\t": net/url: invalid control character in URL`), + errfn: nil, + }, { + name: "endpoint.HTTPClient.Do fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + }, + want: nil, + wantErr: io.EOF, + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("unexpected error type") + } + }, + }, { + name: "reading body fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, netxlite.ECONNRESET + }, + }), + } + return resp, nil + }, + }, + }, + }, + want: nil, + wantErr: errors.New(netxlite.FailureConnectionReset), + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("unexpected error type") + } + }, + }, { + name: "status code indicates failure", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader("deadbeef")), + StatusCode: 403, + } + return resp, nil + }, + }, + }, + }, + want: nil, + wantErr: errors.New("httpapi: http request failed: 403"), + errfn: func(t *testing.T, err error) { + var expect *ErrHTTPRequestFailed + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "success with log body flag", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + LogBody: true, // as documented by this test's name + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader("deadbeef")), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + want: []byte("deadbeef"), + wantErr: nil, + errfn: nil, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Call(tt.args.ctx, tt.args.desc, tt.args.endpoint) + switch { + case err == nil && tt.wantErr == nil: + // nothing + case err != nil && tt.wantErr == nil: + t.Fatalf("expected error but got %s", err.Error()) + case err == nil && tt.wantErr != nil: + t.Fatalf("expected %s but got ", tt.wantErr.Error()) + case err.Error() == tt.wantErr.Error(): + // nothing + default: + t.Fatalf("expected %s but got %s", err.Error(), tt.wantErr.Error()) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) + } +} + +func TestCallWithJSONResponse(t *testing.T) { + type response struct { + Name string + Age int64 + } + expectedResponse := response{ + Name: "sbs", + Age: 99, + } + type args struct { + ctx context.Context + desc *Descriptor + endpoint *Endpoint + } + tests := []struct { + name string + args args + wantErr error + errfn func(*testing.T, error) + }{{ + name: "call fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "\t\t\t\t", // causes failure + }, + }, + wantErr: errors.New(`parse "\t\t\t\t": net/url: invalid control character in URL`), + errfn: nil, + }, { + name: "with error during httpClient.Do", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + }, + wantErr: io.EOF, + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "with error when reading the response body", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, netxlite.ECONNRESET + }, + }), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: errors.New(netxlite.FailureConnectionReset), + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "with HTTP failure", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader(`{"Name": "sbs", "Age": 99}`)), + StatusCode: 400, + } + return resp, nil + }, + }, + }, + }, + wantErr: errors.New("httpapi: http request failed: 400"), + errfn: func(t *testing.T, err error) { + var expect *ErrHTTPRequestFailed + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "with good response and missing header", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader(`{"Name": "sbs", "Age": 99}`)), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: nil, + errfn: nil, + }, { + name: "with good response and good header", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Header: http.Header{ + "Content-Type": {"application/json"}, + }, + Body: io.NopCloser(strings.NewReader(`{"Name": "sbs", "Age": 99}`)), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: nil, + errfn: nil, + }, { + name: "response is not JSON", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + LogBody: false, + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Header: http.Header{ + "Content-Type": {"application/json"}, + }, + Body: io.NopCloser(strings.NewReader(`{`)), // invalid JSON + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: errors.New("unexpected end of JSON input"), + errfn: nil, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var response response + err := CallWithJSONResponse(tt.args.ctx, tt.args.desc, tt.args.endpoint, &response) + switch { + case err == nil && tt.wantErr == nil: + if diff := cmp.Diff(expectedResponse, response); err != nil { + t.Fatal(diff) + } + case err != nil && tt.wantErr == nil: + t.Fatalf("expected error but got %s", err.Error()) + case err == nil && tt.wantErr != nil: + t.Fatalf("expected %s but got ", tt.wantErr.Error()) + case err.Error() == tt.wantErr.Error(): + // nothing + default: + t.Fatalf("expected %s but got %s", err.Error(), tt.wantErr.Error()) + } + if tt.errfn != nil { + tt.errfn(t, err) + } + }) + } +} + +func TestCallHonoursContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // should fail HTTP request immediately + desc := &Descriptor{ + LogBody: false, + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/robots.txt", + } + endpoint := &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: http.DefaultClient, + UserAgent: model.HTTPHeaderUserAgent, + } + body, err := Call(ctx, desc, endpoint) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err", err) + } + if len(body) > 0 { + t.Fatal("expected zero-length body") + } +} + +func TestCallWithJSONResponseHonoursContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // should fail HTTP request immediately + desc := &Descriptor{ + LogBody: false, + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/robots.txt", + } + endpoint := &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: http.DefaultClient, + UserAgent: model.HTTPHeaderUserAgent, + } + var resp url.URL + err := CallWithJSONResponse(ctx, desc, endpoint, &resp) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err", err) + } +} + +func TestDescriptorLogging(t *testing.T) { + + // This test was originally written for the httpx package and we have adapted it + // by keeping the ~same implementation with a custom callx function that converts + // the previous semantics of httpx to the new semantics of httpapi. + callx := func(baseURL string, logBody bool, logger model.Logger, request, response any) error { + desc := MustNewPOSTJSONWithJSONResponseDescriptor(logger, "/", request).WithBodyLogging(logBody) + runtimex.Assert(desc.LogBody == logBody, "desc.LogBody should be equal to logBody here") + endpoint := &Endpoint{ + BaseURL: baseURL, + HTTPClient: http.DefaultClient, + } + return CallWithJSONResponse(context.Background(), desc, endpoint, response) + } + + // we also needed to create a constructor for the logger + newlogger := func(logs chan string) model.Logger { + return &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + logs <- fmt.Sprintf(format, v...) + }, + MockWarnf: func(format string, v ...interface{}) { + logs <- fmt.Sprintf(format, v...) + }, + } + } + + t.Run("body logging enabled, 200 Ok, and without content-type", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, true, logger, input, &output) + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // we expect this because body logging is enabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // we expect this because body logging is enabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // we would expect this because the server does not send us any content-type + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we should see this because we sent a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we should see this because we receive a body + found |= 1 << 4 + continue + } + } + if found != (1<<0 | 1<<1 | 1<<2 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + if err != nil { + t.Fatal(err) + } + }) + + t.Run("body logging enabled, 200 Ok, and with content-type", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, true, logger, input, &output) + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // we expect this because body logging is enabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // we expect this because body logging is enabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // we do not expect this because the server sends us a content-type + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we should see this because we sent a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we should see this because we receive a body + found |= 1 << 4 + continue + } + } + if found != (1<<0 | 1<<1 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + if err != nil { + t.Fatal(err) + } + }) + + t.Run("body logging enabled and 401 Unauthorized", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, true, logger, input, &output) + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // should occur because body logging is enabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // should occur because body logging is enabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // note: this one should not occur because the code is 401 so we're not + // actually going to parse the JSON document + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we should see this because we send a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we should see this because we receive a body + found |= 1 << 4 + continue + } + } + if found != (1<<0 | 1<<1 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 401 { + t.Fatal("unexpected err", err) + } + }) + + t.Run("body logging NOT enabled and 200 Ok", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, false, logger, input, &output) // no logging + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // should not see it: body logging is disabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // should not see it: body logging is disabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // this one should be logged ANYWAY because it's orthogonal to the + // body logging so we should see it also in this case. + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // should see this because we send a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // should see this because we're receiving a body + found |= 1 << 4 + continue + } + } + if found != (1<<2 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + if err != nil { + t.Fatal(err) + } + }) + + t.Run("body logging NOT enabled and 401 Unauthorized", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, false, logger, input, &output) // no logging + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // should not see it: body logging is disabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // should not see it: body logging is disabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // should not see it because we don't parse the body on 401 errors + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we send a body so we should see it + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we receive a body so we should see it + found |= 1 << 4 + continue + } + } + if found != (1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 401 { + t.Fatal("unexpected err", err) + } + }) +} + +func Test_errMaybeCensorship_Unwrap(t *testing.T) { + t.Run("for errors.Is", func(t *testing.T) { + var err error = &errMaybeCensorship{io.EOF} + if !errors.Is(err, io.EOF) { + t.Fatal("cannot unwrap") + } + }) + + t.Run("for errors.As", func(t *testing.T) { + var err error = &errMaybeCensorship{netxlite.ECONNRESET} + var syserr syscall.Errno + if !errors.As(err, &syserr) || syserr != netxlite.ECONNRESET { + t.Fatal("cannot unwrap") + } + }) +} diff --git a/internal/httpapi/descriptor.go b/internal/httpapi/descriptor.go new file mode 100644 index 0000000..ed35e85 --- /dev/null +++ b/internal/httpapi/descriptor.go @@ -0,0 +1,155 @@ +package httpapi + +// +// HTTP API descriptor (e.g., GET /api/v1/test-list/urls) +// + +import ( + "encoding/json" + "net/http" + "net/url" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +// Descriptor contains the parameters for calling a given HTTP +// API (e.g., GET /api/v1/test-list/urls). +// +// The zero value of this struct is invalid. Please, fill all the +// fields marked as MANDATORY for correct initialization. +type Descriptor struct { + // Accept contains the OPTIONAL accept header. + Accept string + + // Authorization is the OPTIONAL authorization. + Authorization string + + // ContentType is the OPTIONAL content-type header. + ContentType string + + // LogBody OPTIONALLY enables logging bodies. + LogBody bool + + // Logger is the MANDATORY logger to use. + // + // For example, model.DiscardLogger. + Logger model.Logger + + // MaxBodySize is the OPTIONAL maximum response body size. If + // not set, we use the |DefaultMaxBodySize| constant. + MaxBodySize int64 + + // Method is the MANDATORY request method. + Method string + + // RequestBody is the OPTIONAL request body. + RequestBody []byte + + // Timeout is the OPTIONAL timeout for this call. If no timeout + // is specified we will use the |DefaultCallTimeout| const. + Timeout time.Duration + + // URLPath is the MANDATORY URL path. + URLPath string + + // URLQuery is the OPTIONAL query. + URLQuery url.Values +} + +// WithBodyLogging returns a SHALLOW COPY of |Descriptor| with LogBody set to |value|. You SHOULD +// only use this method when initializing the descriptor you want to use. +func (desc *Descriptor) WithBodyLogging(value bool) *Descriptor { + out := &Descriptor{} + *out = *desc + out.LogBody = value + return out +} + +// DefaultMaxBodySize is the default value for the maximum +// body size you can fetch using the httpapi package. +const DefaultMaxBodySize = 1 << 22 + +// DefaultCallTimeout is the default timeout for an httpapi call. +const DefaultCallTimeout = 60 * time.Second + +// NewGETJSONDescriptor is a convenience factory for creating a new descriptor +// that uses the GET method and expects a JSON response. +func NewGETJSONDescriptor(logger model.Logger, urlPath string) *Descriptor { + return NewGETJSONWithQueryDescriptor(logger, urlPath, url.Values{}) +} + +// applicationJSON is the content-type for JSON +const applicationJSON = "application/json" + +// NewGETJSONWithQueryDescriptor is like NewGETJSONDescriptor but it also +// allows you to provide |query| arguments. Leaving |query| nil or empty +// is equivalent to calling NewGETJSONDescriptor directly. +func NewGETJSONWithQueryDescriptor(logger model.Logger, urlPath string, query url.Values) *Descriptor { + return &Descriptor{ + Accept: applicationJSON, + Authorization: "", + ContentType: "", + LogBody: false, + Logger: logger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: urlPath, + URLQuery: query, + } +} + +// NewPOSTJSONWithJSONResponseDescriptor creates a descriptor that POSTs a JSON document +// and expects to receive back a JSON document from the API. +// +// This function ONLY fails if we cannot serialize the |request| to JSON. So, if you know +// that |request| is JSON-serializable, you can safely call MustNewPostJSONWithJSONResponseDescriptor instead. +func NewPOSTJSONWithJSONResponseDescriptor(logger model.Logger, urlPath string, request any) (*Descriptor, error) { + rawRequest, err := json.Marshal(request) + if err != nil { + return nil, err + } + desc := &Descriptor{ + Accept: applicationJSON, + Authorization: "", + ContentType: applicationJSON, + LogBody: false, + Logger: logger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodPost, + RequestBody: rawRequest, + Timeout: DefaultCallTimeout, + URLPath: urlPath, + URLQuery: nil, + } + return desc, nil +} + +// MustNewPOSTJSONWithJSONResponseDescriptor is like NewPOSTJSONWithJSONResponseDescriptor except that +// it panics in case it's not possible to JSON serialize the |request|. +func MustNewPOSTJSONWithJSONResponseDescriptor(logger model.Logger, urlPath string, request any) *Descriptor { + desc, err := NewPOSTJSONWithJSONResponseDescriptor(logger, urlPath, request) + runtimex.PanicOnError(err, "NewPOSTJSONWithJSONResponseDescriptor failed") + return desc +} + +// NewGETResourceDescriptor creates a generic descriptor for GETting a +// resource of unspecified type using the given |urlPath|. +func NewGETResourceDescriptor(logger model.Logger, urlPath string) *Descriptor { + return &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: logger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: urlPath, + URLQuery: url.Values{}, + } +} diff --git a/internal/httpapi/descriptor_test.go b/internal/httpapi/descriptor_test.go new file mode 100644 index 0000000..c3c58c0 --- /dev/null +++ b/internal/httpapi/descriptor_test.go @@ -0,0 +1,248 @@ +package httpapi + +import ( + "log" + "net/http" + "net/url" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" +) + +func TestDescriptor_WithBodyLogging(t *testing.T) { + type fields struct { + Accept string + Authorization string + ContentType string + LogBody bool + Logger model.Logger + MaxBodySize int64 + Method string + RequestBody []byte + Timeout time.Duration + URLPath string + URLQuery url.Values + } + tests := []struct { + name string + fields fields + want *Descriptor + }{{ + name: "with empty fields", + fields: fields{}, // LogBody defaults to false + want: &Descriptor{ + LogBody: true, + }, + }, { + name: "with nonempty fields", + fields: fields{ + Accept: "xx", + Authorization: "y", + ContentType: "zzz", + LogBody: false, // obviously must be false + Logger: model.DiscardLogger, + MaxBodySize: 123, + Method: "POST", + RequestBody: []byte("123"), + Timeout: 15555, + URLPath: "/", + URLQuery: map[string][]string{ + "a": {"b"}, + }, + }, + want: &Descriptor{ + Accept: "xx", + Authorization: "y", + ContentType: "zzz", + LogBody: true, + Logger: model.DiscardLogger, + MaxBodySize: 123, + Method: "POST", + RequestBody: []byte("123"), + Timeout: 15555, + URLPath: "/", + URLQuery: map[string][]string{ + "a": {"b"}, + }, + }, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + desc := &Descriptor{ + Accept: tt.fields.Accept, + Authorization: tt.fields.Authorization, + ContentType: tt.fields.ContentType, + LogBody: tt.fields.LogBody, + Logger: tt.fields.Logger, + MaxBodySize: tt.fields.MaxBodySize, + Method: tt.fields.Method, + RequestBody: tt.fields.RequestBody, + Timeout: tt.fields.Timeout, + URLPath: tt.fields.URLPath, + URLQuery: tt.fields.URLQuery, + } + got := desc.WithBodyLogging(true) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) + } +} + +func TestNewGetJSONDescriptor(t *testing.T) { + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: url.Values{}, + } + got := NewGETJSONDescriptor(model.DiscardLogger, "/robots.txt") + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } +} + +func TestNewGetJSONWithQueryDescriptor(t *testing.T) { + query := url.Values{ + "a": {"b"}, + "c": {"d"}, + } + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: query, + } + got := NewGETJSONWithQueryDescriptor(model.DiscardLogger, "/robots.txt", query) + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } +} + +func TestNewPOSTJSONWithJSONResponseDescriptor(t *testing.T) { + type request struct { + Name string + Age int64 + } + + t.Run("with failure", func(t *testing.T) { + request := make(chan int64) + got, err := NewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + if err == nil || err.Error() != "json: unsupported type: chan int64" { + log.Fatal("unexpected err", err) + } + if got != nil { + log.Fatal("expected to get a nil Descriptor") + } + }) + + t.Run("with success", func(t *testing.T) { + request := request{ + Name: "sbs", + Age: 99, + } + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "application/json", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodPost, + RequestBody: []byte(`{"Name":"sbs","Age":99}`), + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: nil, + } + got, err := NewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + if err != nil { + log.Fatal(err) + } + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } + }) +} + +func TestMustNewPOSTJSONWithJSONResponseDescriptor(t *testing.T) { + type request struct { + Name string + Age int64 + } + + t.Run("with failure", func(t *testing.T) { + var panicked bool + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + request := make(chan int64) + _ = MustNewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + }() + if !panicked { + t.Fatal("did not panic") + } + }) + + t.Run("with success", func(t *testing.T) { + request := request{ + Name: "sbs", + Age: 99, + } + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "application/json", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodPost, + RequestBody: []byte(`{"Name":"sbs","Age":99}`), + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: nil, + } + got := MustNewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } + }) +} + +func TestNewGetResourceDescriptor(t *testing.T) { + expected := &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: url.Values{}, + } + got := NewGETResourceDescriptor(model.DiscardLogger, "/robots.txt") + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } +} diff --git a/internal/httpapi/doc.go b/internal/httpapi/doc.go new file mode 100644 index 0000000..0c1361c --- /dev/null +++ b/internal/httpapi/doc.go @@ -0,0 +1,15 @@ +// Package httpapi contains code for calling HTTP APIs. +// +// We model HTTP APIs as follows: +// +// 1. |Endpoint| is an API endpoint (e.g., https://api.ooni.io); +// +// 2. |Descriptor| describes the specific API you want to use (e.g., +// GET /api/v1/test-list/urls with JSON response body). +// +// Generally, you use |Call| to call the API identified by a |Descriptor| +// on the specified |Endpoint|. However, there are cases where you +// need more complex calling patterns. For example, with |SequenceCaller| +// you can invoke the same API |Descriptor| with multiple equivalent +// API |Endpoint|s until one of them succeeds or all fail. +package httpapi diff --git a/internal/httpapi/endpoint.go b/internal/httpapi/endpoint.go new file mode 100644 index 0000000..acfc4d8 --- /dev/null +++ b/internal/httpapi/endpoint.go @@ -0,0 +1,76 @@ +package httpapi + +// +// HTTP API Endpoint (e.g., https://api.ooni.io) +// + +import "github.com/ooni/probe-cli/v3/internal/model" + +// Endpoint models an HTTP endpoint on which you can call +// several HTTP APIs (e.g., https://api.ooni.io) using a +// given HTTP client potentially using a circumvention tunnel +// mechanism such as psiphon or torsf. +// +// The zero value of this struct is invalid. Please, fill all the +// fields marked as MANDATORY for correct initialization. +type Endpoint struct { + // BaseURL is the MANDATORY endpoint base URL. We will honour the + // path of this URL and prepend it to the actual path specified inside + // a |Descriptor.URLPath|. However, we will always discard any query + // that may have been set inside the BaseURL. The only query string + // will be composed from the |Descriptor.URLQuery| values. + // + // For example, https://api.ooni.io. + BaseURL string + + // HTTPClient is the MANDATORY HTTP client to use. + // + // For example, http.DefaultClient. You can introduce circumvention + // here by using an HTTPClient bound to a specific tunnel. + HTTPClient model.HTTPClient + + // Host is the OPTIONAL host header to use. + // + // If this field is empty we use the BaseURL's hostname. A specific + // host header may be needed when using cloudfronting. + Host string + + // User-Agent is the OPTIONAL user-agent to use. If empty, + // we'll use the stdlib's default user-agent string. + UserAgent string +} + +// NewEndpointList constructs a list of API endpoints from |services| +// returned by the OONI backend (or known in advance). +// +// Arguments: +// +// - httpClient is the HTTP client to use for accessing the endpoints; +// +// - userAgent is the user agent you would like to use; +// +// - service is the list of services gathered from the backend. +func NewEndpointList(httpClient model.HTTPClient, + userAgent string, services ...model.OOAPIService) (out []*Endpoint) { + for _, svc := range services { + switch svc.Type { + case "https": + out = append(out, &Endpoint{ + BaseURL: svc.Address, + HTTPClient: httpClient, + Host: "", + UserAgent: userAgent, + }) + case "cloudfront": + out = append(out, &Endpoint{ + BaseURL: svc.Address, + HTTPClient: httpClient, + Host: svc.Front, + UserAgent: userAgent, + }) + default: + // nothing! + } + } + return +} diff --git a/internal/httpapi/endpoint_test.go b/internal/httpapi/endpoint_test.go new file mode 100644 index 0000000..7077e14 --- /dev/null +++ b/internal/httpapi/endpoint_test.go @@ -0,0 +1,69 @@ +package httpapi + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestNewEndpointList(t *testing.T) { + type args struct { + httpClient model.HTTPClient + userAgent string + services []model.OOAPIService + } + defaultHTTPClient := &mocks.HTTPClient{} + tests := []struct { + name string + args args + wantOut []*Endpoint + }{{ + name: "with no services", + args: args{ + httpClient: defaultHTTPClient, + userAgent: model.HTTPHeaderUserAgent, + services: nil, + }, + wantOut: nil, + }, { + name: "common cases", + args: args{ + httpClient: defaultHTTPClient, + userAgent: model.HTTPHeaderUserAgent, + services: []model.OOAPIService{{ + Address: "https://www.example.com/", + Type: "https", + Front: "", + }, { + Address: "https://www.example.org/", + Type: "cloudfront", + Front: "example.org.it", + }, { + Address: "https://nonexistent.onion/", + Type: "onion", + Front: "", + }}, + }, + wantOut: []*Endpoint{{ + BaseURL: "https://www.example.com/", + HTTPClient: defaultHTTPClient, + Host: "", + UserAgent: model.HTTPHeaderUserAgent, + }, { + BaseURL: "https://www.example.org/", + HTTPClient: defaultHTTPClient, + Host: "example.org.it", + UserAgent: model.HTTPHeaderUserAgent, + }}, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOut := NewEndpointList(tt.args.httpClient, tt.args.userAgent, tt.args.services...) + if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" { + t.Fatal(diff) + } + }) + } +} diff --git a/internal/httpapi/sequence.go b/internal/httpapi/sequence.go new file mode 100644 index 0000000..da1f11d --- /dev/null +++ b/internal/httpapi/sequence.go @@ -0,0 +1,92 @@ +package httpapi + +// +// Sequentially call available API endpoints until one succeed +// or all of them fail. A future implementation of this code may +// (probably should?) take into account knowledge of what is +// working and what is not working to optimize the order with +// which to try different alternatives. +// + +import ( + "context" + "errors" + + "github.com/ooni/probe-cli/v3/internal/multierror" +) + +// SequenceCaller calls the API specified by |Descriptor| once for each of +// the available |Endpoints| until one of them succeeds. +// +// CAVEAT: this code will ONLY retry API calls with subsequent endpoints when +// the error originates in the HTTP round trip or while reading the body. +type SequenceCaller struct { + // Descriptor is the API |Descriptor|. + Descriptor *Descriptor + + // Endpoints is the list of |Endpoint| to use. + Endpoints []*Endpoint +} + +// NewSequenceCaller is a factory for creating a |SequenceCaller|. +func NewSequenceCaller(desc *Descriptor, endpoints ...*Endpoint) *SequenceCaller { + return &SequenceCaller{ + Descriptor: desc, + Endpoints: endpoints, + } +} + +// ErrAllEndpointsFailed indicates that all endpoints failed. +var ErrAllEndpointsFailed = errors.New("httpapi: all endpoints failed") + +// shouldRetry returns true when we should try with another endpoint given the +// value of |err| which could (obviously) be nil in case of success. +func (sc *SequenceCaller) shouldRetry(err error) bool { + var kind *errMaybeCensorship + belongs := errors.As(err, &kind) + return belongs +} + +// Call calls |Call| for each |Endpoint| and |Descriptor| until one endpoint succeeds. The +// return value is the response body and the selected endpoint index or the error. +// +// CAVEAT: this code will ONLY retry API calls with subsequent endpoints when +// the error originates in the HTTP round trip or while reading the body. +func (sc *SequenceCaller) Call(ctx context.Context) ([]byte, int, error) { + var selected int + merr := multierror.New(ErrAllEndpointsFailed) + for _, epnt := range sc.Endpoints { + respBody, err := Call(ctx, sc.Descriptor, epnt) + if sc.shouldRetry(err) { + merr.Add(err) + selected++ + continue + } + // Note: some errors will lead us to return + // early as documented for this method + return respBody, selected, err + } + return nil, -1, merr +} + +// CallWithJSONResponse is like |SequenceCaller.Call| except that it invokes the +// underlying |CallWithJSONResponse| rather than invoking |Call|. +// +// CAVEAT: this code will ONLY retry API calls with subsequent endpoints when +// the error originates in the HTTP round trip or while reading the body. +func (sc *SequenceCaller) CallWithJSONResponse(ctx context.Context, response any) (int, error) { + var selected int + merr := multierror.New(ErrAllEndpointsFailed) + for _, epnt := range sc.Endpoints { + err := CallWithJSONResponse(ctx, sc.Descriptor, epnt, response) + if sc.shouldRetry(err) { + merr.Add(err) + selected++ + continue + } + // Note: some errors will lead us to return + // early as documented for this method + return selected, err + } + return -1, merr +} diff --git a/internal/httpapi/sequence_test.go b/internal/httpapi/sequence_test.go new file mode 100644 index 0000000..13cc50f --- /dev/null +++ b/internal/httpapi/sequence_test.go @@ -0,0 +1,358 @@ +package httpapi + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestSequenceCaller(t *testing.T) { + t.Run("Call", func(t *testing.T) { + t.Run("first success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("deadbeef")), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + if err != nil { + t.Fatal(err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff([]byte("deadbeef"), data); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("first HTTP failure and we immediately stop", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 403, // should cause us to return early + Body: io.NopCloser(strings.NewReader("deadbeef")), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 403 { + t.Fatal("unexpected err", err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if len(data) > 0 { + t.Fatal("expected to see no response body") + } + }) + + t.Run("first network failure, second success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to cycle to the second entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("abad1dea")), + } + return resp, nil + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + if err != nil { + t.Fatal(err) + } + if idx != 1 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff([]byte("abad1dea"), data); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("all network failure", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to cycle to the next entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to cycle to the next entry + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + if !errors.Is(err, ErrAllEndpointsFailed) { + t.Fatal("unexpected err", err) + } + if idx != -1 { + t.Fatal("invalid idx") + } + if len(data) > 0 { + t.Fatal("expected zero-length data") + } + }) + }) + + t.Run("CallWithJSONResponse", func(t *testing.T) { + type response struct { + Name string + Age int64 + } + + t.Run("first success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"Name":"sbs","Age":99}`)), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{}`)), // different + } + return resp, nil + }, + }, + }, + ) + expect := response{ + Name: "sbs", + Age: 99, + } + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + if err != nil { + t.Fatal(err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("first HTTP failure and we immediately stop", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 403, // should be enough to cause us fail immediately + Body: io.NopCloser(strings.NewReader(`{"Age": 155, "Name": "sbs"}`)), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + ) + // even though there is a JSON body we don't care about reading it + // and so we expect to see in output the zero-value struct + expect := response{ + Name: "", + Age: 0, + } + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 403 { + t.Fatal("unexpected err", err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("first network failure, second success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to try the next entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"Age":155}`)), + } + return resp, nil + }, + }, + }, + ) + expect := response{ + Name: "", + Age: 155, + } + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + if err != nil { + t.Fatal(err) + } + if idx != 1 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("all network failure", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to try the next entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to try the next entry + }, + }, + }, + ) + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + if !errors.Is(err, ErrAllEndpointsFailed) { + t.Fatal("unexpected err", err) + } + if idx != -1 { + t.Fatal("invalid idx") + } + }) + }) +} diff --git a/internal/httpx/httpx.go b/internal/httpx/httpx.go index f330745..8083f10 100644 --- a/internal/httpx/httpx.go +++ b/internal/httpx/httpx.go @@ -1,4 +1,8 @@ // Package httpx contains http extensions. +// +// Deprecated: new code should use httpapi instead. While this package and httpapi +// are basically using the same implementation, the API exposed by httpapi allows +// us to try the same request with multiple HTTP endpoints. package httpx import (