diff --git a/pkg/oonimkall/eventemitter.go b/pkg/oonimkall/eventemitter.go index 17b9366..83b19b5 100644 --- a/pkg/oonimkall/eventemitter.go +++ b/pkg/oonimkall/eventemitter.go @@ -1,16 +1,16 @@ package oonimkall -import "time" - // eventEmitter emits event on a channel type eventEmitter struct { disabled map[string]bool + eof <-chan interface{} out chan<- *event } // newEventEmitter creates a new Emitter -func newEventEmitter(disabledEvents []string, out chan<- *event) *eventEmitter { - ee := &eventEmitter{out: out} +func newEventEmitter(disabledEvents []string, out chan<- *event, + eof <-chan interface{}) *eventEmitter { + ee := &eventEmitter{eof: eof, out: out} ee.disabled = make(map[string]bool) for _, eventname := range disabledEvents { ee.disabled[eventname] = true @@ -38,13 +38,10 @@ func (ee *eventEmitter) Emit(key string, value interface{}) { if ee.disabled[key] { return } - const maxSendTimeout = 250 * time.Millisecond - timer := time.NewTimer(maxSendTimeout) - defer timer.Stop() + // Prevent this goroutine from blocking on `ee.out` if the caller + // has already told us it's not going to accept more events. select { case ee.out <- &event{Key: key, Value: value}: - // good, we've been able to send the new event - case <-timer.C: - // oops, we've timed out sending + case <-ee.eof: } } diff --git a/pkg/oonimkall/eventemitter_test.go b/pkg/oonimkall/eventemitter_test.go index a9183d6..e0c18e5 100644 --- a/pkg/oonimkall/eventemitter_test.go +++ b/pkg/oonimkall/eventemitter_test.go @@ -4,15 +4,22 @@ import "testing" func TestDisabledEvents(t *testing.T) { out := make(chan *event) - emitter := newEventEmitter([]string{"log"}, out) + eof := make(chan interface{}) + emitter := newEventEmitter([]string{"log"}, out, eof) go func() { emitter.Emit("log", eventLog{Message: "foo"}) - close(out) + close(eof) }() var count int64 - for ev := range out { - if ev.Key == "log" { - count++ +Loop: + for { + select { + case ev := <-out: + if ev.Key == "log" { + count++ + } + case <-eof: + break Loop } } if count > 0 { @@ -22,18 +29,25 @@ func TestDisabledEvents(t *testing.T) { func TestEmitFailureStartup(t *testing.T) { out := make(chan *event) - emitter := newEventEmitter([]string{}, out) + eof := make(chan interface{}) + emitter := newEventEmitter([]string{}, out, eof) go func() { emitter.EmitFailureStartup("mocked error") - close(out) + close(eof) }() var found bool - for ev := range out { - if ev.Key == "failure.startup" { - evv := ev.Value.(eventFailure) // panic if not castable - if evv.Failure == "mocked error" { - found = true +Loop: + for { + select { + case ev := <-out: + if ev.Key == "failure.startup" { + evv := ev.Value.(eventFailure) // panic if not castable + if evv.Failure == "mocked error" { + found = true + } } + case <-eof: + break Loop } } if !found { @@ -43,18 +57,25 @@ func TestEmitFailureStartup(t *testing.T) { func TestEmitStatusProgress(t *testing.T) { out := make(chan *event) - emitter := newEventEmitter([]string{}, out) + eof := make(chan interface{}) + emitter := newEventEmitter([]string{}, out, eof) go func() { emitter.EmitStatusProgress(0.7, "foo") - close(out) + close(eof) }() var found bool - for ev := range out { - if ev.Key == "status.progress" { - evv := ev.Value.(eventStatusProgress) // panic if not castable - if evv.Message == "foo" && evv.Percentage == 0.7 { - found = true +Loop: + for { + select { + case ev := <-out: + if ev.Key == "status.progress" { + evv := ev.Value.(eventStatusProgress) // panic if not castable + if evv.Message == "foo" && evv.Percentage == 0.7 { + found = true + } } + case <-eof: + break Loop } } if !found { diff --git a/pkg/oonimkall/runner.go b/pkg/oonimkall/runner.go index cb76cd4..316f115 100644 --- a/pkg/oonimkall/runner.go +++ b/pkg/oonimkall/runner.go @@ -38,7 +38,9 @@ const ( // run runs the task specified by settings.Name until completion. This is the // top-level API that should be called by oonimkall. func run(ctx context.Context, settings *settings, out chan<- *event) { - r := newRunner(settings, out) + eof := make(chan interface{}) + defer close(eof) // tell the emitter to not emit anymore. + r := newRunner(settings, out, eof) r.Run(ctx) } @@ -51,9 +53,9 @@ type runner struct { } // newRunner creates a new task runner -func newRunner(settings *settings, out chan<- *event) *runner { +func newRunner(settings *settings, out chan<- *event, eof <-chan interface{}) *runner { return &runner{ - emitter: newEventEmitter(settings.DisabledEvents, out), + emitter: newEventEmitter(settings.DisabledEvents, out, eof), out: out, settings: settings, } diff --git a/pkg/oonimkall/runner_internal_test.go b/pkg/oonimkall/runner_internal_test.go index 9b83af3..0f3f604 100644 --- a/pkg/oonimkall/runner_internal_test.go +++ b/pkg/oonimkall/runner_internal_test.go @@ -44,32 +44,39 @@ func TestRunnerMaybeLookupLocationFailure(t *testing.T) { Version: 1, } seench := make(chan int64) + eof := make(chan interface{}) go func() { var seen int64 - for ev := range out { - switch ev.Key { - case "failure.ip_lookup", "failure.asn_lookup", - "failure.cc_lookup", "failure.resolver_lookup": - seen++ - case "status.progress": - evv := ev.Value.(eventStatusProgress) - if evv.Percentage >= 0.2 { - panic(fmt.Sprintf("too much progress: %+v", ev)) + Loop: + for { + select { + case ev := <-out: + switch ev.Key { + case "failure.ip_lookup", "failure.asn_lookup", + "failure.cc_lookup", "failure.resolver_lookup": + seen++ + case "status.progress": + evv := ev.Value.(eventStatusProgress) + if evv.Percentage >= 0.2 { + panic(fmt.Sprintf("too much progress: %+v", ev)) + } + case "status.queued", "status.started", "status.end": + default: + panic(fmt.Sprintf("unexpected key: %s - %+v", ev.Key, ev.Value)) } - case "status.queued", "status.started", "status.end": - default: - panic(fmt.Sprintf("unexpected key: %s - %+v", ev.Key, ev.Value)) + case <-eof: + break Loop } } seench <- seen }() expected := errors.New("mocked error") - r := newRunner(settings, out) + r := newRunner(settings, out, eof) r.maybeLookupLocation = func(*engine.Session) error { return expected } r.Run(context.Background()) - close(out) + close(eof) if n := <-seench; n != 4 { t.Fatal("unexpected number of events") } diff --git a/pkg/oonimkall/task.go b/pkg/oonimkall/task.go index de59b3b..2262750 100644 --- a/pkg/oonimkall/task.go +++ b/pkg/oonimkall/task.go @@ -59,9 +59,11 @@ import ( // running as subsequent Tasks to reuse the Session connections // created with the OONI probe services backends. type Task struct { - cancel context.CancelFunc - isdone *atomicx.Int64 - out chan *event + cancel context.CancelFunc + isdone *atomicx.Int64 + isstarted chan interface{} // for testing + isstopped chan interface{} // for testing + out chan *event } // StartTask starts an asynchronous task. The input argument is a @@ -74,13 +76,17 @@ func StartTask(input string) (*Task, error) { const bufsiz = 128 // common case: we don't want runner to block ctx, cancel := context.WithCancel(context.Background()) task := &Task{ - cancel: cancel, - isdone: &atomicx.Int64{}, - out: make(chan *event, bufsiz), + cancel: cancel, + isdone: &atomicx.Int64{}, + isstarted: make(chan interface{}), + isstopped: make(chan interface{}), + out: make(chan *event, bufsiz), } go func() { + close(task.isstarted) run(ctx, &settings, task.out) task.out <- nil // signal that we're done w/o closing the channel + close(task.isstopped) }() return task, nil } @@ -107,10 +113,6 @@ func (t *Task) IsDone() bool { return t.isdone.Load() != 0 } -func (t *Task) isRunning() bool { - return !t.IsDone() -} - // Interrupt interrupts the task. func (t *Task) Interrupt() { t.cancel() diff --git a/pkg/oonimkall/task_integration_test.go b/pkg/oonimkall/task_integration_test.go index 4a6b1bf..8ac50b1 100644 --- a/pkg/oonimkall/task_integration_test.go +++ b/pkg/oonimkall/task_integration_test.go @@ -290,7 +290,7 @@ func TestInterruptExampleWithInput(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } - t.Skip("Skipping broken test; see https://github.com/ooni/probe-cli/v3/internal/engine/issues/992") + t.Skip("Skipping broken test; see https://github.com/ooni/probe-engine/issues/992") task, err := StartTask(`{ "assets_dir": "../testdata/oonimkall/assets", "inputs": [ @@ -494,7 +494,9 @@ func TestPrivacyAndScrubbing(t *testing.T) { } } -func TestNonblock(t *testing.T) { +func TestNonblockWithFewEvents(t *testing.T) { + // This test tests whether we won't block for a small + // number of events emitted by the task if testing.Short() { t.Skip("skip test in short mode") } @@ -511,16 +513,16 @@ func TestNonblock(t *testing.T) { if err != nil { t.Fatal(err) } - if !task.isRunning() { - t.Fatal("The runner should be running at this point") - } - // If the task blocks because it emits too much events, this test - // will run forever and will be killed. Because we have room for up - // to 128 events in the buffer, we should hopefully be fine. - for task.isRunning() { - time.Sleep(time.Second) - } + // Wait for the task thread to start + <-task.isstarted + // Wait for the task thread to complete + <-task.isstopped + var count int for !task.IsDone() { task.WaitForNextEvent() + count++ + } + if count < 5 { + t.Fatal("too few events") } }