diff --git a/internal/cli/run/run.go b/internal/cli/run/run.go index 65be9e8..3817623 100644 --- a/internal/cli/run/run.go +++ b/internal/cli/run/run.go @@ -13,13 +13,20 @@ import ( "github.com/ooni/probe-cli/internal/ooni" ) -func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) error { - if ctx.IsTerminated() == true { +type runNettestGroupConfig struct { + tg string + ctx *ooni.Probe + inputFiles []string + inputs []string +} + +func runNettestGroup(config runNettestGroupConfig) error { + if config.ctx.IsTerminated() == true { log.Debugf("context is terminated, stopping runNettestGroup early") return nil } - sess, err := ctx.NewSession() + sess, err := config.ctx.NewSession() if err != nil { log.WithError(err).Error("Failed to create a measurement session") return err @@ -31,7 +38,7 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro log.WithError(err).Error("Failed to lookup the location of the probe") return err } - network, err = database.CreateNetwork(ctx.DB(), sess) + network, err := database.CreateNetwork(config.ctx.DB(), sess) if err != nil { log.WithError(err).Error("Failed to create the network row") return err @@ -41,35 +48,38 @@ func runNettestGroup(tg string, ctx *ooni.Probe, network *database.Network) erro return err } - group, ok := nettests.NettestGroups[tg] + group, ok := nettests.NettestGroups[config.tg] if !ok { - log.Errorf("No test group named %s", tg) + log.Errorf("No test group named %s", config.tg) return errors.New("invalid test group name") } log.Debugf("Running test group %s", group.Label) - result, err := database.CreateResult(ctx.DB(), ctx.Home(), tg, network.ID) + result, err := database.CreateResult( + config.ctx.DB(), config.ctx.Home(), config.tg, network.ID) if err != nil { log.Errorf("DB result error: %s", err) return err } - ctx.ListenForSignals() - ctx.MaybeListenForStdinClosed() + config.ctx.ListenForSignals() + config.ctx.MaybeListenForStdinClosed() for i, nt := range group.Nettests { - if ctx.IsTerminated() == true { + if config.ctx.IsTerminated() == true { log.Debugf("context is terminated, stopping group.Nettests early") break } log.Debugf("Running test %T", nt) - ctl := nettests.NewController(nt, ctx, result, sess) + ctl := nettests.NewController(nt, config.ctx, result, sess) + ctl.InputFiles = config.inputFiles + ctl.Inputs = config.inputs ctl.SetNettestIndex(i, len(group.Nettests)) if err = nt.Run(ctl); err != nil { log.WithError(err).Errorf("Failed to run %s", group.Label) } } - if err = result.Finished(ctx.DB()); err != nil { + if err = result.Finished(config.ctx.DB()); err != nil { return err } return nil @@ -80,7 +90,6 @@ func init() { var nettestGroupNamesBlue []string var probe *ooni.Probe - var network *database.Network for name := range nettests.NettestGroups { nettestGroupNamesBlue = append(nettestGroupNamesBlue, color.BlueString(name)) @@ -108,30 +117,50 @@ func init() { }) websitesCmd := cmd.Command("websites", "") + inputFile := websitesCmd.Flag("input-file", "File containing input URLs").Strings() + input := websitesCmd.Flag("input", "Test the specified URL").Strings() websitesCmd.Action(func(_ *kingpin.ParseContext) error { - return runNettestGroup("websites", probe, network) + return runNettestGroup(runNettestGroupConfig{ + tg: "websites", + ctx: probe, + inputFiles: *inputFile, + inputs: *input, + }) }) imCmd := cmd.Command("im", "") imCmd.Action(func(_ *kingpin.ParseContext) error { - return runNettestGroup("im", probe, network) + return runNettestGroup(runNettestGroupConfig{ + tg: "im", + ctx: probe, + }) }) performanceCmd := cmd.Command("performance", "") performanceCmd.Action(func(_ *kingpin.ParseContext) error { - return runNettestGroup("performance", probe, network) + return runNettestGroup(runNettestGroupConfig{ + tg: "performance", + ctx: probe, + }) }) middleboxCmd := cmd.Command("middlebox", "") middleboxCmd.Action(func(_ *kingpin.ParseContext) error { - return runNettestGroup("middlebox", probe, network) + return runNettestGroup(runNettestGroupConfig{ + tg: "middlebox", + ctx: probe, + }) }) circumventionCmd := cmd.Command("circumvention", "") circumventionCmd.Action(func(_ *kingpin.ParseContext) error { - return runNettestGroup("circumvention", probe, network) + return runNettestGroup(runNettestGroupConfig{ + tg: "circumvention", + ctx: probe, + }) }) allCmd := cmd.Command("all", "").Default() allCmd.Action(func(_ *kingpin.ParseContext) error { log.Infof("Running %s tests", color.BlueString("all")) for tg := range nettests.NettestGroups { - if err := runNettestGroup(tg, probe, network); err != nil { + group := runNettestGroupConfig{tg: tg, ctx: probe} + if err := runNettestGroup(group); err != nil { log.WithError(err).Errorf("failed to run %s", tg) } } diff --git a/internal/nettests/nettests.go b/internal/nettests/nettests.go index 02886ae..225ca08 100644 --- a/internal/nettests/nettests.go +++ b/internal/nettests/nettests.go @@ -44,6 +44,15 @@ type Controller struct { msmts map[int64]*database.Measurement inputIdxMap map[int64]int64 // Used to map mk idx to database id + // InputFiles optionally contains the names of the input + // files to read inputs from (only for nettests that take + // inputs, of course) + InputFiles []string + + // Inputs contains inputs to be tested. These are specified + // using the command line using the --input flag. + Inputs []string + // numInputs is the total number of inputs numInputs int diff --git a/internal/nettests/web_connectivity.go b/internal/nettests/web_connectivity.go index 7f33e2d..6a27b8b 100644 --- a/internal/nettests/web_connectivity.go +++ b/internal/nettests/web_connectivity.go @@ -12,6 +12,8 @@ func lookupURLs(ctl *Controller, limit int64, categories []string) ([]string, ma inputloader := engine.NewInputLoader(engine.InputLoaderConfig{ InputPolicy: engine.InputRequired, Session: ctl.Session, + SourceFiles: ctl.InputFiles, + StaticInputs: ctl.Inputs, URLCategories: categories, URLLimit: limit, })