refactor(tracex): convert to unit testing (#781)

The exercise already allowed me to notice issues such as fields not
being properly initialized by savers.

This is one of the last steps before moving tracex away from the
internal/netx package and into the internal package.

See https://github.com/ooni/probe/issues/2121
This commit is contained in:
Simone Basso 2022-06-01 23:15:47 +02:00 committed by GitHub
parent 6212daa54a
commit d397036073
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1693 additions and 1130 deletions

View File

@ -119,7 +119,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -195,7 +195,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -271,7 +271,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -347,7 +347,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }

View File

@ -187,7 +187,7 @@ func NewHTTPTransport(config Config) model.HTTPTransport {
txp = &netxlite.HTTPTransportLogger{Logger: config.Logger, HTTPTransport: txp} txp = &netxlite.HTTPTransportLogger{Logger: config.Logger, HTTPTransport: txp}
} }
if config.HTTPSaver != nil { if config.HTTPSaver != nil {
txp = &tracex.SaverTransactionHTTPTransport{ txp = &tracex.HTTPTransportSaver{
HTTPTransport: txp, Saver: config.HTTPSaver} HTTPTransport: txp, Saver: config.HTTPSaver}
} }
return txp return txp

View File

@ -126,7 +126,7 @@ func TestNewResolverWithSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
sr, ok := ir.Resolver.(*tracex.SaverResolver) sr, ok := ir.Resolver.(*tracex.ResolverSaver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -332,7 +332,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
if rtd.TLSHandshaker == nil { if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker") t.Fatal("invalid TLSHandshaker")
} }
sth, ok := rtd.TLSHandshaker.(*tracex.SaverTLSHandshaker) sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver)
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
@ -504,7 +504,7 @@ func TestNewWithSaver(t *testing.T) {
txp := netx.NewHTTPTransport(netx.Config{ txp := netx.NewHTTPTransport(netx.Config{
HTTPSaver: saver, HTTPSaver: saver,
}) })
stxptxp, ok := txp.(*tracex.SaverTransactionHTTPTransport) stxptxp, ok := txp.(*tracex.HTTPTransportSaver)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -622,7 +622,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(*tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -659,7 +659,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(*tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -700,7 +700,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(*tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -745,7 +745,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(*tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.DNSTransportSaver)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }

View File

@ -1,5 +1,9 @@
package tracex package tracex
//
// Code to generate the OONI archival data format from events
//
import ( import (
"crypto/x509" "crypto/x509"
"errors" "errors"
@ -43,8 +47,7 @@ var (
) )
// NewTCPConnectList creates a new TCPConnectList // NewTCPConnectList creates a new TCPConnectList
func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry { func NewTCPConnectList(begin time.Time, events []Event) (out []TCPConnectEntry) {
var out []TCPConnectEntry
for _, wrapper := range events { for _, wrapper := range events {
if _, ok := wrapper.(*EventConnectOperation); !ok { if _, ok := wrapper.(*EventConnectOperation); !ok {
continue continue
@ -60,13 +63,14 @@ func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry {
IP: ip, IP: ip,
Port: iport, Port: iport,
Status: TCPConnectStatus{ Status: TCPConnectStatus{
Blocked: nil, // only used by Web Connectivity
Failure: NewFailure(event.Err), Failure: NewFailure(event.Err),
Success: event.Err == nil, Success: event.Err == nil,
}, },
T: event.Time.Sub(begin).Seconds(), T: event.Time.Sub(begin).Seconds(),
}) })
} }
return out return
} }
// NewFailure creates a failure nullable string from the given error // NewFailure creates a failure nullable string from the given error
@ -101,11 +105,9 @@ func NewFailedOperation(err error) *string {
return &s return &s
} }
func httpAddHeaders( // httpAddHeaders adds the headers inside source into destList and destMap.
source http.Header, func httpAddHeaders(source http.Header, destList *[]HTTPHeader,
destList *[]HTTPHeader, destMap *map[string]MaybeBinaryValue) {
destMap *map[string]MaybeBinaryValue,
) {
*destList = []HTTPHeader{} *destList = []HTTPHeader{}
*destMap = make(map[string]model.ArchivalMaybeBinaryData) *destMap = make(map[string]model.ArchivalMaybeBinaryData)
for key, values := range source { for key, values := range source {
@ -122,32 +124,28 @@ func httpAddHeaders(
}) })
} }
} }
// Sorting helps with unit testing (map keys are unordered)
sort.Slice(*destList, func(i, j int) bool { sort.Slice(*destList, func(i, j int) bool {
return (*destList)[i].Key < (*destList)[j].Key return (*destList)[i].Key < (*destList)[j].Key
}) })
} }
// NewRequestList returns the list for "requests" // NewRequestList returns the list for "requests"
func NewRequestList(begin time.Time, events []Event) []RequestEntry { func NewRequestList(begin time.Time, events []Event) (out []RequestEntry) {
// OONI wants the last request to appear first // OONI wants the last request to appear first
var out []RequestEntry
tmp := newRequestList(begin, events) tmp := newRequestList(begin, events)
for i := len(tmp) - 1; i >= 0; i-- { for i := len(tmp) - 1; i >= 0; i-- {
out = append(out, tmp[i]) out = append(out, tmp[i])
} }
return out return
} }
func newRequestList(begin time.Time, events []Event) []RequestEntry { func newRequestList(begin time.Time, events []Event) (out []RequestEntry) {
var (
out []RequestEntry
entry RequestEntry
)
for _, wrapper := range events { for _, wrapper := range events {
ev := wrapper.Value() ev := wrapper.Value()
switch wrapper.(type) { switch wrapper.(type) {
case *EventHTTPTransactionDone: case *EventHTTPTransactionDone:
entry = RequestEntry{} entry := RequestEntry{}
entry.T = ev.Time.Sub(begin).Seconds() entry.T = ev.Time.Sub(begin).Seconds()
httpAddHeaders( httpAddHeaders(
ev.HTTPRequestHeaders, &entry.Request.HeadersList, &entry.Request.Headers) ev.HTTPRequestHeaders, &entry.Request.HeadersList, &entry.Request.Headers)
@ -164,15 +162,14 @@ func newRequestList(begin time.Time, events []Event) []RequestEntry {
out = append(out, entry) out = append(out, entry)
} }
} }
return out return
} }
type dnsQueryType string type dnsQueryType string
// NewDNSQueriesList returns a list of DNS queries. // NewDNSQueriesList returns a list of DNS queries.
func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry { func NewDNSQueriesList(begin time.Time, events []Event) (out []DNSQueryEntry) {
// TODO(bassosimone): add support for CNAME lookups. // TODO(bassosimone): add support for CNAME lookups.
var out []DNSQueryEntry
for _, wrapper := range events { for _, wrapper := range events {
if _, ok := wrapper.(*EventResolveDone); !ok { if _, ok := wrapper.(*EventResolveDone); !ok {
continue continue
@ -199,7 +196,7 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry {
out = append(out, entry) out = append(out, entry)
} }
} }
return out return
} }
func (qtype dnsQueryType) ipOfType(addr string) bool { func (qtype dnsQueryType) ipOfType(addr string) bool {
@ -214,6 +211,8 @@ func (qtype dnsQueryType) ipOfType(addr string) bool {
func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry { func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry {
answer := DNSAnswerEntry{AnswerType: string(qtype)} answer := DNSAnswerEntry{AnswerType: string(qtype)}
// Figuring out the ASN and the org here is not just a service to whoever
// is reading a JSON: Web Connectivity also depends on it!
asn, org, _ := geolocate.LookupASN(addr) asn, org, _ := geolocate.LookupASN(addr)
answer.ASN = int64(asn) answer.ASN = int64(asn)
answer.ASOrgName = org answer.ASOrgName = org
@ -237,9 +236,8 @@ func (qtype dnsQueryType) makeQueryEntry(begin time.Time, ev *EventValue) DNSQue
} }
} }
// NewNetworkEventsList returns a list of DNS queries. // NewNetworkEventsList returns a list of network events.
func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent { func NewNetworkEventsList(begin time.Time, events []Event) (out []NetworkEvent) {
var out []NetworkEvent
for _, wrapper := range events { for _, wrapper := range events {
ev := wrapper.Value() ev := wrapper.Value()
switch wrapper.(type) { switch wrapper.(type) {
@ -281,7 +279,7 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent {
NumBytes: int64(ev.NumBytes), NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(), T: ev.Time.Sub(begin).Seconds(),
}) })
default: default: // For example, "tls_handshake_done" (used in data analysis!)
out = append(out, NetworkEvent{ out = append(out, NetworkEvent{
Failure: NewFailure(ev.Err), Failure: NewFailure(ev.Err),
Operation: wrapper.Name(), Operation: wrapper.Name(),
@ -289,15 +287,14 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent {
}) })
} }
} }
return out return
} }
// NewTLSHandshakesList creates a new TLSHandshakesList // NewTLSHandshakesList creates a new TLSHandshakesList
func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake { func NewTLSHandshakesList(begin time.Time, events []Event) (out []TLSHandshake) {
var out []TLSHandshake
for _, wrapper := range events { for _, wrapper := range events {
switch wrapper.(type) { switch wrapper.(type) {
case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // ok case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // interested
default: default:
continue // not interested continue // not interested
} }
@ -314,12 +311,12 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake {
TLSVersion: ev.TLSVersion, TLSVersion: ev.TLSVersion,
}) })
} }
return out return
} }
func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) { func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) {
for _, e := range in { for _, entry := range in {
out = append(out, MaybeBinaryValue{Value: string(e.Raw)}) out = append(out, MaybeBinaryValue{Value: string(entry.Raw)})
} }
return return
} }

View File

