refactor(tracex): start applying recent code conventions (#773)

The code that is now into the tracex package was written a long
time ago, so let's start to make it more in line with the coding
style of packages that were written more recently.

I didn't apply all the changes I'd like to apply in a single diff
and for now I am committing just this diff.

Broadly, what we need to do is:

1. improve documentation

2. ~always use pointer receivers (object receives have the issue
that they are not mutable by accident meaning that you can mutate
them but their state do not change after the call returns, which
is potentially a source of bugs in case you later refactor to use
a pointer receiver, so always use pointer receivers)

3. ~always avoid embedding (let's say we want to avoid embedding
for types we define and it's instead fine to embed types that are
defined in the stdlib: if later we add a new method, we will not
see a broken build and we'll probably forget to add the new method
to all wrappers -- conversely, if we're wrapping rather than
embedding, we'll see a broken build and act accordingly)

4. prefer unit tests and group tests by type being tested rather
than using a flat structure for tests

There's a coverage slippage that I'll compensate in a follow-up diff where I'll focus on unit testing.

Reference issue: https://github.com/ooni/probe/issues/2121
This commit is contained in:
Simone Basso 2022-06-01 07:44:54 +02:00 committed by GitHub
parent bbcd2e2280
commit f4f3ed7c42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 346 additions and 237 deletions

View File

@ -119,7 +119,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -195,7 +195,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -271,7 +271,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }
@ -347,7 +347,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
stxp, ok := sr.Txp.(tracex.SaverDNSTransport) stxp, ok := sr.Txp.(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the DNS transport we expected") t.Fatal("not the DNS transport we expected")
} }

View File

@ -104,9 +104,7 @@ func NewResolver(config Config) model.Resolver {
Resolver: r, Resolver: r,
} }
} }
if config.ResolveSaver != nil { r = config.ResolveSaver.WrapResolver(r) // WAI when config.ResolveSaver==nil
r = tracex.SaverResolver{Resolver: r, Saver: config.ResolveSaver}
}
return &netxlite.ResolverIDNA{Resolver: r} return &netxlite.ResolverIDNA{Resolver: r}
} }
@ -129,23 +127,14 @@ func NewQUICDialer(config Config) model.QUICDialer {
if config.FullResolver == nil { if config.FullResolver == nil {
config.FullResolver = NewResolver(config) config.FullResolver = NewResolver(config)
} }
ql := netxlite.NewQUICListener() ql := config.ReadWriteSaver.WrapQUICListener(netxlite.NewQUICListener())
if config.ReadWriteSaver != nil {
ql = &tracex.QUICListenerSaver{
QUICListener: ql,
Saver: config.ReadWriteSaver,
}
}
var logger model.DebugLogger = model.DiscardLogger var logger model.DebugLogger = model.DiscardLogger
if config.Logger != nil { if config.Logger != nil {
logger = config.Logger logger = config.Logger
} }
extensions := []netxlite.QUICDialerWrapper{ extensions := []netxlite.QUICDialerWrapper{
func(dialer model.QUICDialer) model.QUICDialer { func(dialer model.QUICDialer) model.QUICDialer {
if config.TLSSaver != nil { return config.TLSSaver.WrapQUICDialer(dialer) // robust to nil TLSSaver
dialer = tracex.QUICHandshakeSaver{Saver: config.TLSSaver, QUICDialer: dialer}
}
return dialer
}, },
} }
return netxlite.NewQUICDialerWithResolver(ql, logger, config.FullResolver, extensions...) return netxlite.NewQUICDialerWithResolver(ql, logger, config.FullResolver, extensions...)
@ -161,9 +150,7 @@ func NewTLSDialer(config Config) model.TLSDialer {
if config.Logger != nil { if config.Logger != nil {
h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h} h = &netxlite.TLSHandshakerLogger{DebugLogger: config.Logger, TLSHandshaker: h}
} }
if config.TLSSaver != nil { h = config.TLSSaver.WrapTLSHandshaker(h) // behaves with nil TLSSaver
h = tracex.SaverTLSHandshaker{TLSHandshaker: h, Saver: config.TLSSaver}
}
if config.TLSConfig == nil { if config.TLSConfig == nil {
config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}} config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
} }
@ -284,12 +271,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
httpClient := &http.Client{Transport: NewHTTPTransport(config)} httpClient := &http.Client{Transport: NewHTTPTransport(config)}
var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride( var txp model.DNSTransport = netxlite.NewDNSOverHTTPSTransportWithHostOverride(
httpClient, URL, hostOverride) httpClient, URL, hostOverride)
if config.ResolveSaver != nil { txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
txp = tracex.SaverDNSTransport{
DNSTransport: txp,
Saver: config.ResolveSaver,
}
}
return netxlite.NewSerialResolver(txp), nil return netxlite.NewSerialResolver(txp), nil
case "udp": case "udp":
dialer := NewDialer(config) dialer := NewDialer(config)
@ -299,12 +281,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
} }
var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport( var txp model.DNSTransport = netxlite.NewDNSOverUDPTransport(
dialer, endpoint) dialer, endpoint)
if config.ResolveSaver != nil { txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
txp = tracex.SaverDNSTransport{
DNSTransport: txp,
Saver: config.ResolveSaver,
}
}
return netxlite.NewSerialResolver(txp), nil return netxlite.NewSerialResolver(txp), nil
case "dot": case "dot":
config.TLSConfig.NextProtos = []string{"dot"} config.TLSConfig.NextProtos = []string{"dot"}
@ -315,12 +292,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
} }
var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport( var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport(
tlsDialer.DialTLSContext, endpoint) tlsDialer.DialTLSContext, endpoint)
if config.ResolveSaver != nil { txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
txp = tracex.SaverDNSTransport{
DNSTransport: txp,
Saver: config.ResolveSaver,
}
}
return netxlite.NewSerialResolver(txp), nil return netxlite.NewSerialResolver(txp), nil
case "tcp": case "tcp":
dialer := NewDialer(config) dialer := NewDialer(config)
@ -330,12 +302,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
} }
var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport( var txp model.DNSTransport = netxlite.NewDNSOverTCPTransport(
dialer.DialContext, endpoint) dialer.DialContext, endpoint)
if config.ResolveSaver != nil { txp = config.ResolveSaver.WrapDNSTransport(txp) // safe when config.ResolveSaver == nil
txp = tracex.SaverDNSTransport{
DNSTransport: txp,
Saver: config.ResolveSaver,
}
}
return netxlite.NewSerialResolver(txp), nil return netxlite.NewSerialResolver(txp), nil
default: default:
return nil, errors.New("unsupported resolver scheme") return nil, errors.New("unsupported resolver scheme")

