diff --git a/internal/cli/run/run.go b/internal/cli/run/run.go index d65e019..d142cdc 100644 --- a/internal/cli/run/run.go +++ b/internal/cli/run/run.go @@ -43,7 +43,7 @@ func init() { fmt.Sprintf("msmt-%s-%T.jsonl", nt, time.Now().UTC().Format(time.RFC3339Nano))) - ctl := nettests.NewController(ctx, result, msmtPath) + ctl := nettests.NewController(nt, ctx, result, msmtPath) if err := nt.Run(ctl); err != nil { log.WithError(err).Errorf("Failed to run %s", group.Label) return err diff --git a/internal/database/models.go b/internal/database/models.go index c32bcca..fd98e36 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -8,23 +8,41 @@ import ( "github.com/pkg/errors" ) +// UpdateOne will run the specified update query and check that it only affected one row +func UpdateOne(db *sqlx.DB, query string, arg interface{}) error { + res, err := db.NamedExec(query, arg) + + if err != nil { + return errors.Wrap(err, "updating table") + } + count, err := res.RowsAffected() + if err != nil { + return errors.Wrap(err, "updating table") + } + if count != 1 { + return errors.New("inconsistent update count") + } + return nil +} + // Measurement model type Measurement struct { ID int64 `db:"id"` Name string `db:"name"` StartTime time.Time `db:"start_time"` - EndTime time.Time `db:"end_time"` + Runtime float64 `db:"runtime"` Summary string `db:"summary"` // XXX this should be JSON - ASN int64 `db:"asn"` + ASN string `db:"asn"` IP string `db:"ip"` CountryCode string `db:"country"` State string `db:"state"` Failure string `db:"failure"` + UploadFailure string `db:"upload_failure"` + Uploaded bool `db:"uploaded"` ReportFilePath string `db:"report_file"` ReportID string `db:"report_id"` Input string `db:"input"` - MeasurementID string `db:"measurement_id"` - ResultID string `db:"result_id"` + ResultID int64 `db:"result_id"` } // SetGeoIPInfo for the Measurement @@ -32,12 +50,83 @@ func (m *Measurement) SetGeoIPInfo() error { return nil } +// Failed writes the error string to the measurement +func (m *Measurement) Failed(db *sqlx.DB, failure string) error { + m.Failure = failure + + err := UpdateOne(db, `UPDATE measurements + SET failure = :failure, state = :state + WHERE id = :id`, m) + if err != nil { + return errors.Wrap(err, "updating measurement") + } + return nil +} + +// Done marks the measurement as completed +func (m *Measurement) Done(db *sqlx.DB) error { + m.State = "done" + + err := UpdateOne(db, `UPDATE measurements + SET state = :state + WHERE id = :id`, m) + if err != nil { + return errors.Wrap(err, "updating measurement") + } + return nil +} + +// UploadFailed writes the error string for the upload failure to the measurement +func (m *Measurement) UploadFailed(db *sqlx.DB, failure string) error { + m.UploadFailure = failure + m.Uploaded = false + + err := UpdateOne(db, `UPDATE measurements + SET upload_failure = :upload_failure + WHERE id = :id`, m) + if err != nil { + return errors.Wrap(err, "updating measurement") + } + return nil +} + +// UploadSucceeded writes the error string for the upload failure to the measurement +func (m *Measurement) UploadSucceeded(db *sqlx.DB) error { + m.Uploaded = true + + err := UpdateOne(db, `UPDATE measurements + SET uploaded = :uploaded + WHERE id = :id`, m) + if err != nil { + return errors.Wrap(err, "updating measurement") + } + return nil +} + +// WriteSummary writes the summary to the measurement +func (m *Measurement) WriteSummary(db *sqlx.DB, summary string) error { + m.Summary = summary + + err := UpdateOne(db, `UPDATE measurements + SET summary = :summary + WHERE id = :id`, m) + if err != nil { + return errors.Wrap(err, "updating measurement") + } + return nil +} + // CreateMeasurement writes the measurement to the database a returns a pointer // to the Measurement -func CreateMeasurement(db *sqlx.DB, m Measurement) (*Measurement, error) { +func CreateMeasurement(db *sqlx.DB, m Measurement, i string) (*Measurement, error) { + // XXX Do we want to have this be part of something else? + m.StartTime = time.Now().UTC() + m.Input = i + m.State = "active" + res, err := db.NamedExec(`INSERT INTO measurements (name, start_time, - summary, asn, ip, country, + asn, ip, country, state, failure, report_file, report_id, input, measurement_id, result_id) @@ -86,19 +175,12 @@ func (r *Result) Finished(db *sqlx.DB) error { r.Runtime = float64(time.Now().Sub(r.started)) / float64(time.Microsecond) r.Done = true - res, err := db.NamedExec(`UPDATE results + err := UpdateOne(db, `UPDATE results SET done = true, runtime = :runtime WHERE id = :id`, r) if err != nil { return errors.Wrap(err, "updating result") } - count, err := res.RowsAffected() - if err != nil { - return errors.Wrap(err, "updating result") - } - if count != 1 { - return errors.New("inconsistent update count") - } return nil } diff --git a/nettests/nettests.go b/nettests/nettests.go index 9ce9262..007f639 100644 --- a/nettests/nettests.go +++ b/nettests/nettests.go @@ -1,6 +1,8 @@ package nettests import ( + "encoding/json" + "github.com/apex/log" "github.com/measurement-kit/go-measurement-kit" ooni "github.com/openobservatory/gooni" @@ -23,11 +25,12 @@ type NettestGroup struct { } // NewController creates a nettest controller -func NewController(ctx *ooni.Context, res *database.Result, msmtPath string) *Controller { +func NewController(nt Nettest, ctx *ooni.Context, res *database.Result, msmtPath string) *Controller { return &Controller{ - ctx, - res, - msmtPath, + Ctx: ctx, + nt: nt, + res: res, + msmtPath: msmtPath, } } @@ -36,6 +39,8 @@ func NewController(ctx *ooni.Context, res *database.Result, msmtPath string) *Co type Controller struct { Ctx *ooni.Context res *database.Result + nt Nettest + msmts map[int64]*database.Measurement msmtPath string } @@ -43,6 +48,16 @@ type Controller struct { func (c *Controller) Init(nt *mk.Nettest) error { log.Debugf("Init: %v", nt) + msmtTemplate := database.Measurement{ + ASN: "", + IP: "", + CountryCode: "", + ReportID: "", + Name: nt.Name, + ResultID: c.res.ID, + ReportFilePath: c.msmtPath, + } + log.Debugf("OutputPath: %s", c.msmtPath) nt.Options = mk.NettestOptions{ IncludeIP: c.Ctx.Config.Sharing.IncludeIP, @@ -84,10 +99,29 @@ func (c *Controller) Init(nt *mk.Nettest) error { nt.On("status.report_created", func(e mk.Event) { log.Debugf("%s", e.Key) + + msmtTemplate.ReportID = e.Value["report_id"].(string) }) nt.On("status.geoip_lookup", func(e mk.Event) { log.Debugf("%s", e.Key) + + msmtTemplate.ASN = e.Value["probe_asn"].(string) + msmtTemplate.IP = e.Value["probe_ip"].(string) + msmtTemplate.CountryCode = e.Value["probe_cc"].(string) + }) + + nt.On("status.measurement_started", func(e mk.Event) { + log.Debugf("%s", e.Key) + + idx := e.Value["idx"].(int64) + input := e.Value["input"].(string) + msmt, err := database.CreateMeasurement(c.Ctx.DB, msmtTemplate, input) + if err != nil { + log.WithError(err).Error("Failed to create measurement") + return + } + c.msmts[idx] = msmt }) nt.On("status.progress", func(e mk.Event) { @@ -102,18 +136,41 @@ func (c *Controller) Init(nt *mk.Nettest) error { nt.On("failure.measurement", func(e mk.Event) { log.Debugf("%s", e.Key) + + idx := e.Value["idx"].(int64) + failure := e.Value["failure"].(string) + c.msmts[idx].Failed(c.Ctx.DB, failure) }) - nt.On("failure.report_submission", func(e mk.Event) { + nt.On("failure.measurement_submission", func(e mk.Event) { log.Debugf("%s", e.Key) + + idx := e.Value["idx"].(int64) + failure := e.Value["failure"].(string) + c.msmts[idx].UploadFailed(c.Ctx.DB, failure) + }) + + nt.On("status.measurement_uploaded", func(e mk.Event) { + log.Debugf("%s", e.Key) + + idx := e.Value["idx"].(int64) + c.msmts[idx].UploadSucceeded(c.Ctx.DB) + }) + + nt.On("status.measurement_done", func(e mk.Event) { + log.Debugf("%s", e.Key) + + idx := e.Value["idx"].(int64) + c.msmts[idx].Done(c.Ctx.DB) }) nt.On("measurement", func(e mk.Event) { - c.OnEntry(e.Value["json_str"].(string)) + idx := e.Value["idx"].(int64) + c.OnEntry(idx, e.Value["json_str"].(string)) }) nt.On("end", func(e mk.Event) { - c.OnEntry(e.Value["json_str"].(string)) + log.Debugf("end") }) return nil @@ -124,9 +181,23 @@ func (c *Controller) OnProgress(perc float64, msg string) { log.Debugf("OnProgress: %f - %s", perc, msg) } +// Entry is an opaque measurement entry +type Entry struct { + TestKeys map[string]interface{} `json:"test_keys"` +} + // OnEntry should be called every time there is a new entry -func (c *Controller) OnEntry(jsonStr string) { +func (c *Controller) OnEntry(idx int64, jsonStr string) { log.Debugf("OnEntry: %s", jsonStr) + + var entry Entry + json.Unmarshal([]byte(jsonStr), &entry) + summary := c.nt.Summary(entry.TestKeys) + summaryBytes, err := json.Marshal(summary) + if err != nil { + log.WithError(err).Error("failed to serialize summary") + } + c.msmts[idx].WriteSummary(c.Ctx.DB, string(summaryBytes)) } // MKStart is the interface for the mk.Nettest Start() function