refactor: move httpx into the internal package (#646)

This concludes the TODO list at https://github.com/ooni/probe/issues/1951
This commit is contained in:
Simone Basso
2022-01-05 17:17:20 +01:00
committed by GitHub
parent dba861d262
commit f0181c432f
26 changed files with 24 additions and 24 deletions
+241
View File
@@ -0,0 +1,241 @@
// Package httpx contains http extensions.
package httpx
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// APIClientTemplate is a template for constructing an APIClient.
type APIClientTemplate struct {
// Accept contains the OPTIONAL accept header.
Accept string
// Authorization contains the OPTIONAL authorization header.
Authorization string
// BaseURL is the MANDATORY base URL of the API.
BaseURL string
// HTTPClient is the MANDATORY underlying http client to use.
HTTPClient model.HTTPClient
// Host allows to OPTIONALLY set a specific host header. This is useful
// to implement, e.g., cloudfronting.
Host string
// LogBody is the OPTIONAL flag to force logging the bodies.
LogBody bool
// Logger is MANDATORY the logger to use.
Logger model.DebugLogger
// UserAgent is the OPTIONAL user agent to use.
UserAgent string
}
// WithBodyLogging enables logging of request and response bodies.
func (tmpl *APIClientTemplate) WithBodyLogging() *APIClientTemplate {
out := APIClientTemplate(*tmpl)
out.LogBody = true
return &out
}
// Build creates an APIClient from the APIClientTemplate.
func (tmpl *APIClientTemplate) Build() APIClient {
return tmpl.BuildWithAuthorization(tmpl.Authorization)
}
// BuildWithAuthorization creates an APIClient from the
// APIClientTemplate and ensures it uses the given authorization
// value for APIClient.Authorization in subsequent API calls.
func (tmpl *APIClientTemplate) BuildWithAuthorization(authorization string) APIClient {
ac := apiClient(*tmpl)
ac.Authorization = authorization
return &ac
}
// DefaultMaxBodySize is the default value for the maximum
// body size you can fetch using an APIClient.
const DefaultMaxBodySize = 1 << 22
// APIClient is a client configured to call a given API identified
// by a given baseURL and using a given model.HTTPClient.
type APIClient interface {
// GetJSON reads the JSON resource at resourcePath and unmarshals the
// results into output. The request is bounded by the lifetime of the
// context passed as argument. Returns the error that occurred.
GetJSON(ctx context.Context, resourcePath string, output interface{}) error
// GetJSONWithQuery is like GetJSON but also has a query.
GetJSONWithQuery(ctx context.Context, resourcePath string,
query url.Values, output interface{}) error
// PostJSON creates a JSON subresource of the resource at resourcePath
// using the JSON document at input and returning the result into the
// JSON document at output. The request is bounded by the context's
// lifetime. Returns the error that occurred.
PostJSON(ctx context.Context, resourcePath string, input, output interface{}) error
// FetchResource fetches the specified resource and returns it.
FetchResource(ctx context.Context, URLPath string) ([]byte, error)
}
// apiClient is an extended HTTP client. To construct this struct, make
// sure you initialize all fields marked as MANDATORY.
type apiClient struct {
// Accept contains the OPTIONAL accept header.
Accept string
// Authorization contains the OPTIONAL authorization header.
Authorization string
// BaseURL is the MANDATORY base URL of the API.
BaseURL string
// HTTPClient is the MANDATORY underlying http client to use.
HTTPClient model.HTTPClient
// Host allows to OPTIONALLY set a specific host header. This is useful
// to implement, e.g., cloudfronting.
Host string
// LogBody is the OPTIONAL flag to force logging the bodies.
LogBody bool
// Logger is MANDATORY the logger to use.
Logger model.DebugLogger
// UserAgent is the OPTIONAL user agent to use.
UserAgent string
}
// newRequestWithJSONBody creates a new request with a JSON body
func (c *apiClient) newRequestWithJSONBody(
ctx context.Context, method, resourcePath string,
query url.Values, body interface{}) (*http.Request, error) {
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
c.Logger.Debugf("httpx: request body length: %d bytes", len(data))
if c.LogBody {
c.Logger.Debugf("httpx: request body: %s", string(data))
}
request, err := c.newRequest(
ctx, method, resourcePath, query, bytes.NewReader(data))
if err != nil {
return nil, err
}
if body != nil {
request.Header.Set("Content-Type", "application/json")
}
return request, nil
}
// newRequest creates a new request.
func (c *apiClient) newRequest(ctx context.Context, method, resourcePath string,
query url.Values, body io.Reader) (*http.Request, error) {
URL, err := url.Parse(c.BaseURL)
if err != nil {
return nil, err
}
URL.Path = resourcePath
if query != nil {
URL.RawQuery = query.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, URL.String(), body)
if err != nil {
return nil, err
}
request.Host = c.Host // allow cloudfronting
if c.Authorization != "" {
request.Header.Set("Authorization", c.Authorization)
}
if c.Accept != "" {
request.Header.Set("Accept", c.Accept)
}
request.Header.Set("User-Agent", c.UserAgent)
return request, nil
}
// ErrRequestFailed indicates that the server returned >= 400.
var ErrRequestFailed = errors.New("httpx: request failed")
// do performs the provided request and returns the response body or an error.
func (c *apiClient) do(request *http.Request) ([]byte, error) {
response, err := c.HTTPClient.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
// Implementation note: always read and log the response body since
// it's quite useful to see the response JSON on API error.
r := io.LimitReader(response.Body, DefaultMaxBodySize)
data, err := netxlite.ReadAllContext(request.Context(), r)
if err != nil {
return nil, err
}
c.Logger.Debugf("httpx: response body length: %d bytes", len(data))
if c.LogBody {
c.Logger.Debugf("httpx: response body: %s", string(data))
}
if response.StatusCode >= 400 {
return nil, fmt.Errorf("%w: %s", ErrRequestFailed, response.Status)
}
return data, nil
}
// doJSON performs the provided request and unmarshals the JSON response body
// into the provided output variable.
func (c *apiClient) doJSON(request *http.Request, output interface{}) error {
data, err := c.do(request)
if err != nil {
return err
}
return json.Unmarshal(data, output)
}
// GetJSON implements APIClient.GetJSON.
func (c *apiClient) GetJSON(ctx context.Context, resourcePath string, output interface{}) error {
return c.GetJSONWithQuery(ctx, resourcePath, nil, output)
}
// GetJSONWithQuery implements APIClient.GetJSONWithQuery.
func (c *apiClient) GetJSONWithQuery(
ctx context.Context, resourcePath string,
query url.Values, output interface{}) error {
request, err := c.newRequest(ctx, "GET", resourcePath, query, nil)
if err != nil {
return err
}
return c.doJSON(request, output)
}
// PostJSON implements APIClient.PostJSON.
func (c *apiClient) PostJSON(
ctx context.Context, resourcePath string, input, output interface{}) error {
request, err := c.newRequestWithJSONBody(ctx, "POST", resourcePath, nil, input)
if err != nil {
return err
}
return c.doJSON(request, output)
}
// FetchResource implements APIClient.FetchResource.
func (c *apiClient) FetchResource(ctx context.Context, URLPath string) ([]byte, error) {
request, err := c.newRequest(ctx, "GET", URLPath, nil, nil)
if err != nil {
return nil, err
}
return c.do(request)
}
+612
View File
@@ -0,0 +1,612 @@
package httpx
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/fakefill"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/version"
)
// userAgent is the user agent used by this test suite
var userAgent = fmt.Sprintf("ooniprobe-cli/%s", version.Version)
func TestAPIClientTemplate(t *testing.T) {
t.Run("WithBodyLogging", func(t *testing.T) {
tmpl := &APIClientTemplate{
HTTPClient: http.DefaultClient,
LogBody: false, // explicit default initialization for clarity
Logger: model.DiscardLogger,
}
child := tmpl.WithBodyLogging()
if !child.LogBody {
t.Fatal("expected body logging to be enabled")
}
if tmpl.LogBody {
t.Fatal("expected body logging to still be disabled")
}
})
t.Run("normal constructor", func(t *testing.T) {
// Implementation note: the fakefiller will ignore the
// fields it does not know how to fill, so we are filling
// those fields with plausible values in advance
tmpl := &APIClientTemplate{
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
}
ff := &fakefill.Filler{}
ff.Fill(tmpl)
ac := tmpl.Build()
orig := apiClient(*tmpl)
if diff := cmp.Diff(&orig, ac); diff != "" {
t.Fatal(diff)
}
})
t.Run("constructor with authorization", func(t *testing.T) {
// Implementation note: the fakefiller will ignore the
// fields it does not know how to fill, so we are filling
// those fields with plausible values in advance
tmpl := &APIClientTemplate{
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
}
ff := &fakefill.Filler{}
ff.Fill(tmpl)
tok := ""
ff.Fill(&tok)
ac := tmpl.BuildWithAuthorization(tok)
// the authorization should be different now
if tmpl.Authorization == ac.(*apiClient).Authorization {
t.Fatal("we expect Authorization to be different")
}
// clear authorization for the comparison
tmpl.Authorization = ""
ac.(*apiClient).Authorization = ""
orig := apiClient(*tmpl)
if diff := cmp.Diff(&orig, ac); diff != "" {
t.Fatal(diff)
}
})
}
// newAPIClient is an helper factory creating a client for testing.
func newAPIClient() *apiClient {
return &apiClient{
BaseURL: "https://example.com",
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
UserAgent: userAgent,
}
}
// fakeRequest is a fake request we serialize.
type fakeRequest struct {
Name string
Age int
Sleeping bool
Attributes map[string][]string
}
func TestAPIClient(t *testing.T) {
t.Run("newRequestWithJSONBody", func(t *testing.T) {
t.Run("JSON marshal failure", func(t *testing.T) {
client := newAPIClient()
req, err := client.newRequestWithJSONBody(
context.Background(), "GET", "/", nil, make(chan interface{}),
)
if err == nil || !strings.HasPrefix(err.Error(), "json: unsupported type") {
t.Fatal("not the error we expected", err)
}
if req != nil {
t.Fatal("expected nil request here")
}
})
t.Run("newRequest failure", func(t *testing.T) {
client := newAPIClient()
client.BaseURL = "\t\t\t" // cause URL parse error
req, err := client.newRequestWithJSONBody(
context.Background(), "GET", "/", nil, nil,
)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("not the error we expected")
}
if req != nil {
t.Fatal("expected nil request here")
}
})
t.Run("sets the content-type properly", func(t *testing.T) {
var jsonReq fakeRequest
ff := &fakefill.Filler{}
ff.Fill(&jsonReq)
client := newAPIClient()
req, err := client.newRequestWithJSONBody(
context.Background(), "GET", "/", nil, jsonReq,
)
if err != nil {
t.Fatal(err)
}
if req.Header.Get("Content-Type") != "application/json" {
t.Fatal("did not set content-type properly")
}
})
})
t.Run("newRequest", func(t *testing.T) {
t.Run("with invalid method", func(t *testing.T) {
client := newAPIClient()
req, err := client.newRequest(
context.Background(), "\t\t\t", "/", nil, nil,
)
if err == nil || !strings.HasPrefix(err.Error(), "net/http: invalid method") {
t.Fatal("not the error we expected")
}
if req != nil {
t.Fatal("expected nil request here")
}
})
t.Run("with query", func(t *testing.T) {
client := newAPIClient()
q := url.Values{}
q.Add("antani", "mascetti")
q.Add("melandri", "conte")
req, err := client.newRequest(
context.Background(), "GET", "/", q, nil,
)
if err != nil {
t.Fatal(err)
}
if req.URL.Query().Get("antani") != "mascetti" {
t.Fatal("expected different query string here")
}
if req.URL.Query().Get("melandri") != "conte" {
t.Fatal("expected different query string here")
}
})
t.Run("with authorization", func(t *testing.T) {
client := newAPIClient()
client.Authorization = "deadbeef"
req, err := client.newRequest(
context.Background(), "GET", "/", nil, nil,
)
if err != nil {
t.Fatal(err)
}
if req.Header.Get("Authorization") != client.Authorization {
t.Fatal("expected different Authorization here")
}
})
t.Run("with accept", func(t *testing.T) {
client := newAPIClient()
client.Accept = "application/xml"
req, err := client.newRequestWithJSONBody(
context.Background(), "GET", "/", nil, []string{},
)
if err != nil {
t.Fatal(err)
}
if req.Header.Get("Accept") != "application/xml" {
t.Fatal("expected different Accept here")
}
})
t.Run("with custom host header", func(t *testing.T) {
client := newAPIClient()
client.Host = "www.x.org"
req, err := client.newRequest(
context.Background(), "GET", "/", nil, nil,
)
if err != nil {
t.Fatal(err)
}
if req.Host != client.Host {
t.Fatal("expected different req.Host here")
}
})
t.Run("with user agent", func(t *testing.T) {
client := newAPIClient()
req, err := client.newRequest(
context.Background(), "GET", "/", nil, nil,
)
if err != nil {
t.Fatal(err)
}
if req.Header.Get("User-Agent") != userAgent {
t.Fatal("expected different User-Agent here")
}
})
})
t.Run("doJSON", func(t *testing.T) {
t.Run("do failure", func(t *testing.T) {
expected := errors.New("mocked error")
client := newAPIClient()
client.HTTPClient = &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
return nil, expected
},
}
err := client.doJSON(&http.Request{URL: &url.URL{Scheme: "https", Host: "x.org"}}, nil)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
})
t.Run("response is not successful (i.e., >= 400)", func(t *testing.T) {
client := newAPIClient()
client.HTTPClient = &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 401,
Body: io.NopCloser(strings.NewReader("{}")),
}, nil
},
}
err := client.doJSON(&http.Request{URL: &url.URL{Scheme: "https", Host: "x.org"}}, nil)
if !errors.Is(err, ErrRequestFailed) {
t.Fatal("not the error we expected", err)
}
})
t.Run("cannot read body", func(t *testing.T) {
expected := errors.New("mocked error")
client := newAPIClient()
client.HTTPClient = &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}),
}, nil
},
}
err := client.doJSON(&http.Request{URL: &url.URL{Scheme: "https", Host: "x.org"}}, nil)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
})
t.Run("response is not JSON", func(t *testing.T) {
client := newAPIClient()
client.HTTPClient = &mocks.HTTPClient{
MockDo: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("[")),
}, nil
},
}
err := client.doJSON(&http.Request{URL: &url.URL{Scheme: "https", Host: "x.org"}}, nil)
if err == nil || err.Error() != "unexpected end of JSON input" {
t.Fatal("not the error we expected")
}
})
})
t.Run("GetJSON", func(t *testing.T) {
t.Run("successful case", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`["foo", "bar"]`))
},
))
defer server.Close()
ctx := context.Background()
var result []string
err := (&apiClient{
BaseURL: server.URL,
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
}).GetJSON(ctx, "/", &result)
if err != nil {
t.Fatal(err)
}
if len(result) != 2 || result[0] != "foo" || result[1] != "bar" {
t.Fatal("invalid result", result)
}
})
t.Run("failure case", func(t *testing.T) {
var headers []string
client := newAPIClient()
client.BaseURL = "\t\t\t\t"
err := client.GetJSON(context.Background(), "/", &headers)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("not the error we expected")
}
})
})
t.Run("PostJSON", func(t *testing.T) {
t.Run("successful case", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
var incoming []string
data, err := netxlite.ReadAllContext(r.Context(), r.Body)
if err != nil {
w.WriteHeader(500)
return
}
if err := json.Unmarshal(data, &incoming); err != nil {
w.WriteHeader(500)
return
}
w.Write(data)
},
))
defer server.Close()
ctx := context.Background()
incoming := []string{"foo", "bar"}
var result []string
err := (&apiClient{
BaseURL: server.URL,
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
}).PostJSON(ctx, "/", incoming, &result)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(incoming, result); diff != "" {
t.Fatal(diff)
}
})
t.Run("failure case", func(t *testing.T) {
incoming := []string{"foo", "bar"}
var result []string
client := newAPIClient()
client.BaseURL = "\t\t\t\t"
err := client.PostJSON(context.Background(), "/", incoming, &result)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("not the error we expected")
}
})
})
t.Run("FetchResource", func(t *testing.T) {
t.Run("successful case", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("deadbeef"))
},
))
defer server.Close()
ctx := context.Background()
data, err := (&apiClient{
BaseURL: server.URL,
HTTPClient: http.DefaultClient,
Logger: model.DiscardLogger,
}).FetchResource(ctx, "/")
if err != nil {
t.Fatal(err)
}
if string(data) != "deadbeef" {
t.Fatal("invalid data")
}
})
t.Run("failure case", func(t *testing.T) {
client := newAPIClient()
client.BaseURL = "\t\t\t\t"
data, err := client.FetchResource(context.Background(), "/")
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("unexpected data")
}
})
})
t.Run("we honour context", func(t *testing.T) {
// It should suffice to check one of the public methods here
client := newAPIClient()
ctx, cancel := context.WithCancel(context.Background())
cancel() // test should fail
data, err := client.FetchResource(ctx, "/")
if !errors.Is(err, context.Canceled) {
t.Fatal("unexpected err", err)
}
if data != nil {
t.Fatal("unexpected data")
}
})
t.Run("body logging", func(t *testing.T) {
t.Run("logging enabled and 200 Ok", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("[]"))
},
))
logs := make(chan string, 1024)
defer server.Close()
var (
input []string
output []string
)
ctx := context.Background()
err := (&apiClient{
BaseURL: server.URL,
HTTPClient: http.DefaultClient,
LogBody: true,
Logger: &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
logs <- fmt.Sprintf(format, v...)
},
},
}).PostJSON(ctx, "/", input, &output)
var found int
close(logs)
for entry := range logs {
if strings.HasPrefix(entry, "httpx: request body: ") {
found |= 1 << 0
continue
}
if strings.HasPrefix(entry, "httpx: response body: ") {
found |= 1 << 1
continue
}
}
if found != (1<<0 | 1<<1) {
t.Fatal("did not find logs")
}
if err != nil {
t.Fatal(err)
}
})
t.Run("logging enabled and 401 Unauthorized", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte("[]"))
},
))
logs := make(chan string, 1024)
defer server.Close()
var (
input []string
output []string
)
ctx := context.Background()
err := (&apiClient{
BaseURL: server.URL,
HTTPClient: http.DefaultClient,
LogBody: true,
Logger: &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
logs <- fmt.Sprintf(format, v...)
},
},
}).PostJSON(ctx, "/", input, &output)
var found int
close(logs)
for entry := range logs {
if strings.HasPrefix(entry, "httpx: request body: ") {
found |= 1 << 0
continue
}
if strings.HasPrefix(entry, "httpx: response body: ") {
found |= 1 << 1
continue
}
}
if found != (1<<0 | 1<<1) {
t.Fatal("did not find logs")
}
if !errors.Is(err, ErrRequestFailed) {
t.Fatal("unexpected err", err)
}
})
t.Run("logging NOT enabled and 200 Ok", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("[]"))
},
))
logs := make(chan string, 1024)
defer server.Close()
var (
input []string
output []string
)
ctx := context.Background()
err := (&apiClient{
BaseURL: server.URL,
HTTPClient: http.DefaultClient,
LogBody: false, // explicit initialization
Logger: &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
logs <- fmt.Sprintf(format, v...)
},
},
}).PostJSON(ctx, "/", input, &output)
var found int
close(logs)
for entry := range logs {
if strings.HasPrefix(entry, "httpx: request body: ") {
found |= 1 << 0
continue
}
if strings.HasPrefix(entry, "httpx: response body: ") {
found |= 1 << 1
continue
}
}
if found != 0 {
t.Fatal("did find logs")
}
if err != nil {
t.Fatal(err)
}
})
t.Run("logging NOT enabled and 401 Unauthorized", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(401)
w.Write([]byte("[]"))
},
))
logs := make(chan string, 1024)
defer server.Close()
var (
input []string
output []string
)
ctx := context.Background()
err := (&apiClient{
BaseURL: server.URL,
HTTPClient: http.DefaultClient,
LogBody: false, // explicit initialization
Logger: &mocks.Logger{
MockDebugf: func(format string, v ...interface{}) {
logs <- fmt.Sprintf(format, v...)
},
},
}).PostJSON(ctx, "/", input, &output)
var found int
close(logs)
for entry := range logs {
if strings.HasPrefix(entry, "httpx: request body: ") {
found |= 1 << 0
continue
}
if strings.HasPrefix(entry, "httpx: response body: ") {
found |= 1 << 1
continue
}
}
if found != 0 {
t.Fatal("did find logs")
}
if !errors.Is(err, ErrRequestFailed) {
t.Fatal("unexpected err", err)
}
})
})
}