View File

@ -126,7 +126,7 @@ func TestNewResolverWithSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
sr, ok := ir.Resolver.(tracex.SaverResolver) sr, ok := ir.Resolver.(*tracex.SaverResolver)
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
@ -332,7 +332,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
if rtd.TLSHandshaker == nil { if rtd.TLSHandshaker == nil {
t.Fatal("invalid TLSHandshaker") t.Fatal("invalid TLSHandshaker")
} }
sth, ok := rtd.TLSHandshaker.(tracex.SaverTLSHandshaker) sth, ok := rtd.TLSHandshaker.(*tracex.SaverTLSHandshaker)
if !ok { if !ok {
t.Fatal("not the TLSHandshaker we expected") t.Fatal("not the TLSHandshaker we expected")
} }
@ -633,7 +633,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -670,7 +670,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -711,7 +711,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }
@ -756,7 +756,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
if !ok { if !ok {
t.Fatal("not the resolver we expected") t.Fatal("not the resolver we expected")
} }
txp, ok := r.Transport().(tracex.SaverDNSTransport) txp, ok := r.Transport().(*tracex.SaverDNSTransport)
if !ok { if !ok {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
} }

View File

@ -15,7 +15,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
// Compatibility types // Compatibility types. Most experiments still use these names.
type ( type (
ExtSpec = model.ArchivalExtSpec ExtSpec = model.ArchivalExtSpec
TCPConnectEntry = model.ArchivalTCPConnectResult TCPConnectEntry = model.ArchivalTCPConnectResult
@ -32,7 +32,7 @@ type (
NetworkEvent = model.ArchivalNetworkEvent NetworkEvent = model.ArchivalNetworkEvent
) )
// Compatibility variables // Compatibility variables. Most experiments still use these names.
var ( var (
ExtDNS = model.ArchivalExtDNS ExtDNS = model.ArchivalExtDNS
ExtNetevents = model.ArchivalExtNetevents ExtNetevents = model.ArchivalExtNetevents
@ -100,7 +100,7 @@ func NewFailedOperation(err error) *string {
return &s return &s
} }
func addheaders( func httpAddHeaders(
source http.Header, source http.Header,
destList *[]HTTPHeader, destList *[]HTTPHeader,
destMap *map[string]MaybeBinaryValue, destMap *map[string]MaybeBinaryValue,
@ -150,14 +150,14 @@ func newRequestList(begin time.Time, events []Event) []RequestEntry {
entry.Request.BodyIsTruncated = ev.DataIsTruncated entry.Request.BodyIsTruncated = ev.DataIsTruncated
case "http_request_metadata": case "http_request_metadata":
entry.Request.Headers = make(map[string]MaybeBinaryValue) entry.Request.Headers = make(map[string]MaybeBinaryValue)
addheaders( httpAddHeaders(
ev.HTTPHeaders, &entry.Request.HeadersList, &entry.Request.Headers) ev.HTTPHeaders, &entry.Request.HeadersList, &entry.Request.Headers)
entry.Request.Method = ev.HTTPMethod entry.Request.Method = ev.HTTPMethod
entry.Request.URL = ev.HTTPURL entry.Request.URL = ev.HTTPURL
entry.Request.Transport = ev.Transport entry.Request.Transport = ev.Transport
case "http_response_metadata": case "http_response_metadata":
entry.Response.Headers = make(map[string]MaybeBinaryValue) entry.Response.Headers = make(map[string]MaybeBinaryValue)
addheaders( httpAddHeaders(
ev.HTTPHeaders, &entry.Response.HeadersList, &entry.Response.Headers) ev.HTTPHeaders, &entry.Response.HeadersList, &entry.Response.Headers)
entry.Response.Code = int64(ev.HTTPStatusCode) entry.Response.Code = int64(ev.HTTPStatusCode)
entry.Response.Locations = ev.HTTPHeaders.Values("Location") entry.Response.Locations = ev.HTTPHeaders.Values("Location")
@ -183,11 +183,11 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry {
continue continue
} }
for _, qtype := range []dnsQueryType{"A", "AAAA"} { for _, qtype := range []dnsQueryType{"A", "AAAA"} {
entry := qtype.makequeryentry(begin, ev) entry := qtype.makeQueryEntry(begin, ev)
for _, addr := range ev.Addresses { for _, addr := range ev.Addresses {
if qtype.ipoftype(addr) { if qtype.ipOfType(addr) {
entry.Answers = append( entry.Answers = append(
entry.Answers, qtype.makeanswerentry(addr)) entry.Answers, qtype.makeAnswerEntry(addr))
} }
} }
if len(entry.Answers) <= 0 && ev.Err == nil { if len(entry.Answers) <= 0 && ev.Err == nil {
@ -206,7 +206,7 @@ func NewDNSQueriesList(begin time.Time, events []Event) []DNSQueryEntry {
return out return out
} }
func (qtype dnsQueryType) ipoftype(addr string) bool { func (qtype dnsQueryType) ipOfType(addr string) bool {
switch qtype { switch qtype {
case "A": case "A":
return !strings.Contains(addr, ":") return !strings.Contains(addr, ":")
@ -216,7 +216,7 @@ func (qtype dnsQueryType) ipoftype(addr string) bool {
return false return false
} }
func (qtype dnsQueryType) makeanswerentry(addr string) DNSAnswerEntry { func (qtype dnsQueryType) makeAnswerEntry(addr string) DNSAnswerEntry {
answer := DNSAnswerEntry{AnswerType: string(qtype)} answer := DNSAnswerEntry{AnswerType: string(qtype)}
asn, org, _ := geolocate.LookupASN(addr) asn, org, _ := geolocate.LookupASN(addr)
answer.ASN = int64(asn) answer.ASN = int64(asn)
@ -230,7 +230,7 @@ func (qtype dnsQueryType) makeanswerentry(addr string) DNSAnswerEntry {
return answer return answer
} }
func (qtype dnsQueryType) makequeryentry(begin time.Time, ev Event) DNSQueryEntry { func (qtype dnsQueryType) makeQueryEntry(begin time.Time, ev Event) DNSQueryEntry {
return DNSQueryEntry{ return DNSQueryEntry{
Engine: ev.Proto, Engine: ev.Proto,
Failure: NewFailure(ev.Err), Failure: NewFailure(ev.Err),
@ -315,7 +315,7 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake {
Failure: NewFailure(ev.Err), Failure: NewFailure(ev.Err),
NegotiatedProtocol: ev.TLSNegotiatedProto, NegotiatedProtocol: ev.TLSNegotiatedProto,
NoTLSVerify: ev.NoTLSVerify, NoTLSVerify: ev.NoTLSVerify,
PeerCertificates: makePeerCerts(ev.TLSPeerCerts), PeerCertificates: tlsMakePeerCerts(ev.TLSPeerCerts),
ServerName: ev.TLSServerName, ServerName: ev.TLSServerName,
T: ev.Time.Sub(begin).Seconds(), T: ev.Time.Sub(begin).Seconds(),
TLSVersion: ev.TLSVersion, TLSVersion: ev.TLSVersion,
@ -324,7 +324,7 @@ func NewTLSHandshakesList(begin time.Time, events []Event) []TLSHandshake {
return out return out
} }
func makePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) { func tlsMakePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) {
for _, e := range in { for _, e := range in {
out = append(out, MaybeBinaryValue{Value: string(e.Raw)}) out = append(out, MaybeBinaryValue{Value: string(e.Raw)})
} }

View File

@ -47,7 +47,7 @@ func TestDNSQueryIPOfType(t *testing.T) {
output: false, output: false,
}} }}
for _, exp := range expectations { for _, exp := range expectations {
if exp.qtype.ipoftype(exp.ip) != exp.output { if exp.qtype.ipOfType(exp.ip) != exp.output {
t.Fatalf("failure for %+v", exp) t.Fatalf("failure for %+v", exp)
} }
} }

View File

@ -1,5 +1,9 @@
package tracex package tracex
//
// TCP and connected UDP sockets
//
import ( import (
"context" "context"
"net" "net"
@ -11,7 +15,10 @@ import (
// SaverDialer saves events occurring during the dial // SaverDialer saves events occurring during the dial
type SaverDialer struct { type SaverDialer struct {
model.Dialer // Dialer is the underlying dialer,
Dialer model.Dialer
// Saver saves events.
Saver *Saver Saver *Saver
} }
@ -31,10 +38,17 @@ func (d *SaverDialer) DialContext(ctx context.Context, network, address string)
return conn, err return conn, err
} }
func (d *SaverDialer) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
// SaverConnDialer wraps the returned connection such that we // SaverConnDialer wraps the returned connection such that we
// collect all the read/write events that occur. // collect all the read/write events that occur.
type SaverConnDialer struct { type SaverConnDialer struct {
model.Dialer // Dialer is the underlying dialer
Dialer model.Dialer
// Saver saves events
Saver *Saver Saver *Saver
} }
@ -47,6 +61,10 @@ func (d *SaverConnDialer) DialContext(ctx context.Context, network, address stri
return &saverConn{saver: d.Saver, Conn: conn}, nil return &saverConn{saver: d.Saver, Conn: conn}, nil
} }
func (d *SaverConnDialer) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
type saverConn struct { type saverConn struct {
net.Conn net.Conn
saver *Saver saver *Saver

View File

@ -1,2 +1,8 @@
// Package tracex contains code to perform measurements using tracing. // Package tracex performs measurements using tracing. To use tracing means
// that we'll wrap netx data types (e.g., a Dialer) with equivalent data types
// saving events into a Saver data struture. Then we will use the data types
// normally (e.g., call the Dialer's DialContet method and then use the
// resulting connection). When done, we will extract the trace containing
// all the events that occurred from the saver and process it to determine
// what happened during the measurement itself.
package tracex package tracex

View File

@ -1,9 +1,7 @@
package tracex package tracex
import ( import (
"crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"net/http" "net/http"
"time" "time"
) )
@ -36,25 +34,3 @@ type Event struct {
Time time.Time `json:",omitempty"` Time time.Time `json:",omitempty"`
Transport string `json:",omitempty"` Transport string `json:",omitempty"`
} }
// PeerCerts returns the certificates presented by the peer regardless
// of whether the TLS handshake was successful
func PeerCerts(state tls.ConnectionState, err error) []*x509.Certificate {
var x509HostnameError x509.HostnameError
if errors.As(err, &x509HostnameError) {
// Test case: https://wrong.host.badssl.com/
return []*x509.Certificate{x509HostnameError.Certificate}
}
var x509UnknownAuthorityError x509.UnknownAuthorityError
if errors.As(err, &x509UnknownAuthorityError) {
// Test case: https://self-signed.badssl.com/. This error has
// never been among the ones returned by MK.
return []*x509.Certificate{x509UnknownAuthorityError.Cert}
}
var x509CertificateInvalidError x509.CertificateInvalidError
if errors.As(err, &x509CertificateInvalidError) {
// Test case: https://expired.badssl.com/
return []*x509.Certificate{x509CertificateInvalidError.Cert}
}
return state.PeerCertificates
}

View File

@ -1,5 +1,9 @@
package tracex package tracex
//
// HTTP
//
import ( import (
"bytes" "bytes"
"context" "context"
@ -21,7 +25,7 @@ type SaverMetadataHTTPTransport struct {
// RoundTrip implements RoundTripper.RoundTrip // RoundTrip implements RoundTripper.RoundTrip
func (txp SaverMetadataHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (txp SaverMetadataHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
txp.Saver.Write(Event{ txp.Saver.Write(Event{
HTTPHeaders: txp.CloneHeaders(req), HTTPHeaders: httpCloneHeaders(req),
HTTPMethod: req.Method, HTTPMethod: req.Method,
HTTPURL: req.URL.String(), HTTPURL: req.URL.String(),
Transport: txp.HTTPTransport.Network(), Transport: txp.HTTPTransport.Network(),
@ -41,10 +45,10 @@ func (txp SaverMetadataHTTPTransport) RoundTrip(req *http.Request) (*http.Respon
return resp, err return resp, err
} }
// CloneHeaders returns a clone of the headers where we have // httpCCloneHeaders returns a clone of the headers where we have
// also set the host header, which normally is not set by // also set the host header, which normally is not set by
// golang until it serializes the request itself. // golang until it serializes the request itself.
func (txp SaverMetadataHTTPTransport) CloneHeaders(req *http.Request) http.Header { func httpCloneHeaders(req *http.Request) http.Header {
header := req.Header.Clone() header := req.Header.Clone()
if req.Host != "" { if req.Host != "" {
header.Set("Host", req.Host) header.Set("Host", req.Host)
@ -92,11 +96,11 @@ func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response,
snapsize = txp.SnapshotSize snapsize = txp.SnapshotSize
} }
if req.Body != nil { if req.Body != nil {
data, err := saverSnapRead(req.Context(), req.Body, snapsize) data, err := httpSaverSnapRead(req.Context(), req.Body, snapsize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Body = saverCompose(data, req.Body) req.Body = httpSaverCompose(data, req.Body)
txp.Saver.Write(Event{ txp.Saver.Write(Event{
DataIsTruncated: len(data) >= snapsize, DataIsTruncated: len(data) >= snapsize,
Data: data, Data: data,
@ -108,12 +112,12 @@ func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response,
if err != nil { if err != nil {
return nil, err return nil, err
} }
data, err := saverSnapRead(req.Context(), resp.Body, snapsize) data, err := httpSaverSnapRead(req.Context(), resp.Body, snapsize)
if err != nil { if err != nil {
resp.Body.Close() resp.Body.Close()
return nil, err return nil, err
} }
resp.Body = saverCompose(data, resp.Body) resp.Body = httpSaverCompose(data, resp.Body)
txp.Saver.Write(Event{ txp.Saver.Write(Event{
DataIsTruncated: len(data) >= snapsize, DataIsTruncated: len(data) >= snapsize,
Data: data, Data: data,
@ -123,15 +127,15 @@ func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response,
return resp, nil return resp, nil
} }
func saverSnapRead(ctx context.Context, r io.ReadCloser, snapsize int) ([]byte, error) { func httpSaverSnapRead(ctx context.Context, r io.ReadCloser, snapsize int) ([]byte, error) {
return netxlite.ReadAllContext(ctx, io.LimitReader(r, int64(snapsize))) return netxlite.ReadAllContext(ctx, io.LimitReader(r, int64(snapsize)))
} }
func saverCompose(data []byte, r io.ReadCloser) io.ReadCloser { func httpSaverCompose(data []byte, r io.ReadCloser) io.ReadCloser {
return saverReadCloser{Closer: r, Reader: io.MultiReader(bytes.NewReader(data), r)} return httpSaverReadCloser{Closer: r, Reader: io.MultiReader(bytes.NewReader(data), r)}
} }
type saverReadCloser struct { type httpSaverReadCloser struct {
io.Closer io.Closer
io.Reader io.Reader
} }

View File

@ -394,8 +394,7 @@ func TestCloneHeaders(t *testing.T) {
}, },
Header: http.Header{}, Header: http.Header{},
} }
txp := SaverMetadataHTTPTransport{} header := httpCloneHeaders(req)
header := txp.CloneHeaders(req)
if header.Get("Host") != "www.example.com" { if header.Get("Host") != "www.example.com" {
t.Fatal("did not set Host header correctly") t.Fatal("did not set Host header correctly")
} }
@ -409,8 +408,7 @@ func TestCloneHeaders(t *testing.T) {
}, },
Header: http.Header{}, Header: http.Header{},
} }
txp := SaverMetadataHTTPTransport{} header := httpCloneHeaders(req)
header := txp.CloneHeaders(req)
if header.Get("Host") != "www.kernel.org" { if header.Get("Host") != "www.kernel.org" {
t.Fatal("did not set Host header correctly") t.Fatal("did not set Host header correctly")
} }

View File

@ -1,5 +1,9 @@
package tracex package tracex
//
// QUIC
//
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
@ -11,14 +15,32 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
// QUICHandshakeSaver saves events occurring during the handshake // QUICHandshakeSaver saves events occurring during the QUIC handshake.
type QUICHandshakeSaver struct { type QUICHandshakeSaver struct {
// QUICDialer is the wrapped dialer
QUICDialer model.QUICDialer
// Saver saves events
Saver *Saver Saver *Saver
model.QUICDialer
} }
// DialContext implements ContextDialer.DialContext // WrapQUICDialer wraps a model.QUICDialer with a QUICHandshakeSaver that will
func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string, // save the QUIC handshake results into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original QUICDialer without any wrapping.
func (s *Saver) WrapQUICDialer(qd model.QUICDialer) model.QUICDialer {
if s == nil {
return qd
}
return &QUICHandshakeSaver{
QUICDialer: qd,
Saver: s,
}
}
// DialContext implements QUICDialer.DialContext
func (h *QUICHandshakeSaver) DialContext(ctx context.Context, network string,
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
start := time.Now() start := time.Now()
// TODO(bassosimone): in the future we probably want to also save // TODO(bassosimone): in the future we probably want to also save
@ -35,6 +57,7 @@ func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string,
sess, err := h.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg) sess, err := h.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
stop := time.Now() stop := time.Now()
if err != nil { if err != nil {
// TODO(bassosimone): here we should save the peer certs
h.Saver.Write(Event{ h.Saver.Write(Event{
Duration: stop.Sub(start), Duration: stop.Sub(start),
Err: err, Err: err,
@ -54,7 +77,7 @@ func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string,
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol, TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: tlsCfg.NextProtos, TLSNextProtos: tlsCfg.NextProtos,
TLSPeerCerts: PeerCerts(state, err), TLSPeerCerts: tlsPeerCerts(state, err),
TLSServerName: tlsCfg.ServerName, TLSServerName: tlsCfg.ServerName,
TLSVersion: netxlite.TLSVersionString(state.Version), TLSVersion: netxlite.TLSVersionString(state.Version),
Time: stop, Time: stop,
@ -62,6 +85,10 @@ func (h QUICHandshakeSaver) DialContext(ctx context.Context, network string,
return sess, nil return sess, nil
} }
func (h *QUICHandshakeSaver) CloseIdleConnections() {
h.QUICDialer.CloseIdleConnections()
}
// quicConnectionState returns the ConnectionState of a QUIC Session. // quicConnectionState returns the ConnectionState of a QUIC Session.
func quicConnectionState(sess quic.EarlyConnection) tls.ConnectionState { func quicConnectionState(sess quic.EarlyConnection) tls.ConnectionState {
return sess.ConnectionState().TLS.ConnectionState return sess.ConnectionState().TLS.ConnectionState
@ -70,32 +97,50 @@ func quicConnectionState(sess quic.EarlyConnection) tls.ConnectionState {
// QUICListenerSaver is a QUICListener that also implements saving events. // QUICListenerSaver is a QUICListener that also implements saving events.
type QUICListenerSaver struct { type QUICListenerSaver struct {
// QUICListener is the underlying QUICListener. // QUICListener is the underlying QUICListener.
model.QUICListener QUICListener model.QUICListener
// Saver is the underlying Saver. // Saver is the underlying Saver.
Saver *Saver Saver *Saver
} }
// WrapQUICListener wraps a model.QUICDialer with a QUICListenerSaver that will
// save the QUIC I/O packet conn events into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original QUICListener without any wrapping.
func (s *Saver) WrapQUICListener(ql model.QUICListener) model.QUICListener {
if s == nil {
return ql
}
return &QUICListenerSaver{
QUICListener: ql,
Saver: s,
}
}
// Listen implements QUICListener.Listen. // Listen implements QUICListener.Listen.
func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) { func (qls *QUICListenerSaver) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
pconn, err := qls.QUICListener.Listen(addr) pconn, err := qls.QUICListener.Listen(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &saverUDPConn{ pconn = &udpLikeConnSaver{
UDPLikeConn: pconn, UDPLikeConn: pconn,
saver: qls.Saver, saver: qls.Saver,
}, nil }
return pconn, nil
} }
type saverUDPConn struct { // udpLikeConnSaver saves I/O events
type udpLikeConnSaver struct {
// UDPLikeConn is the wrapped underlying conn
model.UDPLikeConn model.UDPLikeConn
// Saver saves events
saver *Saver saver *Saver
} }
var _ model.UDPLikeConn = &saverUDPConn{} func (c *udpLikeConnSaver) WriteTo(p []byte, addr net.Addr) (int, error) {
func (c *saverUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) {
start := time.Now() start := time.Now()
count, err := c.UDPLikeConn.WriteTo(p, addr) count, err := c.UDPLikeConn.WriteTo(p, addr)
stop := time.Now() stop := time.Now()
@ -111,7 +156,7 @@ func (c *saverUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return count, err return count, err
} }
func (c *saverUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *udpLikeConnSaver) ReadFrom(b []byte) (int, net.Addr, error) {
start := time.Now() start := time.Now()
n, addr, err := c.UDPLikeConn.ReadFrom(b) n, addr, err := c.UDPLikeConn.ReadFrom(b)
stop := time.Now() stop := time.Now()
@ -131,9 +176,13 @@ func (c *saverUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
return n, addr, err return n, addr, err
} }
func (c *saverUDPConn) safeAddrString(addr net.Addr) (out string) { func (c *udpLikeConnSaver) safeAddrString(addr net.Addr) (out string) {
if addr != nil { if addr != nil {
out = addr.String() out = addr.String()
} }
return return
} }
var _ model.QUICDialer = &QUICHandshakeSaver{}
var _ model.QUICListener = &QUICListenerSaver{}
var _ model.UDPLikeConn = &udpLikeConnSaver{}

View File

@ -39,12 +39,9 @@ func TestHandshakeSaverSuccess(t *testing.T) {
ServerName: servername, ServerName: servername,
} }
saver := &Saver{} saver := &Saver{}
dlr := QUICHandshakeSaver{ dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{
QUICDialer: &netxlite.QUICDialerQUICGo{ QUICListener: &netxlite.QUICListenerStdlib{},
QUICListener: &netxlite.QUICListenerStdlib{}, })
},
Saver: saver,
}
sess, err := dlr.DialContext(context.Background(), "udp", sess, err := dlr.DialContext(context.Background(), "udp",
quictesting.Endpoint("443"), tlsConf, &quic.Config{}) quictesting.Endpoint("443"), tlsConf, &quic.Config{})
if err != nil { if err != nil {
@ -97,12 +94,9 @@ func TestHandshakeSaverHostNameError(t *testing.T) {
ServerName: servername, ServerName: servername,
} }
saver := &Saver{} saver := &Saver{}
dlr := QUICHandshakeSaver{ dlr := saver.WrapQUICDialer(&netxlite.QUICDialerQUICGo{
QUICDialer: &netxlite.QUICDialerQUICGo{ QUICListener: &netxlite.QUICListenerStdlib{},
QUICListener: &netxlite.QUICListenerStdlib{}, })
},
Saver: saver,
}
sess, err := dlr.DialContext(context.Background(), "udp", sess, err := dlr.DialContext(context.Background(), "udp",
quictesting.Endpoint("443"), tlsConf, &quic.Config{}) quictesting.Endpoint("443"), tlsConf, &quic.Config{})
if err == nil { if err == nil {
@ -126,14 +120,12 @@ func TestHandshakeSaverHostNameError(t *testing.T) {
func TestQUICListenerSaverCannotListen(t *testing.T) { func TestQUICListenerSaverCannotListen(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
qls := &QUICListenerSaver{ saver := &Saver{}
QUICListener: &mocks.QUICListener{ qls := saver.WrapQUICListener(&mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
return nil, expected return nil, expected
},
}, },
Saver: &Saver{}, })
}
pconn, err := qls.Listen(&net.UDPAddr{ pconn, err := qls.Listen(&net.UDPAddr{
IP: []byte{}, IP: []byte{},
Port: 8080, Port: 8080,
@ -155,10 +147,7 @@ func TestSystemDialerSuccessWithReadWrite(t *testing.T) {
} }
saver := &Saver{} saver := &Saver{}
systemdialer := &netxlite.QUICDialerQUICGo{ systemdialer := &netxlite.QUICDialerQUICGo{
QUICListener: &QUICListenerSaver{ QUICListener: saver.WrapQUICListener(&netxlite.QUICListenerStdlib{}),
QUICListener: &netxlite.QUICListenerStdlib{},
Saver: saver,
},
} }
_, err := systemdialer.DialContext(context.Background(), "udp", _, err := systemdialer.DialContext(context.Background(), "udp",
quictesting.Endpoint("443"), tlsConf, &quic.Config{}) quictesting.Endpoint("443"), tlsConf, &quic.Config{})

View File

@ -1,20 +1,43 @@
package tracex package tracex
//
// DNS lookup and round trip
//
import ( import (
"context" "context"
"net"
"time" "time"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
// SaverResolver is a resolver that saves events // SaverResolver is a resolver that saves events.
type SaverResolver struct { type SaverResolver struct {
model.Resolver // Resolver is the underlying resolver.
Resolver model.Resolver
// Saver saves events.
Saver *Saver Saver *Saver
} }
// WrapResolver wraps a model.Resolver with a SaverResolver that will save
// the DNS lookup results into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original Resolver without any wrapping.
func (s *Saver) WrapResolver(r model.Resolver) model.Resolver {
if s == nil {
return r
}
return &SaverResolver{
Resolver: r,
Saver: s,
}
}
// LookupHost implements Resolver.LookupHost // LookupHost implements Resolver.LookupHost
func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { func (r *SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
start := time.Now() start := time.Now()
r.Saver.Write(Event{ r.Saver.Write(Event{
Address: r.Resolver.Address(), Address: r.Resolver.Address(),
@ -38,49 +61,105 @@ func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]strin
return addrs, err return addrs, err
} }
// SaverDNSTransport is a DNS transport that saves events func (r *SaverResolver) Network() string {
return r.Resolver.Network()
}
func (r *SaverResolver) Address() string {
return r.Resolver.Address()
}
func (r *SaverResolver) CloseIdleConnections() {
r.Resolver.CloseIdleConnections()
}
func (r *SaverResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
// TODO(bassosimone): we should probably implement this method
return r.Resolver.LookupHTTPS(ctx, domain)
}
func (r *SaverResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
// TODO(bassosimone): we should probably implement this method
return r.Resolver.LookupNS(ctx, domain)
}
// SaverDNSTransport is a DNS transport that saves events.
type SaverDNSTransport struct { type SaverDNSTransport struct {
model.DNSTransport // DNSTransport is the underlying DNS transport.
DNSTransport model.DNSTransport
// Saver saves events.
Saver *Saver Saver *Saver
} }
// WrapDNSTransport wraps a model.DNSTransport with a SaverDNSTransport that
// will save the DNS round trip results into this Saver.
//
// When this function is invoked on a nil Saver, it will directly return
// the original DNSTransport without any wrapping.
func (s *Saver) WrapDNSTransport(txp model.DNSTransport) model.DNSTransport {
if s == nil {
return txp
}
return &SaverDNSTransport{
DNSTransport: txp,
Saver: s,
}
}
// RoundTrip implements RoundTripper.RoundTrip // RoundTrip implements RoundTripper.RoundTrip
func (txp SaverDNSTransport) RoundTrip( func (txp *SaverDNSTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
start := time.Now() start := time.Now()
txp.Saver.Write(Event{ txp.Saver.Write(Event{
Address: txp.Address(), Address: txp.DNSTransport.Address(),
DNSQuery: txp.maybeQueryBytes(query), DNSQuery: dnsMaybeQueryBytes(query),
Name: "dns_round_trip_start", Name: "dns_round_trip_start",
Proto: txp.Network(), Proto: txp.DNSTransport.Network(),
Time: start, Time: start,
}) })
response, err := txp.DNSTransport.RoundTrip(ctx, query) response, err := txp.DNSTransport.RoundTrip(ctx, query)
stop := time.Now() stop := time.Now()
txp.Saver.Write(Event{ txp.Saver.Write(Event{
Address: txp.Address(), Address: txp.DNSTransport.Address(),
DNSQuery: txp.maybeQueryBytes(query), DNSQuery: dnsMaybeQueryBytes(query),
DNSReply: txp.maybeResponseBytes(response), DNSReply: dnsMaybeResponseBytes(response),
Duration: stop.Sub(start), Duration: stop.Sub(start),
Err: err, Err: err,
Name: "dns_round_trip_done", Name: "dns_round_trip_done",
Proto: txp.Network(), Proto: txp.DNSTransport.Network(),
Time: stop, Time: stop,
}) })
return response, err return response, err
} }
func (txp SaverDNSTransport) maybeQueryBytes(query model.DNSQuery) []byte { func (txp *SaverDNSTransport) Network() string {
return txp.DNSTransport.Network()
}
func (txp *SaverDNSTransport) Address() string {
return txp.DNSTransport.Address()
}
func (txp *SaverDNSTransport) CloseIdleConnections() {
txp.DNSTransport.CloseIdleConnections()
}
func (txp *SaverDNSTransport) RequiresPadding() bool {
return txp.DNSTransport.RequiresPadding()
}
func dnsMaybeQueryBytes(query model.DNSQuery) []byte {
data, _ := query.Bytes() data, _ := query.Bytes()
return data return data
} }
func (txp SaverDNSTransport) maybeResponseBytes(response model.DNSResponse) []byte { func dnsMaybeResponseBytes(response model.DNSResponse) []byte {
if response == nil { if response == nil {
return nil return nil
} }
return response.Bytes() return response.Bytes()
} }
var _ model.Resolver = SaverResolver{} var _ model.Resolver = &SaverResolver{}
var _ model.DNSTransport = SaverDNSTransport{} var _ model.DNSTransport = &SaverDNSTransport{}

View File

@ -17,10 +17,7 @@ import (
func TestSaverResolverFailure(t *testing.T) { func TestSaverResolverFailure(t *testing.T) {
expected := errors.New("no such host") expected := errors.New("no such host")
saver := &Saver{} saver := &Saver{}
reso := SaverResolver{ reso := saver.WrapResolver(NewFakeResolverWithExplicitError(expected))
Resolver: NewFakeResolverWithExplicitError(expected),
Saver: saver,
}
addrs, err := reso.LookupHost(context.Background(), "www.google.com") addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -64,10 +61,7 @@ func TestSaverResolverFailure(t *testing.T) {
func TestSaverResolverSuccess(t *testing.T) { func TestSaverResolverSuccess(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"} expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &Saver{} saver := &Saver{}
reso := SaverResolver{ reso := saver.WrapResolver(NewFakeResolverWithResult(expected))
Resolver: NewFakeResolverWithResult(expected),
Saver: saver,
}
addrs, err := reso.LookupHost(context.Background(), "www.google.com") addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if err != nil { if err != nil {
t.Fatal("expected nil error here") t.Fatal("expected nil error here")
@ -111,20 +105,17 @@ func TestSaverResolverSuccess(t *testing.T) {
func TestSaverDNSTransportFailure(t *testing.T) { func TestSaverDNSTransportFailure(t *testing.T) {
expected := errors.New("no such host") expected := errors.New("no such host")
saver := &Saver{} saver := &Saver{}
txp := SaverDNSTransport{ txp := saver.WrapDNSTransport(&mocks.DNSTransport{
DNSTransport: &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, expected
return nil, expected
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
}, },
Saver: saver, MockNetwork: func() string {
} return "fake"
},
MockAddress: func() string {
return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{ query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) { MockBytes: func() ([]byte, error) {
@ -179,20 +170,17 @@ func TestSaverDNSTransportSuccess(t *testing.T) {
return expected return expected
}, },
} }
txp := SaverDNSTransport{ txp := saver.WrapDNSTransport(&mocks.DNSTransport{
DNSTransport: &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return response, nil
return response, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
}, },
Saver: saver, MockNetwork: func() string {
} return "fake"
},
MockAddress: func() string {
return ""
},
})
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{ query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) { MockBytes: func() ([]byte, error) {

View File

@ -1,11 +1,19 @@
package tracex package tracex
//
// Saver implementation
//
import "sync" import "sync"
// The Saver saves a trace // The Saver saves a trace. The zero value of this type
// is valid and can be used without initializtion.
type Saver struct { type Saver struct {
// ops contains the saved events.
ops []Event ops []Event
mu sync.Mutex
// mu provides mutual exclusion.
mu sync.Mutex
} }
// Read reads and returns events inside the trace. It advances // Read reads and returns events inside the trace. It advances

View File

@ -5,7 +5,7 @@ import (
"testing" "testing"
) )
func TestGood(t *testing.T) { func TestSaver(t *testing.T) {
saver := Saver{} saver := Saver{}
var wg sync.WaitGroup var wg sync.WaitGroup
const parallel = 10 const parallel = 10

View File

@ -1,8 +1,14 @@
package tracex package tracex
//
// TLS
//
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors"
"net" "net"
"time" "time"
@ -10,16 +16,33 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
// SaverTLSHandshaker saves events occurring during the handshake // SaverTLSHandshaker saves events occurring during the TLS handshake.
type SaverTLSHandshaker struct { type SaverTLSHandshaker struct {
model.TLSHandshaker // TLSHandshaker is the underlying TLS handshaker.
TLSHandshaker model.TLSHandshaker
// Saver is the saver in which to save events.
Saver *Saver Saver *Saver
} }
// Handshake implements TLSHandshaker.Handshake // WrapTLSHandshaker wraps a model.TLSHandshaker with a SaverTLSHandshaker
func (h SaverTLSHandshaker) Handshake( // that will save the TLS handshake results into this Saver.
ctx context.Context, conn net.Conn, config *tls.Config, //
) (net.Conn, tls.ConnectionState, error) { // When this function is invoked on a nil Saver, it will directly return
// the original TLSHandshaker without any wrapping.
func (s *Saver) WrapTLSHandshaker(thx model.TLSHandshaker) model.TLSHandshaker {
if s == nil {
return thx
}
return &SaverTLSHandshaker{
TLSHandshaker: thx,
Saver: s,
}
}
// Handshake implements model.TLSHandshaker.Handshake
func (h *SaverTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
start := time.Now() start := time.Now()
h.Saver.Write(Event{ h.Saver.Write(Event{
Name: "tls_handshake_start", Name: "tls_handshake_start",
@ -40,7 +63,7 @@ func (h SaverTLSHandshaker) Handshake(
TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite), TLSCipherSuite: netxlite.TLSCipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol, TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: config.NextProtos, TLSNextProtos: config.NextProtos,
TLSPeerCerts: PeerCerts(state, err), TLSPeerCerts: tlsPeerCerts(state, err),
TLSServerName: config.ServerName, TLSServerName: config.ServerName,
TLSVersion: netxlite.TLSVersionString(state.Version), TLSVersion: netxlite.TLSVersionString(state.Version),
Time: stop, Time: stop,
@ -48,4 +71,26 @@ func (h SaverTLSHandshaker) Handshake(
return tlsconn, state, err return tlsconn, state, err
} }
var _ model.TLSHandshaker = SaverTLSHandshaker{} var _ model.TLSHandshaker = &SaverTLSHandshaker{}
// tlsPeerCerts returns the certificates presented by the peer regardless
// of whether the TLS handshake was successful
func tlsPeerCerts(state tls.ConnectionState, err error) []*x509.Certificate {
var x509HostnameError x509.HostnameError
if errors.As(err, &x509HostnameError) {
// Test case: https://wrong.host.badssl.com/
return []*x509.Certificate{x509HostnameError.Certificate}
}
var x509UnknownAuthorityError x509.UnknownAuthorityError
if errors.As(err, &x509UnknownAuthorityError) {
// Test case: https://self-signed.badssl.com/. This error has
// never been among the ones returned by MK.
return []*x509.Certificate{x509UnknownAuthorityError.Cert}
}
var x509CertificateInvalidError x509.CertificateInvalidError
if errors.As(err, &x509CertificateInvalidError) {
// Test case: https://expired.badssl.com/
return []*x509.Certificate{x509CertificateInvalidError.Cert}
}
return state.PeerCertificates
}

View File

@ -30,10 +30,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
} }
}, },
), ),
TLSHandshaker: SaverTLSHandshaker{ TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
} }
// Implementation note: we don't close the connection here because it is // Implementation note: we don't close the connection here because it is
// very handy to have the last event being the end of the handshake // very handy to have the last event being the end of the handshake
@ -121,12 +118,9 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
nextprotos := []string{"h2"} nextprotos := []string{"h2"}
saver := &Saver{} saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{ tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{NextProtos: nextprotos}, Config: &tls.Config{NextProtos: nextprotos},
Dialer: netxlite.DefaultDialer, Dialer: &netxlite.DialerSystem{},
TLSHandshaker: SaverTLSHandshaker{ TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
} }
conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
if err != nil { if err != nil {
@ -187,11 +181,8 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
} }
saver := &Saver{} saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{ tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer, Dialer: &netxlite.DialerSystem{},
TLSHandshaker: SaverTLSHandshaker{ TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
} }
conn, err := tlsdlr.DialTLSContext( conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "wrong.host.badssl.com:443") context.Background(), "tcp", "wrong.host.badssl.com:443")
@ -220,11 +211,8 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
} }
saver := &Saver{} saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{ tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer, Dialer: &netxlite.DialerSystem{},
TLSHandshaker: SaverTLSHandshaker{ TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
} }
conn, err := tlsdlr.DialTLSContext( conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "expired.badssl.com:443") context.Background(), "tcp", "expired.badssl.com:443")
@ -253,11 +241,8 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
} }
saver := &Saver{} saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{ tlsdlr := &netxlite.TLSDialerLegacy{
Dialer: netxlite.DefaultDialer, Dialer: &netxlite.DialerSystem{},
TLSHandshaker: SaverTLSHandshaker{ TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
} }
conn, err := tlsdlr.DialTLSContext( conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443") context.Background(), "tcp", "self-signed.badssl.com:443")
@ -286,12 +271,9 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
} }
saver := &Saver{} saver := &Saver{}
tlsdlr := &netxlite.TLSDialerLegacy{ tlsdlr := &netxlite.TLSDialerLegacy{
Config: &tls.Config{InsecureSkipVerify: true}, Config: &tls.Config{InsecureSkipVerify: true},
Dialer: netxlite.DefaultDialer, Dialer: &netxlite.DialerSystem{},
TLSHandshaker: SaverTLSHandshaker{ TLSHandshaker: saver.WrapTLSHandshaker(&netxlite.TLSHandshakerConfigurable{}),
TLSHandshaker: &netxlite.TLSHandshakerConfigurable{},
Saver: saver,
},
} }
conn, err := tlsdlr.DialTLSContext( conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443") context.Background(), "tcp", "self-signed.badssl.com:443")