diff --git a/go.mod b/go.mod index c81527c..b24c571 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.2.0 github.com/gorilla/websocket v1.4.2 + github.com/hexops/gotextdiff v1.0.3 github.com/iancoleman/strcase v0.1.3 github.com/lucas-clemente/quic-go v0.19.3 github.com/marten-seemann/qtls-go1-15 v0.1.2 // indirect diff --git a/go.sum b/go.sum index 98d1875..4547262 100644 --- a/go.sum +++ b/go.sum @@ -218,6 +218,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hinshun/vt10x v0.0.0-20180616224451-1954e6464174 h1:WlZsjVhE8Af9IcZDGgJGQpNflI3+MJSBhsgT5PCtzBQ= github.com/hinshun/vt10x v0.0.0-20180616224451-1954e6464174/go.mod h1:DqJ97dSdRW1W22yXSB90986pcOyQ7r45iio1KN2ez1A= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/internal/engine/ooapi/README.md b/internal/engine/ooapi/README.md new file mode 100644 index 0000000..63de899 --- /dev/null +++ b/internal/engine/ooapi/README.md @@ -0,0 +1,5 @@ +# Package ./internal/engine/ooapi + +Automatically generated API clients for speaking with OONI servers. + +Please, run `go doc ./internal/engine/ooapi` to see API documentation. diff --git a/internal/engine/ooapi/apimodel/checkin.go b/internal/engine/ooapi/apimodel/checkin.go new file mode 100644 index 0000000..8df8085 --- /dev/null +++ b/internal/engine/ooapi/apimodel/checkin.go @@ -0,0 +1,47 @@ +package apimodel + +// CheckInRequestWebConnectivity contains WebConnectivity +// specific parameters to include into CheckInRequest +type CheckInRequestWebConnectivity struct { + CategoryCodes []string `json:"category_codes"` +} + +// CheckInRequest is the check-in API request +type CheckInRequest struct { + Charging bool `json:"charging"` + OnWiFi bool `json:"on_wifi"` + Platform string `json:"platform"` + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + RunType string `json:"run_type"` + SoftwareName string `json:"software_name"` + SoftwareVersion string `json:"software_version"` + WebConnectivity CheckInRequestWebConnectivity `json:"web_connectivity"` +} + +// CheckInResponseURLInfo contains information about an URL. +type CheckInResponseURLInfo struct { + CategoryCode string `json:"category_code"` + CountryCode string `json:"country_code"` + URL string `json:"url"` +} + +// CheckInResponseWebConnectivity contains WebConnectivity +// specific information of a CheckInResponse +type CheckInResponseWebConnectivity struct { + ReportID string `json:"report_id"` + URLs []CheckInResponseURLInfo `json:"urls"` +} + +// CheckInResponse is the check-in API response +type CheckInResponse struct { + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + Tests CheckInResponseTests `json:"tests"` + V int64 `json:"v"` +} + +// CheckInResponseTests contains configuration for tests +type CheckInResponseTests struct { + WebConnectivity CheckInResponseWebConnectivity `json:"web_connectivity"` +} diff --git a/internal/engine/ooapi/apimodel/checkreportid.go b/internal/engine/ooapi/apimodel/checkreportid.go new file mode 100644 index 0000000..068a047 --- /dev/null +++ b/internal/engine/ooapi/apimodel/checkreportid.go @@ -0,0 +1,13 @@ +package apimodel + +// CheckReportIDRequest is the CheckReportID request. +type CheckReportIDRequest struct { + ReportID string `query:"report_id" required:"true"` +} + +// CheckReportIDResponse is the CheckReportID response. +type CheckReportIDResponse struct { + Error string `json:"error"` + Found bool `json:"found"` + V int64 `json:"v"` +} diff --git a/internal/engine/ooapi/apimodel/doc.go b/internal/engine/ooapi/apimodel/doc.go new file mode 100644 index 0000000..da78f36 --- /dev/null +++ b/internal/engine/ooapi/apimodel/doc.go @@ -0,0 +1,22 @@ +// Package apimodel describes the data types used by OONI's API. +// +// If you edit this package to integrate the data model, remember to +// run `go generate ./...`. +// +// We annotate fields with tagging. When a field should be sent +// over as JSON, use the usual `json` tag. +// +// When a field needs to be sent using the query string, use +// the `query` tag instead. We limit what can be sent using the +// query string to int64, string, and bool. +// +// The `path` tag indicates that the URL path contains a +// template. We will replace the value of this field with +// the template. Note that the template should use the +// Go name of the field (e.g. `{{ .ReportID }}`) as opposed +// to the name in the tag, which is only used when we +// generate the API Swagger. +// +// The `required` tag indicates required fields. A required +// field cannot be empty (for the Go definition of empty). +package apimodel diff --git a/internal/engine/ooapi/apimodel/login.go b/internal/engine/ooapi/apimodel/login.go new file mode 100644 index 0000000..408b347 --- /dev/null +++ b/internal/engine/ooapi/apimodel/login.go @@ -0,0 +1,15 @@ +package apimodel + +import "time" + +// LoginRequest is the login API request +type LoginRequest struct { + ClientID string `json:"username"` + Password string `json:"password"` +} + +// LoginResponse is the login API response +type LoginResponse struct { + Expire time.Time `json:"expire"` + Token string `json:"token"` +} diff --git a/internal/engine/ooapi/apimodel/measurementmeta.go b/internal/engine/ooapi/apimodel/measurementmeta.go new file mode 100644 index 0000000..e97da69 --- /dev/null +++ b/internal/engine/ooapi/apimodel/measurementmeta.go @@ -0,0 +1,25 @@ +package apimodel + +// MeasurementMetaRequest is the MeasurementMeta Request. +type MeasurementMetaRequest struct { + ReportID string `query:"report_id" required:"true"` + Full bool `query:"full"` + Input string `query:"input"` +} + +// MeasurementMetaResponse is the MeasurementMeta Response. +type MeasurementMetaResponse struct { + Anomaly bool `json:"anomaly"` + CategoryCode string `json:"category_code"` + Confirmed bool `json:"confirmed"` + Failure bool `json:"failure"` + Input string `json:"input"` + MeasurementStartTime string `json:"measurement_start_time"` + ProbeASN int64 `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + RawMeasurement string `json:"raw_measurement"` + ReportID string `json:"report_id"` + Scores string `json:"scores"` + TestName string `json:"test_name"` + TestStartTime string `json:"test_start_time"` +} diff --git a/internal/engine/ooapi/apimodel/openreport.go b/internal/engine/ooapi/apimodel/openreport.go new file mode 100644 index 0000000..4432bba --- /dev/null +++ b/internal/engine/ooapi/apimodel/openreport.go @@ -0,0 +1,21 @@ +package apimodel + +// OpenReportRequest is the OpenReport request. +type OpenReportRequest struct { + DataFormatVersion string `json:"data_format_version"` + Format string `json:"format"` + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + SoftwareName string `json:"software_name"` + SoftwareVersion string `json:"software_version"` + TestName string `json:"test_name"` + TestStartTime string `json:"test_start_time"` + TestVersion string `json:"test_version"` +} + +// OpenReportResponse is the OpenReport response. +type OpenReportResponse struct { + BackendVersion string `json:"backend_version"` + ReportID string `json:"report_id"` + SupportedFormats []string `json:"supported_formats"` +} diff --git a/internal/engine/ooapi/apimodel/psiphonconfig.go b/internal/engine/ooapi/apimodel/psiphonconfig.go new file mode 100644 index 0000000..40f9726 --- /dev/null +++ b/internal/engine/ooapi/apimodel/psiphonconfig.go @@ -0,0 +1,7 @@ +package apimodel + +// PsiphonConfigRequest is the request for the PsiphonConfig API +type PsiphonConfigRequest struct{} + +// PsiphonConfigResponse is the response from the PsiphonConfig API +type PsiphonConfigResponse map[string]interface{} diff --git a/internal/engine/ooapi/apimodel/register.go b/internal/engine/ooapi/apimodel/register.go new file mode 100644 index 0000000..e167306 --- /dev/null +++ b/internal/engine/ooapi/apimodel/register.go @@ -0,0 +1,26 @@ +package apimodel + +// RegisterRequest is the request for the Register API. +type RegisterRequest struct { + // just password + Password string `json:"password"` + + // metadata + AvailableBandwidth string `json:"available_bandwidth,omitempty"` + DeviceToken string `json:"device_token,omitempty"` + Language string `json:"language,omitempty"` + NetworkType string `json:"network_type,omitempty"` + Platform string `json:"platform"` + ProbeASN string `json:"probe_asn"` + ProbeCC string `json:"probe_cc"` + ProbeFamily string `json:"probe_family,omitempty"` + ProbeTimezone string `json:"probe_timezone,omitempty"` + SoftwareName string `json:"software_name"` + SoftwareVersion string `json:"software_version"` + SupportedTests []string `json:"supported_tests"` +} + +// RegisterResponse is the response from the Register API. +type RegisterResponse struct { + ClientID string `json:"client_id"` +} diff --git a/internal/engine/ooapi/apimodel/submitmeasurement.go b/internal/engine/ooapi/apimodel/submitmeasurement.go new file mode 100644 index 0000000..da542d5 --- /dev/null +++ b/internal/engine/ooapi/apimodel/submitmeasurement.go @@ -0,0 +1,13 @@ +package apimodel + +// SubmitMeasurementRequest is the SubmitMeasurement request. +type SubmitMeasurementRequest struct { + ReportID string `path:"report_id"` + Format string `json:"format"` + Content interface{} `json:"content"` +} + +// SubmitMeasurementResponse is the SubmitMeasurement response. +type SubmitMeasurementResponse struct { + MeasurementUID string `json:"measurement_uid"` +} diff --git a/internal/engine/ooapi/apimodel/testhelpers.go b/internal/engine/ooapi/apimodel/testhelpers.go new file mode 100644 index 0000000..775a40d --- /dev/null +++ b/internal/engine/ooapi/apimodel/testhelpers.go @@ -0,0 +1,15 @@ +package apimodel + +// TestHelpersRequest is the TestHelpers request. +type TestHelpersRequest struct{} + +// TestHelpersResponse is the TestHelpers response. +type TestHelpersResponse map[string][]TestHelpersHelperInfo + +// TestHelpersHelperInfo is a single helper within the +// response returned by the TestHelpers API. +type TestHelpersHelperInfo struct { + Address string `json:"address"` + Type string `json:"type"` + Front string `json:"front,omitempty"` +} diff --git a/internal/engine/ooapi/apimodel/tortargets.go b/internal/engine/ooapi/apimodel/tortargets.go new file mode 100644 index 0000000..a4d7c74 --- /dev/null +++ b/internal/engine/ooapi/apimodel/tortargets.go @@ -0,0 +1,16 @@ +package apimodel + +// TorTargetsRequest is a request for the TorTargets API. +type TorTargetsRequest struct{} + +// TorTargetsResponse is the response from the TorTargets API. +type TorTargetsResponse map[string]TorTargetsTarget + +// TorTargetsTarget is a target for the tor experiment. +type TorTargetsTarget struct { + Address string `json:"address"` + Name string `json:"name"` + Params map[string][]string `json:"params"` + Protocol string `json:"protocol"` + Source string `json:"source"` +} diff --git a/internal/engine/ooapi/apimodel/urls.go b/internal/engine/ooapi/apimodel/urls.go new file mode 100644 index 0000000..dd9094f --- /dev/null +++ b/internal/engine/ooapi/apimodel/urls.go @@ -0,0 +1,26 @@ +package apimodel + +// URLsRequest is the URLs request. +type URLsRequest struct { + CategoryCodes string `query:"category_codes"` + CountryCode string `query:"country_code"` + Limit int64 `query:"limit"` +} + +// URLsResponse is the URLs response. +type URLsResponse struct { + Metadata URLsMetadata `json:"metadata"` + Results []URLsResponseURL `json:"results"` +} + +// URLsMetadata contains metadata in the URLs response. +type URLsMetadata struct { + Count int64 `json:"count"` +} + +// URLsResponseURL is a single URL in the URLs response. +type URLsResponseURL struct { + CategoryCode string `json:"category_code"` + CountryCode string `json:"country_code"` + URL string `json:"url"` +} diff --git a/internal/engine/ooapi/apis.go b/internal/engine/ooapi/apis.go new file mode 100644 index 0000000..3c079be --- /dev/null +++ b/internal/engine/ooapi/apis.go @@ -0,0 +1,607 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:50.431349269 +0100 CET m=+0.000196051 + +package ooapi + +//go:generate go run ./internal/generator -file apis.go + +import ( + "context" + "net/http" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// CheckReportIDAPI implements the CheckReportID API. +type CheckReportIDAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *CheckReportIDAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *CheckReportIDAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *CheckReportIDAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *CheckReportIDAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the CheckReportID API. +func (api *CheckReportIDAPI) Call(ctx context.Context, req *apimodel.CheckReportIDRequest) (*apimodel.CheckReportIDResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// CheckInAPI implements the CheckIn API. +type CheckInAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *CheckInAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *CheckInAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *CheckInAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *CheckInAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the CheckIn API. +func (api *CheckInAPI) Call(ctx context.Context, req *apimodel.CheckInRequest) (*apimodel.CheckInResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// LoginAPI implements the Login API. +type LoginAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *LoginAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *LoginAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *LoginAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *LoginAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the Login API. +func (api *LoginAPI) Call(ctx context.Context, req *apimodel.LoginRequest) (*apimodel.LoginResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// MeasurementMetaAPI implements the MeasurementMeta API. +type MeasurementMetaAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *MeasurementMetaAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *MeasurementMetaAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *MeasurementMetaAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *MeasurementMetaAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the MeasurementMeta API. +func (api *MeasurementMetaAPI) Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// RegisterAPI implements the Register API. +type RegisterAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *RegisterAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *RegisterAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *RegisterAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *RegisterAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the Register API. +func (api *RegisterAPI) Call(ctx context.Context, req *apimodel.RegisterRequest) (*apimodel.RegisterResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// TestHelpersAPI implements the TestHelpers API. +type TestHelpersAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *TestHelpersAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *TestHelpersAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *TestHelpersAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *TestHelpersAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the TestHelpers API. +func (api *TestHelpersAPI) Call(ctx context.Context, req *apimodel.TestHelpersRequest) (apimodel.TestHelpersResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// PsiphonConfigAPI implements the PsiphonConfig API. +type PsiphonConfigAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + Token string // mandatory + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +// WithToken returns a copy of the API where the +// value of the Token field is replaced with token. +func (api *PsiphonConfigAPI) WithToken(token string) PsiphonConfigCaller { + out := &PsiphonConfigAPI{} + out.BaseURL = api.BaseURL + out.HTTPClient = api.HTTPClient + out.JSONCodec = api.JSONCodec + out.RequestMaker = api.RequestMaker + out.UserAgent = api.UserAgent + out.Token = token + return out +} + +func (api *PsiphonConfigAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *PsiphonConfigAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *PsiphonConfigAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *PsiphonConfigAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the PsiphonConfig API. +func (api *PsiphonConfigAPI) Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.Token == "" { + return nil, ErrMissingToken + } + httpReq.Header.Add("Authorization", newAuthorizationHeader(api.Token)) + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// TorTargetsAPI implements the TorTargets API. +type TorTargetsAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + Token string // mandatory + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +// WithToken returns a copy of the API where the +// value of the Token field is replaced with token. +func (api *TorTargetsAPI) WithToken(token string) TorTargetsCaller { + out := &TorTargetsAPI{} + out.BaseURL = api.BaseURL + out.HTTPClient = api.HTTPClient + out.JSONCodec = api.JSONCodec + out.RequestMaker = api.RequestMaker + out.UserAgent = api.UserAgent + out.Token = token + return out +} + +func (api *TorTargetsAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *TorTargetsAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *TorTargetsAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *TorTargetsAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the TorTargets API. +func (api *TorTargetsAPI) Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.Token == "" { + return nil, ErrMissingToken + } + httpReq.Header.Add("Authorization", newAuthorizationHeader(api.Token)) + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// URLsAPI implements the URLs API. +type URLsAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *URLsAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *URLsAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *URLsAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *URLsAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the URLs API. +func (api *URLsAPI) Call(ctx context.Context, req *apimodel.URLsRequest) (*apimodel.URLsResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// OpenReportAPI implements the OpenReport API. +type OpenReportAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + UserAgent string // optional +} + +func (api *OpenReportAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *OpenReportAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *OpenReportAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *OpenReportAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the OpenReport API. +func (api *OpenReportAPI) Call(ctx context.Context, req *apimodel.OpenReportRequest) (*apimodel.OpenReportResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} + +// SubmitMeasurementAPI implements the SubmitMeasurement API. +type SubmitMeasurementAPI struct { + BaseURL string // optional + HTTPClient HTTPClient // optional + JSONCodec JSONCodec // optional + RequestMaker RequestMaker // optional + TemplateExecutor TemplateExecutor // optional + UserAgent string // optional +} + +func (api *SubmitMeasurementAPI) baseURL() string { + if api.BaseURL != "" { + return api.BaseURL + } + return "https://ps1.ooni.io" +} + +func (api *SubmitMeasurementAPI) requestMaker() RequestMaker { + if api.RequestMaker != nil { + return api.RequestMaker + } + return &defaultRequestMaker{} +} + +func (api *SubmitMeasurementAPI) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *SubmitMeasurementAPI) templateExecutor() TemplateExecutor { + if api.TemplateExecutor != nil { + return api.TemplateExecutor + } + return &defaultTemplateExecutor{} +} + +func (api *SubmitMeasurementAPI) httpClient() HTTPClient { + if api.HTTPClient != nil { + return api.HTTPClient + } + return http.DefaultClient +} + +// Call calls the SubmitMeasurement API. +func (api *SubmitMeasurementAPI) Call(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*apimodel.SubmitMeasurementResponse, error) { + httpReq, err := api.newRequest(ctx, req) + if err != nil { + return nil, err + } + httpReq.Header.Add("Accept", "application/json") + if api.UserAgent != "" { + httpReq.Header.Add("User-Agent", api.UserAgent) + } + return api.newResponse(api.httpClient().Do(httpReq)) +} diff --git a/internal/engine/ooapi/apis_test.go b/internal/engine/ooapi/apis_test.go new file mode 100644 index 0000000..ace748d --- /dev/null +++ b/internal/engine/ooapi/apis_test.go @@ -0,0 +1,2776 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:50.81792142 +0100 CET m=+0.000095792 + +package ooapi + +//go:generate go run ./internal/generator -file apis_test.go + +import ( + "context" + "encoding/json" + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func TestCheckReportIDInvalidURL(t *testing.T) { + api := &CheckReportIDAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &CheckReportIDAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckReportIDWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleCheckReportID struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.CheckReportIDResponse + url *url.URL + userAgent string +} + +func (h *handleCheckReportID) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.CheckReportIDResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestCheckReportIDRoundTrip(t *testing.T) { + // setup + handler := &handleCheckReportID{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.CheckReportIDRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &CheckReportIDAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestCheckReportIDMandatoryFields(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 500, + }} + api := &CheckReportIDAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckReportIDRequest{} // deliberately empty + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrEmptyField) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInInvalidURL(t *testing.T) { + api := &CheckInAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &CheckInAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &CheckInAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &CheckInAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestCheckInWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &CheckInAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleCheckIn struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.CheckInResponse + url *url.URL + userAgent string +} + +func (h *handleCheckIn) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.CheckInResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestCheckInRoundTrip(t *testing.T) { + // setup + handler := &handleCheckIn{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.CheckInRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &CheckInAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.CheckInRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestLoginInvalidURL(t *testing.T) { + api := &LoginAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &LoginAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &LoginAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &LoginAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestLoginWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &LoginAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleLogin struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.LoginResponse + url *url.URL + userAgent string +} + +func (h *handleLogin) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.LoginResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestLoginRoundTrip(t *testing.T) { + // setup + handler := &handleLogin{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.LoginRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &LoginAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.LoginRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestMeasurementMetaInvalidURL(t *testing.T) { + api := &MeasurementMetaAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &MeasurementMetaAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestMeasurementMetaWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleMeasurementMeta struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.MeasurementMetaResponse + url *url.URL + userAgent string +} + +func (h *handleMeasurementMeta) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.MeasurementMetaResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestMeasurementMetaRoundTrip(t *testing.T) { + // setup + handler := &handleMeasurementMeta{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.MeasurementMetaRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &MeasurementMetaAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestMeasurementMetaMandatoryFields(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 500, + }} + api := &MeasurementMetaAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.MeasurementMetaRequest{} // deliberately empty + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrEmptyField) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterInvalidURL(t *testing.T) { + api := &RegisterAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &RegisterAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &RegisterAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &RegisterAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestRegisterWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &RegisterAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleRegister struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.RegisterResponse + url *url.URL + userAgent string +} + +func (h *handleRegister) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.RegisterResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestRegisterRoundTrip(t *testing.T) { + // setup + handler := &handleRegister{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.RegisterRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &RegisterAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.RegisterRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestTestHelpersInvalidURL(t *testing.T) { + api := &TestHelpersAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &TestHelpersAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTestHelpersWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &TestHelpersAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleTestHelpers struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp apimodel.TestHelpersResponse + url *url.URL + userAgent string +} + +func (h *handleTestHelpers) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out apimodel.TestHelpersResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestTestHelpersRoundTrip(t *testing.T) { + // setup + handler := &handleTestHelpers{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &TestHelpersAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestTestHelpersResponseLiteralNull(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`null`)}, + }} + api := &TestHelpersAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.TestHelpersRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrJSONLiteralNull) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigInvalidURL(t *testing.T) { + api := &PsiphonConfigAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithMissingToken(t *testing.T) { + api := &PsiphonConfigAPI{} // no token + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrMissingToken) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &PsiphonConfigAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestPsiphonConfigWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handlePsiphonConfig struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp apimodel.PsiphonConfigResponse + url *url.URL + userAgent string +} + +func (h *handlePsiphonConfig) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out apimodel.PsiphonConfigResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestPsiphonConfigRoundTrip(t *testing.T) { + // setup + handler := &handlePsiphonConfig{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &PsiphonConfigAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + ff.fill(&api.Token) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestPsiphonConfigResponseLiteralNull(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`null`)}, + }} + api := &PsiphonConfigAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.PsiphonConfigRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrJSONLiteralNull) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsInvalidURL(t *testing.T) { + api := &TorTargetsAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithMissingToken(t *testing.T) { + api := &TorTargetsAPI{} // no token + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrMissingToken) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &TorTargetsAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestTorTargetsWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &TorTargetsAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleTorTargets struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp apimodel.TorTargetsResponse + url *url.URL + userAgent string +} + +func (h *handleTorTargets) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out apimodel.TorTargetsResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestTorTargetsRoundTrip(t *testing.T) { + // setup + handler := &handleTorTargets{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &TorTargetsAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + ff.fill(&api.Token) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestTorTargetsResponseLiteralNull(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`null`)}, + }} + api := &TorTargetsAPI{ + HTTPClient: clnt, + Token: "fakeToken", + } + ctx := context.Background() + req := &apimodel.TorTargetsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrJSONLiteralNull) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsInvalidURL(t *testing.T) { + api := &URLsAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &URLsAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &URLsAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestURLsWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &URLsAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleURLs struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.URLsResponse + url *url.URL + userAgent string +} + +func (h *handleURLs) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.URLsResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestURLsRoundTrip(t *testing.T) { + // setup + handler := &handleURLs{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.URLsRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &URLsAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "GET" { + t.Fatal("invalid method") + } + // check the query + httpReq, err := api.newRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != "" { + t.Fatal(diff) + } +} + +func TestOpenReportInvalidURL(t *testing.T) { + api := &OpenReportAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &OpenReportAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &OpenReportAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &OpenReportAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestOpenReportWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &OpenReportAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleOpenReport struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.OpenReportResponse + url *url.URL + userAgent string +} + +func (h *handleOpenReport) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.OpenReportResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestOpenReportRoundTrip(t *testing.T) { + // setup + handler := &handleOpenReport{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.OpenReportRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &OpenReportAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.OpenReportRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestSubmitMeasurementInvalidURL(t *testing.T) { + api := &SubmitMeasurementAPI{ + BaseURL: "\t", // invalid + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithHTTPErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Err: errMocked} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementMarshalErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &SubmitMeasurementAPI{ + JSONCodec: &FakeCodec{EncodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithNewRequestErr(t *testing.T) { + errMocked := errors.New("mocked error") + api := &SubmitMeasurementAPI{ + RequestMaker: &FakeRequestMaker{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWith401(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrUnauthorized) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWith400(t *testing.T) { + clnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithResponseBodyReadErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Err: errMocked}, + }} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +func TestSubmitMeasurementWithUnmarshalFailure(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 200, + Body: &FakeBody{Data: []byte(`{}`)}, + }} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} + +type handleSubmitMeasurement struct { + accept string + body []byte + contentType string + count int32 + method string + mu sync.Mutex + resp *apimodel.SubmitMeasurementResponse + url *url.URL + userAgent string +} + +func (h *handleSubmitMeasurement) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer h.mu.Unlock() + h.mu.Lock() + if h.count > 0 { + w.WriteHeader(400) + return + } + h.count++ + if r.Body != nil { + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + h.body = data + } + h.method = r.Method + h.url = r.URL + h.accept = r.Header.Get("Accept") + h.contentType = r.Header.Get("Content-Type") + h.userAgent = r.Header.Get("User-Agent") + var out *apimodel.SubmitMeasurementResponse + ff := fakeFill{} + ff.fill(&out) + h.resp = out + data, err := json.Marshal(out) + if err != nil { + w.WriteHeader(400) + return + } + w.Write(data) +} + +func TestSubmitMeasurementRoundTrip(t *testing.T) { + // setup + handler := &handleSubmitMeasurement{} + srvr := httptest.NewServer(handler) + defer srvr.Close() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(&req) + api := &SubmitMeasurementAPI{BaseURL: srvr.URL} + ff.fill(&api.UserAgent) + // issue request + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response here") + } + // compare our response and server's one + if diff := cmp.Diff(handler.resp, resp); diff != "" { + t.Fatal(diff) + } + // check whether headers are OK + if handler.accept != "application/json" { + t.Fatal("invalid accept header") + } + if handler.userAgent != api.UserAgent { + t.Fatal("invalid user-agent header") + } + // check whether the method is OK + if handler.method != "POST" { + t.Fatal("invalid method") + } + // check the body + if handler.contentType != "application/json" { + t.Fatal("invalid content-type header") + } + got := &apimodel.SubmitMeasurementRequest{} + if err := json.Unmarshal(handler.body, &got); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(req, got); diff != "" { + t.Fatal(diff) + } +} + +func TestSubmitMeasurementTemplateErr(t *testing.T) { + errMocked := errors.New("mocked error") + clnt := &FakeHTTPClient{Resp: &http.Response{ + StatusCode: 500, + }} + api := &SubmitMeasurementAPI{ + HTTPClient: clnt, + TemplateExecutor: &FakeTemplateExecutor{Err: errMocked}, + } + ctx := context.Background() + req := &apimodel.SubmitMeasurementRequest{} + ff := &fakeFill{} + ff.fill(req) + resp, err := api.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } +} diff --git a/internal/engine/ooapi/caching.go b/internal/engine/ooapi/caching.go new file mode 100644 index 0000000..2b2e333 --- /dev/null +++ b/internal/engine/ooapi/caching.go @@ -0,0 +1,98 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:51.194159684 +0100 CET m=+0.000175181 + +package ooapi + +//go:generate go run ./internal/generator -file caching.go + +import ( + "context" + "reflect" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// MeasurementMetaCache implements caching for MeasurementMetaAPI. +type MeasurementMetaCache struct { + API MeasurementMetaCaller // mandatory + GobCodec GobCodec // optional + KVStore KVStore // mandatory +} + +type cacheEntryForMeasurementMeta struct { + Req *apimodel.MeasurementMetaRequest + Resp *apimodel.MeasurementMetaResponse +} + +// Call calls the API and implements caching. +func (c *MeasurementMetaCache) Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + if resp, _ := c.readcache(req); resp != nil { + return resp, nil + } + resp, err := c.API.Call(ctx, req) + if err != nil { + return nil, err + } + if err := c.writecache(req, resp); err != nil { + return nil, err + } + return resp, nil +} + +func (c *MeasurementMetaCache) gobCodec() GobCodec { + if c.GobCodec != nil { + return c.GobCodec + } + return &defaultGobCodec{} +} + +func (c *MeasurementMetaCache) getcache() ([]cacheEntryForMeasurementMeta, error) { + data, err := c.KVStore.Get("MeasurementMeta.cache") + if err != nil { + return nil, err + } + var out []cacheEntryForMeasurementMeta + if err := c.gobCodec().Decode(data, &out); err != nil { + return nil, err + } + return out, nil +} + +func (c *MeasurementMetaCache) setcache(in []cacheEntryForMeasurementMeta) error { + data, err := c.gobCodec().Encode(in) + if err != nil { + return err + } + return c.KVStore.Set("MeasurementMeta.cache", data) +} + +func (c *MeasurementMetaCache) readcache(req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + cache, err := c.getcache() + if err != nil { + return nil, err + } + for _, cur := range cache { + if reflect.DeepEqual(req, cur.Req) { + return cur.Resp, nil + } + } + return nil, errCacheNotFound +} + +func (c *MeasurementMetaCache) writecache(req *apimodel.MeasurementMetaRequest, resp *apimodel.MeasurementMetaResponse) error { + cache, _ := c.getcache() + out := []cacheEntryForMeasurementMeta{{Req: req, Resp: resp}} + const toomany = 64 + for idx, cur := range cache { + if reflect.DeepEqual(req, cur.Req) { + continue // we already updated the cache + } + if idx > toomany { + break + } + out = append(out, cur) + } + return c.setcache(out) +} + +var _ MeasurementMetaCaller = &MeasurementMetaCache{} diff --git a/internal/engine/ooapi/caching_test.go b/internal/engine/ooapi/caching_test.go new file mode 100644 index 0000000..592d69c --- /dev/null +++ b/internal/engine/ooapi/caching_test.go @@ -0,0 +1,222 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:51.49660021 +0100 CET m=+0.000217672 + +package ooapi + +//go:generate go run ./internal/generator -file caching_test.go + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func TestCacheMeasurementMetaAPISuccess(t *testing.T) { + ff := &fakeFill{} + var expect *apimodel.MeasurementMetaResponse + ff.fill(&expect) + cache := &MeasurementMetaCache{ + API: &FakeMeasurementMetaAPI{ + Response: expect, + }, + KVStore: &memkvstore{}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + resp, err := cache.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } +} + +func TestCacheMeasurementMetaAPIWriteCacheError(t *testing.T) { + errMocked := errors.New("mocked error") + ff := &fakeFill{} + var expect *apimodel.MeasurementMetaResponse + ff.fill(&expect) + cache := &MeasurementMetaCache{ + API: &FakeMeasurementMetaAPI{ + Response: expect, + }, + KVStore: &FakeKVStore{SetError: errMocked}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + resp, err := cache.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestCacheMeasurementMetaAPIFailureWithNoCache(t *testing.T) { + errMocked := errors.New("mocked error") + ff := &fakeFill{} + cache := &MeasurementMetaCache{ + API: &FakeMeasurementMetaAPI{ + Err: errMocked, + }, + KVStore: &memkvstore{}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + resp, err := cache.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } +} + +func TestCacheMeasurementMetaAPIFailureWithPreviousCache(t *testing.T) { + ff := &fakeFill{} + var expect *apimodel.MeasurementMetaResponse + ff.fill(&expect) + fakeapi := &FakeMeasurementMetaAPI{ + Response: expect, + } + cache := &MeasurementMetaCache{ + API: fakeapi, + KVStore: &memkvstore{}, + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + ctx := context.Background() + // first pass with no error at all + // use a separate scope to be sure we avoid mistakes + { + resp, err := cache.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + } + // second pass with failure + errMocked := errors.New("mocked error") + fakeapi.Err = errMocked + fakeapi.Response = nil + resp2, err := cache.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp2 == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp2); diff != "" { + t.Fatal(diff) + } +} + +func TestCacheMeasurementMetaAPISetcacheWithEncodeError(t *testing.T) { + ff := &fakeFill{} + errMocked := errors.New("mocked error") + var in []cacheEntryForMeasurementMeta + ff.fill(&in) + cache := &MeasurementMetaCache{ + GobCodec: &FakeCodec{EncodeErr: errMocked}, + } + err := cache.setcache(in) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } +} + +func TestCacheMeasurementMetaAPIReadCacheNotFound(t *testing.T) { + ff := &fakeFill{} + var incache []cacheEntryForMeasurementMeta + ff.fill(&incache) + cache := &MeasurementMetaCache{ + KVStore: &memkvstore{}, + } + err := cache.setcache(incache) + if err != nil { + t.Fatal(err) + } + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + out, err := cache.readcache(req) + if !errors.Is(err, errCacheNotFound) { + t.Fatal("not the error we expected", err) + } + if out != nil { + t.Fatal("expected nil here") + } +} + +func TestCacheMeasurementMetaAPIWriteCacheDuplicate(t *testing.T) { + ff := &fakeFill{} + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + var resp1 *apimodel.MeasurementMetaResponse + ff.fill(&resp1) + var resp2 *apimodel.MeasurementMetaResponse + ff.fill(&resp2) + cache := &MeasurementMetaCache{ + KVStore: &memkvstore{}, + } + err := cache.writecache(req, resp1) + if err != nil { + t.Fatal(err) + } + err = cache.writecache(req, resp2) + if err != nil { + t.Fatal(err) + } + out, err := cache.readcache(req) + if err != nil { + t.Fatal(err) + } + if out == nil { + t.Fatal("expected non-nil here") + } + if diff := cmp.Diff(resp2, out); diff != "" { + t.Fatal(diff) + } +} + +func TestCacheMeasurementMetaAPICacheSizeLimited(t *testing.T) { + ff := &fakeFill{} + cache := &MeasurementMetaCache{ + KVStore: &memkvstore{}, + } + var prev int + for { + var req *apimodel.MeasurementMetaRequest + ff.fill(&req) + var resp *apimodel.MeasurementMetaResponse + ff.fill(&resp) + err := cache.writecache(req, resp) + if err != nil { + t.Fatal(err) + } + out, err := cache.getcache() + if err != nil { + t.Fatal(err) + } + if len(out) > prev { + prev = len(out) + continue + } + break + } +} diff --git a/internal/engine/ooapi/callers.go b/internal/engine/ooapi/callers.go new file mode 100644 index 0000000..89a1b94 --- /dev/null +++ b/internal/engine/ooapi/callers.go @@ -0,0 +1,78 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:51.773813223 +0100 CET m=+0.000114768 + +package ooapi + +//go:generate go run ./internal/generator -file callers.go + +import ( + "context" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// CheckReportIDCaller represents any type exposing a method +// like CheckReportIDAPI.Call. +type CheckReportIDCaller interface { + Call(ctx context.Context, req *apimodel.CheckReportIDRequest) (*apimodel.CheckReportIDResponse, error) +} + +// CheckInCaller represents any type exposing a method +// like CheckInAPI.Call. +type CheckInCaller interface { + Call(ctx context.Context, req *apimodel.CheckInRequest) (*apimodel.CheckInResponse, error) +} + +// LoginCaller represents any type exposing a method +// like LoginAPI.Call. +type LoginCaller interface { + Call(ctx context.Context, req *apimodel.LoginRequest) (*apimodel.LoginResponse, error) +} + +// MeasurementMetaCaller represents any type exposing a method +// like MeasurementMetaAPI.Call. +type MeasurementMetaCaller interface { + Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) +} + +// RegisterCaller represents any type exposing a method +// like RegisterAPI.Call. +type RegisterCaller interface { + Call(ctx context.Context, req *apimodel.RegisterRequest) (*apimodel.RegisterResponse, error) +} + +// TestHelpersCaller represents any type exposing a method +// like TestHelpersAPI.Call. +type TestHelpersCaller interface { + Call(ctx context.Context, req *apimodel.TestHelpersRequest) (apimodel.TestHelpersResponse, error) +} + +// PsiphonConfigCaller represents any type exposing a method +// like PsiphonConfigAPI.Call. +type PsiphonConfigCaller interface { + Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) +} + +// TorTargetsCaller represents any type exposing a method +// like TorTargetsAPI.Call. +type TorTargetsCaller interface { + Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) +} + +// URLsCaller represents any type exposing a method +// like URLsAPI.Call. +type URLsCaller interface { + Call(ctx context.Context, req *apimodel.URLsRequest) (*apimodel.URLsResponse, error) +} + +// OpenReportCaller represents any type exposing a method +// like OpenReportAPI.Call. +type OpenReportCaller interface { + Call(ctx context.Context, req *apimodel.OpenReportRequest) (*apimodel.OpenReportResponse, error) +} + +// SubmitMeasurementCaller represents any type exposing a method +// like SubmitMeasurementAPI.Call. +type SubmitMeasurementCaller interface { + Call(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*apimodel.SubmitMeasurementResponse, error) +} diff --git a/internal/engine/ooapi/cloners.go b/internal/engine/ooapi/cloners.go new file mode 100644 index 0000000..d461e5a --- /dev/null +++ b/internal/engine/ooapi/cloners.go @@ -0,0 +1,18 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.108352268 +0100 CET m=+0.000275862 + +package ooapi + +//go:generate go run ./internal/generator -file cloners.go + +// PsiphonConfigCaller represents any type exposing a method +// like PsiphonConfigAPI.WithToken. +type PsiphonConfigCloner interface { + WithToken(token string) PsiphonConfigCaller +} + +// TorTargetsCaller represents any type exposing a method +// like TorTargetsAPI.WithToken. +type TorTargetsCloner interface { + WithToken(token string) TorTargetsCaller +} diff --git a/internal/engine/ooapi/default.go b/internal/engine/ooapi/default.go new file mode 100644 index 0000000..9917760 --- /dev/null +++ b/internal/engine/ooapi/default.go @@ -0,0 +1,57 @@ +package ooapi + +import ( + "bytes" + "context" + "encoding/gob" + "encoding/json" + "io" + "net/http" + "strings" + "text/template" +) + +type defaultRequestMaker struct{} + +func (*defaultRequestMaker) NewRequest( + ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) { + return http.NewRequestWithContext(ctx, method, URL, body) +} + +type defaultJSONCodec struct{} + +func (*defaultJSONCodec) Encode(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +func (*defaultJSONCodec) Decode(b []byte, v interface{}) error { + return json.Unmarshal(b, v) +} + +type defaultTemplateExecutor struct{} + +func (*defaultTemplateExecutor) Execute(tmpl string, v interface{}) (string, error) { + to, err := template.New("t").Parse(tmpl) + if err != nil { + return "", err + } + var sb strings.Builder + if err := to.Execute(&sb, v); err != nil { + return "", err + } + return sb.String(), nil +} + +type defaultGobCodec struct{} + +func (*defaultGobCodec) Encode(v interface{}) ([]byte, error) { + var bb bytes.Buffer + if err := gob.NewEncoder(&bb).Encode(v); err != nil { + return nil, err + } + return bb.Bytes(), nil +} + +func (*defaultGobCodec) Decode(b []byte, v interface{}) error { + return gob.NewDecoder(bytes.NewReader(b)).Decode(v) +} diff --git a/internal/engine/ooapi/default_test.go b/internal/engine/ooapi/default_test.go new file mode 100644 index 0000000..930181f --- /dev/null +++ b/internal/engine/ooapi/default_test.go @@ -0,0 +1,41 @@ +package ooapi + +import ( + "strings" + "testing" +) + +func TestDefaultTemplateExecutorParseError(t *testing.T) { + te := &defaultTemplateExecutor{} + out, err := te.Execute("{{ .Foo", nil) + if err == nil || !strings.HasSuffix(err.Error(), "unclosed action") { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string") + } +} + +func TestDefaultTemplateExecutorExecError(t *testing.T) { + te := &defaultTemplateExecutor{} + arg := make(chan interface{}) + out, err := te.Execute("{{ .Foo }}", arg) + if err == nil || !strings.Contains(err.Error(), `can't evaluate field Foo`) { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string") + } +} + +func TestDefaultGobCodecEncodeError(t *testing.T) { + codec := &defaultGobCodec{} + arg := make(chan interface{}) + data, err := codec.Encode(arg) + if err == nil || !strings.Contains(err.Error(), "can't handle type") { + t.Fatal("not the error we expected", err) + } + if data != nil { + t.Fatal("expected nil data") + } +} diff --git a/internal/engine/ooapi/dependencies.go b/internal/engine/ooapi/dependencies.go new file mode 100644 index 0000000..0b1bbf1 --- /dev/null +++ b/internal/engine/ooapi/dependencies.go @@ -0,0 +1,54 @@ +package ooapi + +import ( + "context" + "io" + "net/http" +) + +// JSONCodec is a JSON encoder and decoder. +type JSONCodec interface { + // Encode encodes v as a serialized JSON byte slice. + Encode(v interface{}) ([]byte, error) + + // Decode decodes the serialized JSON byte slice into v. + Decode(b []byte, v interface{}) error +} + +// RequestMaker makes an HTTP request. +type RequestMaker interface { + // NewRequest creates a new HTTP request. + NewRequest(ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) +} + +// TemplateExecutor parses and executes a text template. +type TemplateExecutor interface { + // Execute takes in input a template string and some piece of data. It + // returns either a string where template parameters have been replaced, + // on success, or an error, on failure. + Execute(tmpl string, v interface{}) (string, error) +} + +// HTTPClient is the interface of a generic HTTP client. +type HTTPClient interface { + // Do should work like http.Client.Do. + Do(req *http.Request) (*http.Response, error) +} + +// GobCodec is a Gob encoder and decoder. +type GobCodec interface { + // Encode encodes v as a serialized gob byte slice. + Encode(v interface{}) ([]byte, error) + + // Decode decodes the serialized gob byte slice into v. + Decode(b []byte, v interface{}) error +} + +// KVStore is a key-value store. +type KVStore interface { + // Get gets a value from the key-value store. + Get(key string) ([]byte, error) + + // Set stores a value into the key-value store. + Set(key string, value []byte) error +} diff --git a/internal/engine/ooapi/doc.go b/internal/engine/ooapi/doc.go new file mode 100644 index 0000000..1f8f6fd --- /dev/null +++ b/internal/engine/ooapi/doc.go @@ -0,0 +1,163 @@ +// Package ooapi contains clients for the OONI API. We +// automatically generate the code in this package from +// the apimodel and internal/generator packages. For +// each OONI API, we define up to three data structures: +// +// 1. a data structure representing the API; +// +// 2. a caching data structure, if the API +// supports caching; +// +// 3. an auto-login data structure, if the API +// requires login. +// +// The rest of this documentation page describes these +// three data structures and the design and architecture +// of this package. Refer to subpackages for more +// information on how to specify an API. +// +// API data structure +// +// For each API, this package defines a data structure +// representing the API. For example, for the TorTargets API, +// we define the TorTargetsAPI data structure. +// +// The API data structure defines a method named Call that +// allows calling the specified API. Call takes as arguments +// a context and the request for the API and returns the +// API response or an error. +// +// Request and response messages live inside the apimodel +// subpackage. We name them after the API. Thus, for +// the TorTargets API, the request is TorTargetsRequest, +// and the response is TorTargetsResponse. +// +// API data structures are cheap to create and do not +// mutate. They should be used in place and then forgotten +// off once the API call is complete. +// +// Unless explicitly indicated, the zero value of every +// API data structure is a valid API data structure. +// +// In terms of dependencies, APIs certainly need an http.Client +// to communicate with the OONI backend. To represent such a +// client, we use the HTTPClient interface. If you do not tell +// an API which http.Client to use, we will default to the +// standard library's http.DefaultClient. +// +// An API also depends on a JSONCodec. That is, on a data +// structures that encodes data to/from JSON. If you do not +// specify explicitly a JSONCodec, we will use the Go +// standard library's JSON implementation. +// +// When an API requires authentication, you need to tell +// it which authentication token to use. This gives you +// control over obtaining the token and is the low-level +// way of interacting with authenticated APIs. We recommend +// using the auto-login wrappers instead (see below). +// +// Authenticated APIs also define the WithToken method. This +// method takes as argument a token and returns a copy of the +// original API using the given token. We use this method +// to implement auto-login wrappers. +// +// For each API, we also define two interfaces: +// +// 1. the Caller interface represents the possibility of +// calling a specific API with the correct arguments; +// +// 2. the Cloner interface represents the possibility of +// calling WithToken on the given API. +// +// They abstract the interaction between the API type and +// its caching and auto-login wrappers. +// +// Caching +// +// If an API supports caching, we define a type whose name +// ends in Cache. The TorTargets API cache, for example, +// is TorTargetsCache. These caching types wrap the API type +// and provide the caching functionality. +// +// Because the cache needs to read from and write to the +// disk, a caching type needs a KVStore. A KVStore is +// an interface that allow you to bind a specific key to +// a given blob of bytes and to retrieve such bytes later. +// +// Caches use the gob data format from the Go standard +// library (`encoding/gob`). We abstract this dependency +// using the GobCodec interface. By default, when you +// do not specify a GobCodec we use the implementation +// of gob from the Go standard library. +// +// See the example describing caching for more information +// on how to use caching. +// +// Auto-login +// +// If an API supports auto-login, we define a type whose +// name ends with WithLogin. The TorTargets auto-login struct, +// for example, is called TorTargetsAPIWithLogin. +// +// Auto-login wrappers need to store persistent data. We +// use a KVStore for that (see above). We encode login data +// using JSON. To this end, we use a JSONCodec (also +// described above). +// +// See the example describing auto-login for more information +// on how to use auto-login. +// +// Design +// +// Most of the code in this package is auto-generated from the +// data model in ./apimodel and the definition of APIs provided +// by ./internal/generator/spec.go. +// +// We keep the generated files up-to-date by running +// +// go generate ./... +// +// We have tests that ensure that the definition of the API +// used here is reasonably close to the server's one. +// +// Testing +// +// The following command +// +// go test ./... +// +// will, among other things, ensure that the our API spec +// is consistent with the server's one. Running +// +// go test -short ./... +// +// will exclude most (slow) integration tests. +// +// Architecture +// +// The ./apimodel package contains the definition of request +// and response messages. We rely on tagging to specify how +// we should encode and decode messages. +// +// The ./internal/generator contains code to generate most +// code in this package. In particular, the spec.go file is +// the specification of the APIs. +// +// Notable generated files +// +// - apis.go: contains APIs (e.g., TorTargetsAPI); +// +// - caching.go: contains caching wrappers for every API +// that declares that it needs a cache (e.g., TorTargetsCache); +// +// - callers.go: contains Callers; +// +// - cloners.go: contains the Cloners; +// +// - login.go: contains auto-login wrappers (e.g., +// TorTargetsAPIWithLogin); +// +// - requests.go: contains code to generate http.Requests. +// +// - responses.go: code to parse http.Responses. +package ooapi diff --git a/internal/engine/ooapi/errors.go b/internal/engine/ooapi/errors.go new file mode 100644 index 0000000..35ed7cc --- /dev/null +++ b/internal/engine/ooapi/errors.go @@ -0,0 +1,14 @@ +package ooapi + +import "errors" + +// Errors defined by this package. In addition to these errors, this +// package may of course return any other stdlib specific error. +var ( + ErrEmptyField = errors.New("apiclient: empty field") + ErrHTTPFailure = errors.New("apiclient: http request failed") + ErrJSONLiteralNull = errors.New("apiclient: server returned us a literal null") + ErrMissingToken = errors.New("apiclient: missing auth token") + ErrUnauthorized = errors.New("apiclient: not authorized") + errCacheNotFound = errors.New("apiclient: not found in cache") +) diff --git a/internal/engine/ooapi/fake_test.go b/internal/engine/ooapi/fake_test.go new file mode 100644 index 0000000..9b0b708 --- /dev/null +++ b/internal/engine/ooapi/fake_test.go @@ -0,0 +1,96 @@ +package ooapi + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "time" +) + +type FakeCodec struct { + DecodeErr error + EncodeData []byte + EncodeErr error +} + +func (mc *FakeCodec) Encode(v interface{}) ([]byte, error) { + return mc.EncodeData, mc.EncodeErr +} + +func (mc *FakeCodec) Decode(b []byte, v interface{}) error { + return mc.DecodeErr +} + +type FakeHTTPClient struct { + Err error + Resp *http.Response +} + +func (c *FakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + time.Sleep(10 * time.Microsecond) + if req.Body != nil { + _, _ = ioutil.ReadAll(req.Body) + req.Body.Close() + } + if c.Err != nil { + return nil, c.Err + } + c.Resp.Request = req // non thread safe but it doesn't matter + return c.Resp, nil +} + +type FakeBody struct { + Data []byte + Err error +} + +func (fb *FakeBody) Read(p []byte) (int, error) { + time.Sleep(10 * time.Microsecond) + if fb.Err != nil { + return 0, fb.Err + } + if len(fb.Data) <= 0 { + return 0, io.EOF + } + n := copy(p, fb.Data) + fb.Data = fb.Data[n:] + return n, nil +} + +func (fb *FakeBody) Close() error { + return nil +} + +type FakeRequestMaker struct { + Req *http.Request + Err error +} + +func (frm *FakeRequestMaker) NewRequest( + ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) { + return frm.Req, frm.Err +} + +type FakeTemplateExecutor struct { + Out string + Err error +} + +func (fte *FakeTemplateExecutor) Execute(tmpl string, v interface{}) (string, error) { + return fte.Out, fte.Err +} + +type FakeKVStore struct { + SetError error + GetData []byte + GetError error +} + +func (fs *FakeKVStore) Get(key string) ([]byte, error) { + return fs.GetData, fs.GetError +} + +func (fs *FakeKVStore) Set(key string, value []byte) error { + return fs.SetError +} diff --git a/internal/engine/ooapi/fakeapi_test.go b/internal/engine/ooapi/fakeapi_test.go new file mode 100644 index 0000000..3084483 --- /dev/null +++ b/internal/engine/ooapi/fakeapi_test.go @@ -0,0 +1,190 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.357709034 +0100 CET m=+0.000208565 + +package ooapi + +//go:generate go run ./internal/generator -file fakeapi_test.go + +import ( + "context" + "sync/atomic" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +type FakeCheckReportIDAPI struct { + Err error + Response *apimodel.CheckReportIDResponse + CountCall int32 +} + +func (fapi *FakeCheckReportIDAPI) Call(ctx context.Context, req *apimodel.CheckReportIDRequest) (*apimodel.CheckReportIDResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ CheckReportIDCaller = &FakeCheckReportIDAPI{} +) + +type FakeCheckInAPI struct { + Err error + Response *apimodel.CheckInResponse + CountCall int32 +} + +func (fapi *FakeCheckInAPI) Call(ctx context.Context, req *apimodel.CheckInRequest) (*apimodel.CheckInResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ CheckInCaller = &FakeCheckInAPI{} +) + +type FakeLoginAPI struct { + Err error + Response *apimodel.LoginResponse + CountCall int32 +} + +func (fapi *FakeLoginAPI) Call(ctx context.Context, req *apimodel.LoginRequest) (*apimodel.LoginResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ LoginCaller = &FakeLoginAPI{} +) + +type FakeMeasurementMetaAPI struct { + Err error + Response *apimodel.MeasurementMetaResponse + CountCall int32 +} + +func (fapi *FakeMeasurementMetaAPI) Call(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*apimodel.MeasurementMetaResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ MeasurementMetaCaller = &FakeMeasurementMetaAPI{} +) + +type FakeRegisterAPI struct { + Err error + Response *apimodel.RegisterResponse + CountCall int32 +} + +func (fapi *FakeRegisterAPI) Call(ctx context.Context, req *apimodel.RegisterRequest) (*apimodel.RegisterResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ RegisterCaller = &FakeRegisterAPI{} +) + +type FakeTestHelpersAPI struct { + Err error + Response apimodel.TestHelpersResponse + CountCall int32 +} + +func (fapi *FakeTestHelpersAPI) Call(ctx context.Context, req *apimodel.TestHelpersRequest) (apimodel.TestHelpersResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ TestHelpersCaller = &FakeTestHelpersAPI{} +) + +type FakePsiphonConfigAPI struct { + WithResult PsiphonConfigCaller + Err error + Response apimodel.PsiphonConfigResponse + CountCall int32 +} + +func (fapi *FakePsiphonConfigAPI) Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +func (fapi *FakePsiphonConfigAPI) WithToken(token string) PsiphonConfigCaller { + return fapi.WithResult +} + +var ( + _ PsiphonConfigCaller = &FakePsiphonConfigAPI{} + _ PsiphonConfigCloner = &FakePsiphonConfigAPI{} +) + +type FakeTorTargetsAPI struct { + WithResult TorTargetsCaller + Err error + Response apimodel.TorTargetsResponse + CountCall int32 +} + +func (fapi *FakeTorTargetsAPI) Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +func (fapi *FakeTorTargetsAPI) WithToken(token string) TorTargetsCaller { + return fapi.WithResult +} + +var ( + _ TorTargetsCaller = &FakeTorTargetsAPI{} + _ TorTargetsCloner = &FakeTorTargetsAPI{} +) + +type FakeURLsAPI struct { + Err error + Response *apimodel.URLsResponse + CountCall int32 +} + +func (fapi *FakeURLsAPI) Call(ctx context.Context, req *apimodel.URLsRequest) (*apimodel.URLsResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ URLsCaller = &FakeURLsAPI{} +) + +type FakeOpenReportAPI struct { + Err error + Response *apimodel.OpenReportResponse + CountCall int32 +} + +func (fapi *FakeOpenReportAPI) Call(ctx context.Context, req *apimodel.OpenReportRequest) (*apimodel.OpenReportResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ OpenReportCaller = &FakeOpenReportAPI{} +) + +type FakeSubmitMeasurementAPI struct { + Err error + Response *apimodel.SubmitMeasurementResponse + CountCall int32 +} + +func (fapi *FakeSubmitMeasurementAPI) Call(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*apimodel.SubmitMeasurementResponse, error) { + atomic.AddInt32(&fapi.CountCall, 1) + return fapi.Response, fapi.Err +} + +var ( + _ SubmitMeasurementCaller = &FakeSubmitMeasurementAPI{} +) diff --git a/internal/engine/ooapi/fakefill_test.go b/internal/engine/ooapi/fakefill_test.go new file mode 100644 index 0000000..5ee7776 --- /dev/null +++ b/internal/engine/ooapi/fakefill_test.go @@ -0,0 +1,146 @@ +package ooapi + +import ( + "math/rand" + "reflect" + "sync" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// fakeFill fills specific data structures with random data. The only +// exception to this behaviour is time.Time, which is instead filled +// with the current time plus a small random number of seconds. +// +// We use this implementation to initialize data in our model. The code +// has been written with that in mind. It will require some hammering in +// case we extend the model with new field types. +type fakeFill struct { + mu sync.Mutex + now func() time.Time + rnd *rand.Rand +} + +func (ff *fakeFill) getRandLocked() *rand.Rand { + if ff.rnd == nil { + now := time.Now + if ff.now != nil { + now = ff.now + } + ff.rnd = rand.New(rand.NewSource(now().UnixNano())) + } + return ff.rnd +} + +func (ff *fakeFill) getRandomString() string { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + n := rnd.Intn(63) + 1 + // See https://stackoverflow.com/a/31832326 + var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rnd.Intn(len(letterRunes))] + } + return string(b) +} + +func (ff *fakeFill) getRandomInt64() int64 { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + return rnd.Int63() +} + +func (ff *fakeFill) getRandomBool() bool { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + return rnd.Float64() >= 0.5 +} + +func (ff *fakeFill) getRandomSmallPositiveInt() int { + defer ff.mu.Unlock() + ff.mu.Lock() + rnd := ff.getRandLocked() + return int(rnd.Int63n(8)) + 1 // safe cast +} + +func (ff *fakeFill) doFill(v reflect.Value) { + for v.Type().Kind() == reflect.Ptr { + if v.IsNil() { + // if the pointer is nil, allocate an element + v.Set(reflect.New(v.Type().Elem())) + } + // switch to the element + v = v.Elem() + } + switch v.Type().Kind() { + case reflect.String: + v.SetString(ff.getRandomString()) + case reflect.Int64: + v.SetInt(ff.getRandomInt64()) + case reflect.Bool: + v.SetBool(ff.getRandomBool()) + case reflect.Struct: + if v.Type().String() == "time.Time" { + // Implementation note: we treat the time specially + // and we avoid attempting to set its fields. + v.Set(reflect.ValueOf(time.Now().Add( + time.Duration(ff.getRandomSmallPositiveInt()) * time.Second))) + return + } + for idx := 0; idx < v.NumField(); idx++ { + ff.doFill(v.Field(idx)) // visit all fields + } + case reflect.Slice: + kind := v.Type().Elem() + total := ff.getRandomSmallPositiveInt() + for idx := 0; idx < total; idx++ { + value := reflect.New(kind) // make a new element + ff.doFill(value) + v.Set(reflect.Append(v, value.Elem())) // append to slice + } + case reflect.Map: + if v.Type().Key().Kind() != reflect.String { + return // not supported + } + v.Set(reflect.MakeMap(v.Type())) // we need to init the map + total := ff.getRandomSmallPositiveInt() + kind := v.Type().Elem() + for idx := 0; idx < total; idx++ { + value := reflect.New(kind) + ff.doFill(value) + v.SetMapIndex(reflect.ValueOf(ff.getRandomString()), value.Elem()) + } + } +} + +// fill fills in with random data. +func (ff *fakeFill) fill(in interface{}) { + ff.doFill(reflect.ValueOf(in)) +} + +func TestFakeFillAllocatesIntoAPointerToPointer(t *testing.T) { + var req *apimodel.URLsRequest + ff := &fakeFill{} + ff.fill(&req) + if req == nil { + t.Fatal("we expected non nil here") + } +} + +func TestFakeFillAllocatesIntoAMapLike(t *testing.T) { + var resp apimodel.TorTargetsResponse + ff := &fakeFill{} + ff.fill(&resp) + if resp == nil { + t.Fatal("we expected non nil here") + } + if len(resp) < 1 { + t.Fatal("we expected some data here") + } +} diff --git a/internal/engine/ooapi/integration_test.go b/internal/engine/ooapi/integration_test.go new file mode 100644 index 0000000..43465be --- /dev/null +++ b/internal/engine/ooapi/integration_test.go @@ -0,0 +1,204 @@ +package ooapi + +import ( + "context" + "net/http" + "testing" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +type VerboseHTTPClient struct { + t *testing.T +} + +func (c *VerboseHTTPClient) Do(req *http.Request) (*http.Response, error) { + c.t.Logf("> %s %s", req.Method, req.URL.String()) + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.t.Logf("< %s", err.Error()) + return nil, err + } + c.t.Logf("< %d", resp.StatusCode) + return resp, nil +} + +func TestWithRealServerDoCheckIn(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.CheckInRequest{ + Charging: true, + OnWiFi: true, + Platform: "android", + ProbeASN: "AS12353", + ProbeCC: "IT", + RunType: "timed", + SoftwareName: "ooniprobe-android", + SoftwareVersion: "2.7.1", + WebConnectivity: apimodel.CheckInRequestWebConnectivity{ + CategoryCodes: []string{"NEWS", "CULTR"}, + }, + } + httpClnt := &VerboseHTTPClient{t: t} + api := &CheckInAPI{ + HTTPClient: httpClnt, + } + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + for idx, url := range resp.Tests.WebConnectivity.URLs { + if idx >= 3 { + break + } + t.Logf("- %+v", url) + } +} + +func TestWithRealServerDoCheckReportID(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.CheckReportIDRequest{ + ReportID: "20210223T093606Z_ndt_JO_8376_n1_kDYToqrugDY54Soy", + } + api := &CheckReportIDAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} + +func TestWithRealServerDoMeasurementMeta(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.MeasurementMetaRequest{ + ReportID: "20210223T093606Z_ndt_JO_8376_n1_kDYToqrugDY54Soy", + } + api := &MeasurementMetaAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} + +func TestWithRealServerDoOpenReport(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.OpenReportRequest{ + DataFormatVersion: "0.2.0", + Format: "json", + ProbeASN: "AS137", + ProbeCC: "IT", + SoftwareName: "miniooni", + SoftwareVersion: "0.1.0-dev", + TestName: "example", + TestStartTime: "2018-11-01 15:33:20", + TestVersion: "0.1.0", + } + api := &OpenReportAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} + +func TestWithRealServerDoPsiphonConfig(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.PsiphonConfigRequest{} + httpClnt := &VerboseHTTPClient{t: t} + api := &PsiphonConfigAPIWithLogin{ + API: &PsiphonConfigAPI{ + HTTPClient: httpClnt, + }, + KVStore: &memkvstore{}, + RegisterAPI: &RegisterAPI{ + HTTPClient: httpClnt, + }, + LoginAPI: &LoginAPI{ + HTTPClient: httpClnt, + }, + } + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp != nil) +} + +func TestWithRealServerDoTorTargets(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.TorTargetsRequest{} + httpClnt := &VerboseHTTPClient{t: t} + api := &TorTargetsAPIWithLogin{ + API: &TorTargetsAPI{ + HTTPClient: httpClnt, + }, + KVStore: &memkvstore{}, + RegisterAPI: &RegisterAPI{ + HTTPClient: httpClnt, + }, + LoginAPI: &LoginAPI{ + HTTPClient: httpClnt, + }, + } + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp != nil) +} + +func TestWithRealServerDoURLs(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + req := &apimodel.URLsRequest{ + CountryCode: "IT", + Limit: 3, + } + api := &URLsAPI{} + ctx := context.Background() + resp, err := api.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil pointer here") + } + t.Logf("%+v", resp) +} diff --git a/internal/engine/ooapi/internal/generator/apis.go b/internal/engine/ooapi/internal/generator/apis.go new file mode 100644 index 0000000..e69b1f7 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/apis.go @@ -0,0 +1,180 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +// apiField contains the fields of an API data structure +type apiField struct { + // name is the field name + name string + + // kind is the filed type + kind string + + // comment is a brief comment to document the field + comment string + + // ifLogin indicates whether this field should only be + // emitted when the API requires login + ifLogin bool + + // ifTemplate indicates whether this field should only be + // emitted when the URL path is a template + ifTemplate bool + + // noClone is true when this field should not be copied + // from the parent data structure when cloning + noClone bool +} + +var apiFields = []apiField{{ + name: "BaseURL", + kind: "string", + comment: "optional", +}, { + name: "HTTPClient", + kind: "HTTPClient", + comment: "optional", +}, { + name: "JSONCodec", + kind: "JSONCodec", + comment: "optional", +}, { + name: "Token", + kind: "string", + comment: "mandatory", + ifLogin: true, + noClone: true, +}, { + name: "RequestMaker", + kind: "RequestMaker", + comment: "optional", +}, { + name: "TemplateExecutor", + kind: "TemplateExecutor", + comment: "optional", + ifTemplate: true, +}, { + name: "UserAgent", + kind: "string", + comment: "optional", +}} + +func (d *Descriptor) genNewAPI(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s implements the %s API.\n", d.APIStructName(), d.Name) + fmt.Fprintf(sb, "type %s struct {\n", d.APIStructName()) + for _, f := range apiFields { + if !d.RequiresLogin && f.ifLogin { + continue + } + if !d.URLPath.IsTemplate && f.ifTemplate { + continue + } + fmt.Fprintf(sb, "\t%s %s // %s\n", f.name, f.kind, f.comment) + } + fmt.Fprint(sb, "}\n\n") + + if d.RequiresLogin { + fmt.Fprintf(sb, "// WithToken returns a copy of the API where the\n") + fmt.Fprintf(sb, "// value of the Token field is replaced with token.\n") + fmt.Fprintf(sb, "func (api *%s) WithToken(token string) %s {\n", + d.APIStructName(), d.CallerInterfaceName()) + fmt.Fprintf(sb, "out := &%s{}\n", d.APIStructName()) + for _, f := range apiFields { + if !d.URLPath.IsTemplate && f.ifTemplate { + continue + } + if f.noClone == true { + continue + } + fmt.Fprintf(sb, "out.%s = api.%s\n", f.name, f.name) + } + fmt.Fprint(sb, "out.Token = token\n") + fmt.Fprint(sb, "return out\n") + fmt.Fprint(sb, "}\n\n") + } + + fmt.Fprintf(sb, "func (api *%s) baseURL() string {\n", d.APIStructName()) + fmt.Fprint(sb, "\tif api.BaseURL != \"\" {\n") + fmt.Fprint(sb, "\t\treturn api.BaseURL\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn \"https://ps1.ooni.io\"\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) requestMaker() RequestMaker {\n", d.APIStructName()) + fmt.Fprint(sb, "\tif api.RequestMaker != nil {\n") + fmt.Fprint(sb, "\t\treturn api.RequestMaker\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultRequestMaker{}\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) jsonCodec() JSONCodec {\n", d.APIStructName()) + fmt.Fprint(sb, "\tif api.JSONCodec != nil {\n") + fmt.Fprint(sb, "\t\treturn api.JSONCodec\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultJSONCodec{}\n") + fmt.Fprint(sb, "}\n\n") + + if d.URLPath.IsTemplate { + fmt.Fprintf( + sb, "func (api *%s) templateExecutor() TemplateExecutor {\n", + d.APIStructName()) + fmt.Fprint(sb, "\tif api.TemplateExecutor != nil {\n") + fmt.Fprint(sb, "\t\treturn api.TemplateExecutor\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultTemplateExecutor{}\n") + fmt.Fprint(sb, "}\n\n") + } + + fmt.Fprintf( + sb, "func (api *%s) httpClient() HTTPClient {\n", + d.APIStructName()) + fmt.Fprint(sb, "\tif api.HTTPClient != nil {\n") + fmt.Fprint(sb, "\t\treturn api.HTTPClient\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn http.DefaultClient\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "// Call calls the %s API.\n", d.Name) + fmt.Fprintf( + sb, "func (api *%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.APIStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\thttpReq, err := api.newRequest(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\thttpReq.Header.Add(\"Accept\", \"application/json\")\n") + if d.RequiresLogin { + fmt.Fprint(sb, "\tif api.Token == \"\" {\n") + fmt.Fprint(sb, "\t\treturn nil, ErrMissingToken\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\thttpReq.Header.Add(\"Authorization\", newAuthorizationHeader(api.Token))\n") + } + fmt.Fprint(sb, "\tif api.UserAgent != \"\" {\n") + fmt.Fprint(sb, "\t\thttpReq.Header.Add(\"User-Agent\", api.UserAgent)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.newResponse(api.httpClient().Do(httpReq))\n") + fmt.Fprint(sb, "}\n\n") +} + +// GenAPIsGo generates apis.go. +func GenAPIsGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genNewAPI(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/apistest.go b/internal/engine/ooapi/internal/generator/apistest.go new file mode 100644 index 0000000..31cacbe --- /dev/null +++ b/internal/engine/ooapi/internal/generator/apistest.go @@ -0,0 +1,461 @@ +package main + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +func (d *Descriptor) genTestNewRequest(sb *strings.Builder) { + fmt.Fprintf(sb, "\treq := &%s{}\n", d.RequestTypeNameAsStruct()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\tff.fill(req)\n") +} + +func (d *Descriptor) genTestInvalidURL(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sInvalidURL(t *testing.T) {\n", d.Name) + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tBaseURL: \"\\t\", // invalid\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err == nil || !strings.HasSuffix(err.Error(), \"invalid control character in URL\") {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithMissingToken(sb *strings.Builder) { + if d.RequiresLogin == false { + return // does not make sense when login isn't required + } + fmt.Fprintf(sb, "func Test%sWithMissingToken(t *testing.T) {\n", d.Name) + fmt.Fprintf(sb, "\tapi := &%s{} // no token\n", d.APIStructName()) + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrMissingToken) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithHTTPErr(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithHTTPErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Err: errMocked}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestMarshalErr(sb *strings.Builder) { + if d.Method != "POST" { + return // does not make sense when we don't send a request body + } + fmt.Fprintf(sb, "func Test%sMarshalErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tJSONCodec: &FakeCodec{EncodeErr: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithNewRequestErr(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithNewRequestErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tRequestMaker: &FakeRequestMaker{Err: errMocked},\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWith401(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWith401(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 401}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrUnauthorized) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWith400(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWith400(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{StatusCode: 400}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrHTTPFailure) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithResponseBodyReadErr(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithResponseBodyReadErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 200,\n") + fmt.Fprint(sb, "\t\tBody: &FakeBody{Err: errMocked},\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithUnmarshalFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithUnmarshalFailure(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 200,\n") + fmt.Fprint(sb, "\t\tBody: &FakeBody{Data: []byte(`{}`)},\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + fmt.Fprintf(sb, "\t\tJSONCodec: &FakeCodec{DecodeErr: errMocked},\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestRoundTrip(sb *strings.Builder) { + // generate the type of the handler + fmt.Fprintf(sb, "type handle%s struct {\n", d.Name) + fmt.Fprint(sb, "\taccept string\n") + fmt.Fprint(sb, "\tbody []byte\n") + fmt.Fprint(sb, "\tcontentType string\n") + fmt.Fprint(sb, "\tcount int32\n") + fmt.Fprint(sb, "\tmethod string\n") + fmt.Fprint(sb, "\tmu sync.Mutex\n") + fmt.Fprintf(sb, "\tresp %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\turl *url.URL\n") + fmt.Fprint(sb, "\tuserAgent string\n") + fmt.Fprint(sb, "}\n\n") + + // generate the handling function + fmt.Fprintf(sb, + "func (h *handle%s) ServeHTTP(w http.ResponseWriter, r *http.Request) {", + d.Name) + fmt.Fprint(sb, "\tdefer h.mu.Unlock()\n") + fmt.Fprint(sb, "\th.mu.Lock()\n") + fmt.Fprint(sb, "\tif h.count > 0 {\n") + fmt.Fprint(sb, "\t\tw.WriteHeader(400)\n") + fmt.Fprint(sb, "\t\treturn\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\th.count++\n") + fmt.Fprint(sb, "\tif r.Body != nil {\n") + fmt.Fprint(sb, "\t\tdata, err := ioutil.ReadAll(r.Body)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprintf(sb, "\t\t\tw.WriteHeader(400)\n") + fmt.Fprintf(sb, "\t\t\treturn\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\th.body = data\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\th.method = r.Method\n") + fmt.Fprint(sb, "\th.url = r.URL\n") + fmt.Fprint(sb, "\th.accept = r.Header.Get(\"Accept\")\n") + fmt.Fprint(sb, "\th.contentType = r.Header.Get(\"Content-Type\")\n") + fmt.Fprint(sb, "\th.userAgent = r.Header.Get(\"User-Agent\")\n") + fmt.Fprintf(sb, "\tvar out %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff := fakeFill{}\n") + fmt.Fprint(sb, "\tff.fill(&out)\n") + fmt.Fprintf(sb, "\th.resp = out\n") + fmt.Fprintf(sb, "\tdata, err := json.Marshal(out)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tw.WriteHeader(400)\n") + fmt.Fprintf(sb, "\t\treturn\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tw.Write(data)\n") + fmt.Fprintf(sb, "\t}\n\n") + + // generate the test itself + fmt.Fprintf(sb, "func Test%sRoundTrip(t *testing.T) {\n", d.Name) + + fmt.Fprint(sb, "\t// setup\n") + fmt.Fprintf(sb, "\thandler := &handle%s{}\n", d.Name) + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + fmt.Fprintf(sb, "\treq := &%s{}\n", d.RequestTypeNameAsStruct()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprintf(sb, "\tapi := &%s{BaseURL: srvr.URL}\n", d.APIStructName()) + fmt.Fprint(sb, "\tff.fill(&api.UserAgent)\n") + if d.RequiresLogin { + fmt.Fprint(sb, "\tff.fill(&api.Token)\n") + } + + fmt.Fprint(sb, "\t// issue request\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response here\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// compare our response and server's one\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(handler.resp, resp); diff != \"\" {") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// check whether headers are OK\n") + fmt.Fprint(sb, "\tif handler.accept != \"application/json\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid accept header\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif handler.userAgent != api.UserAgent {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid user-agent header\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// check whether the method is OK\n") + fmt.Fprintf(sb, "\tif handler.method != \"%s\" {\n", d.Method) + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid method\")\n") + fmt.Fprint(sb, "\t}\n") + + if d.Method == "POST" { + fmt.Fprint(sb, "\t// check the body\n") + fmt.Fprint(sb, "\tif handler.contentType != \"application/json\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid content-type header\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tgot := &%s{}\n", d.RequestTypeNameAsStruct()) + fmt.Fprintf(sb, "\tif err := json.Unmarshal(handler.body, &got); err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(req, got); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + } else { + fmt.Fprint(sb, "\t// check the query\n") + fmt.Fprint(sb, "\thttpReq, err := api.newRequest(context.Background(), req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(handler.url.Path, httpReq.URL.Path); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(handler.url.RawQuery, httpReq.URL.RawQuery); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + } + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestResponseLiteralNull(sb *strings.Builder) { + switch d.ResponseTypeKind() { + case reflect.Map: + // fallthrough + case reflect.Struct: + return // test not applicable + } + fmt.Fprintf(sb, "func Test%sResponseLiteralNull(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 200,\n") + fmt.Fprint(sb, "\t\tBody: &FakeBody{Data: []byte(`null`)},\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrJSONLiteralNull) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestMandatoryFields(sb *strings.Builder) { + fields := d.StructFieldsWithTag(d.Request, tagForRequired) + if len(fields) < 1 { + return // nothing to test + } + fmt.Fprintf(sb, "func Test%sMandatoryFields(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 500,\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprintf(sb, "\treq := &%s{} // deliberately empty\n", d.RequestTypeNameAsStruct()) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrEmptyField) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestTemplateErr(sb *strings.Builder) { + if !d.URLPath.IsTemplate { + return // nothing to test + } + fmt.Fprintf(sb, "func Test%sTemplateErr(t *testing.T) {\n", d.Name) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tclnt := &FakeHTTPClient{Resp: &http.Response{\n") + fmt.Fprint(sb, "\t\tStatusCode: 500,\n") + fmt.Fprint(sb, "\t}}\n") + fmt.Fprintf(sb, "\tapi := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: clnt,\n") + if d.RequiresLogin == true { + fmt.Fprint(sb, "\t\tToken: \"fakeToken\",\n") + } + fmt.Fprint(sb, "\t\tTemplateExecutor: &FakeTemplateExecutor{Err: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + d.genTestNewRequest(sb) + fmt.Fprint(sb, "\tresp, err := api.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil resp\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +// TODO(bassosimone): we should add a panic for every switch for +// the type of a request or a response for robustness. + +func (d *Descriptor) genAPITests(sb *strings.Builder) { + d.genTestInvalidURL(sb) + d.genTestWithMissingToken(sb) + d.genTestWithHTTPErr(sb) + d.genTestMarshalErr(sb) + d.genTestWithNewRequestErr(sb) + d.genTestWith401(sb) + d.genTestWith400(sb) + d.genTestWithResponseBodyReadErr(sb) + d.genTestWithUnmarshalFailure(sb) + d.genTestRoundTrip(sb) + d.genTestResponseLiteralNull(sb) + d.genTestMandatoryFields(sb) + d.genTestTemplateErr(sb) +} + +// GenAPIsTestGo generates apis_test.go. +func GenAPIsTestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"encoding/json\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\t\"io/ioutil\"\n") + fmt.Fprint(&sb, "\t\"net/http/httptest\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\t\"net/url\"\n") + fmt.Fprint(&sb, "\t\"strings\"\n") + fmt.Fprint(&sb, "\t\"testing\"\n") + fmt.Fprint(&sb, "\t\"sync\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/google/go-cmp/cmp\"\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genAPITests(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/caching.go b/internal/engine/ooapi/internal/generator/caching.go new file mode 100644 index 0000000..ca1dd6a --- /dev/null +++ b/internal/engine/ooapi/internal/generator/caching.go @@ -0,0 +1,130 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewCache(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s implements caching for %s.\n", + d.CacheStructName(), d.APIStructName()) + fmt.Fprintf(sb, "type %s struct {\n", d.CacheStructName()) + fmt.Fprintf(sb, "\tAPI %s // mandatory\n", d.CallerInterfaceName()) + fmt.Fprint(sb, "\tGobCodec GobCodec // optional\n") + fmt.Fprint(sb, "\tKVStore KVStore // mandatory\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "type %s struct {\n", d.CacheEntryName()) + fmt.Fprintf(sb, "\tReq %s\n", d.RequestTypeName()) + fmt.Fprintf(sb, "\tResp %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "// Call calls the API and implements caching.\n") + fmt.Fprintf(sb, "func (c *%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.CacheStructName(), d.RequestTypeName(), d.ResponseTypeName()) + if d.CachePolicy == CacheAlways { + fmt.Fprint(sb, "\tif resp, _ := c.readcache(req); resp != nil {\n") + fmt.Fprint(sb, "\t\treturn resp, nil\n") + fmt.Fprint(sb, "\t}\n") + } + fmt.Fprint(sb, "\tresp, err := c.API.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + if d.CachePolicy == CacheFallback { + fmt.Fprint(sb, "\t\tif resp, _ := c.readcache(req); resp != nil {\n") + fmt.Fprint(sb, "\t\t\treturn resp, nil\n") + fmt.Fprint(sb, "\t\t}\n") + } + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif err := c.writecache(req, resp); err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn resp, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) gobCodec() GobCodec {\n", d.CacheStructName()) + fmt.Fprint(sb, "\tif c.GobCodec != nil {\n") + fmt.Fprint(sb, "\t\treturn c.GobCodec\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultGobCodec{}\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) getcache() ([]%s, error) {\n", + d.CacheStructName(), d.CacheEntryName()) + fmt.Fprintf(sb, "\tdata, err := c.KVStore.Get(\"%s\")\n", d.CacheKey()) + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar out []%s\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tif err := c.gobCodec().Decode(data, &out); err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn out, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) setcache(in []%s) error {\n", + d.CacheStructName(), d.CacheEntryName()) + fmt.Fprint(sb, "\tdata, err := c.gobCodec().Encode(in)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\treturn c.KVStore.Set(\"%s\", data)\n", d.CacheKey()) + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) readcache(req %s) (%s, error) {\n", + d.CacheStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\tcache, err := c.getcache()\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tfor _, cur := range cache {\n") + fmt.Fprint(sb, "\t\tif reflect.DeepEqual(req, cur.Req) {\n") + fmt.Fprint(sb, "\t\t\treturn cur.Resp, nil\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn nil, errCacheNotFound\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (c *%s) writecache(req %s, resp %s) error {\n", + d.CacheStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\tcache, _ := c.getcache()\n") + fmt.Fprintf(sb, "\tout := []%s{{Req: req, Resp: resp}}\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tconst toomany = 64\n") + fmt.Fprint(sb, "\tfor idx, cur := range cache {\n") + fmt.Fprint(sb, "\t\tif reflect.DeepEqual(req, cur.Req) {\n") + fmt.Fprint(sb, "\t\t\tcontinue // we already updated the cache\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif idx > toomany {\n") + fmt.Fprint(sb, "\t\t\tbreak\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tout = append(out, cur)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn c.setcache(out)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "var _ %s = &%s{}\n\n", d.CallerInterfaceName(), + d.CacheStructName()) +} + +// GenCachingGo generates caching.go. +func GenCachingGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"reflect\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if desc.CachePolicy == CacheNone { + continue + } + desc.genNewCache(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/cachingtest.go b/internal/engine/ooapi/internal/generator/cachingtest.go new file mode 100644 index 0000000..a556885 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/cachingtest.go @@ -0,0 +1,274 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genTestCacheSuccess(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sSuccess(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWriteCacheError(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sWriteCacheError(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tKVStore: &FakeKVStore{SetError: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestFailureWithNoCache(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sFailureWithNoCache(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestFailureWithPreviousCache(sb *strings.Builder) { + // This works for both caching policies. + fmt.Fprintf(sb, "func TestCache%sFailureWithPreviousCache(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + fmt.Fprintf(sb, "\tfakeapi := &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tAPI: fakeapi,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\t// first pass with no error at all\n") + fmt.Fprint(sb, "\t// use a separate scope to be sure we avoid mistakes\n") + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t// second pass with failure\n") + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tfakeapi.Err = errMocked\n") + fmt.Fprint(sb, "\tfakeapi.Response = nil\n") + fmt.Fprint(sb, "\tresp2, err := cache.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp2 == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp2); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestSetcacheWithEncodeError(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sSetcacheWithEncodeError(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tvar in []%s\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tff.fill(&in)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tGobCodec: &FakeCodec{EncodeErr: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\terr := cache.setcache(in)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestReadCacheNotFound(sb *strings.Builder) { + if fields := d.StructFields(d.Request); len(fields) <= 0 { + // this test cannot work when there are no fields in the + // request because we will always find a match. + // TODO(bassosimone): how to avoid having uncovered code? + return + } + fmt.Fprintf(sb, "func TestCache%sReadCacheNotFound(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar incache []%s\n", d.CacheEntryName()) + fmt.Fprint(sb, "\tff.fill(&incache)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\terr := cache.setcache(incache)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprintf(sb, "\tout, err := cache.readcache(req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errCacheNotFound) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif out != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil here\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWriteCacheDuplicate(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestCache%sWriteCacheDuplicate(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprintf(sb, "\tvar resp1 %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&resp1)\n") + fmt.Fprintf(sb, "\tvar resp2 %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&resp2)\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\terr := cache.writecache(req, resp1)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\terr = cache.writecache(req, resp2)\n") + fmt.Fprintf(sb, "\tif err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tout, err := cache.readcache(req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif out == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil here\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(resp2, out); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestCachSizeLimited(sb *strings.Builder) { + if fields := d.StructFields(d.Request); len(fields) <= 0 { + // this test cannot work when there are no fields in the + // request because we will always find a match. + // TODO(bassosimone): how to avoid having uncovered code? + return + } + fmt.Fprintf(sb, "func TestCache%sCacheSizeLimited(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tcache := &%s{\n", d.CacheStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tvar prev int\n") + fmt.Fprintf(sb, "\tfor {\n") + fmt.Fprintf(sb, "\t\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\t\tff.fill(&req)\n") + fmt.Fprintf(sb, "\t\tvar resp %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\t\tff.fill(&resp)\n") + fmt.Fprintf(sb, "\t\terr := cache.writecache(req, resp)\n") + fmt.Fprintf(sb, "\t\tif err != nil {\n") + fmt.Fprintf(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t\t}\n") + fmt.Fprintf(sb, "\t\tout, err := cache.getcache()\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif len(out) > prev {\n") + fmt.Fprint(sb, "\t\t\tprev = len(out)\n") + fmt.Fprint(sb, "\t\t\tcontinue\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tbreak\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "}\n\n") +} + +// GenCachingTestGo generates caching_test.go. +func GenCachingTestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\t\"testing\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/google/go-cmp/cmp\"\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if desc.CachePolicy == CacheNone { + continue + } + desc.genTestCacheSuccess(&sb) + desc.genTestWriteCacheError(&sb) + desc.genTestFailureWithNoCache(&sb) + desc.genTestFailureWithPreviousCache(&sb) + desc.genTestSetcacheWithEncodeError(&sb) + desc.genTestReadCacheNotFound(&sb) + desc.genTestWriteCacheDuplicate(&sb) + desc.genTestCachSizeLimited(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/callers.go b/internal/engine/ooapi/internal/generator/callers.go new file mode 100644 index 0000000..4ba8ea8 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/callers.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewCaller(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s represents any type exposing a method\n", + d.CallerInterfaceName()) + fmt.Fprintf(sb, "// like %s.Call.\n", d.APIStructName()) + fmt.Fprintf(sb, "type %s interface {\n", d.CallerInterfaceName()) + fmt.Fprintf(sb, "\tCall(ctx context.Context, req %s) (%s, error)\n", + d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "}\n\n") +} + +// GenCallersGo generates callers.go. +func GenCallersGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genNewCaller(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/cloners.go b/internal/engine/ooapi/internal/generator/cloners.go new file mode 100644 index 0000000..1db6d92 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/cloners.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewCloner(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s represents any type exposing a method\n", + d.CallerInterfaceName()) + fmt.Fprintf(sb, "// like %s.WithToken.\n", d.APIStructName()) + fmt.Fprintf(sb, "type %s interface {\n", d.ClonerInterfaceName()) + fmt.Fprintf(sb, "\tWithToken(token string) %s\n", d.CallerInterfaceName()) + fmt.Fprint(sb, "}\n\n") +} + +// GenClonersGo generates cloners.go. +func GenClonersGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + for _, desc := range Descriptors { + if !desc.RequiresLogin { + continue + } + desc.genNewCloner(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/fakeapitest.go b/internal/engine/ooapi/internal/generator/fakeapitest.go new file mode 100644 index 0000000..deb2b56 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/fakeapitest.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewFakeAPI(sb *strings.Builder) { + fmt.Fprintf(sb, "type Fake%s struct {\n", d.APIStructName()) + if d.RequiresLogin { + fmt.Fprintf(sb, "\tWithResult %s\n", d.CallerInterfaceName()) + } + fmt.Fprint(sb, "\tErr error\n") + fmt.Fprintf(sb, "\tResponse %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tCountCall int32\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (fapi *Fake%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.APIStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\tatomic.AddInt32(&fapi.CountCall, 1)\n") + fmt.Fprint(sb, "\treturn fapi.Response, fapi.Err\n") + fmt.Fprint(sb, "}\n\n") + + if d.RequiresLogin { + fmt.Fprintf(sb, "func (fapi *Fake%s) WithToken(token string) %s {\n", + d.APIStructName(), d.CallerInterfaceName()) + fmt.Fprint(sb, "\treturn fapi.WithResult\n") + fmt.Fprint(sb, "}\n\n") + } + + fmt.Fprint(sb, "var (\n") + fmt.Fprintf(sb, "\t_ %s = &Fake%s{}\n", d.CallerInterfaceName(), + d.APIStructName()) + if d.RequiresLogin { + fmt.Fprintf(sb, "\t_ %s = &Fake%s{}\n", d.ClonerInterfaceName(), + d.APIStructName()) + } + fmt.Fprint(sb, ")\n\n") +} + +// GenFakeAPITestGo generates fakeapi_test.go. +func GenFakeAPITestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"sync/atomic\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + desc.genNewFakeAPI(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/generator.go b/internal/engine/ooapi/internal/generator/generator.go new file mode 100644 index 0000000..6158767 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/generator.go @@ -0,0 +1,53 @@ +// Command generator generates code in the ooapi package. +// +// To this end, it uses the content of the apimodel package as +// well as the content of the spec.go file. +// +// The apimodel package defines the model, i.e., the structure +// of requests and responses and how messages should be sent +// and received. +// +// The spec.go file describes all the implemented APIs. +// +// If you change apimodel or spec.go, remember to run the +// `go generate ./...` command to regenerate all files. +package main + +import ( + "flag" + "fmt" +) + +var flagFile = flag.String("file", "", "Indicate which file to regenerate") + +func main() { + flag.Parse() + switch file := *flagFile; file { + case "apis.go": + GenAPIsGo(file) + case "responses.go": + GenResponsesGo(file) + case "requests.go": + GenRequestsGo(file) + case "swagger_test.go": + GenSwaggerTestGo(file) + case "apis_test.go": + GenAPIsTestGo(file) + case "callers.go": + GenCallersGo(file) + case "caching.go": + GenCachingGo(file) + case "login.go": + GenLoginGo(file) + case "cloners.go": + GenClonersGo(file) + case "fakeapi_test.go": + GenFakeAPITestGo(file) + case "caching_test.go": + GenCachingTestGo(file) + case "login_test.go": + GenLoginTestGo(file) + default: + panic(fmt.Sprintf("don't know how to create this file: %s", file)) + } +} diff --git a/internal/engine/ooapi/internal/generator/login.go b/internal/engine/ooapi/internal/generator/login.go new file mode 100644 index 0000000..4328d2e --- /dev/null +++ b/internal/engine/ooapi/internal/generator/login.go @@ -0,0 +1,182 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genNewLogin(sb *strings.Builder) { + fmt.Fprintf(sb, "// %s implements login for %s.\n", + d.WithLoginAPIStructName(), d.APIStructName()) + fmt.Fprintf(sb, "type %s struct {\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI %s // mandatory\n", d.ClonerInterfaceName()) + fmt.Fprint(sb, "\tJSONCodec JSONCodec // optional\n") + fmt.Fprint(sb, "\tKVStore KVStore // mandatory\n") + fmt.Fprint(sb, "\tRegisterAPI RegisterCaller // mandatory\n") + fmt.Fprint(sb, "\tLoginAPI LoginCaller // mandatory\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "// Call logins, if needed, then calls the API.\n") + fmt.Fprintf(sb, "func (api *%s) Call(ctx context.Context, req %s) (%s, error) {\n", + d.WithLoginAPIStructName(), d.RequestTypeName(), d.ResponseTypeName()) + fmt.Fprint(sb, "\ttoken, err := api.maybeLogin(ctx)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tresp, err := api.API.WithToken(token).Call(ctx, req)\n") + fmt.Fprint(sb, "\tif errors.Is(err, ErrUnauthorized) {\n") + fmt.Fprint(sb, "\t\t// Maybe the clock is just off? Let's try to obtain\n") + fmt.Fprint(sb, "\t\t// a token again and see if this fixes it.\n") + fmt.Fprint(sb, "\t\tif token, err = api.forceLogin(ctx); err == nil {\n") + fmt.Fprint(sb, "\t\t\tswitch resp, err = api.API.WithToken(token).Call(ctx, req); err {\n") + fmt.Fprint(sb, "\t\t\tcase nil:\n") + fmt.Fprint(sb, "\t\t\t\treturn resp, nil\n") + fmt.Fprint(sb, "\t\t\tcase ErrUnauthorized:\n") + fmt.Fprint(sb, "\t\t\t\t// fallthrough\n") + fmt.Fprint(sb, "\t\t\tdefault:\n") + fmt.Fprint(sb, "\t\t\t\treturn nil, err\n") + fmt.Fprint(sb, "\t\t\t}\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\t// Okay, this seems a broader problem. How about we try\n") + fmt.Fprint(sb, "\t\t// and re-register ourselves again instead?\n") + fmt.Fprint(sb, "\t\ttoken, err = api.forceRegister(ctx)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\treturn nil, err\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tresp, err = api.API.WithToken(token).Call(ctx, req)\n") + fmt.Fprint(sb, "\t\t// fallthrough\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn resp, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) jsonCodec() JSONCodec {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tif api.JSONCodec != nil {\n") + fmt.Fprint(sb, "\t\treturn api.JSONCodec\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &defaultJSONCodec{}\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) readstate() (*loginState, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tdata, err := api.KVStore.Get(loginKey)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tvar ls loginState\n") + fmt.Fprint(sb, "\tif err := api.jsonCodec().Decode(data, &ls); err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn &ls, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) writestate(ls *loginState) error {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tdata, err := api.jsonCodec().Encode(*ls)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.KVStore.Set(loginKey, data)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) doRegister(ctx context.Context, password string) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\treq := newRegisterRequest(password)\n") + fmt.Fprint(sb, "\tls := &loginState{}\n") + fmt.Fprint(sb, "\tresp, err := api.RegisterAPI.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tls.ClientID = resp.ClientID\n") + fmt.Fprint(sb, "\tls.Password = req.Password\n") + fmt.Fprint(sb, "\treturn api.doLogin(ctx, ls)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) forceRegister(ctx context.Context) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tvar password string\n") + fmt.Fprint(sb, "\t// If we already have a previous password, let us keep\n") + fmt.Fprint(sb, "\t// using it. This will allow a new version of the API to\n") + fmt.Fprint(sb, "\t// be able to continue to identify this probe. (This\n") + fmt.Fprint(sb, "\t// assumes that we have a stateless API that generates\n") + fmt.Fprint(sb, "\t// the user ID as a signature of the password plus a\n") + fmt.Fprint(sb, "\t// timestamp and that the key to generate the signature\n") + fmt.Fprint(sb, "\t// is not lost. If all these conditions are met, we\n") + fmt.Fprint(sb, "\t// can then serve better test targets to more long running\n") + fmt.Fprint(sb, "\t// (and therefore trusted) probes.)\n") + fmt.Fprint(sb, "\tif ls, err := api.readstate(); err == nil {\n") + fmt.Fprint(sb, "\t\tpassword = ls.Password\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif password == \"\" {\n") + fmt.Fprint(sb, "\t\tpassword = newRandomPassword()\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.doRegister(ctx, password)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) forceLogin(ctx context.Context) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tls, err := api.readstate()\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn api.doLogin(ctx, ls)\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) maybeLogin(ctx context.Context) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\tls, _ := api.readstate()\n") + fmt.Fprint(sb, "\tif ls == nil || !ls.credentialsValid() {\n") + fmt.Fprint(sb, "\t\treturn api.forceRegister(ctx)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif !ls.tokenValid() {\n") + fmt.Fprint(sb, "\t\treturn api.doLogin(ctx, ls)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn ls.Token, nil\n") + fmt.Fprint(sb, "}\n\n") + + fmt.Fprintf(sb, "func (api *%s) doLogin(ctx context.Context, ls *loginState) (string, error) {\n", + d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\treq := &apimodel.LoginRequest{\n") + fmt.Fprint(sb, "\t\tClientID: ls.ClientID,\n") + fmt.Fprint(sb, "\t\tPassword: ls.Password,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tresp, err := api.LoginAPI.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tls.Token = resp.Token\n") + fmt.Fprint(sb, "\tls.Expire = resp.Expire\n") + fmt.Fprint(sb, "\tif err := api.writestate(ls); err != nil {\n") + fmt.Fprint(sb, "\t\treturn \"\", err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\treturn ls.Token, nil\n") + fmt.Fprint(sb, "}\n\n") + fmt.Fprintf(sb, "var _ %s = &%s{}\n\n", d.CallerInterfaceName(), + d.WithLoginAPIStructName()) +} + +// GenLoginGo generates login.go. +func GenLoginGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if !desc.RequiresLogin { + continue + } + desc.genNewLogin(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/logintest.go b/internal/engine/ooapi/internal/generator/logintest.go new file mode 100644 index 0000000..d74c9f5 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/logintest.go @@ -0,0 +1,864 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func (d *Descriptor) genTestRegisterAndLoginSuccess(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestRegisterAndLogin%sSuccess(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestContinueUsingToken(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sContinueUsingToken(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we disable register and login but we\n") + fmt.Fprint(sb, "\t// should be okay because of the token\n") + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tregisterAPI.Err = errMocked\n") + fmt.Fprint(sb, "\tregisterAPI.Response = nil\n") + fmt.Fprint(sb, "\tloginAPI.Err = errMocked\n") + fmt.Fprint(sb, "\tloginAPI.Response = nil\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithValidButExpiredToken(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithValidButExpiredToken(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tls := &loginState{\n") + fmt.Fprintf(sb, "\t\tClientID: \"antani-antani\",\n") + fmt.Fprintf(sb, "\t\tExpire: time.Now().Add(-5 * time.Second),\n") + fmt.Fprintf(sb, "\t\tToken: \"antani-antani-token\",\n") + fmt.Fprintf(sb, "\t\tPassword: \"antani-antani-password\",\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tif err := login.writestate(ls); err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif diff := cmp.Diff(expect, resp); diff != \"\" {\n") + fmt.Fprint(sb, "\t\tt.Fatal(diff)\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 0 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithRegisterAPIError(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithRegisterAPIError(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestWithLoginFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sWithLoginFailure(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestRegisterAndLoginThenFail(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestRegisterAndLogin%sThenFail(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tErr: errMocked,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestTheDatabaseIsReplaced(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sTheDatabaseIsReplaced(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget accounts and try again.\n") + fmt.Fprint(sb, "\thandler.forgetLogins()\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 3 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestTheDatabaseIsReplacedThenFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sTheDatabaseIsReplacedThenFailure(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget accounts and try again.\n") + fmt.Fprint(sb, "\t// but registrations are also failing.\n") + fmt.Fprint(sb, "\thandler.forgetLogins()\n") + fmt.Fprint(sb, "\thandler.noRegister = true\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrHTTPFailure) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestRegisterAndLoginCannotWriteState(sb *strings.Builder) { + fmt.Fprintf(sb, "func TestRegisterAndLogin%sCannotWriteState(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\tregisterAPI := &FakeRegisterAPI{\n") + fmt.Fprint(sb, "\t\tResponse: &apimodel.RegisterResponse{\n") + fmt.Fprint(sb, "\t\t\tClientID: \"antani-antani\",\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t\tloginAPI := &FakeLoginAPI{\n") + fmt.Fprint(sb, "\t\t\tResponse: &apimodel.LoginResponse{\n") + fmt.Fprint(sb, "\t\t\t\tExpire: time.Now().Add(3600*time.Second),\n") + fmt.Fprint(sb, "\t\t\t\tToken: \"antani-antani-token\",\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\t\tAPI: &Fake%s{\n", d.APIStructName()) + fmt.Fprintf(sb, "\t\t\tWithResult: &Fake%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\t\t\tResponse: expect,\n") + fmt.Fprint(sb, "\t\t\t},\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\t\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t\tJSONCodec: &FakeCodec{\n") + fmt.Fprint(sb, "\t\t\tEncodeErr: errMocked,\n") + fmt.Fprint(sb, "\t\t},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, errMocked) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif loginAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid loginAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif registerAPI.CountCall != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid registerAPI.CountCall\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestReadStateDecodeFailure(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sReadStateDecodeFailure(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprintf(sb, "\tvar expect %s\n", d.ResponseTypeName()) + fmt.Fprint(sb, "\tff.fill(&expect)\n") + + fmt.Fprint(sb, "\terrMocked := errors.New(\"mocked error\")\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprint(sb, "\t\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t\tJSONCodec: &FakeCodec{DecodeErr: errMocked},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tls := &loginState{\n") + fmt.Fprintf(sb, "\t\tClientID: \"antani-antani\",\n") + fmt.Fprintf(sb, "\t\tExpire: time.Now().Add(-5 * time.Second),\n") + fmt.Fprintf(sb, "\t\tToken: \"antani-antani-token\",\n") + fmt.Fprintf(sb, "\t\tPassword: \"antani-antani-password\",\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "\tif err := login.writestate(ls); err != nil {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(err)\n") + fmt.Fprintf(sb, "\t}\n") + + fmt.Fprintf(sb, "\tout, err := login.forceLogin(context.Background())\n") + fmt.Fprintf(sb, "if !errors.Is(err, errMocked) {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprintf(sb, "\t}\n") + fmt.Fprintf(sb, "if out != \"\" {\n") + fmt.Fprintf(sb, "\t\tt.Fatal(\"expected empty string here\")\n") + fmt.Fprintf(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestClockIsOffThenSuccess(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sClockIsOffThenSuccess(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget tokens and try again.\n") + fmt.Fprint(sb, "\t// this should simulate the client clock\n") + fmt.Fprint(sb, "\t// being off and considering a token still valid\n") + fmt.Fprint(sb, "\thandler.forgetTokens()\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestClockIsOffThen401(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sClockIsOffThen401(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget tokens and try again.\n") + fmt.Fprint(sb, "\t// this should simulate the client clock\n") + fmt.Fprint(sb, "\t// being off and considering a token still valid\n") + fmt.Fprint(sb, "\thandler.forgetTokens()\n") + fmt.Fprint(sb, "\thandler.failCallWith = []int{401, 401}\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp == nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 3 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +func (d *Descriptor) genTestClockIsOffThen500(sb *strings.Builder) { + fmt.Fprintf(sb, "func Test%sClockIsOffThen500(t *testing.T) {\n", d.APIStructName()) + fmt.Fprint(sb, "\tff := &fakeFill{}\n") + fmt.Fprint(sb, "\thandler := &LoginHandler{t: t}\n") + fmt.Fprint(sb, "\tsrvr := httptest.NewServer(handler)\n") + fmt.Fprint(sb, "\tdefer srvr.Close()\n") + + fmt.Fprint(sb, "\tregisterAPI := &RegisterAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\t\tloginAPI := &LoginAPI{\n") + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprintf(sb, "\tbaseAPI := &%s{\n", d.APIStructName()) + fmt.Fprint(sb, "\t\tHTTPClient: &VerboseHTTPClient{t: t},\n") + fmt.Fprint(sb, "\t\tBaseURL: srvr.URL,\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tlogin := &%s{\n", d.WithLoginAPIStructName()) + fmt.Fprintf(sb, "\tAPI : baseAPI,\n") + fmt.Fprint(sb, "\tRegisterAPI: registerAPI,\n") + fmt.Fprint(sb, "\tLoginAPI: loginAPI,\n") + fmt.Fprint(sb, "\tKVStore: &memkvstore{},\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprintf(sb, "\tvar req %s\n", d.RequestTypeName()) + fmt.Fprint(sb, "\tff.fill(&req)\n") + fmt.Fprint(sb, "\tctx := context.Background()\n") + + fmt.Fprint(sb, "\t// step 1: we register and login and use the token\n") + fmt.Fprint(sb, "\t// inside a scope just to avoid mistakes\n") + + fmt.Fprint(sb, "\t{\n") + fmt.Fprint(sb, "\t\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\t\tif err != nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(err)\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t\tif resp == nil {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"expected non-nil response\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.logins != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t\t}\n") + + fmt.Fprint(sb, "\t\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t\t}\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\t// step 2: we forget tokens and try again.\n") + fmt.Fprint(sb, "\t// this should simulate the client clock\n") + fmt.Fprint(sb, "\t// being off and considering a token still valid\n") + fmt.Fprint(sb, "\thandler.forgetTokens()\n") + fmt.Fprint(sb, "\thandler.failCallWith = []int{401, 500}\n") + + fmt.Fprint(sb, "\tresp, err := login.Call(ctx, req)\n") + fmt.Fprint(sb, "\tif !errors.Is(err, ErrHTTPFailure) {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"not the error we expected\", err)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp != nil {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"expected nil response\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.logins != 2 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.logins\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "\tif handler.registers != 1 {\n") + fmt.Fprint(sb, "\t\tt.Fatal(\"invalid handler.registers\")\n") + fmt.Fprint(sb, "\t}\n") + + fmt.Fprint(sb, "}\n\n") +} + +// GenLoginTestGo generates login_test.go. +func GenLoginTestGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"errors\"\n") + fmt.Fprint(&sb, "\t\"net/http/httptest\"\n") + fmt.Fprint(&sb, "\t\"testing\"\n") + fmt.Fprint(&sb, "\t\"time\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/google/go-cmp/cmp\"\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n") + for _, desc := range Descriptors { + if !desc.RequiresLogin { + continue + } + desc.genTestRegisterAndLoginSuccess(&sb) + desc.genTestContinueUsingToken(&sb) + desc.genTestWithValidButExpiredToken(&sb) + desc.genTestWithRegisterAPIError(&sb) + desc.genTestWithLoginFailure(&sb) + desc.genTestRegisterAndLoginThenFail(&sb) + desc.genTestTheDatabaseIsReplaced(&sb) + desc.genTestRegisterAndLoginCannotWriteState(&sb) + desc.genTestReadStateDecodeFailure(&sb) + desc.genTestTheDatabaseIsReplacedThenFailure(&sb) + desc.genTestClockIsOffThenSuccess(&sb) + desc.genTestClockIsOffThen401(&sb) + desc.genTestClockIsOffThen500(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/reflect.go b/internal/engine/ooapi/internal/generator/reflect.go new file mode 100644 index 0000000..9b2f27b --- /dev/null +++ b/internal/engine/ooapi/internal/generator/reflect.go @@ -0,0 +1,147 @@ +package main + +import ( + "fmt" + "reflect" +) + +// TypeName returns v's package-qualified type name. +func (d *Descriptor) TypeName(v interface{}) string { + return reflect.TypeOf(v).String() +} + +// RequestTypeName calls d.TypeName(d.Request). +func (d *Descriptor) RequestTypeName() string { + return d.TypeName(d.Request) +} + +// ResponseTypeName calls d.TypeName(d.Response). +func (d *Descriptor) ResponseTypeName() string { + return d.TypeName(d.Response) +} + +// APIStructName returns the correct struct type name +// for the API we're currently processing. +func (d *Descriptor) APIStructName() string { + return fmt.Sprintf("%sAPI", d.Name) +} + +// WithLoginAPIStructName returns the correct struct type name +// for the WithLoginAPI we're currently processing. +func (d *Descriptor) WithLoginAPIStructName() string { + return fmt.Sprintf("%sAPIWithLogin", d.Name) +} + +// CallerInterfaceName returns the correct caller interface name +// for the API we're currently processing. +func (d *Descriptor) CallerInterfaceName() string { + return fmt.Sprintf("%sCaller", d.Name) +} + +// ClonerInterfaceName returns the correct cloner interface name +// for the API we're currently processing. +func (d *Descriptor) ClonerInterfaceName() string { + return fmt.Sprintf("%sCloner", d.Name) +} + +// CacheStructName returns the correct struct type name for +// the cache for the API we're currently processing. +func (d *Descriptor) CacheStructName() string { + return fmt.Sprintf("%sCache", d.Name) +} + +// CacheEntryName returns the correct struct type name for the +// cache entry for the API we're currently processing. +func (d *Descriptor) CacheEntryName() string { + return fmt.Sprintf("cacheEntryFor%s", d.Name) +} + +// CacheKey returns the correct cache key for the API +// we're currently processing. +func (d *Descriptor) CacheKey() string { + return fmt.Sprintf("%s.cache", d.Name) +} + +// StructFields returns all the struct fields of in. This function +// assumes that in is a pointer to struct, and will otherwise panic. +func (d *Descriptor) StructFields(in interface{}) []*reflect.StructField { + t := reflect.TypeOf(in) + if t.Kind() != reflect.Ptr { + panic("not a pointer") + } + t = t.Elem() + if t.Kind() != reflect.Struct { + panic("not a struct") + } + var out []*reflect.StructField + for idx := 0; idx < t.NumField(); idx++ { + f := t.Field(idx) + out = append(out, &f) + } + return out +} + +// StructFieldsWithTag returns all the struct fields of +// in that have the specified tag. +func (d *Descriptor) StructFieldsWithTag(in interface{}, tag string) []*reflect.StructField { + var out []*reflect.StructField + for _, f := range d.StructFields(in) { + if f.Tag.Get(tag) != "" { + out = append(out, f) + } + } + return out +} + +// RequestOrResponseTypeKind returns the type kind of in, which should +// be a request or a response. This function assumes that in is either a +// pointer to struct or a map and will panic otherwise. +func (d *Descriptor) RequestOrResponseTypeKind(in interface{}) reflect.Kind { + t := reflect.TypeOf(in) + if t.Kind() == reflect.Ptr { + t = t.Elem() + if t.Kind() != reflect.Struct { + panic("not a struct") + } + return reflect.Struct + } + if t.Kind() != reflect.Map { + panic("not a map") + } + return reflect.Map +} + +// RequestTypeKind calls d.RequestOrResponseTypeKind(d.Request). +func (d *Descriptor) RequestTypeKind() reflect.Kind { + return d.RequestOrResponseTypeKind(d.Request) +} + +// ResponseTypeKind calls d.RequestOrResponseTypeKind(d.Response). +func (d *Descriptor) ResponseTypeKind() reflect.Kind { + return d.RequestOrResponseTypeKind(d.Response) +} + +// TypeNameAsStruct assumes that in is a pointer to struct and +// returns the type of the corresponding struct. The returned +// type is package qualified. +func (d *Descriptor) TypeNameAsStruct(in interface{}) string { + t := reflect.TypeOf(in) + if t.Kind() != reflect.Ptr { + panic("not a pointer") + } + t = t.Elem() + if t.Kind() != reflect.Struct { + panic("not a struct") + } + return t.String() +} + +// RequestTypeNameAsStruct calls d.TypeNameAsStruct(d.Request) +func (d *Descriptor) RequestTypeNameAsStruct() string { + return d.TypeNameAsStruct(d.Request) +} + +// ResponseTypeNameAsStruct calls d.TypeNameAsStruct(d.Response) +func (d *Descriptor) ResponseTypeNameAsStruct() string { + return d.TypeNameAsStruct(d.Response) +} diff --git a/internal/engine/ooapi/internal/generator/requests.go b/internal/engine/ooapi/internal/generator/requests.go new file mode 100644 index 0000000..1e02fa2 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/requests.go @@ -0,0 +1,141 @@ +package main + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +const ( + tagForQuery = "query" + tagForRequired = "required" +) + +func (d *Descriptor) genNewRequestQueryElemString(sb *strings.Builder, f *reflect.StructField) { + name := f.Name + query := f.Tag.Get(tagForQuery) + if f.Tag.Get(tagForRequired) == "true" { + fmt.Fprintf(sb, "\tif req.%s == \"\" {\n", name) + fmt.Fprintf(sb, "\t\treturn nil, newErrEmptyField(\"%s\")\n", name) + fmt.Fprint(sb, "\t}\n") + fmt.Fprintf(sb, "\tq.Add(\"%s\", req.%s)\n", query, name) + return + } + fmt.Fprintf(sb, "\tif req.%s != \"\" {\n", name) + fmt.Fprintf(sb, "\t\tq.Add(\"%s\", req.%s)\n", query, name) + fmt.Fprint(sb, "\t}\n") +} + +func (d *Descriptor) genNewRequestQueryElemBool(sb *strings.Builder, f *reflect.StructField) { + // required does not make much sense for a boolean field + name := f.Name + query := f.Tag.Get(tagForQuery) + fmt.Fprintf(sb, "\tif req.%s {\n", name) + fmt.Fprintf(sb, "\t\tq.Add(\"%s\", \"true\")\n", query) + fmt.Fprint(sb, "\t}\n") +} + +func (d *Descriptor) genNewRequestQueryElemInt64(sb *strings.Builder, f *reflect.StructField) { + // required does not make much sense for an integer field + name := f.Name + query := f.Tag.Get(tagForQuery) + fmt.Fprintf(sb, "\tif req.%s != 0 {\n", name) + fmt.Fprintf(sb, "\t\tq.Add(\"%s\", newQueryFieldInt64(req.%s))\n", query, name) + fmt.Fprint(sb, "\t}\n") +} + +func (d *Descriptor) genNewRequestQuery(sb *strings.Builder) { + if d.Method != "GET" { + return // we only generate query for GET + } + fields := d.StructFieldsWithTag(d.Request, tagForQuery) + if len(fields) <= 0 { + return + } + fmt.Fprint(sb, "\tq := url.Values{}\n") + for idx, f := range fields { + switch f.Type.Kind() { + case reflect.String: + d.genNewRequestQueryElemString(sb, f) + case reflect.Bool: + d.genNewRequestQueryElemBool(sb, f) + case reflect.Int64: + d.genNewRequestQueryElemInt64(sb, f) + default: + panic(fmt.Sprintf("unexpected query type at index %d", idx)) + } + } + fmt.Fprint(sb, "\tURL.RawQuery = q.Encode()\n") +} + +func (d *Descriptor) genNewRequestCallNewRequest(sb *strings.Builder) { + if d.Method == "POST" { + fmt.Fprint(sb, "\tbody, err := api.jsonCodec().Encode(req)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tout, err := api.requestMaker().NewRequest(") + fmt.Fprintf(sb, "ctx, \"%s\", URL.String(), ", d.Method) + fmt.Fprint(sb, "bytes.NewReader(body))\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tout.Header.Set(\"Content-Type\", \"application/json\")\n") + fmt.Fprint(sb, "\treturn out, nil\n") + return + } + fmt.Fprint(sb, "\treturn api.requestMaker().NewRequest(") + fmt.Fprintf(sb, "ctx, \"%s\", URL.String(), ", d.Method) + fmt.Fprint(sb, "nil)\n") +} + +func (d *Descriptor) genNewRequest(sb *strings.Builder) { + + fmt.Fprintf( + sb, "func (api *%s) newRequest(ctx context.Context, req %s) %s {\n", + d.APIStructName(), d.RequestTypeName(), "(*http.Request, error)") + fmt.Fprint(sb, "\tURL, err := url.Parse(api.baseURL())\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + + switch d.URLPath.IsTemplate { + case false: + fmt.Fprintf(sb, "\tURL.Path = \"%s\"\n", d.URLPath.Value) + case true: + fmt.Fprintf( + sb, "\tup, err := api.templateExecutor().Execute(\"%s\", req)\n", + d.URLPath.Value) + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tURL.Path = up\n") + } + + d.genNewRequestQuery(sb) + d.genNewRequestCallNewRequest(sb) + + fmt.Fprintf(sb, "}\n\n") +} + +// GenRequestsGo generates requests.go. +func GenRequestsGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"bytes\"\n") + fmt.Fprint(&sb, "\t\"context\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\t\"net/url\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n\n") + for _, desc := range Descriptors { + desc.genNewRequest(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/responses.go b/internal/engine/ooapi/internal/generator/responses.go new file mode 100644 index 0000000..18ce9e2 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/responses.go @@ -0,0 +1,80 @@ +package main + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +func (d *Descriptor) genNewResponse(sb *strings.Builder) { + fmt.Fprintf(sb, + "func (api *%s) newResponse(resp *http.Response, err error) (%s, error) {\n", + d.APIStructName(), d.ResponseTypeName()) + + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp.StatusCode == 401 {\n") + fmt.Fprint(sb, "\t\treturn nil, ErrUnauthorized\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tif resp.StatusCode != 200 {\n") + fmt.Fprint(sb, "\t\treturn nil, newHTTPFailure(resp.StatusCode)\n") + fmt.Fprint(sb, "\t}\n") + fmt.Fprint(sb, "\tdefer resp.Body.Close()\n") + fmt.Fprint(sb, "\treader := io.LimitReader(resp.Body, 4<<20)\n") + fmt.Fprint(sb, "\tdata, err := ioutil.ReadAll(reader)\n") + fmt.Fprint(sb, "\tif err != nil {\n") + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + + switch d.ResponseTypeKind() { + case reflect.Map: + fmt.Fprintf(sb, "\tout := %s{}\n", d.ResponseTypeName()) + case reflect.Struct: + fmt.Fprintf(sb, "\tout := &%s{}\n", d.ResponseTypeNameAsStruct()) + } + + switch d.ResponseTypeKind() { + case reflect.Map: + fmt.Fprint(sb, "\tif err := api.jsonCodec().Decode(data, &out); err != nil {\n") + case reflect.Struct: + fmt.Fprint(sb, "\tif err := api.jsonCodec().Decode(data, out); err != nil {\n") + } + + fmt.Fprint(sb, "\t\treturn nil, err\n") + fmt.Fprint(sb, "\t}\n") + + switch d.ResponseTypeKind() { + case reflect.Map: + // For rationale, see https://play.golang.org/p/m9-MsTaQ5wt and + // https://play.golang.org/p/6h-v-PShMk9. + fmt.Fprint(sb, "\tif out == nil {\n") + fmt.Fprint(sb, "\t\treturn nil, ErrJSONLiteralNull\n") + fmt.Fprint(sb, "\t}\n") + case reflect.Struct: + // nothing + } + fmt.Fprintf(sb, "\treturn out, nil\n") + fmt.Fprintf(sb, "}\n\n") +} + +// GenResponsesGo generates responses.go. +func GenResponsesGo(file string) { + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprint(&sb, "import (\n") + fmt.Fprint(&sb, "\t\"io\"\n") + fmt.Fprint(&sb, "\t\"io/ioutil\"\n") + fmt.Fprint(&sb, "\t\"net/http\"\n") + fmt.Fprint(&sb, "\n") + fmt.Fprint(&sb, "\t\"github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel\"\n") + fmt.Fprint(&sb, ")\n\n") + for _, desc := range Descriptors { + desc.genNewResponse(&sb) + } + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/spec.go b/internal/engine/ooapi/internal/generator/spec.go new file mode 100644 index 0000000..01bd0c8 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/spec.go @@ -0,0 +1,136 @@ +package main + +import "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" + +// URLPath describes a URLPath. +type URLPath struct { + // IsTemplate indicates whether Value contains a template. A future + // version of this implementation will automatically deduce that. + IsTemplate bool + + // Value is the value of the URL path. + Value string + + // InSwagger indicates the corresponding name to be used in + // the Swagger specification. + InSwagger string +} + +// Descriptor is an API descriptor. It tells the generator +// what code it should emit for a given API. +type Descriptor struct { + // Name is the name of the API. + Name string + + // CachePolicy indicates the caching policy to use. + CachePolicy int + + // RequiresLogin indicates whether the API requires login. + RequiresLogin bool + + // Method is the method to use ("GET" or "POST"). + Method string + + // URLPath is the URL path. + URLPath URLPath + + // Request is an instance of the request type. + Request interface{} + + // Response is an instance of the response type. + Response interface{} +} + +// These are the caching policies. +const ( + // CacheNone indicates we don't use a cache. + CacheNone = iota + + // CacheFallback indicates we fallback to the cache + // when there is a failure. + CacheFallback + + // CacheAlways indicates that we always check the + // cache before sending a request. + CacheAlways +) + +// Descriptors describes all the APIs. +// +// Note that it matters whether the requests and responses +// are pointers. Generally speaking, if the message is a +// struct, use a pointer. If it's a map, don't. +var Descriptors = []Descriptor{{ + Name: "CheckReportID", + Method: "GET", + URLPath: URLPath{Value: "/api/_/check_report_id"}, + Request: &apimodel.CheckReportIDRequest{}, + Response: &apimodel.CheckReportIDResponse{}, +}, { + Name: "CheckIn", + Method: "POST", + URLPath: URLPath{Value: "/api/v1/check-in"}, + Request: &apimodel.CheckInRequest{}, + Response: &apimodel.CheckInResponse{}, +}, { + Name: "Login", + Method: "POST", + URLPath: URLPath{Value: "/api/v1/login"}, + Request: &apimodel.LoginRequest{}, + Response: &apimodel.LoginResponse{}, +}, { + Name: "MeasurementMeta", + Method: "GET", + URLPath: URLPath{Value: "/api/v1/measurement_meta"}, + Request: &apimodel.MeasurementMetaRequest{}, + Response: &apimodel.MeasurementMetaResponse{}, + CachePolicy: CacheAlways, +}, { + Name: "Register", + Method: "POST", + URLPath: URLPath{Value: "/api/v1/register"}, + Request: &apimodel.RegisterRequest{}, + Response: &apimodel.RegisterResponse{}, +}, { + Name: "TestHelpers", + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-helpers"}, + Request: &apimodel.TestHelpersRequest{}, + Response: apimodel.TestHelpersResponse{}, +}, { + Name: "PsiphonConfig", + RequiresLogin: true, + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-list/psiphon-config"}, + Request: &apimodel.PsiphonConfigRequest{}, + Response: apimodel.PsiphonConfigResponse{}, +}, { + Name: "TorTargets", + RequiresLogin: true, + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-list/tor-targets"}, + Request: &apimodel.TorTargetsRequest{}, + Response: apimodel.TorTargetsResponse{}, +}, { + Name: "URLs", + Method: "GET", + URLPath: URLPath{Value: "/api/v1/test-list/urls"}, + Request: &apimodel.URLsRequest{}, + Response: &apimodel.URLsResponse{}, +}, { + Name: "OpenReport", + Method: "POST", + URLPath: URLPath{Value: "/report"}, + Request: &apimodel.OpenReportRequest{}, + Response: &apimodel.OpenReportResponse{}, +}, { + Name: "SubmitMeasurement", + Method: "POST", + URLPath: URLPath{ + InSwagger: "/report/{report_id}", + IsTemplate: true, + Value: "/report/{{ .ReportID }}", + }, + Request: &apimodel.SubmitMeasurementRequest{}, + Response: &apimodel.SubmitMeasurementResponse{}, +}} diff --git a/internal/engine/ooapi/internal/generator/swaggertest.go b/internal/engine/ooapi/internal/generator/swaggertest.go new file mode 100644 index 0000000..3fadc38 --- /dev/null +++ b/internal/engine/ooapi/internal/generator/swaggertest.go @@ -0,0 +1,194 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "reflect" + "strings" + "sync" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/internal/openapi" +) + +const ( + tagForJSON = "json" + tagForPath = "path" +) + +func (d *Descriptor) genSwaggerURLPath() string { + up := d.URLPath + if up.InSwagger != "" { + return up.InSwagger + } + if up.IsTemplate { + panic("we should always use InSwapper and IsTemplate together") + } + return up.Value +} + +func (d *Descriptor) genSwaggerSchema(cur reflect.Type) *openapi.Schema { + switch cur.Kind() { + case reflect.String: + return &openapi.Schema{Type: "string"} + case reflect.Bool: + return &openapi.Schema{Type: "boolean"} + case reflect.Int64: + return &openapi.Schema{Type: "integer"} + case reflect.Slice: + return &openapi.Schema{Type: "array", Items: d.genSwaggerSchema(cur.Elem())} + case reflect.Map: + return &openapi.Schema{Type: "object"} + case reflect.Ptr: + return d.genSwaggerSchema(cur.Elem()) + case reflect.Struct: + if cur.String() == "time.Time" { + // Implementation note: we don't want to dive into time.Time but + // rather we want to pretend it's a string. The JSON parser for + // time.Time can indeed reconstruct a time.Time from a string, and + // it's much easier for us to let it do the parsing. + return &openapi.Schema{Type: "string"} + } + sinfo := &openapi.Schema{Type: "object"} + var once sync.Once + initmap := func() { + sinfo.Properties = make(map[string]*openapi.Schema) + } + for idx := 0; idx < cur.NumField(); idx++ { + field := cur.Field(idx) + if field.Tag.Get(tagForPath) != "" { + continue // skipping because this is a path param + } + if field.Tag.Get(tagForQuery) != "" { + continue // skipping because this is a query param + } + v := field.Name + if j := field.Tag.Get(tagForJSON); j != "" { + j = strings.Replace(j, ",omitempty", "", 1) // remove options + if j == "-" { + continue // not exported via JSON + } + v = j + } + once.Do(initmap) + sinfo.Properties[v] = d.genSwaggerSchema(field.Type) + } + return sinfo + case reflect.Interface: + return &openapi.Schema{Type: "object"} + default: + panic("unsupported type") + } +} + +func (d *Descriptor) swaggerParamForType(t reflect.Type) string { + switch t.Kind() { + case reflect.String: + return "string" + case reflect.Bool: + return "boolean" + case reflect.Int64: + return "integer" + default: + panic("unsupported type") + } +} + +func (d *Descriptor) genSwaggerParams(cur reflect.Type) []*openapi.Parameter { + // when we have params the input must be a pointer to struct + if cur.Kind() != reflect.Ptr { + panic("not a pointer") + } + cur = cur.Elem() + if cur.Kind() != reflect.Struct { + panic("not a pointer to struct") + } + // now that we're sure of the type, inspect the fields + var out []*openapi.Parameter + for idx := 0; idx < cur.NumField(); idx++ { + f := cur.Field(idx) + if q := f.Tag.Get(tagForQuery); q != "" { + out = append( + out, &openapi.Parameter{ + Name: q, + In: "query", + Required: f.Tag.Get(tagForRequired) == "true", + Type: d.swaggerParamForType(f.Type), + }) + continue + } + if p := f.Tag.Get(tagForPath); p != "" { + out = append(out, &openapi.Parameter{ + Name: p, + In: "path", + Required: true, + Type: d.swaggerParamForType(f.Type), + }) + continue + } + } + return out +} + +func (d *Descriptor) genSwaggerPath() (string, *openapi.Path) { + pathStr, pathInfo := d.genSwaggerURLPath(), &openapi.Path{} + rtinfo := &openapi.RoundTrip{Produces: []string{"application/json"}} + switch d.Method { + case "GET": + pathInfo.Get = rtinfo + case "POST": + rtinfo.Consumes = append(rtinfo.Consumes, "application/json") + pathInfo.Post = rtinfo + default: + panic("unsupported method") + } + rtinfo.Parameters = d.genSwaggerParams(reflect.TypeOf(d.Request)) + if d.Method != "GET" { + rtinfo.Parameters = append(rtinfo.Parameters, &openapi.Parameter{ + Name: "body", + In: "body", + Required: true, + Schema: d.genSwaggerSchema(reflect.TypeOf(d.Request)), + }) + } + rtinfo.Responses = &openapi.Responses{Successful: openapi.Body{ + Description: "all good", + Schema: d.genSwaggerSchema(reflect.TypeOf(d.Response)), + }} + return pathStr, pathInfo +} + +func genSwaggerVersion() string { + return time.Now().UTC().Format("0.20060102.1150405") +} + +// GenSwaggerTestGo generates swagger_test.go +func GenSwaggerTestGo(file string) { + swagger := openapi.Swagger{ + Swagger: "2.0", + Info: openapi.API{ + Title: "OONI API specification", + Version: genSwaggerVersion(), + }, + Host: "api.ooni.io", + BasePath: "/", + Schemes: []string{"https"}, + Paths: make(map[string]*openapi.Path), + } + for _, desc := range Descriptors { + pathStr, pathInfo := desc.genSwaggerPath() + swagger.Paths[pathStr] = pathInfo + } + data, err := json.MarshalIndent(swagger, "", " ") + if err != nil { + log.Fatal(err) + } + var sb strings.Builder + fmt.Fprint(&sb, "// Code generated by go generate; DO NOT EDIT.\n") + fmt.Fprintf(&sb, "// %s\n\n", time.Now()) + fmt.Fprint(&sb, "package ooapi\n\n") + fmt.Fprintf(&sb, "//go:generate go run ./internal/generator -file %s\n\n", file) + fmt.Fprintf(&sb, "const swagger = `%s`\n", string(data)) + writefile(file, &sb) +} diff --git a/internal/engine/ooapi/internal/generator/writefile.go b/internal/engine/ooapi/internal/generator/writefile.go new file mode 100644 index 0000000..48bbf2a --- /dev/null +++ b/internal/engine/ooapi/internal/generator/writefile.go @@ -0,0 +1,27 @@ +package main + +import ( + "fmt" + "log" + "os" + "strings" + + "golang.org/x/sys/execabs" +) + +func writefile(name string, sb *strings.Builder) { + filep, err := os.Create(name) + if err != nil { + log.Fatal(err) + } + if _, err := fmt.Fprint(filep, sb.String()); err != nil { + log.Fatal(err) + } + if err := filep.Close(); err != nil { + log.Fatal(err) + } + cmd := execabs.Command("go", "fmt", name) + if err := cmd.Run(); err != nil { + log.Fatal(err) + } +} diff --git a/internal/engine/ooapi/internal/openapi/openapi.go b/internal/engine/ooapi/internal/openapi/openapi.go new file mode 100644 index 0000000..d20d4bb --- /dev/null +++ b/internal/engine/ooapi/internal/openapi/openapi.go @@ -0,0 +1,64 @@ +// Package openapi contains data structures for Swagger v2.0. +// +// We use these data structures to compare the API specification we +// have here with the one of the server. +package openapi + +// Schema is the schema of a specific parameter or +// or the schema used by the response body +type Schema struct { + Properties map[string]*Schema `json:"properties,omitempty"` + Items *Schema `json:"items,omitempty"` + Type string `json:"type"` +} + +// Parameter describes an input parameter, which could be in the +// URL path, in the query string, or in the request body +type Parameter struct { + In string `json:"in"` + Name string `json:"name"` + Required bool `json:"required,omitempty"` + Schema *Schema `json:"schema,omitempty"` + Type string `json:"type,omitempty"` +} + +// Body describes a response body +type Body struct { + Description interface{} `json:"description,omitempty"` + Schema *Schema `json:"schema"` +} + +// Responses describes the possible responses +type Responses struct { + Successful Body `json:"200"` +} + +// RoundTrip describes an HTTP round trip with a given method and path +type RoundTrip struct { + Consumes []string `json:"consumes,omitempty"` + Produces []string `json:"produces,omitempty"` + Parameters []*Parameter `json:"parameters,omitempty"` + Responses *Responses `json:"responses,omitempty"` +} + +// Path describes a path served by the API +type Path struct { + Get *RoundTrip `json:"get,omitempty"` + Post *RoundTrip `json:"post,omitempty"` +} + +// API contains info about the API +type API struct { + Title string `json:"title"` + Version string `json:"version"` +} + +// Swagger is the toplevel structure +type Swagger struct { + Swagger string `json:"swagger"` + Info API `json:"info"` + Host string `json:"host"` + BasePath string `json:"basePath"` + Schemes []string `json:"schemes"` + Paths map[string]*Path `json:"paths"` +} diff --git a/internal/engine/ooapi/kvstore_test.go b/internal/engine/ooapi/kvstore_test.go new file mode 100644 index 0000000..31c7e71 --- /dev/null +++ b/internal/engine/ooapi/kvstore_test.go @@ -0,0 +1,34 @@ +package ooapi + +import ( + "errors" + "fmt" + "sync" +) + +var errMemkvstoreNotFound = errors.New("memkvstore: not found") + +type memkvstore struct { + m map[string][]byte + mu sync.Mutex +} + +func (kvs *memkvstore) Get(key string) ([]byte, error) { + defer kvs.mu.Unlock() + kvs.mu.Lock() + out, good := kvs.m[key] + if !good { + return nil, fmt.Errorf("%w: %s", errMemkvstoreNotFound, key) + } + return out, nil +} + +func (kvs *memkvstore) Set(key string, value []byte) error { + defer kvs.mu.Unlock() + kvs.mu.Lock() + if kvs.m == nil { + kvs.m = make(map[string][]byte) + } + kvs.m[key] = value + return nil +} diff --git a/internal/engine/ooapi/login.go b/internal/engine/ooapi/login.go new file mode 100644 index 0000000..49f9833 --- /dev/null +++ b/internal/engine/ooapi/login.go @@ -0,0 +1,295 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.62521737 +0100 CET m=+0.000161706 + +package ooapi + +//go:generate go run ./internal/generator -file login.go + +import ( + "context" + "errors" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// PsiphonConfigAPIWithLogin implements login for PsiphonConfigAPI. +type PsiphonConfigAPIWithLogin struct { + API PsiphonConfigCloner // mandatory + JSONCodec JSONCodec // optional + KVStore KVStore // mandatory + RegisterAPI RegisterCaller // mandatory + LoginAPI LoginCaller // mandatory +} + +// Call logins, if needed, then calls the API. +func (api *PsiphonConfigAPIWithLogin) Call(ctx context.Context, req *apimodel.PsiphonConfigRequest) (apimodel.PsiphonConfigResponse, error) { + token, err := api.maybeLogin(ctx) + if err != nil { + return nil, err + } + resp, err := api.API.WithToken(token).Call(ctx, req) + if errors.Is(err, ErrUnauthorized) { + // Maybe the clock is just off? Let's try to obtain + // a token again and see if this fixes it. + if token, err = api.forceLogin(ctx); err == nil { + switch resp, err = api.API.WithToken(token).Call(ctx, req); err { + case nil: + return resp, nil + case ErrUnauthorized: + // fallthrough + default: + return nil, err + } + } + // Okay, this seems a broader problem. How about we try + // and re-register ourselves again instead? + token, err = api.forceRegister(ctx) + if err != nil { + return nil, err + } + resp, err = api.API.WithToken(token).Call(ctx, req) + // fallthrough + } + if err != nil { + return nil, err + } + return resp, nil +} + +func (api *PsiphonConfigAPIWithLogin) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *PsiphonConfigAPIWithLogin) readstate() (*loginState, error) { + data, err := api.KVStore.Get(loginKey) + if err != nil { + return nil, err + } + var ls loginState + if err := api.jsonCodec().Decode(data, &ls); err != nil { + return nil, err + } + return &ls, nil +} + +func (api *PsiphonConfigAPIWithLogin) writestate(ls *loginState) error { + data, err := api.jsonCodec().Encode(*ls) + if err != nil { + return err + } + return api.KVStore.Set(loginKey, data) +} + +func (api *PsiphonConfigAPIWithLogin) doRegister(ctx context.Context, password string) (string, error) { + req := newRegisterRequest(password) + ls := &loginState{} + resp, err := api.RegisterAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.ClientID = resp.ClientID + ls.Password = req.Password + return api.doLogin(ctx, ls) +} + +func (api *PsiphonConfigAPIWithLogin) forceRegister(ctx context.Context) (string, error) { + var password string + // If we already have a previous password, let us keep + // using it. This will allow a new version of the API to + // be able to continue to identify this probe. (This + // assumes that we have a stateless API that generates + // the user ID as a signature of the password plus a + // timestamp and that the key to generate the signature + // is not lost. If all these conditions are met, we + // can then serve better test targets to more long running + // (and therefore trusted) probes.) + if ls, err := api.readstate(); err == nil { + password = ls.Password + } + if password == "" { + password = newRandomPassword() + } + return api.doRegister(ctx, password) +} + +func (api *PsiphonConfigAPIWithLogin) forceLogin(ctx context.Context) (string, error) { + ls, err := api.readstate() + if err != nil { + return "", err + } + return api.doLogin(ctx, ls) +} + +func (api *PsiphonConfigAPIWithLogin) maybeLogin(ctx context.Context) (string, error) { + ls, _ := api.readstate() + if ls == nil || !ls.credentialsValid() { + return api.forceRegister(ctx) + } + if !ls.tokenValid() { + return api.doLogin(ctx, ls) + } + return ls.Token, nil +} + +func (api *PsiphonConfigAPIWithLogin) doLogin(ctx context.Context, ls *loginState) (string, error) { + req := &apimodel.LoginRequest{ + ClientID: ls.ClientID, + Password: ls.Password, + } + resp, err := api.LoginAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.Token = resp.Token + ls.Expire = resp.Expire + if err := api.writestate(ls); err != nil { + return "", err + } + return ls.Token, nil +} + +var _ PsiphonConfigCaller = &PsiphonConfigAPIWithLogin{} + +// TorTargetsAPIWithLogin implements login for TorTargetsAPI. +type TorTargetsAPIWithLogin struct { + API TorTargetsCloner // mandatory + JSONCodec JSONCodec // optional + KVStore KVStore // mandatory + RegisterAPI RegisterCaller // mandatory + LoginAPI LoginCaller // mandatory +} + +// Call logins, if needed, then calls the API. +func (api *TorTargetsAPIWithLogin) Call(ctx context.Context, req *apimodel.TorTargetsRequest) (apimodel.TorTargetsResponse, error) { + token, err := api.maybeLogin(ctx) + if err != nil { + return nil, err + } + resp, err := api.API.WithToken(token).Call(ctx, req) + if errors.Is(err, ErrUnauthorized) { + // Maybe the clock is just off? Let's try to obtain + // a token again and see if this fixes it. + if token, err = api.forceLogin(ctx); err == nil { + switch resp, err = api.API.WithToken(token).Call(ctx, req); err { + case nil: + return resp, nil + case ErrUnauthorized: + // fallthrough + default: + return nil, err + } + } + // Okay, this seems a broader problem. How about we try + // and re-register ourselves again instead? + token, err = api.forceRegister(ctx) + if err != nil { + return nil, err + } + resp, err = api.API.WithToken(token).Call(ctx, req) + // fallthrough + } + if err != nil { + return nil, err + } + return resp, nil +} + +func (api *TorTargetsAPIWithLogin) jsonCodec() JSONCodec { + if api.JSONCodec != nil { + return api.JSONCodec + } + return &defaultJSONCodec{} +} + +func (api *TorTargetsAPIWithLogin) readstate() (*loginState, error) { + data, err := api.KVStore.Get(loginKey) + if err != nil { + return nil, err + } + var ls loginState + if err := api.jsonCodec().Decode(data, &ls); err != nil { + return nil, err + } + return &ls, nil +} + +func (api *TorTargetsAPIWithLogin) writestate(ls *loginState) error { + data, err := api.jsonCodec().Encode(*ls) + if err != nil { + return err + } + return api.KVStore.Set(loginKey, data) +} + +func (api *TorTargetsAPIWithLogin) doRegister(ctx context.Context, password string) (string, error) { + req := newRegisterRequest(password) + ls := &loginState{} + resp, err := api.RegisterAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.ClientID = resp.ClientID + ls.Password = req.Password + return api.doLogin(ctx, ls) +} + +func (api *TorTargetsAPIWithLogin) forceRegister(ctx context.Context) (string, error) { + var password string + // If we already have a previous password, let us keep + // using it. This will allow a new version of the API to + // be able to continue to identify this probe. (This + // assumes that we have a stateless API that generates + // the user ID as a signature of the password plus a + // timestamp and that the key to generate the signature + // is not lost. If all these conditions are met, we + // can then serve better test targets to more long running + // (and therefore trusted) probes.) + if ls, err := api.readstate(); err == nil { + password = ls.Password + } + if password == "" { + password = newRandomPassword() + } + return api.doRegister(ctx, password) +} + +func (api *TorTargetsAPIWithLogin) forceLogin(ctx context.Context) (string, error) { + ls, err := api.readstate() + if err != nil { + return "", err + } + return api.doLogin(ctx, ls) +} + +func (api *TorTargetsAPIWithLogin) maybeLogin(ctx context.Context) (string, error) { + ls, _ := api.readstate() + if ls == nil || !ls.credentialsValid() { + return api.forceRegister(ctx) + } + if !ls.tokenValid() { + return api.doLogin(ctx, ls) + } + return ls.Token, nil +} + +func (api *TorTargetsAPIWithLogin) doLogin(ctx context.Context, ls *loginState) (string, error) { + req := &apimodel.LoginRequest{ + ClientID: ls.ClientID, + Password: ls.Password, + } + resp, err := api.LoginAPI.Call(ctx, req) + if err != nil { + return "", err + } + ls.Token = resp.Token + ls.Expire = resp.Expire + if err := api.writestate(ls); err != nil { + return "", err + } + return ls.Token, nil +} + +var _ TorTargetsCaller = &TorTargetsAPIWithLogin{} diff --git a/internal/engine/ooapi/login_test.go b/internal/engine/ooapi/login_test.go new file mode 100644 index 0000000..c5b3e44 --- /dev/null +++ b/internal/engine/ooapi/login_test.go @@ -0,0 +1,1365 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:52.9205436 +0100 CET m=+0.000137951 + +package ooapi + +//go:generate go run ./internal/generator -file login_test.go + +import ( + "context" + "errors" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func TestRegisterAndLoginPsiphonConfigAPISuccess(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIContinueUsingToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } + } + // step 2: we disable register and login but we + // should be okay because of the token + errMocked := errors.New("mocked error") + registerAPI.Err = errMocked + registerAPI.Response = nil + loginAPI.Err = errMocked + loginAPI.Response = nil + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIWithValidButExpiredToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 0 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIWithRegisterAPIError(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIWithLoginFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + errMocked := errors.New("mocked error") + loginAPI := &FakeLoginAPI{ + Err: errMocked, + } + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestRegisterAndLoginPsiphonConfigAPIThenFail(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Err: errMocked, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPITheDatabaseIsReplaced(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + handler.forgetLogins() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestRegisterAndLoginPsiphonConfigAPICannotWriteState(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &PsiphonConfigAPIWithLogin{ + API: &FakePsiphonConfigAPI{ + WithResult: &FakePsiphonConfigAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{ + EncodeErr: errMocked, + }, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestPsiphonConfigAPIReadStateDecodeFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.PsiphonConfigResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + login := &PsiphonConfigAPIWithLogin{ + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + out, err := login.forceLogin(context.Background()) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string here") + } +} + +func TestPsiphonConfigAPITheDatabaseIsReplacedThenFailure(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + // but registrations are also failing. + handler.forgetLogins() + handler.noRegister = true + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestPsiphonConfigAPIClockIsOffThenSuccess(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} + +func TestPsiphonConfigAPIClockIsOffThen401(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 401} + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal("not the error we expected", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestPsiphonConfigAPIClockIsOffThen500(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &PsiphonConfigAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &PsiphonConfigAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.PsiphonConfigRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 500} + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} + +func TestRegisterAndLoginTorTargetsAPISuccess(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIContinueUsingToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } + } + // step 2: we disable register and login but we + // should be okay because of the token + errMocked := errors.New("mocked error") + registerAPI.Err = errMocked + registerAPI.Response = nil + loginAPI.Err = errMocked + loginAPI.Response = nil + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIWithValidButExpiredToken(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if diff := cmp.Diff(expect, resp); diff != "" { + t.Fatal(diff) + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 0 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIWithRegisterAPIError(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + registerAPI := &FakeRegisterAPI{ + Err: errMocked, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIWithLoginFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + errMocked := errors.New("mocked error") + loginAPI := &FakeLoginAPI{ + Err: errMocked, + } + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestRegisterAndLoginTorTargetsAPIThenFail(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Err: errMocked, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPITheDatabaseIsReplaced(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + handler.forgetLogins() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestRegisterAndLoginTorTargetsAPICannotWriteState(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + registerAPI := &FakeRegisterAPI{ + Response: &apimodel.RegisterResponse{ + ClientID: "antani-antani", + }, + } + loginAPI := &FakeLoginAPI{ + Response: &apimodel.LoginResponse{ + Expire: time.Now().Add(3600 * time.Second), + Token: "antani-antani-token", + }, + } + errMocked := errors.New("mocked error") + login := &TorTargetsAPIWithLogin{ + API: &FakeTorTargetsAPI{ + WithResult: &FakeTorTargetsAPI{ + Response: expect, + }, + }, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{ + EncodeErr: errMocked, + }, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + resp, err := login.Call(ctx, req) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if loginAPI.CountCall != 1 { + t.Fatal("invalid loginAPI.CountCall") + } + if registerAPI.CountCall != 1 { + t.Fatal("invalid registerAPI.CountCall") + } +} + +func TestTorTargetsAPIReadStateDecodeFailure(t *testing.T) { + ff := &fakeFill{} + var expect apimodel.TorTargetsResponse + ff.fill(&expect) + errMocked := errors.New("mocked error") + login := &TorTargetsAPIWithLogin{ + KVStore: &memkvstore{}, + JSONCodec: &FakeCodec{DecodeErr: errMocked}, + } + ls := &loginState{ + ClientID: "antani-antani", + Expire: time.Now().Add(-5 * time.Second), + Token: "antani-antani-token", + Password: "antani-antani-password", + } + if err := login.writestate(ls); err != nil { + t.Fatal(err) + } + out, err := login.forceLogin(context.Background()) + if !errors.Is(err, errMocked) { + t.Fatal("not the error we expected", err) + } + if out != "" { + t.Fatal("expected empty string here") + } +} + +func TestTorTargetsAPITheDatabaseIsReplacedThenFailure(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget accounts and try again. + // but registrations are also failing. + handler.forgetLogins() + handler.noRegister = true + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestTorTargetsAPIClockIsOffThenSuccess(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} + +func TestTorTargetsAPIClockIsOffThen401(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 401} + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal("not the error we expected", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 3 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 2 { + t.Fatal("invalid handler.registers") + } +} + +func TestTorTargetsAPIClockIsOffThen500(t *testing.T) { + ff := &fakeFill{} + handler := &LoginHandler{t: t} + srvr := httptest.NewServer(handler) + defer srvr.Close() + registerAPI := &RegisterAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + loginAPI := &LoginAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + baseAPI := &TorTargetsAPI{ + HTTPClient: &VerboseHTTPClient{t: t}, + BaseURL: srvr.URL, + } + login := &TorTargetsAPIWithLogin{ + API: baseAPI, + RegisterAPI: registerAPI, + LoginAPI: loginAPI, + KVStore: &memkvstore{}, + } + var req *apimodel.TorTargetsRequest + ff.fill(&req) + ctx := context.Background() + // step 1: we register and login and use the token + // inside a scope just to avoid mistakes + { + resp, err := login.Call(ctx, req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + if handler.logins != 1 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } + } + // step 2: we forget tokens and try again. + // this should simulate the client clock + // being off and considering a token still valid + handler.forgetTokens() + handler.failCallWith = []int{401, 500} + resp, err := login.Call(ctx, req) + if !errors.Is(err, ErrHTTPFailure) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + if handler.logins != 2 { + t.Fatal("invalid handler.logins") + } + if handler.registers != 1 { + t.Fatal("invalid handler.registers") + } +} diff --git a/internal/engine/ooapi/loginhandler_test.go b/internal/engine/ooapi/loginhandler_test.go new file mode 100644 index 0000000..f9e0274 --- /dev/null +++ b/internal/engine/ooapi/loginhandler_test.go @@ -0,0 +1,204 @@ +package ooapi + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +// LoginHandler is an http.Handler to test login +type LoginHandler struct { + failCallWith []int // ignored by login and register + mu sync.Mutex + noRegister bool + state []*loginState + t *testing.T + logins int32 + registers int32 +} + +func (lh *LoginHandler) forgetLogins() { + defer lh.mu.Unlock() + lh.mu.Lock() + lh.state = nil +} + +func (lh *LoginHandler) forgetTokens() { + defer lh.mu.Unlock() + lh.mu.Lock() + for _, entry := range lh.state { + // This should be enough to cause all tokens to + // be expired and force clients to relogin. + // + // (It does not matter much whether the client + // clock is off, or the server clock is off, + // thanks Galileo for explaining this to us <3.) + entry.Expire = time.Now().Add(-3600 * time.Second) + } +} + +func (lh *LoginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Implementation note: we don't check for the method + // for simplicity since it's already tested. + switch r.URL.Path { + case "/api/v1/register": + atomic.AddInt32(&lh.registers, 1) + lh.register(w, r) + case "/api/v1/login": + atomic.AddInt32(&lh.logins, 1) + lh.login(w, r) + case "/api/v1/test-list/psiphon-config": + lh.psiphon(w, r) + case "/api/v1/test-list/tor-targets": + lh.tor(w, r) + default: + w.WriteHeader(500) + } +} + +func (lh *LoginHandler) register(w http.ResponseWriter, r *http.Request) { + if r.Body == nil { + w.WriteHeader(400) + return + } + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + var req apimodel.RegisterRequest + if err := json.Unmarshal(data, &req); err != nil { + w.WriteHeader(400) + return + } + if req.Password == "" { + w.WriteHeader(400) + return + } + defer lh.mu.Unlock() + lh.mu.Lock() + if lh.noRegister { + // We have been asked to stop registering clients so + // we're going to make a boo boo. + w.WriteHeader(500) + return + } + var resp apimodel.RegisterResponse + ff := &fakeFill{} + ff.fill(&resp) + lh.state = append(lh.state, &loginState{ + ClientID: resp.ClientID, Password: req.Password}) + data, err = json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("register: %+v", string(data)) + w.Write(data) +} + +func (lh *LoginHandler) login(w http.ResponseWriter, r *http.Request) { + if r.Body == nil { + w.WriteHeader(400) + return + } + data, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(400) + return + } + var req apimodel.LoginRequest + if err := json.Unmarshal(data, &req); err != nil { + w.WriteHeader(400) + return + } + defer lh.mu.Unlock() + lh.mu.Lock() + for _, s := range lh.state { + if req.ClientID == s.ClientID && req.Password == s.Password { + var resp apimodel.LoginResponse + ff := &fakeFill{} + ff.fill(&resp) + // We want the token to be many seconds in the future while + // ff.fill only sets the tokent to now plus a small delta. + resp.Expire = time.Now().Add(3600 * time.Second) + s.Expire = resp.Expire + s.Token = resp.Token + data, err = json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("login: %+v", string(data)) + w.Write(data) + return + } + } + lh.t.Log("login: 401") + w.WriteHeader(401) +} + +func (lh *LoginHandler) psiphon(w http.ResponseWriter, r *http.Request) { + defer lh.mu.Unlock() + lh.mu.Lock() + if len(lh.failCallWith) > 0 { + code := lh.failCallWith[0] + lh.failCallWith = lh.failCallWith[1:] + w.WriteHeader(code) + return + } + token := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1) + for _, s := range lh.state { + if token == s.Token && time.Now().Before(s.Expire) { + var resp apimodel.PsiphonConfigResponse + ff := &fakeFill{} + ff.fill(&resp) + data, err := json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("psiphon: %+v", string(data)) + w.Write(data) + return + } + } + lh.t.Log("psiphon: 401") + w.WriteHeader(401) +} + +func (lh *LoginHandler) tor(w http.ResponseWriter, r *http.Request) { + defer lh.mu.Unlock() + lh.mu.Lock() + if len(lh.failCallWith) > 0 { + code := lh.failCallWith[0] + lh.failCallWith = lh.failCallWith[1:] + w.WriteHeader(code) + return + } + token := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1) + for _, s := range lh.state { + if token == s.Token && time.Now().Before(s.Expire) { + var resp apimodel.TorTargetsResponse + ff := &fakeFill{} + ff.fill(&resp) + data, err := json.Marshal(&resp) + if err != nil { + w.WriteHeader(500) + return + } + lh.t.Logf("tor: %+v", string(data)) + w.Write(data) + return + } + } + lh.t.Log("tor: 401") + w.WriteHeader(401) +} diff --git a/internal/engine/ooapi/loginmodel.go b/internal/engine/ooapi/loginmodel.go new file mode 100644 index 0000000..4337348 --- /dev/null +++ b/internal/engine/ooapi/loginmodel.go @@ -0,0 +1,59 @@ +package ooapi + +import ( + "crypto/rand" + "encoding/base64" + "time" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" + "github.com/ooni/probe-cli/v3/internal/engine/runtimex" +) + +// loginState is the struct saved in the kvstore +// to keep track of the login state. +type loginState struct { + ClientID string + Expire time.Time + Password string + Token string +} + +func (ls *loginState) credentialsValid() bool { + return ls.ClientID != "" && ls.Password != "" +} + +func (ls *loginState) tokenValid() bool { + return ls.Token != "" && time.Now().Add(60*time.Second).Before(ls.Expire) +} + +// loginKey is the key with which loginState is saved +// into the key-value store used by Client. +const loginKey = "orchestra.state" + +// newRandomPassword generates a new random password. +func newRandomPassword() string { + b := make([]byte, 48) + _, err := rand.Read(b) + runtimex.PanicOnError(err, "rand.Read failed") + return base64.StdEncoding.EncodeToString(b) +} + +// newRegisterRequest creates a new RegisterRequest. +func newRegisterRequest(password string) *apimodel.RegisterRequest { + return &apimodel.RegisterRequest{ + // The original implementation has as its only use case that we + // were registering and logging in for sending an update regarding + // the probe whereabouts. Yet here in probe-engine, the orchestra + // is currently only used to fetch inputs. For this purpose, we don't + // need to communicate any specific information. The code that will + // perform an update used to be responsible of doing that. Now, we + // are not using orchestra for this purpose anymore. + Platform: "miniooni", + ProbeASN: "AS0", + ProbeCC: "ZZ", + SoftwareName: "miniooni", + SoftwareVersion: "0.1.0-dev", + SupportedTests: []string{"web_connectivity"}, + Password: password, + } +} diff --git a/internal/engine/ooapi/requests.go b/internal/engine/ooapi/requests.go new file mode 100644 index 0000000..806d5f2 --- /dev/null +++ b/internal/engine/ooapi/requests.go @@ -0,0 +1,192 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:53.210720456 +0100 CET m=+0.000083649 + +package ooapi + +//go:generate go run ./internal/generator -file requests.go + +import ( + "bytes" + "context" + "net/http" + "net/url" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func (api *CheckReportIDAPI) newRequest(ctx context.Context, req *apimodel.CheckReportIDRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/_/check_report_id" + q := url.Values{} + if req.ReportID == "" { + return nil, newErrEmptyField("ReportID") + } + q.Add("report_id", req.ReportID) + URL.RawQuery = q.Encode() + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *CheckInAPI) newRequest(ctx context.Context, req *apimodel.CheckInRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/check-in" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *LoginAPI) newRequest(ctx context.Context, req *apimodel.LoginRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/login" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *MeasurementMetaAPI) newRequest(ctx context.Context, req *apimodel.MeasurementMetaRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/measurement_meta" + q := url.Values{} + if req.ReportID == "" { + return nil, newErrEmptyField("ReportID") + } + q.Add("report_id", req.ReportID) + if req.Full { + q.Add("full", "true") + } + if req.Input != "" { + q.Add("input", req.Input) + } + URL.RawQuery = q.Encode() + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *RegisterAPI) newRequest(ctx context.Context, req *apimodel.RegisterRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/register" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *TestHelpersAPI) newRequest(ctx context.Context, req *apimodel.TestHelpersRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-helpers" + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *PsiphonConfigAPI) newRequest(ctx context.Context, req *apimodel.PsiphonConfigRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-list/psiphon-config" + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *TorTargetsAPI) newRequest(ctx context.Context, req *apimodel.TorTargetsRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-list/tor-targets" + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *URLsAPI) newRequest(ctx context.Context, req *apimodel.URLsRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/api/v1/test-list/urls" + q := url.Values{} + if req.CategoryCodes != "" { + q.Add("category_codes", req.CategoryCodes) + } + if req.CountryCode != "" { + q.Add("country_code", req.CountryCode) + } + if req.Limit != 0 { + q.Add("limit", newQueryFieldInt64(req.Limit)) + } + URL.RawQuery = q.Encode() + return api.requestMaker().NewRequest(ctx, "GET", URL.String(), nil) +} + +func (api *OpenReportAPI) newRequest(ctx context.Context, req *apimodel.OpenReportRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + URL.Path = "/report" + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} + +func (api *SubmitMeasurementAPI) newRequest(ctx context.Context, req *apimodel.SubmitMeasurementRequest) (*http.Request, error) { + URL, err := url.Parse(api.baseURL()) + if err != nil { + return nil, err + } + up, err := api.templateExecutor().Execute("/report/{{ .ReportID }}", req) + if err != nil { + return nil, err + } + URL.Path = up + body, err := api.jsonCodec().Encode(req) + if err != nil { + return nil, err + } + out, err := api.requestMaker().NewRequest(ctx, "POST", URL.String(), bytes.NewReader(body)) + if err != nil { + return nil, err + } + out.Header.Set("Content-Type", "application/json") + return out, nil +} diff --git a/internal/engine/ooapi/responses.go b/internal/engine/ooapi/responses.go new file mode 100644 index 0000000..b148aea --- /dev/null +++ b/internal/engine/ooapi/responses.go @@ -0,0 +1,276 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:53.567815989 +0100 CET m=+0.000158731 + +package ooapi + +//go:generate go run ./internal/generator -file responses.go + +import ( + "io" + "io/ioutil" + "net/http" + + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/apimodel" +) + +func (api *CheckReportIDAPI) newResponse(resp *http.Response, err error) (*apimodel.CheckReportIDResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.CheckReportIDResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *CheckInAPI) newResponse(resp *http.Response, err error) (*apimodel.CheckInResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.CheckInResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *LoginAPI) newResponse(resp *http.Response, err error) (*apimodel.LoginResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.LoginResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *MeasurementMetaAPI) newResponse(resp *http.Response, err error) (*apimodel.MeasurementMetaResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.MeasurementMetaResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *RegisterAPI) newResponse(resp *http.Response, err error) (*apimodel.RegisterResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.RegisterResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *TestHelpersAPI) newResponse(resp *http.Response, err error) (apimodel.TestHelpersResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := apimodel.TestHelpersResponse{} + if err := api.jsonCodec().Decode(data, &out); err != nil { + return nil, err + } + if out == nil { + return nil, ErrJSONLiteralNull + } + return out, nil +} + +func (api *PsiphonConfigAPI) newResponse(resp *http.Response, err error) (apimodel.PsiphonConfigResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := apimodel.PsiphonConfigResponse{} + if err := api.jsonCodec().Decode(data, &out); err != nil { + return nil, err + } + if out == nil { + return nil, ErrJSONLiteralNull + } + return out, nil +} + +func (api *TorTargetsAPI) newResponse(resp *http.Response, err error) (apimodel.TorTargetsResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := apimodel.TorTargetsResponse{} + if err := api.jsonCodec().Decode(data, &out); err != nil { + return nil, err + } + if out == nil { + return nil, ErrJSONLiteralNull + } + return out, nil +} + +func (api *URLsAPI) newResponse(resp *http.Response, err error) (*apimodel.URLsResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.URLsResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *OpenReportAPI) newResponse(resp *http.Response, err error) (*apimodel.OpenReportResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.OpenReportResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} + +func (api *SubmitMeasurementAPI) newResponse(resp *http.Response, err error) (*apimodel.SubmitMeasurementResponse, error) { + if err != nil { + return nil, err + } + if resp.StatusCode == 401 { + return nil, ErrUnauthorized + } + if resp.StatusCode != 200 { + return nil, newHTTPFailure(resp.StatusCode) + } + defer resp.Body.Close() + reader := io.LimitReader(resp.Body, 4<<20) + data, err := ioutil.ReadAll(reader) + if err != nil { + return nil, err + } + out := &apimodel.SubmitMeasurementResponse{} + if err := api.jsonCodec().Decode(data, out); err != nil { + return nil, err + } + return out, nil +} diff --git a/internal/engine/ooapi/swagger_test.go b/internal/engine/ooapi/swagger_test.go new file mode 100644 index 0000000..5cf18ab --- /dev/null +++ b/internal/engine/ooapi/swagger_test.go @@ -0,0 +1,578 @@ +// Code generated by go generate; DO NOT EDIT. +// 2021-02-26 15:45:53.881261959 +0100 CET m=+0.000594905 + +package ooapi + +//go:generate go run ./internal/generator -file swagger_test.go + +const swagger = `{ + "swagger": "2.0", + "info": { + "title": "OONI API specification", + "version": "0.20210226.2144553" + }, + "host": "api.ooni.io", + "basePath": "/", + "schemes": [ + "https" + ], + "paths": { + "/api/_/check_report_id": { + "get": { + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "report_id", + "required": true, + "type": "string" + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "error": { + "type": "string" + }, + "found": { + "type": "boolean" + }, + "v": { + "type": "integer" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/check-in": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "charging": { + "type": "boolean" + }, + "on_wifi": { + "type": "boolean" + }, + "platform": { + "type": "string" + }, + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "run_type": { + "type": "string" + }, + "software_name": { + "type": "string" + }, + "software_version": { + "type": "string" + }, + "web_connectivity": { + "properties": { + "category_codes": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "tests": { + "properties": { + "web_connectivity": { + "properties": { + "report_id": { + "type": "string" + }, + "urls": { + "items": { + "properties": { + "category_code": { + "type": "string" + }, + "country_code": { + "type": "string" + }, + "url": { + "type": "string" + } + }, + "type": "object" + }, + "type": "array" + } + }, + "type": "object" + } + }, + "type": "object" + }, + "v": { + "type": "integer" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/login": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "password": { + "type": "string" + }, + "username": { + "type": "string" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "expire": { + "type": "string" + }, + "token": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/measurement_meta": { + "get": { + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "report_id", + "required": true, + "type": "string" + }, + { + "in": "query", + "name": "full", + "type": "boolean" + }, + { + "in": "query", + "name": "input", + "type": "string" + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "anomaly": { + "type": "boolean" + }, + "category_code": { + "type": "string" + }, + "confirmed": { + "type": "boolean" + }, + "failure": { + "type": "boolean" + }, + "input": { + "type": "string" + }, + "measurement_start_time": { + "type": "string" + }, + "probe_asn": { + "type": "integer" + }, + "probe_cc": { + "type": "string" + }, + "raw_measurement": { + "type": "string" + }, + "report_id": { + "type": "string" + }, + "scores": { + "type": "string" + }, + "test_name": { + "type": "string" + }, + "test_start_time": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/register": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "available_bandwidth": { + "type": "string" + }, + "device_token": { + "type": "string" + }, + "language": { + "type": "string" + }, + "network_type": { + "type": "string" + }, + "password": { + "type": "string" + }, + "platform": { + "type": "string" + }, + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "probe_family": { + "type": "string" + }, + "probe_timezone": { + "type": "string" + }, + "software_name": { + "type": "string" + }, + "software_version": { + "type": "string" + }, + "supported_tests": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "client_id": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + }, + "/api/v1/test-helpers": { + "get": { + "produces": [ + "application/json" + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "type": "object" + } + } + } + } + }, + "/api/v1/test-list/psiphon-config": { + "get": { + "produces": [ + "application/json" + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "type": "object" + } + } + } + } + }, + "/api/v1/test-list/tor-targets": { + "get": { + "produces": [ + "application/json" + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "type": "object" + } + } + } + } + }, + "/api/v1/test-list/urls": { + "get": { + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "query", + "name": "category_codes", + "type": "string" + }, + { + "in": "query", + "name": "country_code", + "type": "string" + }, + { + "in": "query", + "name": "limit", + "type": "integer" + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "metadata": { + "properties": { + "count": { + "type": "integer" + } + }, + "type": "object" + }, + "results": { + "items": { + "properties": { + "category_code": { + "type": "string" + }, + "country_code": { + "type": "string" + }, + "url": { + "type": "string" + } + }, + "type": "object" + }, + "type": "array" + } + }, + "type": "object" + } + } + } + } + }, + "/report": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "data_format_version": { + "type": "string" + }, + "format": { + "type": "string" + }, + "probe_asn": { + "type": "string" + }, + "probe_cc": { + "type": "string" + }, + "software_name": { + "type": "string" + }, + "software_version": { + "type": "string" + }, + "test_name": { + "type": "string" + }, + "test_start_time": { + "type": "string" + }, + "test_version": { + "type": "string" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "backend_version": { + "type": "string" + }, + "report_id": { + "type": "string" + }, + "supported_formats": { + "items": { + "type": "string" + }, + "type": "array" + } + }, + "type": "object" + } + } + } + } + }, + "/report/{report_id}": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "in": "path", + "name": "report_id", + "required": true, + "type": "string" + }, + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "properties": { + "content": { + "type": "object" + }, + "format": { + "type": "string" + } + }, + "type": "object" + } + } + ], + "responses": { + "200": { + "description": "all good", + "schema": { + "properties": { + "measurement_uid": { + "type": "string" + } + }, + "type": "object" + } + } + } + } + } + } +}` diff --git a/internal/engine/ooapi/swaggerdiff_test.go b/internal/engine/ooapi/swaggerdiff_test.go new file mode 100644 index 0000000..6715bc9 --- /dev/null +++ b/internal/engine/ooapi/swaggerdiff_test.go @@ -0,0 +1,158 @@ +package ooapi + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "sort" + "strings" + "testing" + + "github.com/hexops/gotextdiff" + "github.com/hexops/gotextdiff/myers" + "github.com/hexops/gotextdiff/span" + "github.com/ooni/probe-cli/v3/internal/engine/ooapi/internal/openapi" +) + +const ( + productionURL = "https://api.ooni.io/apispec_1.json" + testingURL = "https://ams-pg-test.ooni.org/apispec_1.json" +) + +func makeModel(data []byte) *openapi.Swagger { + var out openapi.Swagger + if err := json.Unmarshal(data, &out); err != nil { + log.Fatal(err) + } + // We reduce irrelevant differences by producing a common header + return &openapi.Swagger{Paths: out.Paths} +} + +func getServerModel(serverURL string) *openapi.Swagger { + resp, err := http.Get(serverURL) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + return makeModel(data) +} + +func getClientModel() *openapi.Swagger { + return makeModel([]byte(swagger)) +} + +func simplifyRoundTrip(rt *openapi.RoundTrip) { + // Normalize the used name when a parameter is in body. This + // should only have a cosmetic impact on the spec. + for _, param := range rt.Parameters { + if param.In == "body" { + param.Name = "body" + } + } + + // Sort parameters so the comparison does not depend on order. + sort.SliceStable(rt.Parameters, func(i, j int) bool { + left, right := rt.Parameters[i].Name, rt.Parameters[j].Name + return strings.Compare(left, right) < 0 + }) + + // Normalize description of 200 response + rt.Responses.Successful.Description = "all good" +} + +func simplifyInPlace(path *openapi.Path) *openapi.Path { + if path.Get != nil && path.Post != nil { + log.Fatal("unsupported configuration") + } + if path.Get != nil { + simplifyRoundTrip(path.Get) + } + if path.Post != nil { + simplifyRoundTrip(path.Post) + } + return path +} + +func jsonify(model interface{}) string { + data, err := json.MarshalIndent(model, "", " ") + if err != nil { + log.Fatal(err) + } + return string(data) +} + +type diffable struct { + name string + value string +} + +func computediff(server, client *diffable) string { + d := gotextdiff.ToUnified(server.name, client.name, server.value, myers.ComputeEdits( + span.URIFromPath(server.name), server.value, client.value, + )) + return fmt.Sprint(d) +} + +// maybediff emits the diff between the server and the client and +// returns the length of the diff itself in bytes. +func maybediff(key string, server, client *openapi.Path) int { + diff := computediff(&diffable{ + name: fmt.Sprintf("server%s.json", key), + value: jsonify(simplifyInPlace(server)), + }, &diffable{ + name: fmt.Sprintf("client%s.json", key), + value: jsonify(simplifyInPlace(client)), + }) + if diff != "" { + fmt.Printf("%s", diff) + } + return len(diff) +} + +func compare(serverURL string) bool { + good := true + serverModel, clientModel := getServerModel(serverURL), getClientModel() + // Implementation note: the server model is richer than the client + // model, so we ignore everything not defined by the client. + var count int + for key := range serverModel.Paths { + if _, found := clientModel.Paths[key]; !found { + delete(serverModel.Paths, key) + continue + } + count++ + if maybediff(key, serverModel.Paths[key], clientModel.Paths[key]) > 0 { + good = false + } + } + if count <= 0 { + panic("no element found") + } + return good +} + +func TestWithProductionAPI(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + t.Log("using ", productionURL) + if !compare(productionURL) { + t.Fatal("model mismatch (see above)") + } +} + +func TestWithTestingAPI(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + t.Log("using ", testingURL) + if !compare(testingURL) { + t.Fatal("model mismatch (see above)") + } +} diff --git a/internal/engine/ooapi/utils.go b/internal/engine/ooapi/utils.go new file mode 100644 index 0000000..6d93a0b --- /dev/null +++ b/internal/engine/ooapi/utils.go @@ -0,0 +1,23 @@ +package ooapi + +import "fmt" + +func newErrEmptyField(field string) error { + return fmt.Errorf("%w: %s", ErrEmptyField, field) +} + +func newHTTPFailure(status int) error { + return fmt.Errorf("%w: %d", ErrHTTPFailure, status) +} + +func newQueryFieldInt64(v int64) string { + return fmt.Sprintf("%d", v) +} + +func newQueryFieldBool(v bool) string { + return fmt.Sprintf("%v", v) +} + +func newAuthorizationHeader(token string) string { + return fmt.Sprintf("Bearer %s", token) +} diff --git a/internal/engine/ooapi/utils_test.go b/internal/engine/ooapi/utils_test.go new file mode 100644 index 0000000..54b158b --- /dev/null +++ b/internal/engine/ooapi/utils_test.go @@ -0,0 +1,12 @@ +package ooapi + +import "testing" + +func TestNewQueryFieldBoolWorks(t *testing.T) { + if s := newQueryFieldBool(true); s != "true" { + t.Fatal("invalid encoding of true") + } + if s := newQueryFieldBool(false); s != "false" { + t.Fatal("invalid encoding of false") + } +}