diff --git a/internal/archival/quic_test.go b/internal/archival/quic_test.go index f7a0341..caabccf 100644 --- a/internal/archival/quic_test.go +++ b/internal/archival/quic_test.go @@ -1,3 +1,5 @@ +//go:build !go1.18 + package archival import ( diff --git a/internal/archival/quic_test_go118.go b/internal/archival/quic_test_go118.go new file mode 100644 index 0000000..77a5ca3 --- /dev/null +++ b/internal/archival/quic_test_go118.go @@ -0,0 +1,465 @@ +//go:build go1.18 + +package archival + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/lucas-clemente/quic-go" + "github.com/marten-seemann/qtls-go1-18" // it's annoying to depend on that + "github.com/ooni/probe-cli/v3/internal/fakefill" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +func TestSaverWriteTo(t *testing.T) { + // newAddr creates an new net.Addr for testing. + newAddr := func(endpoint string) net.Addr { + return &mocks.Addr{ + MockString: func() string { + return endpoint + }, + MockNetwork: func() string { + return "udp" + }, + } + } + + // newConn is a helper function for creating a new connection. + newConn := func(numBytes int, err error) model.UDPLikeConn { + return &mocks.UDPLikeConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + time.Sleep(time.Microsecond) + return numBytes, err + }, + } + } + + t.Run("on success", func(t *testing.T) { + const mockedEndpoint = "8.8.4.4:443" + const mockedNumBytes = 128 + addr := newAddr(mockedEndpoint) + conn := newConn(mockedNumBytes, nil) + saver := NewSaver() + v := &SingleNetworkEventValidator{ + ExpectedCount: mockedNumBytes, + ExpectedErr: nil, + ExpectedNetwork: "udp", + ExpectedOp: netxlite.WriteToOperation, + ExpectedEpnt: mockedEndpoint, + Saver: saver, + } + buf := make([]byte, 1024) + count, err := saver.WriteTo(conn, buf, addr) + if err != nil { + t.Fatal(err) + } + if count != mockedNumBytes { + t.Fatal("invalid count") + } + if err := v.Validate(); err != nil { + t.Fatal(err) + } + }) + + t.Run("on failure", func(t *testing.T) { + const mockedEndpoint = "8.8.4.4:443" + mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) + addr := newAddr(mockedEndpoint) + conn := newConn(0, mockedError) + saver := NewSaver() + v := &SingleNetworkEventValidator{ + ExpectedCount: 0, + ExpectedErr: mockedError, + ExpectedNetwork: "udp", + ExpectedOp: netxlite.WriteToOperation, + ExpectedEpnt: mockedEndpoint, + Saver: saver, + } + buf := make([]byte, 1024) + count, err := saver.WriteTo(conn, buf, addr) + if !errors.Is(err, mockedError) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("invalid count") + } + if err := v.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +func TestSaverReadFrom(t *testing.T) { + // newAddr creates an new net.Addr for testing. + newAddr := func(endpoint string) net.Addr { + return &mocks.Addr{ + MockString: func() string { + return endpoint + }, + MockNetwork: func() string { + return "udp" + }, + } + } + + // newConn is a helper function for creating a new connection. + newConn := func(numBytes int, addr net.Addr, err error) model.UDPLikeConn { + return &mocks.UDPLikeConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + time.Sleep(time.Microsecond) + return numBytes, addr, err + }, + } + } + + t.Run("on success", func(t *testing.T) { + const mockedEndpoint = "8.8.4.4:443" + const mockedNumBytes = 128 + expectedAddr := newAddr(mockedEndpoint) + conn := newConn(mockedNumBytes, expectedAddr, nil) + saver := NewSaver() + v := &SingleNetworkEventValidator{ + ExpectedCount: mockedNumBytes, + ExpectedErr: nil, + ExpectedNetwork: "udp", + ExpectedOp: netxlite.ReadFromOperation, + ExpectedEpnt: mockedEndpoint, + Saver: saver, + } + buf := make([]byte, 1024) + count, addr, err := saver.ReadFrom(conn, buf) + if err != nil { + t.Fatal(err) + } + if expectedAddr.Network() != addr.Network() { + t.Fatal("invalid addr.Network") + } + if expectedAddr.String() != addr.String() { + t.Fatal("invalid addr.String") + } + if count != mockedNumBytes { + t.Fatal("invalid count") + } + if err := v.Validate(); err != nil { + t.Fatal(err) + } + }) + + t.Run("on failure", func(t *testing.T) { + mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) + conn := newConn(0, nil, mockedError) + saver := NewSaver() + v := &SingleNetworkEventValidator{ + ExpectedCount: 0, + ExpectedErr: mockedError, + ExpectedNetwork: "udp", + ExpectedOp: netxlite.ReadFromOperation, + ExpectedEpnt: "", + Saver: saver, + } + buf := make([]byte, 1024) + count, addr, err := saver.ReadFrom(conn, buf) + if !errors.Is(err, mockedError) { + t.Fatal(err) + } + if addr != nil { + t.Fatal("invalid addr") + } + if count != 0 { + t.Fatal("invalid count") + } + if err := v.Validate(); err != nil { + t.Fatal(err) + } + }) +} + +func TestSaverQUICDialContext(t *testing.T) { + // newQUICDialer creates a new QUICDialer for testing. + newQUICDialer := func(qconn quic.EarlyConnection, err error) model.QUICDialer { + return &mocks.QUICDialer{ + MockDialContext: func( + ctx context.Context, network, address string, tlsConfig *tls.Config, + quicConfig *quic.Config) (quic.EarlyConnection, error) { + time.Sleep(time.Microsecond) + return qconn, err + }, + } + } + + // newQUICConnection creates a new quic.EarlyConnection for testing. + newQUICConnection := func(handshakeComplete context.Context, state tls.ConnectionState) quic.EarlyConnection { + return &mocks.QUICEarlyConnection{ + MockHandshakeComplete: func() context.Context { + return handshakeComplete + }, + MockConnectionState: func() quic.ConnectionState { + return quic.ConnectionState{ + TLS: qtls.ConnectionStateWith0RTT{ + ConnectionState: state, + }, + } + }, + MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error { + return nil + }, + } + } + + t.Run("on success", func(t *testing.T) { + handshakeCtx := context.Background() + handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx) + handshakeCancel() // simulate a completed handshake + const expectedNetwork = "udp" + const mockedEndpoint = "8.8.4.4:443" + saver := NewSaver() + var peerCerts [][]byte + ff := &fakefill.Filler{} + ff.Fill(&peerCerts) + if len(peerCerts) < 1 { + t.Fatal("did not fill peerCerts") + } + v := &SingleQUICTLSHandshakeValidator{ + ExpectedALPN: []string{"h3"}, + ExpectedSNI: "dns.google", + ExpectedSkipVerify: true, + // + ExpectedCipherSuite: tls.TLS_AES_128_GCM_SHA256, + ExpectedNegotiatedProtocol: "h3", + ExpectedPeerCerts: peerCerts, + ExpectedVersion: tls.VersionTLS13, + // + ExpectedNetwork: "quic", + ExpectedRemoteAddr: mockedEndpoint, + // + QUICConfig: &quic.Config{}, + // + ExpectedFailure: nil, + Saver: saver, + } + qconn := newQUICConnection(handshakeCtx, v.NewTLSConnectionState()) + dialer := newQUICDialer(qconn, nil) + ctx := context.Background() + qconn, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, + mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) + if err != nil { + t.Fatal(err) + } + if qconn == nil { + t.Fatal("expected nil qconn") + } + qconn.CloseWithError(0, "") + if err := v.Validate(); err != nil { + t.Fatal(err) + } + }) + + t.Run("on handshake timeout", func(t *testing.T) { + handshakeCtx := context.Background() + handshakeCtx, handshakeCancel := context.WithCancel(handshakeCtx) + defer handshakeCancel() + const expectedNetwork = "udp" + const mockedEndpoint = "8.8.4.4:443" + saver := NewSaver() + v := &SingleQUICTLSHandshakeValidator{ + ExpectedALPN: []string{"h3"}, + ExpectedSNI: "dns.google", + ExpectedSkipVerify: true, + // + ExpectedCipherSuite: 0, + ExpectedNegotiatedProtocol: "", + ExpectedPeerCerts: nil, + ExpectedVersion: 0, + // + ExpectedNetwork: "quic", + ExpectedRemoteAddr: mockedEndpoint, + // + QUICConfig: &quic.Config{}, + // + ExpectedFailure: context.DeadlineExceeded, + Saver: saver, + } + qconn := newQUICConnection(handshakeCtx, tls.ConnectionState{}) + dialer := newQUICDialer(qconn, nil) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Microsecond) + defer cancel() + qconn, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, + mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal("unexpected error") + } + if qconn != nil { + t.Fatal("expected nil connection") + } + if err := v.Validate(); err != nil { + t.Fatal(err) + } + }) + + t.Run("on other error", func(t *testing.T) { + mockedError := netxlite.NewTopLevelGenericErrWrapper(io.EOF) + const expectedNetwork = "udp" + const mockedEndpoint = "8.8.4.4:443" + saver := NewSaver() + v := &SingleQUICTLSHandshakeValidator{ + ExpectedALPN: []string{"h3"}, + ExpectedSNI: "dns.google", + ExpectedSkipVerify: true, + // + ExpectedCipherSuite: 0, + ExpectedNegotiatedProtocol: "", + ExpectedPeerCerts: nil, + ExpectedVersion: 0, + // + ExpectedNetwork: "quic", + ExpectedRemoteAddr: mockedEndpoint, + // + QUICConfig: &quic.Config{}, + // + ExpectedFailure: mockedError, + Saver: saver, + } + dialer := newQUICDialer(nil, mockedError) + ctx := context.Background() + qconn, err := saver.QUICDialContext(ctx, dialer, expectedNetwork, + mockedEndpoint, v.NewTLSConfig(), v.QUICConfig) + if !errors.Is(err, mockedError) { + t.Fatal("unexpected error") + } + if qconn != nil { + t.Fatal("expected nil connection") + } + if err := v.Validate(); err != nil { + t.Fatal(err) + } + }) + + // TODO(bassosimone): here we're not testing the case in which + // the certificate is invalid for the required SNI. + // + // We need first to figure out whether this is what happens + // when we validate for QUIC in such cases. If that's the case + // indeed, then we can write the tests. + + t.Run("on x509.HostnameError", func(t *testing.T) { + t.Skip("test not implemented") + }) + + t.Run("on x509.UnknownAuthorityError", func(t *testing.T) { + t.Skip("test not implemented") + }) + + t.Run("on x509.CertificateInvalidError", func(t *testing.T) { + t.Skip("test not implemented") + }) +} + +type SingleQUICTLSHandshakeValidator struct { + // related to the tls.Config + ExpectedALPN []string + ExpectedSNI string + ExpectedSkipVerify bool + + // related to the tls.ConnectionState + ExpectedCipherSuite uint16 + ExpectedNegotiatedProtocol string + ExpectedPeerCerts [][]byte + ExpectedVersion uint16 + + // related to the mocked conn (TLS) / dial params (QUIC) + ExpectedNetwork string + ExpectedRemoteAddr string + + // tells us whether we're using QUIC + QUICConfig *quic.Config + + // other fields + ExpectedFailure error + Saver *Saver +} + +func (v *SingleQUICTLSHandshakeValidator) NewTLSConfig() *tls.Config { + return &tls.Config{ + NextProtos: v.ExpectedALPN, + ServerName: v.ExpectedSNI, + InsecureSkipVerify: v.ExpectedSkipVerify, + } +} + +func (v *SingleQUICTLSHandshakeValidator) NewTLSConnectionState() tls.ConnectionState { + var state tls.ConnectionState + if v.ExpectedCipherSuite != 0 { + state.CipherSuite = v.ExpectedCipherSuite + } + if v.ExpectedNegotiatedProtocol != "" { + state.NegotiatedProtocol = v.ExpectedNegotiatedProtocol + } + for _, cert := range v.ExpectedPeerCerts { + state.PeerCertificates = append(state.PeerCertificates, &x509.Certificate{ + Raw: cert, + }) + } + if v.ExpectedVersion != 0 { + state.Version = v.ExpectedVersion + } + return state +} + +func (v *SingleQUICTLSHandshakeValidator) Validate() error { + trace := v.Saver.MoveOutTrace() + var entries []*QUICTLSHandshakeEvent + if v.QUICConfig != nil { + entries = trace.QUICHandshake + } else { + entries = trace.TLSHandshake + } + if len(entries) != 1 { + return errors.New("expected to see a single entry") + } + entry := entries[0] + if diff := cmp.Diff(entry.ALPN, v.ExpectedALPN); diff != "" { + return errors.New(diff) + } + if entry.CipherSuite != netxlite.TLSCipherSuiteString(v.ExpectedCipherSuite) { + return errors.New("unexpected .CipherSuite") + } + if !errors.Is(entry.Failure, v.ExpectedFailure) { + return errors.New("unexpected .Failure") + } + if !entry.Finished.After(entry.Started) { + return errors.New(".Finished is not after .Started") + } + if entry.NegotiatedProto != v.ExpectedNegotiatedProtocol { + return errors.New("unexpected .NegotiatedProto") + } + if entry.Network != v.ExpectedNetwork { + return errors.New("unexpected .Network") + } + if diff := cmp.Diff(entry.PeerCerts, v.ExpectedPeerCerts); diff != "" { + return errors.New("unexpected .PeerCerts") + } + if entry.RemoteAddr != v.ExpectedRemoteAddr { + return errors.New("unexpected .RemoteAddr") + } + if entry.SNI != v.ExpectedSNI { + return errors.New("unexpected .ServerName") + } + if entry.SkipVerify != v.ExpectedSkipVerify { + return errors.New("unexpected .SkipVerify") + } + if entry.TLSVersion != netxlite.TLSVersionString(v.ExpectedVersion) { + return errors.New("unexpected .Version") + } + return nil +} diff --git a/internal/engine/allexperiments.go b/internal/engine/allexperiments.go index d9e7702..8e31b6d 100644 --- a/internal/engine/allexperiments.go +++ b/internal/engine/allexperiments.go @@ -5,6 +5,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/experiment/dash" "github.com/ooni/probe-cli/v3/internal/engine/experiment/dnscheck" + "github.com/ooni/probe-cli/v3/internal/engine/experiment/dnsping" "github.com/ooni/probe-cli/v3/internal/engine/experiment/example" "github.com/ooni/probe-cli/v3/internal/engine/experiment/fbmessenger" "github.com/ooni/probe-cli/v3/internal/engine/experiment/hhfm" @@ -57,6 +58,18 @@ var experimentsByName = map[string]func(*Session) *ExperimentBuilder{ } }, + "dnsping": func(session *Session) *ExperimentBuilder { + return &ExperimentBuilder{ + build: func(config interface{}) *Experiment { + return NewExperiment(session, dnsping.NewExperimentMeasurer( + *config.(*dnsping.Config), + )) + }, + config: &dnsping.Config{}, + inputPolicy: InputOrStaticDefault, + } + }, + "example": func(session *Session) *ExperimentBuilder { return &ExperimentBuilder{ build: func(config interface{}) *Experiment { diff --git a/internal/engine/experiment/dnsping/dnsping.go b/internal/engine/experiment/dnsping/dnsping.go new file mode 100644 index 0000000..bc77b94 --- /dev/null +++ b/internal/engine/experiment/dnsping/dnsping.go @@ -0,0 +1,201 @@ +// Package dnsping is the experimental dnsping experiment. +// +// See https://github.com/ooni/spec/blob/master/nettests/ts-035-dnsping.md. +package dnsping + +import ( + "context" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/ooni/probe-cli/v3/internal/measurex" + "github.com/ooni/probe-cli/v3/internal/model" +) + +const ( + testName = "dnsping" + testVersion = "0.1.0" +) + +// Config contains the experiment configuration. +type Config struct { + // Delay is the delay between each repetition (in milliseconds). + Delay int64 `ooni:"number of milliseconds to wait before sending each ping"` + + // Domains is the space-separated list of domains to measure. + Domains string `ooni:"space-separated list of domains to measure"` + + // Repetitions is the number of repetitions for each ping. + Repetitions int64 `ooni:"number of times to repeat the measurement"` +} + +func (c *Config) delay() time.Duration { + if c.Delay > 0 { + return time.Duration(c.Delay) * time.Millisecond + } + return time.Second +} + +func (c Config) repetitions() int64 { + if c.Repetitions > 0 { + return c.Repetitions + } + return 10 +} + +func (c Config) domains() string { + if c.Domains != "" { + return c.Domains + } + return "edge-chat.instagram.com example.com" +} + +// TestKeys contains the experiment results. +type TestKeys struct { + Pings []*SinglePing `json:"pings"` +} + +// TODO(bassosimone): save more data once the dnsping improvements at +// github.com/bassosimone/websteps-illustrated contains have been merged +// into this repository. When this happens, we'll able to save raw +// queries and network events of each individual query. + +// SinglePing contains the results of a single ping. +type SinglePing struct { + Queries []*measurex.ArchivalDNSLookupEvent `json:"queries"` +} + +// Measurer performs the measurement. +type Measurer struct { + config Config +} + +// ExperimentName implements ExperimentMeasurer.ExperiExperimentName. +func (m *Measurer) ExperimentName() string { + return testName +} + +// ExperimentVersion implements ExperimentMeasurer.ExperimentVersion. +func (m *Measurer) ExperimentVersion() string { + return testVersion +} + +var ( + // errNoInputProvided indicates you didn't provide any input + errNoInputProvided = errors.New("not input provided") + + // errInputIsNotAnURL indicates that input is not an URL + errInputIsNotAnURL = errors.New("input is not an URL") + + // errInvalidScheme indicates that the scheme is invalid + errInvalidScheme = errors.New("scheme must be udp") + + // errMissingPort indicates that there is no port. + errMissingPort = errors.New("the URL must include a port") +) + +// Run implements ExperimentMeasurer.Run. +func (m *Measurer) Run( + ctx context.Context, + sess model.ExperimentSession, + measurement *model.Measurement, + callbacks model.ExperimentCallbacks, +) error { + if measurement.Input == "" { + return errNoInputProvided + } + parsed, err := url.Parse(string(measurement.Input)) + if err != nil { + return fmt.Errorf("%w: %s", errInputIsNotAnURL, err.Error()) + } + if parsed.Scheme != "udp" { + return errInvalidScheme + } + if parsed.Port() == "" { + return errMissingPort + } + tk := new(TestKeys) + measurement.TestKeys = tk + mxmx := measurex.NewMeasurerWithDefaultSettings() + out := make(chan *measurex.DNSMeasurement) + domains := strings.Split(m.config.domains(), " ") + for _, domain := range domains { + go m.dnsPingLoop(ctx, mxmx, parsed.Host, domain, out) + } + // The following multiplication could overflow but we're always using small + // numbers so it's fine for us not to bother with checking for that. + // + // We emit two results (A and AAAA) for each domain and repetition. + numResults := int(m.config.repetitions()) * len(domains) * 2 + for len(tk.Pings) < numResults { + meas := <-out + // TODO(bassosimone): when we merge the improvements at + // https://github.com/bassosimone/websteps-illustrated it + // will become unnecessary to split with query type + // as we're doing below. + queries := measurex.NewArchivalDNSLookupEventList(meas.LookupHost) + tk.Pings = append(tk.Pings, m.onlyQueryWithType(queries, "A")...) + tk.Pings = append(tk.Pings, m.onlyQueryWithType(queries, "AAAA")...) + } + return nil // return nil so we always submit the measurement +} + +// onlyQueryWithType returns only the queries with the given type. +func (m *Measurer) onlyQueryWithType( + in []*measurex.ArchivalDNSLookupEvent, kind string) (out []*SinglePing) { + for _, query := range in { + if query.QueryType == kind { + out = append(out, &SinglePing{ + Queries: []*measurex.ArchivalDNSLookupEvent{query}, + }) + } + } + return +} + +// dnsPingLoop sends all the ping requests and emits the results onto the out channel. +func (m *Measurer) dnsPingLoop(ctx context.Context, mxmx *measurex.Measurer, + address string, domain string, out chan<- *measurex.DNSMeasurement) { + ticker := time.NewTicker(m.config.delay()) + defer ticker.Stop() + for i := int64(0); i < m.config.repetitions(); i++ { + go m.dnsPingAsync(ctx, mxmx, address, domain, out) + <-ticker.C + } +} + +// dnsPingAsync performs a DNS ping and emits the result onto the out channel. +func (m *Measurer) dnsPingAsync(ctx context.Context, mxmx *measurex.Measurer, + address string, domain string, out chan<- *measurex.DNSMeasurement) { + out <- m.dnsRoundTrip(ctx, mxmx, address, domain) +} + +// dnsRoundTrip performs a round trip and returns the results to the caller. +func (m *Measurer) dnsRoundTrip(ctx context.Context, mxmx *measurex.Measurer, + address string, domain string) *measurex.DNSMeasurement { + // TODO(bassosimone): make the timeout user-configurable + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + return mxmx.LookupHostUDP(ctx, domain, address) +} + +// NewExperimentMeasurer creates a new ExperimentMeasurer. +func NewExperimentMeasurer(config Config) model.ExperimentMeasurer { + return &Measurer{config: config} +} + +// SummaryKeys contains summary keys for this experiment. +// +// Note that this structure is part of the ABI contract with ooniprobe +// therefore we should be careful when changing it. +type SummaryKeys struct { + IsAnomaly bool `json:"-"` +} + +// GetSummaryKeys implements model.ExperimentMeasurer.GetSummaryKeys. +func (m Measurer) GetSummaryKeys(measurement *model.Measurement) (interface{}, error) { + return SummaryKeys{IsAnomaly: false}, nil +} diff --git a/internal/engine/experiment/dnsping/dnsping_test.go b/internal/engine/experiment/dnsping/dnsping_test.go new file mode 100644 index 0000000..8f8702f --- /dev/null +++ b/internal/engine/experiment/dnsping/dnsping_test.go @@ -0,0 +1,159 @@ +package dnsping + +import ( + "context" + "errors" + "log" + "net" + "net/url" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/engine/mockable" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +func TestConfig_domains(t *testing.T) { + c := Config{} + if c.domains() != "edge-chat.instagram.com example.com" { + t.Fatal("invalid default domains list") + } +} + +func TestConfig_repetitions(t *testing.T) { + c := Config{} + if c.repetitions() != 10 { + t.Fatal("invalid default number of repetitions") + } +} + +func TestConfig_delay(t *testing.T) { + c := Config{} + if c.delay() != time.Second { + t.Fatal("invalid default delay") + } +} + +func TestMeasurer_run(t *testing.T) { + // expectedPings is the expected number of pings + const expectedPings = 4 + + // runHelper is an helper function to run this set of tests. + runHelper := func(input string) (*model.Measurement, model.ExperimentMeasurer, error) { + m := NewExperimentMeasurer(Config{ + Domains: "example.com", + Delay: 1, // millisecond + Repetitions: expectedPings, + }) + if m.ExperimentName() != "dnsping" { + t.Fatal("invalid experiment name") + } + if m.ExperimentVersion() != "0.1.0" { + t.Fatal("invalid experiment version") + } + ctx := context.Background() + meas := &model.Measurement{ + Input: model.MeasurementTarget(input), + } + sess := &mockable.Session{ + MockableLogger: model.DiscardLogger, + } + callbacks := model.NewPrinterCallbacks(model.DiscardLogger) + err := m.Run(ctx, sess, meas, callbacks) + return meas, m, err + } + + t.Run("with empty input", func(t *testing.T) { + _, _, err := runHelper("") + if !errors.Is(err, errNoInputProvided) { + t.Fatal("unexpected error", err) + } + }) + + t.Run("with invalid URL", func(t *testing.T) { + _, _, err := runHelper("\t") + if !errors.Is(err, errInputIsNotAnURL) { + t.Fatal("unexpected error", err) + } + }) + + t.Run("with invalid scheme", func(t *testing.T) { + _, _, err := runHelper("https://8.8.8.8:443/") + if !errors.Is(err, errInvalidScheme) { + t.Fatal("unexpected error", err) + } + }) + + t.Run("with missing port", func(t *testing.T) { + _, _, err := runHelper("udp://8.8.8.8") + if !errors.Is(err, errMissingPort) { + t.Fatal("unexpected error", err) + } + }) + + t.Run("with local listener", func(t *testing.T) { + srvrURL, dnsListener, err := startDNSServer() + if err != nil { + log.Fatal(err) + } + defer dnsListener.Close() + meas, m, err := runHelper(srvrURL) + if err != nil { + t.Fatal(err) + } + tk := meas.TestKeys.(*TestKeys) + if len(tk.Pings) != expectedPings*2 { // account for A & AAAA pings + t.Fatal("unexpected number of pings") + } + ask, err := m.GetSummaryKeys(meas) + if err != nil { + t.Fatal("cannot obtain summary") + } + summary := ask.(SummaryKeys) + if summary.IsAnomaly { + t.Fatal("expected no anomaly") + } + }) +} + +// startDNSServer starts a local DNS server. +func startDNSServer() (string, net.PacketConn, error) { + dnsListener, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + return "", nil, err + } + go runDNSServer(dnsListener) + URL := &url.URL{ + Scheme: "udp", + Host: dnsListener.LocalAddr().String(), + Path: "/", + } + return URL.String(), dnsListener, nil +} + +// runDNSServer runs the DNS server. +func runDNSServer(dnsListener net.PacketConn) { + ds := &dns.Server{ + Handler: &dnsHandler{}, + Net: "udp", + PacketConn: dnsListener, + } + err := ds.ActivateAndServe() + if !errors.Is(err, net.ErrClosed) { + runtimex.PanicOnError(err, "ActivateAndServe failed") + } +} + +// dnsHandler handles DNS requests. +type dnsHandler struct{} + +// ServeDNS serves a DNS request +func (h *dnsHandler) ServeDNS(rw dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.Compress = true + m.MsgHdr.RecursionAvailable = true + m.SetRcode(req, dns.RcodeServerFailure) + rw.WriteMsg(m) +} diff --git a/internal/engine/experiment/simplequicping/simplequicping_test.go b/internal/engine/experiment/simplequicping/simplequicping_test.go index 8750b86..b937bad 100644 --- a/internal/engine/experiment/simplequicping/simplequicping_test.go +++ b/internal/engine/experiment/simplequicping/simplequicping_test.go @@ -98,11 +98,11 @@ func TestMeasurer_run(t *testing.T) { }) t.Run("with local listener", func(t *testing.T) { - srvrURL, err := startEchoServer() + srvrURL, listener, err := startEchoServer() if err != nil { log.Fatal(err) } - t.Log(srvrURL) + defer listener.Close() meas, m, err := runHelper(srvrURL) if err != nil { t.Fatal(err) @@ -127,10 +127,10 @@ func TestMeasurer_run(t *testing.T) { // SPDX-License-Identifier: MIT // // See https://github.com/lucas-clemente/quic-go/blob/v0.27.0/example/echo/echo.go#L34 -func startEchoServer() (string, error) { +func startEchoServer() (string, quic.Listener, error) { listener, err := quic.ListenAddr("127.0.0.1:0", generateTLSConfig(), nil) if err != nil { - return "", err + return "", nil, err } go echoWorkerMain(listener) URL := &url.URL{ @@ -138,7 +138,7 @@ func startEchoServer() (string, error) { Host: listener.Addr().String(), Path: "/", } - return URL.String(), nil + return URL.String(), listener, nil } // Worker used by startEchoServer to accept a quic connection. @@ -147,16 +147,17 @@ func startEchoServer() (string, error) { // // See https://github.com/lucas-clemente/quic-go/blob/v0.27.0/example/echo/echo.go#L34 func echoWorkerMain(listener quic.Listener) { - defer listener.Close() - conn, err := listener.Accept(context.Background()) - if err != nil { - panic(err) + for { + conn, err := listener.Accept(context.Background()) + if err != nil { + return + } + stream, err := conn.AcceptStream(context.Background()) + if err != nil { + continue + } + stream.Close() } - stream, err := conn.AcceptStream(context.Background()) - if err != nil { - panic(err) - } - stream.Close() } // Setup a bare-bones TLS config for the server.