@ -6,7 +6,6 @@ import (
"errors" "errors"
"io" "io"
"net/http" "net/http"
"reflect"
"testing" "testing"
"time" "time"
@ -15,42 +14,44 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
func TestDNSQueryIPOfType(t *testing.T) { func TestDNSQueryType(t *testing.T) {
type expectation struct { t.Run("ipOfType", func(t *testing.T) {
qtype dnsQueryType type expectation struct {
ip string qtype dnsQueryType
output bool ip string
} output bool
var expectations = []expectation{{
qtype: "A",
ip: "8.8.8.8",
output: true,
}, {
qtype: "A",
ip: "2a00:1450:4002:801::2004",
output: false,
}, {
qtype: "AAAA",
ip: "8.8.8.8",
output: false,
}, {
qtype: "AAAA",
ip: "2a00:1450:4002:801::2004",
output: true,
}, {
qtype: "ANTANI",
ip: "2a00:1450:4002:801::2004",
output: false,
}, {
qtype: "ANTANI",
ip: "8.8.8.8",
output: false,
}}
for _, exp := range expectations {
if exp.qtype.ipOfType(exp.ip) != exp.output {
t.Fatalf("failure for %+v", exp)
} }
} var expectations = []expectation{{
qtype: "A",
ip: "8.8.8.8",
output: true,
}, {
qtype: "A",
ip: "2a00:1450:4002:801::2004",
output: false,
}, {
qtype: "AAAA",
ip: "8.8.8.8",
output: false,
}, {
qtype: "AAAA",
ip: "2a00:1450:4002:801::2004",
output: true,
}, {
qtype: "ANTANI",
ip: "2a00:1450:4002:801::2004",
output: false,
}, {
qtype: "ANTANI",
ip: "8.8.8.8",
output: false,
}}
for _, exp := range expectations {
if exp.qtype.ipOfType(exp.ip) != exp.output {
t.Fatalf("failure for %+v", exp)
}
}
})
} }
func TestNewTCPConnectList(t *testing.T) { func TestNewTCPConnectList(t *testing.T) {
@ -74,7 +75,7 @@ func TestNewTCPConnectList(t *testing.T) {
name: "realistic run", name: "realistic run",
args: args{ args: args{
begin: begin, begin: begin,
events: []Event{&EventResolveDone{&EventValue{ events: []Event{&EventResolveDone{&EventValue{ // skipped because not relevant
Addresses: []string{"8.8.8.8", "8.8.4.4"}, Addresses: []string{"8.8.8.8", "8.8.4.4"},
Hostname: "dns.google.com", Hostname: "dns.google.com",
Time: begin.Add(100 * time.Millisecond), Time: begin.Add(100 * time.Millisecond),
@ -86,7 +87,7 @@ func TestNewTCPConnectList(t *testing.T) {
}}, &EventConnectOperation{&EventValue{ }}, &EventConnectOperation{&EventValue{
Address: "8.8.8.8:853", Address: "8.8.8.8:853",
Duration: 55 * time.Millisecond, Duration: 55 * time.Millisecond,
Proto: "udp", Proto: "udp", // this one should be skipped because it's UDP
Time: begin.Add(130 * time.Millisecond), Time: begin.Add(130 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{ }}, &EventConnectOperation{&EventValue{
Address: "8.8.4.4:53", Address: "8.8.4.4:53",
@ -115,8 +116,9 @@ func TestNewTCPConnectList(t *testing.T) {
}} }}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := NewTCPConnectList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { got := NewTCPConnectList(tt.args.begin, tt.args.events)
t.Error(cmp.Diff(got, tt.want)) if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
} }
}) })
} }
@ -143,6 +145,7 @@ func TestNewRequestList(t *testing.T) {
name: "realistic run", name: "realistic run",
args: args{ args: args{
begin: begin, begin: begin,
// Two round trips so we can test the sorting expected by OONI
events: []Event{&EventHTTPTransactionDone{&EventValue{ events: []Event{&EventHTTPTransactionDone{&EventValue{
HTTPRequestHeaders: http.Header{ HTTPRequestHeaders: http.Header{
"User-Agent": []string{"miniooni/0.1.0-dev"}, "User-Agent": []string{"miniooni/0.1.0-dev"},
@ -286,8 +289,9 @@ func TestNewRequestList(t *testing.T) {
}} }}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := NewRequestList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { got := NewRequestList(tt.args.begin, tt.args.events)
t.Error(cmp.Diff(tt.want, got)) if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
} }
}) })
} }
@ -320,12 +324,12 @@ func TestNewDNSQueriesList(t *testing.T) {
Hostname: "dns.google.com", Hostname: "dns.google.com",
Proto: "dot", Proto: "dot",
Time: begin.Add(100 * time.Millisecond), Time: begin.Add(100 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{ }}, &EventConnectOperation{&EventValue{ // skipped because not relevant
Address: "8.8.8.8:853", Address: "8.8.8.8:853",
Duration: 30 * time.Millisecond, Duration: 30 * time.Millisecond,
Proto: "tcp", Proto: "tcp",
Time: begin.Add(130 * time.Millisecond), Time: begin.Add(130 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{ }}, &EventConnectOperation{&EventValue{ // skipped because not relevant
Address: "8.8.4.4:53", Address: "8.8.4.4:53",
Duration: 50 * time.Millisecond, Duration: 50 * time.Millisecond,
Err: io.EOF, Err: io.EOF,
@ -452,6 +456,10 @@ func TestNewNetworkEventsList(t *testing.T) {
Err: websocket.ErrBadHandshake, Err: websocket.ErrBadHandshake,
NumBytes: 4114, NumBytes: 4114,
Time: begin.Add(14 * time.Millisecond), Time: begin.Add(14 * time.Millisecond),
}}, &EventResolveStart{&EventValue{
// We expect this event to be logged event though it's not a typical I/O
// event (it seems these extra events are useful for debugging)
Time: begin.Add(15 * time.Millisecond),
}}}, }}},
}, },
want: []NetworkEvent{{ want: []NetworkEvent{{
@ -482,12 +490,16 @@ func TestNewNetworkEventsList(t *testing.T) {
NumBytes: 4114, NumBytes: 4114,
Operation: netxlite.WriteToOperation, Operation: netxlite.WriteToOperation,
T: 0.014, T: 0.014,
}, {
Operation: "resolve_start",
T: 0.015,
}}, }},
}} }}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := NewNetworkEventsList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { got := NewNetworkEventsList(tt.args.begin, tt.args.events)
t.Error(cmp.Diff(got, tt.want)) if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
} }
}) })
} }
@ -511,13 +523,14 @@ func TestNewTLSHandshakesList(t *testing.T) {
}, },
want: nil, want: nil,
}, { }, {
name: "realistic run", name: "realistic run with TLS",
args: args{ args: args{
begin: begin, begin: begin,
events: []Event{&EventTLSHandshakeDone{&EventValue{ events: []Event{&EventTLSHandshakeDone{&EventValue{
Address: "131.252.210.176:443", Address: "131.252.210.176:443",
Err: io.EOF, Err: io.EOF,
NoTLSVerify: false, NoTLSVerify: false,
Proto: "tcp",
TLSCipherSuite: "SUITE", TLSCipherSuite: "SUITE",
TLSNegotiatedProto: "h2", TLSNegotiatedProto: "h2",
TLSPeerCerts: []*x509.Certificate{{ TLSPeerCerts: []*x509.Certificate{{
@ -545,11 +558,57 @@ func TestNewTLSHandshakesList(t *testing.T) {
T: 0.055, T: 0.055,
TLSVersion: "TLSv1.3", TLSVersion: "TLSv1.3",
}}, }},
}, {
name: "realistic run with QUIC",
args: args{
begin: begin,
events: []Event{&EventQUICHandshakeDone{&EventValue{
Address: "131.252.210.176:443",
Err: io.EOF,
NoTLSVerify: false,
Proto: "quic",
TLSCipherSuite: "SUITE",
TLSNegotiatedProto: "h3",
TLSPeerCerts: []*x509.Certificate{{
Raw: []byte("deadbeef"),
}, {
Raw: []byte("abad1dea"),
}},
TLSServerName: "x.org",
TLSVersion: "TLSv1.3",
Time: begin.Add(55 * time.Millisecond),
}}},
},
want: []TLSHandshake{{
Address: "131.252.210.176:443",
CipherSuite: "SUITE",
Failure: NewFailure(io.EOF),
NegotiatedProtocol: "h3",
NoTLSVerify: false,
PeerCertificates: []MaybeBinaryValue{{
Value: "deadbeef",
}, {
Value: "abad1dea",
}},
ServerName: "x.org",
T: 0.055,
TLSVersion: "TLSv1.3",
}},
}, {
name: "realistic run with no suitable events",
args: args{
begin: begin,
events: []Event{&EventResolveStart{&EventValue{
Time: begin.Add(55 * time.Millisecond),
}}},
},
want: nil,
}} }}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := NewTLSHandshakesList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) { got := NewTLSHandshakesList(tt.args.begin, tt.args.events)
t.Error(cmp.Diff(got, tt.want)) if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
} }
}) })
} }

View File

