diff --git a/internal/engine/allexperiments.go b/internal/engine/allexperiments.go index 62619b9..06b899e 100644 --- a/internal/engine/allexperiments.go +++ b/internal/engine/allexperiments.go @@ -198,7 +198,7 @@ var experimentsByName = map[string]func(*Session) *ExperimentBuilder{ )) }, config: &stunreachability.Config{}, - inputPolicy: InputOptional, + inputPolicy: InputStrictlyRequired, } }, diff --git a/internal/engine/experiment/stunreachability/fake_test.go b/internal/engine/experiment/stunreachability/fake_test.go index d658228..47df024 100644 --- a/internal/engine/experiment/stunreachability/fake_test.go +++ b/internal/engine/experiment/stunreachability/fake_test.go @@ -6,6 +6,9 @@ import ( "time" ) +// TODO(bassosimone): we should use internal/netxlite/mocks rather +// than rolling out a custom type private to this package. + type FakeConn struct { ReadError error ReadData []byte diff --git a/internal/engine/experiment/stunreachability/stunreachability.go b/internal/engine/experiment/stunreachability/stunreachability.go index 5dadb1f..e017701 100644 --- a/internal/engine/experiment/stunreachability/stunreachability.go +++ b/internal/engine/experiment/stunreachability/stunreachability.go @@ -5,8 +5,10 @@ package stunreachability import ( "context" + "errors" "fmt" "net" + "net/url" "time" "github.com/ooni/probe-cli/v3/internal/engine/legacy/errorsx" @@ -20,7 +22,7 @@ import ( const ( testName = "stunreachability" - testVersion = "0.2.0" + testVersion = "0.3.0" ) // Config contains the experiment config. @@ -64,6 +66,15 @@ func wrap(err error) error { }.MaybeBuild() } +// errStunMissingInput means that the user did not provide any input +var errStunMissingInput = errors.New("stun: missing input") + +// errStunMissingPortInURL means the URL is missing the port +var errStunMissingPortInURL = errors.New("stun: missing port in URL") + +// errUnsupportedURLScheme means we don't support the URL scheme +var errUnsupportedURLScheme = errors.New("stun: unsupported URL scheme") + // Run implements ExperimentMeasurer.Run. func (m *Measurer) Run( ctx context.Context, sess model.ExperimentSession, @@ -72,7 +83,21 @@ func (m *Measurer) Run( tk := new(TestKeys) measurement.TestKeys = tk registerExtensions(measurement) - if err := wrap(tk.run(ctx, m.config, sess, measurement, callbacks)); err != nil { + input := string(measurement.Input) + if input == "" { + return errStunMissingInput + } + URL, err := url.Parse(input) + if err != nil { + return err + } + if URL.Port() == "" { + return errStunMissingPortInURL + } + if URL.Scheme != "stun" { + return errUnsupportedURLScheme + } + if err := wrap(tk.run(ctx, m.config, sess, measurement, callbacks, URL.Host)); err != nil { s := err.Error() tk.Failure = &s return err @@ -83,12 +108,8 @@ func (m *Measurer) Run( func (tk *TestKeys) run( ctx context.Context, config Config, sess model.ExperimentSession, measurement *model.Measurement, callbacks model.ExperimentCallbacks, + endpoint string, ) error { - const defaultAddress = "stun.l.google.com:19302" - endpoint := string(measurement.Input) - if endpoint == "" { - endpoint = defaultAddress - } callbacks.OnProgress(0, fmt.Sprintf("stunreachability: measuring: %s...", endpoint)) defer callbacks.OnProgress( 1, fmt.Sprintf("stunreachability: measuring: %s... done", endpoint)) diff --git a/internal/engine/experiment/stunreachability/stunreachability_internal_test.go b/internal/engine/experiment/stunreachability/stunreachability_internal_test.go deleted file mode 100644 index 8a5607c..0000000 --- a/internal/engine/experiment/stunreachability/stunreachability_internal_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package stunreachability - -import ( - "context" - "net" - - "github.com/pion/stun" -) - -func (c *Config) SetNewClient( - f func(conn stun.Connection, options ...stun.ClientOption) (*stun.Client, error)) { - c.newClient = f -} - -func (c *Config) SetDialContext( - f func(ctx context.Context, network, address string) (net.Conn, error)) { - c.dialContext = f -} diff --git a/internal/engine/experiment/stunreachability/stunreachability_test.go b/internal/engine/experiment/stunreachability/stunreachability_test.go index 2bbc1ba..4ae8fe9 100644 --- a/internal/engine/experiment/stunreachability/stunreachability_test.go +++ b/internal/engine/experiment/stunreachability/stunreachability_test.go @@ -1,37 +1,36 @@ -package stunreachability_test +package stunreachability import ( "context" "errors" "net" - "os" "strings" "testing" "github.com/apex/log" - "github.com/ooni/probe-cli/v3/internal/engine/experiment/stunreachability" "github.com/ooni/probe-cli/v3/internal/engine/mockable" "github.com/ooni/probe-cli/v3/internal/engine/model" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/pion/stun" ) +const ( + defaultEndpoint = "stun.ekiga.net:3478" + defaultInput = "stun://" + defaultEndpoint +) + func TestMeasurerExperimentNameVersion(t *testing.T) { - measurer := stunreachability.NewExperimentMeasurer(stunreachability.Config{}) + measurer := NewExperimentMeasurer(Config{}) if measurer.ExperimentName() != "stunreachability" { t.Fatal("unexpected ExperimentName") } - if measurer.ExperimentVersion() != "0.2.0" { + if measurer.ExperimentVersion() != "0.3.0" { t.Fatal("unexpected ExperimentVersion") } } -func TestRun(t *testing.T) { - if os.Getenv("GITHUB_ACTIONS") == "true" { - // See https://github.com/ooni/probe-engine/issues/874#issuecomment-679850652 - t.Skip("skipping broken test on GitHub Actions") - } - measurer := stunreachability.NewExperimentMeasurer(stunreachability.Config{}) +func TestRunWithoutInput(t *testing.T) { + measurer := NewExperimentMeasurer(Config{}) measurement := new(model.Measurement) err := measurer.Run( context.Background(), @@ -39,29 +38,60 @@ func TestRun(t *testing.T) { measurement, model.NewPrinterCallbacks(log.Log), ) - if err != nil { - t.Fatal(err) - } - tk := measurement.TestKeys.(*stunreachability.TestKeys) - if tk.Failure != nil { - t.Fatal("expected nil failure here") - } - if tk.Endpoint != "stun.l.google.com:19302" { - t.Fatal("unexpected endpoint") - } - if len(tk.NetworkEvents) <= 0 { - t.Fatal("no network events?!") - } - if len(tk.Queries) <= 0 { - t.Fatal("no DNS queries?!") + if !errors.Is(err, errStunMissingInput) { + t.Fatal("not the error we expected", err) } } -func TestRunCustomInput(t *testing.T) { - input := "stun.ekiga.net:3478" - measurer := stunreachability.NewExperimentMeasurer(stunreachability.Config{}) +func TestRunWithInvalidURL(t *testing.T) { + measurer := NewExperimentMeasurer(Config{}) measurement := new(model.Measurement) - measurement.Input = model.MeasurementTarget(input) + measurement.Input = model.MeasurementTarget("\t") // <- invalid URL + err := measurer.Run( + context.Background(), + &mockable.Session{}, + measurement, + model.NewPrinterCallbacks(log.Log), + ) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("not the error we expected", err) + } +} + +func TestRunWithNoPort(t *testing.T) { + measurer := NewExperimentMeasurer(Config{}) + measurement := new(model.Measurement) + measurement.Input = model.MeasurementTarget("stun://stun.ekiga.net") + err := measurer.Run( + context.Background(), + &mockable.Session{}, + measurement, + model.NewPrinterCallbacks(log.Log), + ) + if !errors.Is(err, errStunMissingPortInURL) { + t.Fatal("not the error we expected", err) + } +} + +func TestRunWithUnsupportedURLScheme(t *testing.T) { + measurer := NewExperimentMeasurer(Config{}) + measurement := new(model.Measurement) + measurement.Input = model.MeasurementTarget("https://stun.ekiga.net:3478") + err := measurer.Run( + context.Background(), + &mockable.Session{}, + measurement, + model.NewPrinterCallbacks(log.Log), + ) + if !errors.Is(err, errUnsupportedURLScheme) { + t.Fatal("not the error we expected", err) + } +} + +func TestRunWithInput(t *testing.T) { + measurer := NewExperimentMeasurer(Config{}) + measurement := new(model.Measurement) + measurement.Input = model.MeasurementTarget(defaultInput) err := measurer.Run( context.Background(), &mockable.Session{}, @@ -71,11 +101,11 @@ func TestRunCustomInput(t *testing.T) { if err != nil { t.Fatal(err) } - tk := measurement.TestKeys.(*stunreachability.TestKeys) + tk := measurement.TestKeys.(*TestKeys) if tk.Failure != nil { t.Fatal("expected nil failure here") } - if tk.Endpoint != input { + if tk.Endpoint != defaultEndpoint { t.Fatal("unexpected endpoint") } if len(tk.NetworkEvents) <= 0 { @@ -89,22 +119,23 @@ func TestRunCustomInput(t *testing.T) { func TestCancelledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail everything - measurer := stunreachability.NewExperimentMeasurer(stunreachability.Config{}) + measurer := NewExperimentMeasurer(Config{}) measurement := new(model.Measurement) + measurement.Input = model.MeasurementTarget(defaultInput) err := measurer.Run( ctx, &mockable.Session{}, measurement, model.NewPrinterCallbacks(log.Log), ) - if err.Error() != "interrupted" { - t.Fatal("not the error we expected") + if !errors.Is(err, context.Canceled) { + t.Fatal("not the error we expected", err) } - tk := measurement.TestKeys.(*stunreachability.TestKeys) + tk := measurement.TestKeys.(*TestKeys) if *tk.Failure != "interrupted" { t.Fatal("expected different failure here") } - if tk.Endpoint != "stun.l.google.com:19302" { + if tk.Endpoint != defaultEndpoint { t.Fatal("unexpected endpoint") } if len(tk.NetworkEvents) <= 0 { @@ -117,20 +148,20 @@ func TestCancelledContext(t *testing.T) { if err != nil { t.Fatal(err) } - if _, ok := sk.(stunreachability.SummaryKeys); !ok { + if _, ok := sk.(SummaryKeys); !ok { t.Fatal("invalid type for summary keys") } } func TestNewClientFailure(t *testing.T) { - config := &stunreachability.Config{} + config := &Config{} expected := errors.New("mocked error") - config.SetNewClient( - func(conn stun.Connection, options ...stun.ClientOption) (*stun.Client, error) { - return nil, expected - }) - measurer := stunreachability.NewExperimentMeasurer(*config) + config.newClient = func(conn stun.Connection, options ...stun.ClientOption) (*stun.Client, error) { + return nil, expected + } + measurer := NewExperimentMeasurer(*config) measurement := new(model.Measurement) + measurement.Input = model.MeasurementTarget(defaultInput) err := measurer.Run( context.Background(), &mockable.Session{}, @@ -140,11 +171,11 @@ func TestNewClientFailure(t *testing.T) { if !errors.Is(err, expected) { t.Fatal("not the error we expected") } - tk := measurement.TestKeys.(*stunreachability.TestKeys) + tk := measurement.TestKeys.(*TestKeys) if !strings.HasPrefix(*tk.Failure, "unknown_failure") { t.Fatal("expected different failure here") } - if tk.Endpoint != "stun.l.google.com:19302" { + if tk.Endpoint != defaultEndpoint { t.Fatal("unexpected endpoint") } if len(tk.NetworkEvents) <= 0 { @@ -156,15 +187,15 @@ func TestNewClientFailure(t *testing.T) { } func TestStartFailure(t *testing.T) { - config := &stunreachability.Config{} + config := &Config{} expected := errors.New("mocked error") - config.SetDialContext( - func(ctx context.Context, network, address string) (net.Conn, error) { - conn := &stunreachability.FakeConn{WriteError: expected} - return conn, nil - }) - measurer := stunreachability.NewExperimentMeasurer(*config) + config.dialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + conn := &FakeConn{WriteError: expected} + return conn, nil + } + measurer := NewExperimentMeasurer(*config) measurement := new(model.Measurement) + measurement.Input = model.MeasurementTarget(defaultInput) err := measurer.Run( context.Background(), &mockable.Session{}, @@ -174,11 +205,11 @@ func TestStartFailure(t *testing.T) { if !errors.Is(err, expected) { t.Fatal("not the error we expected") } - tk := measurement.TestKeys.(*stunreachability.TestKeys) + tk := measurement.TestKeys.(*TestKeys) if !strings.HasPrefix(*tk.Failure, "unknown_failure") { t.Fatal("expected different failure here") } - if tk.Endpoint != "stun.l.google.com:19302" { + if tk.Endpoint != defaultEndpoint { t.Fatal("unexpected endpoint") } // We're bypassing normal network with custom dial function @@ -194,15 +225,15 @@ func TestReadFailure(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } - config := &stunreachability.Config{} + config := &Config{} expected := errors.New("mocked error") - config.SetDialContext( - func(ctx context.Context, network, address string) (net.Conn, error) { - conn := &stunreachability.FakeConn{ReadError: expected} - return conn, nil - }) - measurer := stunreachability.NewExperimentMeasurer(*config) + config.dialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + conn := &FakeConn{ReadError: expected} + return conn, nil + } + measurer := NewExperimentMeasurer(*config) measurement := new(model.Measurement) + measurement.Input = model.MeasurementTarget(defaultInput) err := measurer.Run( context.Background(), &mockable.Session{}, @@ -212,11 +243,11 @@ func TestReadFailure(t *testing.T) { if !errors.Is(err, stun.ErrTransactionTimeOut) { t.Fatal("not the error we expected") } - tk := measurement.TestKeys.(*stunreachability.TestKeys) + tk := measurement.TestKeys.(*TestKeys) if *tk.Failure != netxlite.FailureGenericTimeoutError { t.Fatal("expected different failure here") } - if tk.Endpoint != "stun.l.google.com:19302" { + if tk.Endpoint != defaultEndpoint { t.Fatal("unexpected endpoint") } // We're bypassing normal network with custom dial function @@ -229,13 +260,13 @@ func TestReadFailure(t *testing.T) { } func TestSummaryKeysGeneric(t *testing.T) { - measurement := &model.Measurement{TestKeys: &stunreachability.TestKeys{}} - m := &stunreachability.Measurer{} + measurement := &model.Measurement{TestKeys: &TestKeys{}} + m := &Measurer{} osk, err := m.GetSummaryKeys(measurement) if err != nil { t.Fatal(err) } - sk := osk.(stunreachability.SummaryKeys) + sk := osk.(SummaryKeys) if sk.IsAnomaly { t.Fatal("invalid isAnomaly") }