refactor(oohelperd): improve tests implementation (#835)
After this diff has landed, we have addressed all the points originally published at https://github.com/ooni/probe/issues/2134.
This commit is contained in:
parent
535a5d3e00
commit
d419ed8ac8
|
@ -1,143 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
type FakeResolver struct {
|
||||
NumFailures *atomicx.Int64
|
||||
Err error
|
||||
Result []string
|
||||
}
|
||||
|
||||
func NewFakeResolverThatFails() FakeResolver {
|
||||
return FakeResolver{NumFailures: &atomicx.Int64{}, Err: ErrNotFound}
|
||||
}
|
||||
|
||||
func NewFakeResolverWithResult(r []string) FakeResolver {
|
||||
return FakeResolver{NumFailures: &atomicx.Int64{}, Result: r}
|
||||
}
|
||||
|
||||
var ErrNotFound = &net.DNSError{
|
||||
Err: "no such host",
|
||||
}
|
||||
|
||||
func (c FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
if c.Err != nil {
|
||||
if c.NumFailures != nil {
|
||||
c.NumFailures.Add(1)
|
||||
}
|
||||
return nil, c.Err
|
||||
}
|
||||
return c.Result, nil
|
||||
}
|
||||
|
||||
func (c FakeResolver) Network() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (c FakeResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c FakeResolver) CloseIdleConnections() {}
|
||||
|
||||
func (c FakeResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c FakeResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
var _ model.Resolver = FakeResolver{}
|
||||
|
||||
type FakeTransport struct {
|
||||
Name string
|
||||
Err error
|
||||
Func func(*http.Request) (*http.Response, error)
|
||||
Resp *http.Response
|
||||
}
|
||||
|
||||
func (txp FakeTransport) Network() string {
|
||||
return txp.Name
|
||||
}
|
||||
|
||||
func (txp FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
if txp.Func != nil {
|
||||
return txp.Func(req)
|
||||
}
|
||||
if req.Body != nil {
|
||||
netxlite.ReadAllContext(req.Context(), req.Body)
|
||||
req.Body.Close()
|
||||
}
|
||||
if txp.Err != nil {
|
||||
return nil, txp.Err
|
||||
}
|
||||
txp.Resp.Request = req // non thread safe but it doesn't matter
|
||||
return txp.Resp, nil
|
||||
}
|
||||
|
||||
func (txp FakeTransport) CloseIdleConnections() {}
|
||||
|
||||
var _ model.HTTPTransport = FakeTransport{}
|
||||
|
||||
type FakeBody struct {
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (fb *FakeBody) Read(p []byte) (int, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
if fb.Err != nil {
|
||||
return 0, fb.Err
|
||||
}
|
||||
if len(fb.Data) <= 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, fb.Data)
|
||||
fb.Data = fb.Data[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (fb *FakeBody) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = &FakeBody{}
|
||||
|
||||
type FakeResponseWriter struct {
|
||||
Body [][]byte
|
||||
HeaderMap http.Header
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func NewFakeResponseWriter() *FakeResponseWriter {
|
||||
return &FakeResponseWriter{HeaderMap: make(http.Header)}
|
||||
}
|
||||
|
||||
func (frw *FakeResponseWriter) Header() http.Header {
|
||||
return frw.HeaderMap
|
||||
}
|
||||
|
||||
func (frw *FakeResponseWriter) Write(b []byte) (int, error) {
|
||||
frw.Body = append(frw.Body, b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (frw *FakeResponseWriter) WriteHeader(statusCode int) {
|
||||
frw.StatusCode = statusCode
|
||||
}
|
||||
|
||||
var _ http.ResponseWriter = &FakeResponseWriter{}
|
|
@ -4,12 +4,14 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
|
@ -149,17 +151,30 @@ func TestWorkingAsIntended(t *testing.T) {
|
|||
func TestHandlerWithRequestBodyReadingError(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
handler := handler{MaxAcceptableBody: 1 << 24}
|
||||
rw := NewFakeResponseWriter()
|
||||
var statusCode int
|
||||
headers := http.Header{}
|
||||
rw := &mocks.HTTPResponseWriter{
|
||||
MockWriteHeader: func(code int) {
|
||||
statusCode = code
|
||||
},
|
||||
MockHeader: func() http.Header {
|
||||
return headers
|
||||
},
|
||||
}
|
||||
req := &http.Request{
|
||||
Method: "POST",
|
||||
Header: map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"Content-Length": {"2048"},
|
||||
},
|
||||
Body: &FakeBody{Err: expected},
|
||||
Body: io.NopCloser(&mocks.Reader{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, expected
|
||||
},
|
||||
}),
|
||||
}
|
||||
handler.ServeHTTP(rw, req)
|
||||
if rw.StatusCode != 400 {
|
||||
if statusCode != 400 {
|
||||
t.Fatal("unexpected status code")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
|
||||
|
@ -46,8 +47,13 @@ func TestHTTPDoWithHTTPTransportFailure(t *testing.T) {
|
|||
MaxAcceptableBody: 1 << 24,
|
||||
NewClient: func() model.HTTPClient {
|
||||
return &http.Client{
|
||||
Transport: FakeTransport{
|
||||
Err: expected,
|
||||
Transport: &mocks.HTTPTransport{
|
||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockCloseIdleConnections: func() {
|
||||
// nothing
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
|
|
@ -4,6 +4,7 @@ package main
|
|||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -11,19 +12,21 @@ import (
|
|||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
const maxAcceptableBody = 1 << 24
|
||||
|
||||
var (
|
||||
endpoint = flag.String("endpoint", ":8080", "Endpoint where to listen")
|
||||
srvcancel context.CancelFunc
|
||||
srvctx context.Context
|
||||
srvwg = new(sync.WaitGroup)
|
||||
srvAddr = make(chan string, 1) // with buffer
|
||||
srvCancel context.CancelFunc
|
||||
srvCtx context.Context
|
||||
srvWg = new(sync.WaitGroup)
|
||||
)
|
||||
|
||||
func init() {
|
||||
srvctx, srvcancel = context.WithCancel(context.Background())
|
||||
srvCtx, srvCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
|
||||
func newResolver() model.Resolver {
|
||||
|
@ -48,10 +51,7 @@ func main() {
|
|||
debug := flag.Bool("debug", false, "Toggle debug mode")
|
||||
flag.Parse()
|
||||
log.SetLevel(logmap[*debug])
|
||||
testableMain()
|
||||
}
|
||||
|
||||
func testableMain() {
|
||||
defer srvCancel()
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", &handler{
|
||||
MaxAcceptableBody: maxAcceptableBody,
|
||||
|
@ -64,9 +64,13 @@ func testableMain() {
|
|||
NewResolver: newResolver,
|
||||
})
|
||||
srv := &http.Server{Addr: *endpoint, Handler: mux}
|
||||
srvwg.Add(1)
|
||||
go srv.ListenAndServe()
|
||||
<-srvctx.Done()
|
||||
listener, err := net.Listen("tcp", *endpoint)
|
||||
runtimex.PanicOnError(err, "net.Listen failed")
|
||||
srvAddr <- listener.Addr().String()
|
||||
srvWg.Add(1)
|
||||
go srv.Serve(listener)
|
||||
<-srvCtx.Done()
|
||||
shutdown(srv)
|
||||
srvwg.Done()
|
||||
listener.Close()
|
||||
srvWg.Done()
|
||||
}
|
||||
|
|
|
@ -1,15 +1,82 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
func TestSmoke(t *testing.T) {
|
||||
// Just check whether we can start and then tear down the server, so
|
||||
// we have coverage of this code and when we see that some lines aren't
|
||||
// covered we know these are genuine places where we're not testing
|
||||
// the code rather than just places like this simple main.
|
||||
go testableMain()
|
||||
srvcancel() // kills the listener
|
||||
srvwg.Wait() // joined
|
||||
func TestWorkAsIntended(t *testing.T) {
|
||||
// let the kernel pick a random free port
|
||||
*endpoint = "127.0.0.1:0"
|
||||
|
||||
// run the main function in a background goroutine
|
||||
go main()
|
||||
|
||||
// prepare the HTTP request body
|
||||
jsonReq := ctrlRequest{
|
||||
HTTPRequest: "https://dns.google",
|
||||
HTTPRequestHeaders: map[string][]string{
|
||||
"Accept": {model.HTTPHeaderAccept},
|
||||
"Accept-Language": {model.HTTPHeaderAcceptLanguage},
|
||||
"User-Agent": {model.HTTPHeaderUserAgent},
|
||||
},
|
||||
TCPConnect: []string{
|
||||
"8.8.8.8:443",
|
||||
"8.8.4.4:443",
|
||||
},
|
||||
}
|
||||
data, err := json.Marshal(jsonReq)
|
||||
runtimex.PanicOnError(err, "cannot marshal request")
|
||||
|
||||
// construct the test helper's URL
|
||||
endpoint := <-srvAddr
|
||||
URL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: endpoint,
|
||||
Path: "/",
|
||||
}
|
||||
req, err := http.NewRequest("POST", URL.String(), bytes.NewReader(data))
|
||||
runtimex.PanicOnError(err, "cannot create new HTTP request")
|
||||
|
||||
// issue the request and get the response
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatal("unexpected status code", resp.StatusCode)
|
||||
}
|
||||
|
||||
// read the response body
|
||||
data, err = netxlite.ReadAllContext(context.Background(), resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// parse the response
|
||||
var jsonResp ctrlResponse
|
||||
if err := json.Unmarshal(data, &jsonResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// very simple correctness check
|
||||
if !strings.Contains(jsonResp.HTTPRequest.Title, "Google") {
|
||||
t.Fatal("expected the response title to contain the string Google")
|
||||
}
|
||||
|
||||
// tear down the TH
|
||||
srvCancel()
|
||||
|
||||
// wait for the background goroutine to join
|
||||
srvWg.Wait()
|
||||
}
|
||||
|
|
|
@ -40,3 +40,26 @@ func (txp *HTTPClient) Do(req *http.Request) (*http.Response, error) {
|
|||
func (txp *HTTPClient) CloseIdleConnections() {
|
||||
txp.MockCloseIdleConnections()
|
||||
}
|
||||
|
||||
// HTTPResponseWriter allows mocking http.ResponseWriter.
|
||||
type HTTPResponseWriter struct {
|
||||
MockHeader func() http.Header
|
||||
|
||||
MockWrite func(b []byte) (int, error)
|
||||
|
||||
MockWriteHeader func(statusCode int)
|
||||
}
|
||||
|
||||
var _ http.ResponseWriter = &HTTPResponseWriter{}
|
||||
|
||||
func (w *HTTPResponseWriter) Header() http.Header {
|
||||
return w.MockHeader()
|
||||
}
|
||||
|
||||
func (w *HTTPResponseWriter) Write(b []byte) (int, error) {
|
||||
return w.MockWrite(b)
|
||||
}
|
||||
|
||||
func (w *HTTPResponseWriter) WriteHeader(statusCode int) {
|
||||
w.MockWriteHeader(statusCode)
|
||||
}
|
||||
|
|
|
@ -81,3 +81,49 @@ func TestHTTPClient(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPResponseWriter(t *testing.T) {
|
||||
t.Run("Header", func(t *testing.T) {
|
||||
expect := http.Header{}
|
||||
w := &HTTPResponseWriter{
|
||||
MockHeader: func() http.Header {
|
||||
return expect
|
||||
},
|
||||
}
|
||||
got := w.Header()
|
||||
got.Set("Content-Type", "text/plain")
|
||||
if expect.Get("Content-Type") != "text/plain" {
|
||||
t.Fatal("we didn't get the expected header value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
w := &HTTPResponseWriter{
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, expected
|
||||
},
|
||||
}
|
||||
buffer := make([]byte, 16)
|
||||
count, err := w.Write(buffer)
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WriteHeader", func(t *testing.T) {
|
||||
var called bool
|
||||
w := &HTTPResponseWriter{
|
||||
MockWriteHeader: func(statusCode int) {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
w.WriteHeader(200)
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user