@ -12,8 +12,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
// SaverDialer saves events occurring during the dial // DialerSaver saves events occurring during the dial
type SaverDialer struct { type DialerSaver struct {
// Dialer is the underlying dialer, // Dialer is the underlying dialer,
Dialer model.Dialer Dialer model.Dialer
@ -28,26 +28,26 @@ func (s *Saver) NewConnectObserver() model.DialerWrapper {
if s == nil { if s == nil {
return nil // valid DialerWrapper according to netxlite's docs return nil // valid DialerWrapper according to netxlite's docs
} }
return &saverDialerWrapper{ return &dialerConnectObserver{
saver: s, saver: s,
} }
} }
type saverDialerWrapper struct { type dialerConnectObserver struct {
saver *Saver saver *Saver
} }
var _ model.DialerWrapper = &saverDialerWrapper{} var _ model.DialerWrapper = &dialerConnectObserver{}
func (w *saverDialerWrapper) WrapDialer(d model.Dialer) model.Dialer { func (w *dialerConnectObserver) WrapDialer(d model.Dialer) model.Dialer {
return &SaverDialer{ return &DialerSaver{
Dialer: d, Dialer: d,
Saver: w.saver, Saver: w.saver,
} }
} }
// DialContext implements Dialer.DialContext // DialContext implements Dialer.DialContext
func (d *SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *DialerSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
start := time.Now() start := time.Now()
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
stop := time.Now() stop := time.Now()
@ -61,13 +61,13 @@ func (d *SaverDialer) DialContext(ctx context.Context, network, address string)
return conn, err return conn, err
} }
func (d *SaverDialer) CloseIdleConnections() { func (d *DialerSaver) CloseIdleConnections() {
d.Dialer.CloseIdleConnections() d.Dialer.CloseIdleConnections()
} }
// SaverConnDialer wraps the returned connection such that we // DialerConnSaver wraps the returned connection such that we
// collect all the read/write events that occur. // collect all the read/write events that occur.
type SaverConnDialer struct { type DialerConnSaver struct {
// Dialer is the underlying dialer // Dialer is the underlying dialer
Dialer model.Dialer Dialer model.Dialer
@ -82,70 +82,78 @@ func (s *Saver) NewReadWriteObserver() model.DialerWrapper {
if s == nil { if s == nil {
return nil // valid DialerWrapper according to netxlite's docs return nil // valid DialerWrapper according to netxlite's docs
} }
return &saverReadWriteWrapper{ return &dialerReadWriteObserver{
saver: s, saver: s,
} }
} }
type saverReadWriteWrapper struct { type dialerReadWriteObserver struct {
saver *Saver saver *Saver
} }
var _ model.DialerWrapper = &saverReadWriteWrapper{} var _ model.DialerWrapper = &dialerReadWriteObserver{}
func (w *saverReadWriteWrapper) WrapDialer(d model.Dialer) model.Dialer { func (w *dialerReadWriteObserver) WrapDialer(d model.Dialer) model.Dialer {
return &SaverConnDialer{ return &DialerConnSaver{
Dialer: d, Dialer: d,
Saver: w.saver, Saver: w.saver,
} }
} }
// DialContext implements Dialer.DialContext // DialContext implements Dialer.DialContext
func (d *SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *DialerConnSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &saverConn{saver: d.Saver, Conn: conn}, nil return &dialerConnWrapper{saver: d.Saver, Conn: conn}, nil
} }
func (d *SaverConnDialer) CloseIdleConnections() { func (d *DialerConnSaver) CloseIdleConnections() {
d.Dialer.CloseIdleConnections() d.Dialer.CloseIdleConnections()
} }
type saverConn struct { type dialerConnWrapper struct {
net.Conn net.Conn
saver *Saver saver *Saver
} }
func (c *saverConn) Read(p []byte) (int, error) { func (c *dialerConnWrapper) Read(p []byte) (int, error) {
proto := c.Conn.RemoteAddr().Network()
remoteAddr := c.Conn.RemoteAddr().String()
start := time.Now() start := time.Now()
count, err := c.Conn.Read(p) count, err := c.Conn.Read(p)
stop := time.Now() stop := time.Now()
c.saver.Write(&EventReadOperation{&EventValue{ c.saver.Write(&EventReadOperation{&EventValue{
Address: remoteAddr,
Data: p[:count], Data: p[:count],
Duration: stop.Sub(start), Duration: stop.Sub(start),
Err: err, Err: err,
NumBytes: count, NumBytes: count,
Proto: proto,
Time: stop, Time: stop,
}}) }})
return count, err return count, err
} }
func (c *saverConn) Write(p []byte) (int, error) { func (c *dialerConnWrapper) Write(p []byte) (int, error) {
proto := c.Conn.RemoteAddr().Network()
remoteAddr := c.Conn.RemoteAddr().String()
start := time.Now() start := time.Now()
count, err := c.Conn.Write(p) count, err := c.Conn.Write(p)
stop := time.Now() stop := time.Now()
c.saver.Write(&EventWriteOperation{&EventValue{ c.saver.Write(&EventWriteOperation{&EventValue{
Address: remoteAddr,
Data: p[:count], Data: p[:count],
Duration: stop.Sub(start), Duration: stop.Sub(start),
Err: err, Err: err,
NumBytes: count, NumBytes: count,
Proto: proto,
Time: stop, Time: stop,
}}) }})
return count, err return count, err
} }
var _ model.Dialer = &SaverDialer{} var _ model.Dialer = &DialerSaver{}
var _ model.Dialer = &SaverConnDialer{} var _ model.Dialer = &DialerConnSaver{}
var _ net.Conn = &saverConn{} var _ net.Conn = &dialerConnWrapper{}

View File

@ -12,127 +12,268 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
func TestSaverDialerFailure(t *testing.T) { func TestDialerConnectObserver(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{} saver := &Saver{}
dlr := &SaverDialer{ obs := &dialerConnectObserver{
Dialer: &mocks.Dialer{ saver: saver,
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, expected
},
},
Saver: saver,
} }
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") dialer := &mocks.Dialer{}
if !errors.Is(err, expected) { out := obs.WrapDialer(dialer)
t.Fatal("expected another error here") dialSaver := out.(*DialerSaver)
if dialSaver.Dialer != dialer {
t.Fatal("invalid dialer")
} }
if conn != nil { if dialSaver.Saver != saver {
t.Fatal("expected nil conn here") t.Fatal("invalid saver")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected a single event here")
}
if ev[0].Value().Address != "www.google.com:443" {
t.Fatal("unexpected Address")
}
if ev[0].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[0].Value().Err, expected) {
t.Fatal("unexpected Err")
}
if ev[0].Name() != netxlite.ConnectOperation {
t.Fatal("unexpected Name")
}
if ev[0].Value().Proto != "tcp" {
t.Fatal("unexpected Proto")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("unexpected Time")
} }
} }
func TestSaverConnDialerFailure(t *testing.T) { func TestDialerSaver(t *testing.T) {
expected := errors.New("mocked error") t.Run("on failure", func(t *testing.T) {
saver := &Saver{} expected := errors.New("mocked error")
dlr := &SaverConnDialer{ saver := &Saver{}
Dialer: &mocks.Dialer{ dlr := &DialerSaver{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, expected
},
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
func TestSaverConnDialerSuccess(t *testing.T) {
saver := &Saver{}
dlr := &SaverConnDialer{
Dialer: &SaverDialer{
Dialer: &mocks.Dialer{ Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return &mocks.Conn{ return nil, expected
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
MockClose: func() error {
return io.EOF
},
MockLocalAddr: func() net.Addr {
return &net.TCPAddr{Port: 12345}
},
}, nil
}, },
}, },
Saver: saver, Saver: saver,
}, }
Saver: saver, conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
} if !errors.Is(err, expected) {
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") t.Fatal("expected another error here")
if err != nil { }
t.Fatal("not the error we expected", err) if conn != nil {
} t.Fatal("expected nil conn here")
conn.Read(nil) }
conn.Write(nil) ev := saver.Read()
conn.Close() if len(ev) != 1 {
events := saver.Read() t.Fatal("expected a single event here")
if len(events) != 3 { }
t.Fatal("unexpected number of events saved", len(events)) if ev[0].Value().Address != "www.google.com:443" {
} t.Fatal("unexpected Address")
if events[0].Name() != "connect" { }
t.Fatal("expected a connect event") if ev[0].Value().Duration <= 0 {
} t.Fatal("unexpected Duration")
saverCheckConnectEvent(t, &events[0]) }
if events[1].Name() != "read" { if !errors.Is(ev[0].Value().Err, expected) {
t.Fatal("expected a read event") t.Fatal("unexpected Err")
} }
saverCheckReadEvent(t, &events[1]) if ev[0].Name() != netxlite.ConnectOperation {
if events[2].Name() != "write" { t.Fatal("unexpected Name")
t.Fatal("expected a write event") }
} if ev[0].Value().Proto != "tcp" {
saverCheckWriteEvent(t, &events[2]) t.Fatal("unexpected Proto")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := &DialerSaver{
Dialer: child,
Saver: &Saver{},
}
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
} }
func saverCheckConnectEvent(t *testing.T, ev *Event) { func TestDialerReadWriteObserver(t *testing.T) {
// TODO(bassosimone): implement saver := &Saver{}
obs := &dialerReadWriteObserver{
saver: saver,
}
dialer := &mocks.Dialer{}
out := obs.WrapDialer(dialer)
dialSaver := out.(*DialerConnSaver)
if dialSaver.Dialer != dialer {
t.Fatal("invalid dialer")
}
if dialSaver.Saver != saver {
t.Fatal("invalid saver")
}
} }
func saverCheckReadEvent(t *testing.T, ev *Event) { func TestDialerConnSaver(t *testing.T) {
// TODO(bassosimone): implement t.Run("DialContext", func(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
dlr := &DialerConnSaver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, expected
},
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
t.Run("on success", func(t *testing.T) {
origConn := &mocks.Conn{}
saver := &Saver{}
dlr := &DialerConnSaver{
Dialer: &DialerSaver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return origConn, nil
},
},
Saver: saver,
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if err != nil {
t.Fatal("not the error we expected", err)
}
cw := conn.(*dialerConnWrapper)
if cw.Conn != origConn {
t.Fatal("unexpected conn")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := &DialerConnSaver{
Dialer: child,
Saver: &Saver{},
}
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
} }
func saverCheckWriteEvent(t *testing.T, ev *Event) { func TestDialerConnWrapper(t *testing.T) {
// TODO(bassosimone): implement t.Run("Read", func(t *testing.T) {
baseConn := &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "www.google.com:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
saver := &Saver{}
conn := &dialerConnWrapper{
Conn: baseConn,
saver: saver,
}
data := make([]byte, 155)
count, err := conn.Read(data)
if !errors.Is(err, io.EOF) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected a single event here")
}
if ev[0].Value().Address != "www.google.com:443" {
t.Fatal("unexpected Address")
}
if ev[0].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[0].Value().Err, io.EOF) {
t.Fatal("unexpected Err")
}
if ev[0].Name() != netxlite.ReadOperation {
t.Fatal("unexpected Name")
}
if ev[0].Value().Proto != "tcp" {
t.Fatal("unexpected Proto")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
})
t.Run("Write", func(t *testing.T) {
baseConn := &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "www.google.com:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
saver := &Saver{}
conn := &dialerConnWrapper{
Conn: baseConn,
saver: saver,
}
data := make([]byte, 155)
count, err := conn.Write(data)
if !errors.Is(err, io.EOF) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected a single event here")
}
if ev[0].Value().Address != "www.google.com:443" {
t.Fatal("unexpected Address")
}
if ev[0].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[0].Value().Err, io.EOF) {
t.Fatal("unexpected Err")
}
if ev[0].Name() != netxlite.WriteOperation {
t.Fatal("unexpected Name")
}
if ev[0].Value().Proto != "tcp" {
t.Fatal("unexpected Proto")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
})
} }

View File

@ -1,5 +1,9 @@
package tracex package tracex
//
// All the possible events
//
import ( import (
"crypto/x509" "crypto/x509"
"net/http" "net/http"

View File

@ -27,11 +27,17 @@ func httpCloneRequestHeaders(req *http.Request) http.Header {
return header return header
} }
// SaverTransactionHTTPTransport is a RoundTripper that saves // HTTPTransportSaver is a RoundTripper that saves
// events related to the HTTP transaction // events related to the HTTP transaction
type SaverTransactionHTTPTransport struct { type HTTPTransportSaver struct {
model.HTTPTransport // HTTPTransport is the MANDATORY underlying HTTP transport.
Saver *Saver HTTPTransport model.HTTPTransport
// Saver is the MANDATORY saver to use.
Saver *Saver
// SnapshotSize is the OPTIONAL maximum body snapshot size (if not set, we'll
// use 1<<17, which we've been using since the ooni/netx days)
SnapshotSize int64 SnapshotSize int64
} }
@ -40,7 +46,11 @@ type SaverTransactionHTTPTransport struct {
// //
// The maxBodySnapshotSize argument controls the maximum size of the // The maxBodySnapshotSize argument controls the maximum size of the
// body snapshot that we collect along with the HTTP round trip. // body snapshot that we collect along with the HTTP round trip.
func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (txp *HTTPTransportSaver) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO(bassosimone): we're currently using the started time for
// the transaction done event, which contrasts with what we do for
// every other event. What does the spec say?
started := time.Now() started := time.Now()
txp.Saver.Write(&EventHTTPTransactionStart{&EventValue{ txp.Saver.Write(&EventHTTPTransactionStart{&EventValue{
@ -92,7 +102,15 @@ func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Re
return resp, nil return resp, nil
} }
func (txp *SaverTransactionHTTPTransport) snapshotSize() int64 { func (txp *HTTPTransportSaver) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
}
func (txp *HTTPTransportSaver) Network() string {
return txp.HTTPTransport.Network()
}
func (txp *HTTPTransportSaver) snapshotSize() int64 {
if txp.SnapshotSize > 0 { if txp.SnapshotSize > 0 {
return txp.SnapshotSize return txp.SnapshotSize
} }
@ -104,4 +122,4 @@ type httpReadableAgainBody struct {
io.Closer io.Closer
} }
var _ model.HTTPTransport = &SaverTransactionHTTPTransport{} var _ model.HTTPTransport = &HTTPTransportSaver{}

View File

@ -16,227 +16,262 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering" "github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
) )
func TestSaverTransactionHTTPTransport(t *testing.T) { func TestHTTPTransportSaver(t *testing.T) {
startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) { t.Run("CloseIdleConnections", func(t *testing.T) {
server := &filtering.HTTPProxy{ var called bool
OnIncomingHost: func(host string) filtering.HTTPAction { child := &mocks.HTTPTransport{
return action MockCloseIdleConnections: func() {
called = true
}, },
} }
listener, err := server.Start("127.0.0.1:0") dialer := &HTTPTransportSaver{
if err != nil { HTTPTransport: child,
t.Fatal(err) Saver: &Saver{},
} }
URL := &url.URL{ dialer.CloseIdleConnections()
Scheme: "http", if !called {
Host: listener.Addr().String(), t.Fatal("not called")
Path: "/",
}
return listener, URL
}
measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) {
saver := &Saver{}
txp := &SaverTransactionHTTPTransport{
HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger),
Saver: saver,
}
req, err := http.NewRequest("GET", URL.String(), nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("User-Agent", "miniooni")
resp, err := txp.RoundTrip(req)
return resp, saver, err
}
validateRequestFields := func(t *testing.T, value *EventValue, URL *url.URL) {
if value.HTTPMethod != "GET" {
t.Fatal("invalid method")
}
if value.HTTPRequestHeaders.Get("Host") != URL.Host {
t.Fatal("invalid Host header")
}
if value.HTTPRequestHeaders.Get("User-Agent") != "miniooni" {
t.Fatal("invalid User-Agent header")
}
if value.HTTPURL != URL.String() {
t.Fatal("invalid URL")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
if value.Transport != "tcp" {
t.Fatal("expected Transport to be tcp")
}
}
validateRequest := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionStart); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_start" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
}
validateResponseSuccess := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionDone); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_done" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if len(value.HTTPResponseHeaders) <= 0 {
t.Fatal("expected at least one response header")
}
if !bytes.Equal(value.HTTPResponseBody, filtering.HTTPBlockpage451) {
t.Fatal("unexpected value for response body")
}
if value.HTTPStatusCode != 451 {
t.Fatal("unexpected status code")
}
}
t.Run("on success", func(t *testing.T) {
listener, URL := startServer(t, filtering.HTTPAction451)
defer listener.Close()
resp, saver, err := measureHTTP(t, URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 451 {
t.Fatal("unexpected status code", resp.StatusCode)
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
validateRequest(t, events[0], URL)
validateResponseSuccess(t, events[1], URL)
data, err := netxlite.ReadAllContext(context.Background(), resp.Body)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, filtering.HTTPBlockpage451) {
t.Fatal("we cannot re-read the same body")
} }
}) })
validateResponseFailure := func(t *testing.T, ev Event, URL *url.URL) { t.Run("Network", func(t *testing.T) {
if _, good := ev.(*EventHTTPTransactionDone); !good { expected := "antani"
t.Fatal("invalid event type") child := &mocks.HTTPTransport{
MockNetwork: func() string {
return expected
},
} }
if ev.Name() != "http_transaction_done" { dialer := &HTTPTransportSaver{
t.Fatal("invalid event name") HTTPTransport: child,
Saver: &Saver{},
} }
value := ev.Value() if dialer.Network() != expected {
validateRequestFields(t, value, URL) t.Fatal("unexpected Network")
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
} }
if value.Err.Error() != "connection_reset" {
t.Fatal("unexpected Err value")
}
if len(value.HTTPResponseHeaders) > 0 {
t.Fatal("expected zero response headers")
}
if !bytes.Equal(value.HTTPResponseBody, nil) {
t.Fatal("unexpected value for response body")
}
if value.HTTPStatusCode != 0 {
t.Fatal("unexpected status code")
}
}
t.Run("on round trip failure", func(t *testing.T) {
listener, URL := startServer(t, filtering.HTTPActionReset)
defer listener.Close()
resp, saver, err := measureHTTP(t, URL)
if err == nil || err.Error() != "connection_reset" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
validateRequest(t, events[0], URL)
validateResponseFailure(t, events[1], URL)
}) })
// Sometimes useful for testing t.Run("RoundTrip", func(t *testing.T) {
/* startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) {
dumplog := func(t *testing.T, ev Event) { server := &filtering.HTTPProxy{
data, _ := json.MarshalIndent(ev.Value(), " ", " ") OnIncomingHost: func(host string) filtering.HTTPAction {
t.Log(string(data)) return action
t.FailNow() },
}
listener, err := server.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
URL := &url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: "/",
}
return listener, URL
} }
*/
t.Run("on error reading the response body", func(t *testing.T) { measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) {
saver := &Saver{} saver := &Saver{}
expected := errors.New("mocked error") txp := &HTTPTransportSaver{
txp := SaverTransactionHTTPTransport{ HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger),
HTTPTransport: &mocks.HTTPTransport{ Saver: saver,
MockRoundTrip: func(req *http.Request) (*http.Response, error) { }
return &http.Response{ req, err := http.NewRequest("GET", URL.String(), nil)
Header: http.Header{ if err != nil {
"Server": {"antani"}, t.Fatal(err)
}, }
StatusCode: 200, req.Header.Add("User-Agent", "miniooni")
Body: io.NopCloser(&mocks.Reader{ resp, err := txp.RoundTrip(req)
MockRead: func(b []byte) (int, error) { return resp, saver, err
return 0, expected }
validateRequestFields := func(t *testing.T, value *EventValue, URL *url.URL) {
if value.HTTPMethod != "GET" {
t.Fatal("invalid method")
}
if value.HTTPRequestHeaders.Get("Host") != URL.Host {
t.Fatal("invalid Host header")
}
if value.HTTPRequestHeaders.Get("User-Agent") != "miniooni" {
t.Fatal("invalid User-Agent header")
}
if value.HTTPURL != URL.String() {
t.Fatal("invalid URL")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
if value.Transport != "tcp" {
t.Fatal("expected Transport to be tcp")
}
}
validateRequest := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionStart); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_start" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
}
validateResponseSuccess := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionDone); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_done" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if len(value.HTTPResponseHeaders) <= 0 {
t.Fatal("expected at least one response header")
}
if !bytes.Equal(value.HTTPResponseBody, filtering.HTTPBlockpage451) {
t.Fatal("unexpected value for response body")
}
if value.HTTPStatusCode != 451 {
t.Fatal("unexpected status code")
}
}
t.Run("on success", func(t *testing.T) {
listener, URL := startServer(t, filtering.HTTPAction451)
defer listener.Close()
resp, saver, err := measureHTTP(t, URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 451 {
t.Fatal("unexpected status code", resp.StatusCode)
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
validateRequest(t, events[0], URL)
validateResponseSuccess(t, events[1], URL)
data, err := netxlite.ReadAllContext(context.Background(), resp.Body)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, filtering.HTTPBlockpage451) {
t.Fatal("we cannot re-read the same body")
}
})
validateResponseFailure := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionDone); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_done" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if value.Err.Error() != "connection_reset" {
t.Fatal("unexpected Err value")
}
if len(value.HTTPResponseHeaders) > 0 {
t.Fatal("expected zero response headers")
}
if !bytes.Equal(value.HTTPResponseBody, nil) {
t.Fatal("unexpected value for response body")
}
if value.HTTPStatusCode != 0 {
t.Fatal("unexpected status code")
}
}
t.Run("on round trip failure", func(t *testing.T) {
listener, URL := startServer(t, filtering.HTTPActionReset)
defer listener.Close()
resp, saver, err := measureHTTP(t, URL)
if err == nil || err.Error() != "connection_reset" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
validateRequest(t, events[0], URL)
validateResponseFailure(t, events[1], URL)
})
// Sometimes useful for testing
/*
dump := func(t *testing.T, ev Event) {
data, _ := json.MarshalIndent(ev.Value(), " ", " ")
t.Log(string(data))
t.Fail()
}
*/
t.Run("on error reading the response body", func(t *testing.T) {
saver := &Saver{}
expected := errors.New("mocked error")
txp := HTTPTransportSaver{
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return &http.Response{
Header: http.Header{
"Server": {"antani"},
}, },
}), StatusCode: 200,
}, nil Body: io.NopCloser(&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}),
}, nil
},
MockNetwork: func() string {
return "tcp"
},
}, },
MockNetwork: func() string { SnapshotSize: 4,
return "tcp" Saver: saver,
}, }
}, URL := &url.URL{
SnapshotSize: 4, Scheme: "http",
Saver: saver, Host: "127.0.0.1:9050",
} }
URL := &url.URL{ req, err := http.NewRequest("GET", URL.String(), nil)
Scheme: "http", if err != nil {
Host: "127.0.0.1:9050", t.Fatal(err)
} }
req, err := http.NewRequest("GET", URL.String(), nil) req.Header.Add("User-Agent", "miniooni")
if err != nil { resp, err := txp.RoundTrip(req)
t.Fatal(err) if !errors.Is(err, expected) {
} t.Fatal("not the error we expected")
req.Header.Add("User-Agent", "miniooni") }
resp, err := txp.RoundTrip(req) if resp != nil {
if !errors.Is(err, expected) { t.Fatal("expected nil response")
t.Fatal("not the error we expected") }
} ev := saver.Read()
if resp != nil { validateRequest(t, ev[0], URL)
t.Fatal("expected nil response") if ev[1].Value().HTTPStatusCode != 200 {
} t.Fatal("invalid status code")
ev := saver.Read() }
validateRequest(t, ev[0], URL) if ev[1].Value().HTTPResponseHeaders.Get("Server") != "antani" {
if ev[1].Value().HTTPStatusCode != 200 { t.Fatal("invalid Server header")
t.Fatal("invalid status code") }
} if ev[1].Value().Err.Error() != "unknown_failure: mocked error" {
if ev[1].Value().HTTPResponseHeaders.Get("Server") != "antani" { t.Fatal("invalid error")
t.Fatal("invalid Server header") }
} })
if ev[1].Value().Err.Error() != "unknown_failure: mocked error" {
t.Fatal("invalid error")
}
}) })
} }

