chore: upgrade to github.com/upper/db/v4 (#705)

* Upgrade to github.com/upper/db/v4

* fix(oonitest): repair imports after merge

Oops, okay, it seems the merge did not preserve all the import
changes, so let's ensure we use the right imports here!

* cleanup(go.mod): don't refer to upper.io/db/v3

These lines didn't disappear previously because the merge commit
failed to remove all references to upper.io/db/v3.

Co-authored-by: stergem <sgemelas@protonmail.com>
Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
stergem
2022-05-06 14:05:24 +03:00
committed by GitHub
parent 5d2afaade4
commit 8010e9783a
10 changed files with 201 additions and 111 deletions
+2 -3
View File
@@ -9,11 +9,10 @@ import (
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/database"
db "upper.io/db.v3"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
func deleteAll(sess sqlbuilder.Database, skipInteractive bool) error {
func deleteAll(sess db.Session, skipInteractive bool) error {
if skipInteractive == false {
answer := ""
confirm := &survey.Select{
+68 -78
View File
@@ -15,14 +15,13 @@ import (
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/enginex"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/utils"
"github.com/pkg/errors"
db "upper.io/db.v3"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
// ListMeasurements given a result ID
func ListMeasurements(sess sqlbuilder.Database, resultID int64) ([]MeasurementURLNetwork, error) {
func ListMeasurements(sess db.Session, resultID int64) ([]MeasurementURLNetwork, error) {
measurements := []MeasurementURLNetwork{}
req := sess.Select(
req := sess.SQL().Select(
db.Raw("networks.*"),
db.Raw("urls.*"),
db.Raw("measurements.*"),
@@ -41,12 +40,12 @@ func ListMeasurements(sess sqlbuilder.Database, resultID int64) ([]MeasurementUR
}
// GetMeasurementJSON returns a map[string]interface{} given a database and a measurementID
func GetMeasurementJSON(sess sqlbuilder.Database, measurementID int64) (map[string]interface{}, error) {
func GetMeasurementJSON(sess db.Session, measurementID int64) (map[string]interface{}, error) {
var (
measurement MeasurementURLNetwork
msmtJSON map[string]interface{}
)
req := sess.Select(
req := sess.SQL().Select(
db.Raw("urls.*"),
db.Raw("measurements.*"),
).From("measurements").
@@ -104,10 +103,10 @@ func GetMeasurementJSON(sess sqlbuilder.Database, measurementID int64) (map[stri
}
// ListResults return the list of results
func ListResults(sess sqlbuilder.Database) ([]ResultNetwork, []ResultNetwork, error) {
func ListResults(sess db.Session) ([]ResultNetwork, []ResultNetwork, error) {
doneResults := []ResultNetwork{}
incompleteResults := []ResultNetwork{}
req := sess.Select(
req := sess.SQL().Select(
db.Raw("networks.network_name"),
db.Raw("networks.network_type"),
db.Raw("networks.ip"),
@@ -167,7 +166,7 @@ func ListResults(sess sqlbuilder.Database) ([]ResultNetwork, []ResultNetwork, er
// DeleteResult will delete a particular result and the relative measurement on
// disk.
func DeleteResult(sess sqlbuilder.Database, resultID int64) error {
func DeleteResult(sess db.Session, resultID int64) error {
var result Result
res := sess.Collection("results").Find("result_id", resultID)
if err := res.One(&result); err != nil {
@@ -187,37 +186,33 @@ func DeleteResult(sess sqlbuilder.Database, resultID int64) error {
}
// UpdateUploadedStatus will check if all the measurements inside of a given result set have been uploaded and if so will set the is_uploaded flag to true
func UpdateUploadedStatus(sess sqlbuilder.Database, result *Result) error {
tx, err := sess.NewTx(nil)
if err != nil {
log.WithError(err).Error("failed to create transaction")
return err
}
func UpdateUploadedStatus(sess db.Session, result *Result) error {
err := sess.Tx(func(tx db.Session) error {
uploadedTotal := UploadedTotalCount{}
req := tx.SQL().Select(
db.Raw("SUM(measurements.measurement_is_uploaded)"),
db.Raw("COUNT(*)"),
).From("results").
Join("measurements").On("measurements.result_id = results.result_id").
Where("results.result_id = ?", result.ID)
uploadedTotal := UploadedTotalCount{}
req := tx.Select(
db.Raw("SUM(measurements.measurement_is_uploaded)"),
db.Raw("COUNT(*)"),
).From("results").
Join("measurements").On("measurements.result_id = results.result_id").
Where("results.result_id = ?", result.ID)
err = req.One(&uploadedTotal)
if err != nil {
log.WithError(err).Error("failed to retrieve total vs uploaded counts")
return err
}
if uploadedTotal.UploadedCount == uploadedTotal.TotalCount {
result.IsUploaded = true
} else {
result.IsUploaded = false
}
err = tx.Collection("results").Find("result_id", result.ID).Update(result)
if err != nil {
log.WithError(err).Error("failed to update result")
return errors.Wrap(err, "updating result")
}
err = tx.Commit()
err := req.One(&uploadedTotal)
if err != nil {
log.WithError(err).Error("failed to retrieve total vs uploaded counts")
return err
}
if uploadedTotal.UploadedCount == uploadedTotal.TotalCount {
result.IsUploaded = true
} else {
result.IsUploaded = false
}
err = tx.Collection("results").Find("result_id", result.ID).Update(result)
if err != nil {
log.WithError(err).Error("failed to update result")
return errors.Wrap(err, "updating result")
}
return nil
})
if err != nil {
log.WithError(err).Error("Failed to write to the results table")
return err
@@ -228,7 +223,7 @@ func UpdateUploadedStatus(sess sqlbuilder.Database, result *Result) error {
// CreateMeasurement writes the measurement to the database a returns a pointer
// to the Measurement
func CreateMeasurement(sess sqlbuilder.Database, reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) {
func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) {
// TODO we should look into generating this file path in a more robust way.
// If there are two identical test_names in the same test group there is
// going to be a clash of test_name
@@ -250,13 +245,13 @@ func CreateMeasurement(sess sqlbuilder.Database, reportID sql.NullString, testNa
if err != nil {
return nil, errors.Wrap(err, "creating measurement")
}
msmt.ID = newID.(int64)
msmt.ID = newID.ID().(int64)
return &msmt, nil
}
// CreateResult writes the Result to the database a returns a pointer
// to the Result
func CreateResult(sess sqlbuilder.Database, homePath string, testGroupName string, networkID int64) (*Result, error) {
func CreateResult(sess db.Session, homePath string, testGroupName string, networkID int64) (*Result, error) {
startTime := time.Now().UTC()
p, err := utils.MakeResultsDir(homePath, testGroupName, startTime)
@@ -276,12 +271,12 @@ func CreateResult(sess sqlbuilder.Database, homePath string, testGroupName strin
if err != nil {
return nil, errors.Wrap(err, "creating result")
}
result.ID = newID.(int64)
result.ID = newID.ID().(int64)
return &result, nil
}
// CreateNetwork will create a new network in the network table
func CreateNetwork(sess sqlbuilder.Database, loc enginex.LocationProvider) (*Network, error) {
func CreateNetwork(sess db.Session, loc enginex.LocationProvider) (*Network, error) {
network := Network{
ASN: loc.ProbeASN(),
CountryCode: loc.ProbeCC(),
@@ -295,59 +290,54 @@ func CreateNetwork(sess sqlbuilder.Database, loc enginex.LocationProvider) (*Net
return nil, err
}
network.ID = newID.(int64)
network.ID = newID.ID().(int64)
return &network, nil
}
// CreateOrUpdateURL will create a new URL entry to the urls table if it doesn't
// exists, otherwise it will update the category code of the one already in
// there.
func CreateOrUpdateURL(sess sqlbuilder.Database, urlStr string, categoryCode string, countryCode string) (int64, error) {
func CreateOrUpdateURL(sess db.Session, urlStr string, categoryCode string, countryCode string) (int64, error) {
var url URL
tx, err := sess.NewTx(nil)
if err != nil {
log.WithError(err).Error("failed to create transaction")
return 0, err
}
res := tx.Collection("urls").Find(
db.Cond{"url": urlStr, "url_country_code": countryCode},
)
err = res.One(&url)
err := sess.Tx(func(tx db.Session) error {
res := tx.Collection("urls").Find(
db.Cond{"url": urlStr, "url_country_code": countryCode},
)
err := res.One(&url)
if err == db.ErrNoMoreRows {
url = URL{
URL: sql.NullString{String: urlStr, Valid: true},
CategoryCode: sql.NullString{String: categoryCode, Valid: true},
CountryCode: sql.NullString{String: countryCode, Valid: true},
if err == db.ErrNoMoreRows {
url = URL{
URL: sql.NullString{String: urlStr, Valid: true},
CategoryCode: sql.NullString{String: categoryCode, Valid: true},
CountryCode: sql.NullString{String: countryCode, Valid: true},
}
newID, insErr := tx.Collection("urls").Insert(url)
if insErr != nil {
log.Error("Failed to insert into the URLs table")
return insErr
}
url.ID = sql.NullInt64{Int64: newID.ID().(int64), Valid: true}
} else if err != nil {
log.WithError(err).Error("Failed to get single result")
return err
} else {
url.CategoryCode = sql.NullString{String: categoryCode, Valid: true}
res.Update(url)
}
newID, insErr := tx.Collection("urls").Insert(url)
if insErr != nil {
log.Error("Failed to insert into the URLs table")
return 0, insErr
}
url.ID = sql.NullInt64{Int64: newID.(int64), Valid: true}
} else if err != nil {
log.WithError(err).Error("Failed to get single result")
return 0, err
} else {
url.CategoryCode = sql.NullString{String: categoryCode, Valid: true}
res.Update(url)
}
err = tx.Commit()
return nil
})
if err != nil {
log.WithError(err).Error("Failed to write to the URL table")
return 0, err
}
log.Debugf("returning url %d", url.ID.Int64)
return url.ID.Int64, nil
}
// AddTestKeys writes the summary to the measurement
func AddTestKeys(sess sqlbuilder.Database, msmt *Measurement, tk interface{}) error {
func AddTestKeys(sess db.Session, msmt *Measurement, tk interface{}) error {
var (
isAnomaly bool
isAnomalyValid bool
@@ -8,7 +8,7 @@ import (
"os"
"testing"
db "upper.io/db.v3"
"github.com/upper/db/v4"
)
type locationInfo struct {
+6 -6
View File
@@ -8,8 +8,8 @@ import (
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/netxlite"
migrate "github.com/rubenv/sql-migrate"
"upper.io/db.v3/lib/sqlbuilder"
"upper.io/db.v3/sqlite"
"github.com/upper/db/v4"
"github.com/upper/db/v4/adapter/sqlite"
)
//go:embed migrations/*.sql
@@ -36,14 +36,14 @@ func readAssetDir(path string) ([]string, error) {
}
// RunMigrations runs the database migrations
func RunMigrations(db *sql.DB) error {
func RunMigrations(sess *sql.DB) error {
log.Debugf("running migrations")
migrations := &migrate.AssetMigrationSource{
Asset: readAsset,
AssetDir: readAssetDir,
Dir: "migrations",
}
n, err := migrate.Exec(db, "sqlite3", migrations, migrate.Up)
n, err := migrate.Exec(sess, "sqlite3", migrations, migrate.Up)
if err != nil {
return err
}
@@ -52,12 +52,12 @@ func RunMigrations(db *sql.DB) error {
}
// Connect to the database
func Connect(path string) (db sqlbuilder.Database, err error) {
func Connect(path string) (sess db.Session, err error) {
settings := sqlite.ConnectionURL{
Database: path,
Options: map[string]string{"_foreign_keys": "1"},
}
sess, err := sqlite.Open(settings)
sess, err = sqlite.Open(settings)
if err != nil {
log.WithError(err).Error("failed to open the DB")
return nil, err
+6 -6
View File
@@ -5,7 +5,7 @@ import (
"time"
"github.com/pkg/errors"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
// ResultNetwork is used to represent the structure made from the JOIN
@@ -98,7 +98,7 @@ type PerformanceTestKeys struct {
}
// Finished marks the result as done and sets the runtime
func (r *Result) Finished(sess sqlbuilder.Database) error {
func (r *Result) Finished(sess db.Session) error {
if r.IsDone == true || r.Runtime != 0 {
return errors.New("Result is already finished")
}
@@ -113,7 +113,7 @@ func (r *Result) Finished(sess sqlbuilder.Database) error {
}
// Failed writes the error string to the measurement
func (m *Measurement) Failed(sess sqlbuilder.Database, failure string) error {
func (m *Measurement) Failed(sess db.Session, failure string) error {
m.FailureMsg = sql.NullString{String: failure, Valid: true}
m.IsFailed = true
err := sess.Collection("measurements").Find("measurement_id", m.ID).Update(m)
@@ -124,7 +124,7 @@ func (m *Measurement) Failed(sess sqlbuilder.Database, failure string) error {
}
// Done marks the measurement as completed
func (m *Measurement) Done(sess sqlbuilder.Database) error {
func (m *Measurement) Done(sess db.Session) error {
runtime := time.Now().UTC().Sub(m.StartTime)
m.Runtime = runtime.Seconds()
m.IsDone = true
@@ -137,7 +137,7 @@ func (m *Measurement) Done(sess sqlbuilder.Database) error {
}
// UploadFailed writes the error string for the upload failure to the measurement
func (m *Measurement) UploadFailed(sess sqlbuilder.Database, failure string) error {
func (m *Measurement) UploadFailed(sess db.Session, failure string) error {
m.UploadFailureMsg = sql.NullString{String: failure, Valid: true}
m.IsUploaded = false
@@ -149,7 +149,7 @@ func (m *Measurement) UploadFailed(sess sqlbuilder.Database, failure string) err
}
// UploadSucceeded writes the error string for the upload failure to the measurement
func (m *Measurement) UploadSucceeded(sess sqlbuilder.Database) error {
func (m *Measurement) UploadSucceeded(sess db.Session) error {
m.IsUploaded = true
err := sess.Collection("measurements").Find("measurement_id", m.ID).Update(m)
+4 -4
View File
@@ -13,7 +13,7 @@ import (
engine "github.com/ooni/probe-cli/v3/internal/engine"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/pkg/errors"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
// Nettest interface. Every Nettest should implement this.
@@ -78,7 +78,7 @@ type Controller struct {
//
// Arguments:
//
// - db is the database in which to register the URL;
// - sess is the database in which to register the URL;
//
// - testlist is the result from the check-in API (or possibly
// a manually constructed list when applicable, e.g., for dnscheck
@@ -90,13 +90,13 @@ type Controller struct {
//
// - on failure, an error.
func (c *Controller) BuildAndSetInputIdxMap(
db sqlbuilder.Database, testlist []model.OOAPIURLInfo) ([]string, error) {
sess db.Session, testlist []model.OOAPIURLInfo) ([]string, error) {
var urls []string
urlIDMap := make(map[int64]int64)
for idx, url := range testlist {
log.Debugf("Going over URL %d", idx)
urlID, err := database.CreateOrUpdateURL(
db, url.URL, url.CategoryCode, url.CountryCode,
sess, url.URL, url.CategoryCode, url.CountryCode,
)
if err != nil {
log.Error("failed to add to the URL table")
+4 -4
View File
@@ -19,7 +19,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/kvstore"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/pkg/errors"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
// DefaultSoftwareName is the default software name.
@@ -28,7 +28,7 @@ const DefaultSoftwareName = "ooniprobe-cli"
// ProbeCLI is the OONI Probe CLI context.
type ProbeCLI interface {
Config() *config.Config
DB() sqlbuilder.Database
DB() db.Session
IsBatch() bool
Home() string
TempDir() string
@@ -48,7 +48,7 @@ type ProbeEngine interface {
// Probe contains the ooniprobe CLI context.
type Probe struct {
config *config.Config
db sqlbuilder.Database
db db.Session
isBatch bool
home string
@@ -80,7 +80,7 @@ func (p *Probe) Config() *config.Config {
}
// DB returns the database we're using
func (p *Probe) DB() sqlbuilder.Database {
func (p *Probe) DB() db.Session {
return p.db
}
+3 -3
View File
@@ -9,7 +9,7 @@ import (
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/config"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni"
"github.com/ooni/probe-cli/v3/internal/model"
"upper.io/db.v3/lib/sqlbuilder"
"github.com/upper/db/v4"
)
// FakeOutput allows to fake the output package.
@@ -28,7 +28,7 @@ func (fo *FakeOutput) SectionTitle(s string) {
// FakeProbeCLI fakes ooni.ProbeCLI
type FakeProbeCLI struct {
FakeConfig *config.Config
FakeDB sqlbuilder.Database
FakeDB db.Session
FakeIsBatch bool
FakeHome string
FakeTempDir string
@@ -42,7 +42,7 @@ func (cli *FakeProbeCLI) Config() *config.Config {
}
// DB implements ProbeCLI.DB
func (cli *FakeProbeCLI) DB() sqlbuilder.Database {
func (cli *FakeProbeCLI) DB() db.Session {
return cli.FakeDB
}