diff --git a/internal/engine/experiment/urlgetter/configurer_test.go b/internal/engine/experiment/urlgetter/configurer_test.go index 2c27c5d..bdb5212 100644 --- a/internal/engine/experiment/urlgetter/configurer_test.go +++ b/internal/engine/experiment/urlgetter/configurer_test.go @@ -119,7 +119,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the DNS transport we expected") } @@ -195,7 +195,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the DNS transport we expected") } @@ -271,7 +271,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T) if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the DNS transport we expected") } @@ -347,7 +347,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the DNS transport we expected") } diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 088cf70..ae85470 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -104,9 +104,7 @@ func NewResolver(config Config) model.Resolver { Resolver: r, } } - if config.ResolveSaver != nil { - r = tracex.SaverResolver{Resolver: r, Saver: config.ResolveSaver} - } + r = config.ResolveSaver.WrapResolver(r) // WAI when config.ResolveSaver==nil return &netxlite.ResolverIDNA{Resolver: r} } @@ -129,23 +127,14 @@ func NewQUICDialer(config Config) model.QUICDialer { if config.FullResolver == nil { config.FullResolver = NewResolver(config) } - ql := netxlite.NewQUICListener() - if config.ReadWriteSaver != nil { - ql = &tracex.QUICListenerSaver{ - QUICListener: ql, - Saver: config.ReadWriteSaver, - } - } + ql := config.ReadWriteSaver.WrapQUICListener(netxlite.NewQUICListener()) var logger model.DebugLogger = model.DiscardLogger if config.Logger != nil { logger = config.Logger } extensions := []netxlite.QUICDialerWrapper{ func(dialer model.QUICDialer) model.QUICDialer { - if config.TLSSaver != nil { - dialer = tracex.QUICHandshakeSaver{Saver: config.TLSSaver, QUICDialer: dialer} - } - return dialer + return config.TLSSaver.WrapQUICDialer(dialer) // robust to nil TLSSaver }, } return netxlite.NewQUICDialerWithResolver(ql, logger, config.FullResolver, extensions...) @@ -161,9 +150,7 @@ func NewTLSDialer(config Config) model.TLSDialer { if config.Logger != nil { h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h} } - if config.TLSSaver != nil { - h = tracex.SaverTLSHandshaker{TLSHandshaker: h, Saver: config.TLSSaver} - } + h = config.TLSSaver.WrapTLSHandshaker(h) // behaves with nil TLSSaver if config.TLSConfig == nil { config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}} } @@ -284,12 +271,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, httpClient := &http.Client{Transport: NewHTTPTransport(config)} var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride( httpClient, URL, hostOverride) - if config.ResolveSaver != nil { - txp = tracex.SaverDNSTransport{ - DNSTransport: txp, - Saver: config.ResolveSaver, - } - } + txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil return netxlite.NewSerialResolver(txp), nil case "udp": dialer := NewDialer(config) @@ -299,12 +281,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, } var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport( dialer, endpoint) - if config.ResolveSaver != nil { - txp = tracex.SaverDNSTransport{ - DNSTransport: txp, - Saver: config.ResolveSaver, - } - } + txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil return netxlite.NewSerialResolver(txp), nil case "dot": config.TLSConfig.NextProtos = []string{"dot"} @@ -315,12 +292,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, } var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport( tlsDialer.DialTLSContext, endpoint) - if config.ResolveSaver != nil { - txp = tracex.SaverDNSTransport{ - DNSTransport: txp, - Saver: config.ResolveSaver, - } - } + txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil return netxlite.NewSerialResolver(txp), nil case "tcp": dialer := NewDialer(config) @@ -330,12 +302,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, } var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport( dialer.DialContext, endpoint) - if config.ResolveSaver != nil { - txp = tracex.SaverDNSTransport{ - DNSTransport: txp, - Saver: config.ResolveSaver, - } - } + txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil return netxlite.NewSerialResolver(txp), nil default: return nil, errors.New("unsupported resolver scheme") diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 642a71f..3cf2ea5 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -126,7 +126,7 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - sr, ok := ir.Resolver.(tracex.SaverResolver) + sr, ok := ir.Resolver.(*tracex.SaverResolver) if !ok { t.Fatal("not the resolver we expected") } @@ -332,7 +332,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - sth, ok := rtd.TLSHandshaker.(tracex.SaverTLSHandshaker) + sth, ok := rtd.TLSHandshaker.(*tracex.SaverTLSHandshaker) if !ok { t.Fatal("not the TLSHandshaker we expected") } @@ -633,7 +633,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } @@ -670,7 +670,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } @@ -711,7 +711,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } @@ -756,7 +756,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.SaverDNSTransport) if !ok { t.Fatal("not the transport we expected") } diff --git a/internal/engine/netx/tracex/archival.go b/internal/engine/netx/tracex/archival.go index 4a75836..fe8714c 100644 --- a/internal/engine/netx/tracex/archival.go +++ b/internal/engine/netx/tracex/archival.go @@ -15,7 +15,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// Compatibility types +// Compatibility types. Most experiments still use these names. type ( ExtSpec = model.ArchivalExtSpec TCPConnectEntry = model.ArchivalTCPConnectResult @@ -32,7 +32,7 @@ type ( NetworkEvent = model.ArchivalNetworkEvent ) -// Compatibility variables +// Compatibility variables. Most experiments still use these names. var ( ExtDNS = model.ArchivalExtDNS ExtNetevents = model.ArchivalExtNetevents @@ -100,7 +100,7 @@ func NewFailedOperation(err error) *string { return &s } -func addheaders( +func httpAddHeaders( source http.Header, destList *[]HTTPHeader, destMap *map[string]MaybeBinaryValue, @@ -150,14 +150,14 @@ func newRequestList(begin time.Time, events []Event) []RequestEntry { entry.Request.BodyIsTruncated = ev.DataIsTruncated case "http_request_metadata": entry.Request.Headers = make(map[string]MaybeBinaryValue) - addheaders( + httpAddHeaders( ev.HTTPHeaders, &entry.Request.HeadersList, &entry.Request.Headers) entry.Request.Method = ev.HTTPMethod entry.Request.URL = ev.HTTPURL entry.Request.Transport = ev.Transport case "http_response_metadata": entry.Response.Headers = make(map[string]MaybeBinaryValue) - addheaders( + httpAddHeaders( ev.HTTPHeaders, &entry.Response.HeadersList, &entry.Response.Headers) entry.Response.Code = int64(ev.HTTPStatusCode) entry.Response.Locations = ev.HTTPHeaders.Values("Location") @@ -183,11 +183,11 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry { continue } for _, qtype := range []dnsQueryType{"A", "AAAA"} { - entry := qtype.makequeryentry(begin, ev) + entry := qtype.makeQueryEntry(begin, ev) for _, addr := range ev.Addresses { - if qtype.ipoftype(addr) { + if qtype.ipOfType(addr) { entry.Answers = append( - entry.Answers, qtype.makeanswerentry(addr)) + entry.Answers, qtype.makeAnswerEntry(addr)) } } if len(entry.Answers) <= 0 && ev.Err == nil { @@ -206,7 +206,7 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry { return out } -func (qtype dnsQueryType) ipoftype(addr string) bool { +func (qtype dnsQueryType) ipOfType(addr string) bool { switch qtype { case "A": return !strings.Contains(addr, ":") @@ -216,7 +216,7 @@ func (qtype dnsQueryType) ipoftype(addr string) bool { return false } -func (qtype dnsQueryType) makeanswerentry(addr string) DNSAnswerEntry { +func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry { answer := DNSAnswerEntry{AnswerType: string(qtype)} asn, org, _ := geolocate.LookupASN(addr) answer.ASN = int64(asn) @@ -230,7 +230,7 @@ func (qtype dnsQueryType) makeanswerentry(addr string) DNSAnswerEntry { return answer } -func (qtype dnsQueryType) makequeryentry(begin time.Time, ev Event) DNSQueryEntry { +func (qtype dnsQueryType) makeQueryEntry(begin time.Time, ev Event) DNSQueryEntry { return DNSQueryEntry{ Engine: ev.Proto, Failure: NewFailure(ev.Err), @@ -315,7 +315,7 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake { Failure: NewFailure(ev.Err), NegotiatedProtocol: ev.TLSNegotiatedProto, NoTLSVerify: ev.NoTLSVerify, - PeerCertificates: makePeerCerts(ev.TLSPeerCerts), + PeerCertificates: tlsMakePeerCerts(ev.TLSPeerCerts), ServerName: ev.TLSServerName, T: ev.Time.Sub(begin).Seconds(), TLSVersion: ev.TLSVersion, @@ -324,7 +324,7 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake { return out } -func makePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) { +func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) { for _, e := range in { out = append(out, MaybeBinaryValue{Value: string(e.Raw)}) } diff --git a/internal/engine/netx/tracex/archival_test.go b/internal/engine/netx/tracex/archival_test.go index dad3c58..a437d96 100644 --- a/internal/engine/netx/tracex/archival_test.go +++ b/internal/engine/netx/tracex/archival_test.go @@ -47,7 +47,7 @@ func TestDNSQueryIPOfType(t *testing.T) { output: false, }} for _, exp := range expectations { - if exp.qtype.ipoftype(exp.ip) != exp.output { + if exp.qtype.ipOfType(exp.ip) != exp.output { t.Fatalf("failure for %+v", exp) } } diff --git a/internal/engine/netx/tracex/dialer.go b/internal/engine/netx/tracex/dialer.go index ec6d025..e6abdc9 100644 --- a/internal/engine/netx/tracex/dialer.go +++ b/internal/engine/netx/tracex/dialer.go @@ -1,5 +1,9 @@ package tracex +// +// TCP and connected UDP sockets +// + import ( "context" "net" @@ -11,7 +15,10 @@ import ( // SaverDialer saves events occurring during the dial type SaverDialer struct { - model.Dialer + // Dialer is the underlying dialer, + Dialer model.Dialer + + // Saver saves events. Saver *Saver } @@ -31,10 +38,17 @@ func (d *SaverDialer) DialContext(ctx context.Context, network, address string) return conn, err } +func (d *SaverDialer) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() +} + // SaverConnDialer wraps the returned connection such that we // collect all the read/write events that occur. type SaverConnDialer struct { - model.Dialer + // Dialer is the underlying dialer + Dialer model.Dialer + + // Saver saves events Saver *Saver } @@ -47,6 +61,10 @@ func (d *SaverConnDialer) DialContext(ctx context.Context, network, address stri return &saverConn{saver: d.Saver, Conn: conn}, nil } +func (d *SaverConnDialer) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() +} + type saverConn struct { net.Conn saver *Saver diff --git a/internal/engine/netx/tracex/doc.go b/internal/engine/netx/tracex/doc.go index 399e1a1..7a2b467 100644 --- a/internal/engine/netx/tracex/doc.go +++ b/internal/engine/netx/tracex/doc.go @@ -1,2 +1,8 @@ -// Package tracex contains code to perform measurements using tracing. +// Package tracex performs measurements using tracing. To use tracing means +// that we'll wrap netx data types (e.g., a Dialer) with equivalent data types +// saving events into a Saver data struture. Then we will use the data types +// normally (e.g., call the Dialer's DialContet method and then use the +// resulting connection). When done, we will extract the trace containing +// all the events that occurred from the saver and process it to determine +// what happened during the measurement itself. package tracex diff --git a/internal/engine/netx/tracex/event.go b/internal/engine/netx/tracex/event.go index cfdc4b3..3509d6c 100644 --- a/internal/engine/netx/tracex/event.go +++ b/internal/engine/netx/tracex/event.go @@ -1,9 +1,7 @@ package tracex import ( - "crypto/tls" "crypto/x509" - "errors" "net/http" "time" ) @@ -36,25 +34,3 @@ type Event struct { Time time.Time `json:",omitempty"` Transport string `json:",omitempty"` } - -// PeerCerts returns the certificates presented by the peer regardless -// of whether the TLS handshake was successful -func PeerCerts(state tls.ConnectionState, err error) []*x509.Certificate { - var x509HostnameError x509.HostnameError - if errors.As(err, &x509HostnameError) { - // Test case: https://wrong.host.badssl.com/ - return []*x509.Certificate{x509HostnameError.Certificate} - } - var x509UnknownAuthorityError x509.UnknownAuthorityError - if errors.As(err, &x509UnknownAuthorityError) { - // Test case: https://self-signed.badssl.com/. This error has - // never been among the ones returned by MK. - return []*x509.Certificate{x509UnknownAuthorityError.Cert} - } - var x509CertificateInvalidError x509.CertificateInvalidError - if errors.As(err, &x509CertificateInvalidError) { - // Test case: https://expired.badssl.com/ - return []*x509.Certificate{x509CertificateInvalidError.Cert} - } - return state.PeerCertificates -} diff --git a/internal/engine/netx/tracex/http.go b/internal/engine/netx/tracex/http.go index c7db07c..03d15f9 100644 --- a/internal/engine/netx/tracex/http.go +++ b/internal/engine/netx/tracex/http.go @@ -1,5 +1,9 @@ package tracex +// +// HTTP +// + import ( "bytes" "context" @@ -21,7 +25,7 @@ type SaverMetadataHTTPTransport struct { // RoundTrip implements RoundTripper.RoundTrip func (txp SaverMetadataHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { txp.Saver.Write(Event{ - HTTPHeaders: txp.CloneHeaders(req), + HTTPHeaders: httpCloneHeaders(req), HTTPMethod: req.Method, HTTPURL: req.URL.String(), Transport: txp.HTTPTransport.Network(), @@ -41,10 +45,10 @@ func (txp SaverMetadataHTTPTransport) RoundTrip(req *http.Request) (*http.Respon return resp, err } -// CloneHeaders returns a clone of the headers where we have +// httpCCloneHeaders returns a clone of the headers where we have // also set the host header, which normally is not set by // golang until it serializes the request itself. -func (txp SaverMetadataHTTPTransport) CloneHeaders(req *http.Request) http.Header { +func httpCloneHeaders(req *http.Request) http.Header { header := req.Header.Clone() if req.Host != "" { header.Set("Host", req.Host) @@ -92,11 +96,11 @@ func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response, snapsize = txp.SnapshotSize } if req.Body != nil { - data, err := saverSnapRead(req.Context(), req.Body, snapsize) + data, err := httpSaverSnapRead(req.Context(), req.Body, snapsize) if err != nil { return nil, err } - req.Body = saverCompose(data, req.Body) + req.Body = httpSaverCompose(data, req.Body) txp.Saver.Write(Event{ DataIsTruncated: len(data) >= snapsize, Data: data, @@ -108,12 +112,12 @@ func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response, if err != nil { return nil, err } - data, err := saverSnapRead(req.Context(), resp.Body, snapsize) + data, err := httpSaverSnapRead(req.Context(), resp.Body, snapsize) if err != nil { resp.Body.Close() return nil, err } - resp.Body = saverCompose(data, resp.Body) + resp.Body = httpSaverCompose(data, resp.Body) txp.Saver.Write(Event{ DataIsTruncated: len(data) >= snapsize, Data: data, @@ -123,15 +127,15 @@ func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response, return resp, nil } -func saverSnapRead(ctx context.Context, r io.ReadCloser, snapsize int) ([]byte, error) { +func httpSaverSnapRead(ctx context.Context, r io.ReadCloser, snapsize int) ([]byte, error) { return netxlite.ReadAllContext(ctx, io.LimitReader(r, int64(snapsize))) } -func saverCompose(data []byte, r io.ReadCloser) io.ReadCloser { - return saverReadCloser{Closer: r, Reader: io.MultiReader(bytes.NewReader(data), r)} +func httpSaverCompose(data []byte, r io.ReadCloser) io.ReadCloser { + return httpSaverReadCloser{Closer: r, Reader: io.MultiReader(bytes.NewReader(data), r)} } -type saverReadCloser struct { +type httpSaverReadCloser struct { io.Closer io.Reader } diff --git a/internal/engine/netx/tracex/http_test.go b/internal/engine/netx/tracex/http_test.go index 4b854e5..c0285a0 100644 --- a/internal/engine/netx/tracex/http_test.go +++ b/internal/engine/netx/tracex/http_test.go @@ -394,8 +394,7 @@ func TestCloneHeaders(t *testing.T) { }, Header: http.Header{}, } - txp := SaverMetadataHTTPTransport{} - header := txp.CloneHeaders(req) + header := httpCloneHeaders(req) if header.Get("Host") != "www.example.com" { t.Fatal("did not set Host header correctly") } @@ -409,8 +408,7 @@ func TestCloneHeaders(t *testing.T) { }, Header: http.Header{}, } - txp := SaverMetadataHTTPTransport{} - header := txp.CloneHeaders(req) + header := httpCloneHeaders(req) if header.Get("Host") != "www.kernel.org" { t.Fatal("did not set Host header correctly") } diff --git a/internal/engine/netx/tracex/quic.go b/internal/engine/netx/tracex/quic.go index be60c85..6a11de9 100644 --- a/internal/engine/netx/tracex/quic.go +++ b/internal/engine/netx/tracex/quic.go @@ -1,5 +1,9 @@ package tracex +// +// QUIC +// + import ( "context" "crypto/tls" @@ -11,14 +15,32 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// QUICHandshakeSaver saves events occurring during the handshake +// QUICHandshakeSaver saves events occurring during the QUIC handshake. type QUICHandshakeSaver struct { + // QUICDialer is the wrapped dialer + QUICDialer model.QUICDialer + + // Saver saves events Saver *Saver - model.QUICDialer } -// DialContext implements ContextDialer.DialContext -func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string, +// WrapQUICDialer wraps a model.QUICDialer with a QUICHandshakeSaver that will +// save the QUIC handshake results into this Saver. +// +// When this function is invoked on a nil Saver, it will directly return +// the original QUICDialer without any wrapping. +func (s *Saver) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer { + if s == nil { + return qd + } + return &QUICHandshakeSaver{ + QUICDialer: qd, + Saver: s, + } +} + +// DialContext implements QUICDialer.DialContext +func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { start := time.Now() // TODO(bassosimone): in the future we probably want to also save @@ -35,6 +57,7 @@ func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string, sess, err := h.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg) stop := time.Now() if err != nil { + // TODO(bassosimone): here we should save the peer certs h.Saver.Write(Event{ Duration: stop.Sub(start), Err: err, @@ -54,7 +77,7 @@ func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string, TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSNegotiatedProto: state.NegotiatedProtocol, TLSNextProtos: tlsCfg.NextProtos, - TLSPeerCerts: PeerCerts(state, err), + TLSPeerCerts: tlsPeerCerts(state, err), TLSServerName: tlsCfg.ServerName, TLSVersion: netxlite.TLSVersionString(state.Version), Time: stop, @@ -62,6 +85,10 @@ func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string, return sess, nil } +func (h *QUICHandshakeSaver) CloseIdleConnections() { + h.QUICDialer.CloseIdleConnections() +} + // quicConnectionState returns the ConnectionState of a QUIC Session. func quicConnectionState(sess quic.EarlyConnection) tls.ConnectionState { return sess.ConnectionState().TLS.ConnectionState @@ -70,32 +97,50 @@ func quicConnectionState(sess quic.EarlyConnection) tls.ConnectionState { // QUICListenerSaver is a QUICListener that also implements saving events. type QUICListenerSaver struct { // QUICListener is the underlying QUICListener. - model.QUICListener + QUICListener model.QUICListener // Saver is the underlying Saver. Saver *Saver } +// WrapQUICListener wraps a model.QUICDialer with a QUICListenerSaver that will +// save the QUIC I/O packet conn events into this Saver. +// +// When this function is invoked on a nil Saver, it will directly return +// the original QUICListener without any wrapping. +func (s *Saver) WrapQUICListener(ql model.QUICListener) model.QUICListener { + if s == nil { + return ql + } + return &QUICListenerSaver{ + QUICListener: ql, + Saver: s, + } +} + // Listen implements QUICListener.Listen. func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) { pconn, err := qls.QUICListener.Listen(addr) if err != nil { return nil, err } - return &saverUDPConn{ + pconn = &udpLikeConnSaver{ UDPLikeConn: pconn, saver: qls.Saver, - }, nil + } + return pconn, nil } -type saverUDPConn struct { +// udpLikeConnSaver saves I/O events +type udpLikeConnSaver struct { + // UDPLikeConn is the wrapped underlying conn model.UDPLikeConn + + // Saver saves events saver *Saver } -var _ model.UDPLikeConn = &saverUDPConn{} - -func (c *saverUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { +func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) { start := time.Now() count, err := c.UDPLikeConn.WriteTo(p, addr) stop := time.Now() @@ -111,7 +156,7 @@ func (c *saverUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) { return count, err } -func (c *saverUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) { start := time.Now() n, addr, err := c.UDPLikeConn.ReadFrom(b) stop := time.Now() @@ -131,9 +176,13 @@ func (c *saverUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { return n, addr, err } -func (c *saverUDPConn) safeAddrString(addr net.Addr) (out string) { +func (c *udpLikeConnSaver) safeAddrString(addr net.Addr) (out string) { if addr != nil { out = addr.String() } return } + +var _ model.QUICDialer = &QUICHandshakeSaver{} +var _ model.QUICListener = &QUICListenerSaver{} +var _ model.UDPLikeConn = &udpLikeConnSaver{} diff --git a/internal/engine/netx/tracex/quic_test.go b/internal/engine/netx/tracex/quic_test.go index 33058ab..143de7f 100644 --- a/internal/engine/netx/tracex/quic_test.go +++ b/internal/engine/netx/tracex/quic_test.go @@ -39,12 +39,9 @@ func TestHandshakeSaverSuccess(t *testing.T) { ServerName: servername, } saver := &Saver{} - dlr := QUICHandshakeSaver{ - QUICDialer: &netxlite.QUICDialerQUICGo{ - QUICListener: &netxlite.QUICListenerStdlib{}, - }, - Saver: saver, - } + dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, + }) sess, err := dlr.DialContext(context.Background(), "udp", quictesting.Endpoint("443"), tlsConf, &quic.Config{}) if err != nil { @@ -97,12 +94,9 @@ func TestHandshakeSaverHostNameError(t *testing.T) { ServerName: servername, } saver := &Saver{} - dlr := QUICHandshakeSaver{ - QUICDialer: &netxlite.QUICDialerQUICGo{ - QUICListener: &netxlite.QUICListenerStdlib{}, - }, - Saver: saver, - } + dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ + QUICListener: &netxlite.QUICListenerStdlib{}, + }) sess, err := dlr.DialContext(context.Background(), "udp", quictesting.Endpoint("443"), tlsConf, &quic.Config{}) if err == nil { @@ -126,14 +120,12 @@ func TestHandshakeSaverHostNameError(t *testing.T) { func TestQUICListenerSaverCannotListen(t *testing.T) { expected := errors.New("mocked error") - qls := &QUICListenerSaver{ - QUICListener: &mocks.QUICListener{ - MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { - return nil, expected - }, + saver := &Saver{} + qls := saver.WrapQUICListener(&mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { + return nil, expected }, - Saver: &Saver{}, - } + }) pconn, err := qls.Listen(&net.UDPAddr{ IP: []byte{}, Port: 8080, @@ -155,10 +147,7 @@ func TestSystemDialerSuccessWithReadWrite(t *testing.T) { } saver := &Saver{} systemdialer := &netxlite.QUICDialerQUICGo{ - QUICListener: &QUICListenerSaver{ - QUICListener: &netxlite.QUICListenerStdlib{}, - Saver: saver, - }, + QUICListener: saver.WrapQUICListener(&netxlite.QUICListenerStdlib{}), } _, err := systemdialer.DialContext(context.Background(), "udp", quictesting.Endpoint("443"), tlsConf, &quic.Config{}) diff --git a/internal/engine/netx/tracex/resolver.go b/internal/engine/netx/tracex/resolver.go index 4bbdc95..8e5c7c8 100644 --- a/internal/engine/netx/tracex/resolver.go +++ b/internal/engine/netx/tracex/resolver.go @@ -1,20 +1,43 @@ package tracex +// +// DNS lookup and round trip +// + import ( "context" + "net" "time" "github.com/ooni/probe-cli/v3/internal/model" ) -// SaverResolver is a resolver that saves events +// SaverResolver is a resolver that saves events. type SaverResolver struct { - model.Resolver + // Resolver is the underlying resolver. + Resolver model.Resolver + + // Saver saves events. Saver *Saver } +// WrapResolver wraps a model.Resolver with a SaverResolver that will save +// the DNS lookup results into this Saver. +// +// When this function is invoked on a nil Saver, it will directly return +// the original Resolver without any wrapping. +func (s *Saver) WrapResolver(r model.Resolver) model.Resolver { + if s == nil { + return r + } + return &SaverResolver{ + Resolver: r, + Saver: s, + } +} + // LookupHost implements Resolver.LookupHost -func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { +func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { start := time.Now() r.Saver.Write(Event{ Address: r.Resolver.Address(), @@ -38,49 +61,105 @@ func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]strin return addrs, err } -// SaverDNSTransport is a DNS transport that saves events +func (r *SaverResolver) Network() string { + return r.Resolver.Network() +} + +func (r *SaverResolver) Address() string { + return r.Resolver.Address() +} + +func (r *SaverResolver) CloseIdleConnections() { + r.Resolver.CloseIdleConnections() +} + +func (r *SaverResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + // TODO(bassosimone): we should probably implement this method + return r.Resolver.LookupHTTPS(ctx, domain) +} + +func (r *SaverResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { + // TODO(bassosimone): we should probably implement this method + return r.Resolver.LookupNS(ctx, domain) +} + +// SaverDNSTransport is a DNS transport that saves events. type SaverDNSTransport struct { - model.DNSTransport + // DNSTransport is the underlying DNS transport. + DNSTransport model.DNSTransport + + // Saver saves events. Saver *Saver } +// WrapDNSTransport wraps a model.DNSTransport with a SaverDNSTransport that +// will save the DNS round trip results into this Saver. +// +// When this function is invoked on a nil Saver, it will directly return +// the original DNSTransport without any wrapping. +func (s *Saver) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport { + if s == nil { + return txp + } + return &SaverDNSTransport{ + DNSTransport: txp, + Saver: s, + } +} + // RoundTrip implements RoundTripper.RoundTrip -func (txp SaverDNSTransport) RoundTrip( +func (txp *SaverDNSTransport) RoundTrip( ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { start := time.Now() txp.Saver.Write(Event{ - Address: txp.Address(), - DNSQuery: txp.maybeQueryBytes(query), + Address: txp.DNSTransport.Address(), + DNSQuery: dnsMaybeQueryBytes(query), Name: "dns_round_trip_start", - Proto: txp.Network(), + Proto: txp.DNSTransport.Network(), Time: start, }) response, err := txp.DNSTransport.RoundTrip(ctx, query) stop := time.Now() txp.Saver.Write(Event{ - Address: txp.Address(), - DNSQuery: txp.maybeQueryBytes(query), - DNSReply: txp.maybeResponseBytes(response), + Address: txp.DNSTransport.Address(), + DNSQuery: dnsMaybeQueryBytes(query), + DNSReply: dnsMaybeResponseBytes(response), Duration: stop.Sub(start), Err: err, Name: "dns_round_trip_done", - Proto: txp.Network(), + Proto: txp.DNSTransport.Network(), Time: stop, }) return response, err } -func (txp SaverDNSTransport) maybeQueryBytes(query model.DNSQuery) []byte { +func (txp *SaverDNSTransport) Network() string { + return txp.DNSTransport.Network() +} + +func (txp *SaverDNSTransport) Address() string { + return txp.DNSTransport.Address() +} + +func (txp *SaverDNSTransport) CloseIdleConnections() { + txp.DNSTransport.CloseIdleConnections() +} + +func (txp *SaverDNSTransport) RequiresPadding() bool { + return txp.DNSTransport.RequiresPadding() +} + +func dnsMaybeQueryBytes(query model.DNSQuery) []byte { data, _ := query.Bytes() return data } -func (txp SaverDNSTransport) maybeResponseBytes(response model.DNSResponse) []byte { +func dnsMaybeResponseBytes(response model.DNSResponse) []byte { if response == nil { return nil } return response.Bytes() } -var _ model.Resolver = SaverResolver{} -var _ model.DNSTransport = SaverDNSTransport{} +var _ model.Resolver = &SaverResolver{} +var _ model.DNSTransport = &SaverDNSTransport{} diff --git a/internal/engine/netx/tracex/resolver_test.go b/internal/engine/netx/tracex/resolver_test.go index 73a0590..643a0a2 100644 --- a/internal/engine/netx/tracex/resolver_test.go +++ b/internal/engine/netx/tracex/resolver_test.go @@ -17,10 +17,7 @@ import ( func TestSaverResolverFailure(t *testing.T) { expected := errors.New("no such host") saver := &Saver{} - reso := SaverResolver{ - Resolver: NewFakeResolverWithExplicitError(expected), - Saver: saver, - } + reso := saver.WrapResolver(NewFakeResolverWithExplicitError(expected)) addrs, err := reso.LookupHost(context.Background(), "www.google.com") if !errors.Is(err, expected) { t.Fatal("not the error we expected") @@ -64,10 +61,7 @@ func TestSaverResolverFailure(t *testing.T) { func TestSaverResolverSuccess(t *testing.T) { expected := []string{"8.8.8.8", "8.8.4.4"} saver := &Saver{} - reso := SaverResolver{ - Resolver: NewFakeResolverWithResult(expected), - Saver: saver, - } + reso := saver.WrapResolver(NewFakeResolverWithResult(expected)) addrs, err := reso.LookupHost(context.Background(), "www.google.com") if err != nil { t.Fatal("expected nil error here") @@ -111,20 +105,17 @@ func TestSaverResolverSuccess(t *testing.T) { func TestSaverDNSTransportFailure(t *testing.T) { expected := errors.New("no such host") saver := &Saver{} - txp := SaverDNSTransport{ - DNSTransport: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return nil, expected - }, - MockNetwork: func() string { - return "fake" - }, - MockAddress: func() string { - return "" - }, + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expected }, - Saver: saver, - } + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { @@ -179,20 +170,17 @@ func TestSaverDNSTransportSuccess(t *testing.T) { return expected }, } - txp := SaverDNSTransport{ - DNSTransport: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return response, nil - }, - MockNetwork: func() string { - return "fake" - }, - MockAddress: func() string { - return "" - }, + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return response, nil }, - Saver: saver, - } + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} query := &mocks.DNSQuery{ MockBytes: func() ([]byte, error) { diff --git a/internal/engine/netx/tracex/saver.go b/internal/engine/netx/tracex/saver.go index 53d9f8e..174f1b1 100644 --- a/internal/engine/netx/tracex/saver.go +++ b/internal/engine/netx/tracex/saver.go @@ -1,11 +1,19 @@ package tracex +// +// Saver implementation +// + import "sync" -// The Saver saves a trace +// The Saver saves a trace. The zero value of this type +// is valid and can be used without initializtion. type Saver struct { + // ops contains the saved events. ops []Event - mu sync.Mutex + + // mu provides mutual exclusion. + mu sync.Mutex } // Read reads and returns events inside the trace. It advances diff --git a/internal/engine/netx/tracex/saver_test.go b/internal/engine/netx/tracex/saver_test.go index b63e642..f412253 100644 --- a/internal/engine/netx/tracex/saver_test.go +++ b/internal/engine/netx/tracex/saver_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestGood(t *testing.T) { +func TestSaver(t *testing.T) { saver := Saver{} var wg sync.WaitGroup const parallel = 10 diff --git a/internal/engine/netx/tracex/tls.go b/internal/engine/netx/tracex/tls.go index 621c316..9ce4098 100644 --- a/internal/engine/netx/tracex/tls.go +++ b/internal/engine/netx/tracex/tls.go @@ -1,8 +1,14 @@ package tracex +// +// TLS +// + import ( "context" "crypto/tls" + "crypto/x509" + "errors" "net" "time" @@ -10,16 +16,33 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// SaverTLSHandshaker saves events occurring during the handshake +// SaverTLSHandshaker saves events occurring during the TLS handshake. type SaverTLSHandshaker struct { - model.TLSHandshaker + // TLSHandshaker is the underlying TLS handshaker. + TLSHandshaker model.TLSHandshaker + + // Saver is the saver in which to save events. Saver *Saver } -// Handshake implements TLSHandshaker.Handshake -func (h SaverTLSHandshaker) Handshake( - ctx context.Context, conn net.Conn, config *tls.Config, -) (net.Conn, tls.ConnectionState, error) { +// WrapTLSHandshaker wraps a model.TLSHandshaker with a SaverTLSHandshaker +// that will save the TLS handshake results into this Saver. +// +// When this function is invoked on a nil Saver, it will directly return +// the original TLSHandshaker without any wrapping. +func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker { + if s == nil { + return thx + } + return &SaverTLSHandshaker{ + TLSHandshaker: thx, + Saver: s, + } +} + +// Handshake implements model.TLSHandshaker.Handshake +func (h *SaverTLSHandshaker) Handshake( + ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { start := time.Now() h.Saver.Write(Event{ Name: "tls_handshake_start", @@ -40,7 +63,7 @@ func (h SaverTLSHandshaker) Handshake( TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSNegotiatedProto: state.NegotiatedProtocol, TLSNextProtos: config.NextProtos, - TLSPeerCerts: PeerCerts(state, err), + TLSPeerCerts: tlsPeerCerts(state, err), TLSServerName: config.ServerName, TLSVersion: netxlite.TLSVersionString(state.Version), Time: stop, @@ -48,4 +71,26 @@ func (h SaverTLSHandshaker) Handshake( return tlsconn, state, err } -var _ model.TLSHandshaker = SaverTLSHandshaker{} +var _ model.TLSHandshaker = &SaverTLSHandshaker{} + +// tlsPeerCerts returns the certificates presented by the peer regardless +// of whether the TLS handshake was successful +func tlsPeerCerts(state tls.ConnectionState, err error) []*x509.Certificate { + var x509HostnameError x509.HostnameError + if errors.As(err, &x509HostnameError) { + // Test case: https://wrong.host.badssl.com/ + return []*x509.Certificate{x509HostnameError.Certificate} + } + var x509UnknownAuthorityError x509.UnknownAuthorityError + if errors.As(err, &x509UnknownAuthorityError) { + // Test case: https://self-signed.badssl.com/. This error has + // never been among the ones returned by MK. + return []*x509.Certificate{x509UnknownAuthorityError.Cert} + } + var x509CertificateInvalidError x509.CertificateInvalidError + if errors.As(err, &x509CertificateInvalidError) { + // Test case: https://expired.badssl.com/ + return []*x509.Certificate{x509CertificateInvalidError.Cert} + } + return state.PeerCertificates +} diff --git a/internal/engine/netx/tracex/tls_test.go b/internal/engine/netx/tracex/tls_test.go index c5093ca..c6f9e98 100644 --- a/internal/engine/netx/tracex/tls_test.go +++ b/internal/engine/netx/tracex/tls_test.go @@ -30,10 +30,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { } }, ), - TLSHandshaker: SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, - Saver: saver, - }, + TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), } // Implementation note: we don't close the connection here because it is // very handy to have the last event being the end of the handshake @@ -121,12 +118,9 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) { nextprotos := []string{"h2"} saver := &Saver{} tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{NextProtos: nextprotos}, - Dialer: netxlite.DefaultDialer, - TLSHandshaker: SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, - Saver: saver, - }, + Config: &tls.Config{NextProtos: nextprotos}, + Dialer: &netxlite.DialerSystem{}, + TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), } conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") if err != nil { @@ -187,11 +181,8 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) { } saver := &Saver{} tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: netxlite.DefaultDialer, - TLSHandshaker: SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, - Saver: saver, - }, + Dialer: &netxlite.DialerSystem{}, + TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "wrong.host.badssl.com:443") @@ -220,11 +211,8 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { } saver := &Saver{} tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: netxlite.DefaultDialer, - TLSHandshaker: SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, - Saver: saver, - }, + Dialer: &netxlite.DialerSystem{}, + TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "expired.badssl.com:443") @@ -253,11 +241,8 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) { } saver := &Saver{} tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: netxlite.DefaultDialer, - TLSHandshaker: SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, - Saver: saver, - }, + Dialer: &netxlite.DialerSystem{}, + TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "self-signed.badssl.com:443") @@ -286,12 +271,9 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { } saver := &Saver{} tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{InsecureSkipVerify: true}, - Dialer: netxlite.DefaultDialer, - TLSHandshaker: SaverTLSHandshaker{ - TLSHandshaker: &netxlite.TLSHandshakerConfigurable{}, - Saver: saver, - }, + Config: &tls.Config{InsecureSkipVerify: true}, + Dialer: &netxlite.DialerSystem{}, + TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "self-signed.badssl.com:443")