View File

@ -15,8 +15,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
// QUICHandshakeSaver saves events occurring during the QUIC handshake. // QUICDialerSaver saves events occurring during the QUIC handshake.
type QUICHandshakeSaver struct { type QUICDialerSaver struct {
// QUICDialer is the wrapped dialer // QUICDialer is the wrapped dialer
QUICDialer model.QUICDialer QUICDialer model.QUICDialer
@ -33,14 +33,14 @@ func (s *Saver) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer {
if s == nil { if s == nil {
return qd return qd
} }
return &QUICHandshakeSaver{ return &QUICDialerSaver{
QUICDialer: qd, QUICDialer: qd,
Saver: s, Saver: s,
} }
} }
// DialContext implements QUICDialer.DialContext // DialContext implements QUICDialer.DialContext
func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string, func (h *QUICDialerSaver) DialContext(ctx context.Context, network string,
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
start := time.Now() start := time.Now()
// TODO(bassosimone): in the future we probably want to also save // TODO(bassosimone): in the future we probably want to also save
@ -58,9 +58,11 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
if err != nil { if err != nil {
// TODO(bassosimone): here we should save the peer certs // TODO(bassosimone): here we should save the peer certs
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{ h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
Address: host,
Duration: stop.Sub(start), Duration: stop.Sub(start),
Err: err, Err: err,
NoTLSVerify: tlsCfg.InsecureSkipVerify, NoTLSVerify: tlsCfg.InsecureSkipVerify,
Proto: network,
TLSNextProtos: tlsCfg.NextProtos, TLSNextProtos: tlsCfg.NextProtos,
TLSServerName: tlsCfg.ServerName, TLSServerName: tlsCfg.ServerName,
Time: stop, Time: stop,
@ -69,8 +71,10 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
} }
state := quicConnectionState(sess) state := quicConnectionState(sess)
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{ h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
Address: host,
Duration: stop.Sub(start), Duration: stop.Sub(start),
NoTLSVerify: tlsCfg.InsecureSkipVerify, NoTLSVerify: tlsCfg.InsecureSkipVerify,
Proto: network,
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol, TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: tlsCfg.NextProtos, TLSNextProtos: tlsCfg.NextProtos,
@ -82,7 +86,7 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
return sess, nil return sess, nil
} }
func (h *QUICHandshakeSaver) CloseIdleConnections() { func (h *QUICDialerSaver) CloseIdleConnections() {
h.QUICDialer.CloseIdleConnections() h.QUICDialer.CloseIdleConnections()
} }
@ -121,15 +125,15 @@ func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
pconn = &udpLikeConnSaver{ pconn = &quicPacketConnWrapper{
UDPLikeConn: pconn, UDPLikeConn: pconn,
saver: qls.Saver, saver: qls.Saver,
} }
return pconn, nil return pconn, nil
} }
// udpLikeConnSaver saves I/O events // quicPacketConnWrapper saves I/O events
type udpLikeConnSaver struct { type quicPacketConnWrapper struct {
// UDPLikeConn is the wrapped underlying conn // UDPLikeConn is the wrapped underlying conn
model.UDPLikeConn model.UDPLikeConn
@ -137,7 +141,7 @@ type udpLikeConnSaver struct {
saver *Saver saver *Saver
} }
func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) { func (c *quicPacketConnWrapper) WriteTo(p []byte, addr net.Addr) (int, error) {
start := time.Now() start := time.Now()
count, err := c.UDPLikeConn.WriteTo(p, addr) count, err := c.UDPLikeConn.WriteTo(p, addr)
stop := time.Now() stop := time.Now()
@ -152,7 +156,7 @@ func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) {
return count, err return count, err
} }
func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) { func (c *quicPacketConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) {
start := time.Now() start := time.Now()
n, addr, err := c.UDPLikeConn.ReadFrom(b) n, addr, err := c.UDPLikeConn.ReadFrom(b)
stop := time.Now() stop := time.Now()
@ -171,13 +175,13 @@ func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) {
return n, addr, err return n, addr, err
} }
func (c *udpLikeConnSaver) safeAddrString(addr net.Addr) (out string) { func (c *quicPacketConnWrapper) safeAddrString(addr net.Addr) (out string) {
if addr != nil { if addr != nil {
out = addr.String() out = addr.String()
} }
return return
} }
var _ model.QUICDialer = &QUICHandshakeSaver{} var _ model.QUICDialer = &QUICDialerSaver{}
var _ model.QUICListener = &QUICListenerSaver{} var _ model.QUICListener = &QUICListenerSaver{}
var _ model.UDPLikeConn = &udpLikeConnSaver{} var _ model.UDPLikeConn = &quicPacketConnWrapper{}

View File

