feat: introduce database type (#982)

See https://github.com/ooni/probe/issues/2352

Co-authored-by: decfox <decfox@github.com>
This commit is contained in:
DecFox 2022-11-16 20:21:41 +05:30 committed by GitHub
parent 6b01264373
commit 28aabe0947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 144 additions and 94 deletions

View File

@ -7,7 +7,6 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/output" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/output"
"github.com/ooni/probe-cli/v3/internal/database"
) )
func init() { func init() {
@ -20,7 +19,7 @@ func init() {
return err return err
} }
if *resultID > 0 { if *resultID > 0 {
measurements, err := database.ListMeasurements(probeCLI.DB(), *resultID) measurements, err := probeCLI.DB().ListMeasurements(*resultID)
if err != nil { if err != nil {
log.WithError(err).Error("failed to list measurements") log.WithError(err).Error("failed to list measurements")
return err return err
@ -63,7 +62,7 @@ func init() {
} }
output.MeasurementSummary(msmtSummary) output.MeasurementSummary(msmtSummary)
} else { } else {
doneResults, incompleteResults, err := database.ListResults(probeCLI.DB()) doneResults, incompleteResults, err := probeCLI.DB().ListResults()
if err != nil { if err != nil {
log.WithError(err).Error("failed to list results") log.WithError(err).Error("failed to list results")
return err return err

View File

@ -25,7 +25,7 @@ func init() {
log.WithError(err).Error("failed to close the DB") log.WithError(err).Error("failed to close the DB")
return err return err
} }
if *force == true { if *force {
os.RemoveAll(ctx.Home()) os.RemoveAll(ctx.Home())
log.Infof("Deleted %s", ctx.Home()) log.Infof("Deleted %s", ctx.Home())
} else { } else {

View File

@ -12,7 +12,7 @@ import (
"github.com/upper/db/v4" "github.com/upper/db/v4"
) )
func deleteAll(sess db.Session, skipInteractive bool) error { func deleteAll(d *database.Database, skipInteractive bool) error {
if skipInteractive == false { if skipInteractive == false {
answer := "" answer := ""
confirm := &survey.Select{ confirm := &survey.Select{
@ -25,21 +25,21 @@ func deleteAll(sess db.Session, skipInteractive bool) error {
return errors.New("canceled by user") return errors.New("canceled by user")
} }
} }
doneResults, incompleteResults, err := database.ListResults(sess) doneResults, incompleteResults, err := d.ListResults()
if err != nil { if err != nil {
log.WithError(err).Error("failed to list results") log.WithError(err).Error("failed to list results")
return err return err
} }
cnt := 0 cnt := 0
for _, result := range incompleteResults { for _, result := range incompleteResults {
err = database.DeleteResult(sess, result.Result.ID) err = d.DeleteResult(result.Result.ID)
if err == db.ErrNoMoreRows { if err == db.ErrNoMoreRows {
log.WithError(err).Errorf("failed to delete result #%d", result.Result.ID) log.WithError(err).Errorf("failed to delete result #%d", result.Result.ID)
} }
cnt++ cnt++
} }
for _, result := range doneResults { for _, result := range doneResults {
err = database.DeleteResult(sess, result.Result.ID) err = d.DeleteResult(result.Result.ID)
if err == db.ErrNoMoreRows { if err == db.ErrNoMoreRows {
log.WithError(err).Errorf("failed to delete result #%d", result.Result.ID) log.WithError(err).Errorf("failed to delete result #%d", result.Result.ID)
} }
@ -68,7 +68,7 @@ func init() {
} }
if *yes == true { if *yes == true {
err = database.DeleteResult(ctx.DB(), *resultID) err = ctx.DB().DeleteResult(*resultID)
if err == db.ErrNoMoreRows { if err == db.ErrNoMoreRows {
return errors.New("result not found") return errors.New("result not found")
} }
@ -84,7 +84,7 @@ func init() {
if answer == "false" { if answer == "false" {
return errors.New("canceled by user") return errors.New("canceled by user")
} }
err = database.DeleteResult(ctx.DB(), *resultID) err = ctx.DB().DeleteResult(*resultID)
if err == db.ErrNoMoreRows { if err == db.ErrNoMoreRows {
return errors.New("result not found") return errors.New("result not found")
} }

View File

@ -5,7 +5,6 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/output" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/output"
"github.com/ooni/probe-cli/v3/internal/database"
) )
func init() { func init() {
@ -17,7 +16,7 @@ func init() {
log.WithError(err).Error("failed to initialize root context") log.WithError(err).Error("failed to initialize root context")
return err return err
} }
msmt, err := database.GetMeasurementJSON(ctx.DB(), *msmtID) msmt, err := ctx.DB().GetMeasurementJSON(*msmtID)
if err != nil { if err != nil {
log.Errorf("error: %v", err) log.Errorf("error: %v", err)
return err return err

View File

@ -25,7 +25,7 @@ func (n DNSCheck) lookupURLs(ctl *Controller) ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ctl.BuildAndSetInputIdxMap(ctl.Probe.DB(), testlist) return ctl.BuildAndSetInputIdxMap(testlist)
} }
// Run starts the nettest. // Run starts the nettest.

View File

@ -14,7 +14,6 @@ import (
engine "github.com/ooni/probe-cli/v3/internal/engine" engine "github.com/ooni/probe-cli/v3/internal/engine"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/upper/db/v4"
) )
// Nettest interface. Every Nettest should implement this. // Nettest interface. Every Nettest should implement this.
@ -90,14 +89,13 @@ type Controller struct {
// - on success, a list of strings containing URLs to test; // - on success, a list of strings containing URLs to test;
// //
// - on failure, an error. // - on failure, an error.
func (c *Controller) BuildAndSetInputIdxMap( func (c *Controller) BuildAndSetInputIdxMap(testlist []model.OOAPIURLInfo) ([]string, error) {
sess db.Session, testlist []model.OOAPIURLInfo) ([]string, error) {
var urls []string var urls []string
urlIDMap := make(map[int64]int64) urlIDMap := make(map[int64]int64)
for idx, url := range testlist { for idx, url := range testlist {
log.Debugf("Going over URL %d", idx) log.Debugf("Going over URL %d", idx)
urlID, err := database.CreateOrUpdateURL( urlID, err := c.Probe.DB().CreateOrUpdateURL(
sess, url.URL, url.CategoryCode, url.CountryCode, url.URL, url.CategoryCode, url.CountryCode,
) )
if err != nil { if err != nil {
log.Error("failed to add to the URL table") log.Error("failed to add to the URL table")
@ -124,6 +122,7 @@ func (c *Controller) SetNettestIndex(i, n int) {
// This function will continue to run in most cases but will // This function will continue to run in most cases but will
// immediately halt if something's wrong with the file system. // immediately halt if something's wrong with the file system.
func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error { func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error {
db := c.Probe.DB()
// This will configure the controller as handler for the callbacks // This will configure the controller as handler for the callbacks
// called by ooni/probe-engine/experiment.Experiment. // called by ooni/probe-engine/experiment.Experiment.
builder.SetCallbacks(model.ExperimentCallbacks(c)) builder.SetCallbacks(model.ExperimentCallbacks(c))
@ -168,6 +167,7 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error
log.Debug("disabling maxRuntime with user-provided input") log.Debug("disabling maxRuntime with user-provided input")
maxRuntime = 0 maxRuntime = 0
} }
sess := db.Session()
start := time.Now() start := time.Now()
c.ntStartTime = start c.ntStartTime = start
for idx, input := range inputs { for idx, input := range inputs {
@ -187,8 +187,8 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error
urlID = sql.NullInt64{Int64: c.inputIdxMap[idx64], Valid: true} urlID = sql.NullInt64{Int64: c.inputIdxMap[idx64], Valid: true}
} }
msmt, err := database.CreateMeasurement( msmt, err := db.CreateMeasurement(
c.Probe.DB(), reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID, reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID,
) )
if err != nil { if err != nil {
return errors.Wrap(err, "failed to create measurement") return errors.Wrap(err, "failed to create measurement")
@ -201,7 +201,7 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error
measurement, err := exp.MeasureWithContext(context.Background(), input) measurement, err := exp.MeasureWithContext(context.Background(), input)
if err != nil { if err != nil {
log.WithError(err).Debug(color.RedString("failure.measurement")) log.WithError(err).Debug(color.RedString("failure.measurement"))
if err := c.msmts[idx64].Failed(c.Probe.DB(), err.Error()); err != nil { if err := c.msmts[idx64].Failed(sess, err.Error()); err != nil {
return errors.Wrap(err, "failed to mark measurement as failed") return errors.Wrap(err, "failed to mark measurement as failed")
} }
// Since https://github.com/ooni/probe-cli/pull/527, the Measure // Since https://github.com/ooni/probe-cli/pull/527, the Measure
@ -221,10 +221,10 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error
// bit of a spew in the logs, perhaps, but stopping seems less efficient. // bit of a spew in the logs, perhaps, but stopping seems less efficient.
if err := exp.SubmitAndUpdateMeasurementContext(context.Background(), measurement); err != nil { if err := exp.SubmitAndUpdateMeasurementContext(context.Background(), measurement); err != nil {
log.Debug(color.RedString("failure.measurement_submission")) log.Debug(color.RedString("failure.measurement_submission"))
if err := c.msmts[idx64].UploadFailed(c.Probe.DB(), err.Error()); err != nil { if err := c.msmts[idx64].UploadFailed(sess, err.Error()); err != nil {
return errors.Wrap(err, "failed to mark upload as failed") return errors.Wrap(err, "failed to mark upload as failed")
} }
} else if err := c.msmts[idx64].UploadSucceeded(c.Probe.DB()); err != nil { } else if err := c.msmts[idx64].UploadSucceeded(sess); err != nil {
return errors.Wrap(err, "failed to mark upload as succeeded") return errors.Wrap(err, "failed to mark upload as succeeded")
} else { } else {
// Everything went OK, don't save to disk // Everything went OK, don't save to disk
@ -238,7 +238,7 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error
} }
} }
if err := c.msmts[idx64].Done(c.Probe.DB()); err != nil { if err := c.msmts[idx64].Done(sess); err != nil {
return errors.Wrap(err, "failed to mark measurement as done") return errors.Wrap(err, "failed to mark measurement as done")
} }
@ -253,11 +253,11 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error
continue continue
} }
log.Debugf("Fetching: %d %v", idx, c.msmts[idx64]) log.Debugf("Fetching: %d %v", idx, c.msmts[idx64])
if err := database.AddTestKeys(c.Probe.DB(), c.msmts[idx64], tk); err != nil { if err := db.AddTestKeys(c.msmts[idx64], tk); err != nil {
return errors.Wrap(err, "failed to add test keys to summary") return errors.Wrap(err, "failed to add test keys to summary")
} }
} }
database.UpdateUploadedStatus(c.Probe.DB(), c.res) db.UpdateUploadedStatus(c.res)
log.Debugf("status.end") log.Debugf("status.end")
return nil return nil
} }

View File

@ -8,7 +8,6 @@ import (
"testing" "testing"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni"
"github.com/ooni/probe-cli/v3/internal/database"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
@ -53,11 +52,12 @@ func TestRun(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
network, err := database.CreateNetwork(probe.DB(), sess) db := probe.DB()
network, err := db.CreateNetwork(sess)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := database.CreateResult(probe.DB(), probe.Home(), "middlebox", network.ID) res, err := db.CreateResult(probe.Home(), "middlebox", network.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -8,7 +8,6 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni"
"github.com/ooni/probe-cli/v3/internal/database"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -72,7 +71,8 @@ func RunGroup(config RunGroupConfig) error {
log.WithError(err).Error("Failed to lookup the location of the probe") log.WithError(err).Error("Failed to lookup the location of the probe")
return err return err
} }
network, err := database.CreateNetwork(config.Probe.DB(), sess) db := config.Probe.DB()
network, err := db.CreateNetwork(sess)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to create the network row") log.WithError(err).Error("Failed to create the network row")
return err return err
@ -89,8 +89,8 @@ func RunGroup(config RunGroupConfig) error {
} }
log.Debugf("Running test group %s", group.Label) log.Debugf("Running test group %s", group.Label)
result, err := database.CreateResult( result, err := db.CreateResult(
config.Probe.DB(), config.Probe.Home(), config.GroupName, network.ID) config.Probe.Home(), config.GroupName, network.ID)
if err != nil { if err != nil {
log.Errorf("DB result error: %s", err) log.Errorf("DB result error: %s", err)
return err return err
@ -131,8 +131,8 @@ func RunGroup(config RunGroupConfig) error {
if err != nil { if err != nil {
os.Remove(result.MeasurementDir) os.Remove(result.MeasurementDir)
} }
dbSess := db.Session()
if err = result.Finished(config.Probe.DB()); err != nil { if err = result.Finished(dbSess); err != nil {
return err return err
} }
return nil return nil

View File

@ -25,7 +25,7 @@ func (n STUNReachability) lookupURLs(ctl *Controller) ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ctl.BuildAndSetInputIdxMap(ctl.Probe.DB(), testlist) return ctl.BuildAndSetInputIdxMap(testlist)
} }
// Run starts the nettest. // Run starts the nettest.

View File

@ -31,7 +31,7 @@ func (n WebConnectivity) lookupURLs(ctl *Controller, categories []string) ([]str
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ctl.BuildAndSetInputIdxMap(ctl.Probe.DB(), testlist) return ctl.BuildAndSetInputIdxMap(testlist)
} }
// WebConnectivity test implementation // WebConnectivity test implementation

View File

@ -19,7 +19,6 @@ import (
"github.com/ooni/probe-cli/v3/internal/legacy/assetsdir" "github.com/ooni/probe-cli/v3/internal/legacy/assetsdir"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/upper/db/v4"
) )
// DefaultSoftwareName is the default software name. // DefaultSoftwareName is the default software name.
@ -33,7 +32,7 @@ var logger = log.WithFields(log.Fields{
// ProbeCLI is the OONI Probe CLI context. // ProbeCLI is the OONI Probe CLI context.
type ProbeCLI interface { type ProbeCLI interface {
Config() *config.Config Config() *config.Config
DB() db.Session DB() *database.Database
IsBatch() bool IsBatch() bool
Home() string Home() string
TempDir() string TempDir() string
@ -53,7 +52,7 @@ type ProbeEngine interface {
// Probe contains the ooniprobe CLI context. // Probe contains the ooniprobe CLI context.
type Probe struct { type Probe struct {
config *config.Config config *config.Config
db db.Session db *database.Database
isBatch bool isBatch bool
home string home string
@ -86,7 +85,7 @@ func (p *Probe) Config() *config.Config {
} }
// DB returns the database we're using // DB returns the database we're using
func (p *Probe) DB() db.Session { func (p *Probe) DB() *database.Database {
return p.db return p.db
} }
@ -180,7 +179,7 @@ func (p *Probe) Init(softwareName, softwareVersion, proxy string) error {
p.dbPath = utils.DBDir(p.home, "main") p.dbPath = utils.DBDir(p.home, "main")
log.Debugf("Connecting to database sqlite3://%s", p.dbPath) log.Debugf("Connecting to database sqlite3://%s", p.dbPath)
db, err := database.Connect(p.dbPath) db, err := database.Open(p.dbPath)
if err != nil { if err != nil {
return err return err
} }

View File

@ -8,8 +8,8 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/config" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/config"
"github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni"
"github.com/ooni/probe-cli/v3/internal/database"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/upper/db/v4"
) )
// FakeOutput allows to fake the output package. // FakeOutput allows to fake the output package.
@ -28,7 +28,7 @@ func (fo *FakeOutput) SectionTitle(s string) {
// FakeProbeCLI fakes ooni.ProbeCLI // FakeProbeCLI fakes ooni.ProbeCLI
type FakeProbeCLI struct { type FakeProbeCLI struct {
FakeConfig *config.Config FakeConfig *config.Config
FakeDB db.Session FakeDB *database.Database
FakeIsBatch bool FakeIsBatch bool
FakeHome string FakeHome string
FakeTempDir string FakeTempDir string
@ -42,7 +42,7 @@ func (cli *FakeProbeCLI) Config() *config.Config {
} }
// DB implements ProbeCLI.DB // DB implements ProbeCLI.DB
func (cli *FakeProbeCLI) DB() db.Session { func (cli *FakeProbeCLI) DB() *database.Database {
return cli.FakeDB return cli.FakeDB
} }

View File

@ -17,10 +17,31 @@ import (
"github.com/upper/db/v4" "github.com/upper/db/v4"
) )
// Open returns a new database instance
func Open(dbpath string) (*Database, error) {
db, err := Connect(dbpath)
if err != nil {
return nil, err
}
return &Database{
sess: db,
}, nil
}
// Database is a database instance to store measurements
type Database struct {
sess db.Session
}
// Session returns the database session
func (d *Database) Session() db.Session {
return d.sess
}
// ListMeasurements given a result ID // ListMeasurements given a result ID
func ListMeasurements(sess db.Session, resultID int64) ([]MeasurementURLNetwork, error) { func (d *Database) ListMeasurements(resultID int64) ([]MeasurementURLNetwork, error) {
measurements := []MeasurementURLNetwork{} measurements := []MeasurementURLNetwork{}
req := sess.SQL().Select( req := d.sess.SQL().Select(
db.Raw("networks.*"), db.Raw("networks.*"),
db.Raw("urls.*"), db.Raw("urls.*"),
db.Raw("measurements.*"), db.Raw("measurements.*"),
@ -39,12 +60,12 @@ func ListMeasurements(sess db.Session, resultID int64) ([]MeasurementURLNetwork,
} }
// GetMeasurementJSON returns a map[string]interface{} given a database and a measurementID // GetMeasurementJSON returns a map[string]interface{} given a database and a measurementID
func GetMeasurementJSON(sess db.Session, measurementID int64) (map[string]interface{}, error) { func (d *Database) GetMeasurementJSON(measurementID int64) (map[string]interface{}, error) {
var ( var (
measurement MeasurementURLNetwork measurement MeasurementURLNetwork
msmtJSON map[string]interface{} msmtJSON map[string]interface{}
) )
req := sess.SQL().Select( req := d.sess.SQL().Select(
db.Raw("urls.*"), db.Raw("urls.*"),
db.Raw("measurements.*"), db.Raw("measurements.*"),
).From("measurements"). ).From("measurements").
@ -102,10 +123,10 @@ func GetMeasurementJSON(sess db.Session, measurementID int64) (map[string]interf
} }
// ListResults return the list of results // ListResults return the list of results
func ListResults(sess db.Session) ([]ResultNetwork, []ResultNetwork, error) { func (d *Database) ListResults() ([]ResultNetwork, []ResultNetwork, error) {
doneResults := []ResultNetwork{} doneResults := []ResultNetwork{}
incompleteResults := []ResultNetwork{} incompleteResults := []ResultNetwork{}
req := sess.SQL().Select( req := d.sess.SQL().Select(
db.Raw("networks.network_name"), db.Raw("networks.network_name"),
db.Raw("networks.network_type"), db.Raw("networks.network_type"),
db.Raw("networks.ip"), db.Raw("networks.ip"),
@ -165,9 +186,9 @@ func ListResults(sess db.Session) ([]ResultNetwork, []ResultNetwork, error) {
// DeleteResult will delete a particular result and the relative measurement on // DeleteResult will delete a particular result and the relative measurement on
// disk. // disk.
func DeleteResult(sess db.Session, resultID int64) error { func (d *Database) DeleteResult(resultID int64) error {
var result Result var result Result
res := sess.Collection("results").Find("result_id", resultID) res := d.sess.Collection("results").Find("result_id", resultID)
if err := res.One(&result); err != nil { if err := res.One(&result); err != nil {
if err == db.ErrNoMoreRows { if err == db.ErrNoMoreRows {
return err return err
@ -185,8 +206,8 @@ func DeleteResult(sess db.Session, resultID int64) error {
} }
// UpdateUploadedStatus will check if all the measurements inside of a given result set have been uploaded and if so will set the is_uploaded flag to true // UpdateUploadedStatus will check if all the measurements inside of a given result set have been uploaded and if so will set the is_uploaded flag to true
func UpdateUploadedStatus(sess db.Session, result *Result) error { func (d *Database) UpdateUploadedStatus(result *Result) error {
err := sess.Tx(func(tx db.Session) error { err := d.sess.Tx(func(tx db.Session) error {
uploadedTotal := UploadedTotalCount{} uploadedTotal := UploadedTotalCount{}
req := tx.SQL().Select( req := tx.SQL().Select(
db.Raw("SUM(measurements.measurement_is_uploaded)"), db.Raw("SUM(measurements.measurement_is_uploaded)"),
@ -222,7 +243,7 @@ func UpdateUploadedStatus(sess db.Session, result *Result) error {
// CreateMeasurement writes the measurement to the database a returns a pointer // CreateMeasurement writes the measurement to the database a returns a pointer
// to the Measurement // to the Measurement
func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) { func (d *Database) CreateMeasurement(reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) {
// TODO we should look into generating this file path in a more robust way. // TODO we should look into generating this file path in a more robust way.
// If there are two identical test_names in the same test group there is // If there are two identical test_names in the same test group there is
// going to be a clash of test_name // going to be a clash of test_name
@ -240,7 +261,7 @@ func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string
TestKeys: "", TestKeys: "",
} }
newID, err := sess.Collection("measurements").Insert(msmt) newID, err := d.sess.Collection("measurements").Insert(msmt)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "creating measurement") return nil, errors.Wrap(err, "creating measurement")
} }
@ -250,7 +271,7 @@ func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string
// CreateResult writes the Result to the database a returns a pointer // CreateResult writes the Result to the database a returns a pointer
// to the Result // to the Result
func CreateResult(sess db.Session, homePath string, testGroupName string, networkID int64) (*Result, error) { func (d *Database) CreateResult(homePath string, testGroupName string, networkID int64) (*Result, error) {
startTime := time.Now().UTC() startTime := time.Now().UTC()
p, err := makeResultsDir(homePath, testGroupName, startTime) p, err := makeResultsDir(homePath, testGroupName, startTime)
@ -266,7 +287,7 @@ func CreateResult(sess db.Session, homePath string, testGroupName string, networ
result.MeasurementDir = p result.MeasurementDir = p
log.Debugf("Creating result %v", result) log.Debugf("Creating result %v", result)
newID, err := sess.Collection("results").Insert(result) newID, err := d.sess.Collection("results").Insert(result)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "creating result") return nil, errors.Wrap(err, "creating result")
} }
@ -275,7 +296,7 @@ func CreateResult(sess db.Session, homePath string, testGroupName string, networ
} }
// CreateNetwork will create a new network in the network table // CreateNetwork will create a new network in the network table
func CreateNetwork(sess db.Session, loc engine.LocationProvider) (*Network, error) { func (d *Database) CreateNetwork(loc engine.LocationProvider) (*Network, error) {
network := Network{ network := Network{
ASN: loc.ProbeASN(), ASN: loc.ProbeASN(),
CountryCode: loc.ProbeCC(), CountryCode: loc.ProbeCC(),
@ -284,7 +305,7 @@ func CreateNetwork(sess db.Session, loc engine.LocationProvider) (*Network, erro
NetworkType: "wifi", NetworkType: "wifi",
IP: loc.ProbeIP(), IP: loc.ProbeIP(),
} }
newID, err := sess.Collection("networks").Insert(network) newID, err := d.sess.Collection("networks").Insert(network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -296,10 +317,10 @@ func CreateNetwork(sess db.Session, loc engine.LocationProvider) (*Network, erro
// CreateOrUpdateURL will create a new URL entry to the urls table if it doesn't // CreateOrUpdateURL will create a new URL entry to the urls table if it doesn't
// exists, otherwise it will update the category code of the one already in // exists, otherwise it will update the category code of the one already in
// there. // there.
func CreateOrUpdateURL(sess db.Session, urlStr string, categoryCode string, countryCode string) (int64, error) { func (d *Database) CreateOrUpdateURL(urlStr string, categoryCode string, countryCode string) (int64, error) {
var url URL var url URL
err := sess.Tx(func(tx db.Session) error { err := d.sess.Tx(func(tx db.Session) error {
res := tx.Collection("urls").Find( res := tx.Collection("urls").Find(
db.Cond{"url": urlStr, "url_country_code": countryCode}, db.Cond{"url": urlStr, "url_country_code": countryCode},
) )
@ -336,7 +357,7 @@ func CreateOrUpdateURL(sess db.Session, urlStr string, categoryCode string, coun
} }
// AddTestKeys writes the summary to the measurement // AddTestKeys writes the summary to the measurement
func AddTestKeys(sess db.Session, msmt *Measurement, tk interface{}) error { func (d *Database) AddTestKeys(msmt *Measurement, tk interface{}) error {
var ( var (
isAnomaly bool isAnomaly bool
isAnomalyValid bool isAnomalyValid bool
@ -357,10 +378,15 @@ func AddTestKeys(sess db.Session, msmt *Measurement, tk interface{}) error {
msmt.TestKeys = string(tkBytes) msmt.TestKeys = string(tkBytes)
msmt.IsAnomaly = sql.NullBool{Bool: isAnomaly, Valid: isAnomalyValid} msmt.IsAnomaly = sql.NullBool{Bool: isAnomaly, Valid: isAnomalyValid}
err = sess.Collection("measurements").Find("measurement_id", msmt.ID).Update(msmt) err = d.sess.Collection("measurements").Find("measurement_id", msmt.ID).Update(msmt)
if err != nil { if err != nil {
log.WithError(err).Error("failed to update measurement") log.WithError(err).Error("failed to update measurement")
return errors.Wrap(err, "updating measurement") return errors.Wrap(err, "updating measurement")
} }
return nil return nil
} }
// Close closes the database session
func (d *Database) Close() error {
return d.sess.Close()
}

View File

@ -43,6 +43,31 @@ func (lp *locationInfo) ResolverIP() string {
return lp.resolverIP return lp.resolverIP
} }
func TestNewDatabase(t *testing.T) {
t.Run("with empty path", func(t *testing.T) {
dbpath := ""
database, err := Open(dbpath)
if database != nil {
t.Fatal("unexpected database instance")
}
if err.Error() != "Expecting file:// connection scheme." {
t.Fatal(err)
}
})
t.Run("with valid path", func(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "dbtest")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
_, err = Open(tmpfile.Name())
if err != nil {
t.Fatal(err)
}
})
}
func TestMeasurementWorkflow(t *testing.T) { func TestMeasurementWorkflow(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "dbtest") tmpfile, err := ioutil.TempFile("", "dbtest")
if err != nil { if err != nil {
@ -56,7 +81,7 @@ func TestMeasurementWorkflow(t *testing.T) {
} }
defer os.RemoveAll(tmpdir) defer os.RemoveAll(tmpdir)
sess, err := Connect(tmpfile.Name()) database, err := Open(tmpfile.Name())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -66,12 +91,13 @@ func TestMeasurementWorkflow(t *testing.T) {
countryCode: "IT", countryCode: "IT",
networkName: "Unknown", networkName: "Unknown",
} }
network, err := CreateNetwork(sess, &location) sess := database.Session()
network, err := database.CreateNetwork(&location)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
result, err := CreateResult(sess, tmpdir, "websites", network.ID) result, err := database.CreateResult(tmpdir, "websites", network.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -82,7 +108,7 @@ func TestMeasurementWorkflow(t *testing.T) {
msmtFilePath := tmpdir msmtFilePath := tmpdir
urlID := sql.NullInt64{Int64: 0, Valid: false} urlID := sql.NullInt64{Int64: 0, Valid: false}
m1, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) m1, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -93,7 +119,7 @@ func TestMeasurementWorkflow(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
m2, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) m2, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -107,7 +133,7 @@ func TestMeasurementWorkflow(t *testing.T) {
if m2.ResultID != m1.ResultID { if m2.ResultID != m1.ResultID {
t.Error("result_id mismatch") t.Error("result_id mismatch")
} }
err = UpdateUploadedStatus(sess, result) err = database.UpdateUploadedStatus(result)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -122,7 +148,7 @@ func TestMeasurementWorkflow(t *testing.T) {
t.Error("result should be marked as not uploaded") t.Error("result should be marked as not uploaded")
} }
done, incomplete, err := ListResults(sess) done, incomplete, err := database.ListResults()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -141,7 +167,7 @@ func TestMeasurementWorkflow(t *testing.T) {
t.Error("there should be a total of 1 anomalies in the result") t.Error("there should be a total of 1 anomalies in the result")
} }
msmts, err := ListMeasurements(sess, resultID) msmts, err := database.ListMeasurements(resultID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -163,7 +189,7 @@ func TestDeleteResult(t *testing.T) {
} }
defer os.RemoveAll(tmpdir) defer os.RemoveAll(tmpdir)
sess, err := Connect(tmpfile.Name()) database, err := Open(tmpfile.Name())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -173,12 +199,13 @@ func TestDeleteResult(t *testing.T) {
countryCode: "IT", countryCode: "IT",
networkName: "Unknown", networkName: "Unknown",
} }
network, err := CreateNetwork(sess, &location) sess := database.Session()
network, err := database.CreateNetwork(&location)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
result, err := CreateResult(sess, tmpdir, "websites", network.ID) result, err := database.CreateResult(tmpdir, "websites", network.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -189,7 +216,7 @@ func TestDeleteResult(t *testing.T) {
msmtFilePath := tmpdir msmtFilePath := tmpdir
urlID := sql.NullInt64{Int64: 0, Valid: false} urlID := sql.NullInt64{Int64: 0, Valid: false}
m1, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) m1, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -204,7 +231,7 @@ func TestDeleteResult(t *testing.T) {
t.Error("result_id mismatch") t.Error("result_id mismatch")
} }
err = DeleteResult(sess, resultID) err = database.DeleteResult(resultID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -223,7 +250,7 @@ func TestDeleteResult(t *testing.T) {
t.Fatal("measurements should be zero") t.Fatal("measurements should be zero")
} }
err = DeleteResult(sess, 20) err = database.DeleteResult(20)
if err != db.ErrNoMoreRows { if err != db.ErrNoMoreRows {
t.Fatal(err) t.Fatal(err)
} }
@ -236,7 +263,7 @@ func TestNetworkCreate(t *testing.T) {
} }
defer os.Remove(tmpfile.Name()) defer os.Remove(tmpfile.Name())
sess, err := Connect(tmpfile.Name()) database, err := Open(tmpfile.Name())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -253,12 +280,12 @@ func TestNetworkCreate(t *testing.T) {
networkName: "Fufnet", networkName: "Fufnet",
} }
_, err = CreateNetwork(sess, &l1) _, err = database.CreateNetwork(&l1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = CreateNetwork(sess, &l2) _, err = database.CreateNetwork(&l2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -272,32 +299,32 @@ func TestURLCreation(t *testing.T) {
} }
defer os.Remove(tmpfile.Name()) defer os.Remove(tmpfile.Name())
sess, err := Connect(tmpfile.Name()) database, err := Open(tmpfile.Name())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
newID1, err := CreateOrUpdateURL(sess, "https://google.com", "GMB", "XX") newID1, err := database.CreateOrUpdateURL("https://google.com", "GMB", "XX")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
newID2, err := CreateOrUpdateURL(sess, "https://google.com", "SRCH", "XX") newID2, err := database.CreateOrUpdateURL("https://google.com", "SRCH", "XX")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
newID3, err := CreateOrUpdateURL(sess, "https://facebook.com", "GRP", "XX") newID3, err := database.CreateOrUpdateURL("https://facebook.com", "GRP", "XX")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
newID4, err := CreateOrUpdateURL(sess, "https://facebook.com", "GMP", "XX") newID4, err := database.CreateOrUpdateURL("https://facebook.com", "GMP", "XX")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
newID5, err := CreateOrUpdateURL(sess, "https://google.com", "SRCH", "XX") newID5, err := database.CreateOrUpdateURL("https://google.com", "SRCH", "XX")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -351,22 +378,22 @@ func TestGetMeasurementJSON(t *testing.T) {
} }
defer os.RemoveAll(tmpdir) defer os.RemoveAll(tmpdir)
sess, err := Connect(tmpfile.Name()) database, err := Open(tmpfile.Name())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
sess := database.Session()
location := locationInfo{ location := locationInfo{
asn: 0, asn: 0,
countryCode: "IT", countryCode: "IT",
networkName: "Unknown", networkName: "Unknown",
} }
network, err := CreateNetwork(sess, &location) network, err := database.CreateNetwork(&location)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
result, err := CreateResult(sess, tmpdir, "websites", network.ID) result, err := database.CreateResult(tmpdir, "websites", network.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -377,7 +404,7 @@ func TestGetMeasurementJSON(t *testing.T) {
msmtFilePath := tmpdir msmtFilePath := tmpdir
urlID := sql.NullInt64{Int64: 0, Valid: false} urlID := sql.NullInt64{Int64: 0, Valid: false}
msmt, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) msmt, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -387,7 +414,7 @@ func TestGetMeasurementJSON(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tk, err := GetMeasurementJSON(sess, msmt.ID) tk, err := database.GetMeasurementJSON(msmt.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }