refactor(ooni): introduce interfaces for testability

This commit is contained in:
Simone Basso 2020-11-13 20:07:30 +01:00
parent c55f67273e
commit fa803300bb
13 changed files with 129 additions and 70 deletions

View File

@ -12,28 +12,28 @@ func init() {
cmd.Action(func(_ *kingpin.ParseContext) error { cmd.Action(func(_ *kingpin.ParseContext) error {
output.SectionTitle("GeoIP lookup") output.SectionTitle("GeoIP lookup")
ctx, err := root.Init() probeCLI, err := root.Init()
if err != nil { if err != nil {
return err return err
} }
sess, err := ctx.NewSession() engine, err := probeCLI.NewProbeEngine()
if err != nil { if err != nil {
return err return err
} }
defer sess.Close() defer engine.Close()
err = sess.MaybeLookupLocation() err = engine.MaybeLookupLocation()
if err != nil { if err != nil {
return err return err
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"type": "table", "type": "table",
"asn": sess.ProbeASNString(), "asn": engine.ProbeASNString(),
"network_name": sess.ProbeNetworkName(), "network_name": engine.ProbeNetworkName(),
"country_code": sess.ProbeCC(), "country_code": engine.ProbeCC(),
"ip": sess.ProbeIP(), "ip": engine.ProbeIP(),
}).Info("Looked up your location") }).Info("Looked up your location")
return nil return nil

View File

@ -10,16 +10,16 @@ func init() {
cmd := root.Command("info", "Display information about OONI Probe") cmd := root.Command("info", "Display information about OONI Probe")
cmd.Action(func(_ *kingpin.ParseContext) error { cmd.Action(func(_ *kingpin.ParseContext) error {
ctx, err := root.Init() probeCLI, err := root.Init()
if err != nil { if err != nil {
log.Errorf("%s", err) log.Errorf("%s", err)
return err return err
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"path": ctx.Home, "path": probeCLI.Home(),
}).Info("Home") }).Info("Home")
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"path": ctx.TempDir, "path": probeCLI.TempDir(),
}).Info("TempDir") }).Info("TempDir")
return nil return nil

View File

@ -12,13 +12,13 @@ func init() {
cmd := root.Command("list", "List results") cmd := root.Command("list", "List results")
resultID := cmd.Arg("id", "the id of the result to list measurements for").Int64() resultID := cmd.Arg("id", "the id of the result to list measurements for").Int64()
cmd.Action(func(_ *kingpin.ParseContext) error { cmd.Action(func(_ *kingpin.ParseContext) error {
ctx, err := root.Init() probeCLI, err := root.Init()
if err != nil { if err != nil {
log.WithError(err).Error("failed to initialize root context") log.WithError(err).Error("failed to initialize root context")
return err return err
} }
if *resultID > 0 { if *resultID > 0 {
measurements, err := database.ListMeasurements(ctx.DB, *resultID) measurements, err := database.ListMeasurements(probeCLI.DB(), *resultID)
if err != nil { if err != nil {
log.WithError(err).Error("failed to list measurements") log.WithError(err).Error("failed to list measurements")
return err return err
@ -61,7 +61,7 @@ func init() {
} }
output.MeasurementSummary(msmtSummary) output.MeasurementSummary(msmtSummary)
} else { } else {
doneResults, incompleteResults, err := database.ListResults(ctx.DB) doneResults, incompleteResults, err := database.ListResults(probeCLI.DB())
if err != nil { if err != nil {
log.WithError(err).Error("failed to list results") log.WithError(err).Error("failed to list results")
return err return err
@ -91,11 +91,11 @@ func init() {
netCount := make(map[uint]int) netCount := make(map[uint]int)
output.SectionTitle("Results") output.SectionTitle("Results")
for idx, result := range doneResults { for idx, result := range doneResults {
totalCount, anmlyCount, err := database.GetMeasurementCounts(ctx.DB, result.Result.ID) totalCount, anmlyCount, err := database.GetMeasurementCounts(probeCLI.DB(), result.Result.ID)
if err != nil { if err != nil {
log.WithError(err).Error("failed to list measurement counts") log.WithError(err).Error("failed to list measurement counts")
} }
testKeys, err := database.GetResultTestKeys(ctx.DB, result.Result.ID) testKeys, err := database.GetResultTestKeys(probeCLI.DB(), result.Result.ID)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get testKeys") log.WithError(err).Error("failed to get testKeys")
} }

View File

@ -138,11 +138,11 @@ func Onboarding(config *config.Config) error {
// MaybeOnboarding will run the onboarding process only if the informed consent // MaybeOnboarding will run the onboarding process only if the informed consent
// config option is set to false // config option is set to false
func MaybeOnboarding(probe *ooni.Probe) error { func MaybeOnboarding(probe *ooni.Probe) error {
if probe.Config.InformedConsent == false { if probe.Config().InformedConsent == false {
if probe.IsBatch == true { if probe.IsBatch() == true {
return errors.New("cannot run onboarding in batch mode") return errors.New("cannot run onboarding in batch mode")
} }
if err := Onboarding(probe.Config); err != nil { if err := Onboarding(probe.Config()); err != nil {
return errors.Wrap(err, "onboarding") return errors.Wrap(err, "onboarding")
} }
} }
@ -161,20 +161,20 @@ func init() {
} }
if *yes == true { if *yes == true {
probe.Config.Lock() probe.Config().Lock()
probe.Config.InformedConsent = true probe.Config().InformedConsent = true
probe.Config.Unlock() probe.Config().Unlock()
if err := probe.Config.Write(); err != nil { if err := probe.Config().Write(); err != nil {
log.WithError(err).Error("failed to write config file") log.WithError(err).Error("failed to write config file")
return err return err
} }
return nil return nil
} }
if probe.IsBatch == true { if probe.IsBatch() == true {
return errors.New("cannot do onboarding in batch mode") return errors.New("cannot do onboarding in batch mode")
} }
return Onboarding(probe.Config) return Onboarding(probe.Config())
}) })
} }

View File

@ -20,16 +20,16 @@ func init() {
} }
// We need to first the DB otherwise the DB will be rewritten on close when // We need to first the DB otherwise the DB will be rewritten on close when
// we delete the home directory. // we delete the home directory.
err = ctx.DB.Close() err = ctx.DB().Close()
if err != nil { if err != nil {
log.WithError(err).Error("failed to close the DB") log.WithError(err).Error("failed to close the DB")
return err return err
} }
if *force == true { if *force == true {
os.RemoveAll(ctx.Home) os.RemoveAll(ctx.Home())
log.Infof("Deleted %s", ctx.Home) log.Infof("Deleted %s", ctx.Home())
} else { } else {
log.Infof("Run with --force to delete %s", ctx.Home) log.Infof("Run with --force to delete %s", ctx.Home())
} }
return nil return nil

View File

@ -65,11 +65,11 @@ func init() {
} }
if *all == true { if *all == true {
return deleteAll(ctx.DB, *yes) return deleteAll(ctx.DB(), *yes)
} }
if *yes == true { if *yes == true {
err = database.DeleteResult(ctx.DB, *resultID) err = database.DeleteResult(ctx.DB(), *resultID)
if err == db.ErrNoMoreRows { if err == db.ErrNoMoreRows {
return errors.New("result not found") return errors.New("result not found")
} }
@ -85,7 +85,7 @@ func init() {
if answer == "false" { if answer == "false" {
return errors.New("canceled by user") return errors.New("canceled by user")
} }
err = database.DeleteResult(ctx.DB, *resultID) err = database.DeleteResult(ctx.DB(), *resultID)
if err == db.ErrNoMoreRows { if err == db.ErrNoMoreRows {
return errors.New("result not found") return errors.New("result not found")
} }

View File

@ -57,7 +57,7 @@ func init() {
return nil, err return nil, err
} }
if *isBatch { if *isBatch {
probe.IsBatch = true probe.SetIsBatch(true)
} }
return probe, nil return probe, nil

View File

@ -31,7 +31,7 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro
log.WithError(err).Error("Failed to lookup the location of the probe") log.WithError(err).Error("Failed to lookup the location of the probe")
return err return err
} }
network, err = database.CreateNetwork(ctx.DB, sess) network, err = database.CreateNetwork(ctx.DB(), sess)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to create the network row") log.WithError(err).Error("Failed to create the network row")
return err return err
@ -48,7 +48,7 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro
} }
log.Debugf("Running test group %s", group.Label) log.Debugf("Running test group %s", group.Label)
result, err := database.CreateResult(ctx.DB, ctx.Home, tg, network.ID) result, err := database.CreateResult(ctx.DB(), ctx.Home(), tg, network.ID)
if err != nil { if err != nil {
log.Errorf("DB result error: %s", err) log.Errorf("DB result error: %s", err)
return err return err
@ -69,7 +69,7 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro
} }
} }
if err = result.Finished(ctx.DB); err != nil { if err = result.Finished(ctx.DB()); err != nil {
return err return err
} }
return nil return nil
@ -102,7 +102,7 @@ func init() {
} }
if *noCollector == true { if *noCollector == true {
probe.Config.Sharing.UploadResults = false probe.Config().Sharing.UploadResults = false
} }
return nil return nil
}) })

