From d419ed8ac8c584f5de3078c34a3d347a9e33fc77 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Tue, 5 Jul 2022 20:25:18 +0200 Subject: [PATCH] refactor(oohelperd): improve tests implementation (#835) After this diff has landed, we have addressed all the points originally published at https://github.com/ooni/probe/issues/2134. --- internal/cmd/oohelperd/fake_test.go | 143 ------------------------- internal/cmd/oohelperd/handler_test.go | 21 +++- internal/cmd/oohelperd/http_test.go | 10 +- internal/cmd/oohelperd/main.go | 28 ++--- internal/cmd/oohelperd/main_test.go | 83 ++++++++++++-- internal/model/mocks/http.go | 23 ++++ internal/model/mocks/http_test.go | 46 ++++++++ 7 files changed, 186 insertions(+), 168 deletions(-) delete mode 100644 internal/cmd/oohelperd/fake_test.go diff --git a/internal/cmd/oohelperd/fake_test.go b/internal/cmd/oohelperd/fake_test.go deleted file mode 100644 index 0f93128..0000000 --- a/internal/cmd/oohelperd/fake_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package main - -import ( - "context" - "errors" - "io" - "net" - "net/http" - "time" - - "github.com/ooni/probe-cli/v3/internal/atomicx" - "github.com/ooni/probe-cli/v3/internal/model" - "github.com/ooni/probe-cli/v3/internal/netxlite" -) - -type FakeResolver struct { - NumFailures *atomicx.Int64 - Err error - Result []string -} - -func NewFakeResolverThatFails() FakeResolver { - return FakeResolver{NumFailures: &atomicx.Int64{}, Err: ErrNotFound} -} - -func NewFakeResolverWithResult(r []string) FakeResolver { - return FakeResolver{NumFailures: &atomicx.Int64{}, Result: r} -} - -var ErrNotFound = &net.DNSError{ - Err: "no such host", -} - -func (c FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { - time.Sleep(10 * time.Microsecond) - if c.Err != nil { - if c.NumFailures != nil { - c.NumFailures.Add(1) - } - return nil, c.Err - } - return c.Result, nil -} - -func (c FakeResolver) Network() string { - return "fake" -} - -func (c FakeResolver) Address() string { - return "" -} - -func (c FakeResolver) CloseIdleConnections() {} - -func (c FakeResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { - return nil, errors.New("not implemented") -} - -func (c FakeResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { - return nil, errors.New("not implemented") -} - -var _ model.Resolver = FakeResolver{} - -type FakeTransport struct { - Name string - Err error - Func func(*http.Request) (*http.Response, error) - Resp *http.Response -} - -func (txp FakeTransport) Network() string { - return txp.Name -} - -func (txp FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { - time.Sleep(10 * time.Microsecond) - if txp.Func != nil { - return txp.Func(req) - } - if req.Body != nil { - netxlite.ReadAllContext(req.Context(), req.Body) - req.Body.Close() - } - if txp.Err != nil { - return nil, txp.Err - } - txp.Resp.Request = req // non thread safe but it doesn't matter - return txp.Resp, nil -} - -func (txp FakeTransport) CloseIdleConnections() {} - -var _ model.HTTPTransport = FakeTransport{} - -type FakeBody struct { - Data []byte - Err error -} - -func (fb *FakeBody) Read(p []byte) (int, error) { - time.Sleep(10 * time.Microsecond) - if fb.Err != nil { - return 0, fb.Err - } - if len(fb.Data) <= 0 { - return 0, io.EOF - } - n := copy(p, fb.Data) - fb.Data = fb.Data[n:] - return n, nil -} - -func (fb *FakeBody) Close() error { - return nil -} - -var _ io.ReadCloser = &FakeBody{} - -type FakeResponseWriter struct { - Body [][]byte - HeaderMap http.Header - StatusCode int -} - -func NewFakeResponseWriter() *FakeResponseWriter { - return &FakeResponseWriter{HeaderMap: make(http.Header)} -} - -func (frw *FakeResponseWriter) Header() http.Header { - return frw.HeaderMap -} - -func (frw *FakeResponseWriter) Write(b []byte) (int, error) { - frw.Body = append(frw.Body, b) - return len(b), nil -} - -func (frw *FakeResponseWriter) WriteHeader(statusCode int) { - frw.StatusCode = statusCode -} - -var _ http.ResponseWriter = &FakeResponseWriter{} diff --git a/internal/cmd/oohelperd/handler_test.go b/internal/cmd/oohelperd/handler_test.go index e523826..6613e1f 100644 --- a/internal/cmd/oohelperd/handler_test.go +++ b/internal/cmd/oohelperd/handler_test.go @@ -4,12 +4,14 @@ import ( "context" "encoding/json" "errors" + "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -149,17 +151,30 @@ func TestWorkingAsIntended(t *testing.T) { func TestHandlerWithRequestBodyReadingError(t *testing.T) { expected := errors.New("mocked error") handler := handler{MaxAcceptableBody: 1 << 24} - rw := NewFakeResponseWriter() + var statusCode int + headers := http.Header{} + rw := &mocks.HTTPResponseWriter{ + MockWriteHeader: func(code int) { + statusCode = code + }, + MockHeader: func() http.Header { + return headers + }, + } req := &http.Request{ Method: "POST", Header: map[string][]string{ "Content-Type": {"application/json"}, "Content-Length": {"2048"}, }, - Body: &FakeBody{Err: expected}, + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + }), } handler.ServeHTTP(rw, req) - if rw.StatusCode != 400 { + if statusCode != 400 { t.Fatal("unexpected status code") } } diff --git a/internal/cmd/oohelperd/http_test.go b/internal/cmd/oohelperd/http_test.go index c51d8f8..8821a8e 100644 --- a/internal/cmd/oohelperd/http_test.go +++ b/internal/cmd/oohelperd/http_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -46,8 +47,13 @@ func TestHTTPDoWithHTTPTransportFailure(t *testing.T) { MaxAcceptableBody: 1 << 24, NewClient: func() model.HTTPClient { return &http.Client{ - Transport: FakeTransport{ - Err: expected, + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, expected + }, + MockCloseIdleConnections: func() { + // nothing + }, }, } }, diff --git a/internal/cmd/oohelperd/main.go b/internal/cmd/oohelperd/main.go index 245562f..485005c 100644 --- a/internal/cmd/oohelperd/main.go +++ b/internal/cmd/oohelperd/main.go @@ -4,6 +4,7 @@ package main import ( "context" "flag" + "net" "net/http" "sync" "time" @@ -11,19 +12,21 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" ) const maxAcceptableBody = 1 << 24 var ( endpoint = flag.String("endpoint", ":8080", "Endpoint where to listen") - srvcancel context.CancelFunc - srvctx context.Context - srvwg = new(sync.WaitGroup) + srvAddr = make(chan string, 1) // with buffer + srvCancel context.CancelFunc + srvCtx context.Context + srvWg = new(sync.WaitGroup) ) func init() { - srvctx, srvcancel = context.WithCancel(context.Background()) + srvCtx, srvCancel = context.WithCancel(context.Background()) } func newResolver() model.Resolver { @@ -48,10 +51,7 @@ func main() { debug := flag.Bool("debug", false, "Toggle debug mode") flag.Parse() log.SetLevel(logmap[*debug]) - testableMain() -} - -func testableMain() { + defer srvCancel() mux := http.NewServeMux() mux.Handle("/", &handler{ MaxAcceptableBody: maxAcceptableBody, @@ -64,9 +64,13 @@ func testableMain() { NewResolver: newResolver, }) srv := &http.Server{Addr: *endpoint, Handler: mux} - srvwg.Add(1) - go srv.ListenAndServe() - <-srvctx.Done() + listener, err := net.Listen("tcp", *endpoint) + runtimex.PanicOnError(err, "net.Listen failed") + srvAddr <- listener.Addr().String() + srvWg.Add(1) + go srv.Serve(listener) + <-srvCtx.Done() shutdown(srv) - srvwg.Done() + listener.Close() + srvWg.Done() } diff --git a/internal/cmd/oohelperd/main_test.go b/internal/cmd/oohelperd/main_test.go index 9a6d486..54aab25 100644 --- a/internal/cmd/oohelperd/main_test.go +++ b/internal/cmd/oohelperd/main_test.go @@ -1,15 +1,82 @@ package main import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/url" + "strings" "testing" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" ) -func TestSmoke(t *testing.T) { - // Just check whether we can start and then tear down the server, so - // we have coverage of this code and when we see that some lines aren't - // covered we know these are genuine places where we're not testing - // the code rather than just places like this simple main. - go testableMain() - srvcancel() // kills the listener - srvwg.Wait() // joined +func TestWorkAsIntended(t *testing.T) { + // let the kernel pick a random free port + *endpoint = "127.0.0.1:0" + + // run the main function in a background goroutine + go main() + + // prepare the HTTP request body + jsonReq := ctrlRequest{ + HTTPRequest: "https://dns.google", + HTTPRequestHeaders: map[string][]string{ + "Accept": {model.HTTPHeaderAccept}, + "Accept-Language": {model.HTTPHeaderAcceptLanguage}, + "User-Agent": {model.HTTPHeaderUserAgent}, + }, + TCPConnect: []string{ + "8.8.8.8:443", + "8.8.4.4:443", + }, + } + data, err := json.Marshal(jsonReq) + runtimex.PanicOnError(err, "cannot marshal request") + + // construct the test helper's URL + endpoint := <-srvAddr + URL := &url.URL{ + Scheme: "http", + Host: endpoint, + Path: "/", + } + req, err := http.NewRequest("POST", URL.String(), bytes.NewReader(data)) + runtimex.PanicOnError(err, "cannot create new HTTP request") + + // issue the request and get the response + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatal("unexpected status code", resp.StatusCode) + } + + // read the response body + data, err = netxlite.ReadAllContext(context.Background(), resp.Body) + if err != nil { + t.Fatal(err) + } + + // parse the response + var jsonResp ctrlResponse + if err := json.Unmarshal(data, &jsonResp); err != nil { + t.Fatal(err) + } + + // very simple correctness check + if !strings.Contains(jsonResp.HTTPRequest.Title, "Google") { + t.Fatal("expected the response title to contain the string Google") + } + + // tear down the TH + srvCancel() + + // wait for the background goroutine to join + srvWg.Wait() } diff --git a/internal/model/mocks/http.go b/internal/model/mocks/http.go index 4500f84..a599c95 100644 --- a/internal/model/mocks/http.go +++ b/internal/model/mocks/http.go @@ -40,3 +40,26 @@ func (txp *HTTPClient) Do(req *http.Request) (*http.Response, error) { func (txp *HTTPClient) CloseIdleConnections() { txp.MockCloseIdleConnections() } + +// HTTPResponseWriter allows mocking http.ResponseWriter. +type HTTPResponseWriter struct { + MockHeader func() http.Header + + MockWrite func(b []byte) (int, error) + + MockWriteHeader func(statusCode int) +} + +var _ http.ResponseWriter = &HTTPResponseWriter{} + +func (w *HTTPResponseWriter) Header() http.Header { + return w.MockHeader() +} + +func (w *HTTPResponseWriter) Write(b []byte) (int, error) { + return w.MockWrite(b) +} + +func (w *HTTPResponseWriter) WriteHeader(statusCode int) { + w.MockWriteHeader(statusCode) +} diff --git a/internal/model/mocks/http_test.go b/internal/model/mocks/http_test.go index 0cabb67..cf6f868 100644 --- a/internal/model/mocks/http_test.go +++ b/internal/model/mocks/http_test.go @@ -81,3 +81,49 @@ func TestHTTPClient(t *testing.T) { } }) } + +func TestHTTPResponseWriter(t *testing.T) { + t.Run("Header", func(t *testing.T) { + expect := http.Header{} + w := &HTTPResponseWriter{ + MockHeader: func() http.Header { + return expect + }, + } + got := w.Header() + got.Set("Content-Type", "text/plain") + if expect.Get("Content-Type") != "text/plain" { + t.Fatal("we didn't get the expected header value") + } + }) + + t.Run("Write", func(t *testing.T) { + expected := errors.New("mocked error") + w := &HTTPResponseWriter{ + MockWrite: func(b []byte) (int, error) { + return 0, expected + }, + } + buffer := make([]byte, 16) + count, err := w.Write(buffer) + if count != 0 { + t.Fatal("invalid count") + } + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + }) + + t.Run("WriteHeader", func(t *testing.T) { + var called bool + w := &HTTPResponseWriter{ + MockWriteHeader: func(statusCode int) { + called = true + }, + } + w.WriteHeader(200) + if !called { + t.Fatal("not called") + } + }) +}