@ -3,182 +3,446 @@ package tracex
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors" "errors"
"net" "net"
"reflect"
"strings"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/netxlite/quictesting"
) )
type MockDialer struct { func TestQUICDialerSaver(t *testing.T) {
Dialer model.QUICDialer t.Run("DialContext", func(t *testing.T) {
Sess quic.EarlyConnection
Err error
}
func (d MockDialer) DialContext(ctx context.Context, network, host string, checkStartEventFields := func(t *testing.T, value *EventValue) {
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { if value.Address != "8.8.8.8:443" {
if d.Dialer != nil { t.Fatal("invalid Address")
return d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg) }
} if !value.NoTLSVerify {
return d.Sess, d.Err t.Fatal("expected NoTLSVerify to be true")
} }
if value.Proto != "udp" {
t.Fatal("wrong protocol")
}
if diff := cmp.Diff(value.TLSNextProtos, []string{"h3"}); diff != "" {
t.Fatal(diff)
}
if value.TLSServerName != "dns.google" {
t.Fatal("invalid TLSServerName")
}
if value.Time.IsZero() {
t.Fatal("expected non zero time")
}
}
func TestHandshakeSaverSuccess(t *testing.T) { checkStartedEvent := func(t *testing.T, ev Event) {
nextprotos := []string{"h3"} if _, good := ev.(*EventQUICHandshakeStart); !good {
servername := quictesting.Domain t.Fatal("invalid event type")
tlsConf := &tls.Config{ }
NextProtos: nextprotos, value := ev.Value()
ServerName: servername, checkStartEventFields(t, value)
} }
saver := &Saver{}
dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) {
QUICListener: &netxlite.QUICListenerStdlib{}, if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err != nil {
t.Fatal("expected no error here")
}
if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" {
t.Fatal("invalid cipher suite")
}
if value.TLSNegotiatedProto != "h3" {
t.Fatal("invalid negotiated protocol")
}
if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" {
t.Fatal(diff)
}
if value.TLSVersion != "TLSv1.3" {
t.Fatal("invalid TLS version")
}
}
checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) {
if _, good := ev.(*EventQUICHandshakeDone); !good {
t.Fatal("invalid event type")
}
value := ev.Value()
checkStartEventFields(t, value)
fun(t, value)
}
t.Run("on success", func(t *testing.T) {
saver := &Saver{}
returnedConn := &mocks.QUICEarlyConnection{
MockConnectionState: func() quic.ConnectionState {
cs := quic.ConnectionState{}
cs.TLS.ConnectionState.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA
cs.TLS.NegotiatedProtocol = "h3"
cs.TLS.PeerCertificates = []*x509.Certificate{}
cs.TLS.Version = tls.VersionTLS13
return cs
},
}
dialer := saver.WrapQUICDialer(&mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return returnedConn, nil
},
})
ctx := context.Background()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h3"},
ServerName: "dns.google",
}
quicConfig := &quic.Config{}
conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess)
})
checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) {
if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err == nil {
t.Fatal("expected non-nil error here")
}
}
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
dialer := saver.WrapQUICDialer(&mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return nil, expected
},
})
ctx := context.Background()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h3"},
ServerName: "dns.google",
}
quicConfig := &quic.Config{}
conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsFailure)
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.QUICDialer{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := &QUICDialerSaver{
QUICDialer: child,
Saver: &Saver{},
}
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
}) })
sess, err := dlr.DialContext(context.Background(), "udp",
quictesting.Endpoint("443"), tlsConf, &quic.Config{})
if err != nil {
t.Fatal("unexpected error", err)
}
if sess == nil {
t.Fatal("unexpected nil sess")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("unexpected number of events")
}
if ev[0].Name() != "quic_handshake_start" {
t.Fatal("unexpected Name")
}
if ev[0].Value().TLSServerName != quictesting.Domain {
t.Fatal("unexpected TLSServerName")
}
if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[0].Value().Time.After(time.Now()) {
t.Fatal("unexpected Time")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != nil {
t.Fatal("unexpected Err", ev[1].Value().Err)
}
if ev[1].Name() != "quic_handshake_done" {
t.Fatal("unexpected Name")
}
if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[1].Value().TLSServerName != quictesting.Domain {
t.Fatal("unexpected TLSServerName")
}
if ev[1].Value().Time.Before(ev[0].Value().Time) {
t.Fatal("unexpected Time")
}
} }
func TestHandshakeSaverHostNameError(t *testing.T) { func TestQUICListenerSaver(t *testing.T) {
nextprotos := []string{"h3"} t.Run("on failure", func(t *testing.T) {
servername := "example.com" expected := errors.New("mocked error")
tlsConf := &tls.Config{ saver := &Saver{}
NextProtos: nextprotos, qls := saver.WrapQUICListener(&mocks.QUICListener{
ServerName: servername, MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
} return nil, expected
saver := &Saver{} },
dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{ })
QUICListener: &netxlite.QUICListenerStdlib{}, pconn, err := qls.Listen(&net.UDPAddr{
IP: []byte{},
Port: 8080,
Zone: "",
})
if !errors.Is(err, expected) {
t.Fatal("unexpected error", err)
}
if pconn != nil {
t.Fatal("expected nil pconn here")
}
}) })
sess, err := dlr.DialContext(context.Background(), "udp",
quictesting.Endpoint("443"), tlsConf, &quic.Config{}) t.Run("on success", func(t *testing.T) {
if err == nil { saver := &Saver{}
t.Fatal("expected an error here") returnedConn := &mocks.UDPLikeConn{}
} qls := saver.WrapQUICListener(&mocks.QUICListener{
if sess != nil { MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
t.Fatal("expected nil sess here") return returnedConn, nil
} },
for _, ev := range saver.Read() { })
if ev.Name() != "quic_handshake_done" { pconn, err := qls.Listen(&net.UDPAddr{
continue IP: []byte{},
Port: 8080,
Zone: "",
})
if err != nil {
t.Fatal(err)
} }
if ev.Value().NoTLSVerify == true { wconn := pconn.(*quicPacketConnWrapper)
t.Fatal("expected NoTLSVerify to be false") if wconn.UDPLikeConn != returnedConn {
t.Fatal("invalid underlying connection")
} }
if !strings.HasSuffix(ev.Value().Err.Error(), "tls: handshake failure") { if wconn.saver != saver {
t.Fatal("unexpected error", ev.Value().Err) t.Fatal("invalid saver")
} }
} })
} }
func TestQUICListenerSaverCannotListen(t *testing.T) { func TestQUICPacketConnWrapper(t *testing.T) {
expected := errors.New("mocked error") t.Run("ReadFrom", func(t *testing.T) {
saver := &Saver{}
qls := saver.WrapQUICListener(&mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
return nil, expected
},
})
pconn, err := qls.Listen(&net.UDPAddr{
IP: []byte{},
Port: 8080,
Zone: "",
})
if !errors.Is(err, expected) {
t.Fatal("unexpected error", err)
}
if pconn != nil {
t.Fatal("expected nil pconn here")
}
}
func TestSystemDialerSuccessWithReadWrite(t *testing.T) { t.Run("on failure", func(t *testing.T) {
// This is the most common use case for collecting reads, writes expected := errors.New("mocked error")
tlsConf := &tls.Config{ saver := &Saver{}
NextProtos: []string{"h3"}, conn := &quicPacketConnWrapper{
ServerName: quictesting.Domain, UDPLikeConn: &mocks.UDPLikeConn{
} MockReadFrom: func(p []byte) (int, net.Addr, error) {
saver := &Saver{} return 0, nil, expected
systemdialer := &netxlite.QUICDialerQUICGo{ },
QUICListener: saver.WrapQUICListener(&netxlite.QUICListenerStdlib{}), },
} saver: saver,
_, err := systemdialer.DialContext(context.Background(), "udp", }
quictesting.Endpoint("443"), tlsConf, &quic.Config{}) buf := make([]byte, 1<<17)
if err != nil { count, addr, err := conn.ReadFrom(buf)
t.Fatal(err) if !errors.Is(err, expected) {
} t.Fatal("unexpected err", err)
ev := saver.Read() }
if len(ev) < 2 { if count != 0 {
t.Fatal("unexpected number of events") t.Fatal("invalid count")
} }
last := len(ev) - 1 if addr != nil {
for idx := 1; idx < last; idx++ { t.Fatal("invalid addr")
if ev[idx].Value().Data == nil { }
t.Fatal("unexpected Data") events := saver.Read()
} if len(events) != 1 {
if ev[idx].Value().Duration <= 0 { t.Fatal("invalid number of events")
t.Fatal("unexpected Duration") }
} ev0 := events[0]
if ev[idx].Value().Err != nil { if _, good := ev0.(*EventReadFromOperation); !good {
t.Fatal("unexpected Err") t.Fatal("invalid event type")
} }
if ev[idx].Value().NumBytes <= 0 { value := ev0.Value()
t.Fatal("unexpected NumBytes") if value.Address != "" {
} t.Fatal("invalid Address")
switch ev[idx].Name() { }
case netxlite.ReadFromOperation, netxlite.WriteToOperation: if len(value.Data) != 0 {
default: t.Fatal("invalid Data")
t.Fatal("unexpected Name") }
} if value.Duration <= 0 {
if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) { t.Fatal("expected nonzero duration")
t.Fatal("unexpected Time", ev[idx].Value().Time, ev[idx-1].Value().Time) }
} if !errors.Is(value.Err, expected) {
} t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 0 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
t.Run("on success", func(t *testing.T) {
expected := []byte{1, 2, 3, 4}
saver := &Saver{}
expectedAddr := &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
MockNetwork: func() string {
return "udp"
},
}
conn := &quicPacketConnWrapper{
UDPLikeConn: &mocks.UDPLikeConn{
MockReadFrom: func(p []byte) (int, net.Addr, error) {
copy(p, expected)
return len(expected), expectedAddr, nil
},
},
saver: saver,
}
buf := make([]byte, 1<<17)
count, addr, err := conn.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
if count != 4 {
t.Fatal("invalid count")
}
if addr != expectedAddr {
t.Fatal("invalid addr")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("invalid number of events")
}
ev0 := events[0]
if _, good := ev0.(*EventReadFromOperation); !good {
t.Fatal("invalid event type")
}
value := ev0.Value()
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if len(value.Data) != 4 {
t.Fatal("invalid Data")
}
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if value.Err != nil {
t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 4 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
})
t.Run("WriteTo", func(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
conn := &quicPacketConnWrapper{
UDPLikeConn: &mocks.UDPLikeConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return 0, expected
},
},
saver: saver,
}
destAddr := &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
}
buf := make([]byte, 7)
count, err := conn.WriteTo(buf, destAddr)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("invalid count")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("invalid number of events")
}
ev0 := events[0]
if _, good := ev0.(*EventWriteToOperation); !good {
t.Fatal("invalid event type")
}
value := ev0.Value()
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if len(value.Data) != 0 {
t.Fatal("invalid Data")
}
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if !errors.Is(value.Err, expected) {
t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 0 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
t.Run("on success", func(t *testing.T) {
saver := &Saver{}
conn := &quicPacketConnWrapper{
UDPLikeConn: &mocks.UDPLikeConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return 1, nil
},
},
saver: saver,
}
destAddr := &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
}
buf := make([]byte, 7)
count, err := conn.WriteTo(buf, destAddr)
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatal("invalid count")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("invalid number of events")
}
ev0 := events[0]
if _, good := ev0.(*EventWriteToOperation); !good {
t.Fatal("invalid event type")
}
value := ev0.Value()
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if len(value.Data) != 1 {
t.Fatal("invalid Data")
}
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if value.Err != nil {
t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 1 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
})
} }

View File