View File

@ -17,7 +17,7 @@ func init() {
log.WithError(err).Error("failed to initialize root context") log.WithError(err).Error("failed to initialize root context")
return err return err
} }
msmt, err := database.GetMeasurementJSON(ctx.DB, *msmtID) msmt, err := database.GetMeasurementJSON(ctx.DB(), *msmtID)
if err != nil { if err != nil {
log.Errorf("error: %v", err) log.Errorf("error: %v", err)
return err return err

View File

@ -93,7 +93,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err
log.Debug(color.RedString("status.queued")) log.Debug(color.RedString("status.queued"))
log.Debug(color.RedString("status.started")) log.Debug(color.RedString("status.started"))
if c.Probe.Config.Sharing.UploadResults { if c.Probe.Config().Sharing.UploadResults {
if err := exp.OpenReport(); err != nil { if err := exp.OpenReport(); err != nil {
log.Debugf( log.Debugf(
"%s: %s", color.RedString("failure.report_create"), err.Error(), "%s: %s", color.RedString("failure.report_create"), err.Error(),
@ -120,7 +120,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err
} }
msmt, err := database.CreateMeasurement( msmt, err := database.CreateMeasurement(
c.Probe.DB, reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID, c.Probe.DB(), reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID,
) )
if err != nil { if err != nil {
return errors.Wrap(err, "failed to create measurement") return errors.Wrap(err, "failed to create measurement")
@ -133,7 +133,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err
measurement, err := exp.Measure(input) measurement, err := exp.Measure(input)
if err != nil { if err != nil {
log.WithError(err).Debug(color.RedString("failure.measurement")) log.WithError(err).Debug(color.RedString("failure.measurement"))
if err := c.msmts[idx64].Failed(c.Probe.DB, err.Error()); err != nil { if err := c.msmts[idx64].Failed(c.Probe.DB(), err.Error()); err != nil {
return errors.Wrap(err, "failed to mark measurement as failed") return errors.Wrap(err, "failed to mark measurement as failed")
} }
// Even with a failed measurement, we want to continue. We want to // Even with a failed measurement, we want to continue. We want to
@ -142,16 +142,16 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err
// undertsand what went wrong (censorship? bug? anomaly?). // undertsand what went wrong (censorship? bug? anomaly?).
} }
if c.Probe.Config.Sharing.UploadResults { if c.Probe.Config().Sharing.UploadResults {
// Implementation note: SubmitMeasurement will fail here if we did fail // Implementation note: SubmitMeasurement will fail here if we did fail
// to open the report but we still want to continue. There will be a // to open the report but we still want to continue. There will be a
// bit of a spew in the logs, perhaps, but stopping seems less efficient. // bit of a spew in the logs, perhaps, but stopping seems less efficient.
if err := exp.SubmitAndUpdateMeasurement(measurement); err != nil { if err := exp.SubmitAndUpdateMeasurement(measurement); err != nil {
log.Debug(color.RedString("failure.measurement_submission")) log.Debug(color.RedString("failure.measurement_submission"))
if err := c.msmts[idx64].UploadFailed(c.Probe.DB, err.Error()); err != nil { if err := c.msmts[idx64].UploadFailed(c.Probe.DB(), err.Error()); err != nil {
return errors.Wrap(err, "failed to mark upload as failed") return errors.Wrap(err, "failed to mark upload as failed")
} }
} else if err := c.msmts[idx64].UploadSucceeded(c.Probe.DB); err != nil { } else if err := c.msmts[idx64].UploadSucceeded(c.Probe.DB()); err != nil {
return errors.Wrap(err, "failed to mark upload as succeeded") return errors.Wrap(err, "failed to mark upload as succeeded")
} }
} }
@ -159,7 +159,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err
if err := exp.SaveMeasurement(measurement, msmt.MeasurementFilePath.String); err != nil { if err := exp.SaveMeasurement(measurement, msmt.MeasurementFilePath.String); err != nil {
return errors.Wrap(err, "failed to save measurement on disk") return errors.Wrap(err, "failed to save measurement on disk")
} }
if err := c.msmts[idx64].Done(c.Probe.DB); err != nil { if err := c.msmts[idx64].Done(c.Probe.DB()); err != nil {
return errors.Wrap(err, "failed to mark measurement as done") return errors.Wrap(err, "failed to mark measurement as done")
} }
@ -179,7 +179,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err
continue continue
} }
log.Debugf("Fetching: %d %v", idx, c.msmts[idx64]) log.Debugf("Fetching: %d %v", idx, c.msmts[idx64])
if err := database.AddTestKeys(c.Probe.DB, c.msmts[idx64], tk); err != nil { if err := database.AddTestKeys(c.Probe.DB(), c.msmts[idx64], tk); err != nil {
return errors.Wrap(err, "failed to add test keys to summary") return errors.Wrap(err, "failed to add test keys to summary")
} }
} }

View File

@ -38,11 +38,11 @@ func TestRun(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
network, err := database.CreateNetwork(probe.DB, sess) network, err := database.CreateNetwork(probe.DB(), sess)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := database.CreateResult(probe.DB, probe.Home, "middlebox", network.ID) res, err := database.CreateResult(probe.DB(), probe.Home(), "middlebox", network.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -19,7 +19,7 @@ func lookupURLs(ctl *Controller, limit int64, categories []string) ([]string, ma
for idx, url := range testlist.Result { for idx, url := range testlist.Result {
log.Debugf("Going over URL %d", idx) log.Debugf("Going over URL %d", idx)
urlID, err := database.CreateOrUpdateURL( urlID, err := database.CreateOrUpdateURL(
ctl.Probe.DB, url.URL, url.CategoryCode, url.CountryCode, ctl.Probe.DB(), url.URL, url.CategoryCode, url.CountryCode,
) )
if err != nil { if err != nil {
log.Error("failed to add to the URL table") log.Error("failed to add to the URL table")
@ -38,8 +38,8 @@ type WebConnectivity struct {
// Run starts the test // Run starts the test
func (n WebConnectivity) Run(ctl *Controller) error { func (n WebConnectivity) Run(ctl *Controller) error {
log.Debugf("Enabled category codes are the following %v", ctl.Probe.Config.Nettests.WebsitesEnabledCategoryCodes) log.Debugf("Enabled category codes are the following %v", ctl.Probe.Config().Nettests.WebsitesEnabledCategoryCodes)
urls, urlIDMap, err := lookupURLs(ctl, ctl.Probe.Config.Nettests.WebsitesURLLimit, ctl.Probe.Config.Nettests.WebsitesEnabledCategoryCodes) urls, urlIDMap, err := lookupURLs(ctl, ctl.Probe.Config().Nettests.WebsitesURLLimit, ctl.Probe.Config().Nettests.WebsitesEnabledCategoryCodes)
if err != nil { if err != nil {
return err return err
} }

View File

@ -19,14 +19,34 @@ import (
"upper.io/db.v3/lib/sqlbuilder" "upper.io/db.v3/lib/sqlbuilder"
) )
// ProbeCLI is the OONI Probe CLI context.
type ProbeCLI interface {
Config() *config.Config
DB() sqlbuilder.Database
IsBatch() bool
Home() string
TempDir() string
NewProbeEngine() (ProbeEngine, error)
}
// ProbeEngine is an instance of the OONI Probe engine.
type ProbeEngine interface {
Close() error
MaybeLookupLocation() error
ProbeASNString() string
ProbeCC() string
ProbeIP() string
ProbeNetworkName() string
}
// Probe contains the ooniprobe CLI context. // Probe contains the ooniprobe CLI context.
type Probe struct { type Probe struct {
Config *config.Config config *config.Config
DB sqlbuilder.Database db sqlbuilder.Database
IsBatch bool isBatch bool
Home string home string
TempDir string tempDir string
dbPath string dbPath string
configPath string configPath string
@ -41,6 +61,36 @@ type Probe struct {
softwareVersion string softwareVersion string
} }
// SetIsBatch sets the value of isBatch.
func (p *Probe) SetIsBatch(v bool) {
p.isBatch = v
}
// IsBatch returns whether we're running in batch mode.
func (p *Probe) IsBatch() bool {
return p.isBatch
}
// Config returns the configuration
func (p *Probe) Config() *config.Config {
return p.config
}
// DB returns the database we're using
func (p *Probe) DB() sqlbuilder.Database {
return p.db
}
// Home returns the home directory.
func (p *Probe) Home() string {
return p.home
}
// TempDir returns the temporary directory.
func (p *Probe) TempDir() string {
return p.tempDir
}
// IsTerminated checks to see if the isTerminatedAtomicInt is set to a non zero // IsTerminated checks to see if the isTerminatedAtomicInt is set to a non zero
// value and therefore we have received the signal to shutdown the running test // value and therefore we have received the signal to shutdown the running test
func (p *Probe) IsTerminated() bool { func (p *Probe) IsTerminated() bool {
@ -102,37 +152,37 @@ func (p *Probe) MaybeListenForStdinClosed() {
func (p *Probe) Init(softwareName, softwareVersion string) error { func (p *Probe) Init(softwareName, softwareVersion string) error {
var err error var err error
if err = MaybeInitializeHome(p.Home); err != nil { if err = MaybeInitializeHome(p.home); err != nil {
return err return err
} }
if p.configPath != "" { if p.configPath != "" {
log.Debugf("Reading config file from %s", p.configPath) log.Debugf("Reading config file from %s", p.configPath)
p.Config, err = config.ReadConfig(p.configPath) p.config, err = config.ReadConfig(p.configPath)
} else { } else {
log.Debug("Reading default config file") log.Debug("Reading default config file")
p.Config, err = InitDefaultConfig(p.Home) p.config, err = InitDefaultConfig(p.home)
} }
if err != nil { if err != nil {
return err return err
} }
if err = p.Config.MaybeMigrate(); err != nil { if err = p.config.MaybeMigrate(); err != nil {
return errors.Wrap(err, "migrating config") return errors.Wrap(err, "migrating config")
} }
p.dbPath = utils.DBDir(p.Home, "main") p.dbPath = utils.DBDir(p.home, "main")
log.Debugf("Connecting to database sqlite3://%s", p.dbPath) log.Debugf("Connecting to database sqlite3://%s", p.dbPath)
db, err := database.Connect(p.dbPath) db, err := database.Connect(p.dbPath)
if err != nil { if err != nil {
return err return err
} }
p.DB = db p.db = db
tempDir, err := ioutil.TempDir("", "ooni") tempDir, err := ioutil.TempDir("", "ooni")
if err != nil { if err != nil {
return errors.Wrap(err, "creating TempDir") return errors.Wrap(err, "creating TempDir")
} }
p.TempDir = tempDir p.tempDir = tempDir
p.softwareName = softwareName p.softwareName = softwareName
p.softwareVersion = softwareVersion p.softwareVersion = softwareVersion
@ -144,31 +194,40 @@ func (p *Probe) Init(softwareName, softwareVersion string) error {
// the session when done using it, by calling sess.Close(). // the session when done using it, by calling sess.Close().
func (p *Probe) NewSession() (*engine.Session, error) { func (p *Probe) NewSession() (*engine.Session, error) {
kvstore, err := engine.NewFileSystemKVStore( kvstore, err := engine.NewFileSystemKVStore(
utils.EngineDir(p.Home), utils.EngineDir(p.home),
) )
if err != nil { if err != nil {
return nil, errors.Wrap(err, "creating engine's kvstore") return nil, errors.Wrap(err, "creating engine's kvstore")
} }
return engine.NewSession(engine.SessionConfig{ return engine.NewSession(engine.SessionConfig{
AssetsDir: utils.AssetsDir(p.Home), AssetsDir: utils.AssetsDir(p.home),
KVStore: kvstore, KVStore: kvstore,
Logger: enginex.Logger, Logger: enginex.Logger,
PrivacySettings: model.PrivacySettings{ PrivacySettings: model.PrivacySettings{
IncludeASN: p.Config.Sharing.IncludeASN, IncludeASN: p.config.Sharing.IncludeASN,
IncludeCountry: true, IncludeCountry: true,
IncludeIP: p.Config.Sharing.IncludeIP, IncludeIP: p.config.Sharing.IncludeIP,
}, },
SoftwareName: p.softwareName, SoftwareName: p.softwareName,
SoftwareVersion: p.softwareVersion, SoftwareVersion: p.softwareVersion,
TempDir: p.TempDir, TempDir: p.tempDir,
}) })
} }
// NewProbeEngine creates a new ProbeEngine instance.
func (p *Probe) NewProbeEngine() (ProbeEngine, error) {
sess, err := p.NewSession()
if err != nil {
return nil, err
}
return sess, nil
}
// NewProbe creates a new probe instance. // NewProbe creates a new probe instance.
func NewProbe(configPath string, homePath string) *Probe { func NewProbe(configPath string, homePath string) *Probe {
return &Probe{ return &Probe{
Home: homePath, home: homePath,
Config: &config.Config{}, config: &config.Config{},
configPath: configPath, configPath: configPath,
isTerminatedAtomicInt: 0, isTerminatedAtomicInt: 0,
} }