refactor: move tracex outside of engine/netx (#782)

* refactor: move tracex outside of engine/netx

Consistently with https://github.com/ooni/probe/issues/2121 and
https://github.com/ooni/probe/issues/2115, we can now move tracex
outside of engine/netx. The main reason why this makes sense now
is that the package is now changed significantly from the one
that we imported from ooni/probe-engine.

We have improved its implementation, which had not been touched
significantly for quite some time, and converted it to unit
testing. I will document tomorrow some extra work I'd like to
do with this package but likely could not do $soon.

* go fmt

* regen tutorials
This commit is contained in:
Simone Basso
2022-06-02 00:50:55 +02:00
committed by GitHub
parent d397036073
commit 58adb68b2c
50 changed files with 34 additions and 34 deletions
+322
View File
@@ -0,0 +1,322 @@
package tracex
//
// Code to generate the OONI archival data format from events
//
import (
"crypto/x509"
"errors"
"net"
"net/http"
"sort"
"strconv"
"strings"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// Compatibility types. Most experiments still use these names.
type (
ExtSpec = model.ArchivalExtSpec
TCPConnectEntry = model.ArchivalTCPConnectResult
TCPConnectStatus = model.ArchivalTCPConnectStatus
MaybeBinaryValue = model.ArchivalMaybeBinaryData
DNSQueryEntry = model.ArchivalDNSLookupResult
DNSAnswerEntry = model.ArchivalDNSAnswer
TLSHandshake = model.ArchivalTLSOrQUICHandshakeResult
HTTPBody = model.ArchivalHTTPBody
HTTPHeader = model.ArchivalHTTPHeader
RequestEntry = model.ArchivalHTTPRequestResult
HTTPRequest = model.ArchivalHTTPRequest
HTTPResponse = model.ArchivalHTTPResponse
NetworkEvent = model.ArchivalNetworkEvent
)
// Compatibility variables. Most experiments still use these names.
var (
ExtDNS = model.ArchivalExtDNS
ExtNetevents = model.ArchivalExtNetevents
ExtHTTP = model.ArchivalExtHTTP
ExtTCPConnect = model.ArchivalExtTCPConnect
ExtTLSHandshake = model.ArchivalExtTLSHandshake
ExtTunnel = model.ArchivalExtTunnel
)
// NewTCPConnectList creates a new TCPConnectList
func NewTCPConnectList(begin time.Time, events []Event) (out []TCPConnectEntry) {
for _, wrapper := range events {
if _, ok := wrapper.(*EventConnectOperation); !ok {
continue
}
event := wrapper.Value()
if event.Proto != "tcp" {
continue
}
// We assume Go is passing us legit data structures
ip, sport, _ := net.SplitHostPort(event.Address)
iport, _ := strconv.Atoi(sport)
out = append(out, TCPConnectEntry{
IP: ip,
Port: iport,
Status: TCPConnectStatus{
Blocked: nil, // only used by Web Connectivity
Failure: NewFailure(event.Err),
Success: event.Err == nil,
},
T: event.Time.Sub(begin).Seconds(),
})
}
return
}
// NewFailure creates a failure nullable string from the given error
func NewFailure(err error) *string {
if err == nil {
return nil
}
// The following code guarantees that the error is always wrapped even
// when we could not actually hit our code that does the wrapping. A case
// in which this happen is with context deadline for HTTP.
err = netxlite.NewTopLevelGenericErrWrapper(err)
errWrapper := err.(*netxlite.ErrWrapper)
s := errWrapper.Failure
if s == "" {
s = "unknown_failure: errWrapper.Failure is empty"
}
return &s
}
// NewFailedOperation creates a failed operation string from the given error.
func NewFailedOperation(err error) *string {
if err == nil {
return nil
}
var (
errWrapper *netxlite.ErrWrapper
s = netxlite.UnknownOperation
)
if errors.As(err, &errWrapper) && errWrapper.Operation != "" {
s = errWrapper.Operation
}
return &s
}
// httpAddHeaders adds the headers inside source into destList and destMap.
func httpAddHeaders(source http.Header, destList *[]HTTPHeader,
destMap *map[string]MaybeBinaryValue) {
*destList = []HTTPHeader{}
*destMap = make(map[string]model.ArchivalMaybeBinaryData)
for key, values := range source {
for index, value := range values {
value := MaybeBinaryValue{Value: value}
// With the map representation we can only represent a single
// value for every key. Hence the list representation.
if index == 0 {
(*destMap)[key] = value
}
*destList = append(*destList, HTTPHeader{
Key: key,
Value: value,
})
}
}
// Sorting helps with unit testing (map keys are unordered)
sort.Slice(*destList, func(i, j int) bool {
return (*destList)[i].Key < (*destList)[j].Key
})
}
// NewRequestList returns the list for "requests"
func NewRequestList(begin time.Time, events []Event) (out []RequestEntry) {
// OONI wants the last request to appear first
tmp := newRequestList(begin, events)
for i := len(tmp) - 1; i >= 0; i-- {
out = append(out, tmp[i])
}
return
}
func newRequestList(begin time.Time, events []Event) (out []RequestEntry) {
for _, wrapper := range events {
ev := wrapper.Value()
switch wrapper.(type) {
case *EventHTTPTransactionDone:
entry := RequestEntry{}
entry.T = ev.Time.Sub(begin).Seconds()
httpAddHeaders(
ev.HTTPRequestHeaders, &entry.Request.HeadersList, &entry.Request.Headers)
entry.Request.Method = ev.HTTPMethod
entry.Request.URL = ev.HTTPURL
entry.Request.Transport = ev.Transport
httpAddHeaders(
ev.HTTPResponseHeaders, &entry.Response.HeadersList, &entry.Response.Headers)
entry.Response.Code = int64(ev.HTTPStatusCode)
entry.Response.Locations = ev.HTTPResponseHeaders.Values("Location")
entry.Response.Body.Value = string(ev.HTTPResponseBody)
entry.Response.BodyIsTruncated = ev.HTTPResponseBodyIsTruncated
entry.Failure = NewFailure(ev.Err)
out = append(out, entry)
}
}
return
}
type dnsQueryType string
// NewDNSQueriesList returns a list of DNS queries.
func NewDNSQueriesList(begin time.Time, events []Event) (out []DNSQueryEntry) {
// TODO(bassosimone): add support for CNAME lookups.
for _, wrapper := range events {
if _, ok := wrapper.(*EventResolveDone); !ok {
continue
}
ev := wrapper.Value()
for _, qtype := range []dnsQueryType{"A", "AAAA"} {
entry := qtype.makeQueryEntry(begin, ev)
for _, addr := range ev.Addresses {
if qtype.ipOfType(addr) {
entry.Answers = append(
entry.Answers, qtype.makeAnswerEntry(addr))
}
}
if len(entry.Answers) <= 0 && ev.Err == nil {
// This allows us to skip cases where the server does not have
// an IPv6 address but has an IPv4 address. Instead, when we
// receive an error, we want to track its existence. The main
// issue here is that we are cheating, because we are creating
// entries representing queries, but we don't know what the
// resolver actually did, especially the system resolver. So,
// this output is just our best guess.
continue
}
out = append(out, entry)
}
}
return
}
func (qtype dnsQueryType) ipOfType(addr string) bool {
switch qtype {
case "A":
return !strings.Contains(addr, ":")
case "AAAA":
return strings.Contains(addr, ":")
}
return false
}
func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry {
answer := DNSAnswerEntry{AnswerType: string(qtype)}
// Figuring out the ASN and the org here is not just a service to whoever
// is reading a JSON: Web Connectivity also depends on it!
asn, org, _ := geolocate.LookupASN(addr)
answer.ASN = int64(asn)
answer.ASOrgName = org
switch qtype {
case "A":
answer.IPv4 = addr
case "AAAA":
answer.IPv6 = addr
}
return answer
}
func (qtype dnsQueryType) makeQueryEntry(begin time.Time, ev *EventValue) DNSQueryEntry {
return DNSQueryEntry{
Engine: ev.Proto,
Failure: NewFailure(ev.Err),
Hostname: ev.Hostname,
QueryType: string(qtype),
ResolverAddress: ev.Address,
T: ev.Time.Sub(begin).Seconds(),
}
}
// NewNetworkEventsList returns a list of network events.
func NewNetworkEventsList(begin time.Time, events []Event) (out []NetworkEvent) {
for _, wrapper := range events {
ev := wrapper.Value()
switch wrapper.(type) {
case *EventConnectOperation:
out = append(out, NetworkEvent{
Address: ev.Address,
Failure: NewFailure(ev.Err),
Operation: wrapper.Name(),
Proto: ev.Proto,
T: ev.Time.Sub(begin).Seconds(),
})
case *EventReadOperation:
out = append(out, NetworkEvent{
Failure: NewFailure(ev.Err),
Operation: wrapper.Name(),
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
case *EventWriteOperation:
out = append(out, NetworkEvent{
Failure: NewFailure(ev.Err),
Operation: wrapper.Name(),
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
case *EventReadFromOperation:
out = append(out, NetworkEvent{
Address: ev.Address,
Failure: NewFailure(ev.Err),
Operation: wrapper.Name(),
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
case *EventWriteToOperation:
out = append(out, NetworkEvent{
Address: ev.Address,
Failure: NewFailure(ev.Err),
Operation: wrapper.Name(),
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
default: // For example, "tls_handshake_done" (used in data analysis!)
out = append(out, NetworkEvent{
Failure: NewFailure(ev.Err),
Operation: wrapper.Name(),
T: ev.Time.Sub(begin).Seconds(),
})
}
}
return
}
// NewTLSHandshakesList creates a new TLSHandshakesList
func NewTLSHandshakesList(begin time.Time, events []Event) (out []TLSHandshake) {
for _, wrapper := range events {
switch wrapper.(type) {
case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // interested
default:
continue // not interested
}
ev := wrapper.Value()
out = append(out, TLSHandshake{
Address: ev.Address,
CipherSuite: ev.TLSCipherSuite,
Failure: NewFailure(ev.Err),
NegotiatedProtocol: ev.TLSNegotiatedProto,
NoTLSVerify: ev.NoTLSVerify,
PeerCertificates: tlsMakePeerCerts(ev.TLSPeerCerts),
ServerName: ev.TLSServerName,
T: ev.Time.Sub(begin).Seconds(),
TLSVersion: ev.TLSVersion,
})
}
return
}
func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) {
for _, entry := range in {
out = append(out, MaybeBinaryValue{Value: string(entry.Raw)})
}
return
}
+757
View File
@@ -0,0 +1,757 @@
package tracex
import (
"context"
"crypto/x509"
"errors"
"io"
"net/http"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/websocket"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestDNSQueryType(t *testing.T) {
t.Run("ipOfType", func(t *testing.T) {
type expectation struct {
qtype dnsQueryType
ip string
output bool
}
var expectations = []expectation{{
qtype: "A",
ip: "8.8.8.8",
output: true,
}, {
qtype: "A",
ip: "2a00:1450:4002:801::2004",
output: false,
}, {
qtype: "AAAA",
ip: "8.8.8.8",
output: false,
}, {
qtype: "AAAA",
ip: "2a00:1450:4002:801::2004",
output: true,
}, {
qtype: "ANTANI",
ip: "2a00:1450:4002:801::2004",
output: false,
}, {
qtype: "ANTANI",
ip: "8.8.8.8",
output: false,
}}
for _, exp := range expectations {
if exp.qtype.ipOfType(exp.ip) != exp.output {
t.Fatalf("failure for %+v", exp)
}
}
})
}
func TestNewTCPConnectList(t *testing.T) {
begin := time.Now()
type args struct {
begin time.Time
events []Event
}
tests := []struct {
name string
args args
want []TCPConnectEntry
}{{
name: "empty run",
args: args{
begin: begin,
events: nil,
},
want: nil,
}, {
name: "realistic run",
args: args{
begin: begin,
events: []Event{&EventResolveDone{&EventValue{ // skipped because not relevant
Addresses: []string{"8.8.8.8", "8.8.4.4"},
Hostname: "dns.google.com",
Time: begin.Add(100 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{
Address: "8.8.8.8:853",
Duration: 30 * time.Millisecond,
Proto: "tcp",
Time: begin.Add(130 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{
Address: "8.8.8.8:853",
Duration: 55 * time.Millisecond,
Proto: "udp", // this one should be skipped because it's UDP
Time: begin.Add(130 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{
Address: "8.8.4.4:53",
Duration: 50 * time.Millisecond,
Err: io.EOF,
Proto: "tcp",
Time: begin.Add(180 * time.Millisecond),
}}},
},
want: []TCPConnectEntry{{
IP: "8.8.8.8",
Port: 853,
Status: TCPConnectStatus{
Success: true,
},
T: 0.13,
}, {
IP: "8.8.4.4",
Port: 53,
Status: TCPConnectStatus{
Failure: NewFailure(io.EOF),
Success: false,
},
T: 0.18,
}},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewTCPConnectList(tt.args.begin, tt.args.events)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
}
})
}
}
func TestNewRequestList(t *testing.T) {
begin := time.Now()
type args struct {
begin time.Time
events []Event
}
tests := []struct {
name string
args args
want []RequestEntry
}{{
name: "empty run",
args: args{
begin: begin,
events: nil,
},
want: nil,
}, {
name: "realistic run",
args: args{
begin: begin,
// Two round trips so we can test the sorting expected by OONI
events: []Event{&EventHTTPTransactionDone{&EventValue{
HTTPRequestHeaders: http.Header{
"User-Agent": []string{"miniooni/0.1.0-dev"},
},
HTTPMethod: "POST",
HTTPURL: "https://www.example.com/submit",
HTTPResponseHeaders: http.Header{
"Server": []string{"miniooni/0.1.0-dev"},
},
HTTPStatusCode: 200,
HTTPResponseBody: []byte("{}"),
HTTPResponseBodyIsTruncated: false,
Time: begin.Add(10 * time.Millisecond),
}}, &EventHTTPTransactionDone{&EventValue{
HTTPRequestHeaders: http.Header{
"User-Agent": []string{"miniooni/0.1.0-dev"},
},
HTTPMethod: "GET",
HTTPURL: "https://www.example.com/result",
Err: io.EOF,
Time: begin.Add(20 * time.Millisecond),
}}},
},
want: []RequestEntry{{
Failure: NewFailure(io.EOF),
Request: HTTPRequest{
HeadersList: []HTTPHeader{{
Key: "User-Agent",
Value: MaybeBinaryValue{
Value: "miniooni/0.1.0-dev",
},
}},
Headers: map[string]MaybeBinaryValue{
"User-Agent": {Value: "miniooni/0.1.0-dev"},
},
Method: "GET",
URL: "https://www.example.com/result",
},
Response: HTTPResponse{
HeadersList: []HTTPHeader{},
Headers: make(map[string]MaybeBinaryValue),
},
T: 0.02,
}, {
Request: HTTPRequest{
Body: MaybeBinaryValue{
Value: "",
},
HeadersList: []HTTPHeader{{
Key: "User-Agent",
Value: MaybeBinaryValue{
Value: "miniooni/0.1.0-dev",
},
}},
Headers: map[string]MaybeBinaryValue{
"User-Agent": {Value: "miniooni/0.1.0-dev"},
},
Method: "POST",
URL: "https://www.example.com/submit",
},
Response: HTTPResponse{
Body: MaybeBinaryValue{
Value: "{}",
},
Code: 200,
HeadersList: []HTTPHeader{{
Key: "Server",
Value: MaybeBinaryValue{
Value: "miniooni/0.1.0-dev",
},
}},
Headers: map[string]MaybeBinaryValue{
"Server": {Value: "miniooni/0.1.0-dev"},
},
Locations: nil,
},
T: 0.01,
}},
}, {
// for an example of why we need to sort headers, see
// https://github.com/ooni/probe-engine/pull/751/checks?check_run_id=853562310
name: "run with redirect and headers to sort",
args: args{
begin: begin,
events: []Event{&EventHTTPTransactionDone{&EventValue{
HTTPRequestHeaders: http.Header{
"User-Agent": []string{"miniooni/0.1.0-dev"},
},
HTTPMethod: "GET",
HTTPURL: "https://www.example.com/",
HTTPResponseHeaders: http.Header{
"Server": []string{"miniooni/0.1.0-dev"},
"Location": []string{"https://x.example.com", "https://y.example.com"},
},
HTTPStatusCode: 302,
Time: begin.Add(10 * time.Millisecond),
}}},
},
want: []RequestEntry{{
Request: HTTPRequest{
HeadersList: []HTTPHeader{{
Key: "User-Agent",
Value: MaybeBinaryValue{
Value: "miniooni/0.1.0-dev",
},
}},
Headers: map[string]MaybeBinaryValue{
"User-Agent": {Value: "miniooni/0.1.0-dev"},
},
Method: "GET",
URL: "https://www.example.com/",
},
Response: HTTPResponse{
Code: 302,
HeadersList: []HTTPHeader{{
Key: "Location",
Value: MaybeBinaryValue{
Value: "https://x.example.com",
},
}, {
Key: "Location",
Value: MaybeBinaryValue{
Value: "https://y.example.com",
},
}, {
Key: "Server",
Value: MaybeBinaryValue{
Value: "miniooni/0.1.0-dev",
},
}},
Headers: map[string]MaybeBinaryValue{
"Server": {Value: "miniooni/0.1.0-dev"},
"Location": {Value: "https://x.example.com"},
},
Locations: []string{
"https://x.example.com", "https://y.example.com",
},
},
T: 0.01,
}},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewRequestList(tt.args.begin, tt.args.events)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
}
})
}
}
func TestNewDNSQueriesList(t *testing.T) {
begin := time.Now()
type args struct {
begin time.Time
events []Event
}
tests := []struct {
name string
args args
want []DNSQueryEntry
}{{
name: "empty run",
args: args{
begin: begin,
events: nil,
},
want: nil,
}, {
name: "realistic run",
args: args{
begin: begin,
events: []Event{&EventResolveDone{&EventValue{
Address: "1.1.1.1:853",
Addresses: []string{"8.8.8.8", "8.8.4.4"},
Hostname: "dns.google.com",
Proto: "dot",
Time: begin.Add(100 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{ // skipped because not relevant
Address: "8.8.8.8:853",
Duration: 30 * time.Millisecond,
Proto: "tcp",
Time: begin.Add(130 * time.Millisecond),
}}, &EventConnectOperation{&EventValue{ // skipped because not relevant
Address: "8.8.4.4:53",
Duration: 50 * time.Millisecond,
Err: io.EOF,
Proto: "tcp",
Time: begin.Add(180 * time.Millisecond),
}}},
},
want: []DNSQueryEntry{{
Answers: []DNSAnswerEntry{{
ASN: 15169,
ASOrgName: "Google LLC",
AnswerType: "A",
IPv4: "8.8.8.8",
}, {
ASN: 15169,
ASOrgName: "Google LLC",
AnswerType: "A",
IPv4: "8.8.4.4",
}},
Engine: "dot",
Hostname: "dns.google.com",
QueryType: "A",
ResolverAddress: "1.1.1.1:853",
T: 0.1,
}},
}, {
name: "run with IPv6 results",
args: args{
begin: begin,
events: []Event{&EventResolveDone{&EventValue{
Addresses: []string{"2001:4860:4860::8888"},
Hostname: "dns.google.com",
Time: begin.Add(200 * time.Millisecond),
}}},
},
want: []DNSQueryEntry{{
Answers: []DNSAnswerEntry{{
ASN: 15169,
ASOrgName: "Google LLC",
AnswerType: "AAAA",
IPv6: "2001:4860:4860::8888",
}},
Hostname: "dns.google.com",
QueryType: "AAAA",
T: 0.2,
}},
}, {
name: "run with errors",
args: args{
begin: begin,
events: []Event{&EventResolveDone{&EventValue{
Err: &netxlite.ErrWrapper{Failure: netxlite.FailureDNSNXDOMAINError},
Hostname: "dns.google.com",
Time: begin.Add(200 * time.Millisecond),
}}},
},
want: []DNSQueryEntry{{
Answers: nil,
Failure: NewFailure(
&netxlite.ErrWrapper{Failure: netxlite.FailureDNSNXDOMAINError}),
Hostname: "dns.google.com",
QueryType: "A",
T: 0.2,
}, {
Answers: nil,
Failure: NewFailure(
&netxlite.ErrWrapper{Failure: netxlite.FailureDNSNXDOMAINError}),
Hostname: "dns.google.com",
QueryType: "AAAA",
T: 0.2,
}},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewDNSQueriesList(tt.args.begin, tt.args.events)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
}
})
}
}
func TestNewNetworkEventsList(t *testing.T) {
begin := time.Now()
type args struct {
begin time.Time
events []Event
}
tests := []struct {
name string
args args
want []NetworkEvent
}{{
name: "empty run",
args: args{
begin: begin,
events: nil,
},
want: nil,
}, {
name: "realistic run",
args: args{
begin: begin,
events: []Event{&EventConnectOperation{&EventValue{
Address: "8.8.8.8:853",
Err: io.EOF,
Proto: "tcp",
Time: begin.Add(7 * time.Millisecond),
}}, &EventReadOperation{&EventValue{
Err: context.Canceled,
NumBytes: 7117,
Time: begin.Add(11 * time.Millisecond),
}}, &EventReadFromOperation{&EventValue{
Address: "8.8.8.8:853",
Err: context.Canceled,
NumBytes: 7117,
Time: begin.Add(11 * time.Millisecond),
}}, &EventWriteOperation{&EventValue{
Err: websocket.ErrBadHandshake,
NumBytes: 4114,
Time: begin.Add(14 * time.Millisecond),
}}, &EventWriteToOperation{&EventValue{
Address: "8.8.8.8:853",
Err: websocket.ErrBadHandshake,
NumBytes: 4114,
Time: begin.Add(14 * time.Millisecond),
}}, &EventResolveStart{&EventValue{
// We expect this event to be logged event though it's not a typical I/O
// event (it seems these extra events are useful for debugging)
Time: begin.Add(15 * time.Millisecond),
}}},
},
want: []NetworkEvent{{
Address: "8.8.8.8:853",
Failure: NewFailure(io.EOF),
Operation: netxlite.ConnectOperation,
Proto: "tcp",
T: 0.007,
}, {
Failure: NewFailure(context.Canceled),
NumBytes: 7117,
Operation: netxlite.ReadOperation,
T: 0.011,
}, {
Address: "8.8.8.8:853",
Failure: NewFailure(context.Canceled),
NumBytes: 7117,
Operation: netxlite.ReadFromOperation,
T: 0.011,
}, {
Failure: NewFailure(websocket.ErrBadHandshake),
NumBytes: 4114,
Operation: netxlite.WriteOperation,
T: 0.014,
}, {
Address: "8.8.8.8:853",
Failure: NewFailure(websocket.ErrBadHandshake),
NumBytes: 4114,
Operation: netxlite.WriteToOperation,
T: 0.014,
}, {
Operation: "resolve_start",
T: 0.015,
}},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewNetworkEventsList(tt.args.begin, tt.args.events)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
}
})
}
}
func TestNewTLSHandshakesList(t *testing.T) {
begin := time.Now()
type args struct {
begin time.Time
events []Event
}
tests := []struct {
name string
args args
want []TLSHandshake
}{{
name: "empty run",
args: args{
begin: begin,
events: nil,
},
want: nil,
}, {
name: "realistic run with TLS",
args: args{
begin: begin,
events: []Event{&EventTLSHandshakeDone{&EventValue{
Address: "131.252.210.176:443",
Err: io.EOF,
NoTLSVerify: false,
Proto: "tcp",
TLSCipherSuite: "SUITE",
TLSNegotiatedProto: "h2",
TLSPeerCerts: []*x509.Certificate{{
Raw: []byte("deadbeef"),
}, {
Raw: []byte("abad1dea"),
}},
TLSServerName: "x.org",
TLSVersion: "TLSv1.3",
Time: begin.Add(55 * time.Millisecond),
}}},
},
want: []TLSHandshake{{
Address: "131.252.210.176:443",
CipherSuite: "SUITE",
Failure: NewFailure(io.EOF),
NegotiatedProtocol: "h2",
NoTLSVerify: false,
PeerCertificates: []MaybeBinaryValue{{
Value: "deadbeef",
}, {
Value: "abad1dea",
}},
ServerName: "x.org",
T: 0.055,
TLSVersion: "TLSv1.3",
}},
}, {
name: "realistic run with QUIC",
args: args{
begin: begin,
events: []Event{&EventQUICHandshakeDone{&EventValue{
Address: "131.252.210.176:443",
Err: io.EOF,
NoTLSVerify: false,
Proto: "quic",
TLSCipherSuite: "SUITE",
TLSNegotiatedProto: "h3",
TLSPeerCerts: []*x509.Certificate{{
Raw: []byte("deadbeef"),
}, {
Raw: []byte("abad1dea"),
}},
TLSServerName: "x.org",
TLSVersion: "TLSv1.3",
Time: begin.Add(55 * time.Millisecond),
}}},
},
want: []TLSHandshake{{
Address: "131.252.210.176:443",
CipherSuite: "SUITE",
Failure: NewFailure(io.EOF),
NegotiatedProtocol: "h3",
NoTLSVerify: false,
PeerCertificates: []MaybeBinaryValue{{
Value: "deadbeef",
}, {
Value: "abad1dea",
}},
ServerName: "x.org",
T: 0.055,
TLSVersion: "TLSv1.3",
}},
}, {
name: "realistic run with no suitable events",
args: args{
begin: begin,
events: []Event{&EventResolveStart{&EventValue{
Time: begin.Add(55 * time.Millisecond),
}}},
},
want: nil,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewTLSHandshakesList(tt.args.begin, tt.args.events)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
}
})
}
}
func TestNewFailure(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
want *string
}{{
name: "when error is nil",
args: args{
err: nil,
},
want: nil,
}, {
name: "when error is wrapped and failure meaningful",
args: args{
err: &netxlite.ErrWrapper{
Failure: netxlite.FailureConnectionRefused,
},
},
want: func() *string {
s := netxlite.FailureConnectionRefused
return &s
}(),
}, {
name: "when error is wrapped and failure is not meaningful",
args: args{
err: &netxlite.ErrWrapper{},
},
want: func() *string {
s := "unknown_failure: errWrapper.Failure is empty"
return &s
}(),
}, {
name: "when error is not wrapped but wrappable",
args: args{err: io.EOF},
want: func() *string {
s := "eof_error"
return &s
}(),
}, {
name: "when the error is not wrapped and not wrappable",
args: args{
err: errors.New("use of closed socket 127.0.0.1:8080->10.0.0.1:22"),
},
want: func() *string {
s := "unknown_failure: use of closed socket [scrubbed]->[scrubbed]"
return &s
}(),
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewFailure(tt.args.err)
if tt.want == nil && got == nil {
return
}
if tt.want == nil && got != nil {
t.Errorf("NewFailure: want %+v, got %s", tt.want, *got)
return
}
if tt.want != nil && got == nil {
t.Errorf("NewFailure: want %s, got %+v", *tt.want, got)
return
}
if *tt.want != *got {
t.Errorf("NewFailure: want %s, got %s", *tt.want, *got)
return
}
})
}
}
func TestNewFailedOperation(t *testing.T) {
type args struct {
err error
}
tests := []struct {
name string
args args
want *string
}{{
name: "With no error",
args: args{
err: nil, // explicit
},
want: nil, // explicit
}, {
name: "With wrapped error and non-empty operation",
args: args{
err: &netxlite.ErrWrapper{
Failure: netxlite.FailureConnectionRefused,
Operation: netxlite.ConnectOperation,
},
},
want: (func() *string {
s := netxlite.ConnectOperation
return &s
})(),
}, {
name: "With wrapped error and empty operation",
args: args{
err: &netxlite.ErrWrapper{
Failure: netxlite.FailureConnectionRefused,
},
},
want: (func() *string {
s := netxlite.UnknownOperation
return &s
})(),
}, {
name: "With non wrapped error",
args: args{
err: io.EOF,
},
want: (func() *string {
s := netxlite.UnknownOperation
return &s
})(),
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewFailedOperation(tt.args.err)
if got == nil && tt.want == nil {
return
}
if got == nil && tt.want != nil {
t.Errorf("NewFailedOperation() = %v, want %v", got, tt.want)
return
}
if got != nil && tt.want == nil {
t.Errorf("NewFailedOperation() = %v, want %v", got, tt.want)
return
}
if got != nil && tt.want != nil && *got != *tt.want {
t.Errorf("NewFailedOperation() = %v, want %v", got, tt.want)
return
}
})
}
}
+159
View File
@@ -0,0 +1,159 @@
package tracex
//
// TCP and connected UDP sockets
//
import (
"context"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
)
// DialerSaver saves events occurring during the dial
type DialerSaver struct {
// Dialer is the underlying dialer,
Dialer model.Dialer
// Saver saves events.
Saver *Saver
}
// NewConnectObserver returns a DialerWrapper that observes the
// connect event. This function will return nil, which is a valid
// DialerWrapper for netxlite.WrapDialer, if Saver is nil.
func (s *Saver) NewConnectObserver() model.DialerWrapper {
if s == nil {
return nil // valid DialerWrapper according to netxlite's docs
}
return &dialerConnectObserver{
saver: s,
}
}
type dialerConnectObserver struct {
saver *Saver
}
var _ model.DialerWrapper = &dialerConnectObserver{}
func (w *dialerConnectObserver) WrapDialer(d model.Dialer) model.Dialer {
return &DialerSaver{
Dialer: d,
Saver: w.saver,
}
}
// DialContext implements Dialer.DialContext
func (d *DialerSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
start := time.Now()
conn, err := d.Dialer.DialContext(ctx, network, address)
stop := time.Now()
d.Saver.Write(&EventConnectOperation{&EventValue{
Address: address,
Duration: stop.Sub(start),
Err: err,
Proto: network,
Time: stop,
}})
return conn, err
}
func (d *DialerSaver) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
// DialerConnSaver wraps the returned connection such that we
// collect all the read/write events that occur.
type DialerConnSaver struct {
// Dialer is the underlying dialer
Dialer model.Dialer
// Saver saves events
Saver *Saver
}
// NewReadWriteObserver returns a DialerWrapper that observes the
// I/O events. This function will return nil, which is a valid
// DialerWrapper for netxlite.WrapDialer, if Saver is nil.
func (s *Saver) NewReadWriteObserver() model.DialerWrapper {
if s == nil {
return nil // valid DialerWrapper according to netxlite's docs
}
return &dialerReadWriteObserver{
saver: s,
}
}
type dialerReadWriteObserver struct {
saver *Saver
}
var _ model.DialerWrapper = &dialerReadWriteObserver{}
func (w *dialerReadWriteObserver) WrapDialer(d model.Dialer) model.Dialer {
return &DialerConnSaver{
Dialer: d,
Saver: w.saver,
}
}
// DialContext implements Dialer.DialContext
func (d *DialerConnSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
return &dialerConnWrapper{saver: d.Saver, Conn: conn}, nil
}
func (d *DialerConnSaver) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
type dialerConnWrapper struct {
net.Conn
saver *Saver
}
func (c *dialerConnWrapper) Read(p []byte) (int, error) {
proto := c.Conn.RemoteAddr().Network()
remoteAddr := c.Conn.RemoteAddr().String()
start := time.Now()
count, err := c.Conn.Read(p)
stop := time.Now()
c.saver.Write(&EventReadOperation{&EventValue{
Address: remoteAddr,
Data: p[:count],
Duration: stop.Sub(start),
Err: err,
NumBytes: count,
Proto: proto,
Time: stop,
}})
return count, err
}
func (c *dialerConnWrapper) Write(p []byte) (int, error) {
proto := c.Conn.RemoteAddr().Network()
remoteAddr := c.Conn.RemoteAddr().String()
start := time.Now()
count, err := c.Conn.Write(p)
stop := time.Now()
c.saver.Write(&EventWriteOperation{&EventValue{
Address: remoteAddr,
Data: p[:count],
Duration: stop.Sub(start),
Err: err,
NumBytes: count,
Proto: proto,
Time: stop,
}})
return count, err
}
var _ model.Dialer = &DialerSaver{}
var _ model.Dialer = &DialerConnSaver{}
var _ net.Conn = &dialerConnWrapper{}
+279
View File
@@ -0,0 +1,279 @@
package tracex
import (
"context"
"errors"
"io"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestDialerConnectObserver(t *testing.T) {
saver := &Saver{}
obs := &dialerConnectObserver{
saver: saver,
}
dialer := &mocks.Dialer{}
out := obs.WrapDialer(dialer)
dialSaver := out.(*DialerSaver)
if dialSaver.Dialer != dialer {
t.Fatal("invalid dialer")
}
if dialSaver.Saver != saver {
t.Fatal("invalid saver")
}
}
func TestDialerSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
dlr := &DialerSaver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, expected
},
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal("expected another error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected a single event here")
}
if ev[0].Value().Address != "www.google.com:443" {
t.Fatal("unexpected Address")
}
if ev[0].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[0].Value().Err, expected) {
t.Fatal("unexpected Err")
}
if ev[0].Name() != netxlite.ConnectOperation {
t.Fatal("unexpected Name")
}
if ev[0].Value().Proto != "tcp" {
t.Fatal("unexpected Proto")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := &DialerSaver{
Dialer: child,
Saver: &Saver{},
}
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}
func TestDialerReadWriteObserver(t *testing.T) {
saver := &Saver{}
obs := &dialerReadWriteObserver{
saver: saver,
}
dialer := &mocks.Dialer{}
out := obs.WrapDialer(dialer)
dialSaver := out.(*DialerConnSaver)
if dialSaver.Dialer != dialer {
t.Fatal("invalid dialer")
}
if dialSaver.Saver != saver {
t.Fatal("invalid saver")
}
}
func TestDialerConnSaver(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
dlr := &DialerConnSaver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, expected
},
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
})
t.Run("on success", func(t *testing.T) {
origConn := &mocks.Conn{}
saver := &Saver{}
dlr := &DialerConnSaver{
Dialer: &DialerSaver{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return origConn, nil
},
},
Saver: saver,
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if err != nil {
t.Fatal("not the error we expected", err)
}
cw := conn.(*dialerConnWrapper)
if cw.Conn != origConn {
t.Fatal("unexpected conn")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := &DialerConnSaver{
Dialer: child,
Saver: &Saver{},
}
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}
func TestDialerConnWrapper(t *testing.T) {
t.Run("Read", func(t *testing.T) {
baseConn := &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, io.EOF
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "www.google.com:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
saver := &Saver{}
conn := &dialerConnWrapper{
Conn: baseConn,
saver: saver,
}
data := make([]byte, 155)
count, err := conn.Read(data)
if !errors.Is(err, io.EOF) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected a single event here")
}
if ev[0].Value().Address != "www.google.com:443" {
t.Fatal("unexpected Address")
}
if ev[0].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[0].Value().Err, io.EOF) {
t.Fatal("unexpected Err")
}
if ev[0].Name() != netxlite.ReadOperation {
t.Fatal("unexpected Name")
}
if ev[0].Value().Proto != "tcp" {
t.Fatal("unexpected Proto")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
})
t.Run("Write", func(t *testing.T) {
baseConn := &mocks.Conn{
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "www.google.com:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
saver := &Saver{}
conn := &dialerConnWrapper{
Conn: baseConn,
saver: saver,
}
data := make([]byte, 155)
count, err := conn.Write(data)
if !errors.Is(err, io.EOF) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("unexpected count")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected a single event here")
}
if ev[0].Value().Address != "www.google.com:443" {
t.Fatal("unexpected Address")
}
if ev[0].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[0].Value().Err, io.EOF) {
t.Fatal("unexpected Err")
}
if ev[0].Name() != netxlite.WriteOperation {
t.Fatal("unexpected Name")
}
if ev[0].Value().Proto != "tcp" {
t.Fatal("unexpected Proto")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
})
}
+8
View File
@@ -0,0 +1,8 @@
// Package tracex performs measurements using tracing. To use tracing means
// that we'll wrap netx data types (e.g., a Dialer) with equivalent data types
// saving events into a Saver data struture. Then we will use the data types
// normally (e.g., call the Dialer's DialContet method and then use the
// resulting connection). When done, we will extract the trace containing
// all the events that occurred from the saver and process it to determine
// what happened during the measurement itself.
package tracex
+247
View File
@@ -0,0 +1,247 @@
package tracex
//
// All the possible events
//
import (
"crypto/x509"
"net/http"
"time"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// Event is one of the events within a trace.
type Event interface {
// Value returns the event value
Value() *EventValue
// Name returns the event name
Name() string
}
// EventTLSHandshakeStart is the beginning of the TLS handshake.
type EventTLSHandshakeStart struct {
V *EventValue
}
func (ev *EventTLSHandshakeStart) Value() *EventValue {
return ev.V
}
func (ev *EventTLSHandshakeStart) Name() string {
return "tls_handshake_start"
}
// EventTLSHandshakeDone is the end of the TLS handshake.
type EventTLSHandshakeDone struct {
V *EventValue
}
func (ev *EventTLSHandshakeDone) Value() *EventValue {
return ev.V
}
func (ev *EventTLSHandshakeDone) Name() string {
return "tls_handshake_done"
}
// EventResolveStart is the beginning of a DNS lookup operation.
type EventResolveStart struct {
V *EventValue
}
func (ev *EventResolveStart) Value() *EventValue {
return ev.V
}
func (ev *EventResolveStart) Name() string {
return "resolve_start"
}
// EventResolveDone is the end of a DNS lookup operation.
type EventResolveDone struct {
V *EventValue
}
func (ev *EventResolveDone) Value() *EventValue {
return ev.V
}
func (ev *EventResolveDone) Name() string {
return "resolve_done"
}
// EventDNSRoundTripStart is the start of a DNS round trip.
type EventDNSRoundTripStart struct {
V *EventValue
}
func (ev *EventDNSRoundTripStart) Value() *EventValue {
return ev.V
}
func (ev *EventDNSRoundTripStart) Name() string {
return "dns_round_trip_start"
}
// EventDNSRoundTripDone is the end of a DNS round trip.
type EventDNSRoundTripDone struct {
V *EventValue
}
func (ev *EventDNSRoundTripDone) Value() *EventValue {
return ev.V
}
func (ev *EventDNSRoundTripDone) Name() string {
return "dns_round_trip_done"
}
// EventQUICHandshakeStart is the start of a QUIC handshake.
type EventQUICHandshakeStart struct {
V *EventValue
}
func (ev *EventQUICHandshakeStart) Value() *EventValue {
return ev.V
}
func (ev *EventQUICHandshakeStart) Name() string {
return "quic_handshake_start"
}
// EventQUICHandshakeDone is the end of a QUIC handshake.
type EventQUICHandshakeDone struct {
V *EventValue
}
func (ev *EventQUICHandshakeDone) Value() *EventValue {
return ev.V
}
func (ev *EventQUICHandshakeDone) Name() string {
return "quic_handshake_done"
}
// EventWriteToOperation summarizes the WriteTo operation.
type EventWriteToOperation struct {
V *EventValue
}
func (ev *EventWriteToOperation) Value() *EventValue {
return ev.V
}
func (ev *EventWriteToOperation) Name() string {
return netxlite.WriteToOperation
}
// EventReadFromOperation summarizes the ReadFrom operation.
type EventReadFromOperation struct {
V *EventValue
}
func (ev *EventReadFromOperation) Value() *EventValue {
return ev.V
}
func (ev *EventReadFromOperation) Name() string {
return netxlite.ReadFromOperation
}
// EventHTTPTransactionStart is the beginning of an HTTP transaction.
type EventHTTPTransactionStart struct {
V *EventValue
}
func (ev *EventHTTPTransactionStart) Value() *EventValue {
return ev.V
}
func (ev *EventHTTPTransactionStart) Name() string {
return "http_transaction_start"
}
// EventHTTPTransactionDone is the end of an HTTP transaction.
type EventHTTPTransactionDone struct {
V *EventValue
}
func (ev *EventHTTPTransactionDone) Value() *EventValue {
return ev.V
}
func (ev *EventHTTPTransactionDone) Name() string {
return "http_transaction_done"
}
// EventConnectOperation contains information about the connect operation.
type EventConnectOperation struct {
V *EventValue
}
func (ev *EventConnectOperation) Value() *EventValue {
return ev.V
}
func (ev *EventConnectOperation) Name() string {
return netxlite.ConnectOperation
}
// EventReadOperation contains information about a read operation.
type EventReadOperation struct {
V *EventValue
}
func (ev *EventReadOperation) Value() *EventValue {
return ev.V
}
func (ev *EventReadOperation) Name() string {
return netxlite.ReadOperation
}
// EventWriteOperation contains information about a write operation.
type EventWriteOperation struct {
V *EventValue
}
func (ev *EventWriteOperation) Value() *EventValue {
return ev.V
}
func (ev *EventWriteOperation) Name() string {
return netxlite.WriteOperation
}
// Event is one of the events within a trace
type EventValue struct {
Addresses []string `json:",omitempty"`
Address string `json:",omitempty"`
DNSQuery []byte `json:",omitempty"`
DNSResponse []byte `json:",omitempty"`
Data []byte `json:",omitempty"`
Duration time.Duration `json:",omitempty"`
Err error `json:",omitempty"`
HTTPMethod string `json:",omitempty"`
HTTPRequestHeaders http.Header `json:",omitempty"`
HTTPResponseHeaders http.Header `json:",omitempty"`
HTTPResponseBody []byte `json:",omitempty"`
HTTPResponseBodyIsTruncated bool `json:",omitempty"`
HTTPStatusCode int `json:",omitempty"`
HTTPURL string `json:",omitempty"`
Hostname string `json:",omitempty"`
NoTLSVerify bool `json:",omitempty"`
NumBytes int `json:",omitempty"`
Proto string `json:",omitempty"`
TLSServerName string `json:",omitempty"`
TLSCipherSuite string `json:",omitempty"`
TLSNegotiatedProto string `json:",omitempty"`
TLSNextProtos []string `json:",omitempty"`
TLSPeerCerts []*x509.Certificate `json:",omitempty"`
TLSVersion string `json:",omitempty"`
Time time.Time `json:",omitempty"`
Transport string `json:",omitempty"`
}
+125
View File
@@ -0,0 +1,125 @@
package tracex
//
// HTTP
//
import (
"bytes"
"io"
"net/http"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// httpCloneRequestHeaders returns a clone of the headers where we have
// also set the host header, which normally is not set by
// golang until it serializes the request itself.
func httpCloneRequestHeaders(req *http.Request) http.Header {
header := req.Header.Clone()
if req.Host != "" {
header.Set("Host", req.Host)
} else {
header.Set("Host", req.URL.Host)
}
return header
}
// HTTPTransportSaver is a RoundTripper that saves
// events related to the HTTP transaction
type HTTPTransportSaver struct {
// HTTPTransport is the MANDATORY underlying HTTP transport.
HTTPTransport model.HTTPTransport
// Saver is the MANDATORY saver to use.
Saver *Saver
// SnapshotSize is the OPTIONAL maximum body snapshot size (if not set, we'll
// use 1<<17, which we've been using since the ooni/netx days)
SnapshotSize int64
}
// HTTPRoundTrip performs the round trip with the given transport and
// the given arguments and saves the results into the saver.
//
// The maxBodySnapshotSize argument controls the maximum size of the
// body snapshot that we collect along with the HTTP round trip.
func (txp *HTTPTransportSaver) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO(bassosimone): we're currently using the started time for
// the transaction done event, which contrasts with what we do for
// every other event. What does the spec say?
started := time.Now()
txp.Saver.Write(&EventHTTPTransactionStart{&EventValue{
HTTPRequestHeaders: httpCloneRequestHeaders(req),
HTTPMethod: req.Method,
HTTPURL: req.URL.String(),
Transport: txp.HTTPTransport.Network(),
Time: started,
}})
ev := &EventValue{
HTTPRequestHeaders: httpCloneRequestHeaders(req),
HTTPMethod: req.Method,
HTTPURL: req.URL.String(),
Transport: txp.HTTPTransport.Network(),
Time: started,
}
defer txp.Saver.Write(&EventHTTPTransactionDone{ev})
resp, err := txp.HTTPTransport.RoundTrip(req)
if err != nil {
ev.Duration = time.Since(started)
ev.Err = err
return nil, err
}
ev.HTTPStatusCode = resp.StatusCode
ev.HTTPResponseHeaders = resp.Header.Clone()
maxBodySnapshotSize := txp.snapshotSize()
r := io.LimitReader(resp.Body, maxBodySnapshotSize)
body, err := netxlite.ReadAllContext(req.Context(), r)
if err != nil {
ev.Duration = time.Since(started)
ev.Err = err
return nil, err
}
resp.Body = &httpReadableAgainBody{ // allow for reading again the whole body
Reader: io.MultiReader(bytes.NewReader(body), resp.Body),
Closer: resp.Body,
}
ev.Duration = time.Since(started)
ev.HTTPResponseBody = body
ev.HTTPResponseBodyIsTruncated = int64(len(body)) >= maxBodySnapshotSize
return resp, nil
}
func (txp *HTTPTransportSaver) CloseIdleConnections() {
txp.HTTPTransport.CloseIdleConnections()
}
func (txp *HTTPTransportSaver) Network() string {
return txp.HTTPTransport.Network()
}
func (txp *HTTPTransportSaver) snapshotSize() int64 {
if txp.SnapshotSize > 0 {
return txp.SnapshotSize
}
return 1 << 17
}
type httpReadableAgainBody struct {
io.Reader
io.Closer
}
var _ model.HTTPTransport = &HTTPTransportSaver{}
+306
View File
@@ -0,0 +1,306 @@
package tracex
import (
"bytes"
"context"
"errors"
"io"
"net"
"net/http"
"net/url"
"testing"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
)
func TestHTTPTransportSaver(t *testing.T) {
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.HTTPTransport{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := &HTTPTransportSaver{
HTTPTransport: child,
Saver: &Saver{},
}
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("Network", func(t *testing.T) {
expected := "antani"
child := &mocks.HTTPTransport{
MockNetwork: func() string {
return expected
},
}
dialer := &HTTPTransportSaver{
HTTPTransport: child,
Saver: &Saver{},
}
if dialer.Network() != expected {
t.Fatal("unexpected Network")
}
})
t.Run("RoundTrip", func(t *testing.T) {
startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) {
server := &filtering.HTTPProxy{
OnIncomingHost: func(host string) filtering.HTTPAction {
return action
},
}
listener, err := server.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
URL := &url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: "/",
}
return listener, URL
}
measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) {
saver := &Saver{}
txp := &HTTPTransportSaver{
HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger),
Saver: saver,
}
req, err := http.NewRequest("GET", URL.String(), nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("User-Agent", "miniooni")
resp, err := txp.RoundTrip(req)
return resp, saver, err
}
validateRequestFields := func(t *testing.T, value *EventValue, URL *url.URL) {
if value.HTTPMethod != "GET" {
t.Fatal("invalid method")
}
if value.HTTPRequestHeaders.Get("Host") != URL.Host {
t.Fatal("invalid Host header")
}
if value.HTTPRequestHeaders.Get("User-Agent") != "miniooni" {
t.Fatal("invalid User-Agent header")
}
if value.HTTPURL != URL.String() {
t.Fatal("invalid URL")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
if value.Transport != "tcp" {
t.Fatal("expected Transport to be tcp")
}
}
validateRequest := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionStart); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_start" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
}
validateResponseSuccess := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionDone); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_done" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if len(value.HTTPResponseHeaders) <= 0 {
t.Fatal("expected at least one response header")
}
if !bytes.Equal(value.HTTPResponseBody, filtering.HTTPBlockpage451) {
t.Fatal("unexpected value for response body")
}
if value.HTTPStatusCode != 451 {
t.Fatal("unexpected status code")
}
}
t.Run("on success", func(t *testing.T) {
listener, URL := startServer(t, filtering.HTTPAction451)
defer listener.Close()
resp, saver, err := measureHTTP(t, URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 451 {
t.Fatal("unexpected status code", resp.StatusCode)
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
validateRequest(t, events[0], URL)
validateResponseSuccess(t, events[1], URL)
data, err := netxlite.ReadAllContext(context.Background(), resp.Body)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, filtering.HTTPBlockpage451) {
t.Fatal("we cannot re-read the same body")
}
})
validateResponseFailure := func(t *testing.T, ev Event, URL *url.URL) {
if _, good := ev.(*EventHTTPTransactionDone); !good {
t.Fatal("invalid event type")
}
if ev.Name() != "http_transaction_done" {
t.Fatal("invalid event name")
}
value := ev.Value()
validateRequestFields(t, value, URL)
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if value.Err.Error() != "connection_reset" {
t.Fatal("unexpected Err value")
}
if len(value.HTTPResponseHeaders) > 0 {
t.Fatal("expected zero response headers")
}
if !bytes.Equal(value.HTTPResponseBody, nil) {
t.Fatal("unexpected value for response body")
}
if value.HTTPStatusCode != 0 {
t.Fatal("unexpected status code")
}
}
t.Run("on round trip failure", func(t *testing.T) {
listener, URL := startServer(t, filtering.HTTPActionReset)
defer listener.Close()
resp, saver, err := measureHTTP(t, URL)
if err == nil || err.Error() != "connection_reset" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
validateRequest(t, events[0], URL)
validateResponseFailure(t, events[1], URL)
})
// Sometimes useful for testing
/*
dump := func(t *testing.T, ev Event) {
data, _ := json.MarshalIndent(ev.Value(), " ", " ")
t.Log(string(data))
t.Fail()
}
*/
t.Run("on error reading the response body", func(t *testing.T) {
saver := &Saver{}
expected := errors.New("mocked error")
txp := HTTPTransportSaver{
HTTPTransport: &mocks.HTTPTransport{
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
return &http.Response{
Header: http.Header{
"Server": {"antani"},
},
StatusCode: 200,
Body: io.NopCloser(&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}),
}, nil
},
MockNetwork: func() string {
return "tcp"
},
},
SnapshotSize: 4,
Saver: saver,
}
URL := &url.URL{
Scheme: "http",
Host: "127.0.0.1:9050",
}
req, err := http.NewRequest("GET", URL.String(), nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("User-Agent", "miniooni")
resp, err := txp.RoundTrip(req)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response")
}
ev := saver.Read()
validateRequest(t, ev[0], URL)
if ev[1].Value().HTTPStatusCode != 200 {
t.Fatal("invalid status code")
}
if ev[1].Value().HTTPResponseHeaders.Get("Server") != "antani" {
t.Fatal("invalid Server header")
}
if ev[1].Value().Err.Error() != "unknown_failure: mocked error" {
t.Fatal("invalid error")
}
})
})
}
func TestHTTPCloneRequestHeaders(t *testing.T) {
t.Run("with req.Host set", func(t *testing.T) {
req := &http.Request{
Host: "www.example.com",
URL: &url.URL{
Host: "www.kernel.org",
},
Header: http.Header{},
}
header := httpCloneRequestHeaders(req)
if header.Get("Host") != "www.example.com" {
t.Fatal("did not set Host header correctly")
}
})
t.Run("with only req.URL.Host set", func(t *testing.T) {
req := &http.Request{
Host: "",
URL: &url.URL{
Host: "www.kernel.org",
},
Header: http.Header{},
}
header := httpCloneRequestHeaders(req)
if header.Get("Host") != "www.kernel.org" {
t.Fatal("did not set Host header correctly")
}
})
}
+187
View File
@@ -0,0 +1,187 @@
package tracex
//
// QUIC
//
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// QUICDialerSaver saves events occurring during the QUIC handshake.
type QUICDialerSaver struct {
// QUICDialer is the wrapped dialer
QUICDialer model.QUICDialer
// Saver saves events
Saver *Saver
}
// WrapQUICDialer wraps a model.QUICDialer with a QUICHandshakeSaver that will
// save the QUIC handshake results into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original QUICDialer without any wrapping.
func (s *Saver) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer {
if s == nil {
return qd
}
return &QUICDialerSaver{
QUICDialer: qd,
Saver: s,
}
}
// DialContext implements QUICDialer.DialContext
func (h *QUICDialerSaver) DialContext(ctx context.Context, network string,
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
start := time.Now()
// TODO(bassosimone): in the future we probably want to also save
// information about what versions we're willing to accept.
h.Saver.Write(&EventQUICHandshakeStart{&EventValue{
Address: host,
NoTLSVerify: tlsCfg.InsecureSkipVerify,
Proto: network,
TLSNextProtos: tlsCfg.NextProtos,
TLSServerName: tlsCfg.ServerName,
Time: start,
}})
sess, err := h.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
stop := time.Now()
if err != nil {
// TODO(bassosimone): here we should save the peer certs
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
Address: host,
Duration: stop.Sub(start),
Err: err,
NoTLSVerify: tlsCfg.InsecureSkipVerify,
Proto: network,
TLSNextProtos: tlsCfg.NextProtos,
TLSServerName: tlsCfg.ServerName,
Time: stop,
}})
return nil, err
}
state := quicConnectionState(sess)
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
Address: host,
Duration: stop.Sub(start),
NoTLSVerify: tlsCfg.InsecureSkipVerify,
Proto: network,
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: tlsCfg.NextProtos,
TLSPeerCerts: tlsPeerCerts(state, err),
TLSServerName: tlsCfg.ServerName,
TLSVersion: netxlite.TLSVersionString(state.Version),
Time: stop,
}})
return sess, nil
}
func (h *QUICDialerSaver) CloseIdleConnections() {
h.QUICDialer.CloseIdleConnections()
}
// quicConnectionState returns the ConnectionState of a QUIC Session.
func quicConnectionState(sess quic.EarlyConnection) tls.ConnectionState {
return sess.ConnectionState().TLS.ConnectionState
}
// QUICListenerSaver is a QUICListener that also implements saving events.
type QUICListenerSaver struct {
// QUICListener is the underlying QUICListener.
QUICListener model.QUICListener
// Saver is the underlying Saver.
Saver *Saver
}
// WrapQUICListener wraps a model.QUICDialer with a QUICListenerSaver that will
// save the QUIC I/O packet conn events into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original QUICListener without any wrapping.
func (s *Saver) WrapQUICListener(ql model.QUICListener) model.QUICListener {
if s == nil {
return ql
}
return &QUICListenerSaver{
QUICListener: ql,
Saver: s,
}
}
// Listen implements QUICListener.Listen.
func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
pconn, err := qls.QUICListener.Listen(addr)
if err != nil {
return nil, err
}
pconn = &quicPacketConnWrapper{
UDPLikeConn: pconn,
saver: qls.Saver,
}
return pconn, nil
}
// quicPacketConnWrapper saves I/O events
type quicPacketConnWrapper struct {
// UDPLikeConn is the wrapped underlying conn
model.UDPLikeConn
// Saver saves events
saver *Saver
}
func (c *quicPacketConnWrapper) WriteTo(p []byte, addr net.Addr) (int, error) {
start := time.Now()
count, err := c.UDPLikeConn.WriteTo(p, addr)
stop := time.Now()
c.saver.Write(&EventWriteToOperation{&EventValue{
Address: addr.String(),
Data: p[:count],
Duration: stop.Sub(start),
Err: err,
NumBytes: count,
Time: stop,
}})
return count, err
}
func (c *quicPacketConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) {
start := time.Now()
n, addr, err := c.UDPLikeConn.ReadFrom(b)
stop := time.Now()
var data []byte
if n > 0 {
data = b[:n]
}
c.saver.Write(&EventReadFromOperation{&EventValue{
Address: c.safeAddrString(addr),
Data: data,
Duration: stop.Sub(start),
Err: err,
NumBytes: n,
Time: stop,
}})
return n, addr, err
}
func (c *quicPacketConnWrapper) safeAddrString(addr net.Addr) (out string) {
if addr != nil {
out = addr.String()
}
return
}
var _ model.QUICDialer = &QUICDialerSaver{}
var _ model.QUICListener = &QUICListenerSaver{}
var _ model.UDPLikeConn = &quicPacketConnWrapper{}
+448
View File
@@ -0,0 +1,448 @@
package tracex
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestQUICDialerSaver(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
checkStartEventFields := func(t *testing.T, value *EventValue) {
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if !value.NoTLSVerify {
t.Fatal("expected NoTLSVerify to be true")
}
if value.Proto != "udp" {
t.Fatal("wrong protocol")
}
if diff := cmp.Diff(value.TLSNextProtos, []string{"h3"}); diff != "" {
t.Fatal(diff)
}
if value.TLSServerName != "dns.google" {
t.Fatal("invalid TLSServerName")
}
if value.Time.IsZero() {
t.Fatal("expected non zero time")
}
}
checkStartedEvent := func(t *testing.T, ev Event) {
if _, good := ev.(*EventQUICHandshakeStart); !good {
t.Fatal("invalid event type")
}
value := ev.Value()
checkStartEventFields(t, value)
}
checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) {
if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err != nil {
t.Fatal("expected no error here")
}
if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" {
t.Fatal("invalid cipher suite")
}
if value.TLSNegotiatedProto != "h3" {
t.Fatal("invalid negotiated protocol")
}
if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" {
t.Fatal(diff)
}
if value.TLSVersion != "TLSv1.3" {
t.Fatal("invalid TLS version")
}
}
checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) {
if _, good := ev.(*EventQUICHandshakeDone); !good {
t.Fatal("invalid event type")
}
value := ev.Value()
checkStartEventFields(t, value)
fun(t, value)
}
t.Run("on success", func(t *testing.T) {
saver := &Saver{}
returnedConn := &mocks.QUICEarlyConnection{
MockConnectionState: func() quic.ConnectionState {
cs := quic.ConnectionState{}
cs.TLS.ConnectionState.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA
cs.TLS.NegotiatedProtocol = "h3"
cs.TLS.PeerCertificates = []*x509.Certificate{}
cs.TLS.Version = tls.VersionTLS13
return cs
},
}
dialer := saver.WrapQUICDialer(&mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return returnedConn, nil
},
})
ctx := context.Background()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h3"},
ServerName: "dns.google",
}
quicConfig := &quic.Config{}
conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess)
})
checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) {
if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err == nil {
t.Fatal("expected non-nil error here")
}
}
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
dialer := saver.WrapQUICDialer(&mocks.QUICDialer{
MockDialContext: func(ctx context.Context, network, address string,
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
return nil, expected
},
})
ctx := context.Background()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h3"},
ServerName: "dns.google",
}
quicConfig := &quic.Config{}
conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsFailure)
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
child := &mocks.QUICDialer{
MockCloseIdleConnections: func() {
called = true
},
}
dialer := &QUICDialerSaver{
QUICDialer: child,
Saver: &Saver{},
}
dialer.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
}
func TestQUICListenerSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
qls := saver.WrapQUICListener(&mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
return nil, expected
},
})
pconn, err := qls.Listen(&net.UDPAddr{
IP: []byte{},
Port: 8080,
Zone: "",
})
if !errors.Is(err, expected) {
t.Fatal("unexpected error", err)
}
if pconn != nil {
t.Fatal("expected nil pconn here")
}
})
t.Run("on success", func(t *testing.T) {
saver := &Saver{}
returnedConn := &mocks.UDPLikeConn{}
qls := saver.WrapQUICListener(&mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
return returnedConn, nil
},
})
pconn, err := qls.Listen(&net.UDPAddr{
IP: []byte{},
Port: 8080,
Zone: "",
})
if err != nil {
t.Fatal(err)
}
wconn := pconn.(*quicPacketConnWrapper)
if wconn.UDPLikeConn != returnedConn {
t.Fatal("invalid underlying connection")
}
if wconn.saver != saver {
t.Fatal("invalid saver")
}
})
}
func TestQUICPacketConnWrapper(t *testing.T) {
t.Run("ReadFrom", func(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
conn := &quicPacketConnWrapper{
UDPLikeConn: &mocks.UDPLikeConn{
MockReadFrom: func(p []byte) (int, net.Addr, error) {
return 0, nil, expected
},
},
saver: saver,
}
buf := make([]byte, 1<<17)
count, addr, err := conn.ReadFrom(buf)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("invalid count")
}
if addr != nil {
t.Fatal("invalid addr")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("invalid number of events")
}
ev0 := events[0]
if _, good := ev0.(*EventReadFromOperation); !good {
t.Fatal("invalid event type")
}
value := ev0.Value()
if value.Address != "" {
t.Fatal("invalid Address")
}
if len(value.Data) != 0 {
t.Fatal("invalid Data")
}
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if !errors.Is(value.Err, expected) {
t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 0 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
t.Run("on success", func(t *testing.T) {
expected := []byte{1, 2, 3, 4}
saver := &Saver{}
expectedAddr := &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
MockNetwork: func() string {
return "udp"
},
}
conn := &quicPacketConnWrapper{
UDPLikeConn: &mocks.UDPLikeConn{
MockReadFrom: func(p []byte) (int, net.Addr, error) {
copy(p, expected)
return len(expected), expectedAddr, nil
},
},
saver: saver,
}
buf := make([]byte, 1<<17)
count, addr, err := conn.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
if count != 4 {
t.Fatal("invalid count")
}
if addr != expectedAddr {
t.Fatal("invalid addr")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("invalid number of events")
}
ev0 := events[0]
if _, good := ev0.(*EventReadFromOperation); !good {
t.Fatal("invalid event type")
}
value := ev0.Value()
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if len(value.Data) != 4 {
t.Fatal("invalid Data")
}
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if value.Err != nil {
t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 4 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
})
t.Run("WriteTo", func(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
conn := &quicPacketConnWrapper{
UDPLikeConn: &mocks.UDPLikeConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return 0, expected
},
},
saver: saver,
}
destAddr := &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
}
buf := make([]byte, 7)
count, err := conn.WriteTo(buf, destAddr)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("invalid count")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("invalid number of events")
}
ev0 := events[0]
if _, good := ev0.(*EventWriteToOperation); !good {
t.Fatal("invalid event type")
}
value := ev0.Value()
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if len(value.Data) != 0 {
t.Fatal("invalid Data")
}
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if !errors.Is(value.Err, expected) {
t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 0 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
t.Run("on success", func(t *testing.T) {
saver := &Saver{}
conn := &quicPacketConnWrapper{
UDPLikeConn: &mocks.UDPLikeConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return 1, nil
},
},
saver: saver,
}
destAddr := &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
}
buf := make([]byte, 7)
count, err := conn.WriteTo(buf, destAddr)
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatal("invalid count")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("invalid number of events")
}
ev0 := events[0]
if _, good := ev0.(*EventWriteToOperation); !good {
t.Fatal("invalid event type")
}
value := ev0.Value()
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if len(value.Data) != 1 {
t.Fatal("invalid Data")
}
if value.Duration <= 0 {
t.Fatal("expected nonzero duration")
}
if value.Err != nil {
t.Fatal("unexpected value.Err", value.Err)
}
if value.NumBytes != 1 {
t.Fatal("expected NumBytes")
}
if value.Time.IsZero() {
t.Fatal("expected nonzero Time")
}
})
})
}
+161
View File
@@ -0,0 +1,161 @@
package tracex
//
// DNS lookup and round trip
//
import (
"context"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
)
// ResolverSaver is a resolver that saves events.
type ResolverSaver struct {
// Resolver is the underlying resolver.
Resolver model.Resolver
// Saver saves events.
Saver *Saver
}
// WrapResolver wraps a model.Resolver with a SaverResolver that will save
// the DNS lookup results into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original Resolver without any wrapping.
func (s *Saver) WrapResolver(r model.Resolver) model.Resolver {
if s == nil {
return r
}
return &ResolverSaver{
Resolver: r,
Saver: s,
}
}
// LookupHost implements Resolver.LookupHost
func (r *ResolverSaver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
start := time.Now()
r.Saver.Write(&EventResolveStart{&EventValue{
Address: r.Resolver.Address(),
Hostname: hostname,
Proto: r.Resolver.Network(),
Time: start,
}})
addrs, err := r.Resolver.LookupHost(ctx, hostname)
stop := time.Now()
r.Saver.Write(&EventResolveDone{&EventValue{
Addresses: addrs,
Address: r.Resolver.Address(),
Duration: stop.Sub(start),
Err: err,
Hostname: hostname,
Proto: r.Resolver.Network(),
Time: stop,
}})
return addrs, err
}
func (r *ResolverSaver) Network() string {
return r.Resolver.Network()
}
func (r *ResolverSaver) Address() string {
return r.Resolver.Address()
}
func (r *ResolverSaver) CloseIdleConnections() {
r.Resolver.CloseIdleConnections()
}
func (r *ResolverSaver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
// TODO(bassosimone): we should probably implement this method
return r.Resolver.LookupHTTPS(ctx, domain)
}
func (r *ResolverSaver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
// TODO(bassosimone): we should probably implement this method
return r.Resolver.LookupNS(ctx, domain)
}
// DNSTransportSaver is a DNS transport that saves events.
type DNSTransportSaver struct {
// DNSTransport is the underlying DNS transport.
DNSTransport model.DNSTransport
// Saver saves events.
Saver *Saver
}
// WrapDNSTransport wraps a model.DNSTransport with a SaverDNSTransport that
// will save the DNS round trip results into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original DNSTransport without any wrapping.
func (s *Saver) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
if s == nil {
return txp
}
return &DNSTransportSaver{
DNSTransport: txp,
Saver: s,
}
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp *DNSTransportSaver) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
start := time.Now()
txp.Saver.Write(&EventDNSRoundTripStart{&EventValue{
Address: txp.DNSTransport.Address(),
DNSQuery: dnsMaybeQueryBytes(query),
Proto: txp.DNSTransport.Network(),
Time: start,
}})
response, err := txp.DNSTransport.RoundTrip(ctx, query)
stop := time.Now()
txp.Saver.Write(&EventDNSRoundTripDone{&EventValue{
Address: txp.DNSTransport.Address(),
DNSQuery: dnsMaybeQueryBytes(query),
DNSResponse: dnsMaybeResponseBytes(response),
Duration: stop.Sub(start),
Err: err,
Proto: txp.DNSTransport.Network(),
Time: stop,
}})
return response, err
}
func (txp *DNSTransportSaver) Network() string {
return txp.DNSTransport.Network()
}
func (txp *DNSTransportSaver) Address() string {
return txp.DNSTransport.Address()
}
func (txp *DNSTransportSaver) CloseIdleConnections() {
txp.DNSTransport.CloseIdleConnections()
}
func (txp *DNSTransportSaver) RequiresPadding() bool {
return txp.DNSTransport.RequiresPadding()
}
func dnsMaybeQueryBytes(query model.DNSQuery) []byte {
data, _ := query.Bytes()
return data
}
func dnsMaybeResponseBytes(response model.DNSResponse) []byte {
if response == nil {
return nil
}
return response.Bytes()
}
var _ model.Resolver = &ResolverSaver{}
var _ model.DNSTransport = &DNSTransportSaver{}
+279
View File
@@ -0,0 +1,279 @@
package tracex
import (
"bytes"
"context"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
func TestResolverSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("no such host")
saver := &Saver{}
reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected))
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name() != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if ev[1].Value().Addresses != nil {
t.Fatal("unexpected Addresses")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Value().Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name() != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
t.Run("on success", func(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &Saver{}
reso := saver.WrapResolver(newFakeResolverWithResult(expected))
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if err != nil {
t.Fatal("expected nil error here")
}
if !reflect.DeepEqual(addrs, expected) {
t.Fatal("not the result we expected")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name() != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !reflect.DeepEqual(ev[1].Value().Addresses, expected) {
t.Fatal("unexpected Addresses")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Value().Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name() != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
}
func TestDNSTransportSaver(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("no such host")
saver := &Saver{}
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].Value().DNSResponse != nil {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Value().Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
t.Run("on success", func(t *testing.T) {
expected := []byte{0xef, 0xbe, 0xad, 0xde}
saver := &Saver{}
response := &mocks.DNSResponse{
MockBytes: func() []byte {
return expected
},
}
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return response, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal("we expected nil error here")
}
if !bytes.Equal(reply.Bytes(), expected) {
t.Fatal("expected another reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name() != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Value().Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].Value().DNSResponse, expected) {
t.Fatal("unexpected DNSReply")
}
if ev[1].Value().Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Value().Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name() != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Value().Time.After(ev[0].Value().Time) {
t.Fatal("the saved time is wrong")
}
})
}
func newFakeResolverWithExplicitError(err error) model.Resolver {
runtimex.PanicIfNil(err, "passed nil error")
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, err
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
MockCloseIdleConnections: func() {
// nothing
},
MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
return nil, errors.New("not implemented")
},
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
return nil, errors.New("not implemented")
},
}
}
func newFakeResolverWithResult(r []string) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return r, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
MockCloseIdleConnections: func() {
// nothing
},
MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
return nil, errors.New("not implemented")
},
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
return nil, errors.New("not implemented")
},
}
}
+35
View File
@@ -0,0 +1,35 @@
package tracex
//
// Saver implementation
//
import "sync"
// The Saver saves a trace. The zero value of this type
// is valid and can be used without initialization.
type Saver struct {
// ops contains the saved events.
ops []Event
// mu provides mutual exclusion.
mu sync.Mutex
}
// Read reads and returns events inside the trace. It advances
// the read pointer so you won't see such events again.
func (s *Saver) Read() []Event {
s.mu.Lock()
defer s.mu.Unlock()
v := s.ops
s.ops = nil
return v
}
// Write adds the given event to the trace. A subsequent call
// to Read will read this event.
func (s *Saver) Write(ev Event) {
s.mu.Lock()
defer s.mu.Unlock()
s.ops = append(s.ops, ev)
}
+90
View File
@@ -0,0 +1,90 @@
package tracex
import (
"sync"
"testing"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestSaver(t *testing.T) {
t.Run("concurrent writes followed by read", func(t *testing.T) {
saver := Saver{}
var wg sync.WaitGroup
const parallel = 10
wg.Add(parallel)
for idx := 0; idx < parallel; idx++ {
go func() {
saver.Write(&EventReadFromOperation{&EventValue{}})
wg.Done()
}()
}
wg.Wait()
ev := saver.Read()
if len(ev) != parallel {
t.Fatal("unexpected number of events read")
}
})
t.Run("NewConnectObserver", func(t *testing.T) {
t.Run("nil Saver", func(t *testing.T) {
var saver *Saver
obs := saver.NewConnectObserver()
if obs != nil {
t.Fatal("expected nil observer")
}
})
t.Run("nonnnil Saver", func(t *testing.T) {
saver := &Saver{}
obs := saver.NewConnectObserver()
underlying := obs.(*dialerConnectObserver)
if underlying.saver != saver {
t.Fatal("invalid saver")
}
})
})
t.Run("NewReadWriteObserver", func(t *testing.T) {
t.Run("nil Saver", func(t *testing.T) {
var saver *Saver
obs := saver.NewReadWriteObserver()
if obs != nil {
t.Fatal("expected nil observer")
}
})
t.Run("nonnnil Saver", func(t *testing.T) {
saver := &Saver{}
obs := saver.NewReadWriteObserver()
underlying := obs.(*dialerReadWriteObserver)
if underlying.saver != saver {
t.Fatal("invalid saver")
}
})
})
t.Run("WrapQUICDialer", func(t *testing.T) {
t.Run("nil Saver", func(t *testing.T) {
var saver *Saver
base := &mocks.QUICDialer{}
qd := saver.WrapQUICDialer(base)
if qd != base {
t.Fatal("unexpected returned QUICDialer")
}
})
t.Run("nonnnil Saver", func(t *testing.T) {
saver := &Saver{}
base := &mocks.QUICDialer{}
qd := saver.WrapQUICDialer(base)
underlying := qd.(*QUICDialerSaver)
if underlying.Saver != saver {
t.Fatal("invalid Saver")
}
if underlying.QUICDialer != base {
t.Fatal("invalid QUICDialer")
}
})
})
}
+98
View File
@@ -0,0 +1,98 @@
package tracex
//
// TLS
//
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// TLSHandshakerSaver saves events occurring during the TLS handshake.
type TLSHandshakerSaver struct {
// TLSHandshaker is the underlying TLS handshaker.
TLSHandshaker model.TLSHandshaker
// Saver is the saver in which to save events.
Saver *Saver
}
// WrapTLSHandshaker wraps a model.TLSHandshaker with a SaverTLSHandshaker
// that will save the TLS handshake results into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original TLSHandshaker without any wrapping.
func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker {
if s == nil {
return thx
}
return &TLSHandshakerSaver{
TLSHandshaker: thx,
Saver: s,
}
}
// Handshake implements model.TLSHandshaker.Handshake
func (h *TLSHandshakerSaver) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
proto := conn.RemoteAddr().Network()
remoteAddr := conn.RemoteAddr().String()
start := time.Now()
h.Saver.Write(&EventTLSHandshakeStart{&EventValue{
Address: remoteAddr,
NoTLSVerify: config.InsecureSkipVerify,
Proto: proto,
TLSNextProtos: config.NextProtos,
TLSServerName: config.ServerName,
Time: start,
}})
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
stop := time.Now()
h.Saver.Write(&EventTLSHandshakeDone{&EventValue{
Address: remoteAddr,
Duration: stop.Sub(start),
Err: err,
NoTLSVerify: config.InsecureSkipVerify,
Proto: proto,
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: config.NextProtos,
TLSPeerCerts: tlsPeerCerts(state, err),
TLSServerName: config.ServerName,
TLSVersion: netxlite.TLSVersionString(state.Version),
Time: stop,
}})
return tlsconn, state, err
}
var _ model.TLSHandshaker = &TLSHandshakerSaver{}
// tlsPeerCerts returns the certificates presented by the peer regardless
// of whether the TLS handshake was successful
func tlsPeerCerts(state tls.ConnectionState, err error) []*x509.Certificate {
var x509HostnameError x509.HostnameError
if errors.As(err, &x509HostnameError) {
// Test case: https://wrong.host.badssl.com/
return []*x509.Certificate{x509HostnameError.Certificate}
}
var x509UnknownAuthorityError x509.UnknownAuthorityError
if errors.As(err, &x509UnknownAuthorityError) {
// Test case: https://self-signed.badssl.com/. This error has
// never been among the ones returned by MK.
return []*x509.Certificate{x509UnknownAuthorityError.Cert}
}
var x509CertificateInvalidError x509.CertificateInvalidError
if errors.As(err, &x509CertificateInvalidError) {
// Test case: https://expired.badssl.com/
return []*x509.Certificate{x509CertificateInvalidError.Cert}
}
return state.PeerCertificates
}
+252
View File
@@ -0,0 +1,252 @@
package tracex
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestTLSHandshakerSaver(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
checkStartEventFields := func(t *testing.T, value *EventValue) {
if value.Address != "8.8.8.8:443" {
t.Fatal("invalid Address")
}
if !value.NoTLSVerify {
t.Fatal("expected NoTLSVerify to be true")
}
if value.Proto != "tcp" {
t.Fatal("wrong protocol")
}
if diff := cmp.Diff(value.TLSNextProtos, []string{"h2"}); diff != "" {
t.Fatal(diff)
}
if value.TLSServerName != "dns.google" {
t.Fatal("invalid TLSServerName")
}
if value.Time.IsZero() {
t.Fatal("expected non zero time")
}
}
checkStartedEvent := func(t *testing.T, ev Event) {
if _, good := ev.(*EventTLSHandshakeStart); !good {
t.Fatal("invalid event type")
}
value := ev.Value()
checkStartEventFields(t, value)
}
checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) {
if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err != nil {
t.Fatal("expected no error here")
}
if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" {
t.Fatal("invalid cipher suite")
}
if value.TLSNegotiatedProto != "h2" {
t.Fatal("invalid negotiated protocol")
}
if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" {
t.Fatal(diff)
}
if value.TLSVersion != "TLSv1.3" {
t.Fatal("invalid TLS version")
}
}
checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) {
if _, good := ev.(*EventTLSHandshakeDone); !good {
t.Fatal("invalid event type")
}
value := ev.Value()
checkStartEventFields(t, value)
fun(t, value)
}
t.Run("on success", func(t *testing.T) {
saver := &Saver{}
returnedConnState := tls.ConnectionState{
CipherSuite: tls.TLS_RSA_WITH_RC4_128_SHA,
NegotiatedProtocol: "h2",
PeerCertificates: []*x509.Certificate{},
Version: tls.VersionTLS13,
}
returnedConn := &mocks.TLSConn{
MockConnectionState: func() tls.ConnectionState {
return returnedConnState
},
}
thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn,
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return returnedConn, returnedConnState, nil
},
})
ctx := context.Background()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
ServerName: "dns.google",
}
tcpConn := &mocks.Conn{
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess)
})
checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) {
if value.Duration <= 0 {
t.Fatal("expected non-zero duration")
}
if value.Err == nil {
t.Fatal("expected non-nil error here")
}
if value.TLSCipherSuite != "" {
t.Fatal("invalid TLS cipher suite")
}
if value.TLSNegotiatedProto != "" {
t.Fatal("invalid negotiated proto")
}
if len(value.TLSPeerCerts) > 0 {
t.Fatal("expected no peer certs")
}
if value.TLSVersion != "" {
t.Fatal("invalid TLS version")
}
}
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
saver := &Saver{}
thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn,
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expected
},
})
ctx := context.Background()
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
ServerName: "dns.google",
}
tcpConn := &mocks.Conn{
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "8.8.8.8:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("expected two events")
}
checkStartedEvent(t, events[0])
checkDoneEvent(t, events[1], checkDoneEventFieldsFailure)
})
})
}
func Test_tlsPeerCerts(t *testing.T) {
cert0 := &x509.Certificate{Raw: []byte{1, 2, 3, 4}}
type args struct {
state tls.ConnectionState
err error
}
tests := []struct {
name string
args args
want []*x509.Certificate
}{{
name: "no error",
args: args{
state: tls.ConnectionState{
PeerCertificates: []*x509.Certificate{cert0},
},
},
want: []*x509.Certificate{cert0},
}, {
name: "all empty",
args: args{},
want: nil,
}, {
name: "x509.HostnameError",
args: args{
state: tls.ConnectionState{},
err: x509.HostnameError{
Certificate: cert0,
},
},
want: []*x509.Certificate{cert0},
}, {
name: "x509.UnknownAuthorityError",
args: args{
state: tls.ConnectionState{},
err: x509.UnknownAuthorityError{
Cert: cert0,
},
},
want: []*x509.Certificate{cert0},
}, {
name: "x509.CertificateInvalidError",
args: args{
state: tls.ConnectionState{},
err: x509.CertificateInvalidError{
Cert: cert0,
},
},
want: []*x509.Certificate{cert0},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tlsPeerCerts(tt.args.state, tt.args.err)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Fatal(diff)
}
})
}
}