diff --git a/internal/engine/experiment/webconnectivity/control.go b/internal/engine/experiment/webconnectivity/control.go index ccaf870..bd953a1 100644 --- a/internal/engine/experiment/webconnectivity/control.go +++ b/internal/engine/experiment/webconnectivity/control.go @@ -4,9 +4,10 @@ import ( "context" "github.com/ooni/probe-cli/v3/internal/geoipx" - "github.com/ooni/probe-cli/v3/internal/httpx" + "github.com/ooni/probe-cli/v3/internal/httpapi" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" ) // Redirect to types defined inside the model package @@ -21,22 +22,23 @@ type ( // Control performs the control request and returns the response. func Control( ctx context.Context, sess model.ExperimentSession, - thAddr string, creq ControlRequest) (out ControlResponse, err error) { - clnt := &httpx.APIClientTemplate{ - BaseURL: thAddr, - HTTPClient: sess.DefaultHTTPClient(), - Logger: sess.Logger(), - UserAgent: sess.UserAgent(), - } + testhelpers []model.OOAPIService, creq ControlRequest) (ControlResponse, *model.OOAPIService, error) { + seqCaller := httpapi.NewSequenceCaller( + httpapi.MustNewPOSTJSONWithJSONResponseDescriptor(sess.Logger(), "/", creq).WithBodyLogging(true), + httpapi.NewEndpointList(sess.DefaultHTTPClient(), sess.UserAgent(), testhelpers...)..., + ) sess.Logger().Infof("control for %s...", creq.HTTPRequest) - // make sure error is wrapped - err = clnt.WithBodyLogging().Build().PostJSON(ctx, "/", creq, &out) - if err != nil { - err = netxlite.NewTopLevelGenericErrWrapper(err) - } + var out ControlResponse + idx, err := seqCaller.CallWithJSONResponse(ctx, &out) sess.Logger().Infof("control for %s... %+v", creq.HTTPRequest, model.ErrorToStringOrOK(err)) + if err != nil { + // make sure error is wrapped + err = netxlite.NewTopLevelGenericErrWrapper(err) + return ControlResponse{}, nil, err + } fillASNs(&out.DNS) - return + runtimex.Assert(idx >= 0 && idx < len(testhelpers), "idx out of bounds") + return out, &testhelpers[idx], nil } // fillASNs fills the ASNs array of ControlDNSResult. For each Addr inside diff --git a/internal/engine/experiment/webconnectivity/webconnectivity.go b/internal/engine/experiment/webconnectivity/webconnectivity.go index 526b9f7..bdbea35 100644 --- a/internal/engine/experiment/webconnectivity/webconnectivity.go +++ b/internal/engine/experiment/webconnectivity/webconnectivity.go @@ -15,7 +15,7 @@ import ( const ( testName = "web_connectivity" - testVersion = "0.4.1" + testVersion = "0.4.2" ) // Config contains the experiment config. @@ -145,19 +145,9 @@ func (m Measurer) Run( } // 1. find test helper testhelpers, _ := sess.GetTestHelpersByName("web-connectivity") - var testhelper *model.OOAPIService - for _, th := range testhelpers { - if th.Type == "https" { - testhelper = &th - break - } - } - if testhelper == nil { + if len(testhelpers) < 1 { return ErrNoAvailableTestHelpers } - measurement.TestHelpers = map[string]interface{}{ - "backend": testhelper, - } // 2. perform the DNS lookup step dnsBegin := time.Now() dnsResult := DNSLookup(ctx, DNSLookupConfig{ @@ -167,10 +157,11 @@ func (m Measurer) Run( tk.Queries = append(tk.Queries, dnsResult.TestKeys.Queries...) tk.DNSExperimentFailure = dnsResult.Failure epnts := NewEndpoints(URL, dnsResult.Addresses()) - sess.Logger().Infof("using control: %s", testhelper.Address) + sess.Logger().Infof("using control: %+v", testhelpers) // 3. perform the control measurement thBegin := time.Now() - tk.Control, err = Control(ctx, sess, testhelper.Address, ControlRequest{ + var usedTH *model.OOAPIService + tk.Control, usedTH, err = Control(ctx, sess, testhelpers, ControlRequest{ HTTPRequest: URL.String(), HTTPRequestHeaders: map[string][]string{ "Accept": {model.HTTPHeaderAccept}, @@ -179,6 +170,11 @@ func (m Measurer) Run( }, TCPConnect: epnts.Endpoints(), }) + if usedTH != nil { + measurement.TestHelpers = map[string]interface{}{ + "backend": usedTH, + } + } tk.THRuntime = time.Since(thBegin) tk.ControlFailure = tracex.NewFailure(err) // 4. analyze DNS results diff --git a/internal/engine/experiment/webconnectivity/webconnectivity_test.go b/internal/engine/experiment/webconnectivity/webconnectivity_test.go index c1fcb37..6db3989 100644 --- a/internal/engine/experiment/webconnectivity/webconnectivity_test.go +++ b/internal/engine/experiment/webconnectivity/webconnectivity_test.go @@ -21,7 +21,7 @@ func TestNewExperimentMeasurer(t *testing.T) { if measurer.ExperimentName() != "web_connectivity" { t.Fatal("unexpected name") } - if measurer.ExperimentVersion() != "0.4.1" { + if measurer.ExperimentVersion() != "0.4.2" { t.Fatal("unexpected version") } } diff --git a/internal/experiment/webconnectivity/cleartextflow.go b/internal/experiment/webconnectivity/cleartextflow.go index 351a27c..d515d4e 100644 --- a/internal/experiment/webconnectivity/cleartextflow.go +++ b/internal/experiment/webconnectivity/cleartextflow.go @@ -285,7 +285,7 @@ func (t *CleartextFlow) maybeFollowRedirects(ctx context.Context, resp *http.Res WaitGroup: t.WaitGroup, Referer: resp.Request.URL.String(), Session: nil, // no need to issue another control request - THAddr: "", // ditto + TestHelpers: nil, // ditto UDPAddress: t.UDPAddress, } resolvers.Start(ctx) diff --git a/internal/experiment/webconnectivity/control.go b/internal/experiment/webconnectivity/control.go index 3e58f42..bf7003a 100644 --- a/internal/experiment/webconnectivity/control.go +++ b/internal/experiment/webconnectivity/control.go @@ -8,10 +8,11 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/engine/experiment/webconnectivity" - "github.com/ooni/probe-cli/v3/internal/httpx" + "github.com/ooni/probe-cli/v3/internal/httpapi" "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" ) // EndpointMeasurementsStarter is used by Control to start extra @@ -51,8 +52,8 @@ type Control struct { // Session is the MANDATORY session to use. Session model.ExperimentSession - // THAddr is the MANDATORY TH's URL. - THAddr string + // TestHelpers is the MANDATORY list of test helpers. + TestHelpers []model.OOAPIService // URL is the MANDATORY URL we are measuring. URL *url.URL @@ -102,26 +103,20 @@ func (c *Control) Run(parentCtx context.Context) { // create logger for this operation ol := measurexlite.NewOperationLogger( c.Logger, - "control for %s using %s", + "control for %s using %+v", creq.HTTPRequest, - c.THAddr, + c.TestHelpers, ) - // create an API client - clnt := (&httpx.APIClientTemplate{ - Accept: "", - Authorization: "", - BaseURL: c.THAddr, - HTTPClient: c.Session.DefaultHTTPClient(), - Host: "", // use the one inside the URL - LogBody: true, - Logger: c.Logger, - UserAgent: c.Session.UserAgent(), - }).Build() + // create an httpapi sequence caller + seqCaller := httpapi.NewSequenceCaller( + httpapi.MustNewPOSTJSONWithJSONResponseDescriptor(c.Logger, "/", creq).WithBodyLogging(true), + httpapi.NewEndpointList(c.Session.DefaultHTTPClient(), c.Session.UserAgent(), c.TestHelpers...)..., + ) // issue the control request and wait for the response var cresp webconnectivity.ControlResponse - err := clnt.PostJSON(opCtx, "/", creq, &cresp) + idx, err := seqCaller.CallWithJSONResponse(opCtx, &cresp) if err != nil { // make sure error is wrapped err = netxlite.NewTopLevelGenericErrWrapper(err) @@ -134,6 +129,10 @@ func (c *Control) Run(parentCtx context.Context) { c.TestKeys.SetControl(&cresp) ol.Stop(nil) + // record the specific TH that worked + runtimex.Assert(idx >= 0 && idx < len(c.TestHelpers), "idx out of bounds") + c.TestKeys.setTestHelper(&c.TestHelpers[idx]) + // if the TH returned us addresses we did not previously were // aware of, make sure we also measure them c.maybeStartExtraMeasurements(parentCtx, cresp.DNS.Addrs) diff --git a/internal/experiment/webconnectivity/dnsresolvers.go b/internal/experiment/webconnectivity/dnsresolvers.go index 0c9df8d..8709635 100644 --- a/internal/experiment/webconnectivity/dnsresolvers.go +++ b/internal/experiment/webconnectivity/dnsresolvers.go @@ -67,8 +67,9 @@ type DNSResolvers struct { // always follow the redirect chain caused by the provided URL. Session model.ExperimentSession - // THAddr is the OPTIONAL test helper address. - THAddr string + // TestHelpers is the OPTIONAL list of test helpers. If the list is + // empty, we are not going to try to contact any test helper. + TestHelpers []model.OOAPIService // UDPAddress is the OPTIONAL address of the UDP resolver to use. If this // field is not set we use a default one (e.g., `8.8.8.8:53`). @@ -498,15 +499,15 @@ func (t *DNSResolvers) startSecureFlows( } } -// maybeStartControlFlow starts the control flow iff .Session and .THAddr are set. +// maybeStartControlFlow starts the control flow iff .Session and .TestHelpers are set. func (t *DNSResolvers) maybeStartControlFlow( ctx context.Context, ps *prioritySelector, addresses []DNSEntry, ) { - // note: for subsequent requests we don't set .Session and .THAddr hence + // note: for subsequent requests we don't set .Session and .TestHelpers hence // we are not going to query the test helper more than once - if t.Session != nil && t.THAddr != "" { + if t.Session != nil && len(t.TestHelpers) > 0 { var addrs []string for _, addr := range addresses { addrs = append(addrs, addr.Addr) @@ -518,7 +519,7 @@ func (t *DNSResolvers) maybeStartControlFlow( PrioSelector: ps, TestKeys: t.TestKeys, Session: t.Session, - THAddr: t.THAddr, + TestHelpers: t.TestHelpers, URL: t.URL, WaitGroup: t.WaitGroup, } diff --git a/internal/experiment/webconnectivity/measurer.go b/internal/experiment/webconnectivity/measurer.go index c6731e4..5a0b977 100644 --- a/internal/experiment/webconnectivity/measurer.go +++ b/internal/experiment/webconnectivity/measurer.go @@ -36,7 +36,7 @@ func (m *Measurer) ExperimentName() string { // ExperimentVersion implements model.ExperimentMeasurer. func (m *Measurer) ExperimentVersion() string { - return "0.5.18" + return "0.5.19" } // Run implements model.ExperimentMeasurer. @@ -89,17 +89,7 @@ func (m *Measurer) Run(ctx context.Context, sess model.ExperimentSession, // obtain the test helper's address testhelpers, _ := sess.GetTestHelpersByName("web-connectivity") - var thAddr string - for _, th := range testhelpers { - if th.Type == "https" { - thAddr = th.Address - measurement.TestHelpers = map[string]any{ - "backend": &th, - } - break - } - } - if thAddr == "" { + if len(testhelpers) < 1 { sess.Logger().Warnf("continuing without a valid TH address") tk.SetControlFailure(webconnectivity.ErrNoAvailableTestHelpers) } @@ -120,7 +110,7 @@ func (m *Measurer) Run(ctx context.Context, sess model.ExperimentSession, CookieJar: jar, Referer: "", Session: sess, - THAddr: thAddr, + TestHelpers: testhelpers, UDPAddress: "", } resos.Start(ctx) @@ -137,6 +127,16 @@ func (m *Measurer) Run(ctx context.Context, sess model.ExperimentSession, // perform any deferred computation on the test keys tk.Finalize(sess.Logger()) + // set the test helper we used + // TODO(bassosimone): it may be more informative to know about all the + // test helpers we _tried_ to use, however the data format does not have + // support for that as far as I can tell... + if th := tk.getTestHelper(); th != nil { + measurement.TestHelpers = map[string]interface{}{ + "backend": th, + } + } + // return whether there was a fundamental failure, which would prevent // the measurement from being submitted to the OONI collector. return tk.fundamentalFailure diff --git a/internal/experiment/webconnectivity/secureflow.go b/internal/experiment/webconnectivity/secureflow.go index 68f41e1..6a47b73 100644 --- a/internal/experiment/webconnectivity/secureflow.go +++ b/internal/experiment/webconnectivity/secureflow.go @@ -337,7 +337,7 @@ func (t *SecureFlow) maybeFollowRedirects(ctx context.Context, resp *http.Respon WaitGroup: t.WaitGroup, Referer: resp.Request.URL.String(), Session: nil, // no need to issue another control request - THAddr: "", // ditto + TestHelpers: nil, // ditto UDPAddress: t.UDPAddress, } resolvers.Start(ctx) diff --git a/internal/experiment/webconnectivity/testkeys.go b/internal/experiment/webconnectivity/testkeys.go index 37dbc94..5cd917f 100644 --- a/internal/experiment/webconnectivity/testkeys.go +++ b/internal/experiment/webconnectivity/testkeys.go @@ -134,6 +134,10 @@ type TestKeys struct { // mu provides mutual exclusion for accessing the test keys. mu *sync.Mutex + + // testHelper is used to communicate the TH that worked to the main + // goroutine such that we can fill measurement.TestHelpers. + testHelper *model.OOAPIService } // ConnPriorityLogEntry is an entry in the TestKeys.ConnPriorityLog slice. @@ -302,6 +306,21 @@ func (tk *TestKeys) AppendConnPriorityLogEntry(entry *ConnPriorityLogEntry) { tk.mu.Unlock() } +// setTestHelper sets .testHelper in a thread safe way +func (tk *TestKeys) setTestHelper(th *model.OOAPIService) { + tk.mu.Lock() + tk.testHelper = th + tk.mu.Unlock() +} + +// getTestHelper gets .testHelper in a thread safe way +func (tk *TestKeys) getTestHelper() (th *model.OOAPIService) { + tk.mu.Lock() + th = tk.testHelper + tk.mu.Unlock() + return +} + // NewTestKeys creates a new instance of TestKeys. func NewTestKeys() *TestKeys { return &TestKeys{ @@ -348,6 +367,7 @@ func NewTestKeys() *TestKeys { ControlRequest: nil, fundamentalFailure: nil, mu: &sync.Mutex{}, + testHelper: nil, } } diff --git a/internal/httpapi/call.go b/internal/httpapi/call.go new file mode 100644 index 0000000..1228080 --- /dev/null +++ b/internal/httpapi/call.go @@ -0,0 +1,181 @@ +package httpapi + +// +// Calling HTTP APIs. +// + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +// joinURLPath appends |resourcePath| to |urlPath|. +func joinURLPath(urlPath, resourcePath string) string { + if resourcePath == "" { + if urlPath == "" { + return "/" + } + return urlPath + } + if !strings.HasSuffix(urlPath, "/") { + urlPath += "/" + } + resourcePath = strings.TrimPrefix(resourcePath, "/") + return urlPath + resourcePath +} + +// newRequest creates a new http.Request from the given |ctx|, |endpoint|, and |desc|. +func newRequest(ctx context.Context, endpoint *Endpoint, desc *Descriptor) (*http.Request, error) { + URL, err := url.Parse(endpoint.BaseURL) + if err != nil { + return nil, err + } + // BaseURL and resource URL are joined if they have a path + URL.Path = joinURLPath(URL.Path, desc.URLPath) + if len(desc.URLQuery) > 0 { + URL.RawQuery = desc.URLQuery.Encode() + } else { + URL.RawQuery = "" // as documented we only honour desc.URLQuery + } + var reqBody io.Reader + if len(desc.RequestBody) > 0 { + reqBody = bytes.NewReader(desc.RequestBody) + desc.Logger.Debugf("httpapi: request body length: %d", len(desc.RequestBody)) + if desc.LogBody { + desc.Logger.Debugf("httpapi: request body: %s", string(desc.RequestBody)) + } + } + request, err := http.NewRequestWithContext(ctx, desc.Method, URL.String(), reqBody) + if err != nil { + return nil, err + } + request.Host = endpoint.Host // allow cloudfronting + if desc.Authorization != "" { + request.Header.Set("Authorization", desc.Authorization) + } + if desc.ContentType != "" { + request.Header.Set("Content-Type", desc.ContentType) + } + if desc.Accept != "" { + request.Header.Set("Accept", desc.Accept) + } + if endpoint.UserAgent != "" { + request.Header.Set("User-Agent", endpoint.UserAgent) + } + return request, nil +} + +// ErrHTTPRequestFailed indicates that the server returned >= 400. +type ErrHTTPRequestFailed struct { + // StatusCode is the status code that failed. + StatusCode int +} + +// Error implements error. +func (err *ErrHTTPRequestFailed) Error() string { + return fmt.Sprintf("httpapi: http request failed: %d", err.StatusCode) +} + +// errMaybeCensorship indicates that there was an error at the networking layer +// including, e.g., DNS, TCP connect, TLS. When we see this kind of error, we +// will consider retrying with another endpoint under the assumption that it +// may be that the current endpoint is censored. +type errMaybeCensorship struct { + // Err is the underlying error + Err error +} + +// Error implements error +func (err *errMaybeCensorship) Error() string { + return err.Err.Error() +} + +// Unwrap allows to get the underlying error +func (err *errMaybeCensorship) Unwrap() error { + return err.Err +} + +// docall calls the API represented by the given request |req| on the given |endpoint| +// and returns the response and its body or an error. +func docall(endpoint *Endpoint, desc *Descriptor, request *http.Request) (*http.Response, []byte, error) { + // Implementation note: remember to mark errors for which you want + // to retry with another endpoint using errMaybeCensorship. + response, err := endpoint.HTTPClient.Do(request) + if err != nil { + return nil, nil, &errMaybeCensorship{err} + } + defer response.Body.Close() + // Implementation note: always read and log the response body since + // it's quite useful to see the response JSON on API error. + r := io.LimitReader(response.Body, DefaultMaxBodySize) + data, err := netxlite.ReadAllContext(request.Context(), r) + if err != nil { + return response, nil, &errMaybeCensorship{err} + } + desc.Logger.Debugf("httpapi: response body length: %d bytes", len(data)) + if desc.LogBody { + desc.Logger.Debugf("httpapi: response body: %s", string(data)) + } + if response.StatusCode >= 400 { + return response, nil, &ErrHTTPRequestFailed{response.StatusCode} + } + return response, data, nil +} + +// call is like Call but also returns the response. +func call(ctx context.Context, desc *Descriptor, endpoint *Endpoint) (*http.Response, []byte, error) { + timeout := desc.Timeout + if timeout <= 0 { + timeout = DefaultCallTimeout // as documented + } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + request, err := newRequest(ctx, endpoint, desc) + if err != nil { + return nil, nil, err + } + return docall(endpoint, desc, request) +} + +// Call invokes the API described by |desc| on the given HTTP |endpoint| and +// returns the response body (as a slice of bytes) or an error. +// +// Note: this function returns ErrHTTPRequestFailed if the HTTP status code is +// greater or equal than 400. You could use errors.As to obtain a copy of the +// error that was returned and see for yourself the actual status code. +func Call(ctx context.Context, desc *Descriptor, endpoint *Endpoint) ([]byte, error) { + _, rawResponseBody, err := call(ctx, desc, endpoint) + return rawResponseBody, err +} + +// goodContentTypeForJSON tracks known-good content-types for JSON. If the content-type +// is not in this map, |CallWithJSONResponse| emits a warning message. +var goodContentTypeForJSON = map[string]bool{ + applicationJSON: true, +} + +// CallWithJSONResponse is like Call but also assumes that the response is a +// JSON body and attempts to parse it into the |response| field. +// +// Note: this function returns ErrHTTPRequestFailed if the HTTP status code is +// greater or equal than 400. You could use errors.As to obtain a copy of the +// error that was returned and see for yourself the actual status code. +func CallWithJSONResponse(ctx context.Context, desc *Descriptor, endpoint *Endpoint, response any) error { + httpResp, rawRespBody, err := call(ctx, desc, endpoint) + if err != nil { + return err + } + if ctype := httpResp.Header.Get("Content-Type"); !goodContentTypeForJSON[ctype] { + desc.Logger.Warnf("httpapi: unexpected content-type: %s", ctype) + // fallthrough + } + return json.Unmarshal(rawRespBody, response) +} diff --git a/internal/httpapi/call_test.go b/internal/httpapi/call_test.go new file mode 100644 index 0000000..8fd45d7 --- /dev/null +++ b/internal/httpapi/call_test.go @@ -0,0 +1,1163 @@ +package httpapi + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "syscall" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +func Test_joinURLPath(t *testing.T) { + tests := []struct { + name string + urlPath string + resourcePath string + want string + }{{ + name: "whole path inside urlPath and empty resourcePath", + urlPath: "/robots.txt", + resourcePath: "", + want: "/robots.txt", + }, { + name: "empty urlPath and slash-prefixed resourcePath", + urlPath: "", + resourcePath: "/foo", + want: "/foo", + }, { + name: "slash urlPath and slash-prefixed resourcePath", + urlPath: "/", + resourcePath: "/foo", + want: "/foo", + }, { + name: "empty urlPath and empty resourcePath", + urlPath: "", + resourcePath: "", + want: "/", + }, { + name: "non-slash-terminated urlPath and slash-prefixed resourcePath", + urlPath: "/foo", + resourcePath: "/bar", + want: "/foo/bar", + }, { + name: "slash-terminated urlPath and slash-prefixed resourcePath", + urlPath: "/foo/", + resourcePath: "/bar", + want: "/foo/bar", + }, { + name: "slash-terminated urlPath and non-slash-prefixed resourcePath", + urlPath: "/foo", + resourcePath: "bar", + want: "/foo/bar", + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := joinURLPath(tt.urlPath, tt.resourcePath) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) + } +} + +func Test_newRequest(t *testing.T) { + type args struct { + ctx context.Context + endpoint *Endpoint + desc *Descriptor + } + tests := []struct { + name string + args args + wantFn func(*testing.T, *http.Request) + wantErr error + }{{ + name: "url.Parse fails", + args: args{ + ctx: nil, + endpoint: &Endpoint{ + BaseURL: "\t\t\t", // does not parse! + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: "", + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: nil, + wantErr: errors.New(`parse "\t\t\t": net/url: invalid control character in URL`), + }, { + name: "http.NewRequestWithContext fails", + args: args{ + ctx: nil, // causes http.NewRequestWithContext to fail + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: "", + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: nil, + wantErr: errors.New("net/http: nil Context"), + }, { + name: "successful case with GET method, no body, and no extra headers", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + if req.Body != nil { + t.Fatal("invalid body", req.Body) + } + }, + wantErr: nil, + }, { + name: "successful case with POST method and body", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: 0, + Method: http.MethodPost, + RequestBody: []byte("deadbeef"), + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodPost { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + data, err := netxlite.ReadAllContext(context.Background(), req.Body) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff([]byte("deadbeef"), data); diff != "" { + t.Fatal(diff) + } + }, + wantErr: nil, + }, { + name: "with GET method and custom headers", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: nil, + Host: "antani.org", + UserAgent: "httpclient/1.0.1", + }, + desc: &Descriptor{ + Accept: "application/json", + Authorization: "deafbeef", + ContentType: "text/plain", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodPut, + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodPut { + t.Fatal("invalid method") + } + if req.Host != "antani.org" { + t.Fatal("invalid request host") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + if req.Header.Get("Authorization") != "deafbeef" { + t.Fatal("invalid authorization") + } + if req.Header.Get("Content-Type") != "text/plain" { + t.Fatal("invalid content-type") + } + if req.Header.Get("Accept") != "application/json" { + t.Fatal("invalid accept") + } + if req.Header.Get("User-Agent") != "httpclient/1.0.1" { + t.Fatal("invalid user-agent") + } + }, + wantErr: nil, + }, { + name: "we join the urlPath with the resourcePath", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/api/v1", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "/test-list/urls", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://www.example.com/api/v1/test-list/urls" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }, { + name: "we discard any query element inside the Endpoint.BaseURL", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.org/api/v1/?probe_cc=IT", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.org/api/v1/" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }, { + name: "we include query elements from Descriptor.URLQuery", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/api/v1/", + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: http.MethodGet, + RequestBody: nil, + Timeout: 0, + URLPath: "test-list/urls", + URLQuery: map[string][]string{ + "probe_cc": {"IT"}, + }, + }, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://www.example.com/api/v1/test-list/urls?probe_cc=IT" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }, { + name: "with as many implicitly-initialized fields as possible", + args: args{ + ctx: context.Background(), + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + }, + desc: &Descriptor{}, + }, + wantFn: func(t *testing.T, req *http.Request) { + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Method != http.MethodGet { + t.Fatal("invalid method") + } + if req.URL.String() != "https://example.com/" { + t.Fatal("invalid URL") + } + }, + wantErr: nil, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newRequest(tt.args.ctx, tt.args.endpoint, tt.args.desc) + switch { + case err == nil && tt.wantErr == nil: + // nothing + case err != nil && tt.wantErr == nil: + t.Fatalf("expected error but got %s", err.Error()) + case err == nil && tt.wantErr != nil: + t.Fatalf("expected %s but got ", tt.wantErr.Error()) + case err.Error() == tt.wantErr.Error(): + // nothing + default: + t.Fatalf("expected %s but got %s", err.Error(), tt.wantErr.Error()) + } + if tt.wantFn != nil { + tt.wantFn(t, got) + return + } + if got != nil { + t.Fatal("got response with nil tt.wantFn") + } + }) + } +} + +func TestCall(t *testing.T) { + type args struct { + ctx context.Context + desc *Descriptor + endpoint *Endpoint + } + tests := []struct { + name string + args args + want []byte + wantErr error + errfn func(t *testing.T, err error) + }{{ + name: "newRequest fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: nil, + MaxBodySize: 0, + Method: "", + RequestBody: nil, + Timeout: 0, + URLPath: "", + URLQuery: nil, + }, + endpoint: &Endpoint{ + BaseURL: "\t\t\t", // causes newRequest to fail + HTTPClient: nil, + Host: "", + UserAgent: "", + }, + }, + want: nil, + wantErr: errors.New(`parse "\t\t\t": net/url: invalid control character in URL`), + errfn: nil, + }, { + name: "endpoint.HTTPClient.Do fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + }, + want: nil, + wantErr: io.EOF, + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("unexpected error type") + } + }, + }, { + name: "reading body fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, netxlite.ECONNRESET + }, + }), + } + return resp, nil + }, + }, + }, + }, + want: nil, + wantErr: errors.New(netxlite.FailureConnectionReset), + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("unexpected error type") + } + }, + }, { + name: "status code indicates failure", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader("deadbeef")), + StatusCode: 403, + } + return resp, nil + }, + }, + }, + }, + want: nil, + wantErr: errors.New("httpapi: http request failed: 403"), + errfn: func(t *testing.T, err error) { + var expect *ErrHTTPRequestFailed + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "success with log body flag", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + LogBody: true, // as documented by this test's name + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader("deadbeef")), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + want: []byte("deadbeef"), + wantErr: nil, + errfn: nil, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Call(tt.args.ctx, tt.args.desc, tt.args.endpoint) + switch { + case err == nil && tt.wantErr == nil: + // nothing + case err != nil && tt.wantErr == nil: + t.Fatalf("expected error but got %s", err.Error()) + case err == nil && tt.wantErr != nil: + t.Fatalf("expected %s but got ", tt.wantErr.Error()) + case err.Error() == tt.wantErr.Error(): + // nothing + default: + t.Fatalf("expected %s but got %s", err.Error(), tt.wantErr.Error()) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) + } +} + +func TestCallWithJSONResponse(t *testing.T) { + type response struct { + Name string + Age int64 + } + expectedResponse := response{ + Name: "sbs", + Age: 99, + } + type args struct { + ctx context.Context + desc *Descriptor + endpoint *Endpoint + } + tests := []struct { + name string + args args + wantErr error + errfn func(*testing.T, error) + }{{ + name: "call fails", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "\t\t\t\t", // causes failure + }, + }, + wantErr: errors.New(`parse "\t\t\t\t": net/url: invalid control character in URL`), + errfn: nil, + }, { + name: "with error during httpClient.Do", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + }, + wantErr: io.EOF, + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "with error when reading the response body", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, netxlite.ECONNRESET + }, + }), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: errors.New(netxlite.FailureConnectionReset), + errfn: func(t *testing.T, err error) { + var expect *errMaybeCensorship + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "with HTTP failure", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader(`{"Name": "sbs", "Age": 99}`)), + StatusCode: 400, + } + return resp, nil + }, + }, + }, + }, + wantErr: errors.New("httpapi: http request failed: 400"), + errfn: func(t *testing.T, err error) { + var expect *ErrHTTPRequestFailed + if !errors.As(err, &expect) { + t.Fatal("invalid error type") + } + }, + }, { + name: "with good response and missing header", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Body: io.NopCloser(strings.NewReader(`{"Name": "sbs", "Age": 99}`)), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: nil, + errfn: nil, + }, { + name: "with good response and good header", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + Logger: model.DiscardLogger, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/a", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Header: http.Header{ + "Content-Type": {"application/json"}, + }, + Body: io.NopCloser(strings.NewReader(`{"Name": "sbs", "Age": 99}`)), + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: nil, + errfn: nil, + }, { + name: "response is not JSON", + args: args{ + ctx: context.Background(), + desc: &Descriptor{ + LogBody: false, + Logger: model.DiscardLogger, + Method: http.MethodGet, + }, + endpoint: &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + Header: http.Header{ + "Content-Type": {"application/json"}, + }, + Body: io.NopCloser(strings.NewReader(`{`)), // invalid JSON + StatusCode: 200, + } + return resp, nil + }, + }, + }, + }, + wantErr: errors.New("unexpected end of JSON input"), + errfn: nil, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var response response + err := CallWithJSONResponse(tt.args.ctx, tt.args.desc, tt.args.endpoint, &response) + switch { + case err == nil && tt.wantErr == nil: + if diff := cmp.Diff(expectedResponse, response); err != nil { + t.Fatal(diff) + } + case err != nil && tt.wantErr == nil: + t.Fatalf("expected error but got %s", err.Error()) + case err == nil && tt.wantErr != nil: + t.Fatalf("expected %s but got ", tt.wantErr.Error()) + case err.Error() == tt.wantErr.Error(): + // nothing + default: + t.Fatalf("expected %s but got %s", err.Error(), tt.wantErr.Error()) + } + if tt.errfn != nil { + tt.errfn(t, err) + } + }) + } +} + +func TestCallHonoursContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // should fail HTTP request immediately + desc := &Descriptor{ + LogBody: false, + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/robots.txt", + } + endpoint := &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: http.DefaultClient, + UserAgent: model.HTTPHeaderUserAgent, + } + body, err := Call(ctx, desc, endpoint) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err", err) + } + if len(body) > 0 { + t.Fatal("expected zero-length body") + } +} + +func TestCallWithJSONResponseHonoursContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // should fail HTTP request immediately + desc := &Descriptor{ + LogBody: false, + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/robots.txt", + } + endpoint := &Endpoint{ + BaseURL: "https://www.example.com/", + HTTPClient: http.DefaultClient, + UserAgent: model.HTTPHeaderUserAgent, + } + var resp url.URL + err := CallWithJSONResponse(ctx, desc, endpoint, &resp) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err", err) + } +} + +func TestDescriptorLogging(t *testing.T) { + + // This test was originally written for the httpx package and we have adapted it + // by keeping the ~same implementation with a custom callx function that converts + // the previous semantics of httpx to the new semantics of httpapi. + callx := func(baseURL string, logBody bool, logger model.Logger, request, response any) error { + desc := MustNewPOSTJSONWithJSONResponseDescriptor(logger, "/", request).WithBodyLogging(logBody) + runtimex.Assert(desc.LogBody == logBody, "desc.LogBody should be equal to logBody here") + endpoint := &Endpoint{ + BaseURL: baseURL, + HTTPClient: http.DefaultClient, + } + return CallWithJSONResponse(context.Background(), desc, endpoint, response) + } + + // we also needed to create a constructor for the logger + newlogger := func(logs chan string) model.Logger { + return &mocks.Logger{ + MockDebugf: func(format string, v ...interface{}) { + logs <- fmt.Sprintf(format, v...) + }, + MockWarnf: func(format string, v ...interface{}) { + logs <- fmt.Sprintf(format, v...) + }, + } + } + + t.Run("body logging enabled, 200 Ok, and without content-type", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, true, logger, input, &output) + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // we expect this because body logging is enabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // we expect this because body logging is enabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // we would expect this because the server does not send us any content-type + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we should see this because we sent a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we should see this because we receive a body + found |= 1 << 4 + continue + } + } + if found != (1<<0 | 1<<1 | 1<<2 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + if err != nil { + t.Fatal(err) + } + }) + + t.Run("body logging enabled, 200 Ok, and with content-type", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, true, logger, input, &output) + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // we expect this because body logging is enabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // we expect this because body logging is enabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // we do not expect this because the server sends us a content-type + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we should see this because we sent a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we should see this because we receive a body + found |= 1 << 4 + continue + } + } + if found != (1<<0 | 1<<1 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + if err != nil { + t.Fatal(err) + } + }) + + t.Run("body logging enabled and 401 Unauthorized", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, true, logger, input, &output) + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // should occur because body logging is enabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // should occur because body logging is enabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // note: this one should not occur because the code is 401 so we're not + // actually going to parse the JSON document + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we should see this because we send a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we should see this because we receive a body + found |= 1 << 4 + continue + } + } + if found != (1<<0 | 1<<1 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 401 { + t.Fatal("unexpected err", err) + } + }) + + t.Run("body logging NOT enabled and 200 Ok", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, false, logger, input, &output) // no logging + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // should not see it: body logging is disabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // should not see it: body logging is disabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // this one should be logged ANYWAY because it's orthogonal to the + // body logging so we should see it also in this case. + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // should see this because we send a body + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // should see this because we're receiving a body + found |= 1 << 4 + continue + } + } + if found != (1<<2 | 1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + if err != nil { + t.Fatal(err) + } + }) + + t.Run("body logging NOT enabled and 401 Unauthorized", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(401) + w.Write([]byte("[]")) + }, + )) + logs := make(chan string, 1024) + defer server.Close() + var ( + input []string + output []string + ) + logger := newlogger(logs) + err := callx(server.URL, false, logger, input, &output) // no logging + var found int + close(logs) + for entry := range logs { + if strings.HasPrefix(entry, "httpapi: request body: ") { + // should not see it: body logging is disabled + found |= 1 << 0 + continue + } + if strings.HasPrefix(entry, "httpapi: response body: ") { + // should not see it: body logging is disabled + found |= 1 << 1 + continue + } + if strings.HasPrefix(entry, "httpapi: unexpected content-type: ") { + // should not see it because we don't parse the body on 401 errors + found |= 1 << 2 + continue + } + if strings.HasPrefix(entry, "httpapi: request body length: ") { + // we send a body so we should see it + found |= 1 << 3 + continue + } + if strings.HasPrefix(entry, "httpapi: response body length: ") { + // we receive a body so we should see it + found |= 1 << 4 + continue + } + } + if found != (1<<3 | 1<<4) { + t.Fatal("did not find the expected logs") + } + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 401 { + t.Fatal("unexpected err", err) + } + }) +} + +func Test_errMaybeCensorship_Unwrap(t *testing.T) { + t.Run("for errors.Is", func(t *testing.T) { + var err error = &errMaybeCensorship{io.EOF} + if !errors.Is(err, io.EOF) { + t.Fatal("cannot unwrap") + } + }) + + t.Run("for errors.As", func(t *testing.T) { + var err error = &errMaybeCensorship{netxlite.ECONNRESET} + var syserr syscall.Errno + if !errors.As(err, &syserr) || syserr != netxlite.ECONNRESET { + t.Fatal("cannot unwrap") + } + }) +} diff --git a/internal/httpapi/descriptor.go b/internal/httpapi/descriptor.go new file mode 100644 index 0000000..ed35e85 --- /dev/null +++ b/internal/httpapi/descriptor.go @@ -0,0 +1,155 @@ +package httpapi + +// +// HTTP API descriptor (e.g., GET /api/v1/test-list/urls) +// + +import ( + "encoding/json" + "net/http" + "net/url" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +// Descriptor contains the parameters for calling a given HTTP +// API (e.g., GET /api/v1/test-list/urls). +// +// The zero value of this struct is invalid. Please, fill all the +// fields marked as MANDATORY for correct initialization. +type Descriptor struct { + // Accept contains the OPTIONAL accept header. + Accept string + + // Authorization is the OPTIONAL authorization. + Authorization string + + // ContentType is the OPTIONAL content-type header. + ContentType string + + // LogBody OPTIONALLY enables logging bodies. + LogBody bool + + // Logger is the MANDATORY logger to use. + // + // For example, model.DiscardLogger. + Logger model.Logger + + // MaxBodySize is the OPTIONAL maximum response body size. If + // not set, we use the |DefaultMaxBodySize| constant. + MaxBodySize int64 + + // Method is the MANDATORY request method. + Method string + + // RequestBody is the OPTIONAL request body. + RequestBody []byte + + // Timeout is the OPTIONAL timeout for this call. If no timeout + // is specified we will use the |DefaultCallTimeout| const. + Timeout time.Duration + + // URLPath is the MANDATORY URL path. + URLPath string + + // URLQuery is the OPTIONAL query. + URLQuery url.Values +} + +// WithBodyLogging returns a SHALLOW COPY of |Descriptor| with LogBody set to |value|. You SHOULD +// only use this method when initializing the descriptor you want to use. +func (desc *Descriptor) WithBodyLogging(value bool) *Descriptor { + out := &Descriptor{} + *out = *desc + out.LogBody = value + return out +} + +// DefaultMaxBodySize is the default value for the maximum +// body size you can fetch using the httpapi package. +const DefaultMaxBodySize = 1 << 22 + +// DefaultCallTimeout is the default timeout for an httpapi call. +const DefaultCallTimeout = 60 * time.Second + +// NewGETJSONDescriptor is a convenience factory for creating a new descriptor +// that uses the GET method and expects a JSON response. +func NewGETJSONDescriptor(logger model.Logger, urlPath string) *Descriptor { + return NewGETJSONWithQueryDescriptor(logger, urlPath, url.Values{}) +} + +// applicationJSON is the content-type for JSON +const applicationJSON = "application/json" + +// NewGETJSONWithQueryDescriptor is like NewGETJSONDescriptor but it also +// allows you to provide |query| arguments. Leaving |query| nil or empty +// is equivalent to calling NewGETJSONDescriptor directly. +func NewGETJSONWithQueryDescriptor(logger model.Logger, urlPath string, query url.Values) *Descriptor { + return &Descriptor{ + Accept: applicationJSON, + Authorization: "", + ContentType: "", + LogBody: false, + Logger: logger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: urlPath, + URLQuery: query, + } +} + +// NewPOSTJSONWithJSONResponseDescriptor creates a descriptor that POSTs a JSON document +// and expects to receive back a JSON document from the API. +// +// This function ONLY fails if we cannot serialize the |request| to JSON. So, if you know +// that |request| is JSON-serializable, you can safely call MustNewPostJSONWithJSONResponseDescriptor instead. +func NewPOSTJSONWithJSONResponseDescriptor(logger model.Logger, urlPath string, request any) (*Descriptor, error) { + rawRequest, err := json.Marshal(request) + if err != nil { + return nil, err + } + desc := &Descriptor{ + Accept: applicationJSON, + Authorization: "", + ContentType: applicationJSON, + LogBody: false, + Logger: logger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodPost, + RequestBody: rawRequest, + Timeout: DefaultCallTimeout, + URLPath: urlPath, + URLQuery: nil, + } + return desc, nil +} + +// MustNewPOSTJSONWithJSONResponseDescriptor is like NewPOSTJSONWithJSONResponseDescriptor except that +// it panics in case it's not possible to JSON serialize the |request|. +func MustNewPOSTJSONWithJSONResponseDescriptor(logger model.Logger, urlPath string, request any) *Descriptor { + desc, err := NewPOSTJSONWithJSONResponseDescriptor(logger, urlPath, request) + runtimex.PanicOnError(err, "NewPOSTJSONWithJSONResponseDescriptor failed") + return desc +} + +// NewGETResourceDescriptor creates a generic descriptor for GETting a +// resource of unspecified type using the given |urlPath|. +func NewGETResourceDescriptor(logger model.Logger, urlPath string) *Descriptor { + return &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: logger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: urlPath, + URLQuery: url.Values{}, + } +} diff --git a/internal/httpapi/descriptor_test.go b/internal/httpapi/descriptor_test.go new file mode 100644 index 0000000..c3c58c0 --- /dev/null +++ b/internal/httpapi/descriptor_test.go @@ -0,0 +1,248 @@ +package httpapi + +import ( + "log" + "net/http" + "net/url" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" +) + +func TestDescriptor_WithBodyLogging(t *testing.T) { + type fields struct { + Accept string + Authorization string + ContentType string + LogBody bool + Logger model.Logger + MaxBodySize int64 + Method string + RequestBody []byte + Timeout time.Duration + URLPath string + URLQuery url.Values + } + tests := []struct { + name string + fields fields + want *Descriptor + }{{ + name: "with empty fields", + fields: fields{}, // LogBody defaults to false + want: &Descriptor{ + LogBody: true, + }, + }, { + name: "with nonempty fields", + fields: fields{ + Accept: "xx", + Authorization: "y", + ContentType: "zzz", + LogBody: false, // obviously must be false + Logger: model.DiscardLogger, + MaxBodySize: 123, + Method: "POST", + RequestBody: []byte("123"), + Timeout: 15555, + URLPath: "/", + URLQuery: map[string][]string{ + "a": {"b"}, + }, + }, + want: &Descriptor{ + Accept: "xx", + Authorization: "y", + ContentType: "zzz", + LogBody: true, + Logger: model.DiscardLogger, + MaxBodySize: 123, + Method: "POST", + RequestBody: []byte("123"), + Timeout: 15555, + URLPath: "/", + URLQuery: map[string][]string{ + "a": {"b"}, + }, + }, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + desc := &Descriptor{ + Accept: tt.fields.Accept, + Authorization: tt.fields.Authorization, + ContentType: tt.fields.ContentType, + LogBody: tt.fields.LogBody, + Logger: tt.fields.Logger, + MaxBodySize: tt.fields.MaxBodySize, + Method: tt.fields.Method, + RequestBody: tt.fields.RequestBody, + Timeout: tt.fields.Timeout, + URLPath: tt.fields.URLPath, + URLQuery: tt.fields.URLQuery, + } + got := desc.WithBodyLogging(true) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) + } +} + +func TestNewGetJSONDescriptor(t *testing.T) { + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: url.Values{}, + } + got := NewGETJSONDescriptor(model.DiscardLogger, "/robots.txt") + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } +} + +func TestNewGetJSONWithQueryDescriptor(t *testing.T) { + query := url.Values{ + "a": {"b"}, + "c": {"d"}, + } + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: query, + } + got := NewGETJSONWithQueryDescriptor(model.DiscardLogger, "/robots.txt", query) + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } +} + +func TestNewPOSTJSONWithJSONResponseDescriptor(t *testing.T) { + type request struct { + Name string + Age int64 + } + + t.Run("with failure", func(t *testing.T) { + request := make(chan int64) + got, err := NewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + if err == nil || err.Error() != "json: unsupported type: chan int64" { + log.Fatal("unexpected err", err) + } + if got != nil { + log.Fatal("expected to get a nil Descriptor") + } + }) + + t.Run("with success", func(t *testing.T) { + request := request{ + Name: "sbs", + Age: 99, + } + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "application/json", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodPost, + RequestBody: []byte(`{"Name":"sbs","Age":99}`), + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: nil, + } + got, err := NewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + if err != nil { + log.Fatal(err) + } + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } + }) +} + +func TestMustNewPOSTJSONWithJSONResponseDescriptor(t *testing.T) { + type request struct { + Name string + Age int64 + } + + t.Run("with failure", func(t *testing.T) { + var panicked bool + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + request := make(chan int64) + _ = MustNewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + }() + if !panicked { + t.Fatal("did not panic") + } + }) + + t.Run("with success", func(t *testing.T) { + request := request{ + Name: "sbs", + Age: 99, + } + expected := &Descriptor{ + Accept: "application/json", + Authorization: "", + ContentType: "application/json", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodPost, + RequestBody: []byte(`{"Name":"sbs","Age":99}`), + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: nil, + } + got := MustNewPOSTJSONWithJSONResponseDescriptor(model.DiscardLogger, "/robots.txt", request) + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } + }) +} + +func TestNewGetResourceDescriptor(t *testing.T) { + expected := &Descriptor{ + Accept: "", + Authorization: "", + ContentType: "", + LogBody: false, + Logger: model.DiscardLogger, + MaxBodySize: DefaultMaxBodySize, + Method: http.MethodGet, + RequestBody: nil, + Timeout: DefaultCallTimeout, + URLPath: "/robots.txt", + URLQuery: url.Values{}, + } + got := NewGETResourceDescriptor(model.DiscardLogger, "/robots.txt") + if diff := cmp.Diff(expected, got); diff != "" { + t.Fatal(diff) + } +} diff --git a/internal/httpapi/doc.go b/internal/httpapi/doc.go new file mode 100644 index 0000000..0c1361c --- /dev/null +++ b/internal/httpapi/doc.go @@ -0,0 +1,15 @@ +// Package httpapi contains code for calling HTTP APIs. +// +// We model HTTP APIs as follows: +// +// 1. |Endpoint| is an API endpoint (e.g., https://api.ooni.io); +// +// 2. |Descriptor| describes the specific API you want to use (e.g., +// GET /api/v1/test-list/urls with JSON response body). +// +// Generally, you use |Call| to call the API identified by a |Descriptor| +// on the specified |Endpoint|. However, there are cases where you +// need more complex calling patterns. For example, with |SequenceCaller| +// you can invoke the same API |Descriptor| with multiple equivalent +// API |Endpoint|s until one of them succeeds or all fail. +package httpapi diff --git a/internal/httpapi/endpoint.go b/internal/httpapi/endpoint.go new file mode 100644 index 0000000..acfc4d8 --- /dev/null +++ b/internal/httpapi/endpoint.go @@ -0,0 +1,76 @@ +package httpapi + +// +// HTTP API Endpoint (e.g., https://api.ooni.io) +// + +import "github.com/ooni/probe-cli/v3/internal/model" + +// Endpoint models an HTTP endpoint on which you can call +// several HTTP APIs (e.g., https://api.ooni.io) using a +// given HTTP client potentially using a circumvention tunnel +// mechanism such as psiphon or torsf. +// +// The zero value of this struct is invalid. Please, fill all the +// fields marked as MANDATORY for correct initialization. +type Endpoint struct { + // BaseURL is the MANDATORY endpoint base URL. We will honour the + // path of this URL and prepend it to the actual path specified inside + // a |Descriptor.URLPath|. However, we will always discard any query + // that may have been set inside the BaseURL. The only query string + // will be composed from the |Descriptor.URLQuery| values. + // + // For example, https://api.ooni.io. + BaseURL string + + // HTTPClient is the MANDATORY HTTP client to use. + // + // For example, http.DefaultClient. You can introduce circumvention + // here by using an HTTPClient bound to a specific tunnel. + HTTPClient model.HTTPClient + + // Host is the OPTIONAL host header to use. + // + // If this field is empty we use the BaseURL's hostname. A specific + // host header may be needed when using cloudfronting. + Host string + + // User-Agent is the OPTIONAL user-agent to use. If empty, + // we'll use the stdlib's default user-agent string. + UserAgent string +} + +// NewEndpointList constructs a list of API endpoints from |services| +// returned by the OONI backend (or known in advance). +// +// Arguments: +// +// - httpClient is the HTTP client to use for accessing the endpoints; +// +// - userAgent is the user agent you would like to use; +// +// - service is the list of services gathered from the backend. +func NewEndpointList(httpClient model.HTTPClient, + userAgent string, services ...model.OOAPIService) (out []*Endpoint) { + for _, svc := range services { + switch svc.Type { + case "https": + out = append(out, &Endpoint{ + BaseURL: svc.Address, + HTTPClient: httpClient, + Host: "", + UserAgent: userAgent, + }) + case "cloudfront": + out = append(out, &Endpoint{ + BaseURL: svc.Address, + HTTPClient: httpClient, + Host: svc.Front, + UserAgent: userAgent, + }) + default: + // nothing! + } + } + return +} diff --git a/internal/httpapi/endpoint_test.go b/internal/httpapi/endpoint_test.go new file mode 100644 index 0000000..7077e14 --- /dev/null +++ b/internal/httpapi/endpoint_test.go @@ -0,0 +1,69 @@ +package httpapi + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestNewEndpointList(t *testing.T) { + type args struct { + httpClient model.HTTPClient + userAgent string + services []model.OOAPIService + } + defaultHTTPClient := &mocks.HTTPClient{} + tests := []struct { + name string + args args + wantOut []*Endpoint + }{{ + name: "with no services", + args: args{ + httpClient: defaultHTTPClient, + userAgent: model.HTTPHeaderUserAgent, + services: nil, + }, + wantOut: nil, + }, { + name: "common cases", + args: args{ + httpClient: defaultHTTPClient, + userAgent: model.HTTPHeaderUserAgent, + services: []model.OOAPIService{{ + Address: "https://www.example.com/", + Type: "https", + Front: "", + }, { + Address: "https://www.example.org/", + Type: "cloudfront", + Front: "example.org.it", + }, { + Address: "https://nonexistent.onion/", + Type: "onion", + Front: "", + }}, + }, + wantOut: []*Endpoint{{ + BaseURL: "https://www.example.com/", + HTTPClient: defaultHTTPClient, + Host: "", + UserAgent: model.HTTPHeaderUserAgent, + }, { + BaseURL: "https://www.example.org/", + HTTPClient: defaultHTTPClient, + Host: "example.org.it", + UserAgent: model.HTTPHeaderUserAgent, + }}, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOut := NewEndpointList(tt.args.httpClient, tt.args.userAgent, tt.args.services...) + if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" { + t.Fatal(diff) + } + }) + } +} diff --git a/internal/httpapi/sequence.go b/internal/httpapi/sequence.go new file mode 100644 index 0000000..da1f11d --- /dev/null +++ b/internal/httpapi/sequence.go @@ -0,0 +1,92 @@ +package httpapi + +// +// Sequentially call available API endpoints until one succeed +// or all of them fail. A future implementation of this code may +// (probably should?) take into account knowledge of what is +// working and what is not working to optimize the order with +// which to try different alternatives. +// + +import ( + "context" + "errors" + + "github.com/ooni/probe-cli/v3/internal/multierror" +) + +// SequenceCaller calls the API specified by |Descriptor| once for each of +// the available |Endpoints| until one of them succeeds. +// +// CAVEAT: this code will ONLY retry API calls with subsequent endpoints when +// the error originates in the HTTP round trip or while reading the body. +type SequenceCaller struct { + // Descriptor is the API |Descriptor|. + Descriptor *Descriptor + + // Endpoints is the list of |Endpoint| to use. + Endpoints []*Endpoint +} + +// NewSequenceCaller is a factory for creating a |SequenceCaller|. +func NewSequenceCaller(desc *Descriptor, endpoints ...*Endpoint) *SequenceCaller { + return &SequenceCaller{ + Descriptor: desc, + Endpoints: endpoints, + } +} + +// ErrAllEndpointsFailed indicates that all endpoints failed. +var ErrAllEndpointsFailed = errors.New("httpapi: all endpoints failed") + +// shouldRetry returns true when we should try with another endpoint given the +// value of |err| which could (obviously) be nil in case of success. +func (sc *SequenceCaller) shouldRetry(err error) bool { + var kind *errMaybeCensorship + belongs := errors.As(err, &kind) + return belongs +} + +// Call calls |Call| for each |Endpoint| and |Descriptor| until one endpoint succeeds. The +// return value is the response body and the selected endpoint index or the error. +// +// CAVEAT: this code will ONLY retry API calls with subsequent endpoints when +// the error originates in the HTTP round trip or while reading the body. +func (sc *SequenceCaller) Call(ctx context.Context) ([]byte, int, error) { + var selected int + merr := multierror.New(ErrAllEndpointsFailed) + for _, epnt := range sc.Endpoints { + respBody, err := Call(ctx, sc.Descriptor, epnt) + if sc.shouldRetry(err) { + merr.Add(err) + selected++ + continue + } + // Note: some errors will lead us to return + // early as documented for this method + return respBody, selected, err + } + return nil, -1, merr +} + +// CallWithJSONResponse is like |SequenceCaller.Call| except that it invokes the +// underlying |CallWithJSONResponse| rather than invoking |Call|. +// +// CAVEAT: this code will ONLY retry API calls with subsequent endpoints when +// the error originates in the HTTP round trip or while reading the body. +func (sc *SequenceCaller) CallWithJSONResponse(ctx context.Context, response any) (int, error) { + var selected int + merr := multierror.New(ErrAllEndpointsFailed) + for _, epnt := range sc.Endpoints { + err := CallWithJSONResponse(ctx, sc.Descriptor, epnt, response) + if sc.shouldRetry(err) { + merr.Add(err) + selected++ + continue + } + // Note: some errors will lead us to return + // early as documented for this method + return selected, err + } + return -1, merr +} diff --git a/internal/httpapi/sequence_test.go b/internal/httpapi/sequence_test.go new file mode 100644 index 0000000..13cc50f --- /dev/null +++ b/internal/httpapi/sequence_test.go @@ -0,0 +1,358 @@ +package httpapi + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestSequenceCaller(t *testing.T) { + t.Run("Call", func(t *testing.T) { + t.Run("first success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("deadbeef")), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + if err != nil { + t.Fatal(err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff([]byte("deadbeef"), data); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("first HTTP failure and we immediately stop", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 403, // should cause us to return early + Body: io.NopCloser(strings.NewReader("deadbeef")), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 403 { + t.Fatal("unexpected err", err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if len(data) > 0 { + t.Fatal("expected to see no response body") + } + }) + + t.Run("first network failure, second success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to cycle to the second entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("abad1dea")), + } + return resp, nil + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + if err != nil { + t.Fatal(err) + } + if idx != 1 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff([]byte("abad1dea"), data); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("all network failure", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to cycle to the next entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to cycle to the next entry + }, + }, + }, + ) + data, idx, err := sc.Call(context.Background()) + if !errors.Is(err, ErrAllEndpointsFailed) { + t.Fatal("unexpected err", err) + } + if idx != -1 { + t.Fatal("invalid idx") + } + if len(data) > 0 { + t.Fatal("expected zero-length data") + } + }) + }) + + t.Run("CallWithJSONResponse", func(t *testing.T) { + type response struct { + Name string + Age int64 + } + + t.Run("first success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"Name":"sbs","Age":99}`)), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{}`)), // different + } + return resp, nil + }, + }, + }, + ) + expect := response{ + Name: "sbs", + Age: 99, + } + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + if err != nil { + t.Fatal(err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("first HTTP failure and we immediately stop", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 403, // should be enough to cause us fail immediately + Body: io.NopCloser(strings.NewReader(`{"Age": 155, "Name": "sbs"}`)), + } + return resp, nil + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + }, + ) + // even though there is a JSON body we don't care about reading it + // and so we expect to see in output the zero-value struct + expect := response{ + Name: "", + Age: 0, + } + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + var failure *ErrHTTPRequestFailed + if !errors.As(err, &failure) || failure.StatusCode != 403 { + t.Fatal("unexpected err", err) + } + if idx != 0 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("first network failure, second success", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to try the next entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"Age":155}`)), + } + return resp, nil + }, + }, + }, + ) + expect := response{ + Name: "", + Age: 155, + } + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + if err != nil { + t.Fatal(err) + } + if idx != 1 { + t.Fatal("invalid idx") + } + if diff := cmp.Diff(expect, got); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("all network failure", func(t *testing.T) { + sc := NewSequenceCaller( + &Descriptor{ + Logger: model.DiscardLogger, + Method: http.MethodGet, + URLPath: "/", + }, + &Endpoint{ + BaseURL: "https://a.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to try the next entry + }, + }, + }, + &Endpoint{ + BaseURL: "https://b.example.com/", + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF // should cause us to try the next entry + }, + }, + }, + ) + var got response + idx, err := sc.CallWithJSONResponse(context.Background(), &got) + if !errors.Is(err, ErrAllEndpointsFailed) { + t.Fatal("unexpected err", err) + } + if idx != -1 { + t.Fatal("invalid idx") + } + }) + }) +} diff --git a/internal/httpx/httpx.go b/internal/httpx/httpx.go index f330745..8083f10 100644 --- a/internal/httpx/httpx.go +++ b/internal/httpx/httpx.go @@ -1,4 +1,8 @@ // Package httpx contains http extensions. +// +// Deprecated: new code should use httpapi instead. While this package and httpapi +// are basically using the same implementation, the API exposed by httpapi allows +// us to try the same request with multiple HTTP endpoints. package httpx import (