diff --git a/internal/engine/experiment/urlgetter/configurer_test.go b/internal/engine/experiment/urlgetter/configurer_test.go index bdb5212..bc3ff2a 100644 --- a/internal/engine/experiment/urlgetter/configurer_test.go +++ b/internal/engine/experiment/urlgetter/configurer_test.go @@ -119,7 +119,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } @@ -195,7 +195,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } @@ -271,7 +271,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T) if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } @@ -347,7 +347,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) + stxp, ok := sr.Txp.(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the DNS transport we expected") } diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 3b07e56..29ae609 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -187,7 +187,7 @@ func NewHTTPTransport(config Config) model.HTTPTransport { txp = &netxlite.HTTPTransportLogger{Logger: config.Logger, HTTPTransport: txp} } if config.HTTPSaver != nil { - txp = &tracex.SaverTransactionHTTPTransport{ + txp = &tracex.HTTPTransportSaver{ HTTPTransport: txp, Saver: config.HTTPSaver} } return txp diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index d51bdf8..8a454fe 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -126,7 +126,7 @@ func TestNewResolverWithSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - sr, ok := ir.Resolver.(*tracex.SaverResolver) + sr, ok := ir.Resolver.(*tracex.ResolverSaver) if !ok { t.Fatal("not the resolver we expected") } @@ -332,7 +332,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) { if rtd.TLSHandshaker == nil { t.Fatal("invalid TLSHandshaker") } - sth, ok := rtd.TLSHandshaker.(*tracex.SaverTLSHandshaker) + sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver) if !ok { t.Fatal("not the TLSHandshaker we expected") } @@ -504,7 +504,7 @@ func TestNewWithSaver(t *testing.T) { txp := netx.NewHTTPTransport(netx.Config{ HTTPSaver: saver, }) - stxptxp, ok := txp.(*tracex.SaverTransactionHTTPTransport) + stxptxp, ok := txp.(*tracex.HTTPTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -622,7 +622,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -659,7 +659,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -700,7 +700,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } @@ -745,7 +745,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(*tracex.SaverDNSTransport) + txp, ok := r.Transport().(*tracex.DNSTransportSaver) if !ok { t.Fatal("not the transport we expected") } diff --git a/internal/engine/netx/tracex/archival.go b/internal/engine/netx/tracex/archival.go index 15dd5e2..9f9a3e6 100644 --- a/internal/engine/netx/tracex/archival.go +++ b/internal/engine/netx/tracex/archival.go @@ -1,5 +1,9 @@ package tracex +// +// Code to generate the OONI archival data format from events +// + import ( "crypto/x509" "errors" @@ -43,8 +47,7 @@ var ( ) // NewTCPConnectList creates a new TCPConnectList -func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry { - var out []TCPConnectEntry +func NewTCPConnectList(begin time.Time, events []Event) (out []TCPConnectEntry) { for _, wrapper := range events { if _, ok := wrapper.(*EventConnectOperation); !ok { continue @@ -60,13 +63,14 @@ func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry { IP: ip, Port: iport, Status: TCPConnectStatus{ + Blocked: nil, // only used by Web Connectivity Failure: NewFailure(event.Err), Success: event.Err == nil, }, T: event.Time.Sub(begin).Seconds(), }) } - return out + return } // NewFailure creates a failure nullable string from the given error @@ -101,11 +105,9 @@ func NewFailedOperation(err error) *string { return &s } -func httpAddHeaders( - source http.Header, - destList *[]HTTPHeader, - destMap *map[string]MaybeBinaryValue, -) { +// httpAddHeaders adds the headers inside source into destList and destMap. +func httpAddHeaders(source http.Header, destList *[]HTTPHeader, + destMap *map[string]MaybeBinaryValue) { *destList = []HTTPHeader{} *destMap = make(map[string]model.ArchivalMaybeBinaryData) for key, values := range source { @@ -122,32 +124,28 @@ func httpAddHeaders( }) } } + // Sorting helps with unit testing (map keys are unordered) sort.Slice(*destList, func(i, j int) bool { return (*destList)[i].Key < (*destList)[j].Key }) } // NewRequestList returns the list for "requests" -func NewRequestList(begin time.Time, events []Event) []RequestEntry { +func NewRequestList(begin time.Time, events []Event) (out []RequestEntry) { // OONI wants the last request to appear first - var out []RequestEntry tmp := newRequestList(begin, events) for i := len(tmp) - 1; i >= 0; i-- { out = append(out, tmp[i]) } - return out + return } -func newRequestList(begin time.Time, events []Event) []RequestEntry { - var ( - out []RequestEntry - entry RequestEntry - ) +func newRequestList(begin time.Time, events []Event) (out []RequestEntry) { for _, wrapper := range events { ev := wrapper.Value() switch wrapper.(type) { case *EventHTTPTransactionDone: - entry = RequestEntry{} + entry := RequestEntry{} entry.T = ev.Time.Sub(begin).Seconds() httpAddHeaders( ev.HTTPRequestHeaders, &entry.Request.HeadersList, &entry.Request.Headers) @@ -164,15 +162,14 @@ func newRequestList(begin time.Time, events []Event) []RequestEntry { out = append(out, entry) } } - return out + return } type dnsQueryType string // NewDNSQueriesList returns a list of DNS queries. -func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry { +func NewDNSQueriesList(begin time.Time, events []Event) (out []DNSQueryEntry) { // TODO(bassosimone): add support for CNAME lookups. - var out []DNSQueryEntry for _, wrapper := range events { if _, ok := wrapper.(*EventResolveDone); !ok { continue @@ -199,7 +196,7 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry { out = append(out, entry) } } - return out + return } func (qtype dnsQueryType) ipOfType(addr string) bool { @@ -214,6 +211,8 @@ func (qtype dnsQueryType) ipOfType(addr string) bool { func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry { answer := DNSAnswerEntry{AnswerType: string(qtype)} + // Figuring out the ASN and the org here is not just a service to whoever + // is reading a JSON: Web Connectivity also depends on it! asn, org, _ := geolocate.LookupASN(addr) answer.ASN = int64(asn) answer.ASOrgName = org @@ -237,9 +236,8 @@ func (qtype dnsQueryType) makeQueryEntry(begin time.Time, ev *EventValue) DNSQue } } -// NewNetworkEventsList returns a list of DNS queries. -func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent { - var out []NetworkEvent +// NewNetworkEventsList returns a list of network events. +func NewNetworkEventsList(begin time.Time, events []Event) (out []NetworkEvent) { for _, wrapper := range events { ev := wrapper.Value() switch wrapper.(type) { @@ -281,7 +279,7 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent { NumBytes: int64(ev.NumBytes), T: ev.Time.Sub(begin).Seconds(), }) - default: + default: // For example, "tls_handshake_done" (used in data analysis!) out = append(out, NetworkEvent{ Failure: NewFailure(ev.Err), Operation: wrapper.Name(), @@ -289,15 +287,14 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent { }) } } - return out + return } // NewTLSHandshakesList creates a new TLSHandshakesList -func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake { - var out []TLSHandshake +func NewTLSHandshakesList(begin time.Time, events []Event) (out []TLSHandshake) { for _, wrapper := range events { switch wrapper.(type) { - case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // ok + case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // interested default: continue // not interested } @@ -314,12 +311,12 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake { TLSVersion: ev.TLSVersion, }) } - return out + return } func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) { - for _, e := range in { - out = append(out, MaybeBinaryValue{Value: string(e.Raw)}) + for _, entry := range in { + out = append(out, MaybeBinaryValue{Value: string(entry.Raw)}) } return } diff --git a/internal/engine/netx/tracex/archival_test.go b/internal/engine/netx/tracex/archival_test.go index 5b9d8ed..ade7391 100644 --- a/internal/engine/netx/tracex/archival_test.go +++ b/internal/engine/netx/tracex/archival_test.go @@ -6,7 +6,6 @@ import ( "errors" "io" "net/http" - "reflect" "testing" "time" @@ -15,42 +14,44 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestDNSQueryIPOfType(t *testing.T) { - type expectation struct { - qtype dnsQueryType - ip string - output bool - } - var expectations = []expectation{{ - qtype: "A", - ip: "8.8.8.8", - output: true, - }, { - qtype: "A", - ip: "2a00:1450:4002:801::2004", - output: false, - }, { - qtype: "AAAA", - ip: "8.8.8.8", - output: false, - }, { - qtype: "AAAA", - ip: "2a00:1450:4002:801::2004", - output: true, - }, { - qtype: "ANTANI", - ip: "2a00:1450:4002:801::2004", - output: false, - }, { - qtype: "ANTANI", - ip: "8.8.8.8", - output: false, - }} - for _, exp := range expectations { - if exp.qtype.ipOfType(exp.ip) != exp.output { - t.Fatalf("failure for %+v", exp) +func TestDNSQueryType(t *testing.T) { + t.Run("ipOfType", func(t *testing.T) { + type expectation struct { + qtype dnsQueryType + ip string + output bool } - } + var expectations = []expectation{{ + qtype: "A", + ip: "8.8.8.8", + output: true, + }, { + qtype: "A", + ip: "2a00:1450:4002:801::2004", + output: false, + }, { + qtype: "AAAA", + ip: "8.8.8.8", + output: false, + }, { + qtype: "AAAA", + ip: "2a00:1450:4002:801::2004", + output: true, + }, { + qtype: "ANTANI", + ip: "2a00:1450:4002:801::2004", + output: false, + }, { + qtype: "ANTANI", + ip: "8.8.8.8", + output: false, + }} + for _, exp := range expectations { + if exp.qtype.ipOfType(exp.ip) != exp.output { + t.Fatalf("failure for %+v", exp) + } + } + }) } func TestNewTCPConnectList(t *testing.T) { @@ -74,7 +75,7 @@ func TestNewTCPConnectList(t *testing.T) { name: "realistic run", args: args{ begin: begin, - events: []Event{&EventResolveDone{&EventValue{ + events: []Event{&EventResolveDone{&EventValue{ // skipped because not relevant Addresses: []string{"8.8.8.8", "8.8.4.4"}, Hostname: "dns.google.com", Time: begin.Add(100 * time.Millisecond), @@ -86,7 +87,7 @@ func TestNewTCPConnectList(t *testing.T) { }}, &EventConnectOperation{&EventValue{ Address: "8.8.8.8:853", Duration: 55 * time.Millisecond, - Proto: "udp", + Proto: "udp", // this one should be skipped because it's UDP Time: begin.Add(130 * time.Millisecond), }}, &EventConnectOperation{&EventValue{ Address: "8.8.4.4:53", @@ -115,8 +116,9 @@ func TestNewTCPConnectList(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewTCPConnectList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(got, tt.want)) + got := NewTCPConnectList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } @@ -143,6 +145,7 @@ func TestNewRequestList(t *testing.T) { name: "realistic run", args: args{ begin: begin, + // Two round trips so we can test the sorting expected by OONI events: []Event{&EventHTTPTransactionDone{&EventValue{ HTTPRequestHeaders: http.Header{ "User-Agent": []string{"miniooni/0.1.0-dev"}, @@ -286,8 +289,9 @@ func TestNewRequestList(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewRequestList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(tt.want, got)) + got := NewRequestList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } @@ -320,12 +324,12 @@ func TestNewDNSQueriesList(t *testing.T) { Hostname: "dns.google.com", Proto: "dot", Time: begin.Add(100 * time.Millisecond), - }}, &EventConnectOperation{&EventValue{ + }}, &EventConnectOperation{&EventValue{ // skipped because not relevant Address: "8.8.8.8:853", Duration: 30 * time.Millisecond, Proto: "tcp", Time: begin.Add(130 * time.Millisecond), - }}, &EventConnectOperation{&EventValue{ + }}, &EventConnectOperation{&EventValue{ // skipped because not relevant Address: "8.8.4.4:53", Duration: 50 * time.Millisecond, Err: io.EOF, @@ -452,6 +456,10 @@ func TestNewNetworkEventsList(t *testing.T) { Err: websocket.ErrBadHandshake, NumBytes: 4114, Time: begin.Add(14 * time.Millisecond), + }}, &EventResolveStart{&EventValue{ + // We expect this event to be logged event though it's not a typical I/O + // event (it seems these extra events are useful for debugging) + Time: begin.Add(15 * time.Millisecond), }}}, }, want: []NetworkEvent{{ @@ -482,12 +490,16 @@ func TestNewNetworkEventsList(t *testing.T) { NumBytes: 4114, Operation: netxlite.WriteToOperation, T: 0.014, + }, { + Operation: "resolve_start", + T: 0.015, }}, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewNetworkEventsList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(got, tt.want)) + got := NewNetworkEventsList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } @@ -511,13 +523,14 @@ func TestNewTLSHandshakesList(t *testing.T) { }, want: nil, }, { - name: "realistic run", + name: "realistic run with TLS", args: args{ begin: begin, events: []Event{&EventTLSHandshakeDone{&EventValue{ Address: "131.252.210.176:443", Err: io.EOF, NoTLSVerify: false, + Proto: "tcp", TLSCipherSuite: "SUITE", TLSNegotiatedProto: "h2", TLSPeerCerts: []*x509.Certificate{{ @@ -545,11 +558,57 @@ func TestNewTLSHandshakesList(t *testing.T) { T: 0.055, TLSVersion: "TLSv1.3", }}, + }, { + name: "realistic run with QUIC", + args: args{ + begin: begin, + events: []Event{&EventQUICHandshakeDone{&EventValue{ + Address: "131.252.210.176:443", + Err: io.EOF, + NoTLSVerify: false, + Proto: "quic", + TLSCipherSuite: "SUITE", + TLSNegotiatedProto: "h3", + TLSPeerCerts: []*x509.Certificate{{ + Raw: []byte("deadbeef"), + }, { + Raw: []byte("abad1dea"), + }}, + TLSServerName: "x.org", + TLSVersion: "TLSv1.3", + Time: begin.Add(55 * time.Millisecond), + }}}, + }, + want: []TLSHandshake{{ + Address: "131.252.210.176:443", + CipherSuite: "SUITE", + Failure: NewFailure(io.EOF), + NegotiatedProtocol: "h3", + NoTLSVerify: false, + PeerCertificates: []MaybeBinaryValue{{ + Value: "deadbeef", + }, { + Value: "abad1dea", + }}, + ServerName: "x.org", + T: 0.055, + TLSVersion: "TLSv1.3", + }}, + }, { + name: "realistic run with no suitable events", + args: args{ + begin: begin, + events: []Event{&EventResolveStart{&EventValue{ + Time: begin.Add(55 * time.Millisecond), + }}}, + }, + want: nil, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewTLSHandshakesList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { - t.Error(cmp.Diff(got, tt.want)) + got := NewTLSHandshakesList(tt.args.begin, tt.args.events) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) } }) } diff --git a/internal/engine/netx/tracex/dialer.go b/internal/engine/netx/tracex/dialer.go index 5096b64..8820850 100644 --- a/internal/engine/netx/tracex/dialer.go +++ b/internal/engine/netx/tracex/dialer.go @@ -12,8 +12,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// SaverDialer saves events occurring during the dial -type SaverDialer struct { +// DialerSaver saves events occurring during the dial +type DialerSaver struct { // Dialer is the underlying dialer, Dialer model.Dialer @@ -28,26 +28,26 @@ func (s *Saver) NewConnectObserver() model.DialerWrapper { if s == nil { return nil // valid DialerWrapper according to netxlite's docs } - return &saverDialerWrapper{ + return &dialerConnectObserver{ saver: s, } } -type saverDialerWrapper struct { +type dialerConnectObserver struct { saver *Saver } -var _ model.DialerWrapper = &saverDialerWrapper{} +var _ model.DialerWrapper = &dialerConnectObserver{} -func (w *saverDialerWrapper) WrapDialer(d model.Dialer) model.Dialer { - return &SaverDialer{ +func (w *dialerConnectObserver) WrapDialer(d model.Dialer) model.Dialer { + return &DialerSaver{ Dialer: d, Saver: w.saver, } } // DialContext implements Dialer.DialContext -func (d *SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { start := time.Now() conn, err := d.Dialer.DialContext(ctx, network, address) stop := time.Now() @@ -61,13 +61,13 @@ func (d *SaverDialer) DialContext(ctx context.Context, network, address string) return conn, err } -func (d *SaverDialer) CloseIdleConnections() { +func (d *DialerSaver) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -// SaverConnDialer wraps the returned connection such that we +// DialerConnSaver wraps the returned connection such that we // collect all the read/write events that occur. -type SaverConnDialer struct { +type DialerConnSaver struct { // Dialer is the underlying dialer Dialer model.Dialer @@ -82,70 +82,78 @@ func (s *Saver) NewReadWriteObserver() model.DialerWrapper { if s == nil { return nil // valid DialerWrapper according to netxlite's docs } - return &saverReadWriteWrapper{ + return &dialerReadWriteObserver{ saver: s, } } -type saverReadWriteWrapper struct { +type dialerReadWriteObserver struct { saver *Saver } -var _ model.DialerWrapper = &saverReadWriteWrapper{} +var _ model.DialerWrapper = &dialerReadWriteObserver{} -func (w *saverReadWriteWrapper) WrapDialer(d model.Dialer) model.Dialer { - return &SaverConnDialer{ +func (w *dialerReadWriteObserver) WrapDialer(d model.Dialer) model.Dialer { + return &DialerConnSaver{ Dialer: d, Saver: w.saver, } } // DialContext implements Dialer.DialContext -func (d *SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *DialerConnSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { return nil, err } - return &saverConn{saver: d.Saver, Conn: conn}, nil + return &dialerConnWrapper{saver: d.Saver, Conn: conn}, nil } -func (d *SaverConnDialer) CloseIdleConnections() { +func (d *DialerConnSaver) CloseIdleConnections() { d.Dialer.CloseIdleConnections() } -type saverConn struct { +type dialerConnWrapper struct { net.Conn saver *Saver } -func (c *saverConn) Read(p []byte) (int, error) { +func (c *dialerConnWrapper) Read(p []byte) (int, error) { + proto := c.Conn.RemoteAddr().Network() + remoteAddr := c.Conn.RemoteAddr().String() start := time.Now() count, err := c.Conn.Read(p) stop := time.Now() c.saver.Write(&EventReadOperation{&EventValue{ + Address: remoteAddr, Data: p[:count], Duration: stop.Sub(start), Err: err, NumBytes: count, + Proto: proto, Time: stop, }}) return count, err } -func (c *saverConn) Write(p []byte) (int, error) { +func (c *dialerConnWrapper) Write(p []byte) (int, error) { + proto := c.Conn.RemoteAddr().Network() + remoteAddr := c.Conn.RemoteAddr().String() start := time.Now() count, err := c.Conn.Write(p) stop := time.Now() c.saver.Write(&EventWriteOperation{&EventValue{ + Address: remoteAddr, Data: p[:count], Duration: stop.Sub(start), Err: err, NumBytes: count, + Proto: proto, Time: stop, }}) return count, err } -var _ model.Dialer = &SaverDialer{} -var _ model.Dialer = &SaverConnDialer{} -var _ net.Conn = &saverConn{} +var _ model.Dialer = &DialerSaver{} +var _ model.Dialer = &DialerConnSaver{} +var _ net.Conn = &dialerConnWrapper{} diff --git a/internal/engine/netx/tracex/dialer_test.go b/internal/engine/netx/tracex/dialer_test.go index 8608faa..2447ce4 100644 --- a/internal/engine/netx/tracex/dialer_test.go +++ b/internal/engine/netx/tracex/dialer_test.go @@ -12,127 +12,268 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestSaverDialerFailure(t *testing.T) { - expected := errors.New("mocked error") +func TestDialerConnectObserver(t *testing.T) { saver := &Saver{} - dlr := &SaverDialer{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, expected - }, - }, - Saver: saver, + obs := &dialerConnectObserver{ + saver: saver, } - conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, expected) { - t.Fatal("expected another error here") + dialer := &mocks.Dialer{} + out := obs.WrapDialer(dialer) + dialSaver := out.(*DialerSaver) + if dialSaver.Dialer != dialer { + t.Fatal("invalid dialer") } - if conn != nil { - t.Fatal("expected nil conn here") - } - ev := saver.Read() - if len(ev) != 1 { - t.Fatal("expected a single event here") - } - if ev[0].Value().Address != "www.google.com:443" { - t.Fatal("unexpected Address") - } - if ev[0].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if !errors.Is(ev[0].Value().Err, expected) { - t.Fatal("unexpected Err") - } - if ev[0].Name() != netxlite.ConnectOperation { - t.Fatal("unexpected Name") - } - if ev[0].Value().Proto != "tcp" { - t.Fatal("unexpected Proto") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("unexpected Time") + if dialSaver.Saver != saver { + t.Fatal("invalid saver") } } -func TestSaverConnDialerFailure(t *testing.T) { - expected := errors.New("mocked error") - saver := &Saver{} - dlr := &SaverConnDialer{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return nil, expected - }, - }, - Saver: saver, - } - conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } -} - -func TestSaverConnDialerSuccess(t *testing.T) { - saver := &Saver{} - dlr := &SaverConnDialer{ - Dialer: &SaverDialer{ +func TestDialerSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + dlr := &DialerSaver{ Dialer: &mocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { - return &mocks.Conn{ - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockWrite: func(b []byte) (int, error) { - return 0, io.EOF - }, - MockClose: func() error { - return io.EOF - }, - MockLocalAddr: func() net.Addr { - return &net.TCPAddr{Port: 12345} - }, - }, nil + return nil, expected }, }, Saver: saver, - }, - Saver: saver, - } - conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal("not the error we expected", err) - } - conn.Read(nil) - conn.Write(nil) - conn.Close() - events := saver.Read() - if len(events) != 3 { - t.Fatal("unexpected number of events saved", len(events)) - } - if events[0].Name() != "connect" { - t.Fatal("expected a connect event") - } - saverCheckConnectEvent(t, &events[0]) - if events[1].Name() != "read" { - t.Fatal("expected a read event") - } - saverCheckReadEvent(t, &events[1]) - if events[2].Name() != "write" { - t.Fatal("expected a write event") - } - saverCheckWriteEvent(t, &events[2]) + } + conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") + if !errors.Is(err, expected) { + t.Fatal("expected another error here") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + ev := saver.Read() + if len(ev) != 1 { + t.Fatal("expected a single event here") + } + if ev[0].Value().Address != "www.google.com:443" { + t.Fatal("unexpected Address") + } + if ev[0].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[0].Value().Err, expected) { + t.Fatal("unexpected Err") + } + if ev[0].Name() != netxlite.ConnectOperation { + t.Fatal("unexpected Name") + } + if ev[0].Value().Proto != "tcp" { + t.Fatal("unexpected Proto") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("unexpected Time") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + dialer := &DialerSaver{ + Dialer: child, + Saver: &Saver{}, + } + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) } -func saverCheckConnectEvent(t *testing.T, ev *Event) { - // TODO(bassosimone): implement +func TestDialerReadWriteObserver(t *testing.T) { + saver := &Saver{} + obs := &dialerReadWriteObserver{ + saver: saver, + } + dialer := &mocks.Dialer{} + out := obs.WrapDialer(dialer) + dialSaver := out.(*DialerConnSaver) + if dialSaver.Dialer != dialer { + t.Fatal("invalid dialer") + } + if dialSaver.Saver != saver { + t.Fatal("invalid saver") + } } -func saverCheckReadEvent(t *testing.T, ev *Event) { - // TODO(bassosimone): implement +func TestDialerConnSaver(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + dlr := &DialerConnSaver{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }, + }, + Saver: saver, + } + conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + + t.Run("on success", func(t *testing.T) { + origConn := &mocks.Conn{} + saver := &Saver{} + dlr := &DialerConnSaver{ + Dialer: &DialerSaver{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return origConn, nil + }, + }, + Saver: saver, + }, + Saver: saver, + } + conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") + if err != nil { + t.Fatal("not the error we expected", err) + } + cw := conn.(*dialerConnWrapper) + if cw.Conn != origConn { + t.Fatal("unexpected conn") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + dialer := &DialerConnSaver{ + Dialer: child, + Saver: &Saver{}, + } + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) } -func saverCheckWriteEvent(t *testing.T, ev *Event) { - // TODO(bassosimone): implement +func TestDialerConnWrapper(t *testing.T) { + t.Run("Read", func(t *testing.T) { + baseConn := &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "www.google.com:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + saver := &Saver{} + conn := &dialerConnWrapper{ + Conn: baseConn, + saver: saver, + } + data := make([]byte, 155) + count, err := conn.Read(data) + if !errors.Is(err, io.EOF) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + ev := saver.Read() + if len(ev) != 1 { + t.Fatal("expected a single event here") + } + if ev[0].Value().Address != "www.google.com:443" { + t.Fatal("unexpected Address") + } + if ev[0].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[0].Value().Err, io.EOF) { + t.Fatal("unexpected Err") + } + if ev[0].Name() != netxlite.ReadOperation { + t.Fatal("unexpected Name") + } + if ev[0].Value().Proto != "tcp" { + t.Fatal("unexpected Proto") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("unexpected Time") + } + }) + + t.Run("Write", func(t *testing.T) { + baseConn := &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "www.google.com:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + saver := &Saver{} + conn := &dialerConnWrapper{ + Conn: baseConn, + saver: saver, + } + data := make([]byte, 155) + count, err := conn.Write(data) + if !errors.Is(err, io.EOF) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("unexpected count") + } + ev := saver.Read() + if len(ev) != 1 { + t.Fatal("expected a single event here") + } + if ev[0].Value().Address != "www.google.com:443" { + t.Fatal("unexpected Address") + } + if ev[0].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[0].Value().Err, io.EOF) { + t.Fatal("unexpected Err") + } + if ev[0].Name() != netxlite.WriteOperation { + t.Fatal("unexpected Name") + } + if ev[0].Value().Proto != "tcp" { + t.Fatal("unexpected Proto") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("unexpected Time") + } + }) } diff --git a/internal/engine/netx/tracex/event.go b/internal/engine/netx/tracex/event.go index 41a23af..48f2b71 100644 --- a/internal/engine/netx/tracex/event.go +++ b/internal/engine/netx/tracex/event.go @@ -1,5 +1,9 @@ package tracex +// +// All the possible events +// + import ( "crypto/x509" "net/http" diff --git a/internal/engine/netx/tracex/http.go b/internal/engine/netx/tracex/http.go index 744a305..9d745aa 100644 --- a/internal/engine/netx/tracex/http.go +++ b/internal/engine/netx/tracex/http.go @@ -27,11 +27,17 @@ func httpCloneRequestHeaders(req *http.Request) http.Header { return header } -// SaverTransactionHTTPTransport is a RoundTripper that saves +// HTTPTransportSaver is a RoundTripper that saves // events related to the HTTP transaction -type SaverTransactionHTTPTransport struct { - model.HTTPTransport - Saver *Saver +type HTTPTransportSaver struct { + // HTTPTransport is the MANDATORY underlying HTTP transport. + HTTPTransport model.HTTPTransport + + // Saver is the MANDATORY saver to use. + Saver *Saver + + // SnapshotSize is the OPTIONAL maximum body snapshot size (if not set, we'll + // use 1<<17, which we've been using since the ooni/netx days) SnapshotSize int64 } @@ -40,7 +46,11 @@ type SaverTransactionHTTPTransport struct { // // The maxBodySnapshotSize argument controls the maximum size of the // body snapshot that we collect along with the HTTP round trip. -func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { +func (txp *HTTPTransportSaver) RoundTrip(req *http.Request) (*http.Response, error) { + + // TODO(bassosimone): we're currently using the started time for + // the transaction done event, which contrasts with what we do for + // every other event. What does the spec say? started := time.Now() txp.Saver.Write(&EventHTTPTransactionStart{&EventValue{ @@ -92,7 +102,15 @@ func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Re return resp, nil } -func (txp *SaverTransactionHTTPTransport) snapshotSize() int64 { +func (txp *HTTPTransportSaver) CloseIdleConnections() { + txp.HTTPTransport.CloseIdleConnections() +} + +func (txp *HTTPTransportSaver) Network() string { + return txp.HTTPTransport.Network() +} + +func (txp *HTTPTransportSaver) snapshotSize() int64 { if txp.SnapshotSize > 0 { return txp.SnapshotSize } @@ -104,4 +122,4 @@ type httpReadableAgainBody struct { io.Closer } -var _ model.HTTPTransport = &SaverTransactionHTTPTransport{} +var _ model.HTTPTransport = &HTTPTransportSaver{} diff --git a/internal/engine/netx/tracex/http_test.go b/internal/engine/netx/tracex/http_test.go index 1ed764f..0efcae3 100644 --- a/internal/engine/netx/tracex/http_test.go +++ b/internal/engine/netx/tracex/http_test.go @@ -16,227 +16,262 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" ) -func TestSaverTransactionHTTPTransport(t *testing.T) { +func TestHTTPTransportSaver(t *testing.T) { - startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) { - server := &filtering.HTTPProxy{ - OnIncomingHost: func(host string) filtering.HTTPAction { - return action + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + called = true }, } - listener, err := server.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) + dialer := &HTTPTransportSaver{ + HTTPTransport: child, + Saver: &Saver{}, } - URL := &url.URL{ - Scheme: "http", - Host: listener.Addr().String(), - Path: "/", - } - return listener, URL - } - - measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) { - saver := &Saver{} - txp := &SaverTransactionHTTPTransport{ - HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger), - Saver: saver, - } - req, err := http.NewRequest("GET", URL.String(), nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("User-Agent", "miniooni") - resp, err := txp.RoundTrip(req) - return resp, saver, err - } - - validateRequestFields := func(t *testing.T, value *EventValue, URL *url.URL) { - if value.HTTPMethod != "GET" { - t.Fatal("invalid method") - } - if value.HTTPRequestHeaders.Get("Host") != URL.Host { - t.Fatal("invalid Host header") - } - if value.HTTPRequestHeaders.Get("User-Agent") != "miniooni" { - t.Fatal("invalid User-Agent header") - } - if value.HTTPURL != URL.String() { - t.Fatal("invalid URL") - } - if value.Time.IsZero() { - t.Fatal("expected nonzero Time") - } - if value.Transport != "tcp" { - t.Fatal("expected Transport to be tcp") - } - } - - validateRequest := func(t *testing.T, ev Event, URL *url.URL) { - if _, good := ev.(*EventHTTPTransactionStart); !good { - t.Fatal("invalid event type") - } - if ev.Name() != "http_transaction_start" { - t.Fatal("invalid event name") - } - value := ev.Value() - validateRequestFields(t, value, URL) - } - - validateResponseSuccess := func(t *testing.T, ev Event, URL *url.URL) { - if _, good := ev.(*EventHTTPTransactionDone); !good { - t.Fatal("invalid event type") - } - if ev.Name() != "http_transaction_done" { - t.Fatal("invalid event name") - } - value := ev.Value() - validateRequestFields(t, value, URL) - if value.Duration <= 0 { - t.Fatal("expected nonzero duration") - } - if len(value.HTTPResponseHeaders) <= 0 { - t.Fatal("expected at least one response header") - } - if !bytes.Equal(value.HTTPResponseBody, filtering.HTTPBlockpage451) { - t.Fatal("unexpected value for response body") - } - if value.HTTPStatusCode != 451 { - t.Fatal("unexpected status code") - } - } - - t.Run("on success", func(t *testing.T) { - listener, URL := startServer(t, filtering.HTTPAction451) - defer listener.Close() - resp, saver, err := measureHTTP(t, URL) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.StatusCode != 451 { - t.Fatal("unexpected status code", resp.StatusCode) - } - events := saver.Read() - if len(events) != 2 { - t.Fatal("unexpected number of events") - } - validateRequest(t, events[0], URL) - validateResponseSuccess(t, events[1], URL) - data, err := netxlite.ReadAllContext(context.Background(), resp.Body) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, filtering.HTTPBlockpage451) { - t.Fatal("we cannot re-read the same body") + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") } }) - validateResponseFailure := func(t *testing.T, ev Event, URL *url.URL) { - if _, good := ev.(*EventHTTPTransactionDone); !good { - t.Fatal("invalid event type") + t.Run("Network", func(t *testing.T) { + expected := "antani" + child := &mocks.HTTPTransport{ + MockNetwork: func() string { + return expected + }, } - if ev.Name() != "http_transaction_done" { - t.Fatal("invalid event name") + dialer := &HTTPTransportSaver{ + HTTPTransport: child, + Saver: &Saver{}, } - value := ev.Value() - validateRequestFields(t, value, URL) - if value.Duration <= 0 { - t.Fatal("expected nonzero duration") + if dialer.Network() != expected { + t.Fatal("unexpected Network") } - if value.Err.Error() != "connection_reset" { - t.Fatal("unexpected Err value") - } - if len(value.HTTPResponseHeaders) > 0 { - t.Fatal("expected zero response headers") - } - if !bytes.Equal(value.HTTPResponseBody, nil) { - t.Fatal("unexpected value for response body") - } - if value.HTTPStatusCode != 0 { - t.Fatal("unexpected status code") - } - } - - t.Run("on round trip failure", func(t *testing.T) { - listener, URL := startServer(t, filtering.HTTPActionReset) - defer listener.Close() - resp, saver, err := measureHTTP(t, URL) - if err == nil || err.Error() != "connection_reset" { - t.Fatal("unexpected err", err) - } - if resp != nil { - t.Fatal("expected nil response") - } - events := saver.Read() - if len(events) != 2 { - t.Fatal("unexpected number of events") - } - validateRequest(t, events[0], URL) - validateResponseFailure(t, events[1], URL) }) - // Sometimes useful for testing - /* - dumplog := func(t *testing.T, ev Event) { - data, _ := json.MarshalIndent(ev.Value(), " ", " ") - t.Log(string(data)) - t.FailNow() + t.Run("RoundTrip", func(t *testing.T) { + startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) { + server := &filtering.HTTPProxy{ + OnIncomingHost: func(host string) filtering.HTTPAction { + return action + }, + } + listener, err := server.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + URL := &url.URL{ + Scheme: "http", + Host: listener.Addr().String(), + Path: "/", + } + return listener, URL } - */ - t.Run("on error reading the response body", func(t *testing.T) { - saver := &Saver{} - expected := errors.New("mocked error") - txp := SaverTransactionHTTPTransport{ - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - Header: http.Header{ - "Server": {"antani"}, - }, - StatusCode: 200, - Body: io.NopCloser(&mocks.Reader{ - MockRead: func(b []byte) (int, error) { - return 0, expected + measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) { + saver := &Saver{} + txp := &HTTPTransportSaver{ + HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger), + Saver: saver, + } + req, err := http.NewRequest("GET", URL.String(), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("User-Agent", "miniooni") + resp, err := txp.RoundTrip(req) + return resp, saver, err + } + + validateRequestFields := func(t *testing.T, value *EventValue, URL *url.URL) { + if value.HTTPMethod != "GET" { + t.Fatal("invalid method") + } + if value.HTTPRequestHeaders.Get("Host") != URL.Host { + t.Fatal("invalid Host header") + } + if value.HTTPRequestHeaders.Get("User-Agent") != "miniooni" { + t.Fatal("invalid User-Agent header") + } + if value.HTTPURL != URL.String() { + t.Fatal("invalid URL") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + if value.Transport != "tcp" { + t.Fatal("expected Transport to be tcp") + } + } + + validateRequest := func(t *testing.T, ev Event, URL *url.URL) { + if _, good := ev.(*EventHTTPTransactionStart); !good { + t.Fatal("invalid event type") + } + if ev.Name() != "http_transaction_start" { + t.Fatal("invalid event name") + } + value := ev.Value() + validateRequestFields(t, value, URL) + } + + validateResponseSuccess := func(t *testing.T, ev Event, URL *url.URL) { + if _, good := ev.(*EventHTTPTransactionDone); !good { + t.Fatal("invalid event type") + } + if ev.Name() != "http_transaction_done" { + t.Fatal("invalid event name") + } + value := ev.Value() + validateRequestFields(t, value, URL) + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if len(value.HTTPResponseHeaders) <= 0 { + t.Fatal("expected at least one response header") + } + if !bytes.Equal(value.HTTPResponseBody, filtering.HTTPBlockpage451) { + t.Fatal("unexpected value for response body") + } + if value.HTTPStatusCode != 451 { + t.Fatal("unexpected status code") + } + } + + t.Run("on success", func(t *testing.T) { + listener, URL := startServer(t, filtering.HTTPAction451) + defer listener.Close() + resp, saver, err := measureHTTP(t, URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 451 { + t.Fatal("unexpected status code", resp.StatusCode) + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("unexpected number of events") + } + validateRequest(t, events[0], URL) + validateResponseSuccess(t, events[1], URL) + data, err := netxlite.ReadAllContext(context.Background(), resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, filtering.HTTPBlockpage451) { + t.Fatal("we cannot re-read the same body") + } + }) + + validateResponseFailure := func(t *testing.T, ev Event, URL *url.URL) { + if _, good := ev.(*EventHTTPTransactionDone); !good { + t.Fatal("invalid event type") + } + if ev.Name() != "http_transaction_done" { + t.Fatal("invalid event name") + } + value := ev.Value() + validateRequestFields(t, value, URL) + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if value.Err.Error() != "connection_reset" { + t.Fatal("unexpected Err value") + } + if len(value.HTTPResponseHeaders) > 0 { + t.Fatal("expected zero response headers") + } + if !bytes.Equal(value.HTTPResponseBody, nil) { + t.Fatal("unexpected value for response body") + } + if value.HTTPStatusCode != 0 { + t.Fatal("unexpected status code") + } + } + + t.Run("on round trip failure", func(t *testing.T) { + listener, URL := startServer(t, filtering.HTTPActionReset) + defer listener.Close() + resp, saver, err := measureHTTP(t, URL) + if err == nil || err.Error() != "connection_reset" { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil response") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("unexpected number of events") + } + validateRequest(t, events[0], URL) + validateResponseFailure(t, events[1], URL) + }) + + // Sometimes useful for testing + /* + dump := func(t *testing.T, ev Event) { + data, _ := json.MarshalIndent(ev.Value(), " ", " ") + t.Log(string(data)) + t.Fail() + } + */ + + t.Run("on error reading the response body", func(t *testing.T) { + saver := &Saver{} + expected := errors.New("mocked error") + txp := HTTPTransportSaver{ + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Header: http.Header{ + "Server": {"antani"}, }, - }), - }, nil + StatusCode: 200, + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + }), + }, nil + }, + MockNetwork: func() string { + return "tcp" + }, }, - MockNetwork: func() string { - return "tcp" - }, - }, - SnapshotSize: 4, - Saver: saver, - } - URL := &url.URL{ - Scheme: "http", - Host: "127.0.0.1:9050", - } - req, err := http.NewRequest("GET", URL.String(), nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("User-Agent", "miniooni") - resp, err := txp.RoundTrip(req) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if resp != nil { - t.Fatal("expected nil response") - } - ev := saver.Read() - validateRequest(t, ev[0], URL) - if ev[1].Value().HTTPStatusCode != 200 { - t.Fatal("invalid status code") - } - if ev[1].Value().HTTPResponseHeaders.Get("Server") != "antani" { - t.Fatal("invalid Server header") - } - if ev[1].Value().Err.Error() != "unknown_failure: mocked error" { - t.Fatal("invalid error") - } + SnapshotSize: 4, + Saver: saver, + } + URL := &url.URL{ + Scheme: "http", + Host: "127.0.0.1:9050", + } + req, err := http.NewRequest("GET", URL.String(), nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("User-Agent", "miniooni") + resp, err := txp.RoundTrip(req) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response") + } + ev := saver.Read() + validateRequest(t, ev[0], URL) + if ev[1].Value().HTTPStatusCode != 200 { + t.Fatal("invalid status code") + } + if ev[1].Value().HTTPResponseHeaders.Get("Server") != "antani" { + t.Fatal("invalid Server header") + } + if ev[1].Value().Err.Error() != "unknown_failure: mocked error" { + t.Fatal("invalid error") + } + }) }) } diff --git a/internal/engine/netx/tracex/quic.go b/internal/engine/netx/tracex/quic.go index fb4f2c7..2162ec6 100644 --- a/internal/engine/netx/tracex/quic.go +++ b/internal/engine/netx/tracex/quic.go @@ -15,8 +15,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// QUICHandshakeSaver saves events occurring during the QUIC handshake. -type QUICHandshakeSaver struct { +// QUICDialerSaver saves events occurring during the QUIC handshake. +type QUICDialerSaver struct { // QUICDialer is the wrapped dialer QUICDialer model.QUICDialer @@ -33,14 +33,14 @@ func (s *Saver) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer { if s == nil { return qd } - return &QUICHandshakeSaver{ + return &QUICDialerSaver{ QUICDialer: qd, Saver: s, } } // DialContext implements QUICDialer.DialContext -func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, +func (h *QUICDialerSaver) DialContext(ctx context.Context, network string, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { start := time.Now() // TODO(bassosimone): in the future we probably want to also save @@ -58,9 +58,11 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, if err != nil { // TODO(bassosimone): here we should save the peer certs h.Saver.Write(&EventQUICHandshakeDone{&EventValue{ + Address: host, Duration: stop.Sub(start), Err: err, NoTLSVerify: tlsCfg.InsecureSkipVerify, + Proto: network, TLSNextProtos: tlsCfg.NextProtos, TLSServerName: tlsCfg.ServerName, Time: stop, @@ -69,8 +71,10 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, } state := quicConnectionState(sess) h.Saver.Write(&EventQUICHandshakeDone{&EventValue{ + Address: host, Duration: stop.Sub(start), NoTLSVerify: tlsCfg.InsecureSkipVerify, + Proto: network, TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSNegotiatedProto: state.NegotiatedProtocol, TLSNextProtos: tlsCfg.NextProtos, @@ -82,7 +86,7 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, return sess, nil } -func (h *QUICHandshakeSaver) CloseIdleConnections() { +func (h *QUICDialerSaver) CloseIdleConnections() { h.QUICDialer.CloseIdleConnections() } @@ -121,15 +125,15 @@ func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, erro if err != nil { return nil, err } - pconn = &udpLikeConnSaver{ + pconn = &quicPacketConnWrapper{ UDPLikeConn: pconn, saver: qls.Saver, } return pconn, nil } -// udpLikeConnSaver saves I/O events -type udpLikeConnSaver struct { +// quicPacketConnWrapper saves I/O events +type quicPacketConnWrapper struct { // UDPLikeConn is the wrapped underlying conn model.UDPLikeConn @@ -137,7 +141,7 @@ type udpLikeConnSaver struct { saver *Saver } -func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) { +func (c *quicPacketConnWrapper) WriteTo(p []byte, addr net.Addr) (int, error) { start := time.Now() count, err := c.UDPLikeConn.WriteTo(p, addr) stop := time.Now() @@ -152,7 +156,7 @@ func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) { return count, err } -func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *quicPacketConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) { start := time.Now() n, addr, err := c.UDPLikeConn.ReadFrom(b) stop := time.Now() @@ -171,13 +175,13 @@ func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) { return n, addr, err } -func (c *udpLikeConnSaver) safeAddrString(addr net.Addr) (out string) { +func (c *quicPacketConnWrapper) safeAddrString(addr net.Addr) (out string) { if addr != nil { out = addr.String() } return } -var _ model.QUICDialer = &QUICHandshakeSaver{} +var _ model.QUICDialer = &QUICDialerSaver{} var _ model.QUICListener = &QUICListenerSaver{} -var _ model.UDPLikeConn = &udpLikeConnSaver{} +var _ model.UDPLikeConn = &quicPacketConnWrapper{} diff --git a/internal/engine/netx/tracex/quic_test.go b/internal/engine/netx/tracex/quic_test.go index 804cbb6..4f766ed 100644 --- a/internal/engine/netx/tracex/quic_test.go +++ b/internal/engine/netx/tracex/quic_test.go @@ -3,182 +3,446 @@ package tracex import ( "context" "crypto/tls" + "crypto/x509" "errors" "net" - "reflect" - "strings" "testing" - "time" + "github.com/google/go-cmp/cmp" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" - "github.com/ooni/probe-cli/v3/internal/netxlite" - "github.com/ooni/probe-cli/v3/internal/netxlite/quictesting" ) -type MockDialer struct { - Dialer model.QUICDialer - Sess quic.EarlyConnection - Err error -} +func TestQUICDialerSaver(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { -func (d MockDialer) DialContext(ctx context.Context, network, host string, - tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { - if d.Dialer != nil { - return d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg) - } - return d.Sess, d.Err -} + checkStartEventFields := func(t *testing.T, value *EventValue) { + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if !value.NoTLSVerify { + t.Fatal("expected NoTLSVerify to be true") + } + if value.Proto != "udp" { + t.Fatal("wrong protocol") + } + if diff := cmp.Diff(value.TLSNextProtos, []string{"h3"}); diff != "" { + t.Fatal(diff) + } + if value.TLSServerName != "dns.google" { + t.Fatal("invalid TLSServerName") + } + if value.Time.IsZero() { + t.Fatal("expected non zero time") + } + } -func TestHandshakeSaverSuccess(t *testing.T) { - nextprotos := []string{"h3"} - servername := quictesting.Domain - tlsConf := &tls.Config{ - NextProtos: nextprotos, - ServerName: servername, - } - saver := &Saver{} - dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ - QUICListener: &netxlite.QUICListenerStdlib{}, + checkStartedEvent := func(t *testing.T, ev Event) { + if _, good := ev.(*EventQUICHandshakeStart); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) + } + + checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err != nil { + t.Fatal("expected no error here") + } + if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" { + t.Fatal("invalid cipher suite") + } + if value.TLSNegotiatedProto != "h3" { + t.Fatal("invalid negotiated protocol") + } + if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" { + t.Fatal(diff) + } + if value.TLSVersion != "TLSv1.3" { + t.Fatal("invalid TLS version") + } + } + + checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) { + if _, good := ev.(*EventQUICHandshakeDone); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) + fun(t, value) + } + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + returnedConn := &mocks.QUICEarlyConnection{ + MockConnectionState: func() quic.ConnectionState { + cs := quic.ConnectionState{} + cs.TLS.ConnectionState.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA + cs.TLS.NegotiatedProtocol = "h3" + cs.TLS.PeerCertificates = []*x509.Certificate{} + cs.TLS.Version = tls.VersionTLS13 + return cs + }, + } + dialer := saver.WrapQUICDialer(&mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return returnedConn, nil + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h3"}, + ServerName: "dns.google", + } + quicConfig := &quic.Config{} + conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess) + }) + + checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err == nil { + t.Fatal("expected non-nil error here") + } + } + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + dialer := saver.WrapQUICDialer(&mocks.QUICDialer{ + MockDialContext: func(ctx context.Context, network, address string, + tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { + return nil, expected + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h3"}, + ServerName: "dns.google", + } + quicConfig := &quic.Config{} + conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsFailure) + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.QUICDialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + dialer := &QUICDialerSaver{ + QUICDialer: child, + Saver: &Saver{}, + } + dialer.CloseIdleConnections() + if !called { + t.Fatal("not called") + } }) - sess, err := dlr.DialContext(context.Background(), "udp", - quictesting.Endpoint("443"), tlsConf, &quic.Config{}) - if err != nil { - t.Fatal("unexpected error", err) - } - if sess == nil { - t.Fatal("unexpected nil sess") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("unexpected number of events") - } - if ev[0].Name() != "quic_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].Value().TLSServerName != quictesting.Domain { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Value().Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err", ev[1].Value().Err) - } - if ev[1].Name() != "quic_handshake_done" { - t.Fatal("unexpected Name") - } - if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[1].Value().TLSServerName != quictesting.Domain { - t.Fatal("unexpected TLSServerName") - } - if ev[1].Value().Time.Before(ev[0].Value().Time) { - t.Fatal("unexpected Time") - } } -func TestHandshakeSaverHostNameError(t *testing.T) { - nextprotos := []string{"h3"} - servername := "example.com" - tlsConf := &tls.Config{ - NextProtos: nextprotos, - ServerName: servername, - } - saver := &Saver{} - dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ - QUICListener: &netxlite.QUICListenerStdlib{}, +func TestQUICListenerSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + qls := saver.WrapQUICListener(&mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { + return nil, expected + }, + }) + pconn, err := qls.Listen(&net.UDPAddr{ + IP: []byte{}, + Port: 8080, + Zone: "", + }) + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if pconn != nil { + t.Fatal("expected nil pconn here") + } }) - sess, err := dlr.DialContext(context.Background(), "udp", - quictesting.Endpoint("443"), tlsConf, &quic.Config{}) - if err == nil { - t.Fatal("expected an error here") - } - if sess != nil { - t.Fatal("expected nil sess here") - } - for _, ev := range saver.Read() { - if ev.Name() != "quic_handshake_done" { - continue + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + returnedConn := &mocks.UDPLikeConn{} + qls := saver.WrapQUICListener(&mocks.QUICListener{ + MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { + return returnedConn, nil + }, + }) + pconn, err := qls.Listen(&net.UDPAddr{ + IP: []byte{}, + Port: 8080, + Zone: "", + }) + if err != nil { + t.Fatal(err) } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") + wconn := pconn.(*quicPacketConnWrapper) + if wconn.UDPLikeConn != returnedConn { + t.Fatal("invalid underlying connection") } - if !strings.HasSuffix(ev.Value().Err.Error(), "tls: handshake failure") { - t.Fatal("unexpected error", ev.Value().Err) + if wconn.saver != saver { + t.Fatal("invalid saver") } - } + }) } -func TestQUICListenerSaverCannotListen(t *testing.T) { - expected := errors.New("mocked error") - saver := &Saver{} - qls := saver.WrapQUICListener(&mocks.QUICListener{ - MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { - return nil, expected - }, - }) - pconn, err := qls.Listen(&net.UDPAddr{ - IP: []byte{}, - Port: 8080, - Zone: "", - }) - if !errors.Is(err, expected) { - t.Fatal("unexpected error", err) - } - if pconn != nil { - t.Fatal("expected nil pconn here") - } -} +func TestQUICPacketConnWrapper(t *testing.T) { + t.Run("ReadFrom", func(t *testing.T) { -func TestSystemDialerSuccessWithReadWrite(t *testing.T) { - // This is the most common use case for collecting reads, writes - tlsConf := &tls.Config{ - NextProtos: []string{"h3"}, - ServerName: quictesting.Domain, - } - saver := &Saver{} - systemdialer := &netxlite.QUICDialerQUICGo{ - QUICListener: saver.WrapQUICListener(&netxlite.QUICListenerStdlib{}), - } - _, err := systemdialer.DialContext(context.Background(), "udp", - quictesting.Endpoint("443"), tlsConf, &quic.Config{}) - if err != nil { - t.Fatal(err) - } - ev := saver.Read() - if len(ev) < 2 { - t.Fatal("unexpected number of events") - } - last := len(ev) - 1 - for idx := 1; idx < last; idx++ { - if ev[idx].Value().Data == nil { - t.Fatal("unexpected Data") - } - if ev[idx].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[idx].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[idx].Value().NumBytes <= 0 { - t.Fatal("unexpected NumBytes") - } - switch ev[idx].Name() { - case netxlite.ReadFromOperation, netxlite.WriteToOperation: - default: - t.Fatal("unexpected Name") - } - if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) { - t.Fatal("unexpected Time", ev[idx].Value().Time, ev[idx-1].Value().Time) - } - } + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + return 0, nil, expected + }, + }, + saver: saver, + } + buf := make([]byte, 1<<17) + count, addr, err := conn.ReadFrom(buf) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("invalid count") + } + if addr != nil { + t.Fatal("invalid addr") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventReadFromOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "" { + t.Fatal("invalid Address") + } + if len(value.Data) != 0 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if !errors.Is(value.Err, expected) { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 0 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + + t.Run("on success", func(t *testing.T) { + expected := []byte{1, 2, 3, 4} + saver := &Saver{} + expectedAddr := &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + MockNetwork: func() string { + return "udp" + }, + } + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + copy(p, expected) + return len(expected), expectedAddr, nil + }, + }, + saver: saver, + } + buf := make([]byte, 1<<17) + count, addr, err := conn.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + if count != 4 { + t.Fatal("invalid count") + } + if addr != expectedAddr { + t.Fatal("invalid addr") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventReadFromOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if len(value.Data) != 4 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if value.Err != nil { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 4 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + }) + + t.Run("WriteTo", func(t *testing.T) { + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return 0, expected + }, + }, + saver: saver, + } + destAddr := &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + } + buf := make([]byte, 7) + count, err := conn.WriteTo(buf, destAddr) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("invalid count") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventWriteToOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if len(value.Data) != 0 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if !errors.Is(value.Err, expected) { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 0 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + conn := &quicPacketConnWrapper{ + UDPLikeConn: &mocks.UDPLikeConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return 1, nil + }, + }, + saver: saver, + } + destAddr := &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + } + buf := make([]byte, 7) + count, err := conn.WriteTo(buf, destAddr) + if err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("invalid count") + } + events := saver.Read() + if len(events) != 1 { + t.Fatal("invalid number of events") + } + ev0 := events[0] + if _, good := ev0.(*EventWriteToOperation); !good { + t.Fatal("invalid event type") + } + value := ev0.Value() + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if len(value.Data) != 1 { + t.Fatal("invalid Data") + } + if value.Duration <= 0 { + t.Fatal("expected nonzero duration") + } + if value.Err != nil { + t.Fatal("unexpected value.Err", value.Err) + } + if value.NumBytes != 1 { + t.Fatal("expected NumBytes") + } + if value.Time.IsZero() { + t.Fatal("expected nonzero Time") + } + }) + }) } diff --git a/internal/engine/netx/tracex/resolver.go b/internal/engine/netx/tracex/resolver.go index db5afd3..62cf1f5 100644 --- a/internal/engine/netx/tracex/resolver.go +++ b/internal/engine/netx/tracex/resolver.go @@ -12,8 +12,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// SaverResolver is a resolver that saves events. -type SaverResolver struct { +// ResolverSaver is a resolver that saves events. +type ResolverSaver struct { // Resolver is the underlying resolver. Resolver model.Resolver @@ -30,14 +30,14 @@ func (s *Saver) WrapResolver(r model.Resolver) model.Resolver { if s == nil { return r } - return &SaverResolver{ + return &ResolverSaver{ Resolver: r, Saver: s, } } // LookupHost implements Resolver.LookupHost -func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { +func (r *ResolverSaver) LookupHost(ctx context.Context, hostname string) ([]string, error) { start := time.Now() r.Saver.Write(&EventResolveStart{&EventValue{ Address: r.Resolver.Address(), @@ -59,30 +59,30 @@ func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]stri return addrs, err } -func (r *SaverResolver) Network() string { +func (r *ResolverSaver) Network() string { return r.Resolver.Network() } -func (r *SaverResolver) Address() string { +func (r *ResolverSaver) Address() string { return r.Resolver.Address() } -func (r *SaverResolver) CloseIdleConnections() { +func (r *ResolverSaver) CloseIdleConnections() { r.Resolver.CloseIdleConnections() } -func (r *SaverResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { +func (r *ResolverSaver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { // TODO(bassosimone): we should probably implement this method return r.Resolver.LookupHTTPS(ctx, domain) } -func (r *SaverResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { +func (r *ResolverSaver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { // TODO(bassosimone): we should probably implement this method return r.Resolver.LookupNS(ctx, domain) } -// SaverDNSTransport is a DNS transport that saves events. -type SaverDNSTransport struct { +// DNSTransportSaver is a DNS transport that saves events. +type DNSTransportSaver struct { // DNSTransport is the underlying DNS transport. DNSTransport model.DNSTransport @@ -99,14 +99,14 @@ func (s *Saver) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport { if s == nil { return txp } - return &SaverDNSTransport{ + return &DNSTransportSaver{ DNSTransport: txp, Saver: s, } } // RoundTrip implements RoundTripper.RoundTrip -func (txp *SaverDNSTransport) RoundTrip( +func (txp *DNSTransportSaver) RoundTrip( ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { start := time.Now() txp.Saver.Write(&EventDNSRoundTripStart{&EventValue{ @@ -129,19 +129,19 @@ func (txp *SaverDNSTransport) RoundTrip( return response, err } -func (txp *SaverDNSTransport) Network() string { +func (txp *DNSTransportSaver) Network() string { return txp.DNSTransport.Network() } -func (txp *SaverDNSTransport) Address() string { +func (txp *DNSTransportSaver) Address() string { return txp.DNSTransport.Address() } -func (txp *SaverDNSTransport) CloseIdleConnections() { +func (txp *DNSTransportSaver) CloseIdleConnections() { txp.DNSTransport.CloseIdleConnections() } -func (txp *SaverDNSTransport) RequiresPadding() bool { +func (txp *DNSTransportSaver) RequiresPadding() bool { return txp.DNSTransport.RequiresPadding() } @@ -157,5 +157,5 @@ func dnsMaybeResponseBytes(response model.DNSResponse) []byte { return response.Bytes() } -var _ model.Resolver = &SaverResolver{} -var _ model.DNSTransport = &SaverDNSTransport{} +var _ model.Resolver = &ResolverSaver{} +var _ model.DNSTransport = &DNSTransportSaver{} diff --git a/internal/engine/netx/tracex/resolver_test.go b/internal/engine/netx/tracex/resolver_test.go index dd1eede..631ee64 100644 --- a/internal/engine/netx/tracex/resolver_test.go +++ b/internal/engine/netx/tracex/resolver_test.go @@ -14,220 +14,224 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) -func TestSaverResolverFailure(t *testing.T) { - expected := errors.New("no such host") - saver := &Saver{} - reso := saver.WrapResolver(NewFakeResolverWithExplicitError(expected)) - addrs, err := reso.LookupHost(context.Background(), "www.google.com") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil address here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if ev[0].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[0].Name() != "resolve_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if ev[1].Value().Addresses != nil { - t.Fatal("unexpected Addresses") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if !errors.Is(ev[1].Value().Err, expected) { - t.Fatal("unexpected Err") - } - if ev[1].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[1].Name() != "resolve_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } -} - -func TestSaverResolverSuccess(t *testing.T) { - expected := []string{"8.8.8.8", "8.8.4.4"} - saver := &Saver{} - reso := saver.WrapResolver(NewFakeResolverWithResult(expected)) - addrs, err := reso.LookupHost(context.Background(), "www.google.com") - if err != nil { - t.Fatal("expected nil error here") - } - if !reflect.DeepEqual(addrs, expected) { - t.Fatal("not the result we expected") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if ev[0].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[0].Name() != "resolve_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { - t.Fatal("unexpected Addresses") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[1].Value().Hostname != "www.google.com" { - t.Fatal("unexpected Hostname") - } - if ev[1].Name() != "resolve_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } -} - -func TestSaverDNSTransportFailure(t *testing.T) { - expected := errors.New("no such host") - saver := &Saver{} - txp := saver.WrapDNSTransport(&mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return nil, expected - }, - MockNetwork: func() string { - return "fake" - }, - MockAddress: func() string { - return "" - }, +func TestResolverSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("no such host") + saver := &Saver{} + reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected)) + addrs, err := reso.LookupHost(context.Background(), "www.google.com") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if ev[0].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[0].Name() != "resolve_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if ev[1].Value().Addresses != nil { + t.Fatal("unexpected Addresses") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[1].Value().Err, expected) { + t.Fatal("unexpected Err") + } + if ev[1].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[1].Name() != "resolve_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } }) - rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return rawQuery, nil - }, - } - reply, err := txp.RoundTrip(context.Background(), query) - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[0].Name() != "dns_round_trip_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[1].Value().DNSResponse != nil { - t.Fatal("unexpected DNSReply") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if !errors.Is(ev[1].Value().Err, expected) { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "dns_round_trip_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } -} -func TestSaverDNSTransportSuccess(t *testing.T) { - expected := []byte{0xef, 0xbe, 0xad, 0xde} - saver := &Saver{} - response := &mocks.DNSResponse{ - MockBytes: func() []byte { - return expected - }, - } - txp := saver.WrapDNSTransport(&mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - return response, nil - }, - MockNetwork: func() string { - return "fake" - }, - MockAddress: func() string { - return "" - }, + t.Run("on success", func(t *testing.T) { + expected := []string{"8.8.8.8", "8.8.4.4"} + saver := &Saver{} + reso := saver.WrapResolver(newFakeResolverWithResult(expected)) + addrs, err := reso.LookupHost(context.Background(), "www.google.com") + if err != nil { + t.Fatal("expected nil error here") + } + if !reflect.DeepEqual(addrs, expected) { + t.Fatal("not the result we expected") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if ev[0].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[0].Name() != "resolve_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { + t.Fatal("unexpected Addresses") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err != nil { + t.Fatal("unexpected Err") + } + if ev[1].Value().Hostname != "www.google.com" { + t.Fatal("unexpected Hostname") + } + if ev[1].Name() != "resolve_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } }) - rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return rawQuery, nil - }, - } - reply, err := txp.RoundTrip(context.Background(), query) - if err != nil { - t.Fatal("we expected nil error here") - } - if !bytes.Equal(reply.Bytes(), expected) { - t.Fatal("expected another reply here") - } - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("expected number of events") - } - if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if ev[0].Name() != "dns_round_trip_start" { - t.Fatal("unexpected name") - } - if !ev[0].Value().Time.Before(time.Now()) { - t.Fatal("the saved time is wrong") - } - if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { - t.Fatal("unexpected DNSQuery") - } - if !bytes.Equal(ev[1].Value().DNSResponse, expected) { - t.Fatal("unexpected DNSReply") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "dns_round_trip_done" { - t.Fatal("unexpected name") - } - if !ev[1].Value().Time.After(ev[0].Value().Time) { - t.Fatal("the saved time is wrong") - } } -func NewFakeResolverWithExplicitError(err error) model.Resolver { +func TestDNSTransportSaver(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("no such host") + saver := &Saver{} + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expected + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } + reply, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[0].Name() != "dns_round_trip_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[1].Value().DNSResponse != nil { + t.Fatal("unexpected DNSReply") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if !errors.Is(ev[1].Value().Err, expected) { + t.Fatal("unexpected Err") + } + if ev[1].Name() != "dns_round_trip_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } + }) + + t.Run("on success", func(t *testing.T) { + expected := []byte{0xef, 0xbe, 0xad, 0xde} + saver := &Saver{} + response := &mocks.DNSResponse{ + MockBytes: func() []byte { + return expected + }, + } + txp := saver.WrapDNSTransport(&mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return response, nil + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, + }) + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } + reply, err := txp.RoundTrip(context.Background(), query) + if err != nil { + t.Fatal("we expected nil error here") + } + if !bytes.Equal(reply.Bytes(), expected) { + t.Fatal("expected another reply here") + } + ev := saver.Read() + if len(ev) != 2 { + t.Fatal("expected number of events") + } + if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if ev[0].Name() != "dns_round_trip_start" { + t.Fatal("unexpected name") + } + if !ev[0].Value().Time.Before(time.Now()) { + t.Fatal("the saved time is wrong") + } + if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { + t.Fatal("unexpected DNSQuery") + } + if !bytes.Equal(ev[1].Value().DNSResponse, expected) { + t.Fatal("unexpected DNSReply") + } + if ev[1].Value().Duration <= 0 { + t.Fatal("unexpected Duration") + } + if ev[1].Value().Err != nil { + t.Fatal("unexpected Err") + } + if ev[1].Name() != "dns_round_trip_done" { + t.Fatal("unexpected name") + } + if !ev[1].Value().Time.After(ev[0].Value().Time) { + t.Fatal("the saved time is wrong") + } + }) +} + +func newFakeResolverWithExplicitError(err error) model.Resolver { runtimex.PanicIfNil(err, "passed nil error") return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { @@ -251,7 +255,7 @@ func NewFakeResolverWithExplicitError(err error) model.Resolver { } } -func NewFakeResolverWithResult(r []string) model.Resolver { +func newFakeResolverWithResult(r []string) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return r, nil diff --git a/internal/engine/netx/tracex/saver_test.go b/internal/engine/netx/tracex/saver_test.go index 784e168..3f925e4 100644 --- a/internal/engine/netx/tracex/saver_test.go +++ b/internal/engine/netx/tracex/saver_test.go @@ -3,22 +3,88 @@ package tracex import ( "sync" "testing" + + "github.com/ooni/probe-cli/v3/internal/model/mocks" ) func TestSaver(t *testing.T) { - saver := Saver{} - var wg sync.WaitGroup - const parallel = 10 - wg.Add(parallel) - for idx := 0; idx < parallel; idx++ { - go func() { - saver.Write(&EventReadFromOperation{&EventValue{}}) - wg.Done() - }() - } - wg.Wait() - ev := saver.Read() - if len(ev) != parallel { - t.Fatal("unexpected number of events read") - } + t.Run("concurrent writes followed by read", func(t *testing.T) { + saver := Saver{} + var wg sync.WaitGroup + const parallel = 10 + wg.Add(parallel) + for idx := 0; idx < parallel; idx++ { + go func() { + saver.Write(&EventReadFromOperation{&EventValue{}}) + wg.Done() + }() + } + wg.Wait() + ev := saver.Read() + if len(ev) != parallel { + t.Fatal("unexpected number of events read") + } + }) + + t.Run("NewConnectObserver", func(t *testing.T) { + t.Run("nil Saver", func(t *testing.T) { + var saver *Saver + obs := saver.NewConnectObserver() + if obs != nil { + t.Fatal("expected nil observer") + } + }) + + t.Run("nonnnil Saver", func(t *testing.T) { + saver := &Saver{} + obs := saver.NewConnectObserver() + underlying := obs.(*dialerConnectObserver) + if underlying.saver != saver { + t.Fatal("invalid saver") + } + }) + }) + + t.Run("NewReadWriteObserver", func(t *testing.T) { + t.Run("nil Saver", func(t *testing.T) { + var saver *Saver + obs := saver.NewReadWriteObserver() + if obs != nil { + t.Fatal("expected nil observer") + } + }) + + t.Run("nonnnil Saver", func(t *testing.T) { + saver := &Saver{} + obs := saver.NewReadWriteObserver() + underlying := obs.(*dialerReadWriteObserver) + if underlying.saver != saver { + t.Fatal("invalid saver") + } + }) + }) + + t.Run("WrapQUICDialer", func(t *testing.T) { + t.Run("nil Saver", func(t *testing.T) { + var saver *Saver + base := &mocks.QUICDialer{} + qd := saver.WrapQUICDialer(base) + if qd != base { + t.Fatal("unexpected returned QUICDialer") + } + }) + + t.Run("nonnnil Saver", func(t *testing.T) { + saver := &Saver{} + base := &mocks.QUICDialer{} + qd := saver.WrapQUICDialer(base) + underlying := qd.(*QUICDialerSaver) + if underlying.Saver != saver { + t.Fatal("invalid Saver") + } + if underlying.QUICDialer != base { + t.Fatal("invalid QUICDialer") + } + }) + }) } diff --git a/internal/engine/netx/tracex/tls.go b/internal/engine/netx/tracex/tls.go index 32eb51c..02ea359 100644 --- a/internal/engine/netx/tracex/tls.go +++ b/internal/engine/netx/tracex/tls.go @@ -16,8 +16,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -// SaverTLSHandshaker saves events occurring during the TLS handshake. -type SaverTLSHandshaker struct { +// TLSHandshakerSaver saves events occurring during the TLS handshake. +type TLSHandshakerSaver struct { // TLSHandshaker is the underlying TLS handshaker. TLSHandshaker model.TLSHandshaker @@ -34,23 +34,26 @@ func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker { if s == nil { return thx } - return &SaverTLSHandshaker{ + return &TLSHandshakerSaver{ TLSHandshaker: thx, Saver: s, } } // Handshake implements model.TLSHandshaker.Handshake -func (h *SaverTLSHandshaker) Handshake( +func (h *TLSHandshakerSaver) Handshake( ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { + proto := conn.RemoteAddr().Network() + remoteAddr := conn.RemoteAddr().String() start := time.Now() h.Saver.Write(&EventTLSHandshakeStart{&EventValue{ + Address: remoteAddr, NoTLSVerify: config.InsecureSkipVerify, + Proto: proto, TLSNextProtos: config.NextProtos, TLSServerName: config.ServerName, Time: start, }}) - remoteAddr := conn.RemoteAddr().String() tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) stop := time.Now() h.Saver.Write(&EventTLSHandshakeDone{&EventValue{ @@ -58,6 +61,7 @@ func (h *SaverTLSHandshaker) Handshake( Duration: stop.Sub(start), Err: err, NoTLSVerify: config.InsecureSkipVerify, + Proto: proto, TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSNegotiatedProto: state.NegotiatedProtocol, TLSNextProtos: config.NextProtos, @@ -69,7 +73,7 @@ func (h *SaverTLSHandshaker) Handshake( return tlsconn, state, err } -var _ model.TLSHandshaker = &SaverTLSHandshaker{} +var _ model.TLSHandshaker = &TLSHandshakerSaver{} // tlsPeerCerts returns the certificates presented by the peer regardless // of whether the TLS handshake was successful diff --git a/internal/engine/netx/tracex/tls_test.go b/internal/engine/netx/tracex/tls_test.go index 96da0c6..fc0eaf5 100644 --- a/internal/engine/netx/tracex/tls_test.go +++ b/internal/engine/netx/tracex/tls_test.go @@ -3,291 +3,250 @@ package tracex import ( "context" "crypto/tls" - "reflect" + "crypto/x509" + "errors" + "net" "testing" - "time" - "github.com/ooni/probe-cli/v3/internal/model" - "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model/mocks" ) -func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { - // This is the most common use case for collecting reads, writes - if testing.Short() { - t.Skip("skip test in short mode") - } - nextprotos := []string{"h2"} - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{NextProtos: nextprotos}, - Dialer: netxlite.NewDialerWithResolver( - model.DiscardLogger, - netxlite.NewResolverStdlib(model.DiscardLogger), - saver.NewReadWriteObserver(), - ), - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - // Implementation note: we don't close the connection here because it is - // very handy to have the last event being the end of the handshake - _, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal(err) - } - ev := saver.Read() - if len(ev) < 4 { - // it's a bit tricky to be sure about the right number of - // events because network conditions may influence that - t.Fatal("unexpected number of events") - } - if ev[0].Name() != "tls_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Value().Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - last := len(ev) - 1 - for idx := 1; idx < last; idx++ { - if ev[idx].Value().Data == nil { - t.Fatal("unexpected Data") +func TestTLSHandshakerSaver(t *testing.T) { + + t.Run("Handshake", func(t *testing.T) { + checkStartEventFields := func(t *testing.T, value *EventValue) { + if value.Address != "8.8.8.8:443" { + t.Fatal("invalid Address") + } + if !value.NoTLSVerify { + t.Fatal("expected NoTLSVerify to be true") + } + if value.Proto != "tcp" { + t.Fatal("wrong protocol") + } + if diff := cmp.Diff(value.TLSNextProtos, []string{"h2"}); diff != "" { + t.Fatal(diff) + } + if value.TLSServerName != "dns.google" { + t.Fatal("invalid TLSServerName") + } + if value.Time.IsZero() { + t.Fatal("expected non zero time") + } } - if ev[idx].Value().Duration <= 0 { - t.Fatal("unexpected Duration") + + checkStartedEvent := func(t *testing.T, ev Event) { + if _, good := ev.(*EventTLSHandshakeStart); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) } - if ev[idx].Value().Err != nil { - t.Fatal("unexpected Err") + + checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err != nil { + t.Fatal("expected no error here") + } + if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" { + t.Fatal("invalid cipher suite") + } + if value.TLSNegotiatedProto != "h2" { + t.Fatal("invalid negotiated protocol") + } + if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" { + t.Fatal(diff) + } + if value.TLSVersion != "TLSv1.3" { + t.Fatal("invalid TLS version") + } } - if ev[idx].Value().NumBytes <= 0 { - t.Fatal("unexpected NumBytes") + + checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) { + if _, good := ev.(*EventTLSHandshakeDone); !good { + t.Fatal("invalid event type") + } + value := ev.Value() + checkStartEventFields(t, value) + fun(t, value) } - switch ev[idx].Name() { - case netxlite.ReadOperation, netxlite.WriteOperation: - default: - t.Fatal("unexpected Name") + + t.Run("on success", func(t *testing.T) { + saver := &Saver{} + returnedConnState := tls.ConnectionState{ + CipherSuite: tls.TLS_RSA_WITH_RC4_128_SHA, + NegotiatedProtocol: "h2", + PeerCertificates: []*x509.Certificate{}, + Version: tls.VersionTLS13, + } + returnedConn := &mocks.TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return returnedConnState + }, + } + thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, + config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return returnedConn, returnedConnState, nil + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + ServerName: "dns.google", + } + tcpConn := &mocks.Conn{ + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess) + }) + + checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) { + if value.Duration <= 0 { + t.Fatal("expected non-zero duration") + } + if value.Err == nil { + t.Fatal("expected non-nil error here") + } + if value.TLSCipherSuite != "" { + t.Fatal("invalid TLS cipher suite") + } + if value.TLSNegotiatedProto != "" { + t.Fatal("invalid negotiated proto") + } + if len(value.TLSPeerCerts) > 0 { + t.Fatal("expected no peer certs") + } + if value.TLSVersion != "" { + t.Fatal("invalid TLS version") + } } - if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) { - t.Fatal("unexpected Time") - } - } - if ev[last].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[last].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[last].Name() != "tls_handshake_done" { - t.Fatal("unexpected Name") - } - if ev[last].Value().TLSCipherSuite == "" { - t.Fatal("unexpected TLSCipherSuite") - } - if ev[last].Value().TLSNegotiatedProto != "h2" { - t.Fatal("unexpected TLSNegotiatedProto") - } - if !reflect.DeepEqual(ev[last].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[last].Value().TLSPeerCerts == nil { - t.Fatal("unexpected TLSPeerCerts") - } - if ev[last].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if ev[last].Value().TLSVersion == "" { - t.Fatal("unexpected TLSVersion") - } - if ev[last].Value().Time.Before(ev[last-1].Value().Time) { - t.Fatal("unexpected Time") - } + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + saver := &Saver{} + thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{ + MockHandshake: func(ctx context.Context, conn net.Conn, + config *tls.Config) (net.Conn, tls.ConnectionState, error) { + return nil, tls.ConnectionState{}, expected + }, + }) + ctx := context.Background() + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + ServerName: "dns.google", + } + tcpConn := &mocks.Conn{ + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockString: func() string { + return "8.8.8.8:443" + }, + MockNetwork: func() string { + return "tcp" + }, + } + }, + } + conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + events := saver.Read() + if len(events) != 2 { + t.Fatal("expected two events") + } + checkStartedEvent(t, events[0]) + checkDoneEvent(t, events[1], checkDoneEventFieldsFailure) + }) + }) } -func TestSaverTLSHandshakerSuccess(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") +func Test_tlsPeerCerts(t *testing.T) { + cert0 := &x509.Certificate{Raw: []byte{1, 2, 3, 4}} + type args struct { + state tls.ConnectionState + err error } - nextprotos := []string{"h2"} - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{NextProtos: nextprotos}, - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") - if err != nil { - t.Fatal(err) - } - conn.Close() - ev := saver.Read() - if len(ev) != 2 { - t.Fatal("unexpected number of events") - } - if ev[0].Name() != "tls_handshake_start" { - t.Fatal("unexpected Name") - } - if ev[0].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[0].Value().Time.After(time.Now()) { - t.Fatal("unexpected Time") - } - if ev[1].Value().Duration <= 0 { - t.Fatal("unexpected Duration") - } - if ev[1].Value().Err != nil { - t.Fatal("unexpected Err") - } - if ev[1].Name() != "tls_handshake_done" { - t.Fatal("unexpected Name") - } - if ev[1].Value().TLSCipherSuite == "" { - t.Fatal("unexpected TLSCipherSuite") - } - if ev[1].Value().TLSNegotiatedProto != "h2" { - t.Fatal("unexpected TLSNegotiatedProto") - } - if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) { - t.Fatal("unexpected TLSNextProtos") - } - if ev[1].Value().TLSPeerCerts == nil { - t.Fatal("unexpected TLSPeerCerts") - } - if ev[1].Value().TLSServerName != "www.google.com" { - t.Fatal("unexpected TLSServerName") - } - if ev[1].Value().TLSVersion == "" { - t.Fatal("unexpected TLSVersion") - } - if ev[1].Value().Time.Before(ev[0].Value().Time) { - t.Fatal("unexpected Time") - } -} - -func TestSaverTLSHandshakerHostnameError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "wrong.host.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "expired.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerAuthorityError(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "self-signed.badssl.com:443") - if err == nil { - t.Fatal("expected an error here") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify == true { - t.Fatal("expected NoTLSVerify to be false") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } - } -} - -func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - saver := &Saver{} - tlsdlr := &netxlite.TLSDialerLegacy{ - Config: &tls.Config{InsecureSkipVerify: true}, - Dialer: &netxlite.DialerSystem{}, - TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), - } - conn, err := tlsdlr.DialTLSContext( - context.Background(), "tcp", "self-signed.badssl.com:443") - if err != nil { - t.Fatal(err) - } - if conn == nil { - t.Fatal("expected non-nil conn here") - } - conn.Close() - for _, ev := range saver.Read() { - if ev.Name() != "tls_handshake_done" { - continue - } - if ev.Value().NoTLSVerify != true { - t.Fatal("expected NoTLSVerify to be true") - } - if len(ev.Value().TLSPeerCerts) < 1 { - t.Fatal("expected at least a certificate here") - } + tests := []struct { + name string + args args + want []*x509.Certificate + }{{ + name: "no error", + args: args{ + state: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{cert0}, + }, + }, + want: []*x509.Certificate{cert0}, + }, { + name: "all empty", + args: args{}, + want: nil, + }, { + name: "x509.HostnameError", + args: args{ + state: tls.ConnectionState{}, + err: x509.HostnameError{ + Certificate: cert0, + }, + }, + want: []*x509.Certificate{cert0}, + }, { + name: "x509.UnknownAuthorityError", + args: args{ + state: tls.ConnectionState{}, + err: x509.UnknownAuthorityError{ + Cert: cert0, + }, + }, + want: []*x509.Certificate{cert0}, + }, { + name: "x509.CertificateInvalidError", + args: args{ + state: tls.ConnectionState{}, + err: x509.CertificateInvalidError{ + Cert: cert0, + }, + }, + want: []*x509.Certificate{cert0}, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tlsPeerCerts(tt.args.state, tt.args.err) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatal(diff) + } + }) } }