From ae32ffa13b56d2210aef883b3cb4d7043c52adab Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Thu, 4 Mar 2021 11:36:41 +0100 Subject: [PATCH 1/6] chore: set version to 3.7.0-alpha (#240) While there, make sure we don't always skip a currently failing riseupvpn test, and slightly clarify the readme. --- Readme.md | 4 ++-- internal/engine/experiment/riseupvpn/riseupvpn_test.go | 4 +++- internal/version/version.go | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Readme.md b/Readme.md index 5757837..7abcd38 100644 --- a/Readme.md +++ b/Readme.md @@ -1,8 +1,8 @@ -# OONI Probe CLI +# OONI Probe Client Library and CLI [![GoDoc](https://godoc.org/github.com/ooni/probe-cli?status.svg)](https://godoc.org/github.com/ooni/probe-cli) [![Short Tests Status](https://github.com/ooni/probe-cli/workflows/shorttests/badge.svg)](https://github.com/ooni/probe-cli/actions?query=workflow%3Ashorttests) [![All Tests Status](https://github.com/ooni/probe-cli/workflows/alltests/badge.svg)](https://github.com/ooni/probe-cli/actions?query=workflow%3Aalltests) [![Coverage Status](https://coveralls.io/repos/github/ooni/probe-cli/badge.svg?branch=master)](https://coveralls.io/github/ooni/probe-cli?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/ooni/probe-cli)](https://goreportcard.com/report/github.com/ooni/probe-cli) [![linux-debian-packages](https://github.com/ooni/probe-cli/workflows/linux-debian-packages/badge.svg)](https://github.com/ooni/probe-cli/actions?query=workflow%3Alinux-debian-packages) [![GitHub issues by-label](https://img.shields.io/github/issues/ooni/probe/ooni/probe-cli?style=plastic)](https://github.com/ooni/probe/labels/ooni%2Fprobe-cli) -The next generation OONI Probe: library and Command Line Interface. +The next generation OONI Probe: client library and Command Line Interface. ## User setup diff --git a/internal/engine/experiment/riseupvpn/riseupvpn_test.go b/internal/engine/experiment/riseupvpn/riseupvpn_test.go index 81a57cb..3798d5d 100644 --- a/internal/engine/experiment/riseupvpn/riseupvpn_test.go +++ b/internal/engine/experiment/riseupvpn/riseupvpn_test.go @@ -278,7 +278,9 @@ func TestFailureGeoIpServiceBlocked(t *testing.T) { } func TestFailureGateway(t *testing.T) { - t.Skip("test currently not WAI - will restore after release") + if testing.Short() { + t.Skip("skip test in short mode") + } var testCases = [...]string{"openvpn", "obfs4"} eipService, err := fetchEipService() if err != nil { diff --git a/internal/version/version.go b/internal/version/version.go index b57de1a..f3a8148 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -3,5 +3,5 @@ package version const ( // Version is the software version - Version = "3.6.0" + Version = "3.7.0-alpha" ) From 55bdebe8b2241b87238486a625e41d1c6daae31f Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Thu, 4 Mar 2021 11:51:07 +0100 Subject: [PATCH 2/6] engine/ooapi: autogenerated API with login and caching (#234) * internal/engine/ooapi: auto-generated API client * feat: introduce the callers abstraction * feat: implement API caching on disk * feat: implement cloneWithToken when we require login * feat: implement login * fix: do not cache all APIs * feat: start making space for more tests * feat: implement caching policy * feat: write tests for caching layer * feat: add integration tests and fix some minor issues * feat: write much more unit tests * feat: add some more easy unit tests * feat: add tests that use a local server While there, make sure many fields we care about are OK. * doc: write basic documentation * fix: tweak sentence * doc: improve ooapi documentation * doc(ooapi): other documentation improvements * fix(ooapi): remove caching for most APIs We discussed this topic yesterday with @FedericoCeratto. The only place where we want LRU caching is MeasurementMeta. * feat(ooapi): improve handling of errors during login This was also discussed yesterday with @FedericoCeratto * fix(swaggerdiff_test.go): temporarily disable Before I work on this, I need to tend onto other tasks. * fix(ootest): add one more test case We're going towards 100% coverage of this package, as it ought to be. * feat(ooapi): test cases for when the probe clock is off * fix(ooapi): change test to have 100% unittest coverage * feat: sync server and client APIs definition Companion PR: https://github.com/ooni/api/pull/218 * fix(ooapi): start testing again against API * fix(ooapi): only generate each file once * chore: set version to 3.7.0-alpha While there, make sure we don't always skip a currently failing riseupvpn test, and slightly clarify the readme. * fix(kvstore): less scoped error message --- go.mod | 1 + go.sum | 2 + internal/engine/ooapi/README.md | 5 + internal/engine/ooapi/apimodel/checkin.go | 47 + .../engine/ooapi/apimodel/checkreportid.go | 13 + internal/engine/ooapi/apimodel/doc.go | 22 + internal/engine/ooapi/apimodel/login.go | 15 + .../engine/ooapi/apimodel/measurementmeta.go | 25 + internal/engine/ooapi/apimodel/openreport.go | 21 + .../engine/ooapi/apimodel/psiphonconfig.go | 7 + internal/engine/ooapi/apimodel/register.go | 26 + .../ooapi/apimodel/submitmeasurement.go | 13 + internal/engine/ooapi/apimodel/testhelpers.go | 15 + internal/engine/ooapi/apimodel/tortargets.go | 16 + internal/engine/ooapi/apimodel/urls.go | 26 + internal/engine/ooapi/apis.go | 607 ++++ internal/engine/ooapi/apis_test.go | 2776 +++++++++++++++++ internal/engine/ooapi/caching.go | 98 + internal/engine/ooapi/caching_test.go | 222 ++ internal/engine/ooapi/callers.go | 78 + internal/engine/ooapi/cloners.go | 18 + internal/engine/ooapi/default.go | 57 + internal/engine/ooapi/default_test.go | 41 + internal/engine/ooapi/dependencies.go | 54 + internal/engine/ooapi/doc.go | 163 + internal/engine/ooapi/errors.go | 14 + internal/engine/ooapi/fake_test.go | 96 + internal/engine/ooapi/fakeapi_test.go | 190 ++ internal/engine/ooapi/fakefill_test.go | 146 + internal/engine/ooapi/integration_test.go | 204 ++ .../engine/ooapi/internal/generator/apis.go | 180 ++ .../ooapi/internal/generator/apistest.go | 461 +++ .../ooapi/internal/generator/caching.go | 130 + .../ooapi/internal/generator/cachingtest.go | 274 ++ .../ooapi/internal/generator/callers.go | 35 + .../ooapi/internal/generator/cloners.go | 32 + .../ooapi/internal/generator/fakeapitest.go | 59 + .../ooapi/internal/generator/generator.go | 53 + .../engine/ooapi/internal/generator/login.go | 182 ++ .../ooapi/internal/generator/logintest.go | 864 +++++ .../ooapi/internal/generator/reflect.go | 147 + .../ooapi/internal/generator/requests.go | 141 + .../ooapi/internal/generator/responses.go | 80 + .../engine/ooapi/internal/generator/spec.go | 136 + .../ooapi/internal/generator/swaggertest.go | 194 ++ .../ooapi/internal/generator/writefile.go | 27 + .../engine/ooapi/internal/openapi/openapi.go | 64 + internal/engine/ooapi/kvstore_test.go | 34 + internal/engine/ooapi/login.go | 295 ++ internal/engine/ooapi/login_test.go | 1365 ++++++++ internal/engine/ooapi/loginhandler_test.go | 204 ++ internal/engine/ooapi/loginmodel.go | 59 + internal/engine/ooapi/requests.go | 192 ++ internal/engine/ooapi/responses.go | 276 ++ internal/engine/ooapi/swagger_test.go | 578 ++++ internal/engine/ooapi/swaggerdiff_test.go | 158 + internal/engine/ooapi/utils.go | 23 + internal/engine/ooapi/utils_test.go | 12 + 58 files changed, 11273 insertions(+) create mode 100644 internal/engine/ooapi/README.md create mode 100644 internal/engine/ooapi/apimodel/checkin.go create mode 100644 internal/engine/ooapi/apimodel/checkreportid.go create mode 100644 internal/engine/ooapi/apimodel/doc.go create mode 100644 internal/engine/ooapi/apimodel/login.go create mode 100644 internal/engine/ooapi/apimodel/measurementmeta.go create mode 100644 internal/engine/ooapi/apimodel/openreport.go create mode 100644 internal/engine/ooapi/apimodel/psiphonconfig.go create mode 100644 internal/engine/ooapi/apimodel/register.go create mode 100644 internal/engine/ooapi/apimodel/submitmeasurement.go create mode 100644 internal/engine/ooapi/apimodel/testhelpers.go create mode 100644 internal/engine/ooapi/apimodel/tortargets.go create mode 100644 internal/engine/ooapi/apimodel/urls.go create mode 100644 internal/engine/ooapi/apis.go create mode 100644 internal/engine/ooapi/apis_test.go create mode 100644 internal/engine/ooapi/caching.go create mode 100644 internal/engine/ooapi/caching_test.go create mode 100644 internal/engine/ooapi/callers.go create mode 100644 internal/engine/ooapi/cloners.go create mode 100644 internal/engine/ooapi/default.go create mode 100644 internal/engine/ooapi/default_test.go create mode 100644 internal/engine/ooapi/dependencies.go create mode 100644 internal/engine/ooapi/doc.go create mode 100644 internal/engine/ooapi/errors.go create mode 100644 internal/engine/ooapi/fake_test.go create mode 100644 internal/engine/ooapi/fakeapi_test.go create mode 100644 internal/engine/ooapi/fakefill_test.go create mode 100644 internal/engine/ooapi/integration_test.go create mode 100644 internal/engine/ooapi/internal/generator/apis.go create mode 100644 internal/engine/ooapi/internal/generator/apistest.go create mode 100644 internal/engine/ooapi/internal/generator/caching.go create mode 100644 internal/engine/ooapi/internal/generator/cachingtest.go create mode 100644 internal/engine/ooapi/internal/generator/callers.go create mode 100644 internal/engine/ooapi/internal/generator/cloners.go create mode 100644 internal/engine/ooapi/internal/generator/fakeapitest.go create mode 100644 internal/engine/ooapi/internal/generator/generator.go create mode 100644 internal/engine/ooapi/internal/generator/login.go create mode 100644 internal/engine/ooapi/internal/generator/logintest.go create mode 100644 internal/engine/ooapi/internal/generator/reflect.go create mode 100644 internal/engine/ooapi/internal/generator/requests.go create mode 100644 internal/engine/ooapi/internal/generator/responses.go create mode 100644 internal/engine/ooapi/internal/generator/spec.go create mode 100644 internal/engine/ooapi/internal/generator/swaggertest.go create mode 100644 internal/engine/ooapi/internal/generator/writefile.go create mode 100644 internal/engine/ooapi/internal/openapi/openapi.go create mode 100644 internal/engine/ooapi/kvstore_test.go create mode 100644 internal/engine/ooapi/login.go create mode 100644 internal/engine/ooapi/login_test.go create mode 100644 internal/engine/ooapi/loginhandler_test.go create mode 100644 internal/engine/ooapi/loginmodel.go create mode 100644 internal/engine/ooapi/requests.go create mode 100644 internal/engine/ooapi/responses.go create mode 100644 internal/engine/ooapi/swagger_test.go create mode 100644 internal/engine/ooapi/swaggerdiff_test.go create mode 100644 internal/engine/ooapi/utils.go create mode 100644 internal/engine/ooapi/utils_test.go diff --git a/go.mod b/go.mod index c81527c..b24c571 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.2.0 github.com/gorilla/websocket v1.4.2 + github.com/hexops/gotextdiff v1.0.3 github.com/iancoleman/strcase v0.1.3 github.com/lucas-clemente/quic-go v0.19.3 github.com/marten-seemann/qtls-go1-15 v0.1.2 // indirect diff --git a/go.sum b/go.sum index 98d1875..4547262 100644 --- a/go.sum +++ b/go.sum @@ -218,6 +218,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hinshun/vt10x v0.0.0-20180616224451-1954e6464174 h1:WlZsjVhE8Af9IcZDGgJGQpNflI3+MJSBhsgT5PCtzBQ= github.com/hinshun/vt10x v0.0.0-20180616224451-1954e6464174/go.mod h1:DqJ97dSdRW1W22yXSB90986pcOyQ7r45iio1KN2ez1A= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/internal/engine/ooapi/README.md b/internal/engine/ooapi/README.md new file mode 100644 index 0000000..63de899 --- /dev/null +++ b/internal/engine/ooapi/README.md @@ -0,0 +1,5 @@ +# Package ./internal/engine/ooapi + +Automatically generated API clients for speaking with OONI servers. + +Please, run `go doc ./internal/engine/ooapi` to see API documentation. diff --git a/internal/engine/ooapi/apimodel/checkin.go b/internal/engine/ooapi/apimodel/checkin.go new file mode 100644 index 0000000..8df8085 --- /dev/null +++ b/internal/engine/ooapi/apimodel/checkin.go @@ -0,0 +1,47 @@ +package apimodel + +// CheckInRequestWebConnectivity contains WebConnectivity +// specific parameters to include into CheckInRequest +type CheckInRequestWebConnectivity struct { + CategoryCodes []string `json:"category_codes"` +} + +// CheckInRequest is the check-in API request +type CheckInRequest struct { + Charging bool `json:"charging"` + OnWiFi bool `json:"on_wifi"` + Platform string `json:"platform"` + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + RunType string `json:"run_type"` + SoftwareName string `json:"software_name"` + SoftwareVersion string `json:"software_version"` + WebConnectivity CheckInRequestWebConnectivity `json:"web_connectivity"` +} + +// CheckInResponseURLInfo contains information about an URL. +type CheckInResponseURLInfo struct { + CategoryCode string `json:"category_code"` + CountryCode string `json:"country_code"` + URL string `json:"url"` +} + +// CheckInResponseWebConnectivity contains WebConnectivity +// specific information of a CheckInResponse +type CheckInResponseWebConnectivity struct { + ReportID string `json:"report_id"` + URLs []CheckInResponseURLInfo `json:"urls"` +} + +// CheckInResponse is the check-in API response +type CheckInResponse struct { + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + Tests CheckInResponseTests `json:"tests"` + V int64 `json:"v"` +} + +// CheckInResponseTests contains configuration for tests +type CheckInResponseTests struct { + WebConnectivity CheckInResponseWebConnectivity `json:"web_connectivity"` +} diff --git a/internal/engine/ooapi/apimodel/checkreportid.go b/internal/engine/ooapi/apimodel/checkreportid.go new file mode 100644 index 0000000..068a047 --- /dev/null +++ b/internal/engine/ooapi/apimodel/checkreportid.go @@ -0,0 +1,13 @@ +package apimodel + +// CheckReportIDRequest is the CheckReportID request. +type CheckReportIDRequest struct { + ReportID string `query:"report_id" required:"true"` +} + +// CheckReportIDResponse is the CheckReportID response. +type CheckReportIDResponse struct { + Error string `json:"error"` + Found bool `json:"found"` + V int64 `json:"v"` +} diff --git a/internal/engine/ooapi/apimodel/doc.go b/internal/engine/ooapi/apimodel/doc.go new file mode 100644 index 0000000..da78f36 --- /dev/null +++ b/internal/engine/ooapi/apimodel/doc.go @@ -0,0 +1,22 @@ +// Package apimodel describes the data types used by OONI's API. +// +// If you edit this package to integrate the data model, remember to +// run `go generate ./...`. +// +// We annotate fields with tagging. When a field should be sent +// over as JSON, use the usual `json` tag. +// +// When a field needs to be sent using the query string, use +// the `query` tag instead. We limit what can be sent using the +// query string to int64, string, and bool. +// +// The `path` tag indicates that the URL path contains a +// template. We will replace the value of this field with +// the template. Note that the template should use the +// Go name of the field (e.g. `{{ .ReportID }}`) as opposed +// to the name in the tag, which is only used when we +// generate the API Swagger. +// +// The `required` tag indicates required fields. A required +// field cannot be empty (for the Go definition of empty). +package apimodel diff --git a/internal/engine/ooapi/apimodel/login.go b/internal/engine/ooapi/apimodel/login.go new file mode 100644 index 0000000..408b347 --- /dev/null +++ b/internal/engine/ooapi/apimodel/login.go @@ -0,0 +1,15 @@ +package apimodel + +import "time" + +// LoginRequest is the login API request +type LoginRequest struct { + ClientID string `json:"username"` + Password string `json:"password"` +} + +// LoginResponse is the login API response +type LoginResponse struct { + Expire time.Time `json:"expire"` + Token string `json:"token"` +} diff --git a/internal/engine/ooapi/apimodel/measurementmeta.go b/internal/engine/ooapi/apimodel/measurementmeta.go new file mode 100644 index 0000000..e97da69 --- /dev/null +++ b/internal/engine/ooapi/apimodel/measurementmeta.go @@ -0,0 +1,25 @@ +package apimodel + +// MeasurementMetaRequest is the MeasurementMeta Request. +type MeasurementMetaRequest struct { + ReportID string `query:"report_id" required:"true"` + Full bool `query:"full"` + Input string `query:"input"` +} + +// MeasurementMetaResponse is the MeasurementMeta Response. +type MeasurementMetaResponse struct { + Anomaly bool `json:"anomaly"` + CategoryCode string `json:"category_code"` + Confirmed bool `json:"confirmed"` + Failure bool `json:"failure"` + Input string `json:"input"` + MeasurementStartTime string `json:"measurement_start_time"` + ProbeASN int64 `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + RawMeasurement string `json:"raw_measurement"` + ReportID string `json:"report_id"` + Scores string `json:"scores"` + TestName string `json:"test_name"` + TestStartTime string `json:"test_start_time"` +} diff --git a/internal/engine/ooapi/apimodel/openreport.go b/internal/engine/ooapi/apimodel/openreport.go new file mode 100644 index 0000000..4432bba --- /dev/null +++ b/internal/engine/ooapi/apimodel/openreport.go @@ -0,0 +1,21 @@ +package apimodel + +// OpenReportRequest is the OpenReport request. +type OpenReportRequest struct { + DataFormatVersion string `json:"data_format_version"` + Format string `json:"format"` + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + SoftwareName string `json:"software_name"` + SoftwareVersion string `json:"software_version"` + TestName string `json:"test_name"` + TestStartTime string `json:"test_start_time"` + TestVersion string `json:"test_version"` +} + +// OpenReportResponse is the OpenReport response. +type OpenReportResponse struct { + BackendVersion string `json:"backend_version"` + ReportID string `json:"report_id"` + SupportedFormats []string `json:"supported_formats"` +} diff --git a/internal/engine/ooapi/apimodel/psiphonconfig.go b/internal/engine/ooapi/apimodel/psiphonconfig.go new file mode 100644 index 0000000..40f9726 --- /dev/null +++ b/internal/engine/ooapi/apimodel/psiphonconfig.go @@ -0,0 +1,7 @@ +package apimodel + +// PsiphonConfigRequest is the request for the PsiphonConfig API +type PsiphonConfigRequest struct{} + +// PsiphonConfigResponse is the response from the PsiphonConfig API +type PsiphonConfigResponse map[string]interface{} diff --git a/internal/engine/ooapi/apimodel/register.go b/internal/engine/ooapi/apimodel/register.go new file mode 100644 index 0000000..e167306 --- /dev/null +++ b/internal/engine/ooapi/apimodel/register.go @@ -0,0 +1,26 @@ +package apimodel + +// RegisterRequest is the request for the Register API. +type RegisterRequest struct { + // just password + Password string `json:"password"` + + // metadata + AvailableBandwidth string `json:"available_bandwidth,omitempty"` + DeviceToken string `json:"device_token,omitempty"` + Language string `json:"language,omitempty"` + NetworkType string `json:"network_type,omitempty"` + Platform string `json:"platform"` + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + ProbeFamily string `json:"probe_family,omitempty"` + ProbeTimezone string `json:"probe_timezone,omitempty"` + SoftwareName string `json:"software_name"` + SoftwareVersion string `json:"software_version"` + SupportedTests []string `json:"supported_tests"` +} + +// RegisterResponse is the response from the Register API. +type RegisterResponse struct { + ClientID string `json:"client_id"` +} diff --git a/internal/engine/ooapi/apimodel/submitmeasurement.go b/internal/engine/ooapi/apimodel/submitmeasurement.go new file mode 100644 index 0000000..da542d5 --- /dev/null +++ b/internal/engine/ooapi/apimodel/submitmeasurement.go @@ -0,0 +1,13 @@ +package apimodel + +// SubmitMeasurementRequest is the SubmitMeasurement request. +type SubmitMeasurementRequest struct { + ReportID string `path:"report_id"` + Format string `json:"format"` + Content interface{} `json:"content"` +} + +// SubmitMeasurementResponse is the SubmitMeasurement response. +type SubmitMeasurementResponse struct { + MeasurementUID string `json:"measurement_uid"` +} diff --git a/internal/engine/ooapi/apimodel/testhelpers.go b/internal/engine/ooapi/apimodel/testhelpers.go new file mode 100644 index 0000000..775a40d --- /dev/null +++ b/internal/engine/ooapi/apimodel/testhelpers.go @@ -0,0 +1,15 @@ +package apimodel + +// TestHelpersRequest is the TestHelpers request. +type TestHelpersRequest struct{} + +// TestHelpersResponse is the TestHelpers response. +type TestHelpersResponse map[string][]TestHelpersHelperInfo + +// TestHelpersHelperInfo is a single helper within the +// response returned by the TestHelpers API. +type TestHelpersHelperInfo struct { + Address string `json:"address"` + Type string `json:"type"` + Front string `json:"front,omitempty"` +} diff --git a/internal/engine/ooapi/apimodel/tortargets.go b/internal/engine/ooapi/apimodel/tortargets.go new file mode 100644 index 0000000..a4d7c74 --- /dev/null +++ b/internal/engine/ooapi/apimodel/tortargets.go @@ -0,0 +1,16 @@ +package apimodel + +// TorTargetsRequest is a request for the TorTargets API. +type TorTargetsRequest struct{} + +// TorTargetsResponse is the response from the TorTargets API. +type TorTargetsResponse map[string]TorTargetsTarget + +// TorTargetsTarget is a target for the tor experiment. +type TorTargetsTarget struct { + Address string `json:"address"` + Name string `json:"name"` + Params map[string][]string `json:"params"` + Protocol string `json:"protocol"` + Source string `json:"source"` +} diff --git a/internal/engine/ooapi/apimodel/urls.go b/internal/engine/ooapi/apimodel/urls.go new file mode 100644 index 0000000..dd9094f --- /dev/null +++ b/internal/engine/ooapi/apimodel/urls.go @@ -0,0 +1,26 @@ +package apimodel + +// URLsRequest is the URLs request. +type URLsRequest struct { + CategoryCodes string `query:"category_codes"` + CountryCode string `query:"country_code"` + Limit int64 `query:"limit"` +} + +// URLsResponse is the URLs response. +type URLsResponse struct { + Metadata URLsMetadata `json:"metadata"` + Results []URLsResponseURL `json:"results"` +} + +// URLsMetadata contains metadata in the URLs response. +type URLsMetadata struct { + Count int64 `json:"count"` +} + +// URLsResponseURL is a single URL in the URLs response. +type URLsResponseURL struct { + CategoryCode string `json:"category_code"` + CountryCode string `json:"country_code"` + URL string `json:"url"` +} diff --git a/internal/engine/ooapi/apis.go b/internal/engine/ooapi/apis.go new file mode 100644 index 0000000..3c079be --- /dev/null +++ b/internal/engine/ooapi/apis.go @@ -0,0 +1,607 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:50.431349269 +0100 CET m=+0.000196051 + +package ooapi + +//go:generate go run ./internal/generator -file apis.go + +import ( + "context" + "net/http" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// CheckReportIDAPI implements the CheckReportID API. +type CheckReportIDAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *CheckReportIDAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *CheckReportIDAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *CheckReportIDAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *CheckReportIDAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the CheckReportID API. +func (api *CheckReportIDAPI) Call(ctx context.Context, req *apimodel.CheckReportIDRequest) (*apimodel.CheckReportIDResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// CheckInAPI implements the CheckIn API. +type CheckInAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *CheckInAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *CheckInAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *CheckInAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *CheckInAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the CheckIn API. +func (api *CheckInAPI) Call(ctx context.Context, req *apimodel.CheckInRequest) (*apimodel.CheckInResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// LoginAPI implements the Login API. +type LoginAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *LoginAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *LoginAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *LoginAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *LoginAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the Login API. +func (api *LoginAPI) Call(ctx context.Context, req *apimodel.LoginRequest) (*apimodel.LoginResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// MeasurementMetaAPI implements the MeasurementMeta API. +type MeasurementMetaAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *MeasurementMetaAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *MeasurementMetaAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *MeasurementMetaAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *MeasurementMetaAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the MeasurementMeta API. +func (api *MeasurementMetaAPI) Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// RegisterAPI implements the Register API. +type RegisterAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *RegisterAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *RegisterAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *RegisterAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *RegisterAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the Register API. +func (api *RegisterAPI) Call(ctx context.Context, req *apimodel.RegisterRequest) (*apimodel.RegisterResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// TestHelpersAPI implements the TestHelpers API. +type TestHelpersAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *TestHelpersAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *TestHelpersAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *TestHelpersAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *TestHelpersAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the TestHelpers API. +func (api *TestHelpersAPI) Call(ctx context.Context, req *apimodel.TestHelpersRequest) (apimodel.TestHelpersResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// PsiphonConfigAPI implements the PsiphonConfig API. +type PsiphonConfigAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + Token string // mandatory + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +// WithToken returns a copy of the API where the +// value of the Token field is replaced with token. +func (api *PsiphonConfigAPI) WithToken(token string) PsiphonConfigCaller { + out := &PsiphonConfigAPI{} + out.BaseURL = api.BaseURL + out.HTTPClient = api.HTTPClient + out.JSONCodec = api.JSONCodec + out.RequestMaker = api.RequestMaker + out.UserAgent = api.UserAgent + out.Token = token + return out +} + +func (api *PsiphonConfigAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *PsiphonConfigAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *PsiphonConfigAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *PsiphonConfigAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the PsiphonConfig API. +func (api *PsiphonConfigAPI) Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.Token == "" { + return nil, ErrMissingToken + } + httpReq.Header.Add("Authorization", newAuthorizationHeader(api.Token)) + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// TorTargetsAPI implements the TorTargets API. +type TorTargetsAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + Token string // mandatory + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +// WithToken returns a copy of the API where the +// value of the Token field is replaced with token. +func (api *TorTargetsAPI) WithToken(token string) TorTargetsCaller { + out := &TorTargetsAPI{} + out.BaseURL = api.BaseURL + out.HTTPClient = api.HTTPClient + out.JSONCodec = api.JSONCodec + out.RequestMaker = api.RequestMaker + out.UserAgent = api.UserAgent + out.Token = token + return out +} + +func (api *TorTargetsAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *TorTargetsAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *TorTargetsAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *TorTargetsAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the TorTargets API. +func (api *TorTargetsAPI) Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.Token == "" { + return nil, ErrMissingToken + } + httpReq.Header.Add("Authorization", newAuthorizationHeader(api.Token)) + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// URLsAPI implements the URLs API. +type URLsAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *URLsAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *URLsAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *URLsAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *URLsAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the URLs API. +func (api *URLsAPI) Call(ctx context.Context, req *apimodel.URLsRequest) (*apimodel.URLsResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// OpenReportAPI implements the OpenReport API. +type OpenReportAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *OpenReportAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *OpenReportAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *OpenReportAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *OpenReportAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the OpenReport API. +func (api *OpenReportAPI) Call(ctx context.Context, req *apimodel.OpenReportRequest) (*apimodel.OpenReportResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// SubmitMeasurementAPI implements the SubmitMeasurement API. +type SubmitMeasurementAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + TemplateExecutor TemplateExecutor // optional + UserAgent string // optional +} + +func (api *SubmitMeasurementAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *SubmitMeasurementAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *SubmitMeasurementAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *SubmitMeasurementAPI) templateExecutor() TemplateExecutor { + if api.TemplateExecutor != nil { + return api.TemplateExecutor + } + return &defaultTemplateExecutor{} +} + +func (api *SubmitMeasurementAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the SubmitMeasurement API. +func (api *SubmitMeasurementAPI) Call(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*apimodel.SubmitMeasurementResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} diff --git a/internal/engine/ooapi/apis_test.go b/internal/engine/ooapi/apis_test.go new file mode 100644 index 0000000..ace748d --- /dev/null +++ b/internal/engine/ooapi/apis_test.go @@ -0,0 +1,2776 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:50.81792142 +0100 CET m=+0.000095792 + +package ooapi + +//go:generate go run ./internal/generator -file apis_test.go + +import ( + "context" + "encoding/json" + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func TestCheckReportIDInvalidURL(t *testing.T) { + api := &CheckReportIDAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &CheckReportIDAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleCheckReportID struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.CheckReportIDResponse + url *url.URL + userAgent string +} + +func (h *handleCheckReportID) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.CheckReportIDResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestCheckReportIDRoundTrip(t *testing.T) { + // setup + handler := &handleCheckReportID{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &CheckReportIDAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestCheckReportIDMandatoryFields(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 500, + }} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} // deliberately empty + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrEmptyField) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInInvalidURL(t *testing.T) { + api := &CheckInAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &CheckInAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &CheckInAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &CheckInAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleCheckIn struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.CheckInResponse + url *url.URL + userAgent string +} + +func (h *handleCheckIn) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.CheckInResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestCheckInRoundTrip(t *testing.T) { + // setup + handler := &handleCheckIn{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &CheckInAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.CheckInRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestLoginInvalidURL(t *testing.T) { + api := &LoginAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &LoginAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &LoginAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &LoginAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleLogin struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.LoginResponse + url *url.URL + userAgent string +} + +func (h *handleLogin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.LoginResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestLoginRoundTrip(t *testing.T) { + // setup + handler := &handleLogin{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &LoginAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.LoginRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestMeasurementMetaInvalidURL(t *testing.T) { + api := &MeasurementMetaAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &MeasurementMetaAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleMeasurementMeta struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.MeasurementMetaResponse + url *url.URL + userAgent string +} + +func (h *handleMeasurementMeta) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.MeasurementMetaResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestMeasurementMetaRoundTrip(t *testing.T) { + // setup + handler := &handleMeasurementMeta{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &MeasurementMetaAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestMeasurementMetaMandatoryFields(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 500, + }} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} // deliberately empty + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrEmptyField) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterInvalidURL(t *testing.T) { + api := &RegisterAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &RegisterAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &RegisterAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &RegisterAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleRegister struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.RegisterResponse + url *url.URL + userAgent string +} + +func (h *handleRegister) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.RegisterResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestRegisterRoundTrip(t *testing.T) { + // setup + handler := &handleRegister{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &RegisterAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.RegisterRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestTestHelpersInvalidURL(t *testing.T) { + api := &TestHelpersAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &TestHelpersAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &TestHelpersAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleTestHelpers struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp apimodel.TestHelpersResponse + url *url.URL + userAgent string +} + +func (h *handleTestHelpers) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out apimodel.TestHelpersResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestTestHelpersRoundTrip(t *testing.T) { + // setup + handler := &handleTestHelpers{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &TestHelpersAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestTestHelpersResponseLiteralNull(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`null`)}, + }} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrJSONLiteralNull) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigInvalidURL(t *testing.T) { + api := &PsiphonConfigAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithMissingToken(t *testing.T) { + api := &PsiphonConfigAPI{} // no token + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrMissingToken) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &PsiphonConfigAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handlePsiphonConfig struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp apimodel.PsiphonConfigResponse + url *url.URL + userAgent string +} + +func (h *handlePsiphonConfig) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out apimodel.PsiphonConfigResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestPsiphonConfigRoundTrip(t *testing.T) { + // setup + handler := &handlePsiphonConfig{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &PsiphonConfigAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + ff.fill(&api.Token) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestPsiphonConfigResponseLiteralNull(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`null`)}, + }} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrJSONLiteralNull) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsInvalidURL(t *testing.T) { + api := &TorTargetsAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithMissingToken(t *testing.T) { + api := &TorTargetsAPI{} // no token + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrMissingToken) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &TorTargetsAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &TorTargetsAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleTorTargets struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp apimodel.TorTargetsResponse + url *url.URL + userAgent string +} + +func (h *handleTorTargets) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out apimodel.TorTargetsResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestTorTargetsRoundTrip(t *testing.T) { + // setup + handler := &handleTorTargets{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &TorTargetsAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + ff.fill(&api.Token) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestTorTargetsResponseLiteralNull(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`null`)}, + }} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrJSONLiteralNull) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsInvalidURL(t *testing.T) { + api := &URLsAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &URLsAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &URLsAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleURLs struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.URLsResponse + url *url.URL + userAgent string +} + +func (h *handleURLs) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.URLsResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestURLsRoundTrip(t *testing.T) { + // setup + handler := &handleURLs{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &URLsAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestOpenReportInvalidURL(t *testing.T) { + api := &OpenReportAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &OpenReportAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &OpenReportAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &OpenReportAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleOpenReport struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.OpenReportResponse + url *url.URL + userAgent string +} + +func (h *handleOpenReport) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.OpenReportResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestOpenReportRoundTrip(t *testing.T) { + // setup + handler := &handleOpenReport{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &OpenReportAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.OpenReportRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestSubmitMeasurementInvalidURL(t *testing.T) { + api := &SubmitMeasurementAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &SubmitMeasurementAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &SubmitMeasurementAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleSubmitMeasurement struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.SubmitMeasurementResponse + url *url.URL + userAgent string +} + +func (h *handleSubmitMeasurement) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.SubmitMeasurementResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestSubmitMeasurementRoundTrip(t *testing.T) { + // setup + handler := &handleSubmitMeasurement{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &SubmitMeasurementAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.SubmitMeasurementRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestSubmitMeasurementTemplateErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 500, + }} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + TemplateExecutor: &FakeTemplateExecutor{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} diff --git a/internal/engine/ooapi/caching.go b/internal/engine/ooapi/caching.go new file mode 100644 index 0000000..2b2e333 --- /dev/null +++ b/internal/engine/ooapi/caching.go @@ -0,0 +1,98 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:51.194159684 +0100 CET m=+0.000175181 + +package ooapi + +//go:generate go run ./internal/generator -file caching.go + +import ( + "context" + "reflect" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// MeasurementMetaCache implements caching for MeasurementMetaAPI. +type MeasurementMetaCache struct { + API MeasurementMetaCaller // mandatory + GobCodec GobCodec // optional + KVStore KVStore // mandatory +} + +type cacheEntryForMeasurementMeta struct { + Req *apimodel.MeasurementMetaRequest + Resp *apimodel.MeasurementMetaResponse +} + +// Call calls the API and implements caching. +func (c *MeasurementMetaCache) Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + if resp, _ := c.readcache(req); resp != nil { + return resp, nil + } + resp, err := c.API.Call(ctx, req) + if err != nil { + return nil, err + } + if err := c.writecache(req, resp); err != nil { + return nil, err + } + return resp, nil +} + +func (c *MeasurementMetaCache) gobCodec() GobCodec { + if c.GobCodec != nil { + return c.GobCodec + } + return &defaultGobCodec{} +} + +func (c *MeasurementMetaCache) getcache() ([]cacheEntryForMeasurementMeta, error) { + data, err := c.KVStore.Get("MeasurementMeta.cache") + if err != nil { + return nil, err + } + var out []cacheEntryForMeasurementMeta + if err := c.gobCodec().Decode(data, &out); err != nil { + return nil, err + } + return out, nil +} + +func (c *MeasurementMetaCache) setcache(in []cacheEntryForMeasurementMeta) error { + data, err := c.gobCodec().Encode(in) + if err != nil { + return err + } + return c.KVStore.Set("MeasurementMeta.cache", data) +} + +func (c *MeasurementMetaCache) readcache(req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + cache, err := c.getcache() + if err != nil { + return nil, err + } + for _, cur := range cache { + if reflect.DeepEqual(req, cur.Req) { + return cur.Resp, nil + } + } + return nil, errCacheNotFound +} + +func (c *MeasurementMetaCache) writecache(req *apimodel.MeasurementMetaRequest, resp *apimodel.MeasurementMetaResponse) error { + cache, _ := c.getcache() + out := []cacheEntryForMeasurementMeta{{Req: req, Resp: resp}} + const toomany = 64 + for idx, cur := range cache { + if reflect.DeepEqual(req, cur.Req) { + continue // we already updated the cache + } + if idx > toomany { + break + } + out = append(out, cur) + } + return c.setcache(out) +} + +var _ MeasurementMetaCaller = &MeasurementMetaCache{} diff --git a/internal/engine/ooapi/caching_test.go b/internal/engine/ooapi/caching_test.go new file mode 100644 index 0000000..592d69c --- /dev/null +++ b/internal/engine/ooapi/caching_test.go @@ -0,0 +1,222 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:51.49660021 +0100 CET m=+0.000217672 + +package ooapi + +//go:generate go run ./internal/generator -file caching_test.go + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func TestCacheMeasurementMetaAPISuccess(t *testing.T) { + ff := &fakeFill{} + var expect *apimodel.MeasurementMetaResponse + ff.fill(&expect) + cache := &MeasurementMetaCache{ + API: &FakeMeasurementMetaAPI{ + Response: expect, + }, + KVStore: &memkvstore{}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + resp, err := cache.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } +} + +func TestCacheMeasurementMetaAPIWriteCacheError(t *testing.T) { + errMocked := errors.New("mocked error") + ff := &fakeFill{} + var expect *apimodel.MeasurementMetaResponse + ff.fill(&expect) + cache := &MeasurementMetaCache{ + API: &FakeMeasurementMetaAPI{ + Response: expect, + }, + KVStore: &FakeKVStore{SetError: errMocked}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + resp, err := cache.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestCacheMeasurementMetaAPIFailureWithNoCache(t *testing.T) { + errMocked := errors.New("mocked error") + ff := &fakeFill{} + cache := &MeasurementMetaCache{ + API: &FakeMeasurementMetaAPI{ + Err: errMocked, + }, + KVStore: &memkvstore{}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + resp, err := cache.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestCacheMeasurementMetaAPIFailureWithPreviousCache(t *testing.T) { + ff := &fakeFill{} + var expect *apimodel.MeasurementMetaResponse + ff.fill(&expect) + fakeapi := &FakeMeasurementMetaAPI{ + Response: expect, + } + cache := &MeasurementMetaCache{ + API: fakeapi, + KVStore: &memkvstore{}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + // first pass with no error at all + // use a separate scope to be sure we avoid mistakes + { + resp, err := cache.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + } + // second pass with failure + errMocked := errors.New("mocked error") + fakeapi.Err = errMocked + fakeapi.Response = nil + resp2, err := cache.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp2 == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp2); diff != "" { + t.Fatal(diff) + } +} + +func TestCacheMeasurementMetaAPISetcacheWithEncodeError(t *testing.T) { + ff := &fakeFill{} + errMocked := errors.New("mocked error") + var in []cacheEntryForMeasurementMeta + ff.fill(&in) + cache := &MeasurementMetaCache{ + GobCodec: &FakeCodec{EncodeErr: errMocked}, + } + err := cache.setcache(in) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } +} + +func TestCacheMeasurementMetaAPIReadCacheNotFound(t *testing.T) { + ff := &fakeFill{} + var incache []cacheEntryForMeasurementMeta + ff.fill(&incache) + cache := &MeasurementMetaCache{ + KVStore: &memkvstore{}, + } + err := cache.setcache(incache) + if err != nil { + t.Fatal(err) + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + out, err := cache.readcache(req) + if !errors.Is(err, errCacheNotFound) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestCacheMeasurementMetaAPIWriteCacheDuplicate(t *testing.T) { + ff := &fakeFill{} + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + var resp1 *apimodel.MeasurementMetaResponse + ff.fill(&resp1) + var resp2 *apimodel.MeasurementMetaResponse + ff.fill(&resp2) + cache := &MeasurementMetaCache{ + KVStore: &memkvstore{}, + } + err := cache.writecache(req, resp1) + if err != nil { + t.Fatal(err) + } + err = cache.writecache(req, resp2) + if err != nil { + t.Fatal(err) + } + out, err := cache.readcache(req) + if err != nil { + t.Fatal(err) + } + if out == nil { + t.Fatal("expected non-nil here") + } + if diff := cmp.Diff(resp2, out); diff != "" { + t.Fatal(diff) + } +} + +func TestCacheMeasurementMetaAPICacheSizeLimited(t *testing.T) { + ff := &fakeFill{} + cache := &MeasurementMetaCache{ + KVStore: &memkvstore{}, + } + var prev int + for { + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + var resp *apimodel.MeasurementMetaResponse + ff.fill(&resp) + err := cache.writecache(req, resp) + if err != nil { + t.Fatal(err) + } + out, err := cache.getcache() + if err != nil { + t.Fatal(err) + } + if len(out) > prev { + prev = len(out) + continue + } + break + } +} diff --git a/internal/engine/ooapi/callers.go b/internal/engine/ooapi/callers.go new file mode 100644 index 0000000..89a1b94 --- /dev/null +++ b/internal/engine/ooapi/callers.go @@ -0,0 +1,78 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:51.773813223 +0100 CET m=+0.000114768 + +package ooapi + +//go:generate go run ./internal/generator -file callers.go + +import ( + "context" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// CheckReportIDCaller represents any type exposing a method +// like CheckReportIDAPI.Call. +type CheckReportIDCaller interface { + Call(ctx context.Context, req *apimodel.CheckReportIDRequest) (*apimodel.CheckReportIDResponse, error) +} + +// CheckInCaller represents any type exposing a method +// like CheckInAPI.Call. +type CheckInCaller interface { + Call(ctx context.Context, req *apimodel.CheckInRequest) (*apimodel.CheckInResponse, error) +} + +// LoginCaller represents any type exposing a method +// like LoginAPI.Call. +type LoginCaller interface { + Call(ctx context.Context, req *apimodel.LoginRequest) (*apimodel.LoginResponse, error) +} + +// MeasurementMetaCaller represents any type exposing a method +// like MeasurementMetaAPI.Call. +type MeasurementMetaCaller interface { + Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) +} + +// RegisterCaller represents any type exposing a method +// like RegisterAPI.Call. +type RegisterCaller interface { + Call(ctx context.Context, req *apimodel.RegisterRequest) (*apimodel.RegisterResponse, error) +} + +// TestHelpersCaller represents any type exposing a method +// like TestHelpersAPI.Call. +type TestHelpersCaller interface { + Call(ctx context.Context, req *apimodel.TestHelpersRequest) (apimodel.TestHelpersResponse, error) +} + +// PsiphonConfigCaller represents any type exposing a method +// like PsiphonConfigAPI.Call. +type PsiphonConfigCaller interface { + Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) +} + +// TorTargetsCaller represents any type exposing a method +// like TorTargetsAPI.Call. +type TorTargetsCaller interface { + Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) +} + +// URLsCaller represents any type exposing a method +// like URLsAPI.Call. +type URLsCaller interface { + Call(ctx context.Context, req *apimodel.URLsRequest) (*apimodel.URLsResponse, error) +} + +// OpenReportCaller represents any type exposing a method +// like OpenReportAPI.Call. +type OpenReportCaller interface { + Call(ctx context.Context, req *apimodel.OpenReportRequest) (*apimodel.OpenReportResponse, error) +} + +// SubmitMeasurementCaller represents any type exposing a method +// like SubmitMeasurementAPI.Call. +type SubmitMeasurementCaller interface { + Call(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*apimodel.SubmitMeasurementResponse, error) +} diff --git a/internal/engine/ooapi/cloners.go b/internal/engine/ooapi/cloners.go new file mode 100644 index 0000000..d461e5a --- /dev/null +++ b/internal/engine/ooapi/cloners.go @@ -0,0 +1,18 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.108352268 +0100 CET m=+0.000275862 + +package ooapi + +//go:generate go run ./internal/generator -file cloners.go + +// PsiphonConfigCaller represents any type exposing a method +// like PsiphonConfigAPI.WithToken. +type PsiphonConfigCloner interface { + WithToken(token string) PsiphonConfigCaller +} + +// TorTargetsCaller represents any type exposing a method +// like TorTargetsAPI.WithToken. +type TorTargetsCloner interface { + WithToken(token string) TorTargetsCaller +} diff --git a/internal/engine/ooapi/default.go b/internal/engine/ooapi/default.go new file mode 100644 index 0000000..9917760 --- /dev/null +++ b/internal/engine/ooapi/default.go @@ -0,0 +1,57 @@ +package ooapi + +import ( + "bytes" + "context" + "encoding/gob" + "encoding/json" + "io" + "net/http" + "strings" + "text/template" +) + +type defaultRequestMaker struct{} + +func (*defaultRequestMaker) NewRequest( + ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) { + return http.NewRequestWithContext(ctx, method, URL, body) +} + +type defaultJSONCodec struct{} + +func (*defaultJSONCodec) Encode(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func (*defaultJSONCodec) Decode(b []byte, v interface{}) error { + return json.Unmarshal(b, v) +} + +type defaultTemplateExecutor struct{} + +func (*defaultTemplateExecutor) Execute(tmpl string, v interface{}) (string, error) { + to, err := template.New("t").Parse(tmpl) + if err != nil { + return "", err + } + var sb strings.Builder + if err := to.Execute(&sb, v); err != nil { + return "", err + } + return sb.String(), nil +} + +type defaultGobCodec struct{} + +func (*defaultGobCodec) Encode(v interface{}) ([]byte, error) { + var bb bytes.Buffer + if err := gob.NewEncoder(&bb).Encode(v); err != nil { + return nil, err + } + return bb.Bytes(), nil +} + +func (*defaultGobCodec) Decode(b []byte, v interface{}) error { + return gob.NewDecoder(bytes.NewReader(b)).Decode(v) +} diff --git a/internal/engine/ooapi/default_test.go b/internal/engine/ooapi/default_test.go new file mode 100644 index 0000000..930181f --- /dev/null +++ b/internal/engine/ooapi/default_test.go @@ -0,0 +1,41 @@ +package ooapi + +import ( + "strings" + "testing" +) + +func TestDefaultTemplateExecutorParseError(t *testing.T) { + te := &defaultTemplateExecutor{} + out, err := te.Execute("{{ .Foo", nil) + if err == nil || !strings.HasSuffix(err.Error(), "unclosed action") { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string") + } +} + +func TestDefaultTemplateExecutorExecError(t *testing.T) { + te := &defaultTemplateExecutor{} + arg := make(chan interface{}) + out, err := te.Execute("{{ .Foo }}", arg) + if err == nil || !strings.Contains(err.Error(), `can't evaluate field Foo`) { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string") + } +} + +func TestDefaultGobCodecEncodeError(t *testing.T) { + codec := &defaultGobCodec{} + arg := make(chan interface{}) + data, err := codec.Encode(arg) + if err == nil || !strings.Contains(err.Error(), "can't handle type") { + t.Fatal("not the error we expected", err) + } + if data != nil { + t.Fatal("expected nil data") + } +} diff --git a/internal/engine/ooapi/dependencies.go b/internal/engine/ooapi/dependencies.go new file mode 100644 index 0000000..0b1bbf1 --- /dev/null +++ b/internal/engine/ooapi/dependencies.go @@ -0,0 +1,54 @@ +package ooapi + +import ( + "context" + "io" + "net/http" +) + +// JSONCodec is a JSON encoder and decoder. +type JSONCodec interface { + // Encode encodes v as a serialized JSON byte slice. + Encode(v interface{}) ([]byte, error) + + // Decode decodes the serialized JSON byte slice into v. + Decode(b []byte, v interface{}) error +} + +// RequestMaker makes an HTTP request. +type RequestMaker interface { + // NewRequest creates a new HTTP request. + NewRequest(ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) +} + +// TemplateExecutor parses and executes a text template. +type TemplateExecutor interface { + // Execute takes in input a template string and some piece of data. It + // returns either a string where template parameters have been replaced, + // on success, or an error, on failure. + Execute(tmpl string, v interface{}) (string, error) +} + +// HTTPClient is the interface of a generic HTTP client. +type HTTPClient interface { + // Do should work like http.Client.Do. + Do(req *http.Request) (*http.Response, error) +} + +// GobCodec is a Gob encoder and decoder. +type GobCodec interface { + // Encode encodes v as a serialized gob byte slice. + Encode(v interface{}) ([]byte, error) + + // Decode decodes the serialized gob byte slice into v. + Decode(b []byte, v interface{}) error +} + +// KVStore is a key-value store. +type KVStore interface { + // Get gets a value from the key-value store. + Get(key string) ([]byte, error) + + // Set stores a value into the key-value store. + Set(key string, value []byte) error +} diff --git a/internal/engine/ooapi/doc.go b/internal/engine/ooapi/doc.go new file mode 100644 index 0000000..1f8f6fd --- /dev/null +++ b/internal/engine/ooapi/doc.go @@ -0,0 +1,163 @@ +// Package ooapi contains clients for the OONI API. We +// automatically generate the code in this package from +// the apimodel and internal/generator packages. For +// each OONI API, we define up to three data structures: +// +// 1. a data structure representing the API; +// +// 2. a caching data structure, if the API +// supports caching; +// +// 3. an auto-login data structure, if the API +// requires login. +// +// The rest of this documentation page describes these +// three data structures and the design and architecture +// of this package. Refer to subpackages for more +// information on how to specify an API. +// +// API data structure +// +// For each API, this package defines a data structure +// representing the API. For example, for the TorTargets API, +// we define the TorTargetsAPI data structure. +// +// The API data structure defines a method named Call that +// allows calling the specified API. Call takes as arguments +// a context and the request for the API and returns the +// API response or an error. +// +// Request and response messages live inside the apimodel +// subpackage. We name them after the API. Thus, for +// the TorTargets API, the request is TorTargetsRequest, +// and the response is TorTargetsResponse. +// +// API data structures are cheap to create and do not +// mutate. They should be used in place and then forgotten +// off once the API call is complete. +// +// Unless explicitly indicated, the zero value of every +// API data structure is a valid API data structure. +// +// In terms of dependencies, APIs certainly need an http.Client +// to communicate with the OONI backend. To represent such a +// client, we use the HTTPClient interface. If you do not tell +// an API which http.Client to use, we will default to the +// standard library's http.DefaultClient. +// +// An API also depends on a JSONCodec. That is, on a data +// structures that encodes data to/from JSON. If you do not +// specify explicitly a JSONCodec, we will use the Go +// standard library's JSON implementation. +// +// When an API requires authentication, you need to tell +// it which authentication token to use. This gives you +// control over obtaining the token and is the low-level +// way of interacting with authenticated APIs. We recommend +// using the auto-login wrappers instead (see below). +// +// Authenticated APIs also define the WithToken method. This +// method takes as argument a token and returns a copy of the +// original API using the given token. We use this method +// to implement auto-login wrappers. +// +// For each API, we also define two interfaces: +// +// 1. the Caller interface represents the possibility of +// calling a specific API with the correct arguments; +// +// 2. the Cloner interface represents the possibility of +// calling WithToken on the given API. +// +// They abstract the interaction between the API type and +// its caching and auto-login wrappers. +// +// Caching +// +// If an API supports caching, we define a type whose name +// ends in Cache. The TorTargets API cache, for example, +// is TorTargetsCache. These caching types wrap the API type +// and provide the caching functionality. +// +// Because the cache needs to read from and write to the +// disk, a caching type needs a KVStore. A KVStore is +// an interface that allow you to bind a specific key to +// a given blob of bytes and to retrieve such bytes later. +// +// Caches use the gob data format from the Go standard +// library (`encoding/gob`). We abstract this dependency +// using the GobCodec interface. By default, when you +// do not specify a GobCodec we use the implementation +// of gob from the Go standard library. +// +// See the example describing caching for more information +// on how to use caching. +// +// Auto-login +// +// If an API supports auto-login, we define a type whose +// name ends with WithLogin. The TorTargets auto-login struct, +// for example, is called TorTargetsAPIWithLogin. +// +// Auto-login wrappers need to store persistent data. We +// use a KVStore for that (see above). We encode login data +// using JSON. To this end, we use a JSONCodec (also +// described above). +// +// See the example describing auto-login for more information +// on how to use auto-login. +// +// Design +// +// Most of the code in this package is auto-generated from the +// data model in ./apimodel and the definition of APIs provided +// by ./internal/generator/spec.go. +// +// We keep the generated files up-to-date by running +// +// go generate ./... +// +// We have tests that ensure that the definition of the API +// used here is reasonably close to the server's one. +// +// Testing +// +// The following command +// +// go test ./... +// +// will, among other things, ensure that the our API spec +// is consistent with the server's one. Running +// +// go test -short ./... +// +// will exclude most (slow) integration tests. +// +// Architecture +// +// The ./apimodel package contains the definition of request +// and response messages. We rely on tagging to specify how +// we should encode and decode messages. +// +// The ./internal/generator contains code to generate most +// code in this package. In particular, the spec.go file is +// the specification of the APIs. +// +// Notable generated files +// +// - apis.go: contains APIs (e.g., TorTargetsAPI); +// +// - caching.go: contains caching wrappers for every API +// that declares that it needs a cache (e.g., TorTargetsCache); +// +// - callers.go: contains Callers; +// +// - cloners.go: contains the Cloners; +// +// - login.go: contains auto-login wrappers (e.g., +// TorTargetsAPIWithLogin); +// +// - requests.go: contains code to generate http.Requests. +// +// - responses.go: code to parse http.Responses. +package ooapi diff --git a/internal/engine/ooapi/errors.go b/internal/engine/ooapi/errors.go new file mode 100644 index 0000000..35ed7cc --- /dev/null +++ b/internal/engine/ooapi/errors.go @@ -0,0 +1,14 @@ +package ooapi + +import "errors" + +// Errors defined by this package. In addition to these errors, this +// package may of course return any other stdlib specific error. +var ( + ErrEmptyField = errors.New("apiclient: empty field") + ErrHTTPFailure = errors.New("apiclient: http request failed") + ErrJSONLiteralNull = errors.New("apiclient: server returned us a literal null") + ErrMissingToken = errors.New("apiclient: missing auth token") + ErrUnauthorized = errors.New("apiclient: not authorized") + errCacheNotFound = errors.New("apiclient: not found in cache") +) diff --git a/internal/engine/ooapi/fake_test.go b/internal/engine/ooapi/fake_test.go new file mode 100644 index 0000000..9b0b708 --- /dev/null +++ b/internal/engine/ooapi/fake_test.go @@ -0,0 +1,96 @@ +package ooapi + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "time" +) + +type FakeCodec struct { + DecodeErr error + EncodeData []byte + EncodeErr error +} + +func (mc *FakeCodec) Encode(v interface{}) ([]byte, error) { + return mc.EncodeData, mc.EncodeErr +} + +func (mc *FakeCodec) Decode(b []byte, v interface{}) error { + return mc.DecodeErr +} + +type FakeHTTPClient struct { + Err error + Resp *http.Response +} + +func (c *FakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + time.Sleep(10 * time.Microsecond) + if req.Body != nil { + _, _ = ioutil.ReadAll(req.Body) + req.Body.Close() + } + if c.Err != nil { + return nil, c.Err + } + c.Resp.Request = req // non thread safe but it doesn't matter + return c.Resp, nil +} + +type FakeBody struct { + Data []byte + Err error +} + +func (fb *FakeBody) Read(p []byte) (int, error) { + time.Sleep(10 * time.Microsecond) + if fb.Err != nil { + return 0, fb.Err + } + if len(fb.Data) <= 0 { + return 0, io.EOF + } + n := copy(p, fb.Data) + fb.Data = fb.Data[n:] + return n, nil +} + +func (fb *FakeBody) Close() error { + return nil +} + +type FakeRequestMaker struct { + Req *http.Request + Err error +} + +func (frm *FakeRequestMaker) NewRequest( + ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) { + return frm.Req, frm.Err +} + +type FakeTemplateExecutor struct { + Out string + Err error +} + +func (fte *FakeTemplateExecutor) Execute(tmpl string, v interface{}) (string, error) { + return fte.Out, fte.Err +} + +type FakeKVStore struct { + SetError error + GetData []byte + GetError error +} + +func (fs *FakeKVStore) Get(key string) ([]byte, error) { + return fs.GetData, fs.GetError +} + +func (fs *FakeKVStore) Set(key string, value []byte) error { + return fs.SetError +} diff --git a/internal/engine/ooapi/fakeapi_test.go b/internal/engine/ooapi/fakeapi_test.go new file mode 100644 index 0000000..3084483 --- /dev/null +++ b/internal/engine/ooapi/fakeapi_test.go @@ -0,0 +1,190 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.357709034 +0100 CET m=+0.000208565 + +package ooapi + +//go:generate go run ./internal/generator -file fakeapi_test.go + +import ( + "context" + "sync/atomic" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +type FakeCheckReportIDAPI struct { + Err error + Response *apimodel.CheckReportIDResponse + CountCall int32 +} + +func (fapi *FakeCheckReportIDAPI) Call(ctx context.Context, req *apimodel.CheckReportIDRequest) (*apimodel.CheckReportIDResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ CheckReportIDCaller = &FakeCheckReportIDAPI{} +) + +type FakeCheckInAPI struct { + Err error + Response *apimodel.CheckInResponse + CountCall int32 +} + +func (fapi *FakeCheckInAPI) Call(ctx context.Context, req *apimodel.CheckInRequest) (*apimodel.CheckInResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ CheckInCaller = &FakeCheckInAPI{} +) + +type FakeLoginAPI struct { + Err error + Response *apimodel.LoginResponse + CountCall int32 +} + +func (fapi *FakeLoginAPI) Call(ctx context.Context, req *apimodel.LoginRequest) (*apimodel.LoginResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ LoginCaller = &FakeLoginAPI{} +) + +type FakeMeasurementMetaAPI struct { + Err error + Response *apimodel.MeasurementMetaResponse + CountCall int32 +} + +func (fapi *FakeMeasurementMetaAPI) Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ MeasurementMetaCaller = &FakeMeasurementMetaAPI{} +) + +type FakeRegisterAPI struct { + Err error + Response *apimodel.RegisterResponse + CountCall int32 +} + +func (fapi *FakeRegisterAPI) Call(ctx context.Context, req *apimodel.RegisterRequest) (*apimodel.RegisterResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ RegisterCaller = &FakeRegisterAPI{} +) + +type FakeTestHelpersAPI struct { + Err error + Response apimodel.TestHelpersResponse + CountCall int32 +} + +func (fapi *FakeTestHelpersAPI) Call(ctx context.Context, req *apimodel.TestHelpersRequest) (apimodel.TestHelpersResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ TestHelpersCaller = &FakeTestHelpersAPI{} +) + +type FakePsiphonConfigAPI struct { + WithResult PsiphonConfigCaller + Err error + Response apimodel.PsiphonConfigResponse + CountCall int32 +} + +func (fapi *FakePsiphonConfigAPI) Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +func (fapi *FakePsiphonConfigAPI) WithToken(token string) PsiphonConfigCaller { + return fapi.WithResult +} + +var ( + _ PsiphonConfigCaller = &FakePsiphonConfigAPI{} + _ PsiphonConfigCloner = &FakePsiphonConfigAPI{} +) + +type FakeTorTargetsAPI struct { + WithResult TorTargetsCaller + Err error + Response apimodel.TorTargetsResponse + CountCall int32 +} + +func (fapi *FakeTorTargetsAPI) Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +func (fapi *FakeTorTargetsAPI) WithToken(token string) TorTargetsCaller { + return fapi.WithResult +} + +var ( + _ TorTargetsCaller = &FakeTorTargetsAPI{} + _ TorTargetsCloner = &FakeTorTargetsAPI{} +) + +type FakeURLsAPI struct { + Err error + Response *apimodel.URLsResponse + CountCall int32 +} + +func (fapi *FakeURLsAPI) Call(ctx context.Context, req *apimodel.URLsRequest) (*apimodel.URLsResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ URLsCaller = &FakeURLsAPI{} +) + +type FakeOpenReportAPI struct { + Err error + Response *apimodel.OpenReportResponse + CountCall int32 +} + +func (fapi *FakeOpenReportAPI) Call(ctx context.Context, req *apimodel.OpenReportRequest) (*apimodel.OpenReportResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ OpenReportCaller = &FakeOpenReportAPI{} +) + +type FakeSubmitMeasurementAPI struct { + Err error + Response *apimodel.SubmitMeasurementResponse + CountCall int32 +} + +func (fapi *FakeSubmitMeasurementAPI) Call(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*apimodel.SubmitMeasurementResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ SubmitMeasurementCaller = &FakeSubmitMeasurementAPI{} +) diff --git a/internal/engine/ooapi/fakefill_test.go b/internal/engine/ooapi/fakefill_test.go new file mode 100644 index 0000000..5ee7776 --- /dev/null +++ b/internal/engine/ooapi/fakefill_test.go @@ -0,0 +1,146 @@ +package ooapi + +import ( + "math/rand" + "reflect" + "sync" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// fakeFill fills specific data structures with random data. The only +// exception to this behaviour is time.Time, which is instead filled +// with the current time plus a small random number of seconds. +// +// We use this implementation to initialize data in our model. The code +// has been written with that in mind. It will require some hammering in +// case we extend the model with new field types. +type fakeFill struct { + mu sync.Mutex + now func() time.Time + rnd *rand.Rand +} + +func (ff *fakeFill) getRandLocked() *rand.Rand { + if ff.rnd == nil { + now := time.Now + if ff.now != nil { + now = ff.now + } + ff.rnd = rand.New(rand.NewSource(now().UnixNano())) + } + return ff.rnd +} + +func (ff *fakeFill) getRandomString() string { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + n := rnd.Intn(63) + 1 + // See https://stackoverflow.com/a/31832326 + var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rnd.Intn(len(letterRunes))] + } + return string(b) +} + +func (ff *fakeFill) getRandomInt64() int64 { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + return rnd.Int63() +} + +func (ff *fakeFill) getRandomBool() bool { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + return rnd.Float64() >= 0.5 +} + +func (ff *fakeFill) getRandomSmallPositiveInt() int { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + return int(rnd.Int63n(8)) + 1 // safe cast +} + +func (ff *fakeFill) doFill(v reflect.Value) { + for v.Type().Kind() == reflect.Ptr { + if v.IsNil() { + // if the pointer is nil, allocate an element + v.Set(reflect.New(v.Type().Elem())) + } + // switch to the element + v = v.Elem() + } + switch v.Type().Kind() { + case reflect.String: + v.SetString(ff.getRandomString()) + case reflect.Int64: + v.SetInt(ff.getRandomInt64()) + case reflect.Bool: + v.SetBool(ff.getRandomBool()) + case reflect.Struct: + if v.Type().String() == "time.Time" { + // Implementation note: we treat the time specially + // and we avoid attempting to set its fields. + v.Set(reflect.ValueOf(time.Now().Add( + time.Duration(ff.getRandomSmallPositiveInt()) * time.Second))) + return + } + for idx := 0; idx < v.NumField(); idx++ { + ff.doFill(v.Field(idx)) // visit all fields + } + case reflect.Slice: + kind := v.Type().Elem() + total := ff.getRandomSmallPositiveInt() + for idx := 0; idx < total; idx++ { + value := reflect.New(kind) // make a new element + ff.doFill(value) + v.Set(reflect.Append(v, value.Elem())) // append to slice + } + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { + return // not supported + } + v.Set(reflect.MakeMap(v.Type())) // we need to init the map + total := ff.getRandomSmallPositiveInt() + kind := v.Type().Elem() + for idx := 0; idx < total; idx++ { + value := reflect.New(kind) + ff.doFill(value) + v.SetMapIndex(reflect.ValueOf(ff.getRandomString()), value.Elem()) + } + } +} + +// fill fills in with random data. +func (ff *fakeFill) fill(in interface{}) { + ff.doFill(reflect.ValueOf(in)) +} + +func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) { + var req *apimodel.URLsRequest + ff := &fakeFill{} + ff.fill(&req) + if req == nil { + t.Fatal("we expected non nil here") + } +} + +func TestFakeFillAllocatesIntoAMapLike(t *testing.T) { + var resp apimodel.TorTargetsResponse + ff := &fakeFill{} + ff.fill(&resp) + if resp == nil { + t.Fatal("we expected non nil here") + } + if len(resp) < 1 { + t.Fatal("we expected some data here") + } +} diff --git a/internal/engine/ooapi/integration_test.go b/internal/engine/ooapi/integration_test.go new file mode 100644 index 0000000..43465be --- /dev/null +++ b/internal/engine/ooapi/integration_test.go @@ -0,0 +1,204 @@ +package ooapi + +import ( + "context" + "net/http" + "testing" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +type VerboseHTTPClient struct { + t *testing.T +} + +func (c *VerboseHTTPClient) Do(req *http.Request) (*http.Response, error) { + c.t.Logf("> %s %s", req.Method, req.URL.String()) + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.t.Logf("< %s", err.Error()) + return nil, err + } + c.t.Logf("< %d", resp.StatusCode) + return resp, nil +} + +func TestWithRealServerDoCheckIn(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.CheckInRequest{ + Charging: true, + OnWiFi: true, + Platform: "android", + ProbeASN: "AS12353", + ProbeCC: "IT", + RunType: "timed", + SoftwareName: "ooniprobe-android", + SoftwareVersion: "2.7.1", + WebConnectivity: apimodel.CheckInRequestWebConnectivity{ + CategoryCodes: []string{"NEWS", "CULTR"}, + }, + } + httpClnt := &VerboseHTTPClient{t: t} + api := &CheckInAPI{ + HTTPClient: httpClnt, + } + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + for idx, url := range resp.Tests.WebConnectivity.URLs { + if idx >= 3 { + break + } + t.Logf("- %+v", url) + } +} + +func TestWithRealServerDoCheckReportID(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.CheckReportIDRequest{ + ReportID: "20210223T093606Z_ndt_JO_8376_n1_kDYToqrugDY54Soy", + } + api := &CheckReportIDAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} + +func TestWithRealServerDoMeasurementMeta(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.MeasurementMetaRequest{ + ReportID: "20210223T093606Z_ndt_JO_8376_n1_kDYToqrugDY54Soy", + } + api := &MeasurementMetaAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} + +func TestWithRealServerDoOpenReport(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.OpenReportRequest{ + DataFormatVersion: "0.2.0", + Format: "json", + ProbeASN: "AS137", + ProbeCC: "IT", + SoftwareName: "miniooni", + SoftwareVersion: "0.1.0-dev", + TestName: "example", + TestStartTime: "2018-11-01 15:33:20", + TestVersion: "0.1.0", + } + api := &OpenReportAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} + +func TestWithRealServerDoPsiphonConfig(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.PsiphonConfigRequest{} + httpClnt := &VerboseHTTPClient{t: t} + api := &PsiphonConfigAPIWithLogin{ + API: &PsiphonConfigAPI{ + HTTPClient: httpClnt, + }, + KVStore: &memkvstore{}, + RegisterAPI: &RegisterAPI{ + HTTPClient: httpClnt, + }, + LoginAPI: &LoginAPI{ + HTTPClient: httpClnt, + }, + } + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp != nil) +} + +func TestWithRealServerDoTorTargets(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.TorTargetsRequest{} + httpClnt := &VerboseHTTPClient{t: t} + api := &TorTargetsAPIWithLogin{ + API: &TorTargetsAPI{ + HTTPClient: httpClnt, + }, + KVStore: &memkvstore{}, + RegisterAPI: &RegisterAPI{ + HTTPClient: httpClnt, + }, + LoginAPI: &LoginAPI{ + HTTPClient: httpClnt, + }, + } + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp != nil) +} + +func TestWithRealServerDoURLs(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.URLsRequest{ + CountryCode: "IT", + Limit: 3, + } + api := &URLsAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} diff --git a/internal/engine/ooapi/internal/generator/apis.go b/internal/engine/ooapi/internal/generator/apis.go new file mode 100644 index 0000000..e69b1f7 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/apis.go @@ -0,0 +1,180 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +// apiField contains the fields of an API data structure +type apiField struct { + // name is the field name + name string + + // kind is the filed type + kind string + + // comment is a brief comment to document the field + comment string + + // ifLogin indicates whether this field should only be + // emitted when the API requires login + ifLogin bool + + // ifTemplate indicates whether this field should only be + // emitted when the URL path is a template + ifTemplate bool + + // noClone is true when this field should not be copied + // from the parent data structure when cloning + noClone bool +} + +var apiFields = []apiField{{ + name: "BaseURL", + kind: "string", + comment: "optional", +}, { + name: "HTTPClient", + kind: "HTTPClient", + comment: "optional", +}, { + name: "JSONCodec", + kind: "JSONCodec", + comment: "optional", +}, { + name: "Token", + kind: "string", + comment: "mandatory", + ifLogin: true, + noClone: true, +}, { + name: "RequestMaker", + kind: "RequestMaker", + comment: "optional", +}, { + name: "TemplateExecutor", + kind: "TemplateExecutor", + comment: "optional", + ifTemplate: true, +}, { + name: "UserAgent", + kind: "string", + comment: "optional", +}} + +func (d *Descriptor) genNewAPI(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s implements the %s API.\n", d.APIStructName(), d.Name) + fmt.Fprintf(sb, "type %s struct {\n", d.APIStructName()) + for _, f := range apiFields { + if !d.RequiresLogin && f.ifLogin { + continue + } + if !d.URLPath.IsTemplate && f.ifTemplate { + continue + } + fmt.Fprintf(sb, "\t%s %s // %s\n", f.name, f.kind, f.comment) + } + fmt.Fprint(sb, "}\n\n") + + if d.RequiresLogin { + fmt.Fprintf(sb, "// WithToken returns a copy of the API where the\n") + fmt.Fprintf(sb, "// value of the Token field is replaced with token.\n") + fmt.Fprintf(sb, "func (api *%s) WithToken(token string) %s {\n", + d.APIStructName(), d.CallerInterfaceName()) + fmt.Fprintf(sb, "out := &%s{}\n", d.APIStructName()) + for _, f := range apiFields { + if !d.URLPath.IsTemplate && f.ifTemplate { + continue + } + if f.noClone == true { + continue + } + fmt.Fprintf(sb, "out.%s = api.%s\n", f.name, f.name) + } + fmt.Fprint(sb, "out.Token = token\n") + fmt.Fprint(sb, "return out\n") + fmt.Fprint(sb, "}\n\n") + } + + fmt.Fprintf(sb, "func (api *%s) baseURL() string {\n", d.APIStructName()) + fmt.Fprint(sb, "\tif api.BaseURL != \"\" {\n") + fmt.Fprint(sb, "\t\treturn api.BaseURL\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn \"https://ps1.ooni.io\"\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) requestMaker() RequestMaker {\n", d.APIStructName()) + fmt.Fprint(sb, "\tif api.RequestMaker != nil {\n") + fmt.Fprint(sb, "\t\treturn api.RequestMaker\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultRequestMaker{}\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) jsonCodec() JSONCodec {\n", d.APIStructName()) + fmt.Fprint(sb, "\tif api.JSONCodec != nil {\n") + fmt.Fprint(sb, "\t\treturn api.JSONCodec\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultJSONCodec{}\n") + fmt.Fprint(sb, "}\n\n") + + if d.URLPath.IsTemplate { + fmt.Fprintf( + sb, "func (api *%s) templateExecutor() TemplateExecutor {\n", + d.APIStructName()) + fmt.Fprint(sb, "\tif api.TemplateExecutor != nil {\n") + fmt.Fprint(sb, "\t\treturn api.TemplateExecutor\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultTemplateExecutor{}\n") + fmt.Fprint(sb, "}\n\n") + } + + fmt.Fprintf( + sb, "func (api *%s) httpClient() HTTPClient {\n", + d.APIStructName()) + fmt.Fprint(sb, "\tif api.HTTPClient != nil {\n") + fmt.Fprint(sb, "\t\treturn api.HTTPClient\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn http.DefaultClient\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "// Call calls the %s API.\n", d.Name) + fmt.Fprintf( + sb, "func (api *%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.APIStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\thttpReq, err := api.newRequest(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\thttpReq.Header.Add(\"Accept\", \"application/json\")\n") + if d.RequiresLogin { + fmt.Fprint(sb, "\tif api.Token == \"\" {\n") + fmt.Fprint(sb, "\t\treturn nil, ErrMissingToken\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\thttpReq.Header.Add(\"Authorization\", newAuthorizationHeader(api.Token))\n") + } + fmt.Fprint(sb, "\tif api.UserAgent != \"\" {\n") + fmt.Fprint(sb, "\t\thttpReq.Header.Add(\"User-Agent\", api.UserAgent)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.newResponse(api.httpClient().Do(httpReq))\n") + fmt.Fprint(sb, "}\n\n") +} + +// GenAPIsGo generates apis.go. +func GenAPIsGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genNewAPI(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/apistest.go b/internal/engine/ooapi/internal/generator/apistest.go new file mode 100644 index 0000000..31cacbe --- /dev/null +++ b/internal/engine/ooapi/internal/generator/apistest.go @@ -0,0 +1,461 @@ +package main + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +func (d *Descriptor) genTestNewRequest(sb *strings.Builder) { + fmt.Fprintf(sb, "\treq := &%s{}\n", d.RequestTypeNameAsStruct()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\tff.fill(req)\n") +} + +func (d *Descriptor) genTestInvalidURL(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sInvalidURL(t *testing.T) {\n", d.Name) + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tBaseURL: \"\\t\", // invalid\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err == nil || !strings.HasSuffix(err.Error(), \"invalid control character in URL\") {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithMissingToken(sb *strings.Builder) { + if d.RequiresLogin == false { + return // does not make sense when login isn't required + } + fmt.Fprintf(sb, "func Test%sWithMissingToken(t *testing.T) {\n", d.Name) + fmt.Fprintf(sb, "\tapi := &%s{} // no token\n", d.APIStructName()) + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrMissingToken) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithHTTPErr(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithHTTPErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Err: errMocked}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestMarshalErr(sb *strings.Builder) { + if d.Method != "POST" { + return // does not make sense when we don't send a request body + } + fmt.Fprintf(sb, "func Test%sMarshalErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tJSONCodec: &FakeCodec{EncodeErr: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithNewRequestErr(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithNewRequestErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tRequestMaker: &FakeRequestMaker{Err: errMocked},\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWith401(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWith401(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrUnauthorized) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWith400(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWith400(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrHTTPFailure) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithResponseBodyReadErr(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithResponseBodyReadErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 200,\n") + fmt.Fprint(sb, "\t\tBody: &FakeBody{Err: errMocked},\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithUnmarshalFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithUnmarshalFailure(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 200,\n") + fmt.Fprint(sb, "\t\tBody: &FakeBody{Data: []byte(`{}`)},\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + fmt.Fprintf(sb, "\t\tJSONCodec: &FakeCodec{DecodeErr: errMocked},\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestRoundTrip(sb *strings.Builder) { + // generate the type of the handler + fmt.Fprintf(sb, "type handle%s struct {\n", d.Name) + fmt.Fprint(sb, "\taccept string\n") + fmt.Fprint(sb, "\tbody []byte\n") + fmt.Fprint(sb, "\tcontentType string\n") + fmt.Fprint(sb, "\tcount int32\n") + fmt.Fprint(sb, "\tmethod string\n") + fmt.Fprint(sb, "\tmu sync.Mutex\n") + fmt.Fprintf(sb, "\tresp %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\turl *url.URL\n") + fmt.Fprint(sb, "\tuserAgent string\n") + fmt.Fprint(sb, "}\n\n") + + // generate the handling function + fmt.Fprintf(sb, + "func (h *handle%s) ServeHTTP(w http.ResponseWriter, r *http.Request) {", + d.Name) + fmt.Fprint(sb, "\tdefer h.mu.Unlock()\n") + fmt.Fprint(sb, "\th.mu.Lock()\n") + fmt.Fprint(sb, "\tif h.count > 0 {\n") + fmt.Fprint(sb, "\t\tw.WriteHeader(400)\n") + fmt.Fprint(sb, "\t\treturn\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\th.count++\n") + fmt.Fprint(sb, "\tif r.Body != nil {\n") + fmt.Fprint(sb, "\t\tdata, err := ioutil.ReadAll(r.Body)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprintf(sb, "\t\t\tw.WriteHeader(400)\n") + fmt.Fprintf(sb, "\t\t\treturn\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\th.body = data\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\th.method = r.Method\n") + fmt.Fprint(sb, "\th.url = r.URL\n") + fmt.Fprint(sb, "\th.accept = r.Header.Get(\"Accept\")\n") + fmt.Fprint(sb, "\th.contentType = r.Header.Get(\"Content-Type\")\n") + fmt.Fprint(sb, "\th.userAgent = r.Header.Get(\"User-Agent\")\n") + fmt.Fprintf(sb, "\tvar out %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff := fakeFill{}\n") + fmt.Fprint(sb, "\tff.fill(&out)\n") + fmt.Fprintf(sb, "\th.resp = out\n") + fmt.Fprintf(sb, "\tdata, err := json.Marshal(out)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tw.WriteHeader(400)\n") + fmt.Fprintf(sb, "\t\treturn\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tw.Write(data)\n") + fmt.Fprintf(sb, "\t}\n\n") + + // generate the test itself + fmt.Fprintf(sb, "func Test%sRoundTrip(t *testing.T) {\n", d.Name) + + fmt.Fprint(sb, "\t// setup\n") + fmt.Fprintf(sb, "\thandler := &handle%s{}\n", d.Name) + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + fmt.Fprintf(sb, "\treq := &%s{}\n", d.RequestTypeNameAsStruct()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprintf(sb, "\tapi := &%s{BaseURL: srvr.URL}\n", d.APIStructName()) + fmt.Fprint(sb, "\tff.fill(&api.UserAgent)\n") + if d.RequiresLogin { + fmt.Fprint(sb, "\tff.fill(&api.Token)\n") + } + + fmt.Fprint(sb, "\t// issue request\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response here\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// compare our response and server's one\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(handler.resp, resp); diff != \"\" {") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// check whether headers are OK\n") + fmt.Fprint(sb, "\tif handler.accept != \"application/json\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid accept header\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif handler.userAgent != api.UserAgent {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid user-agent header\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// check whether the method is OK\n") + fmt.Fprintf(sb, "\tif handler.method != \"%s\" {\n", d.Method) + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid method\")\n") + fmt.Fprint(sb, "\t}\n") + + if d.Method == "POST" { + fmt.Fprint(sb, "\t// check the body\n") + fmt.Fprint(sb, "\tif handler.contentType != \"application/json\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid content-type header\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tgot := &%s{}\n", d.RequestTypeNameAsStruct()) + fmt.Fprintf(sb, "\tif err := json.Unmarshal(handler.body, &got); err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(req, got); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + } else { + fmt.Fprint(sb, "\t// check the query\n") + fmt.Fprint(sb, "\thttpReq, err := api.newRequest(context.Background(), req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + } + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestResponseLiteralNull(sb *strings.Builder) { + switch d.ResponseTypeKind() { + case reflect.Map: + // fallthrough + case reflect.Struct: + return // test not applicable + } + fmt.Fprintf(sb, "func Test%sResponseLiteralNull(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 200,\n") + fmt.Fprint(sb, "\t\tBody: &FakeBody{Data: []byte(`null`)},\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrJSONLiteralNull) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestMandatoryFields(sb *strings.Builder) { + fields := d.StructFieldsWithTag(d.Request, tagForRequired) + if len(fields) < 1 { + return // nothing to test + } + fmt.Fprintf(sb, "func Test%sMandatoryFields(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 500,\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprintf(sb, "\treq := &%s{} // deliberately empty\n", d.RequestTypeNameAsStruct()) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrEmptyField) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestTemplateErr(sb *strings.Builder) { + if !d.URLPath.IsTemplate { + return // nothing to test + } + fmt.Fprintf(sb, "func Test%sTemplateErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 500,\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t\tTemplateExecutor: &FakeTemplateExecutor{Err: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +// TODO(bassosimone): we should add a panic for every switch for +// the type of a request or a response for robustness. + +func (d *Descriptor) genAPITests(sb *strings.Builder) { + d.genTestInvalidURL(sb) + d.genTestWithMissingToken(sb) + d.genTestWithHTTPErr(sb) + d.genTestMarshalErr(sb) + d.genTestWithNewRequestErr(sb) + d.genTestWith401(sb) + d.genTestWith400(sb) + d.genTestWithResponseBodyReadErr(sb) + d.genTestWithUnmarshalFailure(sb) + d.genTestRoundTrip(sb) + d.genTestResponseLiteralNull(sb) + d.genTestMandatoryFields(sb) + d.genTestTemplateErr(sb) +} + +// GenAPIsTestGo generates apis_test.go. +func GenAPIsTestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"encoding/json\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\t\"io/ioutil\"\n") + fmt.Fprint(&sb, "\t\"net/http/httptest\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\t\"net/url\"\n") + fmt.Fprint(&sb, "\t\"strings\"\n") + fmt.Fprint(&sb, "\t\"testing\"\n") + fmt.Fprint(&sb, "\t\"sync\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/google/go-cmp/cmp\"\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genAPITests(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/caching.go b/internal/engine/ooapi/internal/generator/caching.go new file mode 100644 index 0000000..ca1dd6a --- /dev/null +++ b/internal/engine/ooapi/internal/generator/caching.go @@ -0,0 +1,130 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewCache(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s implements caching for %s.\n", + d.CacheStructName(), d.APIStructName()) + fmt.Fprintf(sb, "type %s struct {\n", d.CacheStructName()) + fmt.Fprintf(sb, "\tAPI %s // mandatory\n", d.CallerInterfaceName()) + fmt.Fprint(sb, "\tGobCodec GobCodec // optional\n") + fmt.Fprint(sb, "\tKVStore KVStore // mandatory\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "type %s struct {\n", d.CacheEntryName()) + fmt.Fprintf(sb, "\tReq %s\n", d.RequestTypeName()) + fmt.Fprintf(sb, "\tResp %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "// Call calls the API and implements caching.\n") + fmt.Fprintf(sb, "func (c *%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.CacheStructName(), d.RequestTypeName(), d.ResponseTypeName()) + if d.CachePolicy == CacheAlways { + fmt.Fprint(sb, "\tif resp, _ := c.readcache(req); resp != nil {\n") + fmt.Fprint(sb, "\t\treturn resp, nil\n") + fmt.Fprint(sb, "\t}\n") + } + fmt.Fprint(sb, "\tresp, err := c.API.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + if d.CachePolicy == CacheFallback { + fmt.Fprint(sb, "\t\tif resp, _ := c.readcache(req); resp != nil {\n") + fmt.Fprint(sb, "\t\t\treturn resp, nil\n") + fmt.Fprint(sb, "\t\t}\n") + } + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif err := c.writecache(req, resp); err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn resp, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) gobCodec() GobCodec {\n", d.CacheStructName()) + fmt.Fprint(sb, "\tif c.GobCodec != nil {\n") + fmt.Fprint(sb, "\t\treturn c.GobCodec\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultGobCodec{}\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) getcache() ([]%s, error) {\n", + d.CacheStructName(), d.CacheEntryName()) + fmt.Fprintf(sb, "\tdata, err := c.KVStore.Get(\"%s\")\n", d.CacheKey()) + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar out []%s\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tif err := c.gobCodec().Decode(data, &out); err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn out, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) setcache(in []%s) error {\n", + d.CacheStructName(), d.CacheEntryName()) + fmt.Fprint(sb, "\tdata, err := c.gobCodec().Encode(in)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\treturn c.KVStore.Set(\"%s\", data)\n", d.CacheKey()) + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) readcache(req %s) (%s, error) {\n", + d.CacheStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\tcache, err := c.getcache()\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tfor _, cur := range cache {\n") + fmt.Fprint(sb, "\t\tif reflect.DeepEqual(req, cur.Req) {\n") + fmt.Fprint(sb, "\t\t\treturn cur.Resp, nil\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn nil, errCacheNotFound\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) writecache(req %s, resp %s) error {\n", + d.CacheStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\tcache, _ := c.getcache()\n") + fmt.Fprintf(sb, "\tout := []%s{{Req: req, Resp: resp}}\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tconst toomany = 64\n") + fmt.Fprint(sb, "\tfor idx, cur := range cache {\n") + fmt.Fprint(sb, "\t\tif reflect.DeepEqual(req, cur.Req) {\n") + fmt.Fprint(sb, "\t\t\tcontinue // we already updated the cache\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif idx > toomany {\n") + fmt.Fprint(sb, "\t\t\tbreak\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tout = append(out, cur)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn c.setcache(out)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "var _ %s = &%s{}\n\n", d.CallerInterfaceName(), + d.CacheStructName()) +} + +// GenCachingGo generates caching.go. +func GenCachingGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"reflect\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if desc.CachePolicy == CacheNone { + continue + } + desc.genNewCache(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/cachingtest.go b/internal/engine/ooapi/internal/generator/cachingtest.go new file mode 100644 index 0000000..a556885 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/cachingtest.go @@ -0,0 +1,274 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genTestCacheSuccess(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sSuccess(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWriteCacheError(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sWriteCacheError(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tKVStore: &FakeKVStore{SetError: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestFailureWithNoCache(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sFailureWithNoCache(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestFailureWithPreviousCache(sb *strings.Builder) { + // This works for both caching policies. + fmt.Fprintf(sb, "func TestCache%sFailureWithPreviousCache(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + fmt.Fprintf(sb, "\tfakeapi := &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tAPI: fakeapi,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\t// first pass with no error at all\n") + fmt.Fprint(sb, "\t// use a separate scope to be sure we avoid mistakes\n") + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t// second pass with failure\n") + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tfakeapi.Err = errMocked\n") + fmt.Fprint(sb, "\tfakeapi.Response = nil\n") + fmt.Fprint(sb, "\tresp2, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp2 == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp2); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestSetcacheWithEncodeError(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sSetcacheWithEncodeError(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tvar in []%s\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tff.fill(&in)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tGobCodec: &FakeCodec{EncodeErr: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\terr := cache.setcache(in)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestReadCacheNotFound(sb *strings.Builder) { + if fields := d.StructFields(d.Request); len(fields) <= 0 { + // this test cannot work when there are no fields in the + // request because we will always find a match. + // TODO(bassosimone): how to avoid having uncovered code? + return + } + fmt.Fprintf(sb, "func TestCache%sReadCacheNotFound(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar incache []%s\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tff.fill(&incache)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\terr := cache.setcache(incache)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprintf(sb, "\tout, err := cache.readcache(req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errCacheNotFound) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif out != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil here\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWriteCacheDuplicate(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sWriteCacheDuplicate(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprintf(sb, "\tvar resp1 %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&resp1)\n") + fmt.Fprintf(sb, "\tvar resp2 %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&resp2)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\terr := cache.writecache(req, resp1)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\terr = cache.writecache(req, resp2)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tout, err := cache.readcache(req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif out == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil here\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(resp2, out); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestCachSizeLimited(sb *strings.Builder) { + if fields := d.StructFields(d.Request); len(fields) <= 0 { + // this test cannot work when there are no fields in the + // request because we will always find a match. + // TODO(bassosimone): how to avoid having uncovered code? + return + } + fmt.Fprintf(sb, "func TestCache%sCacheSizeLimited(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar prev int\n") + fmt.Fprintf(sb, "\tfor {\n") + fmt.Fprintf(sb, "\t\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\t\tff.fill(&req)\n") + fmt.Fprintf(sb, "\t\tvar resp %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\t\tff.fill(&resp)\n") + fmt.Fprintf(sb, "\t\terr := cache.writecache(req, resp)\n") + fmt.Fprintf(sb, "\t\tif err != nil {\n") + fmt.Fprintf(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t\t}\n") + fmt.Fprintf(sb, "\t\tout, err := cache.getcache()\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif len(out) > prev {\n") + fmt.Fprint(sb, "\t\t\tprev = len(out)\n") + fmt.Fprint(sb, "\t\t\tcontinue\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tbreak\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +// GenCachingTestGo generates caching_test.go. +func GenCachingTestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\t\"testing\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/google/go-cmp/cmp\"\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if desc.CachePolicy == CacheNone { + continue + } + desc.genTestCacheSuccess(&sb) + desc.genTestWriteCacheError(&sb) + desc.genTestFailureWithNoCache(&sb) + desc.genTestFailureWithPreviousCache(&sb) + desc.genTestSetcacheWithEncodeError(&sb) + desc.genTestReadCacheNotFound(&sb) + desc.genTestWriteCacheDuplicate(&sb) + desc.genTestCachSizeLimited(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/callers.go b/internal/engine/ooapi/internal/generator/callers.go new file mode 100644 index 0000000..4ba8ea8 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/callers.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewCaller(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s represents any type exposing a method\n", + d.CallerInterfaceName()) + fmt.Fprintf(sb, "// like %s.Call.\n", d.APIStructName()) + fmt.Fprintf(sb, "type %s interface {\n", d.CallerInterfaceName()) + fmt.Fprintf(sb, "\tCall(ctx context.Context, req %s) (%s, error)\n", + d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "}\n\n") +} + +// GenCallersGo generates callers.go. +func GenCallersGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genNewCaller(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/cloners.go b/internal/engine/ooapi/internal/generator/cloners.go new file mode 100644 index 0000000..1db6d92 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/cloners.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewCloner(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s represents any type exposing a method\n", + d.CallerInterfaceName()) + fmt.Fprintf(sb, "// like %s.WithToken.\n", d.APIStructName()) + fmt.Fprintf(sb, "type %s interface {\n", d.ClonerInterfaceName()) + fmt.Fprintf(sb, "\tWithToken(token string) %s\n", d.CallerInterfaceName()) + fmt.Fprint(sb, "}\n\n") +} + +// GenClonersGo generates cloners.go. +func GenClonersGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + for _, desc := range Descriptors { + if !desc.RequiresLogin { + continue + } + desc.genNewCloner(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/fakeapitest.go b/internal/engine/ooapi/internal/generator/fakeapitest.go new file mode 100644 index 0000000..deb2b56 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/fakeapitest.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewFakeAPI(sb *strings.Builder) { + fmt.Fprintf(sb, "type Fake%s struct {\n", d.APIStructName()) + if d.RequiresLogin { + fmt.Fprintf(sb, "\tWithResult %s\n", d.CallerInterfaceName()) + } + fmt.Fprint(sb, "\tErr error\n") + fmt.Fprintf(sb, "\tResponse %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tCountCall int32\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (fapi *Fake%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.APIStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\tatomic.AddInt32(&fapi.CountCall, 1)\n") + fmt.Fprint(sb, "\treturn fapi.Response, fapi.Err\n") + fmt.Fprint(sb, "}\n\n") + + if d.RequiresLogin { + fmt.Fprintf(sb, "func (fapi *Fake%s) WithToken(token string) %s {\n", + d.APIStructName(), d.CallerInterfaceName()) + fmt.Fprint(sb, "\treturn fapi.WithResult\n") + fmt.Fprint(sb, "}\n\n") + } + + fmt.Fprint(sb, "var (\n") + fmt.Fprintf(sb, "\t_ %s = &Fake%s{}\n", d.CallerInterfaceName(), + d.APIStructName()) + if d.RequiresLogin { + fmt.Fprintf(sb, "\t_ %s = &Fake%s{}\n", d.ClonerInterfaceName(), + d.APIStructName()) + } + fmt.Fprint(sb, ")\n\n") +} + +// GenFakeAPITestGo generates fakeapi_test.go. +func GenFakeAPITestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"sync/atomic\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genNewFakeAPI(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/generator.go b/internal/engine/ooapi/internal/generator/generator.go new file mode 100644 index 0000000..6158767 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/generator.go @@ -0,0 +1,53 @@ +// Command generator generates code in the ooapi package. +// +// To this end, it uses the content of the apimodel package as +// well as the content of the spec.go file. +// +// The apimodel package defines the model, i.e., the structure +// of requests and responses and how messages should be sent +// and received. +// +// The spec.go file describes all the implemented APIs. +// +// If you change apimodel or spec.go, remember to run the +// `go generate ./...` command to regenerate all files. +package main + +import ( + "flag" + "fmt" +) + +var flagFile = flag.String("file", "", "Indicate which file to regenerate") + +func main() { + flag.Parse() + switch file := *flagFile; file { + case "apis.go": + GenAPIsGo(file) + case "responses.go": + GenResponsesGo(file) + case "requests.go": + GenRequestsGo(file) + case "swagger_test.go": + GenSwaggerTestGo(file) + case "apis_test.go": + GenAPIsTestGo(file) + case "callers.go": + GenCallersGo(file) + case "caching.go": + GenCachingGo(file) + case "login.go": + GenLoginGo(file) + case "cloners.go": + GenClonersGo(file) + case "fakeapi_test.go": + GenFakeAPITestGo(file) + case "caching_test.go": + GenCachingTestGo(file) + case "login_test.go": + GenLoginTestGo(file) + default: + panic(fmt.Sprintf("don't know how to create this file: %s", file)) + } +} diff --git a/internal/engine/ooapi/internal/generator/login.go b/internal/engine/ooapi/internal/generator/login.go new file mode 100644 index 0000000..4328d2e --- /dev/null +++ b/internal/engine/ooapi/internal/generator/login.go @@ -0,0 +1,182 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewLogin(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s implements login for %s.\n", + d.WithLoginAPIStructName(), d.APIStructName()) + fmt.Fprintf(sb, "type %s struct {\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI %s // mandatory\n", d.ClonerInterfaceName()) + fmt.Fprint(sb, "\tJSONCodec JSONCodec // optional\n") + fmt.Fprint(sb, "\tKVStore KVStore // mandatory\n") + fmt.Fprint(sb, "\tRegisterAPI RegisterCaller // mandatory\n") + fmt.Fprint(sb, "\tLoginAPI LoginCaller // mandatory\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "// Call logins, if needed, then calls the API.\n") + fmt.Fprintf(sb, "func (api *%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.WithLoginAPIStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\ttoken, err := api.maybeLogin(ctx)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tresp, err := api.API.WithToken(token).Call(ctx, req)\n") + fmt.Fprint(sb, "\tif errors.Is(err, ErrUnauthorized) {\n") + fmt.Fprint(sb, "\t\t// Maybe the clock is just off? Let's try to obtain\n") + fmt.Fprint(sb, "\t\t// a token again and see if this fixes it.\n") + fmt.Fprint(sb, "\t\tif token, err = api.forceLogin(ctx); err == nil {\n") + fmt.Fprint(sb, "\t\t\tswitch resp, err = api.API.WithToken(token).Call(ctx, req); err {\n") + fmt.Fprint(sb, "\t\t\tcase nil:\n") + fmt.Fprint(sb, "\t\t\t\treturn resp, nil\n") + fmt.Fprint(sb, "\t\t\tcase ErrUnauthorized:\n") + fmt.Fprint(sb, "\t\t\t\t// fallthrough\n") + fmt.Fprint(sb, "\t\t\tdefault:\n") + fmt.Fprint(sb, "\t\t\t\treturn nil, err\n") + fmt.Fprint(sb, "\t\t\t}\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\t// Okay, this seems a broader problem. How about we try\n") + fmt.Fprint(sb, "\t\t// and re-register ourselves again instead?\n") + fmt.Fprint(sb, "\t\ttoken, err = api.forceRegister(ctx)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\treturn nil, err\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tresp, err = api.API.WithToken(token).Call(ctx, req)\n") + fmt.Fprint(sb, "\t\t// fallthrough\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn resp, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) jsonCodec() JSONCodec {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tif api.JSONCodec != nil {\n") + fmt.Fprint(sb, "\t\treturn api.JSONCodec\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultJSONCodec{}\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) readstate() (*loginState, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tdata, err := api.KVStore.Get(loginKey)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tvar ls loginState\n") + fmt.Fprint(sb, "\tif err := api.jsonCodec().Decode(data, &ls); err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &ls, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) writestate(ls *loginState) error {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tdata, err := api.jsonCodec().Encode(*ls)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.KVStore.Set(loginKey, data)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) doRegister(ctx context.Context, password string) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\treq := newRegisterRequest(password)\n") + fmt.Fprint(sb, "\tls := &loginState{}\n") + fmt.Fprint(sb, "\tresp, err := api.RegisterAPI.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tls.ClientID = resp.ClientID\n") + fmt.Fprint(sb, "\tls.Password = req.Password\n") + fmt.Fprint(sb, "\treturn api.doLogin(ctx, ls)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) forceRegister(ctx context.Context) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tvar password string\n") + fmt.Fprint(sb, "\t// If we already have a previous password, let us keep\n") + fmt.Fprint(sb, "\t// using it. This will allow a new version of the API to\n") + fmt.Fprint(sb, "\t// be able to continue to identify this probe. (This\n") + fmt.Fprint(sb, "\t// assumes that we have a stateless API that generates\n") + fmt.Fprint(sb, "\t// the user ID as a signature of the password plus a\n") + fmt.Fprint(sb, "\t// timestamp and that the key to generate the signature\n") + fmt.Fprint(sb, "\t// is not lost. If all these conditions are met, we\n") + fmt.Fprint(sb, "\t// can then serve better test targets to more long running\n") + fmt.Fprint(sb, "\t// (and therefore trusted) probes.)\n") + fmt.Fprint(sb, "\tif ls, err := api.readstate(); err == nil {\n") + fmt.Fprint(sb, "\t\tpassword = ls.Password\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif password == \"\" {\n") + fmt.Fprint(sb, "\t\tpassword = newRandomPassword()\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.doRegister(ctx, password)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) forceLogin(ctx context.Context) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tls, err := api.readstate()\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.doLogin(ctx, ls)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) maybeLogin(ctx context.Context) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tls, _ := api.readstate()\n") + fmt.Fprint(sb, "\tif ls == nil || !ls.credentialsValid() {\n") + fmt.Fprint(sb, "\t\treturn api.forceRegister(ctx)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif !ls.tokenValid() {\n") + fmt.Fprint(sb, "\t\treturn api.doLogin(ctx, ls)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn ls.Token, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) doLogin(ctx context.Context, ls *loginState) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\treq := &apimodel.LoginRequest{\n") + fmt.Fprint(sb, "\t\tClientID: ls.ClientID,\n") + fmt.Fprint(sb, "\t\tPassword: ls.Password,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tresp, err := api.LoginAPI.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tls.Token = resp.Token\n") + fmt.Fprint(sb, "\tls.Expire = resp.Expire\n") + fmt.Fprint(sb, "\tif err := api.writestate(ls); err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn ls.Token, nil\n") + fmt.Fprint(sb, "}\n\n") + fmt.Fprintf(sb, "var _ %s = &%s{}\n\n", d.CallerInterfaceName(), + d.WithLoginAPIStructName()) +} + +// GenLoginGo generates login.go. +func GenLoginGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if !desc.RequiresLogin { + continue + } + desc.genNewLogin(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/logintest.go b/internal/engine/ooapi/internal/generator/logintest.go new file mode 100644 index 0000000..d74c9f5 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/logintest.go @@ -0,0 +1,864 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genTestRegisterAndLoginSuccess(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestRegisterAndLogin%sSuccess(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestContinueUsingToken(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sContinueUsingToken(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we disable register and login but we\n") + fmt.Fprint(sb, "\t// should be okay because of the token\n") + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tregisterAPI.Err = errMocked\n") + fmt.Fprint(sb, "\tregisterAPI.Response = nil\n") + fmt.Fprint(sb, "\tloginAPI.Err = errMocked\n") + fmt.Fprint(sb, "\tloginAPI.Response = nil\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithValidButExpiredToken(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithValidButExpiredToken(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tls := &loginState{\n") + fmt.Fprintf(sb, "\t\tClientID: \"antani-antani\",\n") + fmt.Fprintf(sb, "\t\tExpire: time.Now().Add(-5 * time.Second),\n") + fmt.Fprintf(sb, "\t\tToken: \"antani-antani-token\",\n") + fmt.Fprintf(sb, "\t\tPassword: \"antani-antani-password\",\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tif err := login.writestate(ls); err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 0 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithRegisterAPIError(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithRegisterAPIError(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithLoginFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithLoginFailure(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestRegisterAndLoginThenFail(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestRegisterAndLogin%sThenFail(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestTheDatabaseIsReplaced(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sTheDatabaseIsReplaced(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget accounts and try again.\n") + fmt.Fprint(sb, "\thandler.forgetLogins()\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 3 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestTheDatabaseIsReplacedThenFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sTheDatabaseIsReplacedThenFailure(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget accounts and try again.\n") + fmt.Fprint(sb, "\t// but registrations are also failing.\n") + fmt.Fprint(sb, "\thandler.forgetLogins()\n") + fmt.Fprint(sb, "\thandler.noRegister = true\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrHTTPFailure) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestRegisterAndLoginCannotWriteState(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestRegisterAndLogin%sCannotWriteState(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t\tJSONCodec: &FakeCodec{\n") + fmt.Fprint(sb, "\t\t\tEncodeErr: errMocked,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestReadStateDecodeFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sReadStateDecodeFailure(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t\tJSONCodec: &FakeCodec{DecodeErr: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tls := &loginState{\n") + fmt.Fprintf(sb, "\t\tClientID: \"antani-antani\",\n") + fmt.Fprintf(sb, "\t\tExpire: time.Now().Add(-5 * time.Second),\n") + fmt.Fprintf(sb, "\t\tToken: \"antani-antani-token\",\n") + fmt.Fprintf(sb, "\t\tPassword: \"antani-antani-password\",\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tif err := login.writestate(ls); err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + + fmt.Fprintf(sb, "\tout, err := login.forceLogin(context.Background())\n") + fmt.Fprintf(sb, "if !errors.Is(err, errMocked) {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "if out != \"\" {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(\"expected empty string here\")\n") + fmt.Fprintf(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestClockIsOffThenSuccess(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sClockIsOffThenSuccess(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget tokens and try again.\n") + fmt.Fprint(sb, "\t// this should simulate the client clock\n") + fmt.Fprint(sb, "\t// being off and considering a token still valid\n") + fmt.Fprint(sb, "\thandler.forgetTokens()\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestClockIsOffThen401(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sClockIsOffThen401(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget tokens and try again.\n") + fmt.Fprint(sb, "\t// this should simulate the client clock\n") + fmt.Fprint(sb, "\t// being off and considering a token still valid\n") + fmt.Fprint(sb, "\thandler.forgetTokens()\n") + fmt.Fprint(sb, "\thandler.failCallWith = []int{401, 401}\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 3 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestClockIsOffThen500(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sClockIsOffThen500(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget tokens and try again.\n") + fmt.Fprint(sb, "\t// this should simulate the client clock\n") + fmt.Fprint(sb, "\t// being off and considering a token still valid\n") + fmt.Fprint(sb, "\thandler.forgetTokens()\n") + fmt.Fprint(sb, "\thandler.failCallWith = []int{401, 500}\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrHTTPFailure) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +// GenLoginTestGo generates login_test.go. +func GenLoginTestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\t\"net/http/httptest\"\n") + fmt.Fprint(&sb, "\t\"testing\"\n") + fmt.Fprint(&sb, "\t\"time\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/google/go-cmp/cmp\"\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if !desc.RequiresLogin { + continue + } + desc.genTestRegisterAndLoginSuccess(&sb) + desc.genTestContinueUsingToken(&sb) + desc.genTestWithValidButExpiredToken(&sb) + desc.genTestWithRegisterAPIError(&sb) + desc.genTestWithLoginFailure(&sb) + desc.genTestRegisterAndLoginThenFail(&sb) + desc.genTestTheDatabaseIsReplaced(&sb) + desc.genTestRegisterAndLoginCannotWriteState(&sb) + desc.genTestReadStateDecodeFailure(&sb) + desc.genTestTheDatabaseIsReplacedThenFailure(&sb) + desc.genTestClockIsOffThenSuccess(&sb) + desc.genTestClockIsOffThen401(&sb) + desc.genTestClockIsOffThen500(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/reflect.go b/internal/engine/ooapi/internal/generator/reflect.go new file mode 100644 index 0000000..9b2f27b --- /dev/null +++ b/internal/engine/ooapi/internal/generator/reflect.go @@ -0,0 +1,147 @@ +package main + +import ( + "fmt" + "reflect" +) + +// TypeName returns v's package-qualified type name. +func (d *Descriptor) TypeName(v interface{}) string { + return reflect.TypeOf(v).String() +} + +// RequestTypeName calls d.TypeName(d.Request). +func (d *Descriptor) RequestTypeName() string { + return d.TypeName(d.Request) +} + +// ResponseTypeName calls d.TypeName(d.Response). +func (d *Descriptor) ResponseTypeName() string { + return d.TypeName(d.Response) +} + +// APIStructName returns the correct struct type name +// for the API we're currently processing. +func (d *Descriptor) APIStructName() string { + return fmt.Sprintf("%sAPI", d.Name) +} + +// WithLoginAPIStructName returns the correct struct type name +// for the WithLoginAPI we're currently processing. +func (d *Descriptor) WithLoginAPIStructName() string { + return fmt.Sprintf("%sAPIWithLogin", d.Name) +} + +// CallerInterfaceName returns the correct caller interface name +// for the API we're currently processing. +func (d *Descriptor) CallerInterfaceName() string { + return fmt.Sprintf("%sCaller", d.Name) +} + +// ClonerInterfaceName returns the correct cloner interface name +// for the API we're currently processing. +func (d *Descriptor) ClonerInterfaceName() string { + return fmt.Sprintf("%sCloner", d.Name) +} + +// CacheStructName returns the correct struct type name for +// the cache for the API we're currently processing. +func (d *Descriptor) CacheStructName() string { + return fmt.Sprintf("%sCache", d.Name) +} + +// CacheEntryName returns the correct struct type name for the +// cache entry for the API we're currently processing. +func (d *Descriptor) CacheEntryName() string { + return fmt.Sprintf("cacheEntryFor%s", d.Name) +} + +// CacheKey returns the correct cache key for the API +// we're currently processing. +func (d *Descriptor) CacheKey() string { + return fmt.Sprintf("%s.cache", d.Name) +} + +// StructFields returns all the struct fields of in. This function +// assumes that in is a pointer to struct, and will otherwise panic. +func (d *Descriptor) StructFields(in interface{}) []*reflect.StructField { + t := reflect.TypeOf(in) + if t.Kind() != reflect.Ptr { + panic("not a pointer") + } + t = t.Elem() + if t.Kind() != reflect.Struct { + panic("not a struct") + } + var out []*reflect.StructField + for idx := 0; idx < t.NumField(); idx++ { + f := t.Field(idx) + out = append(out, &f) + } + return out +} + +// StructFieldsWithTag returns all the struct fields of +// in that have the specified tag. +func (d *Descriptor) StructFieldsWithTag(in interface{}, tag string) []*reflect.StructField { + var out []*reflect.StructField + for _, f := range d.StructFields(in) { + if f.Tag.Get(tag) != "" { + out = append(out, f) + } + } + return out +} + +// RequestOrResponseTypeKind returns the type kind of in, which should +// be a request or a response. This function assumes that in is either a +// pointer to struct or a map and will panic otherwise. +func (d *Descriptor) RequestOrResponseTypeKind(in interface{}) reflect.Kind { + t := reflect.TypeOf(in) + if t.Kind() == reflect.Ptr { + t = t.Elem() + if t.Kind() != reflect.Struct { + panic("not a struct") + } + return reflect.Struct + } + if t.Kind() != reflect.Map { + panic("not a map") + } + return reflect.Map +} + +// RequestTypeKind calls d.RequestOrResponseTypeKind(d.Request). +func (d *Descriptor) RequestTypeKind() reflect.Kind { + return d.RequestOrResponseTypeKind(d.Request) +} + +// ResponseTypeKind calls d.RequestOrResponseTypeKind(d.Response). +func (d *Descriptor) ResponseTypeKind() reflect.Kind { + return d.RequestOrResponseTypeKind(d.Response) +} + +// TypeNameAsStruct assumes that in is a pointer to struct and +// returns the type of the corresponding struct. The returned +// type is package qualified. +func (d *Descriptor) TypeNameAsStruct(in interface{}) string { + t := reflect.TypeOf(in) + if t.Kind() != reflect.Ptr { + panic("not a pointer") + } + t = t.Elem() + if t.Kind() != reflect.Struct { + panic("not a struct") + } + return t.String() +} + +// RequestTypeNameAsStruct calls d.TypeNameAsStruct(d.Request) +func (d *Descriptor) RequestTypeNameAsStruct() string { + return d.TypeNameAsStruct(d.Request) +} + +// ResponseTypeNameAsStruct calls d.TypeNameAsStruct(d.Response) +func (d *Descriptor) ResponseTypeNameAsStruct() string { + return d.TypeNameAsStruct(d.Response) +} diff --git a/internal/engine/ooapi/internal/generator/requests.go b/internal/engine/ooapi/internal/generator/requests.go new file mode 100644 index 0000000..1e02fa2 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/requests.go @@ -0,0 +1,141 @@ +package main + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +const ( + tagForQuery = "query" + tagForRequired = "required" +) + +func (d *Descriptor) genNewRequestQueryElemString(sb *strings.Builder, f *reflect.StructField) { + name := f.Name + query := f.Tag.Get(tagForQuery) + if f.Tag.Get(tagForRequired) == "true" { + fmt.Fprintf(sb, "\tif req.%s == \"\" {\n", name) + fmt.Fprintf(sb, "\t\treturn nil, newErrEmptyField(\"%s\")\n", name) + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tq.Add(\"%s\", req.%s)\n", query, name) + return + } + fmt.Fprintf(sb, "\tif req.%s != \"\" {\n", name) + fmt.Fprintf(sb, "\t\tq.Add(\"%s\", req.%s)\n", query, name) + fmt.Fprint(sb, "\t}\n") +} + +func (d *Descriptor) genNewRequestQueryElemBool(sb *strings.Builder, f *reflect.StructField) { + // required does not make much sense for a boolean field + name := f.Name + query := f.Tag.Get(tagForQuery) + fmt.Fprintf(sb, "\tif req.%s {\n", name) + fmt.Fprintf(sb, "\t\tq.Add(\"%s\", \"true\")\n", query) + fmt.Fprint(sb, "\t}\n") +} + +func (d *Descriptor) genNewRequestQueryElemInt64(sb *strings.Builder, f *reflect.StructField) { + // required does not make much sense for an integer field + name := f.Name + query := f.Tag.Get(tagForQuery) + fmt.Fprintf(sb, "\tif req.%s != 0 {\n", name) + fmt.Fprintf(sb, "\t\tq.Add(\"%s\", newQueryFieldInt64(req.%s))\n", query, name) + fmt.Fprint(sb, "\t}\n") +} + +func (d *Descriptor) genNewRequestQuery(sb *strings.Builder) { + if d.Method != "GET" { + return // we only generate query for GET + } + fields := d.StructFieldsWithTag(d.Request, tagForQuery) + if len(fields) <= 0 { + return + } + fmt.Fprint(sb, "\tq := url.Values{}\n") + for idx, f := range fields { + switch f.Type.Kind() { + case reflect.String: + d.genNewRequestQueryElemString(sb, f) + case reflect.Bool: + d.genNewRequestQueryElemBool(sb, f) + case reflect.Int64: + d.genNewRequestQueryElemInt64(sb, f) + default: + panic(fmt.Sprintf("unexpected query type at index %d", idx)) + } + } + fmt.Fprint(sb, "\tURL.RawQuery = q.Encode()\n") +} + +func (d *Descriptor) genNewRequestCallNewRequest(sb *strings.Builder) { + if d.Method == "POST" { + fmt.Fprint(sb, "\tbody, err := api.jsonCodec().Encode(req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tout, err := api.requestMaker().NewRequest(") + fmt.Fprintf(sb, "ctx, \"%s\", URL.String(), ", d.Method) + fmt.Fprint(sb, "bytes.NewReader(body))\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tout.Header.Set(\"Content-Type\", \"application/json\")\n") + fmt.Fprint(sb, "\treturn out, nil\n") + return + } + fmt.Fprint(sb, "\treturn api.requestMaker().NewRequest(") + fmt.Fprintf(sb, "ctx, \"%s\", URL.String(), ", d.Method) + fmt.Fprint(sb, "nil)\n") +} + +func (d *Descriptor) genNewRequest(sb *strings.Builder) { + + fmt.Fprintf( + sb, "func (api *%s) newRequest(ctx context.Context, req %s) %s {\n", + d.APIStructName(), d.RequestTypeName(), "(*http.Request, error)") + fmt.Fprint(sb, "\tURL, err := url.Parse(api.baseURL())\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + + switch d.URLPath.IsTemplate { + case false: + fmt.Fprintf(sb, "\tURL.Path = \"%s\"\n", d.URLPath.Value) + case true: + fmt.Fprintf( + sb, "\tup, err := api.templateExecutor().Execute(\"%s\", req)\n", + d.URLPath.Value) + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tURL.Path = up\n") + } + + d.genNewRequestQuery(sb) + d.genNewRequestCallNewRequest(sb) + + fmt.Fprintf(sb, "}\n\n") +} + +// GenRequestsGo generates requests.go. +func GenRequestsGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"bytes\"\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\t\"net/url\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n\n") + for _, desc := range Descriptors { + desc.genNewRequest(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/responses.go b/internal/engine/ooapi/internal/generator/responses.go new file mode 100644 index 0000000..18ce9e2 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/responses.go @@ -0,0 +1,80 @@ +package main + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +func (d *Descriptor) genNewResponse(sb *strings.Builder) { + fmt.Fprintf(sb, + "func (api *%s) newResponse(resp *http.Response, err error) (%s, error) {\n", + d.APIStructName(), d.ResponseTypeName()) + + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp.StatusCode == 401 {\n") + fmt.Fprint(sb, "\t\treturn nil, ErrUnauthorized\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp.StatusCode != 200 {\n") + fmt.Fprint(sb, "\t\treturn nil, newHTTPFailure(resp.StatusCode)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tdefer resp.Body.Close()\n") + fmt.Fprint(sb, "\treader := io.LimitReader(resp.Body, 4<<20)\n") + fmt.Fprint(sb, "\tdata, err := ioutil.ReadAll(reader)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + + switch d.ResponseTypeKind() { + case reflect.Map: + fmt.Fprintf(sb, "\tout := %s{}\n", d.ResponseTypeName()) + case reflect.Struct: + fmt.Fprintf(sb, "\tout := &%s{}\n", d.ResponseTypeNameAsStruct()) + } + + switch d.ResponseTypeKind() { + case reflect.Map: + fmt.Fprint(sb, "\tif err := api.jsonCodec().Decode(data, &out); err != nil {\n") + case reflect.Struct: + fmt.Fprint(sb, "\tif err := api.jsonCodec().Decode(data, out); err != nil {\n") + } + + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + + switch d.ResponseTypeKind() { + case reflect.Map: + // For rationale, see https://play.golang.org/p/m9-MsTaQ5wt and + // https://play.golang.org/p/6h-v-PShMk9. + fmt.Fprint(sb, "\tif out == nil {\n") + fmt.Fprint(sb, "\t\treturn nil, ErrJSONLiteralNull\n") + fmt.Fprint(sb, "\t}\n") + case reflect.Struct: + // nothing + } + fmt.Fprintf(sb, "\treturn out, nil\n") + fmt.Fprintf(sb, "}\n\n") +} + +// GenResponsesGo generates responses.go. +func GenResponsesGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"io\"\n") + fmt.Fprint(&sb, "\t\"io/ioutil\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n\n") + for _, desc := range Descriptors { + desc.genNewResponse(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/spec.go b/internal/engine/ooapi/internal/generator/spec.go new file mode 100644 index 0000000..01bd0c8 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/spec.go @@ -0,0 +1,136 @@ +package main + +import "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" + +// URLPath describes a URLPath. +type URLPath struct { + // IsTemplate indicates whether Value contains a template. A future + // version of this implementation will automatically deduce that. + IsTemplate bool + + // Value is the value of the URL path. + Value string + + // InSwagger indicates the corresponding name to be used in + // the Swagger specification. + InSwagger string +} + +// Descriptor is an API descriptor. It tells the generator +// what code it should emit for a given API. +type Descriptor struct { + // Name is the name of the API. + Name string + + // CachePolicy indicates the caching policy to use. + CachePolicy int + + // RequiresLogin indicates whether the API requires login. + RequiresLogin bool + + // Method is the method to use ("GET" or "POST"). + Method string + + // URLPath is the URL path. + URLPath URLPath + + // Request is an instance of the request type. + Request interface{} + + // Response is an instance of the response type. + Response interface{} +} + +// These are the caching policies. +const ( + // CacheNone indicates we don't use a cache. + CacheNone = iota + + // CacheFallback indicates we fallback to the cache + // when there is a failure. + CacheFallback + + // CacheAlways indicates that we always check the + // cache before sending a request. + CacheAlways +) + +// Descriptors describes all the APIs. +// +// Note that it matters whether the requests and responses +// are pointers. Generally speaking, if the message is a +// struct, use a pointer. If it's a map, don't. +var Descriptors = []Descriptor{{ + Name: "CheckReportID", + Method: "GET", + URLPath: URLPath{Value: "/api/_/check_report_id"}, + Request: &apimodel.CheckReportIDRequest{}, + Response: &apimodel.CheckReportIDResponse{}, +}, { + Name: "CheckIn", + Method: "POST", + URLPath: URLPath{Value: "/api/v1/check-in"}, + Request: &apimodel.CheckInRequest{}, + Response: &apimodel.CheckInResponse{}, +}, { + Name: "Login", + Method: "POST", + URLPath: URLPath{Value: "/api/v1/login"}, + Request: &apimodel.LoginRequest{}, + Response: &apimodel.LoginResponse{}, +}, { + Name: "MeasurementMeta", + Method: "GET", + URLPath: URLPath{Value: "/api/v1/measurement_meta"}, + Request: &apimodel.MeasurementMetaRequest{}, + Response: &apimodel.MeasurementMetaResponse{}, + CachePolicy: CacheAlways, +}, { + Name: "Register", + Method: "POST", + URLPath: URLPath{Value: "/api/v1/register"}, + Request: &apimodel.RegisterRequest{}, + Response: &apimodel.RegisterResponse{}, +}, { + Name: "TestHelpers", + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-helpers"}, + Request: &apimodel.TestHelpersRequest{}, + Response: apimodel.TestHelpersResponse{}, +}, { + Name: "PsiphonConfig", + RequiresLogin: true, + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-list/psiphon-config"}, + Request: &apimodel.PsiphonConfigRequest{}, + Response: apimodel.PsiphonConfigResponse{}, +}, { + Name: "TorTargets", + RequiresLogin: true, + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-list/tor-targets"}, + Request: &apimodel.TorTargetsRequest{}, + Response: apimodel.TorTargetsResponse{}, +}, { + Name: "URLs", + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-list/urls"}, + Request: &apimodel.URLsRequest{}, + Response: &apimodel.URLsResponse{}, +}, { + Name: "OpenReport", + Method: "POST", + URLPath: URLPath{Value: "/report"}, + Request: &apimodel.OpenReportRequest{}, + Response: &apimodel.OpenReportResponse{}, +}, { + Name: "SubmitMeasurement", + Method: "POST", + URLPath: URLPath{ + InSwagger: "/report/{report_id}", + IsTemplate: true, + Value: "/report/{{ .ReportID }}", + }, + Request: &apimodel.SubmitMeasurementRequest{}, + Response: &apimodel.SubmitMeasurementResponse{}, +}} diff --git a/internal/engine/ooapi/internal/generator/swaggertest.go b/internal/engine/ooapi/internal/generator/swaggertest.go new file mode 100644 index 0000000..3fadc38 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/swaggertest.go @@ -0,0 +1,194 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "reflect" + "strings" + "sync" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/internal/openapi" +) + +const ( + tagForJSON = "json" + tagForPath = "path" +) + +func (d *Descriptor) genSwaggerURLPath() string { + up := d.URLPath + if up.InSwagger != "" { + return up.InSwagger + } + if up.IsTemplate { + panic("we should always use InSwapper and IsTemplate together") + } + return up.Value +} + +func (d *Descriptor) genSwaggerSchema(cur reflect.Type) *openapi.Schema { + switch cur.Kind() { + case reflect.String: + return &openapi.Schema{Type: "string"} + case reflect.Bool: + return &openapi.Schema{Type: "boolean"} + case reflect.Int64: + return &openapi.Schema{Type: "integer"} + case reflect.Slice: + return &openapi.Schema{Type: "array", Items: d.genSwaggerSchema(cur.Elem())} + case reflect.Map: + return &openapi.Schema{Type: "object"} + case reflect.Ptr: + return d.genSwaggerSchema(cur.Elem()) + case reflect.Struct: + if cur.String() == "time.Time" { + // Implementation note: we don't want to dive into time.Time but + // rather we want to pretend it's a string. The JSON parser for + // time.Time can indeed reconstruct a time.Time from a string, and + // it's much easier for us to let it do the parsing. + return &openapi.Schema{Type: "string"} + } + sinfo := &openapi.Schema{Type: "object"} + var once sync.Once + initmap := func() { + sinfo.Properties = make(map[string]*openapi.Schema) + } + for idx := 0; idx < cur.NumField(); idx++ { + field := cur.Field(idx) + if field.Tag.Get(tagForPath) != "" { + continue // skipping because this is a path param + } + if field.Tag.Get(tagForQuery) != "" { + continue // skipping because this is a query param + } + v := field.Name + if j := field.Tag.Get(tagForJSON); j != "" { + j = strings.Replace(j, ",omitempty", "", 1) // remove options + if j == "-" { + continue // not exported via JSON + } + v = j + } + once.Do(initmap) + sinfo.Properties[v] = d.genSwaggerSchema(field.Type) + } + return sinfo + case reflect.Interface: + return &openapi.Schema{Type: "object"} + default: + panic("unsupported type") + } +} + +func (d *Descriptor) swaggerParamForType(t reflect.Type) string { + switch t.Kind() { + case reflect.String: + return "string" + case reflect.Bool: + return "boolean" + case reflect.Int64: + return "integer" + default: + panic("unsupported type") + } +} + +func (d *Descriptor) genSwaggerParams(cur reflect.Type) []*openapi.Parameter { + // when we have params the input must be a pointer to struct + if cur.Kind() != reflect.Ptr { + panic("not a pointer") + } + cur = cur.Elem() + if cur.Kind() != reflect.Struct { + panic("not a pointer to struct") + } + // now that we're sure of the type, inspect the fields + var out []*openapi.Parameter + for idx := 0; idx < cur.NumField(); idx++ { + f := cur.Field(idx) + if q := f.Tag.Get(tagForQuery); q != "" { + out = append( + out, &openapi.Parameter{ + Name: q, + In: "query", + Required: f.Tag.Get(tagForRequired) == "true", + Type: d.swaggerParamForType(f.Type), + }) + continue + } + if p := f.Tag.Get(tagForPath); p != "" { + out = append(out, &openapi.Parameter{ + Name: p, + In: "path", + Required: true, + Type: d.swaggerParamForType(f.Type), + }) + continue + } + } + return out +} + +func (d *Descriptor) genSwaggerPath() (string, *openapi.Path) { + pathStr, pathInfo := d.genSwaggerURLPath(), &openapi.Path{} + rtinfo := &openapi.RoundTrip{Produces: []string{"application/json"}} + switch d.Method { + case "GET": + pathInfo.Get = rtinfo + case "POST": + rtinfo.Consumes = append(rtinfo.Consumes, "application/json") + pathInfo.Post = rtinfo + default: + panic("unsupported method") + } + rtinfo.Parameters = d.genSwaggerParams(reflect.TypeOf(d.Request)) + if d.Method != "GET" { + rtinfo.Parameters = append(rtinfo.Parameters, &openapi.Parameter{ + Name: "body", + In: "body", + Required: true, + Schema: d.genSwaggerSchema(reflect.TypeOf(d.Request)), + }) + } + rtinfo.Responses = &openapi.Responses{Successful: openapi.Body{ + Description: "all good", + Schema: d.genSwaggerSchema(reflect.TypeOf(d.Response)), + }} + return pathStr, pathInfo +} + +func genSwaggerVersion() string { + return time.Now().UTC().Format("0.20060102.1150405") +} + +// GenSwaggerTestGo generates swagger_test.go +func GenSwaggerTestGo(file string) { + swagger := openapi.Swagger{ + Swagger: "2.0", + Info: openapi.API{ + Title: "OONI API specification", + Version: genSwaggerVersion(), + }, + Host: "api.ooni.io", + BasePath: "/", + Schemes: []string{"https"}, + Paths: make(map[string]*openapi.Path), + } + for _, desc := range Descriptors { + pathStr, pathInfo := desc.genSwaggerPath() + swagger.Paths[pathStr] = pathInfo + } + data, err := json.MarshalIndent(swagger, "", " ") + if err != nil { + log.Fatal(err) + } + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprintf(&sb, "const swagger = `%s`\n", string(data)) + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/writefile.go b/internal/engine/ooapi/internal/generator/writefile.go new file mode 100644 index 0000000..48bbf2a --- /dev/null +++ b/internal/engine/ooapi/internal/generator/writefile.go @@ -0,0 +1,27 @@ +package main + +import ( + "fmt" + "log" + "os" + "strings" + + "golang.org/x/sys/execabs" +) + +func writefile(name string, sb *strings.Builder) { + filep, err := os.Create(name) + if err != nil { + log.Fatal(err) + } + if _, err := fmt.Fprint(filep, sb.String()); err != nil { + log.Fatal(err) + } + if err := filep.Close(); err != nil { + log.Fatal(err) + } + cmd := execabs.Command("go", "fmt", name) + if err := cmd.Run(); err != nil { + log.Fatal(err) + } +} diff --git a/internal/engine/ooapi/internal/openapi/openapi.go b/internal/engine/ooapi/internal/openapi/openapi.go new file mode 100644 index 0000000..d20d4bb --- /dev/null +++ b/internal/engine/ooapi/internal/openapi/openapi.go @@ -0,0 +1,64 @@ +// Package openapi contains data structures for Swagger v2.0. +// +// We use these data structures to compare the API specification we +// have here with the one of the server. +package openapi + +// Schema is the schema of a specific parameter or +// or the schema used by the response body +type Schema struct { + Properties map[string]*Schema `json:"properties,omitempty"` + Items *Schema `json:"items,omitempty"` + Type string `json:"type"` +} + +// Parameter describes an input parameter, which could be in the +// URL path, in the query string, or in the request body +type Parameter struct { + In string `json:"in"` + Name string `json:"name"` + Required bool `json:"required,omitempty"` + Schema *Schema `json:"schema,omitempty"` + Type string `json:"type,omitempty"` +} + +// Body describes a response body +type Body struct { + Description interface{} `json:"description,omitempty"` + Schema *Schema `json:"schema"` +} + +// Responses describes the possible responses +type Responses struct { + Successful Body `json:"200"` +} + +// RoundTrip describes an HTTP round trip with a given method and path +type RoundTrip struct { + Consumes []string `json:"consumes,omitempty"` + Produces []string `json:"produces,omitempty"` + Parameters []*Parameter `json:"parameters,omitempty"` + Responses *Responses `json:"responses,omitempty"` +} + +// Path describes a path served by the API +type Path struct { + Get *RoundTrip `json:"get,omitempty"` + Post *RoundTrip `json:"post,omitempty"` +} + +// API contains info about the API +type API struct { + Title string `json:"title"` + Version string `json:"version"` +} + +// Swagger is the toplevel structure +type Swagger struct { + Swagger string `json:"swagger"` + Info API `json:"info"` + Host string `json:"host"` + BasePath string `json:"basePath"` + Schemes []string `json:"schemes"` + Paths map[string]*Path `json:"paths"` +} diff --git a/internal/engine/ooapi/kvstore_test.go b/internal/engine/ooapi/kvstore_test.go new file mode 100644 index 0000000..31c7e71 --- /dev/null +++ b/internal/engine/ooapi/kvstore_test.go @@ -0,0 +1,34 @@ +package ooapi + +import ( + "errors" + "fmt" + "sync" +) + +var errMemkvstoreNotFound = errors.New("memkvstore: not found") + +type memkvstore struct { + m map[string][]byte + mu sync.Mutex +} + +func (kvs *memkvstore) Get(key string) ([]byte, error) { + defer kvs.mu.Unlock() + kvs.mu.Lock() + out, good := kvs.m[key] + if !good { + return nil, fmt.Errorf("%w: %s", errMemkvstoreNotFound, key) + } + return out, nil +} + +func (kvs *memkvstore) Set(key string, value []byte) error { + defer kvs.mu.Unlock() + kvs.mu.Lock() + if kvs.m == nil { + kvs.m = make(map[string][]byte) + } + kvs.m[key] = value + return nil +} diff --git a/internal/engine/ooapi/login.go b/internal/engine/ooapi/login.go new file mode 100644 index 0000000..49f9833 --- /dev/null +++ b/internal/engine/ooapi/login.go @@ -0,0 +1,295 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.62521737 +0100 CET m=+0.000161706 + +package ooapi + +//go:generate go run ./internal/generator -file login.go + +import ( + "context" + "errors" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// PsiphonConfigAPIWithLogin implements login for PsiphonConfigAPI. +type PsiphonConfigAPIWithLogin struct { + API PsiphonConfigCloner // mandatory + JSONCodec JSONCodec // optional + KVStore KVStore // mandatory + RegisterAPI RegisterCaller // mandatory + LoginAPI LoginCaller // mandatory +} + +// Call logins, if needed, then calls the API. +func (api *PsiphonConfigAPIWithLogin) Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) { + token, err := api.maybeLogin(ctx) + if err != nil { + return nil, err + } + resp, err := api.API.WithToken(token).Call(ctx, req) + if errors.Is(err, ErrUnauthorized) { + // Maybe the clock is just off? Let's try to obtain + // a token again and see if this fixes it. + if token, err = api.forceLogin(ctx); err == nil { + switch resp, err = api.API.WithToken(token).Call(ctx, req); err { + case nil: + return resp, nil + case ErrUnauthorized: + // fallthrough + default: + return nil, err + } + } + // Okay, this seems a broader problem. How about we try + // and re-register ourselves again instead? + token, err = api.forceRegister(ctx) + if err != nil { + return nil, err + } + resp, err = api.API.WithToken(token).Call(ctx, req) + // fallthrough + } + if err != nil { + return nil, err + } + return resp, nil +} + +func (api *PsiphonConfigAPIWithLogin) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *PsiphonConfigAPIWithLogin) readstate() (*loginState, error) { + data, err := api.KVStore.Get(loginKey) + if err != nil { + return nil, err + } + var ls loginState + if err := api.jsonCodec().Decode(data, &ls); err != nil { + return nil, err + } + return &ls, nil +} + +func (api *PsiphonConfigAPIWithLogin) writestate(ls *loginState) error { + data, err := api.jsonCodec().Encode(*ls) + if err != nil { + return err + } + return api.KVStore.Set(loginKey, data) +} + +func (api *PsiphonConfigAPIWithLogin) doRegister(ctx context.Context, password string) (string, error) { + req := newRegisterRequest(password) + ls := &loginState{} + resp, err := api.RegisterAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.ClientID = resp.ClientID + ls.Password = req.Password + return api.doLogin(ctx, ls) +} + +func (api *PsiphonConfigAPIWithLogin) forceRegister(ctx context.Context) (string, error) { + var password string + // If we already have a previous password, let us keep + // using it. This will allow a new version of the API to + // be able to continue to identify this probe. (This + // assumes that we have a stateless API that generates + // the user ID as a signature of the password plus a + // timestamp and that the key to generate the signature + // is not lost. If all these conditions are met, we + // can then serve better test targets to more long running + // (and therefore trusted) probes.) + if ls, err := api.readstate(); err == nil { + password = ls.Password + } + if password == "" { + password = newRandomPassword() + } + return api.doRegister(ctx, password) +} + +func (api *PsiphonConfigAPIWithLogin) forceLogin(ctx context.Context) (string, error) { + ls, err := api.readstate() + if err != nil { + return "", err + } + return api.doLogin(ctx, ls) +} + +func (api *PsiphonConfigAPIWithLogin) maybeLogin(ctx context.Context) (string, error) { + ls, _ := api.readstate() + if ls == nil || !ls.credentialsValid() { + return api.forceRegister(ctx) + } + if !ls.tokenValid() { + return api.doLogin(ctx, ls) + } + return ls.Token, nil +} + +func (api *PsiphonConfigAPIWithLogin) doLogin(ctx context.Context, ls *loginState) (string, error) { + req := &apimodel.LoginRequest{ + ClientID: ls.ClientID, + Password: ls.Password, + } + resp, err := api.LoginAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.Token = resp.Token + ls.Expire = resp.Expire + if err := api.writestate(ls); err != nil { + return "", err + } + return ls.Token, nil +} + +var _ PsiphonConfigCaller = &PsiphonConfigAPIWithLogin{} + +// TorTargetsAPIWithLogin implements login for TorTargetsAPI. +type TorTargetsAPIWithLogin struct { + API TorTargetsCloner // mandatory + JSONCodec JSONCodec // optional + KVStore KVStore // mandatory + RegisterAPI RegisterCaller // mandatory + LoginAPI LoginCaller // mandatory +} + +// Call logins, if needed, then calls the API. +func (api *TorTargetsAPIWithLogin) Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) { + token, err := api.maybeLogin(ctx) + if err != nil { + return nil, err + } + resp, err := api.API.WithToken(token).Call(ctx, req) + if errors.Is(err, ErrUnauthorized) { + // Maybe the clock is just off? Let's try to obtain + // a token again and see if this fixes it. + if token, err = api.forceLogin(ctx); err == nil { + switch resp, err = api.API.WithToken(token).Call(ctx, req); err { + case nil: + return resp, nil + case ErrUnauthorized: + // fallthrough + default: + return nil, err + } + } + // Okay, this seems a broader problem. How about we try + // and re-register ourselves again instead? + token, err = api.forceRegister(ctx) + if err != nil { + return nil, err + } + resp, err = api.API.WithToken(token).Call(ctx, req) + // fallthrough + } + if err != nil { + return nil, err + } + return resp, nil +} + +func (api *TorTargetsAPIWithLogin) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *TorTargetsAPIWithLogin) readstate() (*loginState, error) { + data, err := api.KVStore.Get(loginKey) + if err != nil { + return nil, err + } + var ls loginState + if err := api.jsonCodec().Decode(data, &ls); err != nil { + return nil, err + } + return &ls, nil +} + +func (api *TorTargetsAPIWithLogin) writestate(ls *loginState) error { + data, err := api.jsonCodec().Encode(*ls) + if err != nil { + return err + } + return api.KVStore.Set(loginKey, data) +} + +func (api *TorTargetsAPIWithLogin) doRegister(ctx context.Context, password string) (string, error) { + req := newRegisterRequest(password) + ls := &loginState{} + resp, err := api.RegisterAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.ClientID = resp.ClientID + ls.Password = req.Password + return api.doLogin(ctx, ls) +} + +func (api *TorTargetsAPIWithLogin) forceRegister(ctx context.Context) (string, error) { + var password string + // If we already have a previous password, let us keep + // using it. This will allow a new version of the API to + // be able to continue to identify this probe. (This + // assumes that we have a stateless API that generates + // the user ID as a signature of the password plus a + // timestamp and that the key to generate the signature + // is not lost. If all these conditions are met, we + // can then serve better test targets to more long running + // (and therefore trusted) probes.) + if ls, err := api.readstate(); err == nil { + password = ls.Password + } + if password == "" { + password = newRandomPassword() + } + return api.doRegister(ctx, password) +} + +func (api *TorTargetsAPIWithLogin) forceLogin(ctx context.Context) (string, error) { + ls, err := api.readstate() + if err != nil { + return "", err + } + return api.doLogin(ctx, ls) +} + +func (api *TorTargetsAPIWithLogin) maybeLogin(ctx context.Context) (string, error) { + ls, _ := api.readstate() + if ls == nil || !ls.credentialsValid() { + return api.forceRegister(ctx) + } + if !ls.tokenValid() { + return api.doLogin(ctx, ls) + } + return ls.Token, nil +} + +func (api *TorTargetsAPIWithLogin) doLogin(ctx context.Context, ls *loginState) (string, error) { + req := &apimodel.LoginRequest{ + ClientID: ls.ClientID, + Password: ls.Password, + } + resp, err := api.LoginAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.Token = resp.Token + ls.Expire = resp.Expire + if err := api.writestate(ls); err != nil { + return "", err + } + return ls.Token, nil +} + +var _ TorTargetsCaller = &TorTargetsAPIWithLogin{} diff --git a/internal/engine/ooapi/login_test.go b/internal/engine/ooapi/login_test.go new file mode 100644 index 0000000..c5b3e44 --- /dev/null +++ b/internal/engine/ooapi/login_test.go @@ -0,0 +1,1365 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.9205436 +0100 CET m=+0.000137951 + +package ooapi + +//go:generate go run ./internal/generator -file login_test.go + +import ( + "context" + "errors" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func TestRegisterAndLoginPsiphonConfigAPISuccess(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIContinueUsingToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } + } + // step 2: we disable register and login but we + // should be okay because of the token + errMocked := errors.New("mocked error") + registerAPI.Err = errMocked + registerAPI.Response = nil + loginAPI.Err = errMocked + loginAPI.Response = nil + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIWithValidButExpiredToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 0 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIWithRegisterAPIError(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIWithLoginFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + errMocked := errors.New("mocked error") + loginAPI := &FakeLoginAPI{ + Err: errMocked, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestRegisterAndLoginPsiphonConfigAPIThenFail(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Err: errMocked, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPITheDatabaseIsReplaced(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + handler.forgetLogins() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestRegisterAndLoginPsiphonConfigAPICannotWriteState(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{ + EncodeErr: errMocked, + }, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIReadStateDecodeFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + login := &PsiphonConfigAPIWithLogin{ + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + out, err := login.forceLogin(context.Background()) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string here") + } +} + +func TestPsiphonConfigAPITheDatabaseIsReplacedThenFailure(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + // but registrations are also failing. + handler.forgetLogins() + handler.noRegister = true + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestPsiphonConfigAPIClockIsOffThenSuccess(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} + +func TestPsiphonConfigAPIClockIsOffThen401(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 401} + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal("not the error we expected", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestPsiphonConfigAPIClockIsOffThen500(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 500} + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} + +func TestRegisterAndLoginTorTargetsAPISuccess(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIContinueUsingToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } + } + // step 2: we disable register and login but we + // should be okay because of the token + errMocked := errors.New("mocked error") + registerAPI.Err = errMocked + registerAPI.Response = nil + loginAPI.Err = errMocked + loginAPI.Response = nil + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIWithValidButExpiredToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 0 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIWithRegisterAPIError(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIWithLoginFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + errMocked := errors.New("mocked error") + loginAPI := &FakeLoginAPI{ + Err: errMocked, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestRegisterAndLoginTorTargetsAPIThenFail(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Err: errMocked, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPITheDatabaseIsReplaced(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + handler.forgetLogins() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestRegisterAndLoginTorTargetsAPICannotWriteState(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{ + EncodeErr: errMocked, + }, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIReadStateDecodeFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + login := &TorTargetsAPIWithLogin{ + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + out, err := login.forceLogin(context.Background()) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string here") + } +} + +func TestTorTargetsAPITheDatabaseIsReplacedThenFailure(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + // but registrations are also failing. + handler.forgetLogins() + handler.noRegister = true + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestTorTargetsAPIClockIsOffThenSuccess(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} + +func TestTorTargetsAPIClockIsOffThen401(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 401} + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal("not the error we expected", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestTorTargetsAPIClockIsOffThen500(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 500} + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} diff --git a/internal/engine/ooapi/loginhandler_test.go b/internal/engine/ooapi/loginhandler_test.go new file mode 100644 index 0000000..f9e0274 --- /dev/null +++ b/internal/engine/ooapi/loginhandler_test.go @@ -0,0 +1,204 @@ +package ooapi + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// LoginHandler is an http.Handler to test login +type LoginHandler struct { + failCallWith []int // ignored by login and register + mu sync.Mutex + noRegister bool + state []*loginState + t *testing.T + logins int32 + registers int32 +} + +func (lh *LoginHandler) forgetLogins() { + defer lh.mu.Unlock() + lh.mu.Lock() + lh.state = nil +} + +func (lh *LoginHandler) forgetTokens() { + defer lh.mu.Unlock() + lh.mu.Lock() + for _, entry := range lh.state { + // This should be enough to cause all tokens to + // be expired and force clients to relogin. + // + // (It does not matter much whether the client + // clock is off, or the server clock is off, + // thanks Galileo for explaining this to us <3.) + entry.Expire = time.Now().Add(-3600 * time.Second) + } +} + +func (lh *LoginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Implementation note: we don't check for the method + // for simplicity since it's already tested. + switch r.URL.Path { + case "/api/v1/register": + atomic.AddInt32(&lh.registers, 1) + lh.register(w, r) + case "/api/v1/login": + atomic.AddInt32(&lh.logins, 1) + lh.login(w, r) + case "/api/v1/test-list/psiphon-config": + lh.psiphon(w, r) + case "/api/v1/test-list/tor-targets": + lh.tor(w, r) + default: + w.WriteHeader(500) + } +} + +func (lh *LoginHandler) register(w http.ResponseWriter, r *http.Request) { + if r.Body == nil { + w.WriteHeader(400) + return + } + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + var req apimodel.RegisterRequest + if err := json.Unmarshal(data, &req); err != nil { + w.WriteHeader(400) + return + } + if req.Password == "" { + w.WriteHeader(400) + return + } + defer lh.mu.Unlock() + lh.mu.Lock() + if lh.noRegister { + // We have been asked to stop registering clients so + // we're going to make a boo boo. + w.WriteHeader(500) + return + } + var resp apimodel.RegisterResponse + ff := &fakeFill{} + ff.fill(&resp) + lh.state = append(lh.state, &loginState{ + ClientID: resp.ClientID, Password: req.Password}) + data, err = json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("register: %+v", string(data)) + w.Write(data) +} + +func (lh *LoginHandler) login(w http.ResponseWriter, r *http.Request) { + if r.Body == nil { + w.WriteHeader(400) + return + } + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + var req apimodel.LoginRequest + if err := json.Unmarshal(data, &req); err != nil { + w.WriteHeader(400) + return + } + defer lh.mu.Unlock() + lh.mu.Lock() + for _, s := range lh.state { + if req.ClientID == s.ClientID && req.Password == s.Password { + var resp apimodel.LoginResponse + ff := &fakeFill{} + ff.fill(&resp) + // We want the token to be many seconds in the future while + // ff.fill only sets the tokent to now plus a small delta. + resp.Expire = time.Now().Add(3600 * time.Second) + s.Expire = resp.Expire + s.Token = resp.Token + data, err = json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("login: %+v", string(data)) + w.Write(data) + return + } + } + lh.t.Log("login: 401") + w.WriteHeader(401) +} + +func (lh *LoginHandler) psiphon(w http.ResponseWriter, r *http.Request) { + defer lh.mu.Unlock() + lh.mu.Lock() + if len(lh.failCallWith) > 0 { + code := lh.failCallWith[0] + lh.failCallWith = lh.failCallWith[1:] + w.WriteHeader(code) + return + } + token := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1) + for _, s := range lh.state { + if token == s.Token && time.Now().Before(s.Expire) { + var resp apimodel.PsiphonConfigResponse + ff := &fakeFill{} + ff.fill(&resp) + data, err := json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("psiphon: %+v", string(data)) + w.Write(data) + return + } + } + lh.t.Log("psiphon: 401") + w.WriteHeader(401) +} + +func (lh *LoginHandler) tor(w http.ResponseWriter, r *http.Request) { + defer lh.mu.Unlock() + lh.mu.Lock() + if len(lh.failCallWith) > 0 { + code := lh.failCallWith[0] + lh.failCallWith = lh.failCallWith[1:] + w.WriteHeader(code) + return + } + token := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1) + for _, s := range lh.state { + if token == s.Token && time.Now().Before(s.Expire) { + var resp apimodel.TorTargetsResponse + ff := &fakeFill{} + ff.fill(&resp) + data, err := json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("tor: %+v", string(data)) + w.Write(data) + return + } + } + lh.t.Log("tor: 401") + w.WriteHeader(401) +} diff --git a/internal/engine/ooapi/loginmodel.go b/internal/engine/ooapi/loginmodel.go new file mode 100644 index 0000000..4337348 --- /dev/null +++ b/internal/engine/ooapi/loginmodel.go @@ -0,0 +1,59 @@ +package ooapi + +import ( + "crypto/rand" + "encoding/base64" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" + "github.com/ooni/probe-cli/v3/internal/engine/runtimex" +) + +// loginState is the struct saved in the kvstore +// to keep track of the login state. +type loginState struct { + ClientID string + Expire time.Time + Password string + Token string +} + +func (ls *loginState) credentialsValid() bool { + return ls.ClientID != "" && ls.Password != "" +} + +func (ls *loginState) tokenValid() bool { + return ls.Token != "" && time.Now().Add(60*time.Second).Before(ls.Expire) +} + +// loginKey is the key with which loginState is saved +// into the key-value store used by Client. +const loginKey = "orchestra.state" + +// newRandomPassword generates a new random password. +func newRandomPassword() string { + b := make([]byte, 48) + _, err := rand.Read(b) + runtimex.PanicOnError(err, "rand.Read failed") + return base64.StdEncoding.EncodeToString(b) +} + +// newRegisterRequest creates a new RegisterRequest. +func newRegisterRequest(password string) *apimodel.RegisterRequest { + return &apimodel.RegisterRequest{ + // The original implementation has as its only use case that we + // were registering and logging in for sending an update regarding + // the probe whereabouts. Yet here in probe-engine, the orchestra + // is currently only used to fetch inputs. For this purpose, we don't + // need to communicate any specific information. The code that will + // perform an update used to be responsible of doing that. Now, we + // are not using orchestra for this purpose anymore. + Platform: "miniooni", + ProbeASN: "AS0", + ProbeCC: "ZZ", + SoftwareName: "miniooni", + SoftwareVersion: "0.1.0-dev", + SupportedTests: []string{"web_connectivity"}, + Password: password, + } +} diff --git a/internal/engine/ooapi/requests.go b/internal/engine/ooapi/requests.go new file mode 100644 index 0000000..806d5f2 --- /dev/null +++ b/internal/engine/ooapi/requests.go @@ -0,0 +1,192 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:53.210720456 +0100 CET m=+0.000083649 + +package ooapi + +//go:generate go run ./internal/generator -file requests.go + +import ( + "bytes" + "context" + "net/http" + "net/url" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func (api *CheckReportIDAPI) newRequest(ctx context.Context, req *apimodel.CheckReportIDRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/_/check_report_id" + q := url.Values{} + if req.ReportID == "" { + return nil, newErrEmptyField("ReportID") + } + q.Add("report_id", req.ReportID) + URL.RawQuery = q.Encode() + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *CheckInAPI) newRequest(ctx context.Context, req *apimodel.CheckInRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/check-in" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *LoginAPI) newRequest(ctx context.Context, req *apimodel.LoginRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/login" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *MeasurementMetaAPI) newRequest(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/measurement_meta" + q := url.Values{} + if req.ReportID == "" { + return nil, newErrEmptyField("ReportID") + } + q.Add("report_id", req.ReportID) + if req.Full { + q.Add("full", "true") + } + if req.Input != "" { + q.Add("input", req.Input) + } + URL.RawQuery = q.Encode() + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *RegisterAPI) newRequest(ctx context.Context, req *apimodel.RegisterRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/register" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *TestHelpersAPI) newRequest(ctx context.Context, req *apimodel.TestHelpersRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-helpers" + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *PsiphonConfigAPI) newRequest(ctx context.Context, req *apimodel.PsiphonConfigRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-list/psiphon-config" + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *TorTargetsAPI) newRequest(ctx context.Context, req *apimodel.TorTargetsRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-list/tor-targets" + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *URLsAPI) newRequest(ctx context.Context, req *apimodel.URLsRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-list/urls" + q := url.Values{} + if req.CategoryCodes != "" { + q.Add("category_codes", req.CategoryCodes) + } + if req.CountryCode != "" { + q.Add("country_code", req.CountryCode) + } + if req.Limit != 0 { + q.Add("limit", newQueryFieldInt64(req.Limit)) + } + URL.RawQuery = q.Encode() + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *OpenReportAPI) newRequest(ctx context.Context, req *apimodel.OpenReportRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/report" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *SubmitMeasurementAPI) newRequest(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + up, err := api.templateExecutor().Execute("/report/{{ .ReportID }}", req) + if err != nil { + return nil, err + } + URL.Path = up + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} diff --git a/internal/engine/ooapi/responses.go b/internal/engine/ooapi/responses.go new file mode 100644 index 0000000..b148aea --- /dev/null +++ b/internal/engine/ooapi/responses.go @@ -0,0 +1,276 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:53.567815989 +0100 CET m=+0.000158731 + +package ooapi + +//go:generate go run ./internal/generator -file responses.go + +import ( + "io" + "io/ioutil" + "net/http" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func (api *CheckReportIDAPI) newResponse(resp *http.Response, err error) (*apimodel.CheckReportIDResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.CheckReportIDResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *CheckInAPI) newResponse(resp *http.Response, err error) (*apimodel.CheckInResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.CheckInResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *LoginAPI) newResponse(resp *http.Response, err error) (*apimodel.LoginResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.LoginResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *MeasurementMetaAPI) newResponse(resp *http.Response, err error) (*apimodel.MeasurementMetaResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.MeasurementMetaResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *RegisterAPI) newResponse(resp *http.Response, err error) (*apimodel.RegisterResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.RegisterResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *TestHelpersAPI) newResponse(resp *http.Response, err error) (apimodel.TestHelpersResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := apimodel.TestHelpersResponse{} + if err := api.jsonCodec().Decode(data, &out); err != nil { + return nil, err + } + if out == nil { + return nil, ErrJSONLiteralNull + } + return out, nil +} + +func (api *PsiphonConfigAPI) newResponse(resp *http.Response, err error) (apimodel.PsiphonConfigResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := apimodel.PsiphonConfigResponse{} + if err := api.jsonCodec().Decode(data, &out); err != nil { + return nil, err + } + if out == nil { + return nil, ErrJSONLiteralNull + } + return out, nil +} + +func (api *TorTargetsAPI) newResponse(resp *http.Response, err error) (apimodel.TorTargetsResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := apimodel.TorTargetsResponse{} + if err := api.jsonCodec().Decode(data, &out); err != nil { + return nil, err + } + if out == nil { + return nil, ErrJSONLiteralNull + } + return out, nil +} + +func (api *URLsAPI) newResponse(resp *http.Response, err error) (*apimodel.URLsResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.URLsResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *OpenReportAPI) newResponse(resp *http.Response, err error) (*apimodel.OpenReportResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.OpenReportResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *SubmitMeasurementAPI) newResponse(resp *http.Response, err error) (*apimodel.SubmitMeasurementResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.SubmitMeasurementResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} diff --git a/internal/engine/ooapi/swagger_test.go b/internal/engine/ooapi/swagger_test.go new file mode 100644 index 0000000..5cf18ab --- /dev/null +++ b/internal/engine/ooapi/swagger_test.go @@ -0,0 +1,578 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:53.881261959 +0100 CET m=+0.000594905 + +package ooapi + +//go:generate go run ./internal/generator -file swagger_test.go + +const swagger = `{ + "swagger": "2.0", + "info": { + "title": "OONI API specification", + "version": "0.20210226.2144553" + }, + "host": "api.ooni.io", + "basePath": "/", + "schemes": [ + "https" + ], + "paths": { + "/api/_/check_report_id": { + "get": { + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "report_id", + "required": true, + "type": "string" + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "error": { + "type": "string" + }, + "found": { + "type": "boolean" + }, + "v": { + "type": "integer" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/check-in": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "charging": { + "type": "boolean" + }, + "on_wifi": { + "type": "boolean" + }, + "platform": { + "type": "string" + }, + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "run_type": { + "type": "string" + }, + "software_name": { + "type": "string" + }, + "software_version": { + "type": "string" + }, + "web_connectivity": { + "properties": { + "category_codes": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "tests": { + "properties": { + "web_connectivity": { + "properties": { + "report_id": { + "type": "string" + }, + "urls": { + "items": { + "properties": { + "category_code": { + "type": "string" + }, + "country_code": { + "type": "string" + }, + "url": { + "type": "string" + } + }, + "type": "object" + }, + "type": "array" + } + }, + "type": "object" + } + }, + "type": "object" + }, + "v": { + "type": "integer" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/login": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "password": { + "type": "string" + }, + "username": { + "type": "string" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "expire": { + "type": "string" + }, + "token": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/measurement_meta": { + "get": { + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "report_id", + "required": true, + "type": "string" + }, + { + "in": "query", + "name": "full", + "type": "boolean" + }, + { + "in": "query", + "name": "input", + "type": "string" + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "anomaly": { + "type": "boolean" + }, + "category_code": { + "type": "string" + }, + "confirmed": { + "type": "boolean" + }, + "failure": { + "type": "boolean" + }, + "input": { + "type": "string" + }, + "measurement_start_time": { + "type": "string" + }, + "probe_asn": { + "type": "integer" + }, + "probe_cc": { + "type": "string" + }, + "raw_measurement": { + "type": "string" + }, + "report_id": { + "type": "string" + }, + "scores": { + "type": "string" + }, + "test_name": { + "type": "string" + }, + "test_start_time": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/register": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "available_bandwidth": { + "type": "string" + }, + "device_token": { + "type": "string" + }, + "language": { + "type": "string" + }, + "network_type": { + "type": "string" + }, + "password": { + "type": "string" + }, + "platform": { + "type": "string" + }, + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "probe_family": { + "type": "string" + }, + "probe_timezone": { + "type": "string" + }, + "software_name": { + "type": "string" + }, + "software_version": { + "type": "string" + }, + "supported_tests": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "client_id": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/test-helpers": { + "get": { + "produces": [ + "application/json" + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "type": "object" + } + } + } + } + }, + "/api/v1/test-list/psiphon-config": { + "get": { + "produces": [ + "application/json" + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "type": "object" + } + } + } + } + }, + "/api/v1/test-list/tor-targets": { + "get": { + "produces": [ + "application/json" + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "type": "object" + } + } + } + } + }, + "/api/v1/test-list/urls": { + "get": { + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "category_codes", + "type": "string" + }, + { + "in": "query", + "name": "country_code", + "type": "string" + }, + { + "in": "query", + "name": "limit", + "type": "integer" + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "metadata": { + "properties": { + "count": { + "type": "integer" + } + }, + "type": "object" + }, + "results": { + "items": { + "properties": { + "category_code": { + "type": "string" + }, + "country_code": { + "type": "string" + }, + "url": { + "type": "string" + } + }, + "type": "object" + }, + "type": "array" + } + }, + "type": "object" + } + } + } + } + }, + "/report": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "data_format_version": { + "type": "string" + }, + "format": { + "type": "string" + }, + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "software_name": { + "type": "string" + }, + "software_version": { + "type": "string" + }, + "test_name": { + "type": "string" + }, + "test_start_time": { + "type": "string" + }, + "test_version": { + "type": "string" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "backend_version": { + "type": "string" + }, + "report_id": { + "type": "string" + }, + "supported_formats": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + } + } + } + } + }, + "/report/{report_id}": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "path", + "name": "report_id", + "required": true, + "type": "string" + }, + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "content": { + "type": "object" + }, + "format": { + "type": "string" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "measurement_uid": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + } + } +}` diff --git a/internal/engine/ooapi/swaggerdiff_test.go b/internal/engine/ooapi/swaggerdiff_test.go new file mode 100644 index 0000000..6715bc9 --- /dev/null +++ b/internal/engine/ooapi/swaggerdiff_test.go @@ -0,0 +1,158 @@ +package ooapi + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "sort" + "strings" + "testing" + + "github.com/hexops/gotextdiff" + "github.com/hexops/gotextdiff/myers" + "github.com/hexops/gotextdiff/span" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/internal/openapi" +) + +const ( + productionURL = "https://api.ooni.io/apispec_1.json" + testingURL = "https://ams-pg-test.ooni.org/apispec_1.json" +) + +func makeModel(data []byte) *openapi.Swagger { + var out openapi.Swagger + if err := json.Unmarshal(data, &out); err != nil { + log.Fatal(err) + } + // We reduce irrelevant differences by producing a common header + return &openapi.Swagger{Paths: out.Paths} +} + +func getServerModel(serverURL string) *openapi.Swagger { + resp, err := http.Get(serverURL) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + return makeModel(data) +} + +func getClientModel() *openapi.Swagger { + return makeModel([]byte(swagger)) +} + +func simplifyRoundTrip(rt *openapi.RoundTrip) { + // Normalize the used name when a parameter is in body. This + // should only have a cosmetic impact on the spec. + for _, param := range rt.Parameters { + if param.In == "body" { + param.Name = "body" + } + } + + // Sort parameters so the comparison does not depend on order. + sort.SliceStable(rt.Parameters, func(i, j int) bool { + left, right := rt.Parameters[i].Name, rt.Parameters[j].Name + return strings.Compare(left, right) < 0 + }) + + // Normalize description of 200 response + rt.Responses.Successful.Description = "all good" +} + +func simplifyInPlace(path *openapi.Path) *openapi.Path { + if path.Get != nil && path.Post != nil { + log.Fatal("unsupported configuration") + } + if path.Get != nil { + simplifyRoundTrip(path.Get) + } + if path.Post != nil { + simplifyRoundTrip(path.Post) + } + return path +} + +func jsonify(model interface{}) string { + data, err := json.MarshalIndent(model, "", " ") + if err != nil { + log.Fatal(err) + } + return string(data) +} + +type diffable struct { + name string + value string +} + +func computediff(server, client *diffable) string { + d := gotextdiff.ToUnified(server.name, client.name, server.value, myers.ComputeEdits( + span.URIFromPath(server.name), server.value, client.value, + )) + return fmt.Sprint(d) +} + +// maybediff emits the diff between the server and the client and +// returns the length of the diff itself in bytes. +func maybediff(key string, server, client *openapi.Path) int { + diff := computediff(&diffable{ + name: fmt.Sprintf("server%s.json", key), + value: jsonify(simplifyInPlace(server)), + }, &diffable{ + name: fmt.Sprintf("client%s.json", key), + value: jsonify(simplifyInPlace(client)), + }) + if diff != "" { + fmt.Printf("%s", diff) + } + return len(diff) +} + +func compare(serverURL string) bool { + good := true + serverModel, clientModel := getServerModel(serverURL), getClientModel() + // Implementation note: the server model is richer than the client + // model, so we ignore everything not defined by the client. + var count int + for key := range serverModel.Paths { + if _, found := clientModel.Paths[key]; !found { + delete(serverModel.Paths, key) + continue + } + count++ + if maybediff(key, serverModel.Paths[key], clientModel.Paths[key]) > 0 { + good = false + } + } + if count <= 0 { + panic("no element found") + } + return good +} + +func TestWithProductionAPI(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + t.Log("using ", productionURL) + if !compare(productionURL) { + t.Fatal("model mismatch (see above)") + } +} + +func TestWithTestingAPI(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + t.Log("using ", testingURL) + if !compare(testingURL) { + t.Fatal("model mismatch (see above)") + } +} diff --git a/internal/engine/ooapi/utils.go b/internal/engine/ooapi/utils.go new file mode 100644 index 0000000..6d93a0b --- /dev/null +++ b/internal/engine/ooapi/utils.go @@ -0,0 +1,23 @@ +package ooapi + +import "fmt" + +func newErrEmptyField(field string) error { + return fmt.Errorf("%w: %s", ErrEmptyField, field) +} + +func newHTTPFailure(status int) error { + return fmt.Errorf("%w: %d", ErrHTTPFailure, status) +} + +func newQueryFieldInt64(v int64) string { + return fmt.Sprintf("%d", v) +} + +func newQueryFieldBool(v bool) string { + return fmt.Sprintf("%v", v) +} + +func newAuthorizationHeader(token string) string { + return fmt.Sprintf("Bearer %s", token) +} diff --git a/internal/engine/ooapi/utils_test.go b/internal/engine/ooapi/utils_test.go new file mode 100644 index 0000000..54b158b --- /dev/null +++ b/internal/engine/ooapi/utils_test.go @@ -0,0 +1,12 @@ +package ooapi + +import "testing" + +func TestNewQueryFieldBoolWorks(t *testing.T) { + if s := newQueryFieldBool(true); s != "true" { + t.Fatal("invalid encoding of true") + } + if s := newQueryFieldBool(false); s != "false" { + t.Fatal("invalid encoding of false") + } +} From 2ef5fb503ac9a71aad676147d50cdb498778c62e Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Mon, 8 Mar 2021 12:05:43 +0100 Subject: [PATCH 3/6] fix(webconnectivity): allow measuring https://1.1.1.1 (#241) * fix(webconnectivity): allow measuring https://1.1.1.1 There were two issues preventing us from doing so: 1. in netx, the address resolver was too later in the resolver chain. Therefore, its result wasn't added to the events. 2. when building the DNSCache (in httpget.go), we didn't consider the case where the input is an address. We need to treat this case specially to make sure there is no DNSCache. See https://github.com/ooni/probe/issues/1376. * fix: add unit tests for code making the dnscache * fix(netx): make sure all tests pass * chore: bump webconnectivity version --- .../experiment/webconnectivity/control.go | 4 +- .../experiment/webconnectivity/dnslookup.go | 1 + .../experiment/webconnectivity/httpget.go | 19 ++++- .../webconnectivity/httpget_test.go | 17 ++++ .../webconnectivity/webconnectivity.go | 2 +- .../webconnectivity/webconnectivity_test.go | 2 +- internal/engine/netx/netx.go | 2 +- internal/engine/netx/netx_test.go | 82 +++++++++---------- 8 files changed, 82 insertions(+), 47 deletions(-) diff --git a/internal/engine/experiment/webconnectivity/control.go b/internal/engine/experiment/webconnectivity/control.go index 8d4b3ab..4d189b0 100644 --- a/internal/engine/experiment/webconnectivity/control.go +++ b/internal/engine/experiment/webconnectivity/control.go @@ -57,13 +57,13 @@ func Control( HTTPClient: sess.DefaultHTTPClient(), Logger: sess.Logger(), } - sess.Logger().Infof("control %s...", creq.HTTPRequest) + sess.Logger().Infof("control for %s...", creq.HTTPRequest) // make sure error is wrapped err = errorx.SafeErrWrapperBuilder{ Error: clnt.PostJSON(ctx, "/", creq, &out), Operation: errorx.TopLevelOperation, }.MaybeBuild() - sess.Logger().Infof("control %s... %+v", creq.HTTPRequest, err) + sess.Logger().Infof("control for %s... %+v", creq.HTTPRequest, err) (&out.DNS).FillASNs(sess) return } diff --git a/internal/engine/experiment/webconnectivity/dnslookup.go b/internal/engine/experiment/webconnectivity/dnslookup.go index 92133b7..6d0a324 100644 --- a/internal/engine/experiment/webconnectivity/dnslookup.go +++ b/internal/engine/experiment/webconnectivity/dnslookup.go @@ -37,6 +37,7 @@ func DNSLookup(ctx context.Context, config DNSLookupConfig) (out DNSLookupResult } if answer.IPv6 != "" { out.Addrs[answer.IPv6] = answer.ASN + continue } } } diff --git a/internal/engine/experiment/webconnectivity/httpget.go b/internal/engine/experiment/webconnectivity/httpget.go index ad177f5..35200ca 100644 --- a/internal/engine/experiment/webconnectivity/httpget.go +++ b/internal/engine/experiment/webconnectivity/httpget.go @@ -3,6 +3,7 @@ package webconnectivity import ( "context" "fmt" + "net" "net/url" "strings" @@ -25,6 +26,22 @@ type HTTPGetResult struct { Failure *string } +// TODO(bassosimone): Web Connectivity uses too much external testing +// and we should actually expose much less to the outside by using +// internal testing and by making _many_ functions private. + +// HTTPGetMakeDNSCache constructs the DNSCache option for HTTPGet +// by combining domain and addresses into a single string. As a +// corner case, if the domain is an IP address, we return an empty +// string. This corner case corresponds to Web Connectivity +// inputs like https://1.1.1.1. +func HTTPGetMakeDNSCache(domain, addresses string) string { + if net.ParseIP(domain) != nil { + return "" + } + return fmt.Sprintf("%s %s", domain, addresses) +} + // HTTPGet performs the HTTP/HTTPS part of Web Connectivity. func HTTPGet(ctx context.Context, config HTTPGetConfig) (out HTTPGetResult) { addresses := strings.Join(config.Addresses, " ") @@ -38,7 +55,7 @@ func HTTPGet(ctx context.Context, config HTTPGetConfig) (out HTTPGetResult) { domain := config.TargetURL.Hostname() result, err := urlgetter.Getter{ Config: urlgetter.Config{ - DNSCache: fmt.Sprintf("%s %s", domain, addresses), + DNSCache: HTTPGetMakeDNSCache(domain, addresses), }, Session: config.Session, Target: target, diff --git a/internal/engine/experiment/webconnectivity/httpget_test.go b/internal/engine/experiment/webconnectivity/httpget_test.go index d19cdbf..6f1ddc9 100644 --- a/internal/engine/experiment/webconnectivity/httpget_test.go +++ b/internal/engine/experiment/webconnectivity/httpget_test.go @@ -25,3 +25,20 @@ func TestHTTPGet(t *testing.T) { t.Fatal(*r.Failure) } } + +func TestHTTPGetMakeDNSCache(t *testing.T) { + // test for input being an IP + out := webconnectivity.HTTPGetMakeDNSCache( + "1.1.1.1", "1.1.1.1", + ) + if out != "" { + t.Fatal("expected empty output here") + } + // test for input being a domain + out = webconnectivity.HTTPGetMakeDNSCache( + "dns.google", "8.8.8.8 8.8.4.4", + ) + if out != "dns.google 8.8.8.8 8.8.4.4" { + t.Fatal("expected ordinary output here") + } +} diff --git a/internal/engine/experiment/webconnectivity/webconnectivity.go b/internal/engine/experiment/webconnectivity/webconnectivity.go index 9788e49..7f76ed4 100644 --- a/internal/engine/experiment/webconnectivity/webconnectivity.go +++ b/internal/engine/experiment/webconnectivity/webconnectivity.go @@ -19,7 +19,7 @@ import ( const ( testName = "web_connectivity" - testVersion = "0.2.0" + testVersion = "0.3.0" ) // Config contains the experiment config. diff --git a/internal/engine/experiment/webconnectivity/webconnectivity_test.go b/internal/engine/experiment/webconnectivity/webconnectivity_test.go index e9e3335..8a04618 100644 --- a/internal/engine/experiment/webconnectivity/webconnectivity_test.go +++ b/internal/engine/experiment/webconnectivity/webconnectivity_test.go @@ -21,7 +21,7 @@ func TestNewExperimentMeasurer(t *testing.T) { if measurer.ExperimentName() != "web_connectivity" { t.Fatal("unexpected name") } - if measurer.ExperimentVersion() != "0.2.0" { + if measurer.ExperimentVersion() != "0.3.0" { t.Fatal("unexpected version") } } diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 554fcdf..25f2365 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -126,6 +126,7 @@ func NewResolver(config Config) Resolver { config.BaseResolver = resolver.SystemResolver{} } var r Resolver = config.BaseResolver + r = resolver.AddressResolver{Resolver: r} if config.CacheResolutions { r = &resolver.CacheResolver{Resolver: r} } @@ -146,7 +147,6 @@ func NewResolver(config Config) Resolver { if config.ResolveSaver != nil { r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver} } - r = resolver.AddressResolver{Resolver: r} return resolver.IDNAResolver{Resolver: r} } diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 00bf58e..9526ee8 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -23,15 +23,15 @@ func TestNewResolverVanilla(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ar, ok := ir.Resolver.(resolver.AddressResolver) + ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } - ewr, ok := ar.Resolver.(resolver.ErrorWrapperResolver) + ar, ok := ewr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } - _, ok = ewr.Resolver.(resolver.SystemResolver) + _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -47,15 +47,15 @@ func TestNewResolverSpecificResolver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ar, ok := ir.Resolver.(resolver.AddressResolver) + ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } - ewr, ok := ar.Resolver.(resolver.ErrorWrapperResolver) + ar, ok := ewr.Resolver.(resolver.AddressResolver) if !ok { t.Fatal("not the resolver we expected") } - _, ok = ewr.Resolver.(resolver.BogonResolver) + _, ok = ar.Resolver.(resolver.BogonResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -69,11 +69,7 @@ func TestNewResolverWithBogonFilter(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ar, ok := ir.Resolver.(resolver.AddressResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - ewr, ok := ar.Resolver.(resolver.ErrorWrapperResolver) + ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -81,7 +77,11 @@ func TestNewResolverWithBogonFilter(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = br.Resolver.(resolver.SystemResolver) + ar, ok := br.Resolver.(resolver.AddressResolver) + if !ok { + t.Fatal("not the resolver we expected") + } + _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -95,11 +95,7 @@ func TestNewResolverWithLogging(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ar, ok := ir.Resolver.(resolver.AddressResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - lr, ok := ar.Resolver.(resolver.LoggingResolver) + lr, ok := ir.Resolver.(resolver.LoggingResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -110,7 +106,11 @@ func TestNewResolverWithLogging(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ewr.Resolver.(resolver.SystemResolver) + ar, ok := ewr.Resolver.(resolver.AddressResolver) + if !ok { + t.Fatal("not the resolver we expected") + } + _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -125,11 +125,7 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ar, ok := ir.Resolver.(resolver.AddressResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - sr, ok := ar.Resolver.(resolver.SaverResolver) + sr, ok := ir.Resolver.(resolver.SaverResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -140,7 +136,11 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - _, ok = ewr.Resolver.(resolver.SystemResolver) + ar, ok := ewr.Resolver.(resolver.AddressResolver) + if !ok { + t.Fatal("not the resolver we expected") + } + _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -154,11 +154,7 @@ func TestNewResolverWithReadWriteCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ar, ok := ir.Resolver.(resolver.AddressResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - ewr, ok := ar.Resolver.(resolver.ErrorWrapperResolver) + ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -169,7 +165,11 @@ func TestNewResolverWithReadWriteCache(t *testing.T) { if cr.ReadOnly != false { t.Fatal("expected readwrite cache here") } - _, ok = cr.Resolver.(resolver.SystemResolver) + ar, ok := cr.Resolver.(resolver.AddressResolver) + if !ok { + t.Fatal("not the resolver we expected") + } + _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -185,11 +185,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - ar, ok := ir.Resolver.(resolver.AddressResolver) - if !ok { - t.Fatal("not the resolver we expected") - } - ewr, ok := ar.Resolver.(resolver.ErrorWrapperResolver) + ewr, ok := ir.Resolver.(resolver.ErrorWrapperResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -203,7 +199,11 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) { if cr.Get("dns.google.com")[0] != "8.8.8.8" { t.Fatal("cache not correctly prefilled") } - _, ok = cr.Resolver.(resolver.SystemResolver) + ar, ok := cr.Resolver.(resolver.AddressResolver) + if !ok { + t.Fatal("not the resolver we expected") + } + _, ok = ar.Resolver.(resolver.SystemResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -233,7 +233,7 @@ func TestNewDialerVanilla(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := ir.Resolver.(resolver.AddressResolver); !ok { + if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) @@ -315,7 +315,7 @@ func TestNewDialerWithLogger(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := ir.Resolver.(resolver.AddressResolver); !ok { + if _, ok := ir.Resolver.(resolver.LoggingResolver); !ok { t.Fatal("not the resolver we expected") } ld, ok := dnsd.Dialer.(dialer.LoggingDialer) @@ -365,7 +365,7 @@ func TestNewDialerWithDialSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := ir.Resolver.(resolver.AddressResolver); !ok { + if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } sad, ok := dnsd.Dialer.(dialer.SaverDialer) @@ -415,7 +415,7 @@ func TestNewDialerWithReadWriteSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := ir.Resolver.(resolver.AddressResolver); !ok { + if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } scd, ok := dnsd.Dialer.(dialer.SaverConnDialer) @@ -468,7 +468,7 @@ func TestNewDialerWithContextByteCounting(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := ir.Resolver.(resolver.AddressResolver); !ok { + if _, ok := ir.Resolver.(resolver.ErrorWrapperResolver); !ok { t.Fatal("not the resolver we expected") } ewd, ok := dnsd.Dialer.(dialer.ErrorWrapperDialer) From f5461323dbcc6d0624dc4bae5e0aa3d061752e60 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Mon, 8 Mar 2021 12:47:19 +0100 Subject: [PATCH 4/6] fix(run): run unattended short also on Windows (#242) We already have a short run unattended on macOS and we wanna do the same for Windows. See https://github.com/ooni/probe/issues/1377. --- cmd/ooniprobe/internal/cli/run/run.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cmd/ooniprobe/internal/cli/run/run.go b/cmd/ooniprobe/internal/cli/run/run.go index 49d773f..5e8ca40 100644 --- a/cmd/ooniprobe/internal/cli/run/run.go +++ b/cmd/ooniprobe/internal/cli/run/run.go @@ -76,10 +76,11 @@ func init() { unattendedCmd := cmd.Command("unattended", "") unattendedCmd.Action(func(_ *kingpin.ParseContext) error { - if runtime.GOOS == "darwin" { - // Until we have enabled the check-in API we're called every - // hour on darwin and we need to self throttle. - // TODO(bassosimone): switch to check-in and remove this hack. + // Until we have enabled the check-in API we're called every + // hour on darwin and we need to self throttle. + // TODO(bassosimone): switch to check-in and remove this hack. + switch runtime.GOOS { + case "darwin", "windows": const veryFew = 10 probe.Config().Nettests.WebsitesURLLimit = veryFew } From da95fa936506d35ced4c23ea3d2d65f5bc00d5a6 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Mon, 8 Mar 2021 13:38:34 +0100 Subject: [PATCH 5/6] refactor: signal et al. are now experimental nettests (#243) * refactor: signal et al. are now experimental nettests We move signal into the experimental nettests group. While there, also start adding dnscheck and stunreachability as well. It seems there's more work to be done to correctly represent the results of dnscheck, but this is fine! The experimental section is here exactly for this reason! In terms of UI, the new command is `ooniprobe run experimental`. We will most likely move signal out of experimental soon, since it's already working quite well. We need to keep it here for one more cycle because the desktop app is not ready for it. See the following issues: 1. https://github.com/ooni/probe/issues/1378 2. https://github.com/ooni/probe/issues/1262 * fix(dnscheck): spell check * fix: improve documentation --- cmd/ooniprobe/internal/cli/run/run.go | 3 +- .../internal/log/handlers/cli/results.go | 7 +++ cmd/ooniprobe/internal/nettests/dnscheck.go | 60 +++++++++++++++++++ cmd/ooniprobe/internal/nettests/groups.go | 9 ++- cmd/ooniprobe/internal/nettests/signal.go | 8 +-- .../internal/nettests/stunreachability.go | 13 ++++ 6 files changed, 93 insertions(+), 7 deletions(-) create mode 100644 cmd/ooniprobe/internal/nettests/dnscheck.go create mode 100644 cmd/ooniprobe/internal/nettests/stunreachability.go diff --git a/cmd/ooniprobe/internal/cli/run/run.go b/cmd/ooniprobe/internal/cli/run/run.go index 5e8ca40..32d27cd 100644 --- a/cmd/ooniprobe/internal/cli/run/run.go +++ b/cmd/ooniprobe/internal/cli/run/run.go @@ -69,7 +69,8 @@ func init() { }) }) - easyRuns := []string{"im", "performance", "circumvention", "middlebox"} + easyRuns := []string{ + "im", "performance", "circumvention", "middlebox", "experimental"} for _, name := range easyRuns { cmd.Command(name, "").Action(genRunWithGroupName(name)) } diff --git a/cmd/ooniprobe/internal/log/handlers/cli/results.go b/cmd/ooniprobe/internal/log/handlers/cli/results.go index 09ca641..9350bdd 100644 --- a/cmd/ooniprobe/internal/log/handlers/cli/results.go +++ b/cmd/ooniprobe/internal/log/handlers/cli/results.go @@ -76,6 +76,13 @@ var summarizers = map[string]func(uint64, uint64, string) []string{ "", } }, + "experimental": func(totalCount uint64, anomalyCount uint64, ss string) []string { + return []string{ + fmt.Sprintf("%d tested", totalCount), + fmt.Sprintf("%d blocked", anomalyCount), + "", + } + }, } func makeSummary(name string, totalCount uint64, anomalyCount uint64, ss string) []string { diff --git a/cmd/ooniprobe/internal/nettests/dnscheck.go b/cmd/ooniprobe/internal/nettests/dnscheck.go new file mode 100644 index 0000000..02cf963 --- /dev/null +++ b/cmd/ooniprobe/internal/nettests/dnscheck.go @@ -0,0 +1,60 @@ +package nettests + +import ( + "encoding/json" + + "github.com/ooni/probe-cli/v3/internal/engine/experiment/dnscheck" + "github.com/ooni/probe-cli/v3/internal/engine/experiment/run" + "github.com/ooni/probe-cli/v3/internal/engine/runtimex" +) + +// DNSCheck nettest implementation. +type DNSCheck struct{} + +var dnsCheckDefaultInput []string + +func dnsCheckMustMakeInput(input *run.StructuredInput) string { + data, err := json.Marshal(input) + runtimex.PanicOnError(err, "json.Marshal failed") + return string(data) +} + +func init() { + // The following code just adds a minimal set of URLs to + // test using DNSCheck, so we start exposing it. + // + // TODO(bassosimone): + // + // 1. we should be getting input from the backend instead of + // having an hardcoded list of inputs here. + // + // 2. we should modify dnscheck to accept http3://... as a + // shortcut for https://... with h3. If we don't do that, we + // are stuck with the h3 results hiding h2 results in OONI + // Explorer because they use the same URL. + // + // 3. it seems we have the problem that dnscheck results + // appear as the `run` nettest in `ooniprobe list ` because + // dnscheck is run using the `run` functionality. + dnsCheckDefaultInput = append(dnsCheckDefaultInput, dnsCheckMustMakeInput( + &run.StructuredInput{ + DNSCheck: dnscheck.Config{}, + Name: "dnscheck", + Input: "https://dns.google/dns-query", + })) + dnsCheckDefaultInput = append(dnsCheckDefaultInput, dnsCheckMustMakeInput( + &run.StructuredInput{ + DNSCheck: dnscheck.Config{}, + Name: "dnscheck", + Input: "https://cloudflare-dns.com/dns-query", + })) +} + +// Run starts the nettest. +func (n DNSCheck) Run(ctl *Controller) error { + builder, err := ctl.Session.NewExperimentBuilder("run") + if err != nil { + return err + } + return ctl.Run(builder, dnsCheckDefaultInput) +} diff --git a/cmd/ooniprobe/internal/nettests/groups.go b/cmd/ooniprobe/internal/nettests/groups.go index e2735b7..5c4d9bc 100644 --- a/cmd/ooniprobe/internal/nettests/groups.go +++ b/cmd/ooniprobe/internal/nettests/groups.go @@ -35,7 +35,6 @@ var All = map[string]Group{ Label: "Instant Messaging", Nettests: []Nettest{ FacebookMessenger{}, - Signal{}, Telegram{}, WhatsApp{}, }, @@ -50,4 +49,12 @@ var All = map[string]Group{ }, UnattendedOK: true, }, + "experimental": { + Label: "Experimental Nettests", + Nettests: []Nettest{ + DNSCheck{}, + STUNReachability{}, + Signal{}, + }, + }, } diff --git a/cmd/ooniprobe/internal/nettests/signal.go b/cmd/ooniprobe/internal/nettests/signal.go index 9b17d34..3a2df6b 100644 --- a/cmd/ooniprobe/internal/nettests/signal.go +++ b/cmd/ooniprobe/internal/nettests/signal.go @@ -1,10 +1,9 @@ package nettests -// Signal test implementation -type Signal struct { -} +// Signal nettest implementation. +type Signal struct{} -// Run starts the test +// Run starts the nettest. func (h Signal) Run(ctl *Controller) error { builder, err := ctl.Session.NewExperimentBuilder( "signal", @@ -12,6 +11,5 @@ func (h Signal) Run(ctl *Controller) error { if err != nil { return err } - return ctl.Run(builder, []string{""}) } diff --git a/cmd/ooniprobe/internal/nettests/stunreachability.go b/cmd/ooniprobe/internal/nettests/stunreachability.go new file mode 100644 index 0000000..68fb40e --- /dev/null +++ b/cmd/ooniprobe/internal/nettests/stunreachability.go @@ -0,0 +1,13 @@ +package nettests + +// STUNReachability nettest implementation. +type STUNReachability struct{} + +// Run starts the nettest. +func (n STUNReachability) Run(ctl *Controller) error { + builder, err := ctl.Session.NewExperimentBuilder("stun_reachability") + if err != nil { + return err + } + return ctl.Run(builder, []string{""}) +} From 784d3d0f73dba9ad18781eac09d18c5157f0f99d Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Mon, 8 Mar 2021 13:51:43 +0100 Subject: [PATCH 6/6] chore: release 3.7.0 (#244) This comes just a few days after 3.6.0. It contains small improvements required by ooni/probe-desktop. For this reason, I am going to skeep the normal release process and I am just bumping the version number. --- internal/version/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/version/version.go b/internal/version/version.go index f3a8148..9f7674a 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -3,5 +3,5 @@ package version const ( // Version is the software version - Version = "3.7.0-alpha" + Version = "3.7.0" )