package main import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" ) const simplerequest = `{ "http_request": "https://dns.google", "http_request_headers": { "Accept": [ "*/*" ], "Accept-Language": [ "en-US;q=0.8,en;q=0.5" ], "User-Agent": [ "Mozilla/5.0" ] }, "tcp_connect": [ "8.8.8.8:443" ] }` const requestWithoutDomainName = `{ "http_request": "https://8.8.8.8", "http_request_headers": { "Accept": [ "*/*" ], "Accept-Language": [ "en-US;q=0.8,en;q=0.5" ], "User-Agent": [ "Mozilla/5.0" ] }, "tcp_connect": [ "8.8.8.8:443" ] }` func TestHandlerWorkingAsIntended(t *testing.T) { handler := &handler{ BaseLogger: model.DiscardLogger, Indexer: &atomicx.Int64{}, MaxAcceptableBody: 1 << 24, NewClient: func(model.Logger) model.HTTPClient { return http.DefaultClient }, NewDialer: func(model.Logger) model.Dialer { return netxlite.NewDialerWithStdlibResolver(model.DiscardLogger) }, NewResolver: func(model.Logger) model.Resolver { return netxlite.NewUnwrappedStdlibResolver() }, NewTLSHandshaker: func(model.Logger) model.TLSHandshaker { return netxlite.NewTLSHandshakerStdlib(model.DiscardLogger) }, } srv := httptest.NewServer(handler) defer srv.Close() type expectationSpec struct { name string reqMethod string reqContentType string reqBody string respStatusCode int respContentType string parseBody bool } expectations := []expectationSpec{{ name: "check for invalid method", reqMethod: "GET", reqContentType: "", reqBody: "", respStatusCode: 400, respContentType: "", parseBody: false, }, { name: "check for invalid content-type", reqMethod: "POST", reqContentType: "", reqBody: "", respStatusCode: 400, respContentType: "", parseBody: false, }, { name: "check for invalid request body", reqMethod: "POST", reqContentType: "application/json", reqBody: "{", respStatusCode: 400, respContentType: "", parseBody: false, }, { name: "with measurement failure", reqMethod: "POST", reqContentType: "application/json", reqBody: `{"http_request": "http://[::1]aaaa"}`, respStatusCode: 400, respContentType: "", parseBody: false, }, { name: "with reasonably good request", reqMethod: "POST", reqContentType: "application/json", reqBody: simplerequest, respStatusCode: 200, respContentType: "application/json", parseBody: true, }, { name: "when there's no domain name in the request", reqMethod: "POST", reqContentType: "application/json", reqBody: requestWithoutDomainName, respStatusCode: 200, respContentType: "application/json", parseBody: true, }} for _, expect := range expectations { t.Run(expect.name, func(t *testing.T) { body := strings.NewReader(expect.reqBody) req, err := http.NewRequest(expect.reqMethod, srv.URL, body) if err != nil { t.Fatalf("%s: %+v", expect.name, err) } if expect.reqContentType != "" { req.Header.Add("content-type", expect.reqContentType) } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("%s: %+v", expect.name, err) } defer resp.Body.Close() if resp.StatusCode != expect.respStatusCode { t.Fatalf("unexpected status code: %+v", resp.StatusCode) } if v := resp.Header.Get("content-type"); v != expect.respContentType { t.Fatalf("unexpected content-type: %s", v) } data, err := netxlite.ReadAllContext(context.Background(), resp.Body) if err != nil { t.Fatal(err) } if !expect.parseBody { return } var v interface{} if err := json.Unmarshal(data, &v); err != nil { t.Fatal(err) } }) } } func TestHandlerWithRequestBodyReadingError(t *testing.T) { expected := errors.New("mocked error") handler := handler{MaxAcceptableBody: 1 << 24} var statusCode int headers := http.Header{} rw := &mocks.HTTPResponseWriter{ MockWriteHeader: func(code int) { statusCode = code }, MockHeader: func() http.Header { return headers }, } req := &http.Request{ Method: "POST", Header: map[string][]string{ "Content-Type": {"application/json"}, "Content-Length": {"2048"}, }, Body: io.NopCloser(&mocks.Reader{ MockRead: func(b []byte) (int, error) { return 0, expected }, }), } handler.ServeHTTP(rw, req) if statusCode != 400 { t.Fatal("unexpected status code") } }