From d3970360733cd5f0d7bd4b595d4be5b9c5bdca02 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 1 Jun 2022 23:15:47 +0200 Subject: [PATCH] refactor(tracex): convert to unit testing (#781) The exercise already allowed me to notice issues such as fields not being properly initialized by savers. This is one of the last steps before moving tracex away from the internal/netx package and into the internal package. See https://github.com/ooni/probe/issues/2121 --- .../experiment/urlgetter/configurer_test.go | 8 +- internal/engine/netx/netx.go | 2 +- internal/engine/netx/netx_test.go | 14 +- internal/engine/netx/tracex/archival.go | 61 +- internal/engine/netx/tracex/archival_test.go | 157 +++-- internal/engine/netx/tracex/dialer.go | 58 +- internal/engine/netx/tracex/dialer_test.go | 353 +++++++---- internal/engine/netx/tracex/event.go | 4 + internal/engine/netx/tracex/http.go | 32 +- internal/engine/netx/tracex/http_test.go | 445 +++++++------- internal/engine/netx/tracex/quic.go | 30 +- internal/engine/netx/tracex/quic_test.go | 578 +++++++++++++----- internal/engine/netx/tracex/resolver.go | 38 +- internal/engine/netx/tracex/resolver_test.go | 424 ++++++------- internal/engine/netx/tracex/saver_test.go | 96 ++- internal/engine/netx/tracex/tls.go | 16 +- internal/engine/netx/tracex/tls_test.go | 507 +++++++-------- 17 files changed, 1693 insertions(+), 1130 deletions(-) diff --git a/internal/engine/experiment/urlgetter/configurer_test.go b/internal/engine/experiment/urlgetter/configurer_test.go index bdb5212..bc3ff2a 100644 --- a/internal/engine/experiment/urlgetter/configurer_test.go +++ b/internal/engine/experiment/urlgetter/configurer_test.go @@ -119,7 +119,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } @@ -195,7 +195,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } @@ -271,7 +271,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T) if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } @@ -347,7 +347,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 3b07e56..29ae609 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -187,7 +187,7 @@ func NewHTTPTransport(config Config) model.HTTPTransport { txp = &netxlite.HTTPTransportLogger{Logger: config.Logger, HTTPTransport: txp} } if config.HTTPSaver != nil { - txp = &tracex.SaverTransactionHTTPTransport{ + txp = &tracex.HTTPTransportSaver{ HTTPTransport: txp, Saver: config.HTTPSaver} } return txp diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index d51bdf8..8a454fe 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -126,7 +126,7 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - sr, ok := ir.Resolver.(*tracex.SaverResolver) + sr, ok := ir.Resolver.(*tracex.ResolverSaver) if !ok { t.Fatal("not the resolver we expected") } @@ -332,7 +332,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - sth, ok := rtd.TLSHandshaker.(*tracex.SaverTLSHandshaker) + sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver) if !ok { t.Fatal("not the TLSHandshaker we expected") } @@ -504,7 +504,7 @@ func TestNewWithSaver(t *testing.T) { txp := netx.NewHTTPTransport(netx.Config{ HTTPSaver: saver, }) - stxptxp, ok := txp.(*tracex.SaverTransactionHTTPTransport) + stxptxp, ok := txp.(*tracex.HTTPTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -622,7 +622,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -659,7 +659,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -700,7 +700,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -745,7 +745,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } diff --git a/internal/engine/netx/tracex/archival.go b/internal/engine/netx/tracex/archival.go index 15dd5e2..9f9a3e6 100644 --- a/internal/engine/netx/tracex/archival.go +++ b/internal/engine/netx/tracex/archival.go @@ -1,5 +1,9 @@ package tracex +// +// Code to generate the OONI archival data format from events +// + import ( "crypto/x509" "errors" @@ -43,8 +47,7 @@ var ( ) // NewTCPConnectList creates a new TCPConnectList -func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry { - var out []TCPConnectEntry +func NewTCPConnectList(begin time.Time, events []Event) (out []TCPConnectEntry) { for _, wrapper := range events { if _, ok := wrapper.(*EventConnectOperation); !ok { continue @@ -60,13 +63,14 @@ func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry { IP: ip, Port: iport, Status: TCPConnectStatus{ + Blocked: nil, // only used by Web Connectivity Failure: NewFailure(event.Err), Success: event.Err == nil, }, T: event.Time.Sub(begin).Seconds(), }) } - return out + return } // NewFailure creates a failure nullable string from the given error @@ -101,11 +105,9 @@ func NewFailedOperation(err error) *string { return &s } -func httpAddHeaders( - source http.Header, - destList *[]HTTPHeader, - destMap *map[string]MaybeBinaryValue, -) { +// httpAddHeaders adds the headers inside source into destList and destMap. +func httpAddHeaders(source http.Header, destList *[]HTTPHeader, + destMap *map[string]MaybeBinaryValue) { *destList = []HTTPHeader{} *destMap = make(map[string]model.ArchivalMaybeBinaryData) for key, values := range source { @@ -122,32 +124,28 @@ func httpAddHeaders( }) } } + // Sorting helps with unit testing (map keys are unordered) sort.Slice(*destList, func(i, j int) bool { return (*destList)[i].Key < (*destList)[j].Key }) } // NewRequestList returns the list for "requests" -func NewRequestList(begin time.Time, events []Event) []RequestEntry { +func NewRequestList(begin time.Time, events []Event) (out []RequestEntry) { // OONI wants the last request to appear first - var out []RequestEntry tmp := newRequestList(begin, events) for i := len(tmp) - 1; i >= 0; i-- { out = append(out, tmp[i]) } - return out + return } -func newRequestList(begin time.Time, events []Event) []RequestEntry { - var ( - out []RequestEntry - entry RequestEntry - ) +func newRequestList(begin time.Time, events []Event) (out []RequestEntry) { for _, wrapper := range events { ev := wrapper.Value() switch wrapper.(type) { case *EventHTTPTransactionDone: - entry = RequestEntry{} + entry := RequestEntry{} entry.T = ev.Time.Sub(begin).Seconds() httpAddHeaders( ev.HTTPRequestHeaders, &entry.Request.HeadersList, &entry.Request.Headers) @@ -164,15 +162,14 @@ func newRequestList(begin time.Time, events []Event) []RequestEntry { out = append(out, entry) } } - return out + return } type dnsQueryType string // NewDNSQueriesList returns a list of DNS queries. -func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry { +func NewDNSQueriesList(begin time.Time, events []Event) (out []DNSQueryEntry) { // TODO(bassosimone): add support for CNAME lookups. - var out []DNSQueryEntry for _, wrapper := range events { if _, ok := wrapper.(*EventResolveDone); !ok { continue @@ -199,7 +196,7 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry { out = append(out, entry) } } - return out + return } func (qtype dnsQueryType) ipOfType(addr string) bool { @@ -214,6 +211,8 @@ func (qtype dnsQueryType) ipOfType(addr string) bool { func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry { answer := DNSAnswerEntry{AnswerType: string(qtype)} + // Figuring out the ASN and the org here is not just a service to whoever + // is reading a JSON: Web Connectivity also depends on it! asn, org, _ := geolocate.LookupASN(addr) answer.ASN = int64(asn) answer.ASOrgName = org @@ -237,9 +236,8 @@ func (qtype dnsQueryType) makeQueryEntry(begin time.Time, ev *EventValue) DNSQue } } -// NewNetworkEventsList returns a list of DNS queries. -func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent { - var out []NetworkEvent +// NewNetworkEventsList returns a list of network events. +func NewNetworkEventsList(begin time.Time, events []Event) (out []NetworkEvent) { for _, wrapper := range events { ev := wrapper.Value() switch wrapper.(type) { @@ -281,7 +279,7 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent { NumBytes: int64(ev.NumBytes), T: ev.Time.Sub(begin).Seconds(), }) - default: + default: // For example, "tls_handshake_done" (used in data analysis!) out = append(out, NetworkEvent{ Failure: NewFailure(ev.Err), Operation: wrapper.Name(), @@ -289,15 +287,14 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent { }) } } - return out + return } // NewTLSHandshakesList creates a new TLSHandshakesList -func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake { - var out []TLSHandshake +func NewTLSHandshakesList(begin time.Time, events []Event) (out []TLSHandshake) { for _, wrapper := range events { switch wrapper.(type) { - case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // ok + case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // interested default: continue // not interested } @@ -314,12 +311,12 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake { TLSVersion: ev.TLSVersion, }) } - return out + return } func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) { - for _, e := range in { - out = append(out, MaybeBinaryValue{Value: string(e.Raw)}) + for _, entry := range in { + out = append(out, MaybeBinaryValue{Value: string(entry.Raw)}) } return } diff --git a/internal/engine/netx/tracex/archival_test.go b/internal/engine/netx/tracex/archival_test.go index 5b9d8ed..ade7391 100644 --- a/internal/engine/netx/tracex/archival_test.go +++ b/internal/engine/netx/tracex/archival_test.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net/http" - "reflect" "testing" "time" @@ -15,42 +14,44 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestDNSQueryIPOfType(t *testing.T) { - type expectation struct { - qtype dnsQueryType - ip string - output bool - } - var expectations = []expectation{{ - qtype: "A", - ip: "8.8.8.8", - output: true, - }, { - qtype: "A", - ip: "2a00:1450:4002:801::2004", - output: false, - }, { - qtype: "AAAA", - ip: "8.8.8.8", - output: false, - }, { - qtype: "AAAA", - ip: "2a00:1450:4002:801::2004", - output: true, - }, { - qtype: "ANTANI", - ip: "2a00:1450:4002:801::2004", - output: false, - }, { - qtype: "ANTANI", - ip: "8.8.8.8", - output: false, - }} - for _, exp := range expectations { - if exp.qtype.ipOfType(exp.ip) != exp.output { - t.Fatalf("failure for %+v", exp) +func TestDNSQueryType(t *testing.T) { + t.Run("ipOfType", func(t *testing.T) { + type expectation struct { + qtype dnsQueryType + ip string + output bool } - } + var expectations = []expectation{{ + qtype: "A", + ip: "8.8.8.8", + output: true, + }, { + qtype: "A", + ip: "2a00:1450:4002:801::2004", + output: false, + }, { + qtype: "AAAA", + ip: "8.8.8.8", + output: false, + }, { + qtype: "AAAA", + ip: "2a00:1450:4002:801::2004", + output: true, + }, { + qtype: "ANTANI", + ip: "2a00:1450:4002:801::2004", + output: false, + }, { + qtype: "ANTANI", + ip: "8.8.8.8", + output: false, + }} + for _, exp := range expectations { + if exp.qtype.ipOfType(exp.ip) != exp.output { + t.Fatalf("failure for %+v", exp) + } + } + }) } func TestNewTCPConnectList(t *testing.T) { @@ -74,7 +75,7 @@ func TestNewTCPConnectList(t *testing.T) { name: "realistic run", args: args{ begin: begin, - events: []Event{&EventResolveDone{&EventValue{ + events: []Event{&EventResolveDone{&EventValue{ // skipped because not relevant Addresses: []string{"8.8.8.8", "8.8.4.4"}, Hostname: "dns.google.com", Time: begin.Add(100 * time.Millisecond), @@ -86,7 +87,7 @@ func TestNewTCPConnectList(t *testing.T) { }}, &EventConnectOperation{&EventValue{ Address: "8.8.8.8:853", Duration: 55 * time.Millisecond, - Proto: "udp", + Proto: "udp", // this one should be skipped because it's UDP Time: begin.Add(130 * time.Millisecond), }}, &EventConnectOperation{&EventValue{ Address: "8.8.4.4:53", @@ -115,8 +116,9 @@ func TestNewTCPConnectList(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewTCPConnectList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(got, tt.want)) + got := NewTCPConnectList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } @@ -143,6 +145,7 @@ func TestNewRequestList(t *testing.T) { name: "realistic run", args: args{ begin: begin, + // Two round trips so we can test the sorting expected by OONI events: []Event{&EventHTTPTransactionDone{&EventValue{ HTTPRequestHeaders: http.Header{ "User-Agent": []string{"miniooni/0.1.0-dev"}, @@ -286,8 +289,9 @@ func TestNewRequestList(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewRequestList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(tt.want, got)) + got := NewRequestList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } @@ -320,12 +324,12 @@ func TestNewDNSQueriesList(t *testing.T) { Hostname: "dns.google.com", Proto: "dot", Time: begin.Add(100 * time.Millisecond), - }}, &EventConnectOperation{&EventValue{ + }}, &EventConnectOperation{&EventValue{ // skipped because not relevant Address: "8.8.8.8:853", Duration: 30 * time.Millisecond, Proto: "tcp", Time: begin.Add(130 * time.Millisecond), - }}, &EventConnectOperation{&EventValue{ + }}, &EventConnectOperation{&EventValue{ // skipped because not relevant Address: "8.8.4.4:53", Duration: 50 * time.Millisecond, Err: io.EOF, @@ -452,6 +456,10 @@ func TestNewNetworkEventsList(t *testing.T) { Err: websocket.ErrBadHandshake, NumBytes: 4114, Time: begin.Add(14 * time.Millisecond), + }}, &EventResolveStart{&EventValue{ + // We expect this event to be logged event though it's not a typical I/O + // event (it seems these extra events are useful for debugging) + Time: begin.Add(15 * time.Millisecond), }}}, }, want: []NetworkEvent{{ @@ -482,12 +490,16 @@ func TestNewNetworkEventsList(t *testing.T) { NumBytes: 4114, Operation: netxlite.WriteToOperation, T: 0.014, + }, { + Operation: "resolve_start", + T: 0.015, }}, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewNetworkEventsList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(got, tt.want)) + got := NewNetworkEventsList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } @@ -511,13 +523,14 @@ func TestNewTLSHandshakesList(t *testing.T) { }, want: nil, }, { - name: "realistic run", + name: "realistic run with TLS", args: args{ begin: begin, events: []Event{&EventTLSHandshakeDone{&EventValue{ Address: "131.252.210.176:443", Err: io.EOF, NoTLSVerify: false, + Proto: "tcp", TLSCipherSuite: "SUITE", TLSNegotiatedProto: "h2", TLSPeerCerts: []*x509.Certificate{{ @@ -545,11 +558,57 @@ func TestNewTLSHandshakesList(t *testing.T) { T: 0.055, TLSVersion: "TLSv1.3", }}, + }, { + name: "realistic run with QUIC", + args: args{ + begin: begin, + events: []Event{&EventQUICHandshakeDone{&EventValue{ + Address: "131.252.210.176:443", + Err: io.EOF, + NoTLSVerify: false, + Proto: "quic", + TLSCipherSuite: "SUITE", + TLSNegotiatedProto: "h3", + TLSPeerCerts: []*x509.Certificate{{ + Raw: []byte("deadbeef"), + }, { + Raw: []byte("abad1dea"), + }}, + TLSServerName: "x.org", + TLSVersion: "TLSv1.3", + Time: begin.Add(55 * time.Millisecond), + }}}, + }, + want: []TLSHandshake{{ + Address: "131.252.210.176:443", + CipherSuite: "SUITE", + Failure: NewFailure(io.EOF), + NegotiatedProtocol: "h3", + NoTLSVerify: false, + PeerCertificates: []MaybeBinaryValue{{ + Value: "deadbeef", + }, { + Value: "abad1dea", + }}, + ServerName: "x.org", + T: 0.055, + TLSVersion: "TLSv1.3", + }}, + }, { + name: "realistic run with no suitable events", + args: args{ + begin: begin, + events: []Event{&EventResolveStart{&EventValue{ + Time: begin.Add(55 * time.Millisecond), + }}}, + }, + want: nil, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewTLSHandshakesList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(got, tt.want)) + got := NewTLSHandshakesList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } diff --git a/internal/engine/netx/tracex/dialer.go b/internal/engine/netx/tracex/dialer.go index 5096b64..8820850 100644 --- a/internal/engine/netx/tracex/dialer.go +++ b/internal/engine/netx/tracex/dialer.go @@ -12,8 +12,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// SaverDialer saves events occurring during the dial -type SaverDialer struct { +// DialerSaver saves events occurring during the dial +type DialerSaver struct { // Dialer is the underlying dialer, Dialer model.Dialer @@ -28,26 +28,26 @@ func (s *Saver) NewConnectObserver() model.DialerWrapper { if s == nil { return nil // valid DialerWrapper according to netxlite's docs } - return &saverDialerWrapper{ + return &dialerConnectObserver{ saver: s, } } -type saverDialerWrapper struct { +type dialerConnectObserver struct { saver *Saver } -var _ model.DialerWrapper = &saverDialerWrapper{} +var _ model.DialerWrapper = &dialerConnectObserver{} -func (w *saverDialerWrapper) WrapDialer(d model.Dialer) model.Dialer { - return &SaverDialer{ +func (w *dialerConnectObserver) WrapDialer(d model.Dialer) model.Dialer { + return &DialerSaver{ Dialer: d, Saver: w.saver, } } // DialContext implements Dialer.DialContext -func (d *SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { start := time.Now() conn, err := d.Dialer.DialContext(ctx, network, address) stop := time.Now() @@ -61,13 +61,13 @@ func (d *SaverDialer) DialContext(ctx context.Context, network, address string) return conn, err } -func (d *SaverDialer) CloseIdleConnections() { +func (d *DialerSaver) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -// SaverConnDialer wraps the returned connection such that we +// DialerConnSaver wraps the returned connection such that we // collect all the read/write events that occur. -type SaverConnDialer struct { +type DialerConnSaver struct { // Dialer is the underlying dialer Dialer model.Dialer @@ -82,70 +82,78 @@ func (s *Saver) NewReadWriteObserver() model.DialerWrapper { if s == nil { return nil // valid DialerWrapper according to netxlite's docs } - return &saverReadWriteWrapper{ + return &dialerReadWriteObserver{ saver: s, } } -type saverReadWriteWrapper struct { +type dialerReadWriteObserver struct { saver *Saver } -var _ model.DialerWrapper = &saverReadWriteWrapper{} +var _ model.DialerWrapper = &dialerReadWriteObserver{} -func (w *saverReadWriteWrapper) WrapDialer(d model.Dialer) model.Dialer { - return &SaverConnDialer{ +func (w *dialerReadWriteObserver) WrapDialer(d model.Dialer) model.Dialer { + return &DialerConnSaver{ Dialer: d, Saver: w.saver, } } // DialContext implements Dialer.DialContext -func (d *SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerConnSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { return nil, err } - return &saverConn{saver: d.Saver, Conn: conn}, nil + return &dialerConnWrapper{saver: d.Saver, Conn: conn}, nil } -func (d *SaverConnDialer) CloseIdleConnections() { +func (d *DialerConnSaver) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -type saverConn struct { +type dialerConnWrapper struct { net.Conn saver *Saver } -func (c *saverConn) Read(p []byte) (int, error) { +func (c *dialerConnWrapper) Read(p []byte) (int, error) { + proto := c.Conn.RemoteAddr().Network() + remoteAddr := c.Conn.RemoteAddr().String() start := time.Now() count, err := c.Conn.Read(p) stop := time.Now() c.saver.Write(&EventReadOperation{&EventValue{ + Address: remoteAddr, Data: p[:count], Duration: stop.Sub(start), Err: err, NumBytes: count, + Proto: proto, Time: stop, }}) return count, err } -func (c *saverConn) Write(p []byte) (int, error) { +func (c *dialerConnWrapper) Write(p []byte) (int, error) { + proto := c.Conn.RemoteAddr().Network() + remoteAddr := c.Conn.RemoteAddr().String() start := time.Now() count, err := c.Conn.Write(p) stop := time.Now() c.saver.Write(&EventWriteOperation{&EventValue{ + Address: remoteAddr, Data: p[:count], Duration: stop.Sub(start), Err: err, NumBytes: count, + Proto: proto, Time: stop, }}) return count, err } -var _ model.Dialer = &SaverDialer{} -var _ model.Dialer = &SaverConnDialer{} -var _ net.Conn = &saverConn{} +var _ model.Dialer = &DialerSaver{} +var _ model.Dialer = &DialerConnSaver{} +var _ net.Conn = &dialerConnWrapper{} diff --git a/internal/engine/netx/tracex/dialer_test.go b/internal/engine/netx/tracex/dialer_test.go index 8608faa..2447ce4 100644 --- a/internal/engine/netx/tracex/dialer_test.go +++ b/internal/engine/netx/tracex/dialer_test.go @@ -12,127 +12,268 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestSaverDialerFailure(t *testing.T) { - expected := errors.New("mocked error") +func TestDialerConnectObserver(t *testing.T) { saver := &Saver{} - dlr := &SaverDialer{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, expected - }, - }, - Saver: saver, + obs := &dialerConnectObserver{ + saver: saver, } - conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, expected) { - t.Fatal("expected another error here") + dialer := &mocks.Dialer{} + out := obs.WrapDialer(dialer) + dialSaver := out.(*DialerSaver) + if dialSaver.Dialer != dialer { + t.Fatal("invalid dialer") } - if conn != nil { - t.Fatal("expected nil conn here") - } - ev := saver.Read() - if len(ev) != 1 { - t.Fatal("expected a single event here") - } - if ev[0].Value().Address != "www.google.com:443" { - t.Fatal("unexpected Address") - } - if ev[0].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if !errors.Is(ev[0].Value().Err, expected) { - t.Fatal("unexpected Err") - } - if ev[0].Name() != netxlite.ConnectOperation { - t.Fatal("unexpected Name") - } - if ev[0].Value().Proto != "tcp" { - t.Fatal("unexpected Proto") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("unexpected Time") + if dialSaver.Saver != saver { + t.Fatal("invalid saver") } } -func TestSaverConnDialerFailure(t *testing.T) { - expected := errors.New("mocked error") - saver := &Saver{} - dlr := &SaverConnDialer{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, expected - }, - }, - Saver: saver, - } - conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } -} - -func TestSaverConnDialerSuccess(t *testing.T) { - saver := &Saver{} - dlr := &SaverConnDialer{ - Dialer: &SaverDialer{ +func TestDialerSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + dlr := &DialerSaver{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return &mocks.Conn{ - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockWrite: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockClose: func() error { - return io.EOF - }, - MockLocalAddr: func() net.Addr { - return &net.TCPAddr{Port: 12345} - }, - }, nil + return nil, expected }, }, Saver: saver, - }, - Saver: saver, - } - conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal("not the error we expected", err) - } - conn.Read(nil) - conn.Write(nil) - conn.Close() - events := saver.Read() - if len(events) != 3 { - t.Fatal("unexpected number of events saved", len(events)) - } - if events[0].Name() != "connect" { - t.Fatal("expected a connect event") - } - saverCheckConnectEvent(t, &events[0]) - if events[1].Name() != "read" { - t.Fatal("expected a read event") - } - saverCheckReadEvent(t, &events[1]) - if events[2].Name() != "write" { - t.Fatal("expected a write event") - } - saverCheckWriteEvent(t, &events[2]) + } + conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") + if !errors.Is(err, expected) { + t.Fatal("expected another error here") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + ev := saver.Read() + if len(ev) != 1 { + t.Fatal("expected a single event here") + } + if ev[0].Value().Address != "www.google.com:443" { + t.Fatal("unexpected Address") + } + if ev[0].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[0].Value().Err, expected) { + t.Fatal("unexpected Err") + } + if ev[0].Name() != netxlite.ConnectOperation { + t.Fatal("unexpected Name") + } + if ev[0].Value().Proto != "tcp" { + t.Fatal("unexpected Proto") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("unexpected Time") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + dialer := &DialerSaver{ + Dialer: child, + Saver: &Saver{}, + } + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) } -func saverCheckConnectEvent(t *testing.T, ev *Event) { - // TODO(bassosimone): implement +func TestDialerReadWriteObserver(t *testing.T) { + saver := &Saver{} + obs := &dialerReadWriteObserver{ + saver: saver, + } + dialer := &mocks.Dialer{} + out := obs.WrapDialer(dialer) + dialSaver := out.(*DialerConnSaver) + if dialSaver.Dialer != dialer { + t.Fatal("invalid dialer") + } + if dialSaver.Saver != saver { + t.Fatal("invalid saver") + } } -func saverCheckReadEvent(t *testing.T, ev *Event) { - // TODO(bassosimone): implement +func TestDialerConnSaver(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + dlr := &DialerConnSaver{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }, + }, + Saver: saver, + } + conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + + t.Run("on success", func(t *testing.T) { + origConn := &mocks.Conn{} + saver := &Saver{} + dlr := &DialerConnSaver{ + Dialer: &DialerSaver{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return origConn, nil + }, + }, + Saver: saver, + }, + Saver: saver, + } + conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") + if err != nil { + t.Fatal("not the error we expected", err) + } + cw := conn.(*dialerConnWrapper) + if cw.Conn != origConn { + t.Fatal("unexpected conn") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + dialer := &DialerConnSaver{ + Dialer: child, + Saver: &Saver{}, + } + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) } -func saverCheckWriteEvent(t *testing.T, ev *Event) { - // TODO(bassosimone): implement +func TestDialerConnWrapper(t *testing.T) { + t.Run("Read", func(t *testing.T) { + baseConn := &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "www.google.com:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + saver := &Saver{} + conn := &dialerConnWrapper{ + Conn: baseConn, + saver: saver, + } + data := make([]byte, 155) + count, err := conn.Read(data) + if !errors.Is(err, io.EOF) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + ev := saver.Read() + if len(ev) != 1 { + t.Fatal("expected a single event here") + } + if ev[0].Value().Address != "www.google.com:443" { + t.Fatal("unexpected Address") + } + if ev[0].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[0].Value().Err, io.EOF) { + t.Fatal("unexpected Err") + } + if ev[0].Name() != netxlite.ReadOperation { + t.Fatal("unexpected Name") + } + if ev[0].Value().Proto != "tcp" { + t.Fatal("unexpected Proto") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("unexpected Time") + } + }) + + t.Run("Write", func(t *testing.T) { + baseConn := &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "www.google.com:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + saver := &Saver{} + conn := &dialerConnWrapper{ + Conn: baseConn, + saver: saver, + } + data := make([]byte, 155) + count, err := conn.Write(data) + if !errors.Is(err, io.EOF) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + ev := saver.Read() + if len(ev) != 1 { + t.Fatal("expected a single event here") + } + if ev[0].Value().Address != "www.google.com:443" { + t.Fatal("unexpected Address") + } + if ev[0].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[0].Value().Err, io.EOF) { + t.Fatal("unexpected Err") + } + if ev[0].Name() != netxlite.WriteOperation { + t.Fatal("unexpected Name") + } + if ev[0].Value().Proto != "tcp" { + t.Fatal("unexpected Proto") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("unexpected Time") + } + }) } diff --git a/internal/engine/netx/tracex/event.go b/internal/engine/netx/tracex/event.go index 41a23af..48f2b71 100644 --- a/internal/engine/netx/tracex/event.go +++ b/internal/engine/netx/tracex/event.go @@ -1,5 +1,9 @@ package tracex +// +// All the possible events +// + import ( "crypto/x509" "net/http" diff --git a/internal/engine/netx/tracex/http.go b/internal/engine/netx/tracex/http.go index 744a305..9d745aa 100644 --- a/internal/engine/netx/tracex/http.go +++ b/internal/engine/netx/tracex/http.go @@ -27,11 +27,17 @@ func httpCloneRequestHeaders(req *http.Request) http.Header { return header } -// SaverTransactionHTTPTransport is a RoundTripper that saves +// HTTPTransportSaver is a RoundTripper that saves // events related to the HTTP transaction -type SaverTransactionHTTPTransport struct { - model.HTTPTransport - Saver *Saver +type HTTPTransportSaver struct { + // HTTPTransport is the MANDATORY underlying HTTP transport. + HTTPTransport model.HTTPTransport + + // Saver is the MANDATORY saver to use. + Saver *Saver + + // SnapshotSize is the OPTIONAL maximum body snapshot size (if not set, we'll + // use 1<<17, which we've been using since the ooni/netx days) SnapshotSize int64 } @@ -40,7 +46,11 @@ type SaverTransactionHTTPTransport struct { // // The maxBodySnapshotSize argument controls the maximum size of the // body snapshot that we collect along with the HTTP round trip. -func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (txp *HTTPTransportSaver) RoundTrip(req *http.Request) (*http.Response, error) { + + // TODO(bassosimone): we're currently using the started time for + // the transaction done event, which contrasts with what we do for + // every other event. What does the spec say? started := time.Now() txp.Saver.Write(&EventHTTPTransactionStart{&EventValue{ @@ -92,7 +102,15 @@ func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Re return resp, nil } -func (txp *SaverTransactionHTTPTransport) snapshotSize() int64 { +func (txp *HTTPTransportSaver) CloseIdleConnections() { + txp.HTTPTransport.CloseIdleConnections() +} + +func (txp *HTTPTransportSaver) Network() string { + return txp.HTTPTransport.Network() +} + +func (txp *HTTPTransportSaver) snapshotSize() int64 { if txp.SnapshotSize > 0 { return txp.SnapshotSize } @@ -104,4 +122,4 @@ type httpReadableAgainBody struct { io.Closer } -var _ model.HTTPTransport = &SaverTransactionHTTPTransport{} +var _ model.HTTPTransport = &HTTPTransportSaver{} diff --git a/internal/engine/netx/tracex/http_test.go b/internal/engine/netx/tracex/http_test.go index 1ed764f..0efcae3 100644 --- a/internal/engine/netx/tracex/http_test.go +++ b/internal/engine/netx/tracex/http_test.go @@ -16,227 +16,262 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" ) -func TestSaverTransactionHTTPTransport(t *testing.T) { +func TestHTTPTransportSaver(t *testing.T) { - startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) { - server := &filtering.HTTPProxy{ - OnIncomingHost: func(host string) filtering.HTTPAction { - return action + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + called = true }, } - listener, err := server.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) + dialer := &HTTPTransportSaver{ + HTTPTransport: child, + Saver: &Saver{}, } - URL := &url.URL{ - Scheme: "http", - Host: listener.Addr().String(), - Path: "/", - } - return listener, URL - } - - measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) { - saver := &Saver{} - txp := &SaverTransactionHTTPTransport{ - HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger), - Saver: saver, - } - req, err := http.NewRequest("GET", URL.String(), nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("User-Agent", "miniooni") - resp, err := txp.RoundTrip(req) - return resp, saver, err - } - - validateRequestFields := func(t *testing.T, value *EventValue, URL *url.URL) { - if value.HTTPMethod != "GET" { - t.Fatal("invalid method") - } - if value.HTTPRequestHeaders.Get("Host") != URL.Host { - t.Fatal("invalid Host header") - } - if value.HTTPRequestHeaders.Get("User-Agent") != "miniooni" { - t.Fatal("invalid User-Agent header") - } - if value.HTTPURL != URL.String() { - t.Fatal("invalid URL") - } - if value.Time.IsZero() { - t.Fatal("expected nonzero Time") - } - if value.Transport != "tcp" { - t.Fatal("expected Transport to be tcp") - } - } - - validateRequest := func(t *testing.T, ev Event, URL *url.URL) { - if _, good := ev.(*EventHTTPTransactionStart); !good { - t.Fatal("invalid event type") - } - if ev.Name() != "http_transaction_start" { - t.Fatal("invalid event name") - } - value := ev.Value() - validateRequestFields(t, value, URL) - } - - validateResponseSuccess := func(t *testing.T, ev Event, URL *url.URL) { - if _, good := ev.(*EventHTTPTransactionDone); !good { - t.Fatal("invalid event type") - } - if ev.Name() != "http_transaction_done" { - t.Fatal("invalid event name") - } - value := ev.Value() - validateRequestFields(t, value, URL) - if value.Duration <= 0 { - t.Fatal("expected nonzero duration") - } - if len(value.HTTPResponseHeaders) <= 0 { - t.Fatal("expected at least one response header") - } - if !bytes.Equal(value.HTTPResponseBody, filtering.HTTPBlockpage451) { - t.Fatal("unexpected value for response body") - } - if value.HTTPStatusCode != 451 { - t.Fatal("unexpected status code") - } - } - - t.Run("on success", func(t *testing.T) { - listener, URL := startServer(t, filtering.HTTPAction451) - defer listener.Close() - resp, saver, err := measureHTTP(t, URL) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.StatusCode != 451 { - t.Fatal("unexpected status code", resp.StatusCode) - } - events := saver.Read() - if len(events) != 2 { - t.Fatal("unexpected number of events") - } - validateRequest(t, events[0], URL) - validateResponseSuccess(t, events[1], URL) - data, err := netxlite.ReadAllContext(context.Background(), resp.Body) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, filtering.HTTPBlockpage451) { - t.Fatal("we cannot re-read the same body") + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") } }) - validateResponseFailure := func(t *testing.T, ev Event, URL *url.URL) { - if _, good := ev.(*EventHTTPTransactionDone); !good { - t.Fatal("invalid event type") + t.Run("Network", func(t *testing.T) { + expected := "antani" + child := &mocks.HTTPTransport{ + MockNetwork: func() string { + return expected + }, } - if ev.Name() != "http_transaction_done" { - t.Fatal("invalid event name") + dialer := &HTTPTransportSaver{ + HTTPTransport: child, + Saver: &Saver{}, } - value := ev.Value() - validateRequestFields(t, value, URL) - if value.Duration <= 0 { - t.Fatal("expected nonzero duration") + if dialer.Network() != expected { + t.Fatal("unexpected Network") } - if value.Err.Error() != "connection_reset" { - t.Fatal("unexpected Err value") - } - if len(value.HTTPResponseHeaders) > 0 { - t.Fatal("expected zero response headers") - } - if !bytes.Equal(value.HTTPResponseBody, nil) { - t.Fatal("unexpected value for response body") - } - if value.HTTPStatusCode != 0 { - t.Fatal("unexpected status code") - } - } - - t.Run("on round trip failure", func(t *testing.T) { - listener, URL := startServer(t, filtering.HTTPActionReset) - defer listener.Close() - resp, saver, err := measureHTTP(t, URL) - if err == nil || err.Error() != "connection_reset" { - t.Fatal("unexpected err", err) - } - if resp != nil { - t.Fatal("expected nil response") - } - events := saver.Read() - if len(events) != 2 { - t.Fatal("unexpected number of events") - } - validateRequest(t, events[0], URL) - validateResponseFailure(t, events[1], URL) }) - // Sometimes useful for testing - /* - dumplog := func(t *testing.T, ev Event) { - data, _ := json.MarshalIndent(ev.Value(), " ", " ") - t.Log(string(data)) - t.FailNow() + t.Run("RoundTrip", func(t *testing.T) { + startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) { + server := &filtering.HTTPProxy{ + OnIncomingHost: func(host string) filtering.HTTPAction { + return action + }, + } + listener, err := server.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + URL := &url.URL{ + Scheme: "http", + Host: listener.Addr().String(), + Path: "/", + } + return listener, URL } - */ - t.Run("on error reading the response body", func(t *testing.T) { - saver := &Saver{} - expected := errors.New("mocked error") - txp := SaverTransactionHTTPTransport{ - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - Header: http.Header{ - "Server": {"antani"}, - }, - StatusCode: 200, - Body: io.NopCloser(&mocks.Reader{ - MockRead: func(b []byte) (int, error) { - return 0, expected + measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) { + saver := &Saver{} + txp := &HTTPTransportSaver{ + HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger), + Saver: saver, + } + req, err := http.NewRequest("GET", URL.String(), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("User-Agent", "miniooni") + resp, err := txp.RoundTrip(req) + return resp, saver, err + } + + validateRequestFields := func(t *testing.T, value *EventValue, URL *url.URL) { + if value.HTTPMethod != "GET" { + t.Fatal("invalid method") + } + if value.HTTPRequestHeaders.Get("Host") != URL.Host { + t.Fatal("invalid Host header") + } + if value.HTTPRequestHeaders.Get("User-Agent") != "miniooni" { + t.Fatal("invalid User-Agent header") + } + if value.HTTPURL != URL.String() { + t.Fatal("invalid URL") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + if value.Transport != "tcp" { + t.Fatal("expected Transport to be tcp") + } + } + + validateRequest := func(t *testing.T, ev Event, URL *url.URL) { + if _, good := ev.(*EventHTTPTransactionStart); !good { + t.Fatal("invalid event type") + } + if ev.Name() != "http_transaction_start" { + t.Fatal("invalid event name") + } + value := ev.Value() + validateRequestFields(t, value, URL) + } + + validateResponseSuccess := func(t *testing.T, ev Event, URL *url.URL) { + if _, good := ev.(*EventHTTPTransactionDone); !good { + t.Fatal("invalid event type") + } + if ev.Name() != "http_transaction_done" { + t.Fatal("invalid event name") + } + value := ev.Value() + validateRequestFields(t, value, URL) + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if len(value.HTTPResponseHeaders) <= 0 { + t.Fatal("expected at least one response header") + } + if !bytes.Equal(value.HTTPResponseBody, filtering.HTTPBlockpage451) { + t.Fatal("unexpected value for response body") + } + if value.HTTPStatusCode != 451 { + t.Fatal("unexpected status code") + } + } + + t.Run("on success", func(t *testing.T) { + listener, URL := startServer(t, filtering.HTTPAction451) + defer listener.Close() + resp, saver, err := measureHTTP(t, URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 451 { + t.Fatal("unexpected status code", resp.StatusCode) + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("unexpected number of events") + } + validateRequest(t, events[0], URL) + validateResponseSuccess(t, events[1], URL) + data, err := netxlite.ReadAllContext(context.Background(), resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, filtering.HTTPBlockpage451) { + t.Fatal("we cannot re-read the same body") + } + }) + + validateResponseFailure := func(t *testing.T, ev Event, URL *url.URL) { + if _, good := ev.(*EventHTTPTransactionDone); !good { + t.Fatal("invalid event type") + } + if ev.Name() != "http_transaction_done" { + t.Fatal("invalid event name") + } + value := ev.Value() + validateRequestFields(t, value, URL) + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if value.Err.Error() != "connection_reset" { + t.Fatal("unexpected Err value") + } + if len(value.HTTPResponseHeaders) > 0 { + t.Fatal("expected zero response headers") + } + if !bytes.Equal(value.HTTPResponseBody, nil) { + t.Fatal("unexpected value for response body") + } + if value.HTTPStatusCode != 0 { + t.Fatal("unexpected status code") + } + } + + t.Run("on round trip failure", func(t *testing.T) { + listener, URL := startServer(t, filtering.HTTPActionReset) + defer listener.Close() + resp, saver, err := measureHTTP(t, URL) + if err == nil || err.Error() != "connection_reset" { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("unexpected number of events") + } + validateRequest(t, events[0], URL) + validateResponseFailure(t, events[1], URL) + }) + + // Sometimes useful for testing + /* + dump := func(t *testing.T, ev Event) { + data, _ := json.MarshalIndent(ev.Value(), " ", " ") + t.Log(string(data)) + t.Fail() + } + */ + + t.Run("on error reading the response body", func(t *testing.T) { + saver := &Saver{} + expected := errors.New("mocked error") + txp := HTTPTransportSaver{ + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Header: http.Header{ + "Server": {"antani"}, }, - }), - }, nil + StatusCode: 200, + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + }), + }, nil + }, + MockNetwork: func() string { + return "tcp" + }, }, - MockNetwork: func() string { - return "tcp" - }, - }, - SnapshotSize: 4, - Saver: saver, - } - URL := &url.URL{ - Scheme: "http", - Host: "127.0.0.1:9050", - } - req, err := http.NewRequest("GET", URL.String(), nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("User-Agent", "miniooni") - resp, err := txp.RoundTrip(req) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if resp != nil { - t.Fatal("expected nil response") - } - ev := saver.Read() - validateRequest(t, ev[0], URL) - if ev[1].Value().HTTPStatusCode != 200 { - t.Fatal("invalid status code") - } - if ev[1].Value().HTTPResponseHeaders.Get("Server") != "antani" { - t.Fatal("invalid Server header") - } - if ev[1].Value().Err.Error() != "unknown_failure: mocked error" { - t.Fatal("invalid error") - } + SnapshotSize: 4, + Saver: saver, + } + URL := &url.URL{ + Scheme: "http", + Host: "127.0.0.1:9050", + } + req, err := http.NewRequest("GET", URL.String(), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("User-Agent", "miniooni") + resp, err := txp.RoundTrip(req) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response") + } + ev := saver.Read() + validateRequest(t, ev[0], URL) + if ev[1].Value().HTTPStatusCode != 200 { + t.Fatal("invalid status code") + } + if ev[1].Value().HTTPResponseHeaders.Get("Server") != "antani" { + t.Fatal("invalid Server header") + } + if ev[1].Value().Err.Error() != "unknown_failure: mocked error" { + t.Fatal("invalid error") + } + }) }) } diff --git a/internal/engine/netx/tracex/quic.go b/internal/engine/netx/tracex/quic.go index fb4f2c7..2162ec6 100644 --- a/internal/engine/netx/tracex/quic.go +++ b/internal/engine/netx/tracex/quic.go @@ -15,8 +15,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// QUICHandshakeSaver saves events occurring during the QUIC handshake. -type QUICHandshakeSaver struct { +// QUICDialerSaver saves events occurring during the QUIC handshake. +type QUICDialerSaver struct { // QUICDialer is the wrapped dialer QUICDialer model.QUICDialer @@ -33,14 +33,14 @@ func (s *Saver) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer { if s == nil { return qd } - return &QUICHandshakeSaver{ + return &QUICDialerSaver{ QUICDialer: qd, Saver: s, } } // DialContext implements QUICDialer.DialContext -func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, +func (h *QUICDialerSaver) DialContext(ctx context.Context, network string, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { start := time.Now() // TODO(bassosimone): in the future we probably want to also save @@ -58,9 +58,11 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, if err != nil { // TODO(bassosimone): here we should save the peer certs h.Saver.Write(&EventQUICHandshakeDone{&EventValue{ + Address: host, Duration: stop.Sub(start), Err: err, NoTLSVerify: tlsCfg.InsecureSkipVerify, + Proto: network, TLSNextProtos: tlsCfg.NextProtos, TLSServerName: tlsCfg.ServerName, Time: stop, @@ -69,8 +71,10 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, } state := quicConnectionState(sess) h.Saver.Write(&EventQUICHandshakeDone{&EventValue{ + Address: host, Duration: stop.Sub(start), NoTLSVerify: tlsCfg.InsecureSkipVerify, + Proto: network, TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSNegotiatedProto: state.NegotiatedProtocol, TLSNextProtos: tlsCfg.NextProtos, @@ -82,7 +86,7 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, return sess, nil } -func (h *QUICHandshakeSaver) CloseIdleConnections() { +func (h *QUICDialerSaver) CloseIdleConnections() { h.QUICDialer.CloseIdleConnections() } @@ -121,15 +125,15 @@ func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, erro if err != nil { return nil, err } - pconn = &udpLikeConnSaver{ + pconn = &quicPacketConnWrapper{ UDPLikeConn: pconn, saver: qls.Saver, } return pconn, nil } -// udpLikeConnSaver saves I/O events -type udpLikeConnSaver struct { +// quicPacketConnWrapper saves I/O events +type quicPacketConnWrapper struct { // UDPLikeConn is the wrapped underlying conn model.UDPLikeConn @@ -137,7 +141,7 @@ type udpLikeConnSaver struct { saver *Saver } -func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) { +func (c *quicPacketConnWrapper) WriteTo(p []byte, addr net.Addr) (int, error) { start := time.Now() count, err := c.UDPLikeConn.WriteTo(p, addr) stop := time.Now() @@ -152,7 +156,7 @@ func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) { return count, err } -func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *quicPacketConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) { start := time.Now() n, addr, err := c.UDPLikeConn.ReadFrom(b) stop := time.Now() @@ -171,13 +175,13 @@ func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) { return n, addr, err } -func (c *udpLikeConnSaver) safeAddrString(addr net.Addr) (out string) { +func (c *quicPacketConnWrapper) safeAddrString(addr net.Addr) (out string) { if addr != nil { out = addr.String() } return } -var _ model.QUICDialer = &QUICHandshakeSaver{} +var _ model.QUICDialer = &QUICDialerSaver{} var _ model.QUICListener = &QUICListenerSaver{} -var _ model.UDPLikeConn = &udpLikeConnSaver{} +var _ model.UDPLikeConn = &quicPacketConnWrapper{} diff --git a/internal/engine/netx/tracex/quic_test.go b/internal/engine/netx/tracex/quic_test.go index 804cbb6..4f766ed 100644 --- a/internal/engine/netx/tracex/quic_test.go +++ b/internal/engine/netx/tracex/quic_test.go @@ -3,182 +3,446 @@ package tracex import ( "context" "crypto/tls" + "crypto/x509" "errors" "net" - "reflect" - "strings" "testing" - "time" + "github.com/google/go-cmp/cmp" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" - "github.com/ooni/probe-cli/v3/internal/netxlite" - "github.com/ooni/probe-cli/v3/internal/netxlite/quictesting" ) -type MockDialer struct { - Dialer model.QUICDialer - Sess quic.EarlyConnection - Err error -} +func TestQUICDialerSaver(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { -func (d MockDialer) DialContext(ctx context.Context, network, host string, - tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - if d.Dialer != nil { - return d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg) - } - return d.Sess, d.Err -} + checkStartEventFields := func(t *testing.T, value *EventValue) { + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if !value.NoTLSVerify { + t.Fatal("expected NoTLSVerify to be true") + } + if value.Proto != "udp" { + t.Fatal("wrong protocol") + } + if diff := cmp.Diff(value.TLSNextProtos, []string{"h3"}); diff != "" { + t.Fatal(diff) + } + if value.TLSServerName != "dns.google" { + t.Fatal("invalid TLSServerName") + } + if value.Time.IsZero() { + t.Fatal("expected non zero time") + } + } -func TestHandshakeSaverSuccess(t *testing.T) { - nextprotos := []string{"h3"} - servername := quictesting.Domain - tlsConf := &tls.Config{ - NextProtos: nextprotos, - ServerName: servername, - } - saver := &Saver{} - dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ - QUICListener: &netxlite.QUICListenerStdlib{}, + checkStartedEvent := func(t *testing.T, ev Event) { + if _, good := ev.(*EventQUICHandshakeStart); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) + } + + checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err != nil { + t.Fatal("expected no error here") + } + if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" { + t.Fatal("invalid cipher suite") + } + if value.TLSNegotiatedProto != "h3" { + t.Fatal("invalid negotiated protocol") + } + if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" { + t.Fatal(diff) + } + if value.TLSVersion != "TLSv1.3" { + t.Fatal("invalid TLS version") + } + } + + checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) { + if _, good := ev.(*EventQUICHandshakeDone); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) + fun(t, value) + } + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + returnedConn := &mocks.QUICEarlyConnection{ + MockConnectionState: func() quic.ConnectionState { + cs := quic.ConnectionState{} + cs.TLS.ConnectionState.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA + cs.TLS.NegotiatedProtocol = "h3" + cs.TLS.PeerCertificates = []*x509.Certificate{} + cs.TLS.Version = tls.VersionTLS13 + return cs + }, + } + dialer := saver.WrapQUICDialer(&mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return returnedConn, nil + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h3"}, + ServerName: "dns.google", + } + quicConfig := &quic.Config{} + conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess) + }) + + checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err == nil { + t.Fatal("expected non-nil error here") + } + } + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + dialer := saver.WrapQUICDialer(&mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return nil, expected + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h3"}, + ServerName: "dns.google", + } + quicConfig := &quic.Config{} + conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsFailure) + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.QUICDialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + dialer := &QUICDialerSaver{ + QUICDialer: child, + Saver: &Saver{}, + } + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } }) - sess, err := dlr.DialContext(context.Background(), "udp", - quictesting.Endpoint("443"), tlsConf, &quic.Config{}) - if err != nil { - t.Fatal("unexpected error", err) - } - if sess == nil { - t.Fatal("unexpected nil sess") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("unexpected number of events") - } - if ev[0].Name() != "quic_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].Value().TLSServerName != quictesting.Domain { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Value().Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err", ev[1].Value().Err) - } - if ev[1].Name() != "quic_handshake_done" { - t.Fatal("unexpected Name") - } - if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[1].Value().TLSServerName != quictesting.Domain { - t.Fatal("unexpected TLSServerName") - } - if ev[1].Value().Time.Before(ev[0].Value().Time) { - t.Fatal("unexpected Time") - } } -func TestHandshakeSaverHostNameError(t *testing.T) { - nextprotos := []string{"h3"} - servername := "example.com" - tlsConf := &tls.Config{ - NextProtos: nextprotos, - ServerName: servername, - } - saver := &Saver{} - dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ - QUICListener: &netxlite.QUICListenerStdlib{}, +func TestQUICListenerSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + qls := saver.WrapQUICListener(&mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { + return nil, expected + }, + }) + pconn, err := qls.Listen(&net.UDPAddr{ + IP: []byte{}, + Port: 8080, + Zone: "", + }) + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if pconn != nil { + t.Fatal("expected nil pconn here") + } }) - sess, err := dlr.DialContext(context.Background(), "udp", - quictesting.Endpoint("443"), tlsConf, &quic.Config{}) - if err == nil { - t.Fatal("expected an error here") - } - if sess != nil { - t.Fatal("expected nil sess here") - } - for _, ev := range saver.Read() { - if ev.Name() != "quic_handshake_done" { - continue + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + returnedConn := &mocks.UDPLikeConn{} + qls := saver.WrapQUICListener(&mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { + return returnedConn, nil + }, + }) + pconn, err := qls.Listen(&net.UDPAddr{ + IP: []byte{}, + Port: 8080, + Zone: "", + }) + if err != nil { + t.Fatal(err) } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") + wconn := pconn.(*quicPacketConnWrapper) + if wconn.UDPLikeConn != returnedConn { + t.Fatal("invalid underlying connection") } - if !strings.HasSuffix(ev.Value().Err.Error(), "tls: handshake failure") { - t.Fatal("unexpected error", ev.Value().Err) + if wconn.saver != saver { + t.Fatal("invalid saver") } - } + }) } -func TestQUICListenerSaverCannotListen(t *testing.T) { - expected := errors.New("mocked error") - saver := &Saver{} - qls := saver.WrapQUICListener(&mocks.QUICListener{ - MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { - return nil, expected - }, - }) - pconn, err := qls.Listen(&net.UDPAddr{ - IP: []byte{}, - Port: 8080, - Zone: "", - }) - if !errors.Is(err, expected) { - t.Fatal("unexpected error", err) - } - if pconn != nil { - t.Fatal("expected nil pconn here") - } -} +func TestQUICPacketConnWrapper(t *testing.T) { + t.Run("ReadFrom", func(t *testing.T) { -func TestSystemDialerSuccessWithReadWrite(t *testing.T) { - // This is the most common use case for collecting reads, writes - tlsConf := &tls.Config{ - NextProtos: []string{"h3"}, - ServerName: quictesting.Domain, - } - saver := &Saver{} - systemdialer := &netxlite.QUICDialerQUICGo{ - QUICListener: saver.WrapQUICListener(&netxlite.QUICListenerStdlib{}), - } - _, err := systemdialer.DialContext(context.Background(), "udp", - quictesting.Endpoint("443"), tlsConf, &quic.Config{}) - if err != nil { - t.Fatal(err) - } - ev := saver.Read() - if len(ev) < 2 { - t.Fatal("unexpected number of events") - } - last := len(ev) - 1 - for idx := 1; idx < last; idx++ { - if ev[idx].Value().Data == nil { - t.Fatal("unexpected Data") - } - if ev[idx].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[idx].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[idx].Value().NumBytes <= 0 { - t.Fatal("unexpected NumBytes") - } - switch ev[idx].Name() { - case netxlite.ReadFromOperation, netxlite.WriteToOperation: - default: - t.Fatal("unexpected Name") - } - if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) { - t.Fatal("unexpected Time", ev[idx].Value().Time, ev[idx-1].Value().Time) - } - } + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + return 0, nil, expected + }, + }, + saver: saver, + } + buf := make([]byte, 1<<17) + count, addr, err := conn.ReadFrom(buf) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("invalid count") + } + if addr != nil { + t.Fatal("invalid addr") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventReadFromOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "" { + t.Fatal("invalid Address") + } + if len(value.Data) != 0 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if !errors.Is(value.Err, expected) { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 0 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + + t.Run("on success", func(t *testing.T) { + expected := []byte{1, 2, 3, 4} + saver := &Saver{} + expectedAddr := &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + MockNetwork: func() string { + return "udp" + }, + } + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + copy(p, expected) + return len(expected), expectedAddr, nil + }, + }, + saver: saver, + } + buf := make([]byte, 1<<17) + count, addr, err := conn.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + if count != 4 { + t.Fatal("invalid count") + } + if addr != expectedAddr { + t.Fatal("invalid addr") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventReadFromOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if len(value.Data) != 4 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if value.Err != nil { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 4 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + }) + + t.Run("WriteTo", func(t *testing.T) { + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return 0, expected + }, + }, + saver: saver, + } + destAddr := &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + } + buf := make([]byte, 7) + count, err := conn.WriteTo(buf, destAddr) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("invalid count") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventWriteToOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if len(value.Data) != 0 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if !errors.Is(value.Err, expected) { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 0 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return 1, nil + }, + }, + saver: saver, + } + destAddr := &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + } + buf := make([]byte, 7) + count, err := conn.WriteTo(buf, destAddr) + if err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("invalid count") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventWriteToOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if len(value.Data) != 1 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if value.Err != nil { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 1 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + }) } diff --git a/internal/engine/netx/tracex/resolver.go b/internal/engine/netx/tracex/resolver.go index db5afd3..62cf1f5 100644 --- a/internal/engine/netx/tracex/resolver.go +++ b/internal/engine/netx/tracex/resolver.go @@ -12,8 +12,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// SaverResolver is a resolver that saves events. -type SaverResolver struct { +// ResolverSaver is a resolver that saves events. +type ResolverSaver struct { // Resolver is the underlying resolver. Resolver model.Resolver @@ -30,14 +30,14 @@ func (s *Saver) WrapResolver(r model.Resolver) model.Resolver { if s == nil { return r } - return &SaverResolver{ + return &ResolverSaver{ Resolver: r, Saver: s, } } // LookupHost implements Resolver.LookupHost -func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { +func (r *ResolverSaver) LookupHost(ctx context.Context, hostname string) ([]string, error) { start := time.Now() r.Saver.Write(&EventResolveStart{&EventValue{ Address: r.Resolver.Address(), @@ -59,30 +59,30 @@ func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]stri return addrs, err } -func (r *SaverResolver) Network() string { +func (r *ResolverSaver) Network() string { return r.Resolver.Network() } -func (r *SaverResolver) Address() string { +func (r *ResolverSaver) Address() string { return r.Resolver.Address() } -func (r *SaverResolver) CloseIdleConnections() { +func (r *ResolverSaver) CloseIdleConnections() { r.Resolver.CloseIdleConnections() } -func (r *SaverResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { +func (r *ResolverSaver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { // TODO(bassosimone): we should probably implement this method return r.Resolver.LookupHTTPS(ctx, domain) } -func (r *SaverResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { +func (r *ResolverSaver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { // TODO(bassosimone): we should probably implement this method return r.Resolver.LookupNS(ctx, domain) } -// SaverDNSTransport is a DNS transport that saves events. -type SaverDNSTransport struct { +// DNSTransportSaver is a DNS transport that saves events. +type DNSTransportSaver struct { // DNSTransport is the underlying DNS transport. DNSTransport model.DNSTransport @@ -99,14 +99,14 @@ func (s *Saver) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport { if s == nil { return txp } - return &SaverDNSTransport{ + return &DNSTransportSaver{ DNSTransport: txp, Saver: s, } } // RoundTrip implements RoundTripper.RoundTrip -func (txp *SaverDNSTransport) RoundTrip( +func (txp *DNSTransportSaver) RoundTrip( ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { start := time.Now() txp.Saver.Write(&EventDNSRoundTripStart{&EventValue{ @@ -129,19 +129,19 @@ func (txp *SaverDNSTransport) RoundTrip( return response, err } -func (txp *SaverDNSTransport) Network() string { +func (txp *DNSTransportSaver) Network() string { return txp.DNSTransport.Network() } -func (txp *SaverDNSTransport) Address() string { +func (txp *DNSTransportSaver) Address() string { return txp.DNSTransport.Address() } -func (txp *SaverDNSTransport) CloseIdleConnections() { +func (txp *DNSTransportSaver) CloseIdleConnections() { txp.DNSTransport.CloseIdleConnections() } -func (txp *SaverDNSTransport) RequiresPadding() bool { +func (txp *DNSTransportSaver) RequiresPadding() bool { return txp.DNSTransport.RequiresPadding() } @@ -157,5 +157,5 @@ func dnsMaybeResponseBytes(response model.DNSResponse) []byte { return response.Bytes() } -var _ model.Resolver = &SaverResolver{} -var _ model.DNSTransport = &SaverDNSTransport{} +var _ model.Resolver = &ResolverSaver{} +var _ model.DNSTransport = &DNSTransportSaver{} diff --git a/internal/engine/netx/tracex/resolver_test.go b/internal/engine/netx/tracex/resolver_test.go index dd1eede..631ee64 100644 --- a/internal/engine/netx/tracex/resolver_test.go +++ b/internal/engine/netx/tracex/resolver_test.go @@ -14,220 +14,224 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) -func TestSaverResolverFailure(t *testing.T) { - expected := errors.New("no such host") - saver := &Saver{} - reso := saver.WrapResolver(NewFakeResolverWithExplicitError(expected)) - addrs, err := reso.LookupHost(context.Background(), "www.google.com") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil address here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if ev[0].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[0].Name() != "resolve_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if ev[1].Value().Addresses != nil { - t.Fatal("unexpected Addresses") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if !errors.Is(ev[1].Value().Err, expected) { - t.Fatal("unexpected Err") - } - if ev[1].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[1].Name() != "resolve_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } -} - -func TestSaverResolverSuccess(t *testing.T) { - expected := []string{"8.8.8.8", "8.8.4.4"} - saver := &Saver{} - reso := saver.WrapResolver(NewFakeResolverWithResult(expected)) - addrs, err := reso.LookupHost(context.Background(), "www.google.com") - if err != nil { - t.Fatal("expected nil error here") - } - if !reflect.DeepEqual(addrs, expected) { - t.Fatal("not the result we expected") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if ev[0].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[0].Name() != "resolve_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { - t.Fatal("unexpected Addresses") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[1].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[1].Name() != "resolve_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } -} - -func TestSaverDNSTransportFailure(t *testing.T) { - expected := errors.New("no such host") - saver := &Saver{} - txp := saver.WrapDNSTransport(&mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return nil, expected - }, - MockNetwork: func() string { - return "fake" - }, - MockAddress: func() string { - return "" - }, +func TestResolverSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("no such host") + saver := &Saver{} + reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected)) + addrs, err := reso.LookupHost(context.Background(), "www.google.com") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if ev[0].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[0].Name() != "resolve_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if ev[1].Value().Addresses != nil { + t.Fatal("unexpected Addresses") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[1].Value().Err, expected) { + t.Fatal("unexpected Err") + } + if ev[1].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[1].Name() != "resolve_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } }) - rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return rawQuery, nil - }, - } - reply, err := txp.RoundTrip(context.Background(), query) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[0].Name() != "dns_round_trip_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[1].Value().DNSResponse != nil { - t.Fatal("unexpected DNSReply") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if !errors.Is(ev[1].Value().Err, expected) { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "dns_round_trip_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } -} -func TestSaverDNSTransportSuccess(t *testing.T) { - expected := []byte{0xef, 0xbe, 0xad, 0xde} - saver := &Saver{} - response := &mocks.DNSResponse{ - MockBytes: func() []byte { - return expected - }, - } - txp := saver.WrapDNSTransport(&mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return response, nil - }, - MockNetwork: func() string { - return "fake" - }, - MockAddress: func() string { - return "" - }, + t.Run("on success", func(t *testing.T) { + expected := []string{"8.8.8.8", "8.8.4.4"} + saver := &Saver{} + reso := saver.WrapResolver(newFakeResolverWithResult(expected)) + addrs, err := reso.LookupHost(context.Background(), "www.google.com") + if err != nil { + t.Fatal("expected nil error here") + } + if !reflect.DeepEqual(addrs, expected) { + t.Fatal("not the result we expected") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if ev[0].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[0].Name() != "resolve_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { + t.Fatal("unexpected Addresses") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err != nil { + t.Fatal("unexpected Err") + } + if ev[1].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[1].Name() != "resolve_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } }) - rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return rawQuery, nil - }, - } - reply, err := txp.RoundTrip(context.Background(), query) - if err != nil { - t.Fatal("we expected nil error here") - } - if !bytes.Equal(reply.Bytes(), expected) { - t.Fatal("expected another reply here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[0].Name() != "dns_round_trip_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if !bytes.Equal(ev[1].Value().DNSResponse, expected) { - t.Fatal("unexpected DNSReply") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "dns_round_trip_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } } -func NewFakeResolverWithExplicitError(err error) model.Resolver { +func TestDNSTransportSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("no such host") + saver := &Saver{} + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expected + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } + reply, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[0].Name() != "dns_round_trip_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[1].Value().DNSResponse != nil { + t.Fatal("unexpected DNSReply") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[1].Value().Err, expected) { + t.Fatal("unexpected Err") + } + if ev[1].Name() != "dns_round_trip_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } + }) + + t.Run("on success", func(t *testing.T) { + expected := []byte{0xef, 0xbe, 0xad, 0xde} + saver := &Saver{} + response := &mocks.DNSResponse{ + MockBytes: func() []byte { + return expected + }, + } + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return response, nil + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } + reply, err := txp.RoundTrip(context.Background(), query) + if err != nil { + t.Fatal("we expected nil error here") + } + if !bytes.Equal(reply.Bytes(), expected) { + t.Fatal("expected another reply here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[0].Name() != "dns_round_trip_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if !bytes.Equal(ev[1].Value().DNSResponse, expected) { + t.Fatal("unexpected DNSReply") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err != nil { + t.Fatal("unexpected Err") + } + if ev[1].Name() != "dns_round_trip_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } + }) +} + +func newFakeResolverWithExplicitError(err error) model.Resolver { runtimex.PanicIfNil(err, "passed nil error") return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { @@ -251,7 +255,7 @@ func NewFakeResolverWithExplicitError(err error) model.Resolver { } } -func NewFakeResolverWithResult(r []string) model.Resolver { +func newFakeResolverWithResult(r []string) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return r, nil diff --git a/internal/engine/netx/tracex/saver_test.go b/internal/engine/netx/tracex/saver_test.go index 784e168..3f925e4 100644 --- a/internal/engine/netx/tracex/saver_test.go +++ b/internal/engine/netx/tracex/saver_test.go @@ -3,22 +3,88 @@ package tracex import ( "sync" "testing" + + "github.com/ooni/probe-cli/v3/internal/model/mocks" ) func TestSaver(t *testing.T) { - saver := Saver{} - var wg sync.WaitGroup - const parallel = 10 - wg.Add(parallel) - for idx := 0; idx < parallel; idx++ { - go func() { - saver.Write(&EventReadFromOperation{&EventValue{}}) - wg.Done() - }() - } - wg.Wait() - ev := saver.Read() - if len(ev) != parallel { - t.Fatal("unexpected number of events read") - } + t.Run("concurrent writes followed by read", func(t *testing.T) { + saver := Saver{} + var wg sync.WaitGroup + const parallel = 10 + wg.Add(parallel) + for idx := 0; idx < parallel; idx++ { + go func() { + saver.Write(&EventReadFromOperation{&EventValue{}}) + wg.Done() + }() + } + wg.Wait() + ev := saver.Read() + if len(ev) != parallel { + t.Fatal("unexpected number of events read") + } + }) + + t.Run("NewConnectObserver", func(t *testing.T) { + t.Run("nil Saver", func(t *testing.T) { + var saver *Saver + obs := saver.NewConnectObserver() + if obs != nil { + t.Fatal("expected nil observer") + } + }) + + t.Run("nonnnil Saver", func(t *testing.T) { + saver := &Saver{} + obs := saver.NewConnectObserver() + underlying := obs.(*dialerConnectObserver) + if underlying.saver != saver { + t.Fatal("invalid saver") + } + }) + }) + + t.Run("NewReadWriteObserver", func(t *testing.T) { + t.Run("nil Saver", func(t *testing.T) { + var saver *Saver + obs := saver.NewReadWriteObserver() + if obs != nil { + t.Fatal("expected nil observer") + } + }) + + t.Run("nonnnil Saver", func(t *testing.T) { + saver := &Saver{} + obs := saver.NewReadWriteObserver() + underlying := obs.(*dialerReadWriteObserver) + if underlying.saver != saver { + t.Fatal("invalid saver") + } + }) + }) + + t.Run("WrapQUICDialer", func(t *testing.T) { + t.Run("nil Saver", func(t *testing.T) { + var saver *Saver + base := &mocks.QUICDialer{} + qd := saver.WrapQUICDialer(base) + if qd != base { + t.Fatal("unexpected returned QUICDialer") + } + }) + + t.Run("nonnnil Saver", func(t *testing.T) { + saver := &Saver{} + base := &mocks.QUICDialer{} + qd := saver.WrapQUICDialer(base) + underlying := qd.(*QUICDialerSaver) + if underlying.Saver != saver { + t.Fatal("invalid Saver") + } + if underlying.QUICDialer != base { + t.Fatal("invalid QUICDialer") + } + }) + }) } diff --git a/internal/engine/netx/tracex/tls.go b/internal/engine/netx/tracex/tls.go index 32eb51c..02ea359 100644 --- a/internal/engine/netx/tracex/tls.go +++ b/internal/engine/netx/tracex/tls.go @@ -16,8 +16,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// SaverTLSHandshaker saves events occurring during the TLS handshake. -type SaverTLSHandshaker struct { +// TLSHandshakerSaver saves events occurring during the TLS handshake. +type TLSHandshakerSaver struct { // TLSHandshaker is the underlying TLS handshaker. TLSHandshaker model.TLSHandshaker @@ -34,23 +34,26 @@ func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker { if s == nil { return thx } - return &SaverTLSHandshaker{ + return &TLSHandshakerSaver{ TLSHandshaker: thx, Saver: s, } } // Handshake implements model.TLSHandshaker.Handshake -func (h *SaverTLSHandshaker) Handshake( +func (h *TLSHandshakerSaver) Handshake( ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + proto := conn.RemoteAddr().Network() + remoteAddr := conn.RemoteAddr().String() start := time.Now() h.Saver.Write(&EventTLSHandshakeStart{&EventValue{ + Address: remoteAddr, NoTLSVerify: config.InsecureSkipVerify, + Proto: proto, TLSNextProtos: config.NextProtos, TLSServerName: config.ServerName, Time: start, }}) - remoteAddr := conn.RemoteAddr().String() tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) stop := time.Now() h.Saver.Write(&EventTLSHandshakeDone{&EventValue{ @@ -58,6 +61,7 @@ func (h *SaverTLSHandshaker) Handshake( Duration: stop.Sub(start), Err: err, NoTLSVerify: config.InsecureSkipVerify, + Proto: proto, TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSNegotiatedProto: state.NegotiatedProtocol, TLSNextProtos: config.NextProtos, @@ -69,7 +73,7 @@ func (h *SaverTLSHandshaker) Handshake( return tlsconn, state, err } -var _ model.TLSHandshaker = &SaverTLSHandshaker{} +var _ model.TLSHandshaker = &TLSHandshakerSaver{} // tlsPeerCerts returns the certificates presented by the peer regardless // of whether the TLS handshake was successful diff --git a/internal/engine/netx/tracex/tls_test.go b/internal/engine/netx/tracex/tls_test.go index 96da0c6..fc0eaf5 100644 --- a/internal/engine/netx/tracex/tls_test.go +++ b/internal/engine/netx/tracex/tls_test.go @@ -3,291 +3,250 @@ package tracex import ( "context" "crypto/tls" - "reflect" + "crypto/x509" + "errors" + "net" "testing" - "time" - "github.com/ooni/probe-cli/v3/internal/model" - "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model/mocks" ) -func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { - // This is the most common use case for collecting reads, writes - if testing.Short() { - t.Skip("skip test in short mode") - } - nextprotos := []string{"h2"} - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{NextProtos: nextprotos}, - Dialer: netxlite.NewDialerWithResolver( - model.DiscardLogger, - netxlite.NewResolverStdlib(model.DiscardLogger), - saver.NewReadWriteObserver(), - ), - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - // Implementation note: we don't close the connection here because it is - // very handy to have the last event being the end of the handshake - _, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal(err) - } - ev := saver.Read() - if len(ev) < 4 { - // it's a bit tricky to be sure about the right number of - // events because network conditions may influence that - t.Fatal("unexpected number of events") - } - if ev[0].Name() != "tls_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Value().Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - last := len(ev) - 1 - for idx := 1; idx < last; idx++ { - if ev[idx].Value().Data == nil { - t.Fatal("unexpected Data") +func TestTLSHandshakerSaver(t *testing.T) { + + t.Run("Handshake", func(t *testing.T) { + checkStartEventFields := func(t *testing.T, value *EventValue) { + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if !value.NoTLSVerify { + t.Fatal("expected NoTLSVerify to be true") + } + if value.Proto != "tcp" { + t.Fatal("wrong protocol") + } + if diff := cmp.Diff(value.TLSNextProtos, []string{"h2"}); diff != "" { + t.Fatal(diff) + } + if value.TLSServerName != "dns.google" { + t.Fatal("invalid TLSServerName") + } + if value.Time.IsZero() { + t.Fatal("expected non zero time") + } } - if ev[idx].Value().Duration <= 0 { - t.Fatal("unexpected Duration") + + checkStartedEvent := func(t *testing.T, ev Event) { + if _, good := ev.(*EventTLSHandshakeStart); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) } - if ev[idx].Value().Err != nil { - t.Fatal("unexpected Err") + + checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err != nil { + t.Fatal("expected no error here") + } + if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" { + t.Fatal("invalid cipher suite") + } + if value.TLSNegotiatedProto != "h2" { + t.Fatal("invalid negotiated protocol") + } + if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" { + t.Fatal(diff) + } + if value.TLSVersion != "TLSv1.3" { + t.Fatal("invalid TLS version") + } } - if ev[idx].Value().NumBytes <= 0 { - t.Fatal("unexpected NumBytes") + + checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) { + if _, good := ev.(*EventTLSHandshakeDone); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) + fun(t, value) } - switch ev[idx].Name() { - case netxlite.ReadOperation, netxlite.WriteOperation: - default: - t.Fatal("unexpected Name") + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + returnedConnState := tls.ConnectionState{ + CipherSuite: tls.TLS_RSA_WITH_RC4_128_SHA, + NegotiatedProtocol: "h2", + PeerCertificates: []*x509.Certificate{}, + Version: tls.VersionTLS13, + } + returnedConn := &mocks.TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return returnedConnState + }, + } + thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, + config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return returnedConn, returnedConnState, nil + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + ServerName: "dns.google", + } + tcpConn := &mocks.Conn{ + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess) + }) + + checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err == nil { + t.Fatal("expected non-nil error here") + } + if value.TLSCipherSuite != "" { + t.Fatal("invalid TLS cipher suite") + } + if value.TLSNegotiatedProto != "" { + t.Fatal("invalid negotiated proto") + } + if len(value.TLSPeerCerts) > 0 { + t.Fatal("expected no peer certs") + } + if value.TLSVersion != "" { + t.Fatal("invalid TLS version") + } } - if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) { - t.Fatal("unexpected Time") - } - } - if ev[last].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[last].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[last].Name() != "tls_handshake_done" { - t.Fatal("unexpected Name") - } - if ev[last].Value().TLSCipherSuite == "" { - t.Fatal("unexpected TLSCipherSuite") - } - if ev[last].Value().TLSNegotiatedProto != "h2" { - t.Fatal("unexpected TLSNegotiatedProto") - } - if !reflect.DeepEqual(ev[last].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[last].Value().TLSPeerCerts == nil { - t.Fatal("unexpected TLSPeerCerts") - } - if ev[last].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if ev[last].Value().TLSVersion == "" { - t.Fatal("unexpected TLSVersion") - } - if ev[last].Value().Time.Before(ev[last-1].Value().Time) { - t.Fatal("unexpected Time") - } + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, + config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, expected + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + ServerName: "dns.google", + } + tcpConn := &mocks.Conn{ + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsFailure) + }) + }) } -func TestSaverTLSHandshakerSuccess(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") +func Test_tlsPeerCerts(t *testing.T) { + cert0 := &x509.Certificate{Raw: []byte{1, 2, 3, 4}} + type args struct { + state tls.ConnectionState + err error } - nextprotos := []string{"h2"} - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{NextProtos: nextprotos}, - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal(err) - } - conn.Close() - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("unexpected number of events") - } - if ev[0].Name() != "tls_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Value().Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "tls_handshake_done" { - t.Fatal("unexpected Name") - } - if ev[1].Value().TLSCipherSuite == "" { - t.Fatal("unexpected TLSCipherSuite") - } - if ev[1].Value().TLSNegotiatedProto != "h2" { - t.Fatal("unexpected TLSNegotiatedProto") - } - if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[1].Value().TLSPeerCerts == nil { - t.Fatal("unexpected TLSPeerCerts") - } - if ev[1].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if ev[1].Value().TLSVersion == "" { - t.Fatal("unexpected TLSVersion") - } - if ev[1].Value().Time.Before(ev[0].Value().Time) { - t.Fatal("unexpected Time") - } -} - -func TestSaverTLSHandshakerHostnameError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "wrong.host.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "expired.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerAuthorityError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "self-signed.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{InsecureSkipVerify: true}, - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "self-signed.badssl.com:443") - if err != nil { - t.Fatal(err) - } - if conn == nil { - t.Fatal("expected non-nil conn here") - } - conn.Close() - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify != true { - t.Fatal("expected NoTLSVerify to be true") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } + tests := []struct { + name string + args args + want []*x509.Certificate + }{{ + name: "no error", + args: args{ + state: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{cert0}, + }, + }, + want: []*x509.Certificate{cert0}, + }, { + name: "all empty", + args: args{}, + want: nil, + }, { + name: "x509.HostnameError", + args: args{ + state: tls.ConnectionState{}, + err: x509.HostnameError{ + Certificate: cert0, + }, + }, + want: []*x509.Certificate{cert0}, + }, { + name: "x509.UnknownAuthorityError", + args: args{ + state: tls.ConnectionState{}, + err: x509.UnknownAuthorityError{ + Cert: cert0, + }, + }, + want: []*x509.Certificate{cert0}, + }, { + name: "x509.CertificateInvalidError", + args: args{ + state: tls.ConnectionState{}, + err: x509.CertificateInvalidError{ + Cert: cert0, + }, + }, + want: []*x509.Certificate{cert0}, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tlsPeerCerts(tt.args.state, tt.args.err) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) } }