Refactor the list measurements function to make use of nested queries (#662)

* Refactor the list measurements function to make use of nested queries

With a dataset of 1489 test, the ooniprobe list command went from
taking 17.27s to run, to requiring 0.17s or a 100x speed boost

See https://github.com/ooni/probe/issues/1966

* Remove dead code from actions

* Improve the tests for the ListResults function

* Add test for AnomalyCount

* Add more documentation about the merging of the test_keys
This commit is contained in:
Arturo Filastò 2022-01-14 11:24:43 +01:00 committed by GitHub
parent b5da8be183
commit a884481b12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 80 deletions

View File

@ -1,6 +1,8 @@
package list package list
import ( import (
"strings"
"github.com/alecthomas/kingpin" "github.com/alecthomas/kingpin"
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root"
@ -92,14 +94,23 @@ func init() {
netCount := make(map[uint]int) netCount := make(map[uint]int)
output.SectionTitle("Results") output.SectionTitle("Results")
for idx, result := range doneResults { for idx, result := range doneResults {
totalCount, anmlyCount, err := database.GetMeasurementCounts(probeCLI.DB(), result.Result.ID) testKeys := "{}"
if err != nil {
log.WithError(err).Error("failed to list measurement counts") // We only care to expose in the testKeys the value of the ndt test result
} if result.TestGroupName == "performance" {
testKeys, err := database.GetResultTestKeys(probeCLI.DB(), result.Result.ID) // The test_keys column are concanetated with the "|" character as a separator.
if err != nil { // We consider this to be safe since we only really care about values of the
log.WithError(err).Error("failed to get testKeys") // performance test_keys where the values are all numbers and none of the keys
// contain the "|" character.
for _, e := range strings.Split(result.TestKeys, "|") {
// We use the presence of the "download" key to indicate we have found the
// ndt test_keys, since the dash result does not contain it.
if strings.Contains(e, "download") {
testKeys = e
}
}
} }
output.ResultItem(output.ResultItemData{ output.ResultItem(output.ResultItemData{
ID: result.Result.ID, ID: result.Result.ID,
Index: idx, Index: idx,
@ -110,8 +121,8 @@ func init() {
Country: result.Network.CountryCode, Country: result.Network.CountryCode,
ASN: result.Network.ASN, ASN: result.Network.ASN,
TestKeys: testKeys, TestKeys: testKeys,
MeasurementCount: totalCount, MeasurementCount: result.TotalCount,
MeasurementAnomalyCount: anmlyCount, MeasurementAnomalyCount: result.AnomalyCount,
Done: result.IsDone, Done: result.IsDone,
DataUsageUp: result.DataUsageUp, DataUsageUp: result.DataUsageUp,
DataUsageDown: result.DataUsageDown, DataUsageDown: result.DataUsageDown,

View File

@ -103,76 +103,59 @@ func GetMeasurementJSON(sess sqlbuilder.Database, measurementID int64) (map[stri
return msmtJSON, nil return msmtJSON, nil
} }
// GetResultTestKeys returns a list of TestKeys for a given result
func GetResultTestKeys(sess sqlbuilder.Database, resultID int64) (string, error) {
res := sess.Collection("measurements").Find("result_id", resultID)
defer res.Close()
var (
msmt Measurement
tk PerformanceTestKeys
)
for res.Next(&msmt) {
// We only really care about performance keys.
// Note: since even in case of failure we still initialise an empty struct,
// it could be that these keys come out as initializes with the default
// values.
// XXX we may want to change this behaviour by adding `omitempty` to the
// struct definition.
if msmt.TestName != "ndt" && msmt.TestName != "dash" {
return "{}", nil
}
if err := json.Unmarshal([]byte(msmt.TestKeys), &tk); err != nil {
log.WithError(err).Error("failed to parse testKeys")
return "{}", err
}
}
b, err := json.Marshal(tk)
if err != nil {
log.WithError(err).Error("failed to serialize testKeys")
return "{}", err
}
return string(b), nil
}
// GetMeasurementCounts returns the number of anomalous and total measurement for a given result
func GetMeasurementCounts(sess sqlbuilder.Database, resultID int64) (uint64, uint64, error) {
var (
totalCount uint64
anmlyCount uint64
err error
)
col := sess.Collection("measurements")
// XXX these two queries can be done with a single query
totalCount, err = col.Find("result_id", resultID).
Count()
if err != nil {
log.WithError(err).Error("failed to get total count")
return totalCount, anmlyCount, err
}
anmlyCount, err = col.Find("result_id", resultID).
And(db.Cond{"is_anomaly": true}).Count()
if err != nil {
log.WithError(err).Error("failed to get anmly count")
return totalCount, anmlyCount, err
}
log.Debugf("counts: %d, %d, %d", resultID, totalCount, anmlyCount)
return totalCount, anmlyCount, err
}
// ListResults return the list of results // ListResults return the list of results
func ListResults(sess sqlbuilder.Database) ([]ResultNetwork, []ResultNetwork, error) { func ListResults(sess sqlbuilder.Database) ([]ResultNetwork, []ResultNetwork, error) {
doneResults := []ResultNetwork{} doneResults := []ResultNetwork{}
incompleteResults := []ResultNetwork{} incompleteResults := []ResultNetwork{}
req := sess.Select( req := sess.Select(
db.Raw("networks.*"), db.Raw("networks.network_name"),
db.Raw("results.*"), db.Raw("networks.network_type"),
db.Raw("networks.ip"),
db.Raw("networks.asn"),
db.Raw("networks.network_country_code"),
db.Raw("results.result_id"),
db.Raw("results.test_group_name"),
db.Raw("results.result_start_time"),
db.Raw("results.network_id"),
db.Raw("results.result_is_viewed"),
db.Raw("results.result_runtime"),
db.Raw("results.result_is_done"),
db.Raw("results.result_is_uploaded"),
db.Raw("results.result_data_usage_up"),
db.Raw("results.result_data_usage_down"),
db.Raw("results.measurement_dir"),
db.Raw("COUNT(CASE WHEN measurements.is_anomaly = TRUE THEN 1 END) as anomaly_count"),
db.Raw("COUNT() as total_count"),
// The test_keys column are concanetated with the "|" character as a separator.
// We consider this to be safe since we only really care about values of the
// performance test_keys where the values are all numbers and none of the keys
// contain the "|" character.
db.Raw("group_concat(test_keys, '|') as test_keys"),
).From("results"). ).From("results").
Join("networks").On("results.network_id = networks.network_id"). Join("networks").On("results.network_id = networks.network_id").
OrderBy("results.result_start_time") Join("measurements").On("measurements.result_id = results.result_id").
OrderBy("results.result_start_time").
GroupBy(
db.Raw("networks.network_name"),
db.Raw("networks.network_type"),
db.Raw("networks.ip"),
db.Raw("networks.asn"),
db.Raw("networks.network_country_code"),
db.Raw("results.result_id"),
db.Raw("results.test_group_name"),
db.Raw("results.result_start_time"),
db.Raw("results.network_id"),
db.Raw("results.result_is_viewed"),
db.Raw("results.result_runtime"),
db.Raw("results.result_is_done"),
db.Raw("results.result_is_uploaded"),
db.Raw("results.result_data_usage_up"),
db.Raw("results.result_data_usage_down"),
db.Raw("results.measurement_dir"),
)
if err := req.Where("result_is_done = true").All(&doneResults); err != nil { if err := req.Where("result_is_done = true").All(&doneResults); err != nil {
return doneResults, incompleteResults, errors.Wrap(err, "failed to get result done list") return doneResults, incompleteResults, errors.Wrap(err, "failed to get result done list")
} }

View File

@ -87,17 +87,18 @@ func TestMeasurementWorkflow(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
m1.IsUploaded = true m1.IsUploaded = true
m1.IsAnomaly = sql.NullBool{Valid: true, Bool: false}
err = sess.Collection("measurements").Find("measurement_id", m1.ID).Update(m1) err = sess.Collection("measurements").Find("measurement_id", m1.ID).Update(m1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
var m2 Measurement m2, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID)
err = sess.Collection("measurements").Find("measurement_id", m1.ID).One(&m2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
m2.IsUploaded = false m2.IsUploaded = false
m2.IsAnomaly = sql.NullBool{Valid: true, Bool: true}
err = sess.Collection("measurements").Find("measurement_id", m2.ID).Update(m2) err = sess.Collection("measurements").Find("measurement_id", m2.ID).Update(m2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -110,6 +111,7 @@ func TestMeasurementWorkflow(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
result.Finished(sess)
var r Result var r Result
err = sess.Collection("measurements").Find("result_id", result.ID).One(&r) err = sess.Collection("measurements").Find("result_id", result.ID).One(&r)
@ -125,11 +127,18 @@ func TestMeasurementWorkflow(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if len(incomplete) != 1 { if len(incomplete) != 0 {
t.Error("there should be 1 incomplete measurement") t.Error("there should be 0 incomplete result")
} }
if len(done) != 0 { if len(done) != 1 {
t.Error("there should be 0 done measurements") t.Error("there should be 1 done result")
}
if done[0].TotalCount != 2 {
t.Error("there should be a total of 2 measurements in the result")
}
if done[0].AnomalyCount != 1 {
t.Error("there should be a total of 1 anomalies in the result")
} }
msmts, err := ListMeasurements(sess, resultID) msmts, err := ListMeasurements(sess, resultID)

View File

@ -11,8 +11,11 @@ import (
// ResultNetwork is used to represent the structure made from the JOIN // ResultNetwork is used to represent the structure made from the JOIN
// between the results and networks tables. // between the results and networks tables.
type ResultNetwork struct { type ResultNetwork struct {
Result `db:",inline"` Result `db:",inline"`
Network `db:",inline"` Network `db:",inline"`
AnomalyCount uint64 `db:"anomaly_count"`
TotalCount uint64 `db:"total_count"`
TestKeys string `db:"test_keys"`
} }
// UploadedTotalCount is the count of the measurements which have been uploaded vs the total measurements in a given result set // UploadedTotalCount is the count of the measurements which have been uploaded vs the total measurements in a given result set