refactor(tracex): convert to unit testing (#781)
The exercise already allowed me to notice issues such as fields not being properly initialized by savers. This is one of the last steps before moving tracex away from the internal/netx package and into the internal package. See https://github.com/ooni/probe/issues/2121
This commit is contained in:
parent
6212daa54a
commit
d397036073
|
@ -119,7 +119,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
|
stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -195,7 +195,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
|
stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -271,7 +271,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
|
stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -347,7 +347,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
|
stxp, ok := sr.Txp.(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the DNS transport we expected")
|
t.Fatal("not the DNS transport we expected")
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,7 +187,7 @@ func NewHTTPTransport(config Config) model.HTTPTransport {
|
||||||
txp = &netxlite.HTTPTransportLogger{Logger: config.Logger, HTTPTransport: txp}
|
txp = &netxlite.HTTPTransportLogger{Logger: config.Logger, HTTPTransport: txp}
|
||||||
}
|
}
|
||||||
if config.HTTPSaver != nil {
|
if config.HTTPSaver != nil {
|
||||||
txp = &tracex.SaverTransactionHTTPTransport{
|
txp = &tracex.HTTPTransportSaver{
|
||||||
HTTPTransport: txp, Saver: config.HTTPSaver}
|
HTTPTransport: txp, Saver: config.HTTPSaver}
|
||||||
}
|
}
|
||||||
return txp
|
return txp
|
||||||
|
|
|
@ -126,7 +126,7 @@ func TestNewResolverWithSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
sr, ok := ir.Resolver.(*tracex.SaverResolver)
|
sr, ok := ir.Resolver.(*tracex.ResolverSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
|
@ -332,7 +332,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
|
||||||
if rtd.TLSHandshaker == nil {
|
if rtd.TLSHandshaker == nil {
|
||||||
t.Fatal("invalid TLSHandshaker")
|
t.Fatal("invalid TLSHandshaker")
|
||||||
}
|
}
|
||||||
sth, ok := rtd.TLSHandshaker.(*tracex.SaverTLSHandshaker)
|
sth, ok := rtd.TLSHandshaker.(*tracex.TLSHandshakerSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the TLSHandshaker we expected")
|
t.Fatal("not the TLSHandshaker we expected")
|
||||||
}
|
}
|
||||||
|
@ -504,7 +504,7 @@ func TestNewWithSaver(t *testing.T) {
|
||||||
txp := netx.NewHTTPTransport(netx.Config{
|
txp := netx.NewHTTPTransport(netx.Config{
|
||||||
HTTPSaver: saver,
|
HTTPSaver: saver,
|
||||||
})
|
})
|
||||||
stxptxp, ok := txp.(*tracex.SaverTransactionHTTPTransport)
|
stxptxp, ok := txp.(*tracex.HTTPTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -622,7 +622,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(*tracex.SaverDNSTransport)
|
txp, ok := r.Transport().(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -659,7 +659,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(*tracex.SaverDNSTransport)
|
txp, ok := r.Transport().(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -700,7 +700,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(*tracex.SaverDNSTransport)
|
txp, ok := r.Transport().(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
@ -745,7 +745,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the resolver we expected")
|
t.Fatal("not the resolver we expected")
|
||||||
}
|
}
|
||||||
txp, ok := r.Transport().(*tracex.SaverDNSTransport)
|
txp, ok := r.Transport().(*tracex.DNSTransportSaver)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatal("not the transport we expected")
|
t.Fatal("not the transport we expected")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package tracex
|
package tracex
|
||||||
|
|
||||||
|
//
|
||||||
|
// Code to generate the OONI archival data format from events
|
||||||
|
//
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -43,8 +47,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewTCPConnectList creates a new TCPConnectList
|
// NewTCPConnectList creates a new TCPConnectList
|
||||||
func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry {
|
func NewTCPConnectList(begin time.Time, events []Event) (out []TCPConnectEntry) {
|
||||||
var out []TCPConnectEntry
|
|
||||||
for _, wrapper := range events {
|
for _, wrapper := range events {
|
||||||
if _, ok := wrapper.(*EventConnectOperation); !ok {
|
if _, ok := wrapper.(*EventConnectOperation); !ok {
|
||||||
continue
|
continue
|
||||||
|
@ -60,13 +63,14 @@ func NewTCPConnectList(begin time.Time, events []Event) []TCPConnectEntry {
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Port: iport,
|
Port: iport,
|
||||||
Status: TCPConnectStatus{
|
Status: TCPConnectStatus{
|
||||||
|
Blocked: nil, // only used by Web Connectivity
|
||||||
Failure: NewFailure(event.Err),
|
Failure: NewFailure(event.Err),
|
||||||
Success: event.Err == nil,
|
Success: event.Err == nil,
|
||||||
},
|
},
|
||||||
T: event.Time.Sub(begin).Seconds(),
|
T: event.Time.Sub(begin).Seconds(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return out
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFailure creates a failure nullable string from the given error
|
// NewFailure creates a failure nullable string from the given error
|
||||||
|
@ -101,11 +105,9 @@ func NewFailedOperation(err error) *string {
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpAddHeaders(
|
// httpAddHeaders adds the headers inside source into destList and destMap.
|
||||||
source http.Header,
|
func httpAddHeaders(source http.Header, destList *[]HTTPHeader,
|
||||||
destList *[]HTTPHeader,
|
destMap *map[string]MaybeBinaryValue) {
|
||||||
destMap *map[string]MaybeBinaryValue,
|
|
||||||
) {
|
|
||||||
*destList = []HTTPHeader{}
|
*destList = []HTTPHeader{}
|
||||||
*destMap = make(map[string]model.ArchivalMaybeBinaryData)
|
*destMap = make(map[string]model.ArchivalMaybeBinaryData)
|
||||||
for key, values := range source {
|
for key, values := range source {
|
||||||
|
@ -122,32 +124,28 @@ func httpAddHeaders(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Sorting helps with unit testing (map keys are unordered)
|
||||||
sort.Slice(*destList, func(i, j int) bool {
|
sort.Slice(*destList, func(i, j int) bool {
|
||||||
return (*destList)[i].Key < (*destList)[j].Key
|
return (*destList)[i].Key < (*destList)[j].Key
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequestList returns the list for "requests"
|
// NewRequestList returns the list for "requests"
|
||||||
func NewRequestList(begin time.Time, events []Event) []RequestEntry {
|
func NewRequestList(begin time.Time, events []Event) (out []RequestEntry) {
|
||||||
// OONI wants the last request to appear first
|
// OONI wants the last request to appear first
|
||||||
var out []RequestEntry
|
|
||||||
tmp := newRequestList(begin, events)
|
tmp := newRequestList(begin, events)
|
||||||
for i := len(tmp) - 1; i >= 0; i-- {
|
for i := len(tmp) - 1; i >= 0; i-- {
|
||||||
out = append(out, tmp[i])
|
out = append(out, tmp[i])
|
||||||
}
|
}
|
||||||
return out
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequestList(begin time.Time, events []Event) []RequestEntry {
|
func newRequestList(begin time.Time, events []Event) (out []RequestEntry) {
|
||||||
var (
|
|
||||||
out []RequestEntry
|
|
||||||
entry RequestEntry
|
|
||||||
)
|
|
||||||
for _, wrapper := range events {
|
for _, wrapper := range events {
|
||||||
ev := wrapper.Value()
|
ev := wrapper.Value()
|
||||||
switch wrapper.(type) {
|
switch wrapper.(type) {
|
||||||
case *EventHTTPTransactionDone:
|
case *EventHTTPTransactionDone:
|
||||||
entry = RequestEntry{}
|
entry := RequestEntry{}
|
||||||
entry.T = ev.Time.Sub(begin).Seconds()
|
entry.T = ev.Time.Sub(begin).Seconds()
|
||||||
httpAddHeaders(
|
httpAddHeaders(
|
||||||
ev.HTTPRequestHeaders, &entry.Request.HeadersList, &entry.Request.Headers)
|
ev.HTTPRequestHeaders, &entry.Request.HeadersList, &entry.Request.Headers)
|
||||||
|
@ -164,15 +162,14 @@ func newRequestList(begin time.Time, events []Event) []RequestEntry {
|
||||||
out = append(out, entry)
|
out = append(out, entry)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type dnsQueryType string
|
type dnsQueryType string
|
||||||
|
|
||||||
// NewDNSQueriesList returns a list of DNS queries.
|
// NewDNSQueriesList returns a list of DNS queries.
|
||||||
func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry {
|
func NewDNSQueriesList(begin time.Time, events []Event) (out []DNSQueryEntry) {
|
||||||
// TODO(bassosimone): add support for CNAME lookups.
|
// TODO(bassosimone): add support for CNAME lookups.
|
||||||
var out []DNSQueryEntry
|
|
||||||
for _, wrapper := range events {
|
for _, wrapper := range events {
|
||||||
if _, ok := wrapper.(*EventResolveDone); !ok {
|
if _, ok := wrapper.(*EventResolveDone); !ok {
|
||||||
continue
|
continue
|
||||||
|
@ -199,7 +196,7 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry {
|
||||||
out = append(out, entry)
|
out = append(out, entry)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (qtype dnsQueryType) ipOfType(addr string) bool {
|
func (qtype dnsQueryType) ipOfType(addr string) bool {
|
||||||
|
@ -214,6 +211,8 @@ func (qtype dnsQueryType) ipOfType(addr string) bool {
|
||||||
|
|
||||||
func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry {
|
func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry {
|
||||||
answer := DNSAnswerEntry{AnswerType: string(qtype)}
|
answer := DNSAnswerEntry{AnswerType: string(qtype)}
|
||||||
|
// Figuring out the ASN and the org here is not just a service to whoever
|
||||||
|
// is reading a JSON: Web Connectivity also depends on it!
|
||||||
asn, org, _ := geolocate.LookupASN(addr)
|
asn, org, _ := geolocate.LookupASN(addr)
|
||||||
answer.ASN = int64(asn)
|
answer.ASN = int64(asn)
|
||||||
answer.ASOrgName = org
|
answer.ASOrgName = org
|
||||||
|
@ -237,9 +236,8 @@ func (qtype dnsQueryType) makeQueryEntry(begin time.Time, ev *EventValue) DNSQue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetworkEventsList returns a list of DNS queries.
|
// NewNetworkEventsList returns a list of network events.
|
||||||
func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent {
|
func NewNetworkEventsList(begin time.Time, events []Event) (out []NetworkEvent) {
|
||||||
var out []NetworkEvent
|
|
||||||
for _, wrapper := range events {
|
for _, wrapper := range events {
|
||||||
ev := wrapper.Value()
|
ev := wrapper.Value()
|
||||||
switch wrapper.(type) {
|
switch wrapper.(type) {
|
||||||
|
@ -281,7 +279,7 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent {
|
||||||
NumBytes: int64(ev.NumBytes),
|
NumBytes: int64(ev.NumBytes),
|
||||||
T: ev.Time.Sub(begin).Seconds(),
|
T: ev.Time.Sub(begin).Seconds(),
|
||||||
})
|
})
|
||||||
default:
|
default: // For example, "tls_handshake_done" (used in data analysis!)
|
||||||
out = append(out, NetworkEvent{
|
out = append(out, NetworkEvent{
|
||||||
Failure: NewFailure(ev.Err),
|
Failure: NewFailure(ev.Err),
|
||||||
Operation: wrapper.Name(),
|
Operation: wrapper.Name(),
|
||||||
|
@ -289,15 +287,14 @@ func NewNetworkEventsList(begin time.Time, events []Event) []NetworkEvent {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTLSHandshakesList creates a new TLSHandshakesList
|
// NewTLSHandshakesList creates a new TLSHandshakesList
|
||||||
func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake {
|
func NewTLSHandshakesList(begin time.Time, events []Event) (out []TLSHandshake) {
|
||||||
var out []TLSHandshake
|
|
||||||
for _, wrapper := range events {
|
for _, wrapper := range events {
|
||||||
switch wrapper.(type) {
|
switch wrapper.(type) {
|
||||||
case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // ok
|
case *EventQUICHandshakeDone, *EventTLSHandshakeDone: // interested
|
||||||
default:
|
default:
|
||||||
continue // not interested
|
continue // not interested
|
||||||
}
|
}
|
||||||
|
@ -314,12 +311,12 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake {
|
||||||
TLSVersion: ev.TLSVersion,
|
TLSVersion: ev.TLSVersion,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return out
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) {
|
func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) {
|
||||||
for _, e := range in {
|
for _, entry := range in {
|
||||||
out = append(out, MaybeBinaryValue{Value: string(e.Raw)})
|
out = append(out, MaybeBinaryValue{Value: string(entry.Raw)})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -15,7 +14,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSQueryIPOfType(t *testing.T) {
|
func TestDNSQueryType(t *testing.T) {
|
||||||
|
t.Run("ipOfType", func(t *testing.T) {
|
||||||
type expectation struct {
|
type expectation struct {
|
||||||
qtype dnsQueryType
|
qtype dnsQueryType
|
||||||
ip string
|
ip string
|
||||||
|
@ -51,6 +51,7 @@ func TestDNSQueryIPOfType(t *testing.T) {
|
||||||
t.Fatalf("failure for %+v", exp)
|
t.Fatalf("failure for %+v", exp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTCPConnectList(t *testing.T) {
|
func TestNewTCPConnectList(t *testing.T) {
|
||||||
|
@ -74,7 +75,7 @@ func TestNewTCPConnectList(t *testing.T) {
|
||||||
name: "realistic run",
|
name: "realistic run",
|
||||||
args: args{
|
args: args{
|
||||||
begin: begin,
|
begin: begin,
|
||||||
events: []Event{&EventResolveDone{&EventValue{
|
events: []Event{&EventResolveDone{&EventValue{ // skipped because not relevant
|
||||||
Addresses: []string{"8.8.8.8", "8.8.4.4"},
|
Addresses: []string{"8.8.8.8", "8.8.4.4"},
|
||||||
Hostname: "dns.google.com",
|
Hostname: "dns.google.com",
|
||||||
Time: begin.Add(100 * time.Millisecond),
|
Time: begin.Add(100 * time.Millisecond),
|
||||||
|
@ -86,7 +87,7 @@ func TestNewTCPConnectList(t *testing.T) {
|
||||||
}}, &EventConnectOperation{&EventValue{
|
}}, &EventConnectOperation{&EventValue{
|
||||||
Address: "8.8.8.8:853",
|
Address: "8.8.8.8:853",
|
||||||
Duration: 55 * time.Millisecond,
|
Duration: 55 * time.Millisecond,
|
||||||
Proto: "udp",
|
Proto: "udp", // this one should be skipped because it's UDP
|
||||||
Time: begin.Add(130 * time.Millisecond),
|
Time: begin.Add(130 * time.Millisecond),
|
||||||
}}, &EventConnectOperation{&EventValue{
|
}}, &EventConnectOperation{&EventValue{
|
||||||
Address: "8.8.4.4:53",
|
Address: "8.8.4.4:53",
|
||||||
|
@ -115,8 +116,9 @@ func TestNewTCPConnectList(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := NewTCPConnectList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) {
|
got := NewTCPConnectList(tt.args.begin, tt.args.events)
|
||||||
t.Error(cmp.Diff(got, tt.want))
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -143,6 +145,7 @@ func TestNewRequestList(t *testing.T) {
|
||||||
name: "realistic run",
|
name: "realistic run",
|
||||||
args: args{
|
args: args{
|
||||||
begin: begin,
|
begin: begin,
|
||||||
|
// Two round trips so we can test the sorting expected by OONI
|
||||||
events: []Event{&EventHTTPTransactionDone{&EventValue{
|
events: []Event{&EventHTTPTransactionDone{&EventValue{
|
||||||
HTTPRequestHeaders: http.Header{
|
HTTPRequestHeaders: http.Header{
|
||||||
"User-Agent": []string{"miniooni/0.1.0-dev"},
|
"User-Agent": []string{"miniooni/0.1.0-dev"},
|
||||||
|
@ -286,8 +289,9 @@ func TestNewRequestList(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := NewRequestList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) {
|
got := NewRequestList(tt.args.begin, tt.args.events)
|
||||||
t.Error(cmp.Diff(tt.want, got))
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -320,12 +324,12 @@ func TestNewDNSQueriesList(t *testing.T) {
|
||||||
Hostname: "dns.google.com",
|
Hostname: "dns.google.com",
|
||||||
Proto: "dot",
|
Proto: "dot",
|
||||||
Time: begin.Add(100 * time.Millisecond),
|
Time: begin.Add(100 * time.Millisecond),
|
||||||
}}, &EventConnectOperation{&EventValue{
|
}}, &EventConnectOperation{&EventValue{ // skipped because not relevant
|
||||||
Address: "8.8.8.8:853",
|
Address: "8.8.8.8:853",
|
||||||
Duration: 30 * time.Millisecond,
|
Duration: 30 * time.Millisecond,
|
||||||
Proto: "tcp",
|
Proto: "tcp",
|
||||||
Time: begin.Add(130 * time.Millisecond),
|
Time: begin.Add(130 * time.Millisecond),
|
||||||
}}, &EventConnectOperation{&EventValue{
|
}}, &EventConnectOperation{&EventValue{ // skipped because not relevant
|
||||||
Address: "8.8.4.4:53",
|
Address: "8.8.4.4:53",
|
||||||
Duration: 50 * time.Millisecond,
|
Duration: 50 * time.Millisecond,
|
||||||
Err: io.EOF,
|
Err: io.EOF,
|
||||||
|
@ -452,6 +456,10 @@ func TestNewNetworkEventsList(t *testing.T) {
|
||||||
Err: websocket.ErrBadHandshake,
|
Err: websocket.ErrBadHandshake,
|
||||||
NumBytes: 4114,
|
NumBytes: 4114,
|
||||||
Time: begin.Add(14 * time.Millisecond),
|
Time: begin.Add(14 * time.Millisecond),
|
||||||
|
}}, &EventResolveStart{&EventValue{
|
||||||
|
// We expect this event to be logged event though it's not a typical I/O
|
||||||
|
// event (it seems these extra events are useful for debugging)
|
||||||
|
Time: begin.Add(15 * time.Millisecond),
|
||||||
}}},
|
}}},
|
||||||
},
|
},
|
||||||
want: []NetworkEvent{{
|
want: []NetworkEvent{{
|
||||||
|
@ -482,12 +490,16 @@ func TestNewNetworkEventsList(t *testing.T) {
|
||||||
NumBytes: 4114,
|
NumBytes: 4114,
|
||||||
Operation: netxlite.WriteToOperation,
|
Operation: netxlite.WriteToOperation,
|
||||||
T: 0.014,
|
T: 0.014,
|
||||||
|
}, {
|
||||||
|
Operation: "resolve_start",
|
||||||
|
T: 0.015,
|
||||||
}},
|
}},
|
||||||
}}
|
}}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := NewNetworkEventsList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) {
|
got := NewNetworkEventsList(tt.args.begin, tt.args.events)
|
||||||
t.Error(cmp.Diff(got, tt.want))
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -511,13 +523,14 @@ func TestNewTLSHandshakesList(t *testing.T) {
|
||||||
},
|
},
|
||||||
want: nil,
|
want: nil,
|
||||||
}, {
|
}, {
|
||||||
name: "realistic run",
|
name: "realistic run with TLS",
|
||||||
args: args{
|
args: args{
|
||||||
begin: begin,
|
begin: begin,
|
||||||
events: []Event{&EventTLSHandshakeDone{&EventValue{
|
events: []Event{&EventTLSHandshakeDone{&EventValue{
|
||||||
Address: "131.252.210.176:443",
|
Address: "131.252.210.176:443",
|
||||||
Err: io.EOF,
|
Err: io.EOF,
|
||||||
NoTLSVerify: false,
|
NoTLSVerify: false,
|
||||||
|
Proto: "tcp",
|
||||||
TLSCipherSuite: "SUITE",
|
TLSCipherSuite: "SUITE",
|
||||||
TLSNegotiatedProto: "h2",
|
TLSNegotiatedProto: "h2",
|
||||||
TLSPeerCerts: []*x509.Certificate{{
|
TLSPeerCerts: []*x509.Certificate{{
|
||||||
|
@ -545,11 +558,57 @@ func TestNewTLSHandshakesList(t *testing.T) {
|
||||||
T: 0.055,
|
T: 0.055,
|
||||||
TLSVersion: "TLSv1.3",
|
TLSVersion: "TLSv1.3",
|
||||||
}},
|
}},
|
||||||
|
}, {
|
||||||
|
name: "realistic run with QUIC",
|
||||||
|
args: args{
|
||||||
|
begin: begin,
|
||||||
|
events: []Event{&EventQUICHandshakeDone{&EventValue{
|
||||||
|
Address: "131.252.210.176:443",
|
||||||
|
Err: io.EOF,
|
||||||
|
NoTLSVerify: false,
|
||||||
|
Proto: "quic",
|
||||||
|
TLSCipherSuite: "SUITE",
|
||||||
|
TLSNegotiatedProto: "h3",
|
||||||
|
TLSPeerCerts: []*x509.Certificate{{
|
||||||
|
Raw: []byte("deadbeef"),
|
||||||
|
}, {
|
||||||
|
Raw: []byte("abad1dea"),
|
||||||
|
}},
|
||||||
|
TLSServerName: "x.org",
|
||||||
|
TLSVersion: "TLSv1.3",
|
||||||
|
Time: begin.Add(55 * time.Millisecond),
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
want: []TLSHandshake{{
|
||||||
|
Address: "131.252.210.176:443",
|
||||||
|
CipherSuite: "SUITE",
|
||||||
|
Failure: NewFailure(io.EOF),
|
||||||
|
NegotiatedProtocol: "h3",
|
||||||
|
NoTLSVerify: false,
|
||||||
|
PeerCertificates: []MaybeBinaryValue{{
|
||||||
|
Value: "deadbeef",
|
||||||
|
}, {
|
||||||
|
Value: "abad1dea",
|
||||||
|
}},
|
||||||
|
ServerName: "x.org",
|
||||||
|
T: 0.055,
|
||||||
|
TLSVersion: "TLSv1.3",
|
||||||
|
}},
|
||||||
|
}, {
|
||||||
|
name: "realistic run with no suitable events",
|
||||||
|
args: args{
|
||||||
|
begin: begin,
|
||||||
|
events: []Event{&EventResolveStart{&EventValue{
|
||||||
|
Time: begin.Add(55 * time.Millisecond),
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
}}
|
}}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := NewTLSHandshakesList(tt.args.begin, tt.args.events); !reflect.DeepEqual(got, tt.want) {
|
got := NewTLSHandshakesList(tt.args.begin, tt.args.events)
|
||||||
t.Error(cmp.Diff(got, tt.want))
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SaverDialer saves events occurring during the dial
|
// DialerSaver saves events occurring during the dial
|
||||||
type SaverDialer struct {
|
type DialerSaver struct {
|
||||||
// Dialer is the underlying dialer,
|
// Dialer is the underlying dialer,
|
||||||
Dialer model.Dialer
|
Dialer model.Dialer
|
||||||
|
|
||||||
|
@ -28,26 +28,26 @@ func (s *Saver) NewConnectObserver() model.DialerWrapper {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil // valid DialerWrapper according to netxlite's docs
|
return nil // valid DialerWrapper according to netxlite's docs
|
||||||
}
|
}
|
||||||
return &saverDialerWrapper{
|
return &dialerConnectObserver{
|
||||||
saver: s,
|
saver: s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type saverDialerWrapper struct {
|
type dialerConnectObserver struct {
|
||||||
saver *Saver
|
saver *Saver
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DialerWrapper = &saverDialerWrapper{}
|
var _ model.DialerWrapper = &dialerConnectObserver{}
|
||||||
|
|
||||||
func (w *saverDialerWrapper) WrapDialer(d model.Dialer) model.Dialer {
|
func (w *dialerConnectObserver) WrapDialer(d model.Dialer) model.Dialer {
|
||||||
return &SaverDialer{
|
return &DialerSaver{
|
||||||
Dialer: d,
|
Dialer: d,
|
||||||
Saver: w.saver,
|
Saver: w.saver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext implements Dialer.DialContext
|
// DialContext implements Dialer.DialContext
|
||||||
func (d *SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d *DialerSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
stop := time.Now()
|
stop := time.Now()
|
||||||
|
@ -61,13 +61,13 @@ func (d *SaverDialer) DialContext(ctx context.Context, network, address string)
|
||||||
return conn, err
|
return conn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SaverDialer) CloseIdleConnections() {
|
func (d *DialerSaver) CloseIdleConnections() {
|
||||||
d.Dialer.CloseIdleConnections()
|
d.Dialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaverConnDialer wraps the returned connection such that we
|
// DialerConnSaver wraps the returned connection such that we
|
||||||
// collect all the read/write events that occur.
|
// collect all the read/write events that occur.
|
||||||
type SaverConnDialer struct {
|
type DialerConnSaver struct {
|
||||||
// Dialer is the underlying dialer
|
// Dialer is the underlying dialer
|
||||||
Dialer model.Dialer
|
Dialer model.Dialer
|
||||||
|
|
||||||
|
@ -82,70 +82,78 @@ func (s *Saver) NewReadWriteObserver() model.DialerWrapper {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil // valid DialerWrapper according to netxlite's docs
|
return nil // valid DialerWrapper according to netxlite's docs
|
||||||
}
|
}
|
||||||
return &saverReadWriteWrapper{
|
return &dialerReadWriteObserver{
|
||||||
saver: s,
|
saver: s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type saverReadWriteWrapper struct {
|
type dialerReadWriteObserver struct {
|
||||||
saver *Saver
|
saver *Saver
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DialerWrapper = &saverReadWriteWrapper{}
|
var _ model.DialerWrapper = &dialerReadWriteObserver{}
|
||||||
|
|
||||||
func (w *saverReadWriteWrapper) WrapDialer(d model.Dialer) model.Dialer {
|
func (w *dialerReadWriteObserver) WrapDialer(d model.Dialer) model.Dialer {
|
||||||
return &SaverConnDialer{
|
return &DialerConnSaver{
|
||||||
Dialer: d,
|
Dialer: d,
|
||||||
Saver: w.saver,
|
Saver: w.saver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext implements Dialer.DialContext
|
// DialContext implements Dialer.DialContext
|
||||||
func (d *SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
func (d *DialerConnSaver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &saverConn{saver: d.Saver, Conn: conn}, nil
|
return &dialerConnWrapper{saver: d.Saver, Conn: conn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SaverConnDialer) CloseIdleConnections() {
|
func (d *DialerConnSaver) CloseIdleConnections() {
|
||||||
d.Dialer.CloseIdleConnections()
|
d.Dialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
type saverConn struct {
|
type dialerConnWrapper struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
saver *Saver
|
saver *Saver
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *saverConn) Read(p []byte) (int, error) {
|
func (c *dialerConnWrapper) Read(p []byte) (int, error) {
|
||||||
|
proto := c.Conn.RemoteAddr().Network()
|
||||||
|
remoteAddr := c.Conn.RemoteAddr().String()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
count, err := c.Conn.Read(p)
|
count, err := c.Conn.Read(p)
|
||||||
stop := time.Now()
|
stop := time.Now()
|
||||||
c.saver.Write(&EventReadOperation{&EventValue{
|
c.saver.Write(&EventReadOperation{&EventValue{
|
||||||
|
Address: remoteAddr,
|
||||||
Data: p[:count],
|
Data: p[:count],
|
||||||
Duration: stop.Sub(start),
|
Duration: stop.Sub(start),
|
||||||
Err: err,
|
Err: err,
|
||||||
NumBytes: count,
|
NumBytes: count,
|
||||||
|
Proto: proto,
|
||||||
Time: stop,
|
Time: stop,
|
||||||
}})
|
}})
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *saverConn) Write(p []byte) (int, error) {
|
func (c *dialerConnWrapper) Write(p []byte) (int, error) {
|
||||||
|
proto := c.Conn.RemoteAddr().Network()
|
||||||
|
remoteAddr := c.Conn.RemoteAddr().String()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
count, err := c.Conn.Write(p)
|
count, err := c.Conn.Write(p)
|
||||||
stop := time.Now()
|
stop := time.Now()
|
||||||
c.saver.Write(&EventWriteOperation{&EventValue{
|
c.saver.Write(&EventWriteOperation{&EventValue{
|
||||||
|
Address: remoteAddr,
|
||||||
Data: p[:count],
|
Data: p[:count],
|
||||||
Duration: stop.Sub(start),
|
Duration: stop.Sub(start),
|
||||||
Err: err,
|
Err: err,
|
||||||
NumBytes: count,
|
NumBytes: count,
|
||||||
|
Proto: proto,
|
||||||
Time: stop,
|
Time: stop,
|
||||||
}})
|
}})
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Dialer = &SaverDialer{}
|
var _ model.Dialer = &DialerSaver{}
|
||||||
var _ model.Dialer = &SaverConnDialer{}
|
var _ model.Dialer = &DialerConnSaver{}
|
||||||
var _ net.Conn = &saverConn{}
|
var _ net.Conn = &dialerConnWrapper{}
|
||||||
|
|
|
@ -12,10 +12,27 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSaverDialerFailure(t *testing.T) {
|
func TestDialerConnectObserver(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
obs := &dialerConnectObserver{
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
dialer := &mocks.Dialer{}
|
||||||
|
out := obs.WrapDialer(dialer)
|
||||||
|
dialSaver := out.(*DialerSaver)
|
||||||
|
if dialSaver.Dialer != dialer {
|
||||||
|
t.Fatal("invalid dialer")
|
||||||
|
}
|
||||||
|
if dialSaver.Saver != saver {
|
||||||
|
t.Fatal("invalid saver")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialerSaver(t *testing.T) {
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
dlr := &SaverDialer{
|
dlr := &DialerSaver{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
|
@ -52,12 +69,48 @@ func TestSaverDialerFailure(t *testing.T) {
|
||||||
if !ev[0].Value().Time.Before(time.Now()) {
|
if !ev[0].Value().Time.Before(time.Now()) {
|
||||||
t.Fatal("unexpected Time")
|
t.Fatal("unexpected Time")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
child := &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := &DialerSaver{
|
||||||
|
Dialer: child,
|
||||||
|
Saver: &Saver{},
|
||||||
|
}
|
||||||
|
dialer.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSaverConnDialerFailure(t *testing.T) {
|
func TestDialerReadWriteObserver(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
obs := &dialerReadWriteObserver{
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
dialer := &mocks.Dialer{}
|
||||||
|
out := obs.WrapDialer(dialer)
|
||||||
|
dialSaver := out.(*DialerConnSaver)
|
||||||
|
if dialSaver.Dialer != dialer {
|
||||||
|
t.Fatal("invalid dialer")
|
||||||
|
}
|
||||||
|
if dialSaver.Saver != saver {
|
||||||
|
t.Fatal("invalid saver")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialerConnSaver(t *testing.T) {
|
||||||
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
dlr := &SaverConnDialer{
|
dlr := &DialerConnSaver{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
|
@ -72,28 +125,16 @@ func TestSaverConnDialerFailure(t *testing.T) {
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
t.Fatal("expected nil conn here")
|
t.Fatal("expected nil conn here")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestSaverConnDialerSuccess(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
origConn := &mocks.Conn{}
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
dlr := &SaverConnDialer{
|
dlr := &DialerConnSaver{
|
||||||
Dialer: &SaverDialer{
|
Dialer: &DialerSaver{
|
||||||
Dialer: &mocks.Dialer{
|
Dialer: &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return origConn, nil
|
||||||
MockRead: func(b []byte) (int, error) {
|
|
||||||
return 0, io.EOF
|
|
||||||
},
|
|
||||||
MockWrite: func(b []byte) (int, error) {
|
|
||||||
return 0, io.EOF
|
|
||||||
},
|
|
||||||
MockClose: func() error {
|
|
||||||
return io.EOF
|
|
||||||
},
|
|
||||||
MockLocalAddr: func() net.Addr {
|
|
||||||
return &net.TCPAddr{Port: 12345}
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
|
@ -104,35 +145,135 @@ func TestSaverConnDialerSuccess(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
}
|
}
|
||||||
conn.Read(nil)
|
cw := conn.(*dialerConnWrapper)
|
||||||
conn.Write(nil)
|
if cw.Conn != origConn {
|
||||||
conn.Close()
|
t.Fatal("unexpected conn")
|
||||||
events := saver.Read()
|
|
||||||
if len(events) != 3 {
|
|
||||||
t.Fatal("unexpected number of events saved", len(events))
|
|
||||||
}
|
}
|
||||||
if events[0].Name() != "connect" {
|
})
|
||||||
t.Fatal("expected a connect event")
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
child := &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
}
|
}
|
||||||
saverCheckConnectEvent(t, &events[0])
|
dialer := &DialerConnSaver{
|
||||||
if events[1].Name() != "read" {
|
Dialer: child,
|
||||||
t.Fatal("expected a read event")
|
Saver: &Saver{},
|
||||||
}
|
}
|
||||||
saverCheckReadEvent(t, &events[1])
|
dialer.CloseIdleConnections()
|
||||||
if events[2].Name() != "write" {
|
if !called {
|
||||||
t.Fatal("expected a write event")
|
t.Fatal("not called")
|
||||||
}
|
}
|
||||||
saverCheckWriteEvent(t, &events[2])
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func saverCheckConnectEvent(t *testing.T, ev *Event) {
|
func TestDialerConnWrapper(t *testing.T) {
|
||||||
// TODO(bassosimone): implement
|
t.Run("Read", func(t *testing.T) {
|
||||||
}
|
baseConn := &mocks.Conn{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
|
},
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "www.google.com:443"
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
saver := &Saver{}
|
||||||
|
conn := &dialerConnWrapper{
|
||||||
|
Conn: baseConn,
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
data := make([]byte, 155)
|
||||||
|
count, err := conn.Read(data)
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("unexpected count")
|
||||||
|
}
|
||||||
|
ev := saver.Read()
|
||||||
|
if len(ev) != 1 {
|
||||||
|
t.Fatal("expected a single event here")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Address != "www.google.com:443" {
|
||||||
|
t.Fatal("unexpected Address")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Duration <= 0 {
|
||||||
|
t.Fatal("unexpected Duration")
|
||||||
|
}
|
||||||
|
if !errors.Is(ev[0].Value().Err, io.EOF) {
|
||||||
|
t.Fatal("unexpected Err")
|
||||||
|
}
|
||||||
|
if ev[0].Name() != netxlite.ReadOperation {
|
||||||
|
t.Fatal("unexpected Name")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Proto != "tcp" {
|
||||||
|
t.Fatal("unexpected Proto")
|
||||||
|
}
|
||||||
|
if !ev[0].Value().Time.Before(time.Now()) {
|
||||||
|
t.Fatal("unexpected Time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
func saverCheckReadEvent(t *testing.T, ev *Event) {
|
t.Run("Write", func(t *testing.T) {
|
||||||
// TODO(bassosimone): implement
|
baseConn := &mocks.Conn{
|
||||||
}
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return 0, io.EOF
|
||||||
func saverCheckWriteEvent(t *testing.T, ev *Event) {
|
},
|
||||||
// TODO(bassosimone): implement
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "www.google.com:443"
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
saver := &Saver{}
|
||||||
|
conn := &dialerConnWrapper{
|
||||||
|
Conn: baseConn,
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
data := make([]byte, 155)
|
||||||
|
count, err := conn.Write(data)
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("unexpected count")
|
||||||
|
}
|
||||||
|
ev := saver.Read()
|
||||||
|
if len(ev) != 1 {
|
||||||
|
t.Fatal("expected a single event here")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Address != "www.google.com:443" {
|
||||||
|
t.Fatal("unexpected Address")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Duration <= 0 {
|
||||||
|
t.Fatal("unexpected Duration")
|
||||||
|
}
|
||||||
|
if !errors.Is(ev[0].Value().Err, io.EOF) {
|
||||||
|
t.Fatal("unexpected Err")
|
||||||
|
}
|
||||||
|
if ev[0].Name() != netxlite.WriteOperation {
|
||||||
|
t.Fatal("unexpected Name")
|
||||||
|
}
|
||||||
|
if ev[0].Value().Proto != "tcp" {
|
||||||
|
t.Fatal("unexpected Proto")
|
||||||
|
}
|
||||||
|
if !ev[0].Value().Time.Before(time.Now()) {
|
||||||
|
t.Fatal("unexpected Time")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package tracex
|
package tracex
|
||||||
|
|
||||||
|
//
|
||||||
|
// All the possible events
|
||||||
|
//
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
|
@ -27,11 +27,17 @@ func httpCloneRequestHeaders(req *http.Request) http.Header {
|
||||||
return header
|
return header
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaverTransactionHTTPTransport is a RoundTripper that saves
|
// HTTPTransportSaver is a RoundTripper that saves
|
||||||
// events related to the HTTP transaction
|
// events related to the HTTP transaction
|
||||||
type SaverTransactionHTTPTransport struct {
|
type HTTPTransportSaver struct {
|
||||||
model.HTTPTransport
|
// HTTPTransport is the MANDATORY underlying HTTP transport.
|
||||||
|
HTTPTransport model.HTTPTransport
|
||||||
|
|
||||||
|
// Saver is the MANDATORY saver to use.
|
||||||
Saver *Saver
|
Saver *Saver
|
||||||
|
|
||||||
|
// SnapshotSize is the OPTIONAL maximum body snapshot size (if not set, we'll
|
||||||
|
// use 1<<17, which we've been using since the ooni/netx days)
|
||||||
SnapshotSize int64
|
SnapshotSize int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,7 +46,11 @@ type SaverTransactionHTTPTransport struct {
|
||||||
//
|
//
|
||||||
// The maxBodySnapshotSize argument controls the maximum size of the
|
// The maxBodySnapshotSize argument controls the maximum size of the
|
||||||
// body snapshot that we collect along with the HTTP round trip.
|
// body snapshot that we collect along with the HTTP round trip.
|
||||||
func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (txp *HTTPTransportSaver) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
|
||||||
|
// TODO(bassosimone): we're currently using the started time for
|
||||||
|
// the transaction done event, which contrasts with what we do for
|
||||||
|
// every other event. What does the spec say?
|
||||||
|
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
txp.Saver.Write(&EventHTTPTransactionStart{&EventValue{
|
txp.Saver.Write(&EventHTTPTransactionStart{&EventValue{
|
||||||
|
@ -92,7 +102,15 @@ func (txp *SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Re
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *SaverTransactionHTTPTransport) snapshotSize() int64 {
|
func (txp *HTTPTransportSaver) CloseIdleConnections() {
|
||||||
|
txp.HTTPTransport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (txp *HTTPTransportSaver) Network() string {
|
||||||
|
return txp.HTTPTransport.Network()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (txp *HTTPTransportSaver) snapshotSize() int64 {
|
||||||
if txp.SnapshotSize > 0 {
|
if txp.SnapshotSize > 0 {
|
||||||
return txp.SnapshotSize
|
return txp.SnapshotSize
|
||||||
}
|
}
|
||||||
|
@ -104,4 +122,4 @@ type httpReadableAgainBody struct {
|
||||||
io.Closer
|
io.Closer
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.HTTPTransport = &SaverTransactionHTTPTransport{}
|
var _ model.HTTPTransport = &HTTPTransportSaver{}
|
||||||
|
|
|
@ -16,8 +16,42 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSaverTransactionHTTPTransport(t *testing.T) {
|
func TestHTTPTransportSaver(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
child := &mocks.HTTPTransport{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := &HTTPTransportSaver{
|
||||||
|
HTTPTransport: child,
|
||||||
|
Saver: &Saver{},
|
||||||
|
}
|
||||||
|
dialer.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Network", func(t *testing.T) {
|
||||||
|
expected := "antani"
|
||||||
|
child := &mocks.HTTPTransport{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := &HTTPTransportSaver{
|
||||||
|
HTTPTransport: child,
|
||||||
|
Saver: &Saver{},
|
||||||
|
}
|
||||||
|
if dialer.Network() != expected {
|
||||||
|
t.Fatal("unexpected Network")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) {
|
startServer := func(t *testing.T, action filtering.HTTPAction) (net.Listener, *url.URL) {
|
||||||
server := &filtering.HTTPProxy{
|
server := &filtering.HTTPProxy{
|
||||||
OnIncomingHost: func(host string) filtering.HTTPAction {
|
OnIncomingHost: func(host string) filtering.HTTPAction {
|
||||||
|
@ -38,7 +72,7 @@ func TestSaverTransactionHTTPTransport(t *testing.T) {
|
||||||
|
|
||||||
measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) {
|
measureHTTP := func(t *testing.T, URL *url.URL) (*http.Response, *Saver, error) {
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
txp := &SaverTransactionHTTPTransport{
|
txp := &HTTPTransportSaver{
|
||||||
HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger),
|
HTTPTransport: netxlite.NewHTTPTransportStdlib(model.DiscardLogger),
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
}
|
}
|
||||||
|
@ -178,17 +212,17 @@ func TestSaverTransactionHTTPTransport(t *testing.T) {
|
||||||
|
|
||||||
// Sometimes useful for testing
|
// Sometimes useful for testing
|
||||||
/*
|
/*
|
||||||
dumplog := func(t *testing.T, ev Event) {
|
dump := func(t *testing.T, ev Event) {
|
||||||
data, _ := json.MarshalIndent(ev.Value(), " ", " ")
|
data, _ := json.MarshalIndent(ev.Value(), " ", " ")
|
||||||
t.Log(string(data))
|
t.Log(string(data))
|
||||||
t.FailNow()
|
t.Fail()
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
t.Run("on error reading the response body", func(t *testing.T) {
|
t.Run("on error reading the response body", func(t *testing.T) {
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
txp := SaverTransactionHTTPTransport{
|
txp := HTTPTransportSaver{
|
||||||
HTTPTransport: &mocks.HTTPTransport{
|
HTTPTransport: &mocks.HTTPTransport{
|
||||||
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
MockRoundTrip: func(req *http.Request) (*http.Response, error) {
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
|
@ -238,6 +272,7 @@ func TestSaverTransactionHTTPTransport(t *testing.T) {
|
||||||
t.Fatal("invalid error")
|
t.Fatal("invalid error")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPCloneRequestHeaders(t *testing.T) {
|
func TestHTTPCloneRequestHeaders(t *testing.T) {
|
||||||
|
|
|
@ -15,8 +15,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QUICHandshakeSaver saves events occurring during the QUIC handshake.
|
// QUICDialerSaver saves events occurring during the QUIC handshake.
|
||||||
type QUICHandshakeSaver struct {
|
type QUICDialerSaver struct {
|
||||||
// QUICDialer is the wrapped dialer
|
// QUICDialer is the wrapped dialer
|
||||||
QUICDialer model.QUICDialer
|
QUICDialer model.QUICDialer
|
||||||
|
|
||||||
|
@ -33,14 +33,14 @@ func (s *Saver) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return qd
|
return qd
|
||||||
}
|
}
|
||||||
return &QUICHandshakeSaver{
|
return &QUICDialerSaver{
|
||||||
QUICDialer: qd,
|
QUICDialer: qd,
|
||||||
Saver: s,
|
Saver: s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContext implements QUICDialer.DialContext
|
// DialContext implements QUICDialer.DialContext
|
||||||
func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
|
func (h *QUICDialerSaver) DialContext(ctx context.Context, network string,
|
||||||
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
// TODO(bassosimone): in the future we probably want to also save
|
// TODO(bassosimone): in the future we probably want to also save
|
||||||
|
@ -58,9 +58,11 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(bassosimone): here we should save the peer certs
|
// TODO(bassosimone): here we should save the peer certs
|
||||||
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
|
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
|
||||||
|
Address: host,
|
||||||
Duration: stop.Sub(start),
|
Duration: stop.Sub(start),
|
||||||
Err: err,
|
Err: err,
|
||||||
NoTLSVerify: tlsCfg.InsecureSkipVerify,
|
NoTLSVerify: tlsCfg.InsecureSkipVerify,
|
||||||
|
Proto: network,
|
||||||
TLSNextProtos: tlsCfg.NextProtos,
|
TLSNextProtos: tlsCfg.NextProtos,
|
||||||
TLSServerName: tlsCfg.ServerName,
|
TLSServerName: tlsCfg.ServerName,
|
||||||
Time: stop,
|
Time: stop,
|
||||||
|
@ -69,8 +71,10 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
|
||||||
}
|
}
|
||||||
state := quicConnectionState(sess)
|
state := quicConnectionState(sess)
|
||||||
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
|
h.Saver.Write(&EventQUICHandshakeDone{&EventValue{
|
||||||
|
Address: host,
|
||||||
Duration: stop.Sub(start),
|
Duration: stop.Sub(start),
|
||||||
NoTLSVerify: tlsCfg.InsecureSkipVerify,
|
NoTLSVerify: tlsCfg.InsecureSkipVerify,
|
||||||
|
Proto: network,
|
||||||
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
|
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
|
||||||
TLSNegotiatedProto: state.NegotiatedProtocol,
|
TLSNegotiatedProto: state.NegotiatedProtocol,
|
||||||
TLSNextProtos: tlsCfg.NextProtos,
|
TLSNextProtos: tlsCfg.NextProtos,
|
||||||
|
@ -82,7 +86,7 @@ func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
|
||||||
return sess, nil
|
return sess, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *QUICHandshakeSaver) CloseIdleConnections() {
|
func (h *QUICDialerSaver) CloseIdleConnections() {
|
||||||
h.QUICDialer.CloseIdleConnections()
|
h.QUICDialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,15 +125,15 @@ func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, erro
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pconn = &udpLikeConnSaver{
|
pconn = &quicPacketConnWrapper{
|
||||||
UDPLikeConn: pconn,
|
UDPLikeConn: pconn,
|
||||||
saver: qls.Saver,
|
saver: qls.Saver,
|
||||||
}
|
}
|
||||||
return pconn, nil
|
return pconn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// udpLikeConnSaver saves I/O events
|
// quicPacketConnWrapper saves I/O events
|
||||||
type udpLikeConnSaver struct {
|
type quicPacketConnWrapper struct {
|
||||||
// UDPLikeConn is the wrapped underlying conn
|
// UDPLikeConn is the wrapped underlying conn
|
||||||
model.UDPLikeConn
|
model.UDPLikeConn
|
||||||
|
|
||||||
|
@ -137,7 +141,7 @@ type udpLikeConnSaver struct {
|
||||||
saver *Saver
|
saver *Saver
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) {
|
func (c *quicPacketConnWrapper) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
count, err := c.UDPLikeConn.WriteTo(p, addr)
|
count, err := c.UDPLikeConn.WriteTo(p, addr)
|
||||||
stop := time.Now()
|
stop := time.Now()
|
||||||
|
@ -152,7 +156,7 @@ func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) {
|
func (c *quicPacketConnWrapper) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
n, addr, err := c.UDPLikeConn.ReadFrom(b)
|
n, addr, err := c.UDPLikeConn.ReadFrom(b)
|
||||||
stop := time.Now()
|
stop := time.Now()
|
||||||
|
@ -171,13 +175,13 @@ func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||||
return n, addr, err
|
return n, addr, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *udpLikeConnSaver) safeAddrString(addr net.Addr) (out string) {
|
func (c *quicPacketConnWrapper) safeAddrString(addr net.Addr) (out string) {
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
out = addr.String()
|
out = addr.String()
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.QUICDialer = &QUICHandshakeSaver{}
|
var _ model.QUICDialer = &QUICDialerSaver{}
|
||||||
var _ model.QUICListener = &QUICListenerSaver{}
|
var _ model.QUICListener = &QUICListenerSaver{}
|
||||||
var _ model.UDPLikeConn = &udpLikeConnSaver{}
|
var _ model.UDPLikeConn = &quicPacketConnWrapper{}
|
||||||
|
|
|
@ -3,122 +3,180 @@ package tracex
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/quictesting"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type MockDialer struct {
|
func TestQUICDialerSaver(t *testing.T) {
|
||||||
Dialer model.QUICDialer
|
t.Run("DialContext", func(t *testing.T) {
|
||||||
Sess quic.EarlyConnection
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MockDialer) DialContext(ctx context.Context, network, host string,
|
checkStartEventFields := func(t *testing.T, value *EventValue) {
|
||||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
if value.Address != "8.8.8.8:443" {
|
||||||
if d.Dialer != nil {
|
t.Fatal("invalid Address")
|
||||||
return d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
}
|
||||||
|
if !value.NoTLSVerify {
|
||||||
|
t.Fatal("expected NoTLSVerify to be true")
|
||||||
|
}
|
||||||
|
if value.Proto != "udp" {
|
||||||
|
t.Fatal("wrong protocol")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(value.TLSNextProtos, []string{"h3"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
if value.TLSServerName != "dns.google" {
|
||||||
|
t.Fatal("invalid TLSServerName")
|
||||||
|
}
|
||||||
|
if value.Time.IsZero() {
|
||||||
|
t.Fatal("expected non zero time")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return d.Sess, d.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandshakeSaverSuccess(t *testing.T) {
|
checkStartedEvent := func(t *testing.T, ev Event) {
|
||||||
nextprotos := []string{"h3"}
|
if _, good := ev.(*EventQUICHandshakeStart); !good {
|
||||||
servername := quictesting.Domain
|
t.Fatal("invalid event type")
|
||||||
tlsConf := &tls.Config{
|
|
||||||
NextProtos: nextprotos,
|
|
||||||
ServerName: servername,
|
|
||||||
}
|
}
|
||||||
|
value := ev.Value()
|
||||||
|
checkStartEventFields(t, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) {
|
||||||
|
if value.Duration <= 0 {
|
||||||
|
t.Fatal("expected non-zero duration")
|
||||||
|
}
|
||||||
|
if value.Err != nil {
|
||||||
|
t.Fatal("expected no error here")
|
||||||
|
}
|
||||||
|
if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" {
|
||||||
|
t.Fatal("invalid cipher suite")
|
||||||
|
}
|
||||||
|
if value.TLSNegotiatedProto != "h3" {
|
||||||
|
t.Fatal("invalid negotiated protocol")
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
if value.TLSVersion != "TLSv1.3" {
|
||||||
|
t.Fatal("invalid TLS version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) {
|
||||||
|
if _, good := ev.(*EventQUICHandshakeDone); !good {
|
||||||
|
t.Fatal("invalid event type")
|
||||||
|
}
|
||||||
|
value := ev.Value()
|
||||||
|
checkStartEventFields(t, value)
|
||||||
|
fun(t, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{
|
returnedConn := &mocks.QUICEarlyConnection{
|
||||||
QUICListener: &netxlite.QUICListenerStdlib{},
|
MockConnectionState: func() quic.ConnectionState {
|
||||||
|
cs := quic.ConnectionState{}
|
||||||
|
cs.TLS.ConnectionState.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA
|
||||||
|
cs.TLS.NegotiatedProtocol = "h3"
|
||||||
|
cs.TLS.PeerCertificates = []*x509.Certificate{}
|
||||||
|
cs.TLS.Version = tls.VersionTLS13
|
||||||
|
return cs
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := saver.WrapQUICDialer(&mocks.QUICDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
|
||||||
|
return returnedConn, nil
|
||||||
|
},
|
||||||
})
|
})
|
||||||
sess, err := dlr.DialContext(context.Background(), "udp",
|
ctx := context.Background()
|
||||||
quictesting.Endpoint("443"), tlsConf, &quic.Config{})
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"h3"},
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
quicConfig := &quic.Config{}
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("unexpected error", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if sess == nil {
|
if conn == nil {
|
||||||
t.Fatal("unexpected nil sess")
|
t.Fatal("expected non-nil conn")
|
||||||
}
|
}
|
||||||
ev := saver.Read()
|
events := saver.Read()
|
||||||
if len(ev) != 2 {
|
if len(events) != 2 {
|
||||||
t.Fatal("unexpected number of events")
|
t.Fatal("expected two events")
|
||||||
}
|
}
|
||||||
if ev[0].Name() != "quic_handshake_start" {
|
checkStartedEvent(t, events[0])
|
||||||
t.Fatal("unexpected Name")
|
checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess)
|
||||||
}
|
|
||||||
if ev[0].Value().TLSServerName != quictesting.Domain {
|
|
||||||
t.Fatal("unexpected TLSServerName")
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) {
|
|
||||||
t.Fatal("unexpected TLSNextProtos")
|
|
||||||
}
|
|
||||||
if ev[0].Value().Time.After(time.Now()) {
|
|
||||||
t.Fatal("unexpected Time")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Err != nil {
|
|
||||||
t.Fatal("unexpected Err", ev[1].Value().Err)
|
|
||||||
}
|
|
||||||
if ev[1].Name() != "quic_handshake_done" {
|
|
||||||
t.Fatal("unexpected Name")
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) {
|
|
||||||
t.Fatal("unexpected TLSNextProtos")
|
|
||||||
}
|
|
||||||
if ev[1].Value().TLSServerName != quictesting.Domain {
|
|
||||||
t.Fatal("unexpected TLSServerName")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Time.Before(ev[0].Value().Time) {
|
|
||||||
t.Fatal("unexpected Time")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandshakeSaverHostNameError(t *testing.T) {
|
|
||||||
nextprotos := []string{"h3"}
|
|
||||||
servername := "example.com"
|
|
||||||
tlsConf := &tls.Config{
|
|
||||||
NextProtos: nextprotos,
|
|
||||||
ServerName: servername,
|
|
||||||
}
|
|
||||||
saver := &Saver{}
|
|
||||||
dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{
|
|
||||||
QUICListener: &netxlite.QUICListenerStdlib{},
|
|
||||||
})
|
})
|
||||||
sess, err := dlr.DialContext(context.Background(), "udp",
|
|
||||||
quictesting.Endpoint("443"), tlsConf, &quic.Config{})
|
checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) {
|
||||||
if err == nil {
|
if value.Duration <= 0 {
|
||||||
t.Fatal("expected an error here")
|
t.Fatal("expected non-zero duration")
|
||||||
}
|
}
|
||||||
if sess != nil {
|
if value.Err == nil {
|
||||||
t.Fatal("expected nil sess here")
|
t.Fatal("expected non-nil error here")
|
||||||
}
|
|
||||||
for _, ev := range saver.Read() {
|
|
||||||
if ev.Name() != "quic_handshake_done" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ev.Value().NoTLSVerify == true {
|
|
||||||
t.Fatal("expected NoTLSVerify to be false")
|
|
||||||
}
|
|
||||||
if !strings.HasSuffix(ev.Value().Err.Error(), "tls: handshake failure") {
|
|
||||||
t.Fatal("unexpected error", ev.Value().Err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
saver := &Saver{}
|
||||||
|
dialer := saver.WrapQUICDialer(&mocks.QUICDialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string,
|
||||||
|
tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"h3"},
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
quicConfig := &quic.Config{}
|
||||||
|
conn, err := dialer.DialContext(ctx, "udp", "8.8.8.8:443", tlsConfig, quicConfig)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
events := saver.Read()
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Fatal("expected two events")
|
||||||
|
}
|
||||||
|
checkStartedEvent(t, events[0])
|
||||||
|
checkDoneEvent(t, events[1], checkDoneEventFieldsFailure)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
child := &mocks.QUICDialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dialer := &QUICDialerSaver{
|
||||||
|
QUICDialer: child,
|
||||||
|
Saver: &Saver{},
|
||||||
|
}
|
||||||
|
dialer.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQUICListenerSaverCannotListen(t *testing.T) {
|
func TestQUICListenerSaver(t *testing.T) {
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
qls := saver.WrapQUICListener(&mocks.QUICListener{
|
qls := saver.WrapQUICListener(&mocks.QUICListener{
|
||||||
|
@ -137,48 +195,254 @@ func TestQUICListenerSaverCannotListen(t *testing.T) {
|
||||||
if pconn != nil {
|
if pconn != nil {
|
||||||
t.Fatal("expected nil pconn here")
|
t.Fatal("expected nil pconn here")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestSystemDialerSuccessWithReadWrite(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
// This is the most common use case for collecting reads, writes
|
|
||||||
tlsConf := &tls.Config{
|
|
||||||
NextProtos: []string{"h3"},
|
|
||||||
ServerName: quictesting.Domain,
|
|
||||||
}
|
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
systemdialer := &netxlite.QUICDialerQUICGo{
|
returnedConn := &mocks.UDPLikeConn{}
|
||||||
QUICListener: saver.WrapQUICListener(&netxlite.QUICListenerStdlib{}),
|
qls := saver.WrapQUICListener(&mocks.QUICListener{
|
||||||
}
|
MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
|
||||||
_, err := systemdialer.DialContext(context.Background(), "udp",
|
return returnedConn, nil
|
||||||
quictesting.Endpoint("443"), tlsConf, &quic.Config{})
|
},
|
||||||
|
})
|
||||||
|
pconn, err := qls.Listen(&net.UDPAddr{
|
||||||
|
IP: []byte{},
|
||||||
|
Port: 8080,
|
||||||
|
Zone: "",
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
ev := saver.Read()
|
wconn := pconn.(*quicPacketConnWrapper)
|
||||||
if len(ev) < 2 {
|
if wconn.UDPLikeConn != returnedConn {
|
||||||
t.Fatal("unexpected number of events")
|
t.Fatal("invalid underlying connection")
|
||||||
}
|
|
||||||
last := len(ev) - 1
|
|
||||||
for idx := 1; idx < last; idx++ {
|
|
||||||
if ev[idx].Value().Data == nil {
|
|
||||||
t.Fatal("unexpected Data")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().Err != nil {
|
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().NumBytes <= 0 {
|
|
||||||
t.Fatal("unexpected NumBytes")
|
|
||||||
}
|
|
||||||
switch ev[idx].Name() {
|
|
||||||
case netxlite.ReadFromOperation, netxlite.WriteToOperation:
|
|
||||||
default:
|
|
||||||
t.Fatal("unexpected Name")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) {
|
|
||||||
t.Fatal("unexpected Time", ev[idx].Value().Time, ev[idx-1].Value().Time)
|
|
||||||
}
|
}
|
||||||
|
if wconn.saver != saver {
|
||||||
|
t.Fatal("invalid saver")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQUICPacketConnWrapper(t *testing.T) {
|
||||||
|
t.Run("ReadFrom", func(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
saver := &Saver{}
|
||||||
|
conn := &quicPacketConnWrapper{
|
||||||
|
UDPLikeConn: &mocks.UDPLikeConn{
|
||||||
|
MockReadFrom: func(p []byte) (int, net.Addr, error) {
|
||||||
|
return 0, nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
buf := make([]byte, 1<<17)
|
||||||
|
count, addr, err := conn.ReadFrom(buf)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if addr != nil {
|
||||||
|
t.Fatal("invalid addr")
|
||||||
|
}
|
||||||
|
events := saver.Read()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("invalid number of events")
|
||||||
|
}
|
||||||
|
ev0 := events[0]
|
||||||
|
if _, good := ev0.(*EventReadFromOperation); !good {
|
||||||
|
t.Fatal("invalid event type")
|
||||||
|
}
|
||||||
|
value := ev0.Value()
|
||||||
|
if value.Address != "" {
|
||||||
|
t.Fatal("invalid Address")
|
||||||
|
}
|
||||||
|
if len(value.Data) != 0 {
|
||||||
|
t.Fatal("invalid Data")
|
||||||
|
}
|
||||||
|
if value.Duration <= 0 {
|
||||||
|
t.Fatal("expected nonzero duration")
|
||||||
|
}
|
||||||
|
if !errors.Is(value.Err, expected) {
|
||||||
|
t.Fatal("unexpected value.Err", value.Err)
|
||||||
|
}
|
||||||
|
if value.NumBytes != 0 {
|
||||||
|
t.Fatal("expected NumBytes")
|
||||||
|
}
|
||||||
|
if value.Time.IsZero() {
|
||||||
|
t.Fatal("expected nonzero Time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
expected := []byte{1, 2, 3, 4}
|
||||||
|
saver := &Saver{}
|
||||||
|
expectedAddr := &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "8.8.8.8:443"
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn := &quicPacketConnWrapper{
|
||||||
|
UDPLikeConn: &mocks.UDPLikeConn{
|
||||||
|
MockReadFrom: func(p []byte) (int, net.Addr, error) {
|
||||||
|
copy(p, expected)
|
||||||
|
return len(expected), expectedAddr, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
buf := make([]byte, 1<<17)
|
||||||
|
count, addr, err := conn.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if count != 4 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
if addr != expectedAddr {
|
||||||
|
t.Fatal("invalid addr")
|
||||||
|
}
|
||||||
|
events := saver.Read()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("invalid number of events")
|
||||||
|
}
|
||||||
|
ev0 := events[0]
|
||||||
|
if _, good := ev0.(*EventReadFromOperation); !good {
|
||||||
|
t.Fatal("invalid event type")
|
||||||
|
}
|
||||||
|
value := ev0.Value()
|
||||||
|
if value.Address != "8.8.8.8:443" {
|
||||||
|
t.Fatal("invalid Address")
|
||||||
|
}
|
||||||
|
if len(value.Data) != 4 {
|
||||||
|
t.Fatal("invalid Data")
|
||||||
|
}
|
||||||
|
if value.Duration <= 0 {
|
||||||
|
t.Fatal("expected nonzero duration")
|
||||||
|
}
|
||||||
|
if value.Err != nil {
|
||||||
|
t.Fatal("unexpected value.Err", value.Err)
|
||||||
|
}
|
||||||
|
if value.NumBytes != 4 {
|
||||||
|
t.Fatal("expected NumBytes")
|
||||||
|
}
|
||||||
|
if value.Time.IsZero() {
|
||||||
|
t.Fatal("expected nonzero Time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WriteTo", func(t *testing.T) {
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
saver := &Saver{}
|
||||||
|
conn := &quicPacketConnWrapper{
|
||||||
|
UDPLikeConn: &mocks.UDPLikeConn{
|
||||||
|
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
|
||||||
|
return 0, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
destAddr := &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "8.8.8.8:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
buf := make([]byte, 7)
|
||||||
|
count, err := conn.WriteTo(buf, destAddr)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
events := saver.Read()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("invalid number of events")
|
||||||
|
}
|
||||||
|
ev0 := events[0]
|
||||||
|
if _, good := ev0.(*EventWriteToOperation); !good {
|
||||||
|
t.Fatal("invalid event type")
|
||||||
|
}
|
||||||
|
value := ev0.Value()
|
||||||
|
if value.Address != "8.8.8.8:443" {
|
||||||
|
t.Fatal("invalid Address")
|
||||||
|
}
|
||||||
|
if len(value.Data) != 0 {
|
||||||
|
t.Fatal("invalid Data")
|
||||||
|
}
|
||||||
|
if value.Duration <= 0 {
|
||||||
|
t.Fatal("expected nonzero duration")
|
||||||
|
}
|
||||||
|
if !errors.Is(value.Err, expected) {
|
||||||
|
t.Fatal("unexpected value.Err", value.Err)
|
||||||
|
}
|
||||||
|
if value.NumBytes != 0 {
|
||||||
|
t.Fatal("expected NumBytes")
|
||||||
|
}
|
||||||
|
if value.Time.IsZero() {
|
||||||
|
t.Fatal("expected nonzero Time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
conn := &quicPacketConnWrapper{
|
||||||
|
UDPLikeConn: &mocks.UDPLikeConn{
|
||||||
|
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
|
||||||
|
return 1, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
saver: saver,
|
||||||
|
}
|
||||||
|
destAddr := &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "8.8.8.8:443"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
buf := make([]byte, 7)
|
||||||
|
count, err := conn.WriteTo(buf, destAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if count != 1 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
events := saver.Read()
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatal("invalid number of events")
|
||||||
|
}
|
||||||
|
ev0 := events[0]
|
||||||
|
if _, good := ev0.(*EventWriteToOperation); !good {
|
||||||
|
t.Fatal("invalid event type")
|
||||||
|
}
|
||||||
|
value := ev0.Value()
|
||||||
|
if value.Address != "8.8.8.8:443" {
|
||||||
|
t.Fatal("invalid Address")
|
||||||
|
}
|
||||||
|
if len(value.Data) != 1 {
|
||||||
|
t.Fatal("invalid Data")
|
||||||
|
}
|
||||||
|
if value.Duration <= 0 {
|
||||||
|
t.Fatal("expected nonzero duration")
|
||||||
|
}
|
||||||
|
if value.Err != nil {
|
||||||
|
t.Fatal("unexpected value.Err", value.Err)
|
||||||
|
}
|
||||||
|
if value.NumBytes != 1 {
|
||||||
|
t.Fatal("expected NumBytes")
|
||||||
|
}
|
||||||
|
if value.Time.IsZero() {
|
||||||
|
t.Fatal("expected nonzero Time")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,8 +12,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SaverResolver is a resolver that saves events.
|
// ResolverSaver is a resolver that saves events.
|
||||||
type SaverResolver struct {
|
type ResolverSaver struct {
|
||||||
// Resolver is the underlying resolver.
|
// Resolver is the underlying resolver.
|
||||||
Resolver model.Resolver
|
Resolver model.Resolver
|
||||||
|
|
||||||
|
@ -30,14 +30,14 @@ func (s *Saver) WrapResolver(r model.Resolver) model.Resolver {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
return &SaverResolver{
|
return &ResolverSaver{
|
||||||
Resolver: r,
|
Resolver: r,
|
||||||
Saver: s,
|
Saver: s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHost implements Resolver.LookupHost
|
// LookupHost implements Resolver.LookupHost
|
||||||
func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
func (r *ResolverSaver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
r.Saver.Write(&EventResolveStart{&EventValue{
|
r.Saver.Write(&EventResolveStart{&EventValue{
|
||||||
Address: r.Resolver.Address(),
|
Address: r.Resolver.Address(),
|
||||||
|
@ -59,30 +59,30 @@ func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]stri
|
||||||
return addrs, err
|
return addrs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SaverResolver) Network() string {
|
func (r *ResolverSaver) Network() string {
|
||||||
return r.Resolver.Network()
|
return r.Resolver.Network()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SaverResolver) Address() string {
|
func (r *ResolverSaver) Address() string {
|
||||||
return r.Resolver.Address()
|
return r.Resolver.Address()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SaverResolver) CloseIdleConnections() {
|
func (r *ResolverSaver) CloseIdleConnections() {
|
||||||
r.Resolver.CloseIdleConnections()
|
r.Resolver.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SaverResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
func (r *ResolverSaver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||||
// TODO(bassosimone): we should probably implement this method
|
// TODO(bassosimone): we should probably implement this method
|
||||||
return r.Resolver.LookupHTTPS(ctx, domain)
|
return r.Resolver.LookupHTTPS(ctx, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SaverResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
|
func (r *ResolverSaver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||||
// TODO(bassosimone): we should probably implement this method
|
// TODO(bassosimone): we should probably implement this method
|
||||||
return r.Resolver.LookupNS(ctx, domain)
|
return r.Resolver.LookupNS(ctx, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaverDNSTransport is a DNS transport that saves events.
|
// DNSTransportSaver is a DNS transport that saves events.
|
||||||
type SaverDNSTransport struct {
|
type DNSTransportSaver struct {
|
||||||
// DNSTransport is the underlying DNS transport.
|
// DNSTransport is the underlying DNS transport.
|
||||||
DNSTransport model.DNSTransport
|
DNSTransport model.DNSTransport
|
||||||
|
|
||||||
|
@ -99,14 +99,14 @@ func (s *Saver) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return txp
|
return txp
|
||||||
}
|
}
|
||||||
return &SaverDNSTransport{
|
return &DNSTransportSaver{
|
||||||
DNSTransport: txp,
|
DNSTransport: txp,
|
||||||
Saver: s,
|
Saver: s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements RoundTripper.RoundTrip
|
// RoundTrip implements RoundTripper.RoundTrip
|
||||||
func (txp *SaverDNSTransport) RoundTrip(
|
func (txp *DNSTransportSaver) RoundTrip(
|
||||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
txp.Saver.Write(&EventDNSRoundTripStart{&EventValue{
|
txp.Saver.Write(&EventDNSRoundTripStart{&EventValue{
|
||||||
|
@ -129,19 +129,19 @@ func (txp *SaverDNSTransport) RoundTrip(
|
||||||
return response, err
|
return response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *SaverDNSTransport) Network() string {
|
func (txp *DNSTransportSaver) Network() string {
|
||||||
return txp.DNSTransport.Network()
|
return txp.DNSTransport.Network()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *SaverDNSTransport) Address() string {
|
func (txp *DNSTransportSaver) Address() string {
|
||||||
return txp.DNSTransport.Address()
|
return txp.DNSTransport.Address()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *SaverDNSTransport) CloseIdleConnections() {
|
func (txp *DNSTransportSaver) CloseIdleConnections() {
|
||||||
txp.DNSTransport.CloseIdleConnections()
|
txp.DNSTransport.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *SaverDNSTransport) RequiresPadding() bool {
|
func (txp *DNSTransportSaver) RequiresPadding() bool {
|
||||||
return txp.DNSTransport.RequiresPadding()
|
return txp.DNSTransport.RequiresPadding()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,5 +157,5 @@ func dnsMaybeResponseBytes(response model.DNSResponse) []byte {
|
||||||
return response.Bytes()
|
return response.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Resolver = &SaverResolver{}
|
var _ model.Resolver = &ResolverSaver{}
|
||||||
var _ model.DNSTransport = &SaverDNSTransport{}
|
var _ model.DNSTransport = &DNSTransportSaver{}
|
||||||
|
|
|
@ -14,10 +14,11 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSaverResolverFailure(t *testing.T) {
|
func TestResolverSaver(t *testing.T) {
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
expected := errors.New("no such host")
|
expected := errors.New("no such host")
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
reso := saver.WrapResolver(NewFakeResolverWithExplicitError(expected))
|
reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected))
|
||||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -56,12 +57,12 @@ func TestSaverResolverFailure(t *testing.T) {
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
t.Fatal("the saved time is wrong")
|
t.Fatal("the saved time is wrong")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestSaverResolverSuccess(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
reso := saver.WrapResolver(NewFakeResolverWithResult(expected))
|
reso := saver.WrapResolver(newFakeResolverWithResult(expected))
|
||||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("expected nil error here")
|
t.Fatal("expected nil error here")
|
||||||
|
@ -100,9 +101,11 @@ func TestSaverResolverSuccess(t *testing.T) {
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
t.Fatal("the saved time is wrong")
|
t.Fatal("the saved time is wrong")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSaverDNSTransportFailure(t *testing.T) {
|
func TestDNSTransportSaver(t *testing.T) {
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
expected := errors.New("no such host")
|
expected := errors.New("no such host")
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
|
txp := saver.WrapDNSTransport(&mocks.DNSTransport{
|
||||||
|
@ -160,9 +163,9 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
t.Fatal("the saved time is wrong")
|
t.Fatal("the saved time is wrong")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestSaverDNSTransportSuccess(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
expected := []byte{0xef, 0xbe, 0xad, 0xde}
|
expected := []byte{0xef, 0xbe, 0xad, 0xde}
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
response := &mocks.DNSResponse{
|
response := &mocks.DNSResponse{
|
||||||
|
@ -225,9 +228,10 @@ func TestSaverDNSTransportSuccess(t *testing.T) {
|
||||||
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
if !ev[1].Value().Time.After(ev[0].Value().Time) {
|
||||||
t.Fatal("the saved time is wrong")
|
t.Fatal("the saved time is wrong")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFakeResolverWithExplicitError(err error) model.Resolver {
|
func newFakeResolverWithExplicitError(err error) model.Resolver {
|
||||||
runtimex.PanicIfNil(err, "passed nil error")
|
runtimex.PanicIfNil(err, "passed nil error")
|
||||||
return &mocks.Resolver{
|
return &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
|
@ -251,7 +255,7 @@ func NewFakeResolverWithExplicitError(err error) model.Resolver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFakeResolverWithResult(r []string) model.Resolver {
|
func newFakeResolverWithResult(r []string) model.Resolver {
|
||||||
return &mocks.Resolver{
|
return &mocks.Resolver{
|
||||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||||
return r, nil
|
return r, nil
|
||||||
|
|
|
@ -3,9 +3,12 @@ package tracex
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSaver(t *testing.T) {
|
func TestSaver(t *testing.T) {
|
||||||
|
t.Run("concurrent writes followed by read", func(t *testing.T) {
|
||||||
saver := Saver{}
|
saver := Saver{}
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
const parallel = 10
|
const parallel = 10
|
||||||
|
@ -21,4 +24,67 @@ func TestSaver(t *testing.T) {
|
||||||
if len(ev) != parallel {
|
if len(ev) != parallel {
|
||||||
t.Fatal("unexpected number of events read")
|
t.Fatal("unexpected number of events read")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewConnectObserver", func(t *testing.T) {
|
||||||
|
t.Run("nil Saver", func(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
obs := saver.NewConnectObserver()
|
||||||
|
if obs != nil {
|
||||||
|
t.Fatal("expected nil observer")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonnnil Saver", func(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
obs := saver.NewConnectObserver()
|
||||||
|
underlying := obs.(*dialerConnectObserver)
|
||||||
|
if underlying.saver != saver {
|
||||||
|
t.Fatal("invalid saver")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NewReadWriteObserver", func(t *testing.T) {
|
||||||
|
t.Run("nil Saver", func(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
obs := saver.NewReadWriteObserver()
|
||||||
|
if obs != nil {
|
||||||
|
t.Fatal("expected nil observer")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonnnil Saver", func(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
obs := saver.NewReadWriteObserver()
|
||||||
|
underlying := obs.(*dialerReadWriteObserver)
|
||||||
|
if underlying.saver != saver {
|
||||||
|
t.Fatal("invalid saver")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WrapQUICDialer", func(t *testing.T) {
|
||||||
|
t.Run("nil Saver", func(t *testing.T) {
|
||||||
|
var saver *Saver
|
||||||
|
base := &mocks.QUICDialer{}
|
||||||
|
qd := saver.WrapQUICDialer(base)
|
||||||
|
if qd != base {
|
||||||
|
t.Fatal("unexpected returned QUICDialer")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nonnnil Saver", func(t *testing.T) {
|
||||||
|
saver := &Saver{}
|
||||||
|
base := &mocks.QUICDialer{}
|
||||||
|
qd := saver.WrapQUICDialer(base)
|
||||||
|
underlying := qd.(*QUICDialerSaver)
|
||||||
|
if underlying.Saver != saver {
|
||||||
|
t.Fatal("invalid Saver")
|
||||||
|
}
|
||||||
|
if underlying.QUICDialer != base {
|
||||||
|
t.Fatal("invalid QUICDialer")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,8 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SaverTLSHandshaker saves events occurring during the TLS handshake.
|
// TLSHandshakerSaver saves events occurring during the TLS handshake.
|
||||||
type SaverTLSHandshaker struct {
|
type TLSHandshakerSaver struct {
|
||||||
// TLSHandshaker is the underlying TLS handshaker.
|
// TLSHandshaker is the underlying TLS handshaker.
|
||||||
TLSHandshaker model.TLSHandshaker
|
TLSHandshaker model.TLSHandshaker
|
||||||
|
|
||||||
|
@ -34,23 +34,26 @@ func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return thx
|
return thx
|
||||||
}
|
}
|
||||||
return &SaverTLSHandshaker{
|
return &TLSHandshakerSaver{
|
||||||
TLSHandshaker: thx,
|
TLSHandshaker: thx,
|
||||||
Saver: s,
|
Saver: s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake implements model.TLSHandshaker.Handshake
|
// Handshake implements model.TLSHandshaker.Handshake
|
||||||
func (h *SaverTLSHandshaker) Handshake(
|
func (h *TLSHandshakerSaver) Handshake(
|
||||||
ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
proto := conn.RemoteAddr().Network()
|
||||||
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
h.Saver.Write(&EventTLSHandshakeStart{&EventValue{
|
h.Saver.Write(&EventTLSHandshakeStart{&EventValue{
|
||||||
|
Address: remoteAddr,
|
||||||
NoTLSVerify: config.InsecureSkipVerify,
|
NoTLSVerify: config.InsecureSkipVerify,
|
||||||
|
Proto: proto,
|
||||||
TLSNextProtos: config.NextProtos,
|
TLSNextProtos: config.NextProtos,
|
||||||
TLSServerName: config.ServerName,
|
TLSServerName: config.ServerName,
|
||||||
Time: start,
|
Time: start,
|
||||||
}})
|
}})
|
||||||
remoteAddr := conn.RemoteAddr().String()
|
|
||||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||||
stop := time.Now()
|
stop := time.Now()
|
||||||
h.Saver.Write(&EventTLSHandshakeDone{&EventValue{
|
h.Saver.Write(&EventTLSHandshakeDone{&EventValue{
|
||||||
|
@ -58,6 +61,7 @@ func (h *SaverTLSHandshaker) Handshake(
|
||||||
Duration: stop.Sub(start),
|
Duration: stop.Sub(start),
|
||||||
Err: err,
|
Err: err,
|
||||||
NoTLSVerify: config.InsecureSkipVerify,
|
NoTLSVerify: config.InsecureSkipVerify,
|
||||||
|
Proto: proto,
|
||||||
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
|
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
|
||||||
TLSNegotiatedProto: state.NegotiatedProtocol,
|
TLSNegotiatedProto: state.NegotiatedProtocol,
|
||||||
TLSNextProtos: config.NextProtos,
|
TLSNextProtos: config.NextProtos,
|
||||||
|
@ -69,7 +73,7 @@ func (h *SaverTLSHandshaker) Handshake(
|
||||||
return tlsconn, state, err
|
return tlsconn, state, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.TLSHandshaker = &SaverTLSHandshaker{}
|
var _ model.TLSHandshaker = &TLSHandshakerSaver{}
|
||||||
|
|
||||||
// tlsPeerCerts returns the certificates presented by the peer regardless
|
// tlsPeerCerts returns the certificates presented by the peer regardless
|
||||||
// of whether the TLS handshake was successful
|
// of whether the TLS handshake was successful
|
||||||
|
|
|
@ -3,291 +3,250 @@ package tracex
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"reflect"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
|
func TestTLSHandshakerSaver(t *testing.T) {
|
||||||
// This is the most common use case for collecting reads, writes
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skip test in short mode")
|
|
||||||
}
|
|
||||||
nextprotos := []string{"h2"}
|
|
||||||
saver := &Saver{}
|
|
||||||
tlsdlr := &netxlite.TLSDialerLegacy{
|
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
|
||||||
Dialer: netxlite.NewDialerWithResolver(
|
|
||||||
model.DiscardLogger,
|
|
||||||
netxlite.NewResolverStdlib(model.DiscardLogger),
|
|
||||||
saver.NewReadWriteObserver(),
|
|
||||||
),
|
|
||||||
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
|
|
||||||
}
|
|
||||||
// Implementation note: we don't close the connection here because it is
|
|
||||||
// very handy to have the last event being the end of the handshake
|
|
||||||
_, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ev := saver.Read()
|
|
||||||
if len(ev) < 4 {
|
|
||||||
// it's a bit tricky to be sure about the right number of
|
|
||||||
// events because network conditions may influence that
|
|
||||||
t.Fatal("unexpected number of events")
|
|
||||||
}
|
|
||||||
if ev[0].Name() != "tls_handshake_start" {
|
|
||||||
t.Fatal("unexpected Name")
|
|
||||||
}
|
|
||||||
if ev[0].Value().TLSServerName != "www.google.com" {
|
|
||||||
t.Fatal("unexpected TLSServerName")
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) {
|
|
||||||
t.Fatal("unexpected TLSNextProtos")
|
|
||||||
}
|
|
||||||
if ev[0].Value().Time.After(time.Now()) {
|
|
||||||
t.Fatal("unexpected Time")
|
|
||||||
}
|
|
||||||
last := len(ev) - 1
|
|
||||||
for idx := 1; idx < last; idx++ {
|
|
||||||
if ev[idx].Value().Data == nil {
|
|
||||||
t.Fatal("unexpected Data")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().Err != nil {
|
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().NumBytes <= 0 {
|
|
||||||
t.Fatal("unexpected NumBytes")
|
|
||||||
}
|
|
||||||
switch ev[idx].Name() {
|
|
||||||
case netxlite.ReadOperation, netxlite.WriteOperation:
|
|
||||||
default:
|
|
||||||
t.Fatal("unexpected Name")
|
|
||||||
}
|
|
||||||
if ev[idx].Value().Time.Before(ev[idx-1].Value().Time) {
|
|
||||||
t.Fatal("unexpected Time")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ev[last].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[last].Value().Err != nil {
|
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
}
|
|
||||||
if ev[last].Name() != "tls_handshake_done" {
|
|
||||||
t.Fatal("unexpected Name")
|
|
||||||
}
|
|
||||||
if ev[last].Value().TLSCipherSuite == "" {
|
|
||||||
t.Fatal("unexpected TLSCipherSuite")
|
|
||||||
}
|
|
||||||
if ev[last].Value().TLSNegotiatedProto != "h2" {
|
|
||||||
t.Fatal("unexpected TLSNegotiatedProto")
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(ev[last].Value().TLSNextProtos, nextprotos) {
|
|
||||||
t.Fatal("unexpected TLSNextProtos")
|
|
||||||
}
|
|
||||||
if ev[last].Value().TLSPeerCerts == nil {
|
|
||||||
t.Fatal("unexpected TLSPeerCerts")
|
|
||||||
}
|
|
||||||
if ev[last].Value().TLSServerName != "www.google.com" {
|
|
||||||
t.Fatal("unexpected TLSServerName")
|
|
||||||
}
|
|
||||||
if ev[last].Value().TLSVersion == "" {
|
|
||||||
t.Fatal("unexpected TLSVersion")
|
|
||||||
}
|
|
||||||
if ev[last].Value().Time.Before(ev[last-1].Value().Time) {
|
|
||||||
t.Fatal("unexpected Time")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSaverTLSHandshakerSuccess(t *testing.T) {
|
t.Run("Handshake", func(t *testing.T) {
|
||||||
if testing.Short() {
|
checkStartEventFields := func(t *testing.T, value *EventValue) {
|
||||||
t.Skip("skip test in short mode")
|
if value.Address != "8.8.8.8:443" {
|
||||||
|
t.Fatal("invalid Address")
|
||||||
}
|
}
|
||||||
nextprotos := []string{"h2"}
|
if !value.NoTLSVerify {
|
||||||
saver := &Saver{}
|
t.Fatal("expected NoTLSVerify to be true")
|
||||||
tlsdlr := &netxlite.TLSDialerLegacy{
|
|
||||||
Config: &tls.Config{NextProtos: nextprotos},
|
|
||||||
Dialer: &netxlite.DialerSystem{},
|
|
||||||
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
|
|
||||||
}
|
}
|
||||||
conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
|
if value.Proto != "tcp" {
|
||||||
if err != nil {
|
t.Fatal("wrong protocol")
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
conn.Close()
|
if diff := cmp.Diff(value.TLSNextProtos, []string{"h2"}); diff != "" {
|
||||||
ev := saver.Read()
|
t.Fatal(diff)
|
||||||
if len(ev) != 2 {
|
|
||||||
t.Fatal("unexpected number of events")
|
|
||||||
}
|
}
|
||||||
if ev[0].Name() != "tls_handshake_start" {
|
if value.TLSServerName != "dns.google" {
|
||||||
t.Fatal("unexpected Name")
|
t.Fatal("invalid TLSServerName")
|
||||||
}
|
}
|
||||||
if ev[0].Value().TLSServerName != "www.google.com" {
|
if value.Time.IsZero() {
|
||||||
t.Fatal("unexpected TLSServerName")
|
t.Fatal("expected non zero time")
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(ev[0].Value().TLSNextProtos, nextprotos) {
|
|
||||||
t.Fatal("unexpected TLSNextProtos")
|
|
||||||
}
|
}
|
||||||
if ev[0].Value().Time.After(time.Now()) {
|
|
||||||
t.Fatal("unexpected Time")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Duration <= 0 {
|
|
||||||
t.Fatal("unexpected Duration")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Err != nil {
|
|
||||||
t.Fatal("unexpected Err")
|
|
||||||
}
|
|
||||||
if ev[1].Name() != "tls_handshake_done" {
|
|
||||||
t.Fatal("unexpected Name")
|
|
||||||
}
|
|
||||||
if ev[1].Value().TLSCipherSuite == "" {
|
|
||||||
t.Fatal("unexpected TLSCipherSuite")
|
|
||||||
}
|
|
||||||
if ev[1].Value().TLSNegotiatedProto != "h2" {
|
|
||||||
t.Fatal("unexpected TLSNegotiatedProto")
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(ev[1].Value().TLSNextProtos, nextprotos) {
|
|
||||||
t.Fatal("unexpected TLSNextProtos")
|
|
||||||
}
|
|
||||||
if ev[1].Value().TLSPeerCerts == nil {
|
|
||||||
t.Fatal("unexpected TLSPeerCerts")
|
|
||||||
}
|
|
||||||
if ev[1].Value().TLSServerName != "www.google.com" {
|
|
||||||
t.Fatal("unexpected TLSServerName")
|
|
||||||
}
|
|
||||||
if ev[1].Value().TLSVersion == "" {
|
|
||||||
t.Fatal("unexpected TLSVersion")
|
|
||||||
}
|
|
||||||
if ev[1].Value().Time.Before(ev[0].Value().Time) {
|
|
||||||
t.Fatal("unexpected Time")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSaverTLSHandshakerHostnameError(t *testing.T) {
|
checkStartedEvent := func(t *testing.T, ev Event) {
|
||||||
if testing.Short() {
|
if _, good := ev.(*EventTLSHandshakeStart); !good {
|
||||||
t.Skip("skip test in short mode")
|
t.Fatal("invalid event type")
|
||||||
}
|
}
|
||||||
saver := &Saver{}
|
value := ev.Value()
|
||||||
tlsdlr := &netxlite.TLSDialerLegacy{
|
checkStartEventFields(t, value)
|
||||||
Dialer: &netxlite.DialerSystem{},
|
|
||||||
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
|
|
||||||
}
|
}
|
||||||
conn, err := tlsdlr.DialTLSContext(
|
|
||||||
context.Background(), "tcp", "wrong.host.badssl.com:443")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected an error here")
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("expected nil conn here")
|
|
||||||
}
|
|
||||||
for _, ev := range saver.Read() {
|
|
||||||
if ev.Name() != "tls_handshake_done" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ev.Value().NoTLSVerify == true {
|
|
||||||
t.Fatal("expected NoTLSVerify to be false")
|
|
||||||
}
|
|
||||||
if len(ev.Value().TLSPeerCerts) < 1 {
|
|
||||||
t.Fatal("expected at least a certificate here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
|
checkDoneEventFieldsSuccess := func(t *testing.T, value *EventValue) {
|
||||||
if testing.Short() {
|
if value.Duration <= 0 {
|
||||||
t.Skip("skip test in short mode")
|
t.Fatal("expected non-zero duration")
|
||||||
}
|
}
|
||||||
saver := &Saver{}
|
if value.Err != nil {
|
||||||
tlsdlr := &netxlite.TLSDialerLegacy{
|
t.Fatal("expected no error here")
|
||||||
Dialer: &netxlite.DialerSystem{},
|
|
||||||
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
|
|
||||||
}
|
}
|
||||||
conn, err := tlsdlr.DialTLSContext(
|
if value.TLSCipherSuite != "TLS_RSA_WITH_RC4_128_SHA" {
|
||||||
context.Background(), "tcp", "expired.badssl.com:443")
|
t.Fatal("invalid cipher suite")
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected an error here")
|
|
||||||
}
|
}
|
||||||
if conn != nil {
|
if value.TLSNegotiatedProto != "h2" {
|
||||||
t.Fatal("expected nil conn here")
|
t.Fatal("invalid negotiated protocol")
|
||||||
}
|
}
|
||||||
for _, ev := range saver.Read() {
|
if diff := cmp.Diff(value.TLSPeerCerts, []*x509.Certificate{}); diff != "" {
|
||||||
if ev.Name() != "tls_handshake_done" {
|
t.Fatal(diff)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if ev.Value().NoTLSVerify == true {
|
if value.TLSVersion != "TLSv1.3" {
|
||||||
t.Fatal("expected NoTLSVerify to be false")
|
t.Fatal("invalid TLS version")
|
||||||
}
|
|
||||||
if len(ev.Value().TLSPeerCerts) < 1 {
|
|
||||||
t.Fatal("expected at least a certificate here")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
|
checkDoneEvent := func(t *testing.T, ev Event, fun func(t *testing.T, value *EventValue)) {
|
||||||
if testing.Short() {
|
if _, good := ev.(*EventTLSHandshakeDone); !good {
|
||||||
t.Skip("skip test in short mode")
|
t.Fatal("invalid event type")
|
||||||
}
|
}
|
||||||
saver := &Saver{}
|
value := ev.Value()
|
||||||
tlsdlr := &netxlite.TLSDialerLegacy{
|
checkStartEventFields(t, value)
|
||||||
Dialer: &netxlite.DialerSystem{},
|
fun(t, value)
|
||||||
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
|
|
||||||
}
|
}
|
||||||
conn, err := tlsdlr.DialTLSContext(
|
|
||||||
context.Background(), "tcp", "self-signed.badssl.com:443")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected an error here")
|
|
||||||
}
|
|
||||||
if conn != nil {
|
|
||||||
t.Fatal("expected nil conn here")
|
|
||||||
}
|
|
||||||
for _, ev := range saver.Read() {
|
|
||||||
if ev.Name() != "tls_handshake_done" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ev.Value().NoTLSVerify == true {
|
|
||||||
t.Fatal("expected NoTLSVerify to be false")
|
|
||||||
}
|
|
||||||
if len(ev.Value().TLSPeerCerts) < 1 {
|
|
||||||
t.Fatal("expected at least a certificate here")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
|
t.Run("on success", func(t *testing.T) {
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skip test in short mode")
|
|
||||||
}
|
|
||||||
saver := &Saver{}
|
saver := &Saver{}
|
||||||
tlsdlr := &netxlite.TLSDialerLegacy{
|
returnedConnState := tls.ConnectionState{
|
||||||
Config: &tls.Config{InsecureSkipVerify: true},
|
CipherSuite: tls.TLS_RSA_WITH_RC4_128_SHA,
|
||||||
Dialer: &netxlite.DialerSystem{},
|
NegotiatedProtocol: "h2",
|
||||||
TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
|
PeerCertificates: []*x509.Certificate{},
|
||||||
|
Version: tls.VersionTLS13,
|
||||||
}
|
}
|
||||||
conn, err := tlsdlr.DialTLSContext(
|
returnedConn := &mocks.TLSConn{
|
||||||
context.Background(), "tcp", "self-signed.badssl.com:443")
|
MockConnectionState: func() tls.ConnectionState {
|
||||||
|
return returnedConnState
|
||||||
|
},
|
||||||
|
}
|
||||||
|
thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{
|
||||||
|
MockHandshake: func(ctx context.Context, conn net.Conn,
|
||||||
|
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
return returnedConn, returnedConnState, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
tcpConn := &mocks.Conn{
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "8.8.8.8:443"
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
t.Fatal("expected non-nil conn here")
|
t.Fatal("expected non-nil conn")
|
||||||
}
|
}
|
||||||
conn.Close()
|
events := saver.Read()
|
||||||
for _, ev := range saver.Read() {
|
if len(events) != 2 {
|
||||||
if ev.Name() != "tls_handshake_done" {
|
t.Fatal("expected two events")
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if ev.Value().NoTLSVerify != true {
|
checkStartedEvent(t, events[0])
|
||||||
t.Fatal("expected NoTLSVerify to be true")
|
checkDoneEvent(t, events[1], checkDoneEventFieldsSuccess)
|
||||||
|
})
|
||||||
|
|
||||||
|
checkDoneEventFieldsFailure := func(t *testing.T, value *EventValue) {
|
||||||
|
if value.Duration <= 0 {
|
||||||
|
t.Fatal("expected non-zero duration")
|
||||||
}
|
}
|
||||||
if len(ev.Value().TLSPeerCerts) < 1 {
|
if value.Err == nil {
|
||||||
t.Fatal("expected at least a certificate here")
|
t.Fatal("expected non-nil error here")
|
||||||
}
|
}
|
||||||
|
if value.TLSCipherSuite != "" {
|
||||||
|
t.Fatal("invalid TLS cipher suite")
|
||||||
|
}
|
||||||
|
if value.TLSNegotiatedProto != "" {
|
||||||
|
t.Fatal("invalid negotiated proto")
|
||||||
|
}
|
||||||
|
if len(value.TLSPeerCerts) > 0 {
|
||||||
|
t.Fatal("expected no peer certs")
|
||||||
|
}
|
||||||
|
if value.TLSVersion != "" {
|
||||||
|
t.Fatal("invalid TLS version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
saver := &Saver{}
|
||||||
|
thx := saver.WrapTLSHandshaker(&mocks.TLSHandshaker{
|
||||||
|
MockHandshake: func(ctx context.Context, conn net.Conn,
|
||||||
|
config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||||
|
return nil, tls.ConnectionState{}, expected
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ctx := context.Background()
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
ServerName: "dns.google",
|
||||||
|
}
|
||||||
|
tcpConn := &mocks.Conn{
|
||||||
|
MockRemoteAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockString: func() string {
|
||||||
|
return "8.8.8.8:443"
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "tcp"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, _, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
events := saver.Read()
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Fatal("expected two events")
|
||||||
|
}
|
||||||
|
checkStartedEvent(t, events[0])
|
||||||
|
checkDoneEvent(t, events[1], checkDoneEventFieldsFailure)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_tlsPeerCerts(t *testing.T) {
|
||||||
|
cert0 := &x509.Certificate{Raw: []byte{1, 2, 3, 4}}
|
||||||
|
type args struct {
|
||||||
|
state tls.ConnectionState
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []*x509.Certificate
|
||||||
|
}{{
|
||||||
|
name: "no error",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{
|
||||||
|
PeerCertificates: []*x509.Certificate{cert0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []*x509.Certificate{cert0},
|
||||||
|
}, {
|
||||||
|
name: "all empty",
|
||||||
|
args: args{},
|
||||||
|
want: nil,
|
||||||
|
}, {
|
||||||
|
name: "x509.HostnameError",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{},
|
||||||
|
err: x509.HostnameError{
|
||||||
|
Certificate: cert0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []*x509.Certificate{cert0},
|
||||||
|
}, {
|
||||||
|
name: "x509.UnknownAuthorityError",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{},
|
||||||
|
err: x509.UnknownAuthorityError{
|
||||||
|
Cert: cert0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []*x509.Certificate{cert0},
|
||||||
|
}, {
|
||||||
|
name: "x509.CertificateInvalidError",
|
||||||
|
args: args{
|
||||||
|
state: tls.ConnectionState{},
|
||||||
|
err: x509.CertificateInvalidError{
|
||||||
|
Cert: cert0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []*x509.Certificate{cert0},
|
||||||
|
}}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := tlsPeerCerts(tt.args.state, tt.args.err)
|
||||||
|
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user