@ -12,8 +12,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
// SaverResolver is a resolver that saves events. // ResolverSaver is a resolver that saves events.
type SaverResolver struct { type ResolverSaver struct {
// Resolver is the underlying resolver. // Resolver is the underlying resolver.
Resolver model.Resolver Resolver model.Resolver
@ -30,14 +30,14 @@ func (s *Saver) WrapResolver(r model.Resolver) model.Resolver {
if s == nil { if s == nil {
return r return r
} }
return &SaverResolver{ return &ResolverSaver{
Resolver: r, Resolver: r,
Saver: s, Saver: s,
} }
} }
// LookupHost implements Resolver.LookupHost // LookupHost implements Resolver.LookupHost
func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *ResolverSaver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
start := time.Now() start := time.Now()
r.Saver.Write(&EventResolveStart{&EventValue{ r.Saver.Write(&EventResolveStart{&EventValue{
Address: r.Resolver.Address(), Address: r.Resolver.Address(),
@ -59,30 +59,30 @@ func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]stri
return addrs, err return addrs, err
} }
func (r *SaverResolver) Network() string { func (r *ResolverSaver) Network() string {
return r.Resolver.Network() return r.Resolver.Network()
} }
func (r *SaverResolver) Address() string { func (r *ResolverSaver) Address() string {
return r.Resolver.Address() return r.Resolver.Address()
} }
func (r *SaverResolver) CloseIdleConnections() { func (r *ResolverSaver) CloseIdleConnections() {
r.Resolver.CloseIdleConnections() r.Resolver.CloseIdleConnections()
} }
func (r *SaverResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { func (r *ResolverSaver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
// TODO(bassosimone): we should probably implement this method // TODO(bassosimone): we should probably implement this method
return r.Resolver.LookupHTTPS(ctx, domain) return r.Resolver.LookupHTTPS(ctx, domain)
} }
func (r *SaverResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { func (r *ResolverSaver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
// TODO(bassosimone): we should probably implement this method // TODO(bassosimone): we should probably implement this method
return r.Resolver.LookupNS(ctx, domain) return r.Resolver.LookupNS(ctx, domain)
} }
// SaverDNSTransport is a DNS transport that saves events. // DNSTransportSaver is a DNS transport that saves events.
type SaverDNSTransport struct { type DNSTransportSaver struct {
// DNSTransport is the underlying DNS transport. // DNSTransport is the underlying DNS transport.
DNSTransport model.DNSTransport DNSTransport model.DNSTransport
@ -99,14 +99,14 @@ func (s *Saver) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
if s == nil { if s == nil {
return txp return txp
} }
return &SaverDNSTransport{ return &DNSTransportSaver{
DNSTransport: txp, DNSTransport: txp,
Saver: s, Saver: s,
} }
} }
// RoundTrip implements RoundTripper.RoundTrip // RoundTrip implements RoundTripper.RoundTrip
func (txp *SaverDNSTransport) RoundTrip( func (txp *DNSTransportSaver) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
start := time.Now() start := time.Now()
txp.Saver.Write(&EventDNSRoundTripStart{&EventValue{ txp.Saver.Write(&EventDNSRoundTripStart{&EventValue{
@ -129,19 +129,19 @@ func (txp *SaverDNSTransport) RoundTrip(
return response, err return response, err
} }
func (txp *SaverDNSTransport) Network() string { func (txp *DNSTransportSaver) Network() string {
return txp.DNSTransport.Network() return txp.DNSTransport.Network()
} }
func (txp *SaverDNSTransport) Address() string { func (txp *DNSTransportSaver) Address() string {
return txp.DNSTransport.Address() return txp.DNSTransport.Address()
} }
func (txp *SaverDNSTransport) CloseIdleConnections() { func (txp *DNSTransportSaver) CloseIdleConnections() {
txp.DNSTransport.CloseIdleConnections() txp.DNSTransport.CloseIdleConnections()
} }
func (txp *SaverDNSTransport) RequiresPadding() bool { func (txp *DNSTransportSaver) RequiresPadding() bool {
return txp.DNSTransport.RequiresPadding() return txp.DNSTransport.RequiresPadding()
} }
@ -157,5 +157,5 @@ func dnsMaybeResponseBytes(response model.DNSResponse) []byte {
return response.Bytes() return response.Bytes()
} }
var _ model.Resolver = &SaverResolver{} var _ model.Resolver = &ResolverSaver{}
var _ model.DNSTransport = &SaverDNSTransport{} var _ model.DNSTransport = &DNSTransportSaver{}

View File

@ -14,220 +14,224 @@ import (
"github.com/ooni/probe-cli/v3/internal/runtimex" "github.com/ooni/probe-cli/v3/internal/runtimex"
) )
func TestSaverResolverFailure(t *testing.T) { func TestResolverSaver(t *testing.T) {
expected := errors.New("no such host") t.Run("on failure", func(t *testing.T) {
saver := &Saver{} expected := errors.New("no such host")
reso := saver.WrapResolver(NewFakeResolverWithExplicitError(expected)) saver := &Saver{}
addrs, err := reso.LookupHost(context.Background(), "www.google.com") reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected))
if !errors.Is(err, expected) { addrs, err := reso.LookupHost(context.Background(), "www.google.com")
t.Fatal("not the error we expected") if !errors.Is(err, expected) {
} t.Fatal("not the error we expected")
if addrs != nil { }
t.Fatal("expected nil address here") if addrs != nil {
} t.Fatal("expected nil address here")
ev := saver.Read() }
if len(ev) != 2 { ev := saver.Read()
t.Fatal("expected number of events") if len(ev) != 2 {
} t.Fatal("expected number of events")
if ev[0].Value().Hostname != "www.google.com" { }
t.Fatal("unexpected Hostname") if ev[0].Value().Hostname != "www.google.com" {
} t.Fatal("unexpected Hostname")
if ev[0].Name() != "resolve_start" { }
t.Fatal("unexpected name") if ev[0].Name() != "resolve_start" {
} t.Fatal("unexpected name")
if !ev[0].Value().Time.Before(time.Now()) { }
t.Fatal("the saved time is wrong") if !ev[0].Value().Time.Before(time.Now()) {
} t.Fatal("the saved time is wrong")
if ev[1].Value().Addresses != nil { }
t.Fatal("unexpected Addresses") if ev[1].Value().Addresses != nil {
} t.Fatal("unexpected Addresses")
if ev[1].Value().Duration <= 0 { }
t.Fatal("unexpected Duration") if ev[1].Value().Duration <= 0 {
} t.Fatal("unexpected Duration")
if !errors.Is(ev[1].Value().Err, expected) { }
t.Fatal("unexpected Err") if !errors.Is(ev[1].Value().Err, expected) {
} t.Fatal("unexpected Err")
if ev[1].Value().Hostname != "www.google.com" { }
t.Fatal("unexpected Hostname") if ev[1].Value().Hostname != "www.google.com" {
} t.Fatal("unexpected Hostname")
if ev[1].Name() != "resolve_done" { }
t.Fatal("unexpected name") if ev[1].Name() != "resolve_done" {
} t.Fatal("unexpected name")
if !ev[1].Value().Time.After(ev[0].Value().Time) { }
t.Fatal("the saved time is wrong") if !ev[1].Value().Time.After(ev[0].Value().Time) {
} t.Fatal("the saved time is wrong")
} }
func TestSaverResolverSuccess(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &Saver{}
reso := saver.WrapResolver(NewFakeResolverWithResult(expected))
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if err != nil {
t.Fatal("expected nil error here")
}
if !reflect.DeepEqual(addrs, expected) {
t.Fatal("not the result we expected")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name() != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !reflect.DeepEqual(ev[1].Value().Addresses, expected) {
t.Fatal("unexpected Addresses")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name() != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverDNSTransportFailure(t *testing.T) {
expected := errors.New("no such host")
saver := &Saver{}
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
}) })
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].Value().DNSResponse != nil {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Value().Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverDNSTransportSuccess(t *testing.T) { t.Run("on success", func(t *testing.T) {
expected := []byte{0xef, 0xbe, 0xad, 0xde} expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &Saver{} saver := &Saver{}
response := &mocks.DNSResponse{ reso := saver.WrapResolver(newFakeResolverWithResult(expected))
MockBytes: func() []byte { addrs, err := reso.LookupHost(context.Background(), "www.google.com")
return expected if err != nil {
}, t.Fatal("expected nil error here")
} }
txp := saver.WrapDNSTransport(&mocks.DNSTransport{ if !reflect.DeepEqual(addrs, expected) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { t.Fatal("not the result we expected")
return response, nil }
}, ev := saver.Read()
MockNetwork: func() string { if len(ev) != 2 {
return "fake" t.Fatal("expected number of events")
}, }
MockAddress: func() string { if ev[0].Value().Hostname != "www.google.com" {
return "" t.Fatal("unexpected Hostname")
}, }
if ev[0].Name() != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !reflect.DeepEqual(ev[1].Value().Addresses, expected) {
t.Fatal("unexpected Addresses")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name() != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
}) })
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal("we expected nil error here")
}
if !bytes.Equal(reply.Bytes(), expected) {
t.Fatal("expected another reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].Value().DNSResponse, expected) {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
} }
func NewFakeResolverWithExplicitError(err error) model.Resolver { func TestDNSTransportSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("no such host")
saver := &Saver{}
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].Value().DNSResponse != nil {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Value().Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
t.Run("on success", func(t *testing.T) {
expected := []byte{0xef, 0xbe, 0xad, 0xde}
saver := &Saver{}
response := &mocks.DNSResponse{
MockBytes: func() []byte {
return expected
},
}
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return response, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal("we expected nil error here")
}
if !bytes.Equal(reply.Bytes(), expected) {
t.Fatal("expected another reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].Value().DNSResponse, expected) {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
}
func newFakeResolverWithExplicitError(err error) model.Resolver {
runtimex.PanicIfNil(err, "passed nil error") runtimex.PanicIfNil(err, "passed nil error")
return &mocks.Resolver{ return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
@ -251,7 +255,7 @@ func NewFakeResolverWithExplicitError(err error) model.Resolver {
} }
} }
func NewFakeResolverWithResult(r []string) model.Resolver { func newFakeResolverWithResult(r []string) model.Resolver {
return &mocks.Resolver{ return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return r, nil return r, nil

View File

@ -3,22 +3,88 @@ package tracex
import ( import (
"sync" "sync"
"testing" "testing"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestSaver(t *testing.T) { func TestSaver(t *testing.T) {
saver := Saver{} t.Run("concurrent writes followed by read", func(t *testing.T) {
var wg sync.WaitGroup saver := Saver{}
const parallel = 10 var wg sync.WaitGroup
wg.Add(parallel) const parallel = 10
for idx := 0; idx < parallel; idx++ { wg.Add(parallel)
go func() { for idx := 0; idx < parallel; idx++ {
saver.Write(&EventReadFromOperation{&EventValue{}}) go func() {
wg.Done() saver.Write(&EventReadFromOperation{&EventValue{}})
}() wg.Done()
} }()
wg.Wait() }
ev := saver.Read() wg.Wait()
if len(ev) != parallel { ev := saver.Read()
t.Fatal("unexpected number of events read") if len(ev) != parallel {
} t.Fatal("unexpected number of events read")
}
})
t.Run("NewConnectObserver", func(t *testing.T) {
t.Run("nil Saver", func(t *testing.T) {
var saver *Saver
obs := saver.NewConnectObserver()
if obs != nil {
t.Fatal("expected nil observer")
}
})
t.Run("nonnnil Saver", func(t *testing.T) {
saver := &Saver{}
obs := saver.NewConnectObserver()
underlying := obs.(*dialerConnectObserver)
if underlying.saver != saver {
t.Fatal("invalid saver")
}
})
})
t.Run("NewReadWriteObserver", func(t *testing.T) {
t.Run("nil Saver", func(t *testing.T) {
var saver *Saver
obs := saver.NewReadWriteObserver()
if obs != nil {
t.Fatal("expected nil observer")
}
})
t.Run("nonnnil Saver", func(t *testing.T) {
saver := &Saver{}
obs := saver.NewReadWriteObserver()
underlying := obs.(*dialerReadWriteObserver)
if underlying.saver != saver {
t.Fatal("invalid saver")
}
})
})
t.Run("WrapQUICDialer", func(t *testing.T) {
t.Run("nil Saver", func(t *testing.T) {
var saver *Saver
base := &mocks.QUICDialer{}
qd := saver.WrapQUICDialer(base)
if qd != base {
t.Fatal("unexpected returned QUICDialer")
}
})
t.Run("nonnnil Saver", func(t *testing.T) {
saver := &Saver{}
base := &mocks.QUICDialer{}
qd := saver.WrapQUICDialer(base)
underlying := qd.(*QUICDialerSaver)
if underlying.Saver != saver {
t.Fatal("invalid Saver")
}
if underlying.QUICDialer != base {
t.Fatal("invalid QUICDialer")
}
})
})
} }

View File

@ -16,8 +16,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
// SaverTLSHandshaker saves events occurring during the TLS handshake. // TLSHandshakerSaver saves events occurring during the TLS handshake.
type SaverTLSHandshaker struct { type TLSHandshakerSaver struct {
// TLSHandshaker is the underlying TLS handshaker. // TLSHandshaker is the underlying TLS handshaker.
TLSHandshaker model.TLSHandshaker TLSHandshaker model.TLSHandshaker
@ -34,23 +34,26 @@ func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker {
if s == nil { if s == nil {
return thx return thx
} }
return &SaverTLSHandshaker{ return &TLSHandshakerSaver{
TLSHandshaker: thx, TLSHandshaker: thx,
Saver: s, Saver: s,
} }
} }
// Handshake implements model.TLSHandshaker.Handshake // Handshake implements model.TLSHandshaker.Handshake
func (h *SaverTLSHandshaker) Handshake( func (h *TLSHandshakerSaver) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
proto := conn.RemoteAddr().Network()
remoteAddr := conn.RemoteAddr().String()
start := time.Now() start := time.Now()
h.Saver.Write(&EventTLSHandshakeStart{&EventValue{ h.Saver.Write(&EventTLSHandshakeStart{&EventValue{
Address: remoteAddr,
NoTLSVerify: config.InsecureSkipVerify, NoTLSVerify: config.InsecureSkipVerify,
Proto: proto,
TLSNextProtos: config.NextProtos, TLSNextProtos: config.NextProtos,
TLSServerName: config.ServerName, TLSServerName: config.ServerName,
Time: start, Time: start,
}}) }})
remoteAddr := conn.RemoteAddr().String()
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config) tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
stop := time.Now() stop := time.Now()
h.Saver.Write(&EventTLSHandshakeDone{&EventValue{ h.Saver.Write(&EventTLSHandshakeDone{&EventValue{
@ -58,6 +61,7 @@ func (h *SaverTLSHandshaker) Handshake(
Duration: stop.Sub(start), Duration: stop.Sub(start),
Err: err, Err: err,
NoTLSVerify: config.InsecureSkipVerify, NoTLSVerify: config.InsecureSkipVerify,
Proto: proto,
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol, TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: config.NextProtos, TLSNextProtos: config.NextProtos,
@ -69,7 +73,7 @@ func (h *SaverTLSHandshaker) Handshake(
return tlsconn, state, err return tlsconn, state, err
} }
var _ model.TLSHandshaker = &SaverTLSHandshaker{} var _ model.TLSHandshaker = &TLSHandshakerSaver{}
// tlsPeerCerts returns the certificates presented by the peer regardless // tlsPeerCerts returns the certificates presented by the peer regardless
// of whether the TLS handshake was successful // of whether the TLS handshake was successful

View File

@ -3,291 +3,250 @@ package tracex
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"reflect" "crypto/x509"
"errors"
"net"
"testing" "testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { func TestTLSHandshakerSaver(t *testing.T) {
// This is the most common use case for collecting reads, writes
if testing.Short() { t.Run("Handshake", func(t *testing.T) {
t.Skip("skip test in short mode") checkStartEventFields := func(t *testing.T, value *EventValue) {
} if value.Address != "8.8.8.8:443" {
nextprotos := []string{"h2"} t.Fatal("invalid Address")
saver := &Saver{} }
tlsdlr := &netxlite.TLSDialerLegacy{ if !value.NoTLSVerify {
Config: &tls.Config{NextProtos: nextprotos}, t.Fatal("expected NoTLSVerify to be true")
Dialer: netxlite.NewDialerWithResolver( }
model.DiscardLogger, if value.Proto != "tcp" {
netxlite.NewResolverStdlib(model.DiscardLogger), t.Fatal("wrong protocol")
saver.NewReadWriteObserver(), }
), if diff := cmp.Diff(value.TLSNextProtos, []string{"h2"}); diff != "" {
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), t.Fatal(diff)
} }
// Implementation note: we don't close the connection here because it is if value.TLSServerName != "dns.google" {
// very handy to have the last event being the end of the handshake t.Fatal("invalid TLSServerName")
_, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") }
if err != nil { if value.Time.IsZero() {
t.Fatal(err) t.Fatal("expected non zero time")
} }
ev := saver.Read()
if len(ev) < 4 {
// it's a bit tricky to be sure about the right number of
// events because network conditions may influence that
t.Fatal("unexpected number of events")
}
if ev[0].Name() != "tls_handshake_start" {
t.Fatal("unexpected Name")
}
if ev[0].Value().TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[0].Value().Time.After(time.Now()) {
t.Fatal("unexpected Time")
}
last := len(ev) - 1
for idx := 1; idx < last; idx++ {
if ev[idx].Value().Data == nil {
t.Fatal("unexpected Data")
} }
if ev[idx].Value().Duration <= 0 {
t.Fatal("unexpected Duration") checkStartedEvent := func(t *testing.T, ev Event) {
if _, good := ev.(*EventTLSHandshakeStart); !good {
t.Fatal("invalid event type")
}
value := ev.Value()
checkStartEventFields(t, value)
} }
if ev[idx].Value().Err != nil {
t.Fatal("unexpected Err") checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) {
if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err != nil {
t.Fatal("expected no error here")
}
if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" {
t.Fatal("invalid cipher suite")
}
if value.TLSNegotiatedProto != "h2" {
t.Fatal("invalid negotiated protocol")
}
if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" {
t.Fatal(diff)
}
if value.TLSVersion != "TLSv1.3" {
t.Fatal("invalid TLS version")
}
} }
if ev[idx].Value().NumBytes <= 0 {
t.Fatal("unexpected NumBytes") checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) {
if _, good := ev.(*EventTLSHandshakeDone); !good {
t.Fatal("invalid event type")
}
value := ev.Value()
checkStartEventFields(t, value)
fun(t, value)
} }
switch ev[idx].Name() {
case netxlite.ReadOperation, netxlite.WriteOperation: t.Run("on success", func(t *testing.T) {
default: saver := &Saver{}
t.Fatal("unexpected Name") returnedConnState := tls.ConnectionState{
CipherSuite: tls.TLS_RSA_WITH_RC4_128_SHA,
NegotiatedProtocol: "h2",
PeerCertificates: []*x509.Certificate{},
Version: tls.VersionTLS13,
}
returnedConn := &mocks.TLSConn{
MockConnectionState: func() tls.ConnectionState {
return returnedConnState
},
}
thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn,
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return returnedConn, returnedConnState, nil
},
})
ctx := context.Background()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
ServerName: "dns.google",
}
tcpConn := &mocks.Conn{
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess)
})
checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) {
if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err == nil {
t.Fatal("expected non-nil error here")
}
if value.TLSCipherSuite != "" {
t.Fatal("invalid TLS cipher suite")
}
if value.TLSNegotiatedProto != "" {
t.Fatal("invalid negotiated proto")
}
if len(value.TLSPeerCerts) > 0 {
t.Fatal("expected no peer certs")
}
if value.TLSVersion != "" {
t.Fatal("invalid TLS version")
}
} }
if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) {
t.Fatal("unexpected Time") t.Run("on failure", func(t *testing.T) {
} expected := errors.New("mocked error")
} saver := &Saver{}
if ev[last].Value().Duration <= 0 { thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{
t.Fatal("unexpected Duration") MockHandshake: func(ctx context.Context, conn net.Conn,
} config *tls.Config) (net.Conn, tls.ConnectionState, error) {
if ev[last].Value().Err != nil { return nil, tls.ConnectionState{}, expected
t.Fatal("unexpected Err") },
} })
if ev[last].Name() != "tls_handshake_done" { ctx := context.Background()
t.Fatal("unexpected Name") tlsConfig := &tls.Config{
} InsecureSkipVerify: true,
if ev[last].Value().TLSCipherSuite == "" { NextProtos: []string{"h2"},
t.Fatal("unexpected TLSCipherSuite") ServerName: "dns.google",
} }
if ev[last].Value().TLSNegotiatedProto != "h2" { tcpConn := &mocks.Conn{
t.Fatal("unexpected TLSNegotiatedProto") MockRemoteAddr: func() net.Addr {
} return &mocks.Addr{
if !reflect.DeepEqual(ev[last].Value().TLSNextProtos, nextprotos) { MockString: func() string {
t.Fatal("unexpected TLSNextProtos") return "8.8.8.8:443"
} },
if ev[last].Value().TLSPeerCerts == nil { MockNetwork: func() string {
t.Fatal("unexpected TLSPeerCerts") return "tcp"
} },
if ev[last].Value().TLSServerName != "www.google.com" { }
t.Fatal("unexpected TLSServerName") },
} }
if ev[last].Value().TLSVersion == "" { conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig)
t.Fatal("unexpected TLSVersion") if !errors.Is(err, expected) {
} t.Fatal("unexpected err", err)
if ev[last].Value().Time.Before(ev[last-1].Value().Time) { }
t.Fatal("unexpected Time") if conn != nil {
} t.Fatal("expected nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsFailure)
})
})
} }
func TestSaverTLSHandshakerSuccess(t *testing.T) { func Test_tlsPeerCerts(t *testing.T) {
if testing.Short() { cert0 := &x509.Certificate{Raw: []byte{1, 2, 3, 4}}
t.Skip("skip test in short mode") type args struct {
state tls.ConnectionState
err error
} }
nextprotos := []string{"h2"} tests := []struct {
saver := &Saver{} name string
tlsdlr := &netxlite.TLSDialerLegacy{ args args
Config: &tls.Config{NextProtos: nextprotos}, want []*x509.Certificate
Dialer: &netxlite.DialerSystem{}, }{{
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}), name: "no error",
} args: args{
conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") state: tls.ConnectionState{
if err != nil { PeerCertificates: []*x509.Certificate{cert0},
t.Fatal(err) },
} },
conn.Close() want: []*x509.Certificate{cert0},
ev := saver.Read() }, {
if len(ev) != 2 { name: "all empty",
t.Fatal("unexpected number of events") args: args{},
} want: nil,
if ev[0].Name() != "tls_handshake_start" { }, {
t.Fatal("unexpected Name") name: "x509.HostnameError",
} args: args{
if ev[0].Value().TLSServerName != "www.google.com" { state: tls.ConnectionState{},
t.Fatal("unexpected TLSServerName") err: x509.HostnameError{
} Certificate: cert0,
if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) { },
t.Fatal("unexpected TLSNextProtos") },
} want: []*x509.Certificate{cert0},
if ev[0].Value().Time.After(time.Now()) { }, {
t.Fatal("unexpected Time") name: "x509.UnknownAuthorityError",
} args: args{
if ev[1].Value().Duration <= 0 { state: tls.ConnectionState{},
t.Fatal("unexpected Duration") err: x509.UnknownAuthorityError{
} Cert: cert0,
if ev[1].Value().Err != nil { },
t.Fatal("unexpected Err") },
} want: []*x509.Certificate{cert0},
if ev[1].Name() != "tls_handshake_done" { }, {
t.Fatal("unexpected Name") name: "x509.CertificateInvalidError",
} args: args{
if ev[1].Value().TLSCipherSuite == "" { state: tls.ConnectionState{},
t.Fatal("unexpected TLSCipherSuite") err: x509.CertificateInvalidError{
} Cert: cert0,
if ev[1].Value().TLSNegotiatedProto != "h2" { },
t.Fatal("unexpected TLSNegotiatedProto") },
} want: []*x509.Certificate{cert0},
if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) { }}
t.Fatal("unexpected TLSNextProtos") for _, tt := range tests {
} t.Run(tt.name, func(t *testing.T) {
if ev[1].Value().TLSPeerCerts == nil { got := tlsPeerCerts(tt.args.state, tt.args.err)
t.Fatal("unexpected TLSPeerCerts") if diff := cmp.Diff(tt.want, got); diff != "" {
} t.Fatal(diff)
if ev[1].Value().TLSServerName != "www.google.com" { }
t.Fatal("unexpected TLSServerName") })
}
if ev[1].Value().TLSVersion == "" {
t.Fatal("unexpected TLSVersion")
}
if ev[1].Value().Time.Before(ev[0].Value().Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverTLSHandshakerHostnameError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: &netxlite.DialerSystem{},
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "wrong.host.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name() != "tls_handshake_done" {
continue
}
if ev.Value().NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.Value().TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: &netxlite.DialerSystem{},
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "expired.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name() != "tls_handshake_done" {
continue
}
if ev.Value().NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.Value().TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: &netxlite.DialerSystem{},
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name() != "tls_handshake_done" {
continue
}
if ev.Value().NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.Value().TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{InsecureSkipVerify: true},
Dialer: &netxlite.DialerSystem{},
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn here")
}
conn.Close()
for _, ev := range saver.Read() {
if ev.Name() != "tls_handshake_done" {
continue
}
if ev.Value().NoTLSVerify != true {
t.Fatal("expected NoTLSVerify to be true")
}
if len(ev.Value().TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
} }
} }