diff --git a/internal/cli/geoip/geoip.go b/internal/cli/geoip/geoip.go index db6772e..7b653dc 100644 --- a/internal/cli/geoip/geoip.go +++ b/internal/cli/geoip/geoip.go @@ -4,38 +4,54 @@ import ( "github.com/alecthomas/kingpin" "github.com/apex/log" "github.com/ooni/probe-cli/internal/cli/root" + "github.com/ooni/probe-cli/internal/ooni" "github.com/ooni/probe-cli/internal/output" ) func init() { cmd := root.Command("geoip", "Perform a geoip lookup") - cmd.Action(func(_ *kingpin.ParseContext) error { - output.SectionTitle("GeoIP lookup") - ctx, err := root.Init() - if err != nil { - return err - } - - sess, err := ctx.NewSession() - if err != nil { - return err - } - defer sess.Close() - - err = sess.MaybeLookupLocation() - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "type": "table", - "asn": sess.ProbeASNString(), - "network_name": sess.ProbeNetworkName(), - "country_code": sess.ProbeCC(), - "ip": sess.ProbeIP(), - }).Info("Looked up your location") - - return nil + return dogeoip(defaultconfig) }) } + +type dogeoipconfig struct { + Logger log.Interface + NewProbeCLI func() (ooni.ProbeCLI, error) + SectionTitle func(string) +} + +var defaultconfig = dogeoipconfig{ + Logger: log.Log, + NewProbeCLI: root.NewProbeCLI, + SectionTitle: output.SectionTitle, +} + +func dogeoip(config dogeoipconfig) error { + config.SectionTitle("GeoIP lookup") + probeCLI, err := config.NewProbeCLI() + if err != nil { + return err + } + + engine, err := probeCLI.NewProbeEngine() + if err != nil { + return err + } + defer engine.Close() + + err = engine.MaybeLookupLocation() + if err != nil { + return err + } + + config.Logger.WithFields(log.Fields{ + "type": "table", + "asn": engine.ProbeASNString(), + "network_name": engine.ProbeNetworkName(), + "country_code": engine.ProbeCC(), + "ip": engine.ProbeIP(), + }).Info("Looked up your location") + + return nil +} diff --git a/internal/cli/geoip/geoip_test.go b/internal/cli/geoip/geoip_test.go new file mode 100644 index 0000000..b223a22 --- /dev/null +++ b/internal/cli/geoip/geoip_test.go @@ -0,0 +1,134 @@ +package geoip + +import ( + "errors" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/internal/ooni" + "github.com/ooni/probe-cli/internal/oonitest" +) + +func TestNewProbeCLIFailed(t *testing.T) { + fo := &oonitest.FakeOutput{} + expected := errors.New("mocked error") + err := dogeoip(dogeoipconfig{ + SectionTitle: fo.SectionTitle, + NewProbeCLI: func() (ooni.ProbeCLI, error) { + return nil, expected + }, + }) + if !errors.Is(err, expected) { + t.Fatalf("not the error we expected: %+v", err) + } + if len(fo.FakeSectionTitle) != 1 { + t.Fatal("invalid section title list size") + } + if fo.FakeSectionTitle[0] != "GeoIP lookup" { + t.Fatal("unexpected string") + } +} + +func TestNewProbeEngineFailed(t *testing.T) { + fo := &oonitest.FakeOutput{} + expected := errors.New("mocked error") + cli := &oonitest.FakeProbeCLI{ + FakeProbeEngineErr: expected, + } + err := dogeoip(dogeoipconfig{ + SectionTitle: fo.SectionTitle, + NewProbeCLI: func() (ooni.ProbeCLI, error) { + return cli, nil + }, + }) + if !errors.Is(err, expected) { + t.Fatalf("not the error we expected: %+v", err) + } + if len(fo.FakeSectionTitle) != 1 { + t.Fatal("invalid section title list size") + } + if fo.FakeSectionTitle[0] != "GeoIP lookup" { + t.Fatal("unexpected string") + } +} + +func TestMaybeLookupLocationFailed(t *testing.T) { + fo := &oonitest.FakeOutput{} + expected := errors.New("mocked error") + engine := &oonitest.FakeProbeEngine{ + FakeMaybeLookupLocation: expected, + } + cli := &oonitest.FakeProbeCLI{ + FakeProbeEnginePtr: engine, + } + err := dogeoip(dogeoipconfig{ + SectionTitle: fo.SectionTitle, + NewProbeCLI: func() (ooni.ProbeCLI, error) { + return cli, nil + }, + }) + if !errors.Is(err, expected) { + t.Fatalf("not the error we expected: %+v", err) + } + if len(fo.FakeSectionTitle) != 1 { + t.Fatal("invalid section title list size") + } + if fo.FakeSectionTitle[0] != "GeoIP lookup" { + t.Fatal("unexpected string") + } +} + +func TestMaybeLookupLocationSuccess(t *testing.T) { + fo := &oonitest.FakeOutput{} + engine := &oonitest.FakeProbeEngine{ + FakeProbeASNString: "AS30722", + FakeProbeCC: "IT", + FakeProbeNetworkName: "Vodafone Italia S.p.A.", + FakeProbeIP: "130.25.90.216", + } + cli := &oonitest.FakeProbeCLI{ + FakeProbeEnginePtr: engine, + } + handler := &oonitest.FakeLoggerHandler{} + err := dogeoip(dogeoipconfig{ + SectionTitle: fo.SectionTitle, + NewProbeCLI: func() (ooni.ProbeCLI, error) { + return cli, nil + }, + Logger: &log.Logger{ + Handler: handler, + Level: log.DebugLevel, + }, + }) + if err != nil { + t.Fatal(err) + } + if len(fo.FakeSectionTitle) != 1 { + t.Fatal("invalid section title list size") + } + if fo.FakeSectionTitle[0] != "GeoIP lookup" { + t.Fatal("unexpected string") + } + if len(handler.FakeEntries) != 1 { + t.Fatal("invalid number of written entries") + } + entry := handler.FakeEntries[0] + if entry.Level != log.InfoLevel { + t.Fatal("invalid log level") + } + if entry.Message != "Looked up your location" { + t.Fatal("invalid .Message") + } + if entry.Fields["asn"].(string) != "AS30722" { + t.Fatal("invalid asn") + } + if entry.Fields["country_code"].(string) != "IT" { + t.Fatal("invalid asn") + } + if entry.Fields["network_name"].(string) != "Vodafone Italia S.p.A." { + t.Fatal("invalid asn") + } + if entry.Fields["ip"].(string) != "130.25.90.216" { + t.Fatal("invalid asn") + } +} diff --git a/internal/cli/info/info.go b/internal/cli/info/info.go index 13f4834..17cd75d 100644 --- a/internal/cli/info/info.go +++ b/internal/cli/info/info.go @@ -4,24 +4,33 @@ import ( "github.com/alecthomas/kingpin" "github.com/apex/log" "github.com/ooni/probe-cli/internal/cli/root" + "github.com/ooni/probe-cli/internal/ooni" ) func init() { cmd := root.Command("info", "Display information about OONI Probe") - cmd.Action(func(_ *kingpin.ParseContext) error { - ctx, err := root.Init() - if err != nil { - log.Errorf("%s", err) - return err - } - log.WithFields(log.Fields{ - "path": ctx.Home, - }).Info("Home") - log.WithFields(log.Fields{ - "path": ctx.TempDir, - }).Info("TempDir") - - return nil + return doinfo(defaultconfig) }) } + +type doinfoconfig struct { + Logger log.Interface + NewProbeCLI func() (ooni.ProbeCLI, error) +} + +var defaultconfig = doinfoconfig{ + Logger: log.Log, + NewProbeCLI: root.NewProbeCLI, +} + +func doinfo(config doinfoconfig) error { + probeCLI, err := config.NewProbeCLI() + if err != nil { + config.Logger.Errorf("%s", err) + return err + } + config.Logger.WithFields(log.Fields{"path": probeCLI.Home()}).Info("Home") + config.Logger.WithFields(log.Fields{"path": probeCLI.TempDir()}).Info("TempDir") + return nil +} diff --git a/internal/cli/info/info_test.go b/internal/cli/info/info_test.go new file mode 100644 index 0000000..5339d87 --- /dev/null +++ b/internal/cli/info/info_test.go @@ -0,0 +1,80 @@ +package info + +import ( + "errors" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/internal/ooni" + "github.com/ooni/probe-cli/internal/oonitest" +) + +func TestNewProbeCLIFailed(t *testing.T) { + expected := errors.New("mocked error") + handler := &oonitest.FakeLoggerHandler{} + err := doinfo(doinfoconfig{ + NewProbeCLI: func() (ooni.ProbeCLI, error) { + return nil, expected + }, + Logger: &log.Logger{ + Handler: handler, + Level: log.DebugLevel, + }, + }) + if !errors.Is(err, expected) { + t.Fatalf("not the error we expected: %+v", err) + } + if len(handler.FakeEntries) != 1 { + t.Fatal("invalid number of log entries") + } + entry := handler.FakeEntries[0] + if entry.Level != log.ErrorLevel { + t.Fatal("invalid log level") + } + if entry.Message != "mocked error" { + t.Fatal("invalid .Message") + } +} + +func TestSuccess(t *testing.T) { + handler := &oonitest.FakeLoggerHandler{} + cli := &oonitest.FakeProbeCLI{ + FakeHome: "fakehome", + FakeTempDir: "faketempdir", + } + err := doinfo(doinfoconfig{ + NewProbeCLI: func() (ooni.ProbeCLI, error) { + return cli, nil + }, + Logger: &log.Logger{ + Handler: handler, + Level: log.DebugLevel, + }, + }) + if err != nil { + t.Fatal(err) + } + if len(handler.FakeEntries) != 2 { + t.Fatal("invalid number of log entries") + } + entry := handler.FakeEntries[0] + if entry.Level != log.InfoLevel { + t.Fatal("invalid log level") + } + if entry.Message != "Home" { + t.Fatal("invalid .Message") + } + if entry.Fields["path"].(string) != "fakehome" { + t.Fatal("invalid path") + } + entry = handler.FakeEntries[1] + if entry.Level != log.InfoLevel { + t.Fatal("invalid log level") + } + if entry.Message != "TempDir" { + t.Fatal("invalid .Message") + } + if entry.Fields["path"].(string) != "faketempdir" { + t.Fatal("invalid path") + } +} diff --git a/internal/cli/list/list.go b/internal/cli/list/list.go index 2638434..e95408c 100644 --- a/internal/cli/list/list.go +++ b/internal/cli/list/list.go @@ -12,13 +12,13 @@ func init() { cmd := root.Command("list", "List results") resultID := cmd.Arg("id", "the id of the result to list measurements for").Int64() cmd.Action(func(_ *kingpin.ParseContext) error { - ctx, err := root.Init() + probeCLI, err := root.Init() if err != nil { log.WithError(err).Error("failed to initialize root context") return err } if *resultID > 0 { - measurements, err := database.ListMeasurements(ctx.DB, *resultID) + measurements, err := database.ListMeasurements(probeCLI.DB(), *resultID) if err != nil { log.WithError(err).Error("failed to list measurements") return err @@ -61,7 +61,7 @@ func init() { } output.MeasurementSummary(msmtSummary) } else { - doneResults, incompleteResults, err := database.ListResults(ctx.DB) + doneResults, incompleteResults, err := database.ListResults(probeCLI.DB()) if err != nil { log.WithError(err).Error("failed to list results") return err @@ -91,11 +91,11 @@ func init() { netCount := make(map[uint]int) output.SectionTitle("Results") for idx, result := range doneResults { - totalCount, anmlyCount, err := database.GetMeasurementCounts(ctx.DB, result.Result.ID) + totalCount, anmlyCount, err := database.GetMeasurementCounts(probeCLI.DB(), result.Result.ID) if err != nil { log.WithError(err).Error("failed to list measurement counts") } - testKeys, err := database.GetResultTestKeys(ctx.DB, result.Result.ID) + testKeys, err := database.GetResultTestKeys(probeCLI.DB(), result.Result.ID) if err != nil { log.WithError(err).Error("failed to get testKeys") } diff --git a/internal/cli/onboard/onboard.go b/internal/cli/onboard/onboard.go index dba3591..1b420e4 100644 --- a/internal/cli/onboard/onboard.go +++ b/internal/cli/onboard/onboard.go @@ -138,11 +138,11 @@ func Onboarding(config *config.Config) error { // MaybeOnboarding will run the onboarding process only if the informed consent // config option is set to false func MaybeOnboarding(probe *ooni.Probe) error { - if probe.Config.InformedConsent == false { - if probe.IsBatch == true { + if probe.Config().InformedConsent == false { + if probe.IsBatch() == true { return errors.New("cannot run onboarding in batch mode") } - if err := Onboarding(probe.Config); err != nil { + if err := Onboarding(probe.Config()); err != nil { return errors.Wrap(err, "onboarding") } } @@ -161,20 +161,20 @@ func init() { } if *yes == true { - probe.Config.Lock() - probe.Config.InformedConsent = true - probe.Config.Unlock() + probe.Config().Lock() + probe.Config().InformedConsent = true + probe.Config().Unlock() - if err := probe.Config.Write(); err != nil { + if err := probe.Config().Write(); err != nil { log.WithError(err).Error("failed to write config file") return err } return nil } - if probe.IsBatch == true { + if probe.IsBatch() == true { return errors.New("cannot do onboarding in batch mode") } - return Onboarding(probe.Config) + return Onboarding(probe.Config()) }) } diff --git a/internal/cli/reset/reset.go b/internal/cli/reset/reset.go index 85ceba2..3a9ed8c 100644 --- a/internal/cli/reset/reset.go +++ b/internal/cli/reset/reset.go @@ -20,16 +20,16 @@ func init() { } // We need to first the DB otherwise the DB will be rewritten on close when // we delete the home directory. - err = ctx.DB.Close() + err = ctx.DB().Close() if err != nil { log.WithError(err).Error("failed to close the DB") return err } if *force == true { - os.RemoveAll(ctx.Home) - log.Infof("Deleted %s", ctx.Home) + os.RemoveAll(ctx.Home()) + log.Infof("Deleted %s", ctx.Home()) } else { - log.Infof("Run with --force to delete %s", ctx.Home) + log.Infof("Run with --force to delete %s", ctx.Home()) } return nil diff --git a/internal/cli/rm/rm.go b/internal/cli/rm/rm.go index adb7896..b631d97 100644 --- a/internal/cli/rm/rm.go +++ b/internal/cli/rm/rm.go @@ -65,11 +65,11 @@ func init() { } if *all == true { - return deleteAll(ctx.DB, *yes) + return deleteAll(ctx.DB(), *yes) } if *yes == true { - err = database.DeleteResult(ctx.DB, *resultID) + err = database.DeleteResult(ctx.DB(), *resultID) if err == db.ErrNoMoreRows { return errors.New("result not found") } @@ -85,7 +85,7 @@ func init() { if answer == "false" { return errors.New("canceled by user") } - err = database.DeleteResult(ctx.DB, *resultID) + err = database.DeleteResult(ctx.DB(), *resultID) if err == db.ErrNoMoreRows { return errors.New("result not found") } diff --git a/internal/cli/root/root.go b/internal/cli/root/root.go index d6e059e..bcaadbd 100644 --- a/internal/cli/root/root.go +++ b/internal/cli/root/root.go @@ -19,6 +19,15 @@ var Command = Cmd.Command // Init should be called by all subcommand that care to have a ooni.Context instance var Init func() (*ooni.Probe, error) +// NewProbeCLI is like Init but returns a ooni.ProbeCLI instead. +func NewProbeCLI() (ooni.ProbeCLI, error) { + probeCLI, err := Init() + if err != nil { + return nil, err + } + return probeCLI, nil +} + func init() { configPath := Cmd.Flag("config", "Set a custom config file path").Short('c').String() @@ -57,7 +66,7 @@ func init() { return nil, err } if *isBatch { - probe.IsBatch = true + probe.SetIsBatch(true) } return probe, nil diff --git a/internal/cli/run/run.go b/internal/cli/run/run.go index 95c04a5..65be9e8 100644 --- a/internal/cli/run/run.go +++ b/internal/cli/run/run.go @@ -31,7 +31,7 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro log.WithError(err).Error("Failed to lookup the location of the probe") return err } - network, err = database.CreateNetwork(ctx.DB, sess) + network, err = database.CreateNetwork(ctx.DB(), sess) if err != nil { log.WithError(err).Error("Failed to create the network row") return err @@ -48,7 +48,7 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro } log.Debugf("Running test group %s", group.Label) - result, err := database.CreateResult(ctx.DB, ctx.Home, tg, network.ID) + result, err := database.CreateResult(ctx.DB(), ctx.Home(), tg, network.ID) if err != nil { log.Errorf("DB result error: %s", err) return err @@ -69,7 +69,7 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro } } - if err = result.Finished(ctx.DB); err != nil { + if err = result.Finished(ctx.DB()); err != nil { return err } return nil @@ -102,7 +102,7 @@ func init() { } if *noCollector == true { - probe.Config.Sharing.UploadResults = false + probe.Config().Sharing.UploadResults = false } return nil }) diff --git a/internal/cli/show/show.go b/internal/cli/show/show.go index b2f7d66..bff94a0 100644 --- a/internal/cli/show/show.go +++ b/internal/cli/show/show.go @@ -17,7 +17,7 @@ func init() { log.WithError(err).Error("failed to initialize root context") return err } - msmt, err := database.GetMeasurementJSON(ctx.DB, *msmtID) + msmt, err := database.GetMeasurementJSON(ctx.DB(), *msmtID) if err != nil { log.Errorf("error: %v", err) return err diff --git a/internal/nettests/nettests.go b/internal/nettests/nettests.go index cfadf55..1ca9400 100644 --- a/internal/nettests/nettests.go +++ b/internal/nettests/nettests.go @@ -93,7 +93,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err log.Debug(color.RedString("status.queued")) log.Debug(color.RedString("status.started")) - if c.Probe.Config.Sharing.UploadResults { + if c.Probe.Config().Sharing.UploadResults { if err := exp.OpenReport(); err != nil { log.Debugf( "%s: %s", color.RedString("failure.report_create"), err.Error(), @@ -120,7 +120,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err } msmt, err := database.CreateMeasurement( - c.Probe.DB, reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID, + c.Probe.DB(), reportID, exp.Name(), c.res.MeasurementDir, idx, resultID, urlID, ) if err != nil { return errors.Wrap(err, "failed to create measurement") @@ -133,7 +133,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err measurement, err := exp.Measure(input) if err != nil { log.WithError(err).Debug(color.RedString("failure.measurement")) - if err := c.msmts[idx64].Failed(c.Probe.DB, err.Error()); err != nil { + if err := c.msmts[idx64].Failed(c.Probe.DB(), err.Error()); err != nil { return errors.Wrap(err, "failed to mark measurement as failed") } // Even with a failed measurement, we want to continue. We want to @@ -142,16 +142,16 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err // undertsand what went wrong (censorship? bug? anomaly?). } - if c.Probe.Config.Sharing.UploadResults { + if c.Probe.Config().Sharing.UploadResults { // Implementation note: SubmitMeasurement will fail here if we did fail // to open the report but we still want to continue. There will be a // bit of a spew in the logs, perhaps, but stopping seems less efficient. if err := exp.SubmitAndUpdateMeasurement(measurement); err != nil { log.Debug(color.RedString("failure.measurement_submission")) - if err := c.msmts[idx64].UploadFailed(c.Probe.DB, err.Error()); err != nil { + if err := c.msmts[idx64].UploadFailed(c.Probe.DB(), err.Error()); err != nil { return errors.Wrap(err, "failed to mark upload as failed") } - } else if err := c.msmts[idx64].UploadSucceeded(c.Probe.DB); err != nil { + } else if err := c.msmts[idx64].UploadSucceeded(c.Probe.DB()); err != nil { return errors.Wrap(err, "failed to mark upload as succeeded") } } @@ -159,7 +159,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err if err := exp.SaveMeasurement(measurement, msmt.MeasurementFilePath.String); err != nil { return errors.Wrap(err, "failed to save measurement on disk") } - if err := c.msmts[idx64].Done(c.Probe.DB); err != nil { + if err := c.msmts[idx64].Done(c.Probe.DB()); err != nil { return errors.Wrap(err, "failed to mark measurement as done") } @@ -179,7 +179,7 @@ func (c *Controller) Run(builder *engine.ExperimentBuilder, inputs []string) err continue } log.Debugf("Fetching: %d %v", idx, c.msmts[idx64]) - if err := database.AddTestKeys(c.Probe.DB, c.msmts[idx64], tk); err != nil { + if err := database.AddTestKeys(c.Probe.DB(), c.msmts[idx64], tk); err != nil { return errors.Wrap(err, "failed to add test keys to summary") } } diff --git a/internal/nettests/nettests_test.go b/internal/nettests/nettests_test.go index 7710a15..ac8e8c4 100644 --- a/internal/nettests/nettests_test.go +++ b/internal/nettests/nettests_test.go @@ -38,11 +38,11 @@ func TestRun(t *testing.T) { if err != nil { t.Fatal(err) } - network, err := database.CreateNetwork(probe.DB, sess) + network, err := database.CreateNetwork(probe.DB(), sess) if err != nil { t.Fatal(err) } - res, err := database.CreateResult(probe.DB, probe.Home, "middlebox", network.ID) + res, err := database.CreateResult(probe.DB(), probe.Home(), "middlebox", network.ID) if err != nil { t.Fatal(err) } diff --git a/internal/nettests/web_connectivity.go b/internal/nettests/web_connectivity.go index 7a8e793..23cb969 100644 --- a/internal/nettests/web_connectivity.go +++ b/internal/nettests/web_connectivity.go @@ -19,7 +19,7 @@ func lookupURLs(ctl *Controller, limit int64, categories []string) ([]string, ma for idx, url := range testlist.Result { log.Debugf("Going over URL %d", idx) urlID, err := database.CreateOrUpdateURL( - ctl.Probe.DB, url.URL, url.CategoryCode, url.CountryCode, + ctl.Probe.DB(), url.URL, url.CategoryCode, url.CountryCode, ) if err != nil { log.Error("failed to add to the URL table") @@ -38,8 +38,8 @@ type WebConnectivity struct { // Run starts the test func (n WebConnectivity) Run(ctl *Controller) error { - log.Debugf("Enabled category codes are the following %v", ctl.Probe.Config.Nettests.WebsitesEnabledCategoryCodes) - urls, urlIDMap, err := lookupURLs(ctl, ctl.Probe.Config.Nettests.WebsitesURLLimit, ctl.Probe.Config.Nettests.WebsitesEnabledCategoryCodes) + log.Debugf("Enabled category codes are the following %v", ctl.Probe.Config().Nettests.WebsitesEnabledCategoryCodes) + urls, urlIDMap, err := lookupURLs(ctl, ctl.Probe.Config().Nettests.WebsitesURLLimit, ctl.Probe.Config().Nettests.WebsitesEnabledCategoryCodes) if err != nil { return err } diff --git a/internal/ooni/ooni.go b/internal/ooni/ooni.go index 550073a..1ffe6e2 100644 --- a/internal/ooni/ooni.go +++ b/internal/ooni/ooni.go @@ -19,14 +19,34 @@ import ( "upper.io/db.v3/lib/sqlbuilder" ) +// ProbeCLI is the OONI Probe CLI context. +type ProbeCLI interface { + Config() *config.Config + DB() sqlbuilder.Database + IsBatch() bool + Home() string + TempDir() string + NewProbeEngine() (ProbeEngine, error) +} + +// ProbeEngine is an instance of the OONI Probe engine. +type ProbeEngine interface { + Close() error + MaybeLookupLocation() error + ProbeASNString() string + ProbeCC() string + ProbeIP() string + ProbeNetworkName() string +} + // Probe contains the ooniprobe CLI context. type Probe struct { - Config *config.Config - DB sqlbuilder.Database - IsBatch bool + config *config.Config + db sqlbuilder.Database + isBatch bool - Home string - TempDir string + home string + tempDir string dbPath string configPath string @@ -41,6 +61,36 @@ type Probe struct { softwareVersion string } +// SetIsBatch sets the value of isBatch. +func (p *Probe) SetIsBatch(v bool) { + p.isBatch = v +} + +// IsBatch returns whether we're running in batch mode. +func (p *Probe) IsBatch() bool { + return p.isBatch +} + +// Config returns the configuration +func (p *Probe) Config() *config.Config { + return p.config +} + +// DB returns the database we're using +func (p *Probe) DB() sqlbuilder.Database { + return p.db +} + +// Home returns the home directory. +func (p *Probe) Home() string { + return p.home +} + +// TempDir returns the temporary directory. +func (p *Probe) TempDir() string { + return p.tempDir +} + // IsTerminated checks to see if the isTerminatedAtomicInt is set to a non zero // value and therefore we have received the signal to shutdown the running test func (p *Probe) IsTerminated() bool { @@ -102,37 +152,37 @@ func (p *Probe) MaybeListenForStdinClosed() { func (p *Probe) Init(softwareName, softwareVersion string) error { var err error - if err = MaybeInitializeHome(p.Home); err != nil { + if err = MaybeInitializeHome(p.home); err != nil { return err } if p.configPath != "" { log.Debugf("Reading config file from %s", p.configPath) - p.Config, err = config.ReadConfig(p.configPath) + p.config, err = config.ReadConfig(p.configPath) } else { log.Debug("Reading default config file") - p.Config, err = InitDefaultConfig(p.Home) + p.config, err = InitDefaultConfig(p.home) } if err != nil { return err } - if err = p.Config.MaybeMigrate(); err != nil { + if err = p.config.MaybeMigrate(); err != nil { return errors.Wrap(err, "migrating config") } - p.dbPath = utils.DBDir(p.Home, "main") + p.dbPath = utils.DBDir(p.home, "main") log.Debugf("Connecting to database sqlite3://%s", p.dbPath) db, err := database.Connect(p.dbPath) if err != nil { return err } - p.DB = db + p.db = db tempDir, err := ioutil.TempDir("", "ooni") if err != nil { return errors.Wrap(err, "creating TempDir") } - p.TempDir = tempDir + p.tempDir = tempDir p.softwareName = softwareName p.softwareVersion = softwareVersion @@ -144,31 +194,40 @@ func (p *Probe) Init(softwareName, softwareVersion string) error { // the session when done using it, by calling sess.Close(). func (p *Probe) NewSession() (*engine.Session, error) { kvstore, err := engine.NewFileSystemKVStore( - utils.EngineDir(p.Home), + utils.EngineDir(p.home), ) if err != nil { return nil, errors.Wrap(err, "creating engine's kvstore") } return engine.NewSession(engine.SessionConfig{ - AssetsDir: utils.AssetsDir(p.Home), + AssetsDir: utils.AssetsDir(p.home), KVStore: kvstore, Logger: enginex.Logger, PrivacySettings: model.PrivacySettings{ - IncludeASN: p.Config.Sharing.IncludeASN, + IncludeASN: p.config.Sharing.IncludeASN, IncludeCountry: true, - IncludeIP: p.Config.Sharing.IncludeIP, + IncludeIP: p.config.Sharing.IncludeIP, }, SoftwareName: p.softwareName, SoftwareVersion: p.softwareVersion, - TempDir: p.TempDir, + TempDir: p.tempDir, }) } +// NewProbeEngine creates a new ProbeEngine instance. +func (p *Probe) NewProbeEngine() (ProbeEngine, error) { + sess, err := p.NewSession() + if err != nil { + return nil, err + } + return sess, nil +} + // NewProbe creates a new probe instance. func NewProbe(configPath string, homePath string) *Probe { return &Probe{ - Home: homePath, - Config: &config.Config{}, + home: homePath, + config: &config.Config{}, configPath: configPath, isTerminatedAtomicInt: 0, } diff --git a/internal/oonitest/oonitest.go b/internal/oonitest/oonitest.go new file mode 100644 index 0000000..7fa2abe --- /dev/null +++ b/internal/oonitest/oonitest.go @@ -0,0 +1,126 @@ +// Package oonitest contains code used for testing. +package oonitest + +import ( + "sync" + + "github.com/apex/log" + "github.com/ooni/probe-cli/internal/config" + "github.com/ooni/probe-cli/internal/ooni" + "upper.io/db.v3/lib/sqlbuilder" +) + +// FakeOutput allows to fake the output package. +type FakeOutput struct { + FakeSectionTitle []string + mu sync.Mutex +} + +// SectionTitle writes the section title. +func (fo *FakeOutput) SectionTitle(s string) { + fo.mu.Lock() + defer fo.mu.Unlock() + fo.FakeSectionTitle = append(fo.FakeSectionTitle, s) +} + +// FakeProbeCLI fakes ooni.ProbeCLI +type FakeProbeCLI struct { + FakeConfig *config.Config + FakeDB sqlbuilder.Database + FakeIsBatch bool + FakeHome string + FakeTempDir string + FakeProbeEnginePtr ooni.ProbeEngine + FakeProbeEngineErr error +} + +// Config implements ProbeCLI.Config +func (cli *FakeProbeCLI) Config() *config.Config { + return cli.FakeConfig +} + +// DB implements ProbeCLI.DB +func (cli *FakeProbeCLI) DB() sqlbuilder.Database { + return cli.FakeDB +} + +// IsBatch implements ProbeCLI.IsBatch +func (cli *FakeProbeCLI) IsBatch() bool { + return cli.FakeIsBatch +} + +// Home implements ProbeCLI.Home +func (cli *FakeProbeCLI) Home() string { + return cli.FakeHome +} + +// TempDir implements ProbeCLI.TempDir +func (cli *FakeProbeCLI) TempDir() string { + return cli.FakeTempDir +} + +// NewProbeEngine implements ProbeCLI.NewProbeEngine +func (cli *FakeProbeCLI) NewProbeEngine() (ooni.ProbeEngine, error) { + return cli.FakeProbeEnginePtr, cli.FakeProbeEngineErr +} + +var _ ooni.ProbeCLI = &FakeProbeCLI{} + +// FakeProbeEngine fakes ooni.ProbeEngine +type FakeProbeEngine struct { + FakeClose error + FakeMaybeLookupLocation error + FakeProbeASNString string + FakeProbeCC string + FakeProbeIP string + FakeProbeNetworkName string +} + +// Close implements ProbeEngine.Close +func (eng *FakeProbeEngine) Close() error { + return eng.FakeClose +} + +// MaybeLookupLocation implements ProbeEngine.MaybeLookupLocation +func (eng *FakeProbeEngine) MaybeLookupLocation() error { + return eng.FakeMaybeLookupLocation +} + +// ProbeASNString implements ProbeEngine.ProbeASNString +func (eng *FakeProbeEngine) ProbeASNString() string { + return eng.FakeProbeASNString +} + +// ProbeCC implements ProbeEngine.ProbeCC +func (eng *FakeProbeEngine) ProbeCC() string { + return eng.FakeProbeCC +} + +// ProbeIP implements ProbeEngine.ProbeIP +func (eng *FakeProbeEngine) ProbeIP() string { + return eng.FakeProbeIP +} + +// ProbeNetworkName implements ProbeEngine.ProbeNetworkName +func (eng *FakeProbeEngine) ProbeNetworkName() string { + return eng.FakeProbeNetworkName +} + +var _ ooni.ProbeEngine = &FakeProbeEngine{} + +// FakeLoggerHandler fakes apex.log.Handler. +type FakeLoggerHandler struct { + FakeEntries []*log.Entry + FakeErr error + mu sync.Mutex +} + +// HandleLog implements Handler.HandleLog. +func (handler *FakeLoggerHandler) HandleLog(entry *log.Entry) error { + handler.mu.Lock() + defer handler.mu.Unlock() + handler.FakeEntries = append(handler.FakeEntries, entry) + return handler.FakeErr +} + +var _ log.Handler = &FakeLoggerHandler{}