From a884481b12bf7d33831381be75a31714a64c39eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arturo=20Filast=C3=B2?= Date: Fri, 14 Jan 2022 11:24:43 +0100 Subject: [PATCH] Refactor the list measurements function to make use of nested queries (#662) * Refactor the list measurements function to make use of nested queries With a dataset of 1489 test, the ooniprobe list command went from taking 17.27s to run, to requiring 0.17s or a 100x speed boost See https://github.com/ooni/probe/issues/1966 * Remove dead code from actions * Improve the tests for the ListResults function * Add test for AnomalyCount * Add more documentation about the merging of the test_keys --- cmd/ooniprobe/internal/cli/list/list.go | 29 +++-- cmd/ooniprobe/internal/database/actions.go | 109 ++++++++---------- .../internal/database/actions_test.go | 21 +++- cmd/ooniprobe/internal/database/models.go | 7 +- 4 files changed, 86 insertions(+), 80 deletions(-) diff --git a/cmd/ooniprobe/internal/cli/list/list.go b/cmd/ooniprobe/internal/cli/list/list.go index 05f45b6..779f1e2 100644 --- a/cmd/ooniprobe/internal/cli/list/list.go +++ b/cmd/ooniprobe/internal/cli/list/list.go @@ -1,6 +1,8 @@ package list import ( + "strings" + "github.com/alecthomas/kingpin" "github.com/apex/log" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root" @@ -92,14 +94,23 @@ func init() { netCount := make(map[uint]int) output.SectionTitle("Results") for idx, result := range doneResults { - totalCount, anmlyCount, err := database.GetMeasurementCounts(probeCLI.DB(), result.Result.ID) - if err != nil { - log.WithError(err).Error("failed to list measurement counts") - } - testKeys, err := database.GetResultTestKeys(probeCLI.DB(), result.Result.ID) - if err != nil { - log.WithError(err).Error("failed to get testKeys") + testKeys := "{}" + + // We only care to expose in the testKeys the value of the ndt test result + if result.TestGroupName == "performance" { + // The test_keys column are concanetated with the "|" character as a separator. + // We consider this to be safe since we only really care about values of the + // performance test_keys where the values are all numbers and none of the keys + // contain the "|" character. + for _, e := range strings.Split(result.TestKeys, "|") { + // We use the presence of the "download" key to indicate we have found the + // ndt test_keys, since the dash result does not contain it. + if strings.Contains(e, "download") { + testKeys = e + } + } } + output.ResultItem(output.ResultItemData{ ID: result.Result.ID, Index: idx, @@ -110,8 +121,8 @@ func init() { Country: result.Network.CountryCode, ASN: result.Network.ASN, TestKeys: testKeys, - MeasurementCount: totalCount, - MeasurementAnomalyCount: anmlyCount, + MeasurementCount: result.TotalCount, + MeasurementAnomalyCount: result.AnomalyCount, Done: result.IsDone, DataUsageUp: result.DataUsageUp, DataUsageDown: result.DataUsageDown, diff --git a/cmd/ooniprobe/internal/database/actions.go b/cmd/ooniprobe/internal/database/actions.go index 2f4b1c8..7725243 100644 --- a/cmd/ooniprobe/internal/database/actions.go +++ b/cmd/ooniprobe/internal/database/actions.go @@ -103,76 +103,59 @@ func GetMeasurementJSON(sess sqlbuilder.Database, measurementID int64) (map[stri return msmtJSON, nil } -// GetResultTestKeys returns a list of TestKeys for a given result -func GetResultTestKeys(sess sqlbuilder.Database, resultID int64) (string, error) { - res := sess.Collection("measurements").Find("result_id", resultID) - defer res.Close() - - var ( - msmt Measurement - tk PerformanceTestKeys - ) - for res.Next(&msmt) { - // We only really care about performance keys. - // Note: since even in case of failure we still initialise an empty struct, - // it could be that these keys come out as initializes with the default - // values. - // XXX we may want to change this behaviour by adding `omitempty` to the - // struct definition. - if msmt.TestName != "ndt" && msmt.TestName != "dash" { - return "{}", nil - } - if err := json.Unmarshal([]byte(msmt.TestKeys), &tk); err != nil { - log.WithError(err).Error("failed to parse testKeys") - return "{}", err - } - } - b, err := json.Marshal(tk) - if err != nil { - log.WithError(err).Error("failed to serialize testKeys") - return "{}", err - } - return string(b), nil -} - -// GetMeasurementCounts returns the number of anomalous and total measurement for a given result -func GetMeasurementCounts(sess sqlbuilder.Database, resultID int64) (uint64, uint64, error) { - var ( - totalCount uint64 - anmlyCount uint64 - err error - ) - col := sess.Collection("measurements") - - // XXX these two queries can be done with a single query - totalCount, err = col.Find("result_id", resultID). - Count() - if err != nil { - log.WithError(err).Error("failed to get total count") - return totalCount, anmlyCount, err - } - - anmlyCount, err = col.Find("result_id", resultID). - And(db.Cond{"is_anomaly": true}).Count() - if err != nil { - log.WithError(err).Error("failed to get anmly count") - return totalCount, anmlyCount, err - } - - log.Debugf("counts: %d, %d, %d", resultID, totalCount, anmlyCount) - return totalCount, anmlyCount, err -} - // ListResults return the list of results func ListResults(sess sqlbuilder.Database) ([]ResultNetwork, []ResultNetwork, error) { doneResults := []ResultNetwork{} incompleteResults := []ResultNetwork{} req := sess.Select( - db.Raw("networks.*"), - db.Raw("results.*"), + db.Raw("networks.network_name"), + db.Raw("networks.network_type"), + db.Raw("networks.ip"), + db.Raw("networks.asn"), + db.Raw("networks.network_country_code"), + + db.Raw("results.result_id"), + db.Raw("results.test_group_name"), + db.Raw("results.result_start_time"), + db.Raw("results.network_id"), + db.Raw("results.result_is_viewed"), + db.Raw("results.result_runtime"), + db.Raw("results.result_is_done"), + db.Raw("results.result_is_uploaded"), + db.Raw("results.result_data_usage_up"), + db.Raw("results.result_data_usage_down"), + db.Raw("results.measurement_dir"), + + db.Raw("COUNT(CASE WHEN measurements.is_anomaly = TRUE THEN 1 END) as anomaly_count"), + db.Raw("COUNT() as total_count"), + // The test_keys column are concanetated with the "|" character as a separator. + // We consider this to be safe since we only really care about values of the + // performance test_keys where the values are all numbers and none of the keys + // contain the "|" character. + db.Raw("group_concat(test_keys, '|') as test_keys"), ).From("results"). Join("networks").On("results.network_id = networks.network_id"). - OrderBy("results.result_start_time") + Join("measurements").On("measurements.result_id = results.result_id"). + OrderBy("results.result_start_time"). + GroupBy( + db.Raw("networks.network_name"), + db.Raw("networks.network_type"), + db.Raw("networks.ip"), + db.Raw("networks.asn"), + db.Raw("networks.network_country_code"), + + db.Raw("results.result_id"), + db.Raw("results.test_group_name"), + db.Raw("results.result_start_time"), + db.Raw("results.network_id"), + db.Raw("results.result_is_viewed"), + db.Raw("results.result_runtime"), + db.Raw("results.result_is_done"), + db.Raw("results.result_is_uploaded"), + db.Raw("results.result_data_usage_up"), + db.Raw("results.result_data_usage_down"), + db.Raw("results.measurement_dir"), + ) if err := req.Where("result_is_done = true").All(&doneResults); err != nil { return doneResults, incompleteResults, errors.Wrap(err, "failed to get result done list") } diff --git a/cmd/ooniprobe/internal/database/actions_test.go b/cmd/ooniprobe/internal/database/actions_test.go index e1feeb5..b328fa0 100644 --- a/cmd/ooniprobe/internal/database/actions_test.go +++ b/cmd/ooniprobe/internal/database/actions_test.go @@ -87,17 +87,18 @@ func TestMeasurementWorkflow(t *testing.T) { t.Fatal(err) } m1.IsUploaded = true + m1.IsAnomaly = sql.NullBool{Valid: true, Bool: false} err = sess.Collection("measurements").Find("measurement_id", m1.ID).Update(m1) if err != nil { t.Fatal(err) } - var m2 Measurement - err = sess.Collection("measurements").Find("measurement_id", m1.ID).One(&m2) + m2, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) if err != nil { t.Fatal(err) } m2.IsUploaded = false + m2.IsAnomaly = sql.NullBool{Valid: true, Bool: true} err = sess.Collection("measurements").Find("measurement_id", m2.ID).Update(m2) if err != nil { t.Fatal(err) @@ -110,6 +111,7 @@ func TestMeasurementWorkflow(t *testing.T) { if err != nil { t.Fatal(err) } + result.Finished(sess) var r Result err = sess.Collection("measurements").Find("result_id", result.ID).One(&r) @@ -125,11 +127,18 @@ func TestMeasurementWorkflow(t *testing.T) { t.Fatal(err) } - if len(incomplete) != 1 { - t.Error("there should be 1 incomplete measurement") + if len(incomplete) != 0 { + t.Error("there should be 0 incomplete result") } - if len(done) != 0 { - t.Error("there should be 0 done measurements") + if len(done) != 1 { + t.Error("there should be 1 done result") + } + + if done[0].TotalCount != 2 { + t.Error("there should be a total of 2 measurements in the result") + } + if done[0].AnomalyCount != 1 { + t.Error("there should be a total of 1 anomalies in the result") } msmts, err := ListMeasurements(sess, resultID) diff --git a/cmd/ooniprobe/internal/database/models.go b/cmd/ooniprobe/internal/database/models.go index a9ec835..c239be7 100644 --- a/cmd/ooniprobe/internal/database/models.go +++ b/cmd/ooniprobe/internal/database/models.go @@ -11,8 +11,11 @@ import ( // ResultNetwork is used to represent the structure made from the JOIN // between the results and networks tables. type ResultNetwork struct { - Result `db:",inline"` - Network `db:",inline"` + Result `db:",inline"` + Network `db:",inline"` + AnomalyCount uint64 `db:"anomaly_count"` + TotalCount uint64 `db:"total_count"` + TestKeys string `db:"test_keys"` } // UploadedTotalCount is the count of the measurements which have been uploaded vs the total measurements in a given result set