From 28aabe0947acfae9ba50a0eadcd93e38c098caff Mon Sep 17 00:00:00 2001 From: DecFox <33030671+DecFox@users.noreply.github.com> Date: Wed, 16 Nov 2022 20:21:41 +0530 Subject: [PATCH] feat: introduce database type (#982) See https://github.com/ooni/probe/issues/2352 Co-authored-by: decfox --- cmd/ooniprobe/internal/cli/list/list.go | 5 +- cmd/ooniprobe/internal/cli/reset/reset.go | 2 +- cmd/ooniprobe/internal/cli/rm/rm.go | 12 +-- cmd/ooniprobe/internal/cli/show/show.go | 3 +- cmd/ooniprobe/internal/nettests/dnscheck.go | 2 +- cmd/ooniprobe/internal/nettests/nettests.go | 26 +++--- .../internal/nettests/nettests_test.go | 6 +- cmd/ooniprobe/internal/nettests/run.go | 12 +-- .../internal/nettests/stunreachability.go | 2 +- .../internal/nettests/web_connectivity.go | 2 +- cmd/ooniprobe/internal/ooni/ooni.go | 9 +- cmd/ooniprobe/internal/oonitest/oonitest.go | 6 +- internal/database/actions.go | 66 +++++++++----- internal/database/actions_test.go | 85 ++++++++++++------- 14 files changed, 144 insertions(+), 94 deletions(-) diff --git a/cmd/ooniprobe/internal/cli/list/list.go b/cmd/ooniprobe/internal/cli/list/list.go index 4d847a0..ea812fc 100644 --- a/cmd/ooniprobe/internal/cli/list/list.go +++ b/cmd/ooniprobe/internal/cli/list/list.go @@ -7,7 +7,6 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/output" - "github.com/ooni/probe-cli/v3/internal/database" ) func init() { @@ -20,7 +19,7 @@ func init() { return err } if *resultID > 0 { - measurements, err := database.ListMeasurements(probeCLI.DB(), *resultID) + measurements, err := probeCLI.DB().ListMeasurements(*resultID) if err != nil { log.WithError(err).Error("failed to list measurements") return err @@ -63,7 +62,7 @@ func init() { } output.MeasurementSummary(msmtSummary) } else { - doneResults, incompleteResults, err := database.ListResults(probeCLI.DB()) + doneResults, incompleteResults, err := probeCLI.DB().ListResults() if err != nil { log.WithError(err).Error("failed to list results") return err diff --git a/cmd/ooniprobe/internal/cli/reset/reset.go b/cmd/ooniprobe/internal/cli/reset/reset.go index 7841303..9ba37a0 100644 --- a/cmd/ooniprobe/internal/cli/reset/reset.go +++ b/cmd/ooniprobe/internal/cli/reset/reset.go @@ -25,7 +25,7 @@ func init() { log.WithError(err).Error("failed to close the DB") return err } - if *force == true { + if *force { os.RemoveAll(ctx.Home()) log.Infof("Deleted %s", ctx.Home()) } else { diff --git a/cmd/ooniprobe/internal/cli/rm/rm.go b/cmd/ooniprobe/internal/cli/rm/rm.go index 4dafb54..d7f9d1f 100644 --- a/cmd/ooniprobe/internal/cli/rm/rm.go +++ b/cmd/ooniprobe/internal/cli/rm/rm.go @@ -12,7 +12,7 @@ import ( "github.com/upper/db/v4" ) -func deleteAll(sess db.Session, skipInteractive bool) error { +func deleteAll(d *database.Database, skipInteractive bool) error { if skipInteractive == false { answer := "" confirm := &survey.Select{ @@ -25,21 +25,21 @@ func deleteAll(sess db.Session, skipInteractive bool) error { return errors.New("canceled by user") } } - doneResults, incompleteResults, err := database.ListResults(sess) + doneResults, incompleteResults, err := d.ListResults() if err != nil { log.WithError(err).Error("failed to list results") return err } cnt := 0 for _, result := range incompleteResults { - err = database.DeleteResult(sess, result.Result.ID) + err = d.DeleteResult(result.Result.ID) if err == db.ErrNoMoreRows { log.WithError(err).Errorf("failed to delete result #%d", result.Result.ID) } cnt++ } for _, result := range doneResults { - err = database.DeleteResult(sess, result.Result.ID) + err = d.DeleteResult(result.Result.ID) if err == db.ErrNoMoreRows { log.WithError(err).Errorf("failed to delete result #%d", result.Result.ID) } @@ -68,7 +68,7 @@ func init() { } if *yes == true { - err = database.DeleteResult(ctx.DB(), *resultID) + err = ctx.DB().DeleteResult(*resultID) if err == db.ErrNoMoreRows { return errors.New("result not found") } @@ -84,7 +84,7 @@ func init() { if answer == "false" { return errors.New("canceled by user") } - err = database.DeleteResult(ctx.DB(), *resultID) + err = ctx.DB().DeleteResult(*resultID) if err == db.ErrNoMoreRows { return errors.New("result not found") } diff --git a/cmd/ooniprobe/internal/cli/show/show.go b/cmd/ooniprobe/internal/cli/show/show.go index 973c053..f093b2d 100644 --- a/cmd/ooniprobe/internal/cli/show/show.go +++ b/cmd/ooniprobe/internal/cli/show/show.go @@ -5,7 +5,6 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/cli/root" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/output" - "github.com/ooni/probe-cli/v3/internal/database" ) func init() { @@ -17,7 +16,7 @@ func init() { log.WithError(err).Error("failed to initialize root context") return err } - msmt, err := database.GetMeasurementJSON(ctx.DB(), *msmtID) + msmt, err := ctx.DB().GetMeasurementJSON(*msmtID) if err != nil { log.Errorf("error: %v", err) return err diff --git a/cmd/ooniprobe/internal/nettests/dnscheck.go b/cmd/ooniprobe/internal/nettests/dnscheck.go index 54bf283..beebcd3 100644 --- a/cmd/ooniprobe/internal/nettests/dnscheck.go +++ b/cmd/ooniprobe/internal/nettests/dnscheck.go @@ -25,7 +25,7 @@ func (n DNSCheck) lookupURLs(ctl *Controller) ([]string, error) { if err != nil { return nil, err } - return ctl.BuildAndSetInputIdxMap(ctl.Probe.DB(), testlist) + return ctl.BuildAndSetInputIdxMap(testlist) } // Run starts the nettest. diff --git a/cmd/ooniprobe/internal/nettests/nettests.go b/cmd/ooniprobe/internal/nettests/nettests.go index a13b168..f872205 100644 --- a/cmd/ooniprobe/internal/nettests/nettests.go +++ b/cmd/ooniprobe/internal/nettests/nettests.go @@ -14,7 +14,6 @@ import ( engine "github.com/ooni/probe-cli/v3/internal/engine" "github.com/ooni/probe-cli/v3/internal/model" "github.com/pkg/errors" - "github.com/upper/db/v4" ) // Nettest interface. Every Nettest should implement this. @@ -90,14 +89,13 @@ type Controller struct { // - on success, a list of strings containing URLs to test; // // - on failure, an error. -func (c *Controller) BuildAndSetInputIdxMap( - sess db.Session, testlist []model.OOAPIURLInfo) ([]string, error) { +func (c *Controller) BuildAndSetInputIdxMap(testlist []model.OOAPIURLInfo) ([]string, error) { var urls []string urlIDMap := make(map[int64]int64) for idx, url := range testlist { log.Debugf("Going over URL %d", idx) - urlID, err := database.CreateOrUpdateURL( - sess, url.URL, url.CategoryCode, url.CountryCode, + urlID, err := c.Probe.DB().CreateOrUpdateURL( + url.URL, url.CategoryCode, url.CountryCode, ) if err != nil { log.Error("failed to add to the URL table") @@ -124,6 +122,7 @@ func (c *Controller) SetNettestIndex(i, n int) { // This function will continue to run in most cases but will // immediately halt if something's wrong with the file system. func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error { + db := c.Probe.DB() // This will configure the controller as handler for the callbacks // called by ooni/probe-engine/experiment.Experiment. builder.SetCallbacks(model.ExperimentCallbacks(c)) @@ -168,6 +167,7 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error log.Debug("disabling maxRuntime with user-provided input") maxRuntime = 0 } + sess := db.Session() start := time.Now() c.ntStartTime = start for idx, input := range inputs { @@ -187,8 +187,8 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error urlID = sql.NullInt64{Int64: c.inputIdxMap[idx64], Valid: true} } - msmt, err := database.CreateMeasurement( - c.Probe.DB(), reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID, + msmt, err := db.CreateMeasurement( + reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID, ) if err != nil { return errors.Wrap(err, "failed to create measurement") @@ -201,7 +201,7 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error measurement, err := exp.MeasureWithContext(context.Background(), input) if err != nil { log.WithError(err).Debug(color.RedString("failure.measurement")) - if err := c.msmts[idx64].Failed(c.Probe.DB(), err.Error()); err != nil { + if err := c.msmts[idx64].Failed(sess, err.Error()); err != nil { return errors.Wrap(err, "failed to mark measurement as failed") } // Since https://github.com/ooni/probe-cli/pull/527, the Measure @@ -221,10 +221,10 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error // bit of a spew in the logs, perhaps, but stopping seems less efficient. if err := exp.SubmitAndUpdateMeasurementContext(context.Background(), measurement); err != nil { log.Debug(color.RedString("failure.measurement_submission")) - if err := c.msmts[idx64].UploadFailed(c.Probe.DB(), err.Error()); err != nil { + if err := c.msmts[idx64].UploadFailed(sess, err.Error()); err != nil { return errors.Wrap(err, "failed to mark upload as failed") } - } else if err := c.msmts[idx64].UploadSucceeded(c.Probe.DB()); err != nil { + } else if err := c.msmts[idx64].UploadSucceeded(sess); err != nil { return errors.Wrap(err, "failed to mark upload as succeeded") } else { // Everything went OK, don't save to disk @@ -238,7 +238,7 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error } } - if err := c.msmts[idx64].Done(c.Probe.DB()); err != nil { + if err := c.msmts[idx64].Done(sess); err != nil { return errors.Wrap(err, "failed to mark measurement as done") } @@ -253,11 +253,11 @@ func (c *Controller) Run(builder model.ExperimentBuilder, inputs []string) error continue } log.Debugf("Fetching: %d %v", idx, c.msmts[idx64]) - if err := database.AddTestKeys(c.Probe.DB(), c.msmts[idx64], tk); err != nil { + if err := db.AddTestKeys(c.msmts[idx64], tk); err != nil { return errors.Wrap(err, "failed to add test keys to summary") } } - database.UpdateUploadedStatus(c.Probe.DB(), c.res) + db.UpdateUploadedStatus(c.res) log.Debugf("status.end") return nil } diff --git a/cmd/ooniprobe/internal/nettests/nettests_test.go b/cmd/ooniprobe/internal/nettests/nettests_test.go index 6d8d5bf..0b4d94d 100644 --- a/cmd/ooniprobe/internal/nettests/nettests_test.go +++ b/cmd/ooniprobe/internal/nettests/nettests_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni" - "github.com/ooni/probe-cli/v3/internal/database" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -53,11 +52,12 @@ func TestRun(t *testing.T) { if err != nil { t.Fatal(err) } - network, err := database.CreateNetwork(probe.DB(), sess) + db := probe.DB() + network, err := db.CreateNetwork(sess) if err != nil { t.Fatal(err) } - res, err := database.CreateResult(probe.DB(), probe.Home(), "middlebox", network.ID) + res, err := db.CreateResult(probe.Home(), "middlebox", network.ID) if err != nil { t.Fatal(err) } diff --git a/cmd/ooniprobe/internal/nettests/run.go b/cmd/ooniprobe/internal/nettests/run.go index eef842c..1281501 100644 --- a/cmd/ooniprobe/internal/nettests/run.go +++ b/cmd/ooniprobe/internal/nettests/run.go @@ -8,7 +8,6 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni" - "github.com/ooni/probe-cli/v3/internal/database" "github.com/ooni/probe-cli/v3/internal/model" "github.com/pkg/errors" ) @@ -72,7 +71,8 @@ func RunGroup(config RunGroupConfig) error { log.WithError(err).Error("Failed to lookup the location of the probe") return err } - network, err := database.CreateNetwork(config.Probe.DB(), sess) + db := config.Probe.DB() + network, err := db.CreateNetwork(sess) if err != nil { log.WithError(err).Error("Failed to create the network row") return err @@ -89,8 +89,8 @@ func RunGroup(config RunGroupConfig) error { } log.Debugf("Running test group %s", group.Label) - result, err := database.CreateResult( - config.Probe.DB(), config.Probe.Home(), config.GroupName, network.ID) + result, err := db.CreateResult( + config.Probe.Home(), config.GroupName, network.ID) if err != nil { log.Errorf("DB result error: %s", err) return err @@ -131,8 +131,8 @@ func RunGroup(config RunGroupConfig) error { if err != nil { os.Remove(result.MeasurementDir) } - - if err = result.Finished(config.Probe.DB()); err != nil { + dbSess := db.Session() + if err = result.Finished(dbSess); err != nil { return err } return nil diff --git a/cmd/ooniprobe/internal/nettests/stunreachability.go b/cmd/ooniprobe/internal/nettests/stunreachability.go index 6b8fe1c..186bb7b 100644 --- a/cmd/ooniprobe/internal/nettests/stunreachability.go +++ b/cmd/ooniprobe/internal/nettests/stunreachability.go @@ -25,7 +25,7 @@ func (n STUNReachability) lookupURLs(ctl *Controller) ([]string, error) { if err != nil { return nil, err } - return ctl.BuildAndSetInputIdxMap(ctl.Probe.DB(), testlist) + return ctl.BuildAndSetInputIdxMap(testlist) } // Run starts the nettest. diff --git a/cmd/ooniprobe/internal/nettests/web_connectivity.go b/cmd/ooniprobe/internal/nettests/web_connectivity.go index fd681c3..fc643c0 100644 --- a/cmd/ooniprobe/internal/nettests/web_connectivity.go +++ b/cmd/ooniprobe/internal/nettests/web_connectivity.go @@ -31,7 +31,7 @@ func (n WebConnectivity) lookupURLs(ctl *Controller, categories []string) ([]str if err != nil { return nil, err } - return ctl.BuildAndSetInputIdxMap(ctl.Probe.DB(), testlist) + return ctl.BuildAndSetInputIdxMap(testlist) } // WebConnectivity test implementation diff --git a/cmd/ooniprobe/internal/ooni/ooni.go b/cmd/ooniprobe/internal/ooni/ooni.go index 338863d..2501ca6 100644 --- a/cmd/ooniprobe/internal/ooni/ooni.go +++ b/cmd/ooniprobe/internal/ooni/ooni.go @@ -19,7 +19,6 @@ import ( "github.com/ooni/probe-cli/v3/internal/legacy/assetsdir" "github.com/ooni/probe-cli/v3/internal/model" "github.com/pkg/errors" - "github.com/upper/db/v4" ) // DefaultSoftwareName is the default software name. @@ -33,7 +32,7 @@ var logger = log.WithFields(log.Fields{ // ProbeCLI is the OONI Probe CLI context. type ProbeCLI interface { Config() *config.Config - DB() db.Session + DB() *database.Database IsBatch() bool Home() string TempDir() string @@ -53,7 +52,7 @@ type ProbeEngine interface { // Probe contains the ooniprobe CLI context. type Probe struct { config *config.Config - db db.Session + db *database.Database isBatch bool home string @@ -86,7 +85,7 @@ func (p *Probe) Config() *config.Config { } // DB returns the database we're using -func (p *Probe) DB() db.Session { +func (p *Probe) DB() *database.Database { return p.db } @@ -180,7 +179,7 @@ func (p *Probe) Init(softwareName, softwareVersion, proxy string) error { p.dbPath = utils.DBDir(p.home, "main") log.Debugf("Connecting to database sqlite3://%s", p.dbPath) - db, err := database.Connect(p.dbPath) + db, err := database.Open(p.dbPath) if err != nil { return err } diff --git a/cmd/ooniprobe/internal/oonitest/oonitest.go b/cmd/ooniprobe/internal/oonitest/oonitest.go index 678c2b3..a915c8f 100644 --- a/cmd/ooniprobe/internal/oonitest/oonitest.go +++ b/cmd/ooniprobe/internal/oonitest/oonitest.go @@ -8,8 +8,8 @@ import ( "github.com/apex/log" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/config" "github.com/ooni/probe-cli/v3/cmd/ooniprobe/internal/ooni" + "github.com/ooni/probe-cli/v3/internal/database" "github.com/ooni/probe-cli/v3/internal/model" - "github.com/upper/db/v4" ) // FakeOutput allows to fake the output package. @@ -28,7 +28,7 @@ func (fo *FakeOutput) SectionTitle(s string) { // FakeProbeCLI fakes ooni.ProbeCLI type FakeProbeCLI struct { FakeConfig *config.Config - FakeDB db.Session + FakeDB *database.Database FakeIsBatch bool FakeHome string FakeTempDir string @@ -42,7 +42,7 @@ func (cli *FakeProbeCLI) Config() *config.Config { } // DB implements ProbeCLI.DB -func (cli *FakeProbeCLI) DB() db.Session { +func (cli *FakeProbeCLI) DB() *database.Database { return cli.FakeDB } diff --git a/internal/database/actions.go b/internal/database/actions.go index c94b9f1..185cc2e 100644 --- a/internal/database/actions.go +++ b/internal/database/actions.go @@ -17,10 +17,31 @@ import ( "github.com/upper/db/v4" ) +// Open returns a new database instance +func Open(dbpath string) (*Database, error) { + db, err := Connect(dbpath) + if err != nil { + return nil, err + } + return &Database{ + sess: db, + }, nil +} + +// Database is a database instance to store measurements +type Database struct { + sess db.Session +} + +// Session returns the database session +func (d *Database) Session() db.Session { + return d.sess +} + // ListMeasurements given a result ID -func ListMeasurements(sess db.Session, resultID int64) ([]MeasurementURLNetwork, error) { +func (d *Database) ListMeasurements(resultID int64) ([]MeasurementURLNetwork, error) { measurements := []MeasurementURLNetwork{} - req := sess.SQL().Select( + req := d.sess.SQL().Select( db.Raw("networks.*"), db.Raw("urls.*"), db.Raw("measurements.*"), @@ -39,12 +60,12 @@ func ListMeasurements(sess db.Session, resultID int64) ([]MeasurementURLNetwork, } // GetMeasurementJSON returns a map[string]interface{} given a database and a measurementID -func GetMeasurementJSON(sess db.Session, measurementID int64) (map[string]interface{}, error) { +func (d *Database) GetMeasurementJSON(measurementID int64) (map[string]interface{}, error) { var ( measurement MeasurementURLNetwork msmtJSON map[string]interface{} ) - req := sess.SQL().Select( + req := d.sess.SQL().Select( db.Raw("urls.*"), db.Raw("measurements.*"), ).From("measurements"). @@ -102,10 +123,10 @@ func GetMeasurementJSON(sess db.Session, measurementID int64) (map[string]interf } // ListResults return the list of results -func ListResults(sess db.Session) ([]ResultNetwork, []ResultNetwork, error) { +func (d *Database) ListResults() ([]ResultNetwork, []ResultNetwork, error) { doneResults := []ResultNetwork{} incompleteResults := []ResultNetwork{} - req := sess.SQL().Select( + req := d.sess.SQL().Select( db.Raw("networks.network_name"), db.Raw("networks.network_type"), db.Raw("networks.ip"), @@ -165,9 +186,9 @@ func ListResults(sess db.Session) ([]ResultNetwork, []ResultNetwork, error) { // DeleteResult will delete a particular result and the relative measurement on // disk. -func DeleteResult(sess db.Session, resultID int64) error { +func (d *Database) DeleteResult(resultID int64) error { var result Result - res := sess.Collection("results").Find("result_id", resultID) + res := d.sess.Collection("results").Find("result_id", resultID) if err := res.One(&result); err != nil { if err == db.ErrNoMoreRows { return err @@ -185,8 +206,8 @@ func DeleteResult(sess db.Session, resultID int64) error { } // UpdateUploadedStatus will check if all the measurements inside of a given result set have been uploaded and if so will set the is_uploaded flag to true -func UpdateUploadedStatus(sess db.Session, result *Result) error { - err := sess.Tx(func(tx db.Session) error { +func (d *Database) UpdateUploadedStatus(result *Result) error { + err := d.sess.Tx(func(tx db.Session) error { uploadedTotal := UploadedTotalCount{} req := tx.SQL().Select( db.Raw("SUM(measurements.measurement_is_uploaded)"), @@ -222,7 +243,7 @@ func UpdateUploadedStatus(sess db.Session, result *Result) error { // CreateMeasurement writes the measurement to the database a returns a pointer // to the Measurement -func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) { +func (d *Database) CreateMeasurement(reportID sql.NullString, testName string, measurementDir string, idx int, resultID int64, urlID sql.NullInt64) (*Measurement, error) { // TODO we should look into generating this file path in a more robust way. // If there are two identical test_names in the same test group there is // going to be a clash of test_name @@ -240,7 +261,7 @@ func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string TestKeys: "", } - newID, err := sess.Collection("measurements").Insert(msmt) + newID, err := d.sess.Collection("measurements").Insert(msmt) if err != nil { return nil, errors.Wrap(err, "creating measurement") } @@ -250,7 +271,7 @@ func CreateMeasurement(sess db.Session, reportID sql.NullString, testName string // CreateResult writes the Result to the database a returns a pointer // to the Result -func CreateResult(sess db.Session, homePath string, testGroupName string, networkID int64) (*Result, error) { +func (d *Database) CreateResult(homePath string, testGroupName string, networkID int64) (*Result, error) { startTime := time.Now().UTC() p, err := makeResultsDir(homePath, testGroupName, startTime) @@ -266,7 +287,7 @@ func CreateResult(sess db.Session, homePath string, testGroupName string, networ result.MeasurementDir = p log.Debugf("Creating result %v", result) - newID, err := sess.Collection("results").Insert(result) + newID, err := d.sess.Collection("results").Insert(result) if err != nil { return nil, errors.Wrap(err, "creating result") } @@ -275,7 +296,7 @@ func CreateResult(sess db.Session, homePath string, testGroupName string, networ } // CreateNetwork will create a new network in the network table -func CreateNetwork(sess db.Session, loc engine.LocationProvider) (*Network, error) { +func (d *Database) CreateNetwork(loc engine.LocationProvider) (*Network, error) { network := Network{ ASN: loc.ProbeASN(), CountryCode: loc.ProbeCC(), @@ -284,7 +305,7 @@ func CreateNetwork(sess db.Session, loc engine.LocationProvider) (*Network, erro NetworkType: "wifi", IP: loc.ProbeIP(), } - newID, err := sess.Collection("networks").Insert(network) + newID, err := d.sess.Collection("networks").Insert(network) if err != nil { return nil, err } @@ -296,10 +317,10 @@ func CreateNetwork(sess db.Session, loc engine.LocationProvider) (*Network, erro // CreateOrUpdateURL will create a new URL entry to the urls table if it doesn't // exists, otherwise it will update the category code of the one already in // there. -func CreateOrUpdateURL(sess db.Session, urlStr string, categoryCode string, countryCode string) (int64, error) { +func (d *Database) CreateOrUpdateURL(urlStr string, categoryCode string, countryCode string) (int64, error) { var url URL - err := sess.Tx(func(tx db.Session) error { + err := d.sess.Tx(func(tx db.Session) error { res := tx.Collection("urls").Find( db.Cond{"url": urlStr, "url_country_code": countryCode}, ) @@ -336,7 +357,7 @@ func CreateOrUpdateURL(sess db.Session, urlStr string, categoryCode string, coun } // AddTestKeys writes the summary to the measurement -func AddTestKeys(sess db.Session, msmt *Measurement, tk interface{}) error { +func (d *Database) AddTestKeys(msmt *Measurement, tk interface{}) error { var ( isAnomaly bool isAnomalyValid bool @@ -357,10 +378,15 @@ func AddTestKeys(sess db.Session, msmt *Measurement, tk interface{}) error { msmt.TestKeys = string(tkBytes) msmt.IsAnomaly = sql.NullBool{Bool: isAnomaly, Valid: isAnomalyValid} - err = sess.Collection("measurements").Find("measurement_id", msmt.ID).Update(msmt) + err = d.sess.Collection("measurements").Find("measurement_id", msmt.ID).Update(msmt) if err != nil { log.WithError(err).Error("failed to update measurement") return errors.Wrap(err, "updating measurement") } return nil } + +// Close closes the database session +func (d *Database) Close() error { + return d.sess.Close() +} diff --git a/internal/database/actions_test.go b/internal/database/actions_test.go index c7bc0ef..bea8e07 100644 --- a/internal/database/actions_test.go +++ b/internal/database/actions_test.go @@ -43,6 +43,31 @@ func (lp *locationInfo) ResolverIP() string { return lp.resolverIP } +func TestNewDatabase(t *testing.T) { + t.Run("with empty path", func(t *testing.T) { + dbpath := "" + database, err := Open(dbpath) + if database != nil { + t.Fatal("unexpected database instance") + } + if err.Error() != "Expecting file:// connection scheme." { + t.Fatal(err) + } + }) + + t.Run("with valid path", func(t *testing.T) { + tmpfile, err := ioutil.TempFile("", "dbtest") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + _, err = Open(tmpfile.Name()) + if err != nil { + t.Fatal(err) + } + }) +} + func TestMeasurementWorkflow(t *testing.T) { tmpfile, err := ioutil.TempFile("", "dbtest") if err != nil { @@ -56,7 +81,7 @@ func TestMeasurementWorkflow(t *testing.T) { } defer os.RemoveAll(tmpdir) - sess, err := Connect(tmpfile.Name()) + database, err := Open(tmpfile.Name()) if err != nil { t.Fatal(err) } @@ -66,12 +91,13 @@ func TestMeasurementWorkflow(t *testing.T) { countryCode: "IT", networkName: "Unknown", } - network, err := CreateNetwork(sess, &location) + sess := database.Session() + network, err := database.CreateNetwork(&location) if err != nil { t.Fatal(err) } - result, err := CreateResult(sess, tmpdir, "websites", network.ID) + result, err := database.CreateResult(tmpdir, "websites", network.ID) if err != nil { t.Fatal(err) } @@ -82,7 +108,7 @@ func TestMeasurementWorkflow(t *testing.T) { msmtFilePath := tmpdir urlID := sql.NullInt64{Int64: 0, Valid: false} - m1, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) + m1, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID) if err != nil { t.Fatal(err) } @@ -93,7 +119,7 @@ func TestMeasurementWorkflow(t *testing.T) { t.Fatal(err) } - m2, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) + m2, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID) if err != nil { t.Fatal(err) } @@ -107,7 +133,7 @@ func TestMeasurementWorkflow(t *testing.T) { if m2.ResultID != m1.ResultID { t.Error("result_id mismatch") } - err = UpdateUploadedStatus(sess, result) + err = database.UpdateUploadedStatus(result) if err != nil { t.Fatal(err) } @@ -122,7 +148,7 @@ func TestMeasurementWorkflow(t *testing.T) { t.Error("result should be marked as not uploaded") } - done, incomplete, err := ListResults(sess) + done, incomplete, err := database.ListResults() if err != nil { t.Fatal(err) } @@ -141,7 +167,7 @@ func TestMeasurementWorkflow(t *testing.T) { t.Error("there should be a total of 1 anomalies in the result") } - msmts, err := ListMeasurements(sess, resultID) + msmts, err := database.ListMeasurements(resultID) if err != nil { t.Fatal(err) } @@ -163,7 +189,7 @@ func TestDeleteResult(t *testing.T) { } defer os.RemoveAll(tmpdir) - sess, err := Connect(tmpfile.Name()) + database, err := Open(tmpfile.Name()) if err != nil { t.Fatal(err) } @@ -173,12 +199,13 @@ func TestDeleteResult(t *testing.T) { countryCode: "IT", networkName: "Unknown", } - network, err := CreateNetwork(sess, &location) + sess := database.Session() + network, err := database.CreateNetwork(&location) if err != nil { t.Fatal(err) } - result, err := CreateResult(sess, tmpdir, "websites", network.ID) + result, err := database.CreateResult(tmpdir, "websites", network.ID) if err != nil { t.Fatal(err) } @@ -189,7 +216,7 @@ func TestDeleteResult(t *testing.T) { msmtFilePath := tmpdir urlID := sql.NullInt64{Int64: 0, Valid: false} - m1, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) + m1, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID) if err != nil { t.Fatal(err) } @@ -204,7 +231,7 @@ func TestDeleteResult(t *testing.T) { t.Error("result_id mismatch") } - err = DeleteResult(sess, resultID) + err = database.DeleteResult(resultID) if err != nil { t.Fatal(err) } @@ -223,7 +250,7 @@ func TestDeleteResult(t *testing.T) { t.Fatal("measurements should be zero") } - err = DeleteResult(sess, 20) + err = database.DeleteResult(20) if err != db.ErrNoMoreRows { t.Fatal(err) } @@ -236,7 +263,7 @@ func TestNetworkCreate(t *testing.T) { } defer os.Remove(tmpfile.Name()) - sess, err := Connect(tmpfile.Name()) + database, err := Open(tmpfile.Name()) if err != nil { t.Fatal(err) } @@ -253,12 +280,12 @@ func TestNetworkCreate(t *testing.T) { networkName: "Fufnet", } - _, err = CreateNetwork(sess, &l1) + _, err = database.CreateNetwork(&l1) if err != nil { t.Fatal(err) } - _, err = CreateNetwork(sess, &l2) + _, err = database.CreateNetwork(&l2) if err != nil { t.Fatal(err) } @@ -272,32 +299,32 @@ func TestURLCreation(t *testing.T) { } defer os.Remove(tmpfile.Name()) - sess, err := Connect(tmpfile.Name()) + database, err := Open(tmpfile.Name()) if err != nil { t.Fatal(err) } - newID1, err := CreateOrUpdateURL(sess, "https://google.com", "GMB", "XX") + newID1, err := database.CreateOrUpdateURL("https://google.com", "GMB", "XX") if err != nil { t.Fatal(err) } - newID2, err := CreateOrUpdateURL(sess, "https://google.com", "SRCH", "XX") + newID2, err := database.CreateOrUpdateURL("https://google.com", "SRCH", "XX") if err != nil { t.Fatal(err) } - newID3, err := CreateOrUpdateURL(sess, "https://facebook.com", "GRP", "XX") + newID3, err := database.CreateOrUpdateURL("https://facebook.com", "GRP", "XX") if err != nil { t.Fatal(err) } - newID4, err := CreateOrUpdateURL(sess, "https://facebook.com", "GMP", "XX") + newID4, err := database.CreateOrUpdateURL("https://facebook.com", "GMP", "XX") if err != nil { t.Fatal(err) } - newID5, err := CreateOrUpdateURL(sess, "https://google.com", "SRCH", "XX") + newID5, err := database.CreateOrUpdateURL("https://google.com", "SRCH", "XX") if err != nil { t.Fatal(err) } @@ -351,22 +378,22 @@ func TestGetMeasurementJSON(t *testing.T) { } defer os.RemoveAll(tmpdir) - sess, err := Connect(tmpfile.Name()) + database, err := Open(tmpfile.Name()) if err != nil { t.Fatal(err) } - + sess := database.Session() location := locationInfo{ asn: 0, countryCode: "IT", networkName: "Unknown", } - network, err := CreateNetwork(sess, &location) + network, err := database.CreateNetwork(&location) if err != nil { t.Fatal(err) } - result, err := CreateResult(sess, tmpdir, "websites", network.ID) + result, err := database.CreateResult(tmpdir, "websites", network.ID) if err != nil { t.Fatal(err) } @@ -377,7 +404,7 @@ func TestGetMeasurementJSON(t *testing.T) { msmtFilePath := tmpdir urlID := sql.NullInt64{Int64: 0, Valid: false} - msmt, err := CreateMeasurement(sess, reportID, testName, msmtFilePath, 0, resultID, urlID) + msmt, err := database.CreateMeasurement(reportID, testName, msmtFilePath, 0, resultID, urlID) if err != nil { t.Fatal(err) } @@ -387,7 +414,7 @@ func TestGetMeasurementJSON(t *testing.T) { t.Fatal(err) } - tk, err := GetMeasurementJSON(sess, msmt.ID) + tk, err := database.GetMeasurementJSON(msmt.ID) if err != nil { t.Fatal(err) }