chore: upgrade to github.com/upper/db/v4 (#705)

* Upgrade to github.com/upper/db/v4

* fix(oonitest): repair imports after merge

Oops, okay, it seems the merge did not preserve all the import
changes, so let's ensure we use the right imports here!

* cleanup(go.mod): don't refer to upper.io/db/v3

These lines didn't disappear previously because the merge commit
failed to remove all references to upper.io/db/v3.

Co-authored-by: stergem <sgemelas@protonmail.com>
Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
stergem
2022-05-06 14:05:24 +03:00
committed by GitHub
parent 5d2afaade4
commit 8010e9783a
10 changed files with 201 additions and 111 deletions
+68 -78
View File
@@ -15,14 +15,13 @@ import (
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/enginex"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/utils"
"github.com/pkg/errors"
db "upper.io/db.v3"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
// ListMeasurements given a result ID
func ListMeasurements(sess sqlbuilder.Database, resultID int64) ([]MeasurementURLNetwork, error) {
func ListMeasurements(sess db.Session, resultID int64) ([]MeasurementURLNetwork, error) {
measurements := []MeasurementURLNetwork{}
req := sess.Select(
req := sess.SQL().Select(
db.Raw("networks.*"),
db.Raw("urls.*"),
db.Raw("measurements.*"),
@@ -41,12 +40,12 @@ func ListMeasurements(sess sqlbuilder.Database, resultID int64) ([]MeasurementUR
}
// GetMeasurementJSON returns a map[string]interface{} given a database and a measurementID
func GetMeasurementJSON(sess sqlbuilder.Database, measurementID int64) (map[string]interface{}, error) {
func GetMeasurementJSON(sess db.Session, measurementID int64) (map[string]interface{}, error) {
var (
measurement MeasurementURLNetwork
msmtJSON map[string]interface{}
)
req := sess.Select(
req := sess.SQL().Select(
db.Raw("urls.*"),
db.Raw("measurements.*"),
).From("measurements").
@@ -104,10 +103,10 @@ func GetMeasurementJSON(sess sqlbuilder.Database, measurementID int64) (map[stri
}
// ListResults return the list of results
func ListResults(sess sqlbuilder.Database) ([]ResultNetwork, []ResultNetwork, error) {
func ListResults(sess db.Session) ([]ResultNetwork, []ResultNetwork, error) {
doneResults := []ResultNetwork{}
incompleteResults := []ResultNetwork{}
req := sess.Select(
req := sess.SQL().Select(
db.Raw("networks.network_name"),
db.Raw("networks.network_type"),
db.Raw("networks.ip"),
@@ -167,7 +166,7 @@ func ListResults(sess sqlbuilder.Database) ([]ResultNetwork, []ResultNetwork, er
// DeleteResult will delete a particular result and the relative measurement on
// disk.
func DeleteResult(sess sqlbuilder.Database, resultID int64) error {
func DeleteResult(sess db.Session, resultID int64) error {
var result Result
res := sess.Collection("results").Find("result_id", resultID)
if err := res.One(&result); err != nil {
@@ -187,37 +186,33 @@ func DeleteResult(sess sqlbuilder.Database, resultID int64) error {
}
// UpdateUploadedStatus will check if all the measurements inside of a given result set have been uploaded and if so will set the is_uploaded flag to true
func UpdateUploadedStatus(sess sqlbuilder.Database, result *Result) error {
tx, err := sess.NewTx(nil)
if err != nil {
log.WithError(err).Error("failed to create transaction")
return err
}
func UpdateUploadedStatus(sess db.Session, result *Result) error {
err := sess.Tx(func(tx db.Session) error {
uploadedTotal := UploadedTotalCount{}
req := tx.SQL().Select(
db.Raw("SUM(measurements.measurement_is_uploaded)"),
db.Raw("COUNT(*)"),
).From("results").
Join("measurements").On("measurements.result_id = results.result_id").
Where("results.result_id = ?", result.ID)
uploadedTotal := UploadedTotalCount{}
req := tx.Select(
db.Raw("SUM(measurements.measurement_is_uploaded)"),
db.Raw("COUNT(*)"),
).From("results").
Join("measurements").On("measurements.result_id = results.result_id").
Where("results.result_id = ?", result.ID)
err = req.One(&uploadedTotal)
if err != nil {
log.WithError(err).Error("failed to retrieve total vs uploaded counts")
return err
}
if uploadedTotal.UploadedCount == uploadedTotal.TotalCount {
result.IsUploaded = true
} else {
result.IsUploaded = false
}
err = tx.Collection("results").Find("result_id", result.ID).Update(result)
if err != nil {
log.WithError(err).Error("failed to update result")
return errors.Wrap(err, "updating result")
}
err = tx.Commit()
err := req.One(&uploadedTotal)
if err != nil {
log.WithError(err).Error("failed to retrieve total vs uploaded counts")
return err
}
if uploadedTotal.UploadedCount == uploadedTotal.TotalCount {
result.IsUploaded = true
} else {
result.IsUploaded = false
}
err = tx.Collection("results").Find("result_id", result.ID).Update(result)
if err != nil {
log.WithError(err).Error("failed to update result")
return errors.Wrap(err, "updating result")
}
return nil
})
if err != nil {
log.WithError(err).Error("Failed to write to the results table")
return err
@@ -228,7 +223,7 @@ func UpdateUploadedStatus(sess sqlbuilder.Database, result *Result) error {
// CreateMeasurement writes the measurement to the database a returns a pointer
// to the Measurement
func CreateMeasurement(sess sqlbuilder.Database, reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) {
func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) {
// TODO we should look into generating this file path in a more robust way.
// If there are two identical test_names in the same test group there is
// going to be a clash of test_name
@@ -250,13 +245,13 @@ func CreateMeasurement(sess sqlbuilder.Database, reportID sql.NullString, testNa
if err != nil {
return nil, errors.Wrap(err, "creating measurement")
}
msmt.ID = newID.(int64)
msmt.ID = newID.ID().(int64)
return &msmt, nil
}
// CreateResult writes the Result to the database a returns a pointer
// to the Result
func CreateResult(sess sqlbuilder.Database, homePath string, testGroupName string, networkID int64) (*Result, error) {
func CreateResult(sess db.Session, homePath string, testGroupName string, networkID int64) (*Result, error) {
startTime := time.Now().UTC()
p, err := utils.MakeResultsDir(homePath, testGroupName, startTime)
@@ -276,12 +271,12 @@ func CreateResult(sess sqlbuilder.Database, homePath string, testGroupName strin
if err != nil {
return nil, errors.Wrap(err, "creating result")
}
result.ID = newID.(int64)
result.ID = newID.ID().(int64)
return &result, nil
}
// CreateNetwork will create a new network in the network table
func CreateNetwork(sess sqlbuilder.Database, loc enginex.LocationProvider) (*Network, error) {
func CreateNetwork(sess db.Session, loc enginex.LocationProvider) (*Network, error) {
network := Network{
ASN: loc.ProbeASN(),
CountryCode: loc.ProbeCC(),
@@ -295,59 +290,54 @@ func CreateNetwork(sess sqlbuilder.Database, loc enginex.LocationProvider) (*Net
return nil, err
}
network.ID = newID.(int64)
network.ID = newID.ID().(int64)
return &network, nil
}
// CreateOrUpdateURL will create a new URL entry to the urls table if it doesn't
// exists, otherwise it will update the category code of the one already in
// there.
func CreateOrUpdateURL(sess sqlbuilder.Database, urlStr string, categoryCode string, countryCode string) (int64, error) {
func CreateOrUpdateURL(sess db.Session, urlStr string, categoryCode string, countryCode string) (int64, error) {
var url URL
tx, err := sess.NewTx(nil)
if err != nil {
log.WithError(err).Error("failed to create transaction")
return 0, err
}
res := tx.Collection("urls").Find(
db.Cond{"url": urlStr, "url_country_code": countryCode},
)
err = res.One(&url)
err := sess.Tx(func(tx db.Session) error {
res := tx.Collection("urls").Find(
db.Cond{"url": urlStr, "url_country_code": countryCode},
)
err := res.One(&url)
if err == db.ErrNoMoreRows {
url = URL{
URL: sql.NullString{String: urlStr, Valid: true},
CategoryCode: sql.NullString{String: categoryCode, Valid: true},
CountryCode: sql.NullString{String: countryCode, Valid: true},
if err == db.ErrNoMoreRows {
url = URL{
URL: sql.NullString{String: urlStr, Valid: true},
CategoryCode: sql.NullString{String: categoryCode, Valid: true},
CountryCode: sql.NullString{String: countryCode, Valid: true},
}
newID, insErr := tx.Collection("urls").Insert(url)
if insErr != nil {
log.Error("Failed to insert into the URLs table")
return insErr
}
url.ID = sql.NullInt64{Int64: newID.ID().(int64), Valid: true}
} else if err != nil {
log.WithError(err).Error("Failed to get single result")
return err
} else {
url.CategoryCode = sql.NullString{String: categoryCode, Valid: true}
res.Update(url)
}
newID, insErr := tx.Collection("urls").Insert(url)
if insErr != nil {
log.Error("Failed to insert into the URLs table")
return 0, insErr
}
url.ID = sql.NullInt64{Int64: newID.(int64), Valid: true}
} else if err != nil {
log.WithError(err).Error("Failed to get single result")
return 0, err
} else {
url.CategoryCode = sql.NullString{String: categoryCode, Valid: true}
res.Update(url)
}
err = tx.Commit()
return nil
})
if err != nil {
log.WithError(err).Error("Failed to write to the URL table")
return 0, err
}
log.Debugf("returning url %d", url.ID.Int64)
return url.ID.Int64, nil
}
// AddTestKeys writes the summary to the measurement
func AddTestKeys(sess sqlbuilder.Database, msmt *Measurement, tk interface{}) error {
func AddTestKeys(sess db.Session, msmt *Measurement, tk interface{}) error {
var (
isAnomaly bool
isAnomalyValid bool
@@ -8,7 +8,7 @@ import (
"os"
"testing"
db "upper.io/db.v3"
"github.com/upper/db/v4"
)
type locationInfo struct {
+6 -6
View File
@@ -8,8 +8,8 @@ import (
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/netxlite"
migrate "github.com/rubenv/sql-migrate"
"upper.io/db.v3/lib/sqlbuilder"
"upper.io/db.v3/sqlite"
"github.com/upper/db/v4"
"github.com/upper/db/v4/adapter/sqlite"
)
//go:embed migrations/*.sql
@@ -36,14 +36,14 @@ func readAssetDir(path string) ([]string, error) {
}
// RunMigrations runs the database migrations
func RunMigrations(db *sql.DB) error {
func RunMigrations(sess *sql.DB) error {
log.Debugf("running migrations")
migrations := &migrate.AssetMigrationSource{
Asset: readAsset,
AssetDir: readAssetDir,
Dir: "migrations",
}
n, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up)
n, err := migrate.Exec(sess, "sqlite3", migrations, migrate.Up)
if err != nil {
return err
}
@@ -52,12 +52,12 @@ func RunMigrations(db *sql.DB) error {
}
// Connect to the database
func Connect(path string) (db sqlbuilder.Database, err error) {
func Connect(path string) (sess db.Session, err error) {
settings := sqlite.ConnectionURL{
Database: path,
Options: map[string]string{"_foreign_keys": "1"},
}
sess, err := sqlite.Open(settings)
sess, err = sqlite.Open(settings)
if err != nil {
log.WithError(err).Error("failed to open the DB")
return nil, err
+6 -6
View File
@@ -5,7 +5,7 @@ import (
"time"
"github.com/pkg/errors"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
// ResultNetwork is used to represent the structure made from the JOIN
@@ -98,7 +98,7 @@ type PerformanceTestKeys struct {
}
// Finished marks the result as done and sets the runtime
func (r *Result) Finished(sess sqlbuilder.Database) error {
func (r *Result) Finished(sess db.Session) error {
if r.IsDone == true || r.Runtime != 0 {
return errors.New("Result is already finished")
}
@@ -113,7 +113,7 @@ func (r *Result) Finished(sess sqlbuilder.Database) error {
}
// Failed writes the error string to the measurement
func (m *Measurement) Failed(sess sqlbuilder.Database, failure string) error {
func (m *Measurement) Failed(sess db.Session, failure string) error {
m.FailureMsg = sql.NullString{String: failure, Valid: true}
m.IsFailed = true
err := sess.Collection("measurements").Find("measurement_id", m.ID).Update(m)
@@ -124,7 +124,7 @@ func (m *Measurement) Failed(sess sqlbuilder.Database, failure string) error {
}
// Done marks the measurement as completed
func (m *Measurement) Done(sess sqlbuilder.Database) error {
func (m *Measurement) Done(sess db.Session) error {
runtime := time.Now().UTC().Sub(m.StartTime)
m.Runtime = runtime.Seconds()
m.IsDone = true
@@ -137,7 +137,7 @@ func (m *Measurement) Done(sess sqlbuilder.Database) error {
}
// UploadFailed writes the error string for the upload failure to the measurement
func (m *Measurement) UploadFailed(sess sqlbuilder.Database, failure string) error {
func (m *Measurement) UploadFailed(sess db.Session, failure string) error {
m.UploadFailureMsg = sql.NullString{String: failure, Valid: true}
m.IsUploaded = false
@@ -149,7 +149,7 @@ func (m *Measurement) UploadFailed(sess sqlbuilder.Database, failure string) err
}
// UploadSucceeded writes the error string for the upload failure to the measurement
func (m *Measurement) UploadSucceeded(sess sqlbuilder.Database) error {
func (m *Measurement) UploadSucceeded(sess db.Session) error {
m.IsUploaded = true
err := sess.Collection("measurements").Find("measurement_id", m.ID).Update(m)