diff --git a/internal/cli/run/run.go b/internal/cli/run/run.go index 3d7df37..88785c3 100644 --- a/internal/cli/run/run.go +++ b/internal/cli/run/run.go @@ -3,9 +3,7 @@ package run import ( "errors" "fmt" - "path/filepath" "strings" - "time" "github.com/alecthomas/kingpin" "github.com/apex/log" @@ -14,7 +12,6 @@ import ( "github.com/ooni/probe-cli/internal/database" "github.com/ooni/probe-cli/nettests" "github.com/ooni/probe-cli/nettests/groups" - "github.com/ooni/probe-cli/utils" ) func init() { @@ -55,24 +52,13 @@ func init() { return err } - network := database.Network{ - ASN: ctx.Location.ASN, - CountryCode: ctx.Location.CountryCode, - NetworkName: ctx.Location.NetworkName, - IP: ctx.Location.IP, - } - newID, err := ctx.DB.Collection("networks").Insert(network) + network, err := database.CreateNetwork(ctx.DB, ctx.Location) if err != nil { log.WithError(err).Error("Failed to create the network row") return nil } - network.ID = newID.(int64) - result, err := database.CreateResult(ctx.DB, ctx.Home, database.Result{ - TestGroupName: *nettestGroup, - StartTime: time.Now().UTC(), - NetworkID: network.ID, - }) + result, err := database.CreateResult(ctx.DB, ctx.Home, *nettestGroup, network.ID) if err != nil { log.Errorf("DB result error: %s", err) return err @@ -80,16 +66,13 @@ func init() { for _, nt := range group.Nettests { log.Debugf("Running test %T", nt) - msmtPath := filepath.Join(ctx.TempDir, - fmt.Sprintf("msmt-%T-%s.jsonl", nt, - time.Now().UTC().Format(utils.ResultTimestamp))) - - ctl := nettests.NewController(nt, ctx, result, msmtPath) + ctl := nettests.NewController(nt, ctx, result) if err = nt.Run(ctl); err != nil { log.WithError(err).Errorf("Failed to run %s", group.Label) return err } } + if err = result.Finished(ctx.DB, group.Summary); err != nil { return err } diff --git a/internal/database/actions.go b/internal/database/actions.go index 2b3090e..851ad4e 100644 --- a/internal/database/actions.go +++ b/internal/database/actions.go @@ -1,6 +1,7 @@ package database import ( + "database/sql" "time" "github.com/apex/log" @@ -117,41 +118,65 @@ func ListResults(db sqlbuilder.Database) ([]*Result, []*Result, error) { // CreateMeasurement writes the measurement to the database a returns a pointer // to the Measurement -func CreateMeasurement(sess sqlbuilder.Database, m Measurement, i string) (*Measurement, error) { - col := sess.Collection("measurements") +func CreateMeasurement(sess sqlbuilder.Database, reportID sql.NullString, testName string, resultID int64, reportFilePath string, urlID sql.NullInt64) (*Measurement, error) { + msmt := Measurement{ + ReportID: reportID, + TestName: testName, + ResultID: resultID, + ReportFilePath: reportFilePath, + URLID: urlID, + // XXX Do we want to have this be part of something else? + StartTime: time.Now().UTC(), + TestKeys: "", + } - // XXX Do we want to have this be part of something else? - m.StartTime = time.Now().UTC() - m.TestKeys = "" - - // XXX insert also the URL and stuff - //m.Input = i - //m.State = "active" - - newID, err := col.Insert(m) + newID, err := sess.Collection("measurements").Insert(msmt) if err != nil { return nil, errors.Wrap(err, "creating measurement") } - m.ID = newID.(int64) - return &m, nil + msmt.ID = newID.(int64) + return &msmt, nil } // CreateResult writes the Result to the database a returns a pointer // to the Result -func CreateResult(sess sqlbuilder.Database, homePath string, r Result) (*Result, error) { - log.Debugf("Creating result %v", r) +func CreateResult(sess sqlbuilder.Database, homePath string, testGroupName string, networkID int64) (*Result, error) { + startTime := time.Now().UTC() - col := sess.Collection("results") - - p, err := utils.MakeResultsDir(homePath, r.TestGroupName, r.StartTime) + p, err := utils.MakeResultsDir(homePath, testGroupName, startTime) if err != nil { return nil, err } - r.MeasurementDir = p - newID, err := col.Insert(r) + + result := Result{ + TestGroupName: testGroupName, + StartTime: startTime, + NetworkID: networkID, + } + result.MeasurementDir = p + log.Debugf("Creating result %v", result) + + newID, err := sess.Collection("results").Insert(result) if err != nil { return nil, errors.Wrap(err, "creating result") } - r.ID = newID.(int64) - return &r, nil + result.ID = newID.(int64) + return &result, nil +} + +// CreateNetwork will create a new network in the network table +func CreateNetwork(sess sqlbuilder.Database, location *utils.LocationInfo) (*Network, error) { + network := Network{ + ASN: location.ASN, + CountryCode: location.CountryCode, + NetworkName: location.NetworkName, + IP: location.IP, + } + newID, err := sess.Collection("networks").Insert(network) + if err != nil { + return nil, err + } + + network.ID = newID.(int64) + return &network, nil } diff --git a/internal/database/actions_test.go b/internal/database/actions_test.go index 1712562..6cb1ec9 100644 --- a/internal/database/actions_test.go +++ b/internal/database/actions_test.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "os" "testing" - "time" ) func TestMeasurementWorkflow(t *testing.T) { @@ -25,21 +24,18 @@ func TestMeasurementWorkflow(t *testing.T) { if err != nil { t.Error(err) } - result, err := CreateResult(sess, tmpdir, Result{ - TestGroupName: "websites", - StartTime: time.Now().UTC(), - }) + result, err := CreateResult(sess, tmpdir, "websites", 0) if err != nil { t.Fatal(err) } - msmtTemplate := Measurement{ - ReportID: sql.NullString{String: "", Valid: false}, - TestName: "antani", - ResultID: result.ID, - ReportFilePath: tmpdir, - } - m1, err := CreateMeasurement(sess, msmtTemplate, "") + reportID := sql.NullString{String: "", Valid: false} + testName := "antani" + resultID := result.ID + reportFilePath := tmpdir + urlID := sql.NullInt64{Int64: 0, Valid: false} + + m1, err := CreateMeasurement(sess, reportID, testName, resultID, reportFilePath, urlID) if err != nil { t.Fatal(err) } diff --git a/internal/database/models.go b/internal/database/models.go index ffb641a..8cfe1e8 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -23,7 +23,7 @@ type Network struct { // URL represents URLs from the testing lists type URL struct { ID int64 `db:"id"` - URL int64 `db:"url"` + URL string `db:"url"` CategoryCode string `db:"category_code"` CountryCode string `db:"country_code"` } @@ -42,7 +42,7 @@ type Measurement struct { UploadFailureMsg sql.NullString `db:"upload_failure_msg,omitempty"` IsRerun bool `db:"is_rerun"` ReportID sql.NullString `db:"report_id,omitempty"` - URLID string `db:"url_id"` // Used to reference URL + URLID sql.NullInt64 `db:"url_id,omitempty"` // Used to reference URL MeasurementID sql.NullInt64 `db:"measurement_id,omitempty"` IsAnomaly sql.NullBool `db:"is_anomaly,omitempty"` // FIXME we likely want to support JSON. See: https://github.com/upper/db/issues/462 diff --git a/nettests/nettests.go b/nettests/nettests.go index 61915e1..398b60f 100644 --- a/nettests/nettests.go +++ b/nettests/nettests.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "time" "github.com/apex/log" "github.com/fatih/color" @@ -24,7 +25,10 @@ type Nettest interface { } // NewController creates a nettest controller -func NewController(nt Nettest, ctx *ooni.Context, res *database.Result, msmtPath string) *Controller { +func NewController(nt Nettest, ctx *ooni.Context, res *database.Result) *Controller { + msmtPath := filepath.Join(ctx.TempDir, + fmt.Sprintf("msmt-%T-%s.jsonl", nt, + time.Now().UTC().Format(utils.ResultTimestamp))) return &Controller{ Ctx: ctx, nt: nt, @@ -36,11 +40,12 @@ func NewController(nt Nettest, ctx *ooni.Context, res *database.Result, msmtPath // Controller is passed to the run method of every Nettest // each nettest instance has one controller type Controller struct { - Ctx *ooni.Context - res *database.Result - nt Nettest - msmts map[int64]*database.Measurement - msmtPath string // XXX maybe we can drop this and just use a temporary file + Ctx *ooni.Context + res *database.Result + nt Nettest + msmts map[int64]*database.Measurement + msmtPath string // XXX maybe we can drop this and just use a temporary file + inputIdxMap map[int64]int64 // Used to map mk idx to database id } func getCaBundlePath() string { @@ -51,6 +56,11 @@ func getCaBundlePath() string { return "/etc/ssl/cert.pem" } +func (c *Controller) SetInputIdxMap(inputIdxMap map[int64]int64) error { + c.inputIdxMap = inputIdxMap + return nil +} + // Init should be called once to initialise the nettest func (c *Controller) Init(nt *mk.Nettest) error { log.Debugf("Init: %v", nt) @@ -58,12 +68,11 @@ func (c *Controller) Init(nt *mk.Nettest) error { c.msmts = make(map[int64]*database.Measurement) - msmtTemplate := database.Measurement{ - ReportID: sql.NullString{String: "", Valid: false}, - TestName: nt.Name, - ResultID: c.res.ID, - ReportFilePath: c.msmtPath, - } + // These values are shared by every measurement + reportID := sql.NullString{String: "", Valid: false} + testName := nt.Name + resultID := c.res.ID + reportFilePath := c.msmtPath // This is to workaround homedirs having UTF-8 characters in them. // See: https://github.com/measurement-kit/measurement-kit/issues/1635 @@ -157,7 +166,7 @@ func (c *Controller) Init(nt *mk.Nettest) error { nt.On("status.report_created", func(e mk.Event) { log.Debugf("%s", e.Key) - msmtTemplate.ReportID = sql.NullString{String: e.Value.ReportID, Valid: true} + reportID = sql.NullString{String: e.Value.ReportID, Valid: true} }) nt.On("status.geoip_lookup", func(e mk.Event) { @@ -175,7 +184,11 @@ func (c *Controller) Init(nt *mk.Nettest) error { log.Debugf(color.RedString(e.Key)) idx := e.Value.Idx - msmt, err := database.CreateMeasurement(c.Ctx.DB, msmtTemplate, e.Value.Input) + urlID := sql.NullInt64{Int64: 0, Valid: false} + if c.inputIdxMap != nil { + urlID = sql.NullInt64{Int64: c.inputIdxMap[idx], Valid: true} + } + msmt, err := database.CreateMeasurement(c.Ctx.DB, reportID, testName, resultID, reportFilePath, urlID) if err != nil { log.WithError(err).Error("Failed to create measurement") return diff --git a/nettests/websites/web_connectivity.go b/nettests/websites/web_connectivity.go index d149207..b4db6d4 100644 --- a/nettests/websites/web_connectivity.go +++ b/nettests/websites/web_connectivity.go @@ -6,7 +6,9 @@ import ( "io/ioutil" "net/http" + "github.com/apex/log" "github.com/measurement-kit/go-measurement-kit" + "github.com/ooni/probe-cli/internal/database" "github.com/ooni/probe-cli/nettests" "github.com/pkg/errors" ) @@ -14,6 +16,7 @@ import ( // URLInfo contains the URL and the citizenlab category code for that URL type URLInfo struct { URL string `json:"url"` + CountryCode string `json:"country_code"` CategoryCode string `json:"category_code"` } @@ -24,10 +27,11 @@ type URLResponse struct { const orchestrateBaseURL = "https://events.proteus.test.ooni.io" -func lookupURLs(ctl *nettests.Controller) ([]string, error) { +func lookupURLs(ctl *nettests.Controller) ([]string, map[int64]int64, error) { var ( - parsed = new(URLResponse) - urls []string + parsed = new(URLResponse) + urls []string + urlIDMap map[int64]int64 ) // XXX pass in the configuration for category codes reqURL := fmt.Sprintf("%s/api/v1/urls?probe_cc=%s", @@ -36,22 +40,58 @@ func lookupURLs(ctl *nettests.Controller) ([]string, error) { resp, err := http.Get(reqURL) if err != nil { - return urls, errors.Wrap(err, "failed to perform request") + return urls, urlIDMap, errors.Wrap(err, "failed to perform request") } body, err := ioutil.ReadAll(resp.Body) if err != nil { - return urls, errors.Wrap(err, "failed to read response body") + return urls, urlIDMap, errors.Wrap(err, "failed to read response body") } err = json.Unmarshal([]byte(body), &parsed) if err != nil { - return urls, errors.Wrap(err, "failed to parse json") + return urls, urlIDMap, errors.Wrap(err, "failed to parse json") } - for _, url := range parsed.Results { + for idx, url := range parsed.Results { + var urlID int64 + + res, err := ctl.Ctx.DB.Update("urls").Set( + "url", url.URL, + "category_code", url.CategoryCode, + "country_code", url.CountryCode, + ).Where("url = ? AND country_code = ?", url.URL, url.CountryCode).Exec() + + if err != nil { + log.Error("Failed to write to the URL table") + } else { + affected, err := res.RowsAffected() + + if err != nil { + log.Error("Failed to get affected row count") + } else if affected == 0 { + newID, err := ctl.Ctx.DB.Collection("urls").Insert( + database.URL{ + URL: url.URL, + CategoryCode: url.CategoryCode, + CountryCode: url.CountryCode, + }) + if err != nil { + log.Error("Failed to insert into the URLs table") + } + urlID = newID.(int64) + } else { + lastID, err := res.LastInsertId() + if err != nil { + log.Error("failed to get URL ID") + } + urlID = lastID + } + } + + urlIDMap[int64(idx)] = urlID urls = append(urls, url.URL) } - return urls, nil + return urls, urlIDMap, nil } // WebConnectivity test implementation @@ -63,10 +103,11 @@ func (n WebConnectivity) Run(ctl *nettests.Controller) error { nt := mk.NewNettest("WebConnectivity") ctl.Init(nt) - urls, err := lookupURLs(ctl) + urls, urlIDMap, err := lookupURLs(ctl) if err != nil { return err } + ctl.SetInputIdxMap(urlIDMap) nt.Options.Inputs = urls return nt.Run()