ooni-probe-cli/internal/database/actions.go
DecFox 28aabe0947
feat: introduce database type (#982)
See https://github.com/ooni/probe/issues/2352

Co-authored-by: decfox <decfox@github.com>
2022-11-16 15:51:41 +01:00

393 lines
12 KiB
Go

package database
import (
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"time"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine"
"github.com/pkg/errors"
"github.com/upper/db/v4"
)
// Open returns a new database instance
func Open(dbpath string) (*Database, error) {
db, err := Connect(dbpath)
if err != nil {
return nil, err
}
return &Database{
sess: db,
}, nil
}
// Database is a database instance to store measurements
type Database struct {
sess db.Session
}
// Session returns the database session
func (d *Database) Session() db.Session {
return d.sess
}
// ListMeasurements given a result ID
func (d *Database) ListMeasurements(resultID int64) ([]MeasurementURLNetwork, error) {
measurements := []MeasurementURLNetwork{}
req := d.sess.SQL().Select(
db.Raw("networks.*"),
db.Raw("urls.*"),
db.Raw("measurements.*"),
db.Raw("results.*"),
).From("results").
Join("measurements").On("results.result_id = measurements.result_id").
Join("networks").On("results.network_id = networks.network_id").
LeftJoin("urls").On("urls.url_id = measurements.url_id").
OrderBy("measurements.measurement_start_time").
Where("results.result_id = ?", resultID)
if err := req.All(&measurements); err != nil {
log.Errorf("failed to run query %s: %v", req.String(), err)
return measurements, err
}
return measurements, nil
}
// GetMeasurementJSON returns a map[string]interface{} given a database and a measurementID
func (d *Database) GetMeasurementJSON(measurementID int64) (map[string]interface{}, error) {
var (
measurement MeasurementURLNetwork
msmtJSON map[string]interface{}
)
req := d.sess.SQL().Select(
db.Raw("urls.*"),
db.Raw("measurements.*"),
).From("measurements").
LeftJoin("urls").On("urls.url_id = measurements.url_id").
Where("measurements.measurement_id= ?", measurementID)
if err := req.One(&measurement); err != nil {
log.Errorf("failed to run query %s: %v", req.String(), err)
return nil, err
}
if measurement.Measurement.IsUploaded {
// TODO(bassosimone): this should be a function exposed by probe-engine
reportID := measurement.Measurement.ReportID.String
measurementURL := &url.URL{
Scheme: "https",
Host: "api.ooni.io",
Path: "/api/v1/raw_measurement",
}
query := url.Values{}
query.Add("report_id", reportID)
if measurement.URL.URL.Valid {
query.Add("input", measurement.URL.URL.String)
}
measurementURL.RawQuery = query.Encode()
log.Debugf("using %s", measurementURL.String())
resp, err := http.Get(measurementURL.String())
if err != nil {
log.Errorf("failed to fetch the measurement %s %s", reportID, measurement.URL.URL.String)
return nil, err
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(&msmtJSON); err != nil {
log.Error("failed to unmarshal the measurement_json")
return nil, err
}
return msmtJSON, nil
}
// MeasurementFilePath might be NULL because the measurement from a
// 3.0.0-beta install
if !measurement.Measurement.MeasurementFilePath.Valid {
log.Error("invalid measurement_file_path")
log.Error("backup your OONI_HOME and run `ooniprobe reset`")
return nil, errors.New("cannot access measurement file")
}
measurementFilePath := measurement.Measurement.MeasurementFilePath.String
b, err := os.ReadFile(measurementFilePath)
if err != nil {
return nil, err
}
if err := json.Unmarshal(b, &msmtJSON); err != nil {
log.Error("failed to unmarshal the measurement_json")
log.Error("backup your OONI_HOME and run `ooniprobe reset`")
return nil, err
}
return msmtJSON, nil
}
// ListResults return the list of results
func (d *Database) ListResults() ([]ResultNetwork, []ResultNetwork, error) {
doneResults := []ResultNetwork{}
incompleteResults := []ResultNetwork{}
req := d.sess.SQL().Select(
db.Raw("networks.network_name"),
db.Raw("networks.network_type"),
db.Raw("networks.ip"),
db.Raw("networks.asn"),
db.Raw("networks.network_country_code"),
db.Raw("results.result_id"),
db.Raw("results.test_group_name"),
db.Raw("results.result_start_time"),
db.Raw("results.network_id"),
db.Raw("results.result_is_viewed"),
db.Raw("results.result_runtime"),
db.Raw("results.result_is_done"),
db.Raw("results.result_is_uploaded"),
db.Raw("results.result_data_usage_up"),
db.Raw("results.result_data_usage_down"),
db.Raw("results.measurement_dir"),
db.Raw("COUNT(CASE WHEN measurements.is_anomaly = TRUE THEN 1 END) as anomaly_count"),
db.Raw("COUNT() as total_count"),
// The test_keys column are concanetated with the "|" character as a separator.
// We consider this to be safe since we only really care about values of the
// performance test_keys where the values are all numbers and none of the keys
// contain the "|" character.
db.Raw("group_concat(test_keys, '|') as test_keys"),
).From("results").
Join("networks").On("results.network_id = networks.network_id").
Join("measurements").On("measurements.result_id = results.result_id").
OrderBy("results.result_start_time").
GroupBy(
db.Raw("networks.network_name"),
db.Raw("networks.network_type"),
db.Raw("networks.ip"),
db.Raw("networks.asn"),
db.Raw("networks.network_country_code"),
db.Raw("results.result_id"),
db.Raw("results.test_group_name"),
db.Raw("results.result_start_time"),
db.Raw("results.network_id"),
db.Raw("results.result_is_viewed"),
db.Raw("results.result_runtime"),
db.Raw("results.result_is_done"),
db.Raw("results.result_is_uploaded"),
db.Raw("results.result_data_usage_up"),
db.Raw("results.result_data_usage_down"),
db.Raw("results.measurement_dir"),
)
if err := req.Where("result_is_done = true").All(&doneResults); err != nil {
return doneResults, incompleteResults, errors.Wrap(err, "failed to get result done list")
}
if err := req.Where("result_is_done = false").All(&incompleteResults); err != nil {
return doneResults, incompleteResults, errors.Wrap(err, "failed to get result done list")
}
return doneResults, incompleteResults, nil
}
// DeleteResult will delete a particular result and the relative measurement on
// disk.
func (d *Database) DeleteResult(resultID int64) error {
var result Result
res := d.sess.Collection("results").Find("result_id", resultID)
if err := res.One(&result); err != nil {
if err == db.ErrNoMoreRows {
return err
}
log.WithError(err).Error("error in obtaining the result")
return err
}
if err := res.Delete(); err != nil {
log.WithError(err).Error("failed to delete the result directory")
return err
}
os.RemoveAll(result.MeasurementDir)
return nil
}
// UpdateUploadedStatus will check if all the measurements inside of a given result set have been uploaded and if so will set the is_uploaded flag to true
func (d *Database) UpdateUploadedStatus(result *Result) error {
err := d.sess.Tx(func(tx db.Session) error {
uploadedTotal := UploadedTotalCount{}
req := tx.SQL().Select(
db.Raw("SUM(measurements.measurement_is_uploaded)"),
db.Raw("COUNT(*)"),
).From("results").
Join("measurements").On("measurements.result_id = results.result_id").
Where("results.result_id = ?", result.ID)
err := req.One(&uploadedTotal)
if err != nil {
log.WithError(err).Error("failed to retrieve total vs uploaded counts")
return err
}
if uploadedTotal.UploadedCount == uploadedTotal.TotalCount {
result.IsUploaded = true
} else {
result.IsUploaded = false
}
err = tx.Collection("results").Find("result_id", result.ID).Update(result)
if err != nil {
log.WithError(err).Error("failed to update result")
return errors.Wrap(err, "updating result")
}
return nil
})
if err != nil {
log.WithError(err).Error("Failed to write to the results table")
return err
}
return nil
}
// CreateMeasurement writes the measurement to the database a returns a pointer
// to the Measurement
func (d *Database) CreateMeasurement(reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) {
// TODO we should look into generating this file path in a more robust way.
// If there are two identical test_names in the same test group there is
// going to be a clash of test_name
msmtFilePath := filepath.Join(measurementDir, fmt.Sprintf("msmt-%s-%d.json", testName, idx))
msmt := Measurement{
ReportID: reportID,
TestName: testName,
ResultID: resultID,
MeasurementFilePath: sql.NullString{String: msmtFilePath, Valid: true},
URLID: urlID,
IsFailed: false,
IsDone: false,
// XXX Do we want to have this be part of something else?
StartTime: time.Now().UTC(),
TestKeys: "",
}
newID, err := d.sess.Collection("measurements").Insert(msmt)
if err != nil {
return nil, errors.Wrap(err, "creating measurement")
}
msmt.ID = newID.ID().(int64)
return &msmt, nil
}
// CreateResult writes the Result to the database a returns a pointer
// to the Result
func (d *Database) CreateResult(homePath string, testGroupName string, networkID int64) (*Result, error) {
startTime := time.Now().UTC()
p, err := makeResultsDir(homePath, testGroupName, startTime)
if err != nil {
return nil, err
}
result := Result{
TestGroupName: testGroupName,
StartTime: startTime,
NetworkID: networkID,
}
result.MeasurementDir = p
log.Debugf("Creating result %v", result)
newID, err := d.sess.Collection("results").Insert(result)
if err != nil {
return nil, errors.Wrap(err, "creating result")
}
result.ID = newID.ID().(int64)
return &result, nil
}
// CreateNetwork will create a new network in the network table
func (d *Database) CreateNetwork(loc engine.LocationProvider) (*Network, error) {
network := Network{
ASN: loc.ProbeASN(),
CountryCode: loc.ProbeCC(),
NetworkName: loc.ProbeNetworkName(),
// On desktop we consider it to always be wifi
NetworkType: "wifi",
IP: loc.ProbeIP(),
}
newID, err := d.sess.Collection("networks").Insert(network)
if err != nil {
return nil, err
}
network.ID = newID.ID().(int64)
return &network, nil
}
// CreateOrUpdateURL will create a new URL entry to the urls table if it doesn't
// exists, otherwise it will update the category code of the one already in
// there.
func (d *Database) CreateOrUpdateURL(urlStr string, categoryCode string, countryCode string) (int64, error) {
var url URL
err := d.sess.Tx(func(tx db.Session) error {
res := tx.Collection("urls").Find(
db.Cond{"url": urlStr, "url_country_code": countryCode},
)
err := res.One(&url)
if err == db.ErrNoMoreRows {
url = URL{
URL: sql.NullString{String: urlStr, Valid: true},
CategoryCode: sql.NullString{String: categoryCode, Valid: true},
CountryCode: sql.NullString{String: countryCode, Valid: true},
}
newID, insErr := tx.Collection("urls").Insert(url)
if insErr != nil {
log.Error("Failed to insert into the URLs table")
return insErr
}
url.ID = sql.NullInt64{Int64: newID.ID().(int64), Valid: true}
} else if err != nil {
log.WithError(err).Error("Failed to get single result")
return err
} else {
url.CategoryCode = sql.NullString{String: categoryCode, Valid: true}
res.Update(url)
}
return nil
})
if err != nil {
log.WithError(err).Error("Failed to write to the URL table")
return 0, err
}
return url.ID.Int64, nil
}
// AddTestKeys writes the summary to the measurement
func (d *Database) AddTestKeys(msmt *Measurement, tk interface{}) error {
var (
isAnomaly bool
isAnomalyValid bool
)
tkBytes, err := json.Marshal(tk)
if err != nil {
log.WithError(err).Error("failed to serialize summary")
}
// This is necessary so that we can extract from the the opaque testKeys just
// the IsAnomaly field of bool type.
// Maybe generics are not so bad after-all, heh golang?
isAnomalyValue := reflect.ValueOf(tk).FieldByName("IsAnomaly")
if isAnomalyValue.IsValid() && isAnomalyValue.Kind() == reflect.Bool {
isAnomaly = isAnomalyValue.Bool()
isAnomalyValid = true
}
msmt.TestKeys = string(tkBytes)
msmt.IsAnomaly = sql.NullBool{Bool: isAnomaly, Valid: isAnomalyValid}
err = d.sess.Collection("measurements").Find("measurement_id", msmt.ID).Update(msmt)
if err != nil {
log.WithError(err).Error("failed to update measurement")
return errors.Wrap(err, "updating measurement")
}
return nil
}
// Close closes the database session
func (d *Database) Close() error {
return d.sess.Close()
}