chore: merge probe-engine into probe-cli (#201)

This is how I did it:

1. `git clone https://github.com/ooni/probe-engine internal/engine`

2. ```
(cd internal/engine && git describe --tags)
v0.23.0
```

3. `nvim go.mod` (merging `go.mod` with `internal/engine/go.mod`

4. `rm -rf internal/.git internal/engine/go.{mod,sum}`

5. `git add internal/engine`

6. `find . -type f -name \*.go -exec sed -i 's@/ooni/probe-engine@/ooni/probe-cli/v3/internal/engine@g' {} \;`

7. `go build ./...` (passes)

8. `go test -race ./...` (temporary failure on RiseupVPN)

9. `go mod tidy`

10. this commit message

Once this piece of work is done, we can build a new version of `ooniprobe` that
is using `internal/engine` directly. We need to do more work to ensure all the
other functionality in `probe-engine` (e.g. making mobile packages) are still WAI.

Part of https://github.com/ooni/probe/issues/1335
This commit is contained in:
Simone Basso
2021-02-02 12:05:47 +01:00
committed by GitHub
parent b1ce300c8d
commit d57c78bc71
535 changed files with 66182 additions and 23 deletions
+23
View File
@@ -0,0 +1,23 @@
# Package github.com/ooni/probe-engine/netx
OONI extensions to the `net` and `net/http` packages. This code is
used by `ooni/probe-engine` as a low level library to collect
network, DNS, and HTTP events occurring during OONI measurements.
This library contains replacements for commonly used standard library
interfaces that facilitate seamless network measurements. By using
such replacements, as opposed to standard library interfaces, we can:
* save the timing of HTTP events (e.g. received response headers)
* save the timing and result of every Connect, Read, Write, Close operation
* save the timing and result of the TLS handshake (including certificates)
By default, this library uses the system resolver. In addition, it
is possible to configure alternative DNS transports and remote
servers. We support DNS over UDP, DNS over TCP, DNS over TLS (DoT),
and DNS over HTTPS (DoH). When using an alternative transport, we
are also able to intercept and save DNS messages, as well as any
other interaction with the remote server (e.g., the result of the
TLS handshake for DoT and DoH).
This package is a fork of [github.com/ooni/netx](https://github.com/ooni/netx).
+580
View File
@@ -0,0 +1,580 @@
// Package archival contains data formats used for archival.
//
// See https://github.com/ooni/spec.
package archival
import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"net"
"net/http"
"sort"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
"github.com/ooni/probe-cli/v3/internal/engine/model"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// ExtSpec describes a data format extension
type ExtSpec struct {
Name string // extension name
V int64 // extension version
}
// AddTo adds the current ExtSpec to the specified measurement
func (spec ExtSpec) AddTo(m *model.Measurement) {
if m.Extensions == nil {
m.Extensions = make(map[string]int64)
}
m.Extensions[spec.Name] = spec.V
}
var (
// ExtDNS is the version of df-002-dnst.md
ExtDNS = ExtSpec{Name: "dnst", V: 0}
// ExtNetevents is the version of df-008-netevents.md
ExtNetevents = ExtSpec{Name: "netevents", V: 0}
// ExtHTTP is the version of df-001-httpt.md
ExtHTTP = ExtSpec{Name: "httpt", V: 0}
// ExtTCPConnect is the version of df-005-tcpconnect.md
ExtTCPConnect = ExtSpec{Name: "tcpconnect", V: 0}
// ExtTLSHandshake is the version of df-006-tlshandshake.md
ExtTLSHandshake = ExtSpec{Name: "tlshandshake", V: 0}
// ExtTunnel is the version of df-009-tunnel.md
ExtTunnel = ExtSpec{Name: "tunnel", V: 0}
)
// TCPConnectStatus contains the TCP connect status.
//
// The Blocked field breaks the separation between measurement and analysis
// we have been enforcing for quite some time now. It is a legacy from the
// Web Connectivity experiment and it should be here because of that.
type TCPConnectStatus struct {
Blocked *bool `json:"blocked,omitempty"` // Web Connectivity only
Failure *string `json:"failure"`
Success bool `json:"success"`
}
// TCPConnectEntry contains one of the entries that are part
// of the "tcp_connect" key of a OONI report.
type TCPConnectEntry struct {
ConnID int64 `json:"conn_id,omitempty"`
DialID int64 `json:"dial_id,omitempty"`
IP string `json:"ip"`
Port int `json:"port"`
Status TCPConnectStatus `json:"status"`
T float64 `json:"t"`
TransactionID int64 `json:"transaction_id,omitempty"`
}
// NewTCPConnectList creates a new TCPConnectList
func NewTCPConnectList(begin time.Time, events []trace.Event) []TCPConnectEntry {
var out []TCPConnectEntry
for _, event := range events {
if event.Name != errorx.ConnectOperation {
continue
}
if event.Proto != "tcp" {
continue
}
// We assume Go is passing us legit data structures
ip, sport, _ := net.SplitHostPort(event.Address)
iport, _ := strconv.Atoi(sport)
out = append(out, TCPConnectEntry{
IP: ip,
Port: iport,
Status: TCPConnectStatus{
Failure: NewFailure(event.Err),
Success: event.Err == nil,
},
T: event.Time.Sub(begin).Seconds(),
})
}
return out
}
// NewFailure creates a failure nullable string from the given error
func NewFailure(err error) *string {
if err == nil {
return nil
}
// The following code guarantees that the error is always wrapped even
// when we could not actually hit our code that does the wrapping. A case
// in which this happen is with context deadline for HTTP.
err = errorx.SafeErrWrapperBuilder{
Error: err,
Operation: errorx.TopLevelOperation,
}.MaybeBuild()
errWrapper := err.(*errorx.ErrWrapper)
s := errWrapper.Failure
if s == "" {
s = "unknown_failure: errWrapper.Failure is empty"
}
return &s
}
// NewFailedOperation creates a failed operation string from the given error.
func NewFailedOperation(err error) *string {
if err == nil {
return nil
}
var (
errWrapper *errorx.ErrWrapper
s = errorx.UnknownOperation
)
if errors.As(err, &errWrapper) && errWrapper.Operation != "" {
s = errWrapper.Operation
}
return &s
}
// HTTPTor contains Tor information
type HTTPTor struct {
ExitIP *string `json:"exit_ip"`
ExitName *string `json:"exit_name"`
IsTor bool `json:"is_tor"`
}
// MaybeBinaryValue is a possibly binary string. We use this helper class
// to define a custom JSON encoder that allows us to choose the proper
// representation depending on whether the Value field is valid UTF-8 or not.
type MaybeBinaryValue struct {
Value string
}
// MarshalJSON marshals a string-like to JSON following the OONI spec that
// says that UTF-8 content is represened as string and non-UTF-8 content is
// instead represented using `{"format":"base64","data":"..."}`.
func (hb MaybeBinaryValue) MarshalJSON() ([]byte, error) {
if utf8.ValidString(hb.Value) {
return json.Marshal(hb.Value)
}
er := make(map[string]string)
er["format"] = "base64"
er["data"] = base64.StdEncoding.EncodeToString([]byte(hb.Value))
return json.Marshal(er)
}
// UnmarshalJSON is the opposite of MarshalJSON.
func (hb *MaybeBinaryValue) UnmarshalJSON(d []byte) error {
if err := json.Unmarshal(d, &hb.Value); err == nil {
return nil
}
er := make(map[string]string)
if err := json.Unmarshal(d, &er); err != nil {
return err
}
if v, ok := er["format"]; !ok || v != "base64" {
return errors.New("missing or invalid format field")
}
if _, ok := er["data"]; !ok {
return errors.New("missing data field")
}
b64, err := base64.StdEncoding.DecodeString(er["data"])
if err != nil {
return err
}
hb.Value = string(b64)
return nil
}
// HTTPBody is an HTTP body. As an implementation note, this type must be
// an alias for the MaybeBinaryValue type, otherwise the specific serialisation
// mechanism implemented by MaybeBinaryValue is not working.
type HTTPBody = MaybeBinaryValue
// HTTPHeader is a single HTTP header.
type HTTPHeader struct {
Key string
Value MaybeBinaryValue
}
// MarshalJSON marshals a single HTTP header to a tuple where the first
// element is a string and the second element is maybe-binary data.
func (hh HTTPHeader) MarshalJSON() ([]byte, error) {
if utf8.ValidString(hh.Value.Value) {
return json.Marshal([]string{hh.Key, hh.Value.Value})
}
value := make(map[string]string)
value["format"] = "base64"
value["data"] = base64.StdEncoding.EncodeToString([]byte(hh.Value.Value))
return json.Marshal([]interface{}{hh.Key, value})
}
// UnmarshalJSON is the opposite of MarshalJSON.
func (hh *HTTPHeader) UnmarshalJSON(d []byte) error {
var pair []interface{}
if err := json.Unmarshal(d, &pair); err != nil {
return err
}
if len(pair) != 2 {
return errors.New("unexpected pair length")
}
key, ok := pair[0].(string)
if !ok {
return errors.New("the key is not a string")
}
value, ok := pair[1].(string)
if !ok {
mapvalue, ok := pair[1].(map[string]interface{})
if !ok {
return errors.New("the value is neither a string nor a map[string]interface{}")
}
if _, ok := mapvalue["format"]; !ok {
return errors.New("missing format")
}
if v, ok := mapvalue["format"].(string); !ok || v != "base64" {
return errors.New("invalid format")
}
if _, ok := mapvalue["data"]; !ok {
return errors.New("missing data field")
}
v, ok := mapvalue["data"].(string)
if !ok {
return errors.New("the data field is not a string")
}
b64, err := base64.StdEncoding.DecodeString(v)
if err != nil {
return err
}
value = string(b64)
}
hh.Key, hh.Value = key, MaybeBinaryValue{Value: value}
return nil
}
// HTTPRequest contains an HTTP request.
//
// Headers are a map in Web Connectivity data format but
// we have added support for a list since January 2020.
type HTTPRequest struct {
Body HTTPBody `json:"body"`
BodyIsTruncated bool `json:"body_is_truncated"`
HeadersList []HTTPHeader `json:"headers_list"`
Headers map[string]MaybeBinaryValue `json:"headers"`
Method string `json:"method"`
Tor HTTPTor `json:"tor"`
Transport string `json:"x_transport"`
URL string `json:"url"`
}
// HTTPResponse contains an HTTP response.
//
// Headers are a map in Web Connectivity data format but
// we have added support for a list since January 2020.
type HTTPResponse struct {
Body HTTPBody `json:"body"`
BodyIsTruncated bool `json:"body_is_truncated"`
Code int64 `json:"code"`
HeadersList []HTTPHeader `json:"headers_list"`
Headers map[string]MaybeBinaryValue `json:"headers"`
// The following fields are not serialised but are useful to simplify
// analysing the measurements in telegram, whatsapp, etc.
Locations []string `json:"-"`
}
// RequestEntry is one of the entries that are part of
// the "requests" key of a OONI report.
type RequestEntry struct {
Failure *string `json:"failure"`
Request HTTPRequest `json:"request"`
Response HTTPResponse `json:"response"`
T float64 `json:"t"`
TransactionID int64 `json:"transaction_id,omitempty"`
}
func addheaders(
source http.Header,
destList *[]HTTPHeader,
destMap *map[string]MaybeBinaryValue,
) {
for key, values := range source {
for index, value := range values {
value := MaybeBinaryValue{Value: value}
// With the map representation we can only represent a single
// value for every key. Hence the list representation.
if index == 0 {
(*destMap)[key] = value
}
*destList = append(*destList, HTTPHeader{
Key: key,
Value: value,
})
}
}
sort.Slice(*destList, func(i, j int) bool {
return (*destList)[i].Key < (*destList)[j].Key
})
}
// NewRequestList returns the list for "requests"
func NewRequestList(begin time.Time, events []trace.Event) []RequestEntry {
// OONI wants the last request to appear first
var out []RequestEntry
tmp := newRequestList(begin, events)
for i := len(tmp) - 1; i >= 0; i-- {
out = append(out, tmp[i])
}
return out
}
func newRequestList(begin time.Time, events []trace.Event) []RequestEntry {
var (
out []RequestEntry
entry RequestEntry
)
for _, ev := range events {
switch ev.Name {
case "http_transaction_start":
entry = RequestEntry{}
entry.T = ev.Time.Sub(begin).Seconds()
case "http_request_body_snapshot":
entry.Request.Body.Value = string(ev.Data)
entry.Request.BodyIsTruncated = ev.DataIsTruncated
case "http_request_metadata":
entry.Request.Headers = make(map[string]MaybeBinaryValue)
addheaders(
ev.HTTPHeaders, &entry.Request.HeadersList, &entry.Request.Headers)
entry.Request.Method = ev.HTTPMethod
entry.Request.URL = ev.HTTPURL
entry.Request.Transport = ev.Transport
case "http_response_metadata":
entry.Response.Headers = make(map[string]MaybeBinaryValue)
addheaders(
ev.HTTPHeaders, &entry.Response.HeadersList, &entry.Response.Headers)
entry.Response.Code = int64(ev.HTTPStatusCode)
entry.Response.Locations = ev.HTTPHeaders.Values("Location")
case "http_response_body_snapshot":
entry.Response.Body.Value = string(ev.Data)
entry.Response.BodyIsTruncated = ev.DataIsTruncated
case "http_transaction_done":
entry.Failure = NewFailure(ev.Err)
out = append(out, entry)
}
}
return out
}
// DNSAnswerEntry is the answer to a DNS query
type DNSAnswerEntry struct {
ASN int64 `json:"asn,omitempty"`
ASOrgName string `json:"as_org_name,omitempty"`
AnswerType string `json:"answer_type"`
Hostname string `json:"hostname,omitempty"`
IPv4 string `json:"ipv4,omitempty"`
IPv6 string `json:"ipv6,omitempty"`
TTL *uint32 `json:"ttl"`
}
// DNSQueryEntry is a DNS query with possibly an answer
type DNSQueryEntry struct {
Answers []DNSAnswerEntry `json:"answers"`
DialID int64 `json:"dial_id,omitempty"`
Engine string `json:"engine"`
Failure *string `json:"failure"`
Hostname string `json:"hostname"`
QueryType string `json:"query_type"`
ResolverHostname *string `json:"resolver_hostname"`
ResolverPort *string `json:"resolver_port"`
ResolverAddress string `json:"resolver_address"`
T float64 `json:"t"`
TransactionID int64 `json:"transaction_id,omitempty"`
}
type dnsQueryType string
// NewDNSQueriesList returns a list of DNS queries.
func NewDNSQueriesList(begin time.Time, events []trace.Event, dbpath string) []DNSQueryEntry {
// TODO(bassosimone): add support for CNAME lookups.
var out []DNSQueryEntry
for _, ev := range events {
if ev.Name != "resolve_done" {
continue
}
for _, qtype := range []dnsQueryType{"A", "AAAA"} {
entry := qtype.makequeryentry(begin, ev)
for _, addr := range ev.Addresses {
if qtype.ipoftype(addr) {
entry.Answers = append(
entry.Answers, qtype.makeanswerentry(addr, dbpath))
}
}
if len(entry.Answers) <= 0 && ev.Err == nil {
// This allows us to skip cases where the server does not have
// an IPv6 address but has an IPv4 address. Instead, when we
// receive an error, we want to track its existence. The main
// issue here is that we are cheating, because we are creating
// entries representing queries, but we don't know what the
// resolver actually did, especially the system resolver. So,
// this output is just our best guess.
continue
}
out = append(out, entry)
}
}
return out
}
func (qtype dnsQueryType) ipoftype(addr string) bool {
switch qtype {
case "A":
return strings.Contains(addr, ":") == false
case "AAAA":
return strings.Contains(addr, ":") == true
}
return false
}
func (qtype dnsQueryType) makeanswerentry(addr string, dbpath string) DNSAnswerEntry {
answer := DNSAnswerEntry{AnswerType: string(qtype)}
asn, org, _ := geolocate.LookupASN(dbpath, addr)
answer.ASN = int64(asn)
answer.ASOrgName = org
switch qtype {
case "A":
answer.IPv4 = addr
case "AAAA":
answer.IPv6 = addr
}
return answer
}
func (qtype dnsQueryType) makequeryentry(begin time.Time, ev trace.Event) DNSQueryEntry {
return DNSQueryEntry{
Engine: ev.Proto,
Failure: NewFailure(ev.Err),
Hostname: ev.Hostname,
QueryType: string(qtype),
ResolverAddress: ev.Address,
T: ev.Time.Sub(begin).Seconds(),
}
}
// NetworkEvent is a network event.
type NetworkEvent struct {
Address string `json:"address,omitempty"`
ConnID int64 `json:"conn_id,omitempty"`
DialID int64 `json:"dial_id,omitempty"`
Failure *string `json:"failure"`
NumBytes int64 `json:"num_bytes,omitempty"`
Operation string `json:"operation"`
Proto string `json:"proto,omitempty"`
T float64 `json:"t"`
TransactionID int64 `json:"transaction_id,omitempty"`
}
// NewNetworkEventsList returns a list of DNS queries.
func NewNetworkEventsList(begin time.Time, events []trace.Event) []NetworkEvent {
var out []NetworkEvent
for _, ev := range events {
if ev.Name == errorx.ConnectOperation {
out = append(out, NetworkEvent{
Address: ev.Address,
Failure: NewFailure(ev.Err),
Operation: ev.Name,
Proto: ev.Proto,
T: ev.Time.Sub(begin).Seconds(),
})
continue
}
if ev.Name == errorx.ReadOperation {
out = append(out, NetworkEvent{
Failure: NewFailure(ev.Err),
Operation: ev.Name,
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
continue
}
if ev.Name == errorx.WriteOperation {
out = append(out, NetworkEvent{
Failure: NewFailure(ev.Err),
Operation: ev.Name,
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
continue
}
if ev.Name == errorx.ReadFromOperation {
out = append(out, NetworkEvent{
Address: ev.Address,
Failure: NewFailure(ev.Err),
Operation: ev.Name,
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
continue
}
if ev.Name == errorx.WriteToOperation {
out = append(out, NetworkEvent{
Address: ev.Address,
Failure: NewFailure(ev.Err),
Operation: ev.Name,
NumBytes: int64(ev.NumBytes),
T: ev.Time.Sub(begin).Seconds(),
})
continue
}
out = append(out, NetworkEvent{
Failure: NewFailure(ev.Err),
Operation: ev.Name,
T: ev.Time.Sub(begin).Seconds(),
})
}
return out
}
// TLSHandshake contains TLS handshake data
type TLSHandshake struct {
CipherSuite string `json:"cipher_suite"`
ConnID int64 `json:"conn_id,omitempty"`
Failure *string `json:"failure"`
NegotiatedProtocol string `json:"negotiated_protocol"`
NoTLSVerify bool `json:"no_tls_verify"`
PeerCertificates []MaybeBinaryValue `json:"peer_certificates"`
ServerName string `json:"server_name"`
T float64 `json:"t"`
TLSVersion string `json:"tls_version"`
TransactionID int64 `json:"transaction_id,omitempty"`
}
// NewTLSHandshakesList creates a new TLSHandshakesList
func NewTLSHandshakesList(begin time.Time, events []trace.Event) []TLSHandshake {
var out []TLSHandshake
for _, ev := range events {
if !strings.Contains(ev.Name, "_handshake_done") {
continue
}
out = append(out, TLSHandshake{
CipherSuite: ev.TLSCipherSuite,
Failure: NewFailure(ev.Err),
NegotiatedProtocol: ev.TLSNegotiatedProto,
NoTLSVerify: ev.NoTLSVerify,
PeerCertificates: makePeerCerts(ev.TLSPeerCerts),
ServerName: ev.TLSServerName,
T: ev.Time.Sub(begin).Seconds(),
TLSVersion: ev.TLSVersion,
})
}
return out
}
func makePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) {
for _, e := range in {
out = append(out, MaybeBinaryValue{Value: string(e.Raw)})
}
return
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,8 @@
package archival
// DNSQueryType allows to access dnsQueryType from unit tests
type DNSQueryType = dnsQueryType
func (qtype dnsQueryType) IPOfType(addr string) bool {
return qtype.ipoftype(addr)
}
@@ -0,0 +1,54 @@
package bytecounter
import "github.com/ooni/probe-cli/v3/internal/engine/atomicx"
// Counter counts bytes sent and received.
type Counter struct {
Received *atomicx.Int64
Sent *atomicx.Int64
}
// New creates a new Counter.
func New() *Counter {
return &Counter{Received: atomicx.NewInt64(), Sent: atomicx.NewInt64()}
}
// CountBytesSent adds count to the bytes sent counter.
func (c *Counter) CountBytesSent(count int) {
c.Sent.Add(int64(count))
}
// CountKibiBytesSent adds 1024*count to the bytes sent counter.
func (c *Counter) CountKibiBytesSent(count float64) {
c.Sent.Add(int64(1024 * count))
}
// BytesSent returns the bytes sent so far.
func (c *Counter) BytesSent() int64 {
return c.Sent.Load()
}
// KibiBytesSent returns the KiB sent so far.
func (c *Counter) KibiBytesSent() float64 {
return float64(c.BytesSent()) / 1024
}
// CountBytesReceived adds count to the bytes received counter.
func (c *Counter) CountBytesReceived(count int) {
c.Received.Add(int64(count))
}
// CountKibiBytesReceived adds 1024*count to the bytes received counter.
func (c *Counter) CountKibiBytesReceived(count float64) {
c.Received.Add(int64(1024 * count))
}
// BytesReceived returns the bytes received so far.
func (c *Counter) BytesReceived() int64 {
return c.Received.Load()
}
// KibiBytesReceived returns the KiB received so far.
func (c *Counter) KibiBytesReceived() float64 {
return float64(c.BytesReceived()) / 1024
}
@@ -0,0 +1,27 @@
package bytecounter_test
import (
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
)
func TestGood(t *testing.T) {
counter := bytecounter.New()
counter.CountBytesReceived(16384)
counter.CountKibiBytesReceived(10)
counter.CountBytesSent(2048)
counter.CountKibiBytesSent(10)
if counter.BytesSent() != 12288 {
t.Fatal("invalid bytes sent")
}
if counter.BytesReceived() != 26624 {
t.Fatal("invalid bytes received")
}
if v := counter.KibiBytesSent(); v < 11.9 || v > 12.1 {
t.Fatal("invalid kibibytes sent")
}
if v := counter.KibiBytesReceived(); v < 25.9 || v > 26.1 {
t.Fatal("invalid kibibytes received")
}
}
@@ -0,0 +1,93 @@
package dialer
import (
"context"
"net"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
)
// ByteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you
// should make sure that you insert this dialer in the dialing chain.
//
// Bug
//
// This implementation cannot properly account for the bytes that are sent by
// persistent connections, because they strick to the counters set when the
// connection was established. This typically means we miss the bytes sent and
// received when submitting a measurement. Such bytes are specifically not
// see by the experiment specific byte counter.
//
// For this reason, this implementation may be heavily changed/removed.
type ByteCounterDialer struct {
Dialer
}
// DialContext implements Dialer.DialContext
func (d ByteCounterDialer) DialContext(
ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
exp := ContextExperimentByteCounter(ctx)
sess := ContextSessionByteCounter(ctx)
if exp == nil && sess == nil {
return conn, nil // no point in wrapping
}
return byteCounterConnWrapper{Conn: conn, exp: exp, sess: sess}, nil
}
type byteCounterSessionKey struct{}
// ContextSessionByteCounter retrieves the session byte counter from the context
func ContextSessionByteCounter(ctx context.Context) *bytecounter.Counter {
counter, _ := ctx.Value(byteCounterSessionKey{}).(*bytecounter.Counter)
return counter
}
// WithSessionByteCounter assigns the session byte counter to the context
func WithSessionByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context {
return context.WithValue(ctx, byteCounterSessionKey{}, counter)
}
type byteCounterExperimentKey struct{}
// ContextExperimentByteCounter retrieves the experiment byte counter from the context
func ContextExperimentByteCounter(ctx context.Context) *bytecounter.Counter {
counter, _ := ctx.Value(byteCounterExperimentKey{}).(*bytecounter.Counter)
return counter
}
// WithExperimentByteCounter assigns the experiment byte counter to the context
func WithExperimentByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context {
return context.WithValue(ctx, byteCounterExperimentKey{}, counter)
}
type byteCounterConnWrapper struct {
net.Conn
exp *bytecounter.Counter
sess *bytecounter.Counter
}
func (c byteCounterConnWrapper) Read(p []byte) (int, error) {
count, err := c.Conn.Read(p)
if c.exp != nil {
c.exp.CountBytesReceived(count)
}
if c.sess != nil {
c.sess.CountBytesReceived(count)
}
return count, err
}
func (c byteCounterConnWrapper) Write(p []byte) (int, error) {
count, err := c.Conn.Write(p)
if c.exp != nil {
c.exp.CountBytesSent(count)
}
if c.sess != nil {
c.sess.CountBytesSent(count)
}
return count, err
}
@@ -0,0 +1,81 @@
package dialer_test
import (
"context"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
func dorequest(ctx context.Context, url string) error {
txp := http.DefaultTransport.(*http.Transport).Clone()
defer txp.CloseIdleConnections()
dialer := dialer.ByteCounterDialer{Dialer: new(net.Dialer)}
txp.DialContext = dialer.DialContext
client := &http.Client{Transport: txp}
req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil)
if err != nil {
return err
}
resp, err := client.Do(req)
if err != nil {
return err
}
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
return err
}
return resp.Body.Close()
}
func TestByteCounterNormalUsage(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
sess := bytecounter.New()
ctx := context.Background()
ctx = dialer.WithSessionByteCounter(ctx, sess)
if err := dorequest(ctx, "http://www.google.com"); err != nil {
t.Fatal(err)
}
exp := bytecounter.New()
ctx = dialer.WithExperimentByteCounter(ctx, exp)
if err := dorequest(ctx, "http://facebook.com"); err != nil {
t.Fatal(err)
}
if sess.Received.Load() <= exp.Received.Load() {
t.Fatal("session should have received more than experiment")
}
if sess.Sent.Load() <= exp.Sent.Load() {
t.Fatal("session should have sent more than experiment")
}
}
func TestByteCounterNoHandlers(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
ctx := context.Background()
if err := dorequest(ctx, "http://www.google.com"); err != nil {
t.Fatal(err)
}
if err := dorequest(ctx, "http://facebook.com"); err != nil {
t.Fatal(err)
}
}
func TestByteCounterConnectFailure(t *testing.T) {
dialer := dialer.ByteCounterDialer{Dialer: dialer.EOFDialer{}}
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
+29
View File
@@ -0,0 +1,29 @@
package dialer
import (
"context"
"net"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
)
// Dialer is the interface we expect from a dialer
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// Resolver is the interface we expect from a resolver
type Resolver interface {
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
}
func safeLocalAddress(conn net.Conn) (s string) {
if conn != nil && conn.LocalAddr() != nil {
s = conn.LocalAddr().String()
}
return
}
func safeConnID(network string, conn net.Conn) int64 {
return connid.Compute(network, safeLocalAddress(conn))
}
+73
View File
@@ -0,0 +1,73 @@
package dialer
import (
"context"
"errors"
"net"
"strings"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
// DNSDialer is a dialer that uses the configured Resolver to resolver a
// domain name to IP addresses, and the configured Dialer to connect.
type DNSDialer struct {
Dialer
Resolver Resolver
}
// DialContext implements Dialer.DialContext.
func (d DNSDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
onlyhost, onlyport, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
ctx = dialid.WithDialID(ctx) // important to create before lookupHost
var addrs []string
addrs, err = d.LookupHost(ctx, onlyhost)
if err != nil {
return nil, err
}
var errorslist []error
for _, addr := range addrs {
target := net.JoinHostPort(addr, onlyport)
conn, err := d.Dialer.DialContext(ctx, network, target)
if err == nil {
return conn, nil
}
errorslist = append(errorslist, err)
}
return nil, ReduceErrors(errorslist)
}
// ReduceErrors finds a known error in a list of errors since it's probably most relevant
func ReduceErrors(errorslist []error) error {
if len(errorslist) == 0 {
return nil
}
// If we have a known error, let's consider this the real error
// since it's probably most relevant. Otherwise let's return the
// first considering that (1) local resolvers likely will give
// us IPv4 first and (2) also our resolver does that. So, in case
// the user has no IPv6 connectivity, an IPv6 error is going to
// appear later in the list of errors.
for _, err := range errorslist {
var wrapper *errorx.ErrWrapper
if errors.As(err, &wrapper) && !strings.HasPrefix(
err.Error(), "unknown_failure",
) {
return err
}
}
// TODO(bassosimone): handle this case in a better way
return errorslist[0]
}
// LookupHost implements Resolver.LookupHost
func (d DNSDialer) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {
return []string{hostname}, nil
}
return d.Resolver.LookupHost(ctx, hostname)
}
+171
View File
@@ -0,0 +1,171 @@
package dialer_test
import (
"context"
"errors"
"io"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
func TestDNSDialerNoPort(t *testing.T) {
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: new(net.Resolver)}
conn, err := dialer.DialContext(context.Background(), "tcp", "antani.ooni.nu")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected a nil conn here")
}
}
func TestDNSDialerLookupHostAddress(t *testing.T) {
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{
Err: errors.New("mocked error"),
}}
addrs, err := dialer.LookupHost(context.Background(), "1.1.1.1")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
t.Fatal("not the result we expected")
}
}
func TestDNSDialerLookupHostFailure(t *testing.T) {
expected := errors.New("mocked error")
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{
Err: expected,
}}
conn, err := dialer.DialContext(context.Background(), "tcp", "dns.google.com:853")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn")
}
}
type MockableResolver struct {
Addresses []string
Err error
}
func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
return r.Addresses, r.Err
}
func TestDNSDialerDialForSingleIPFails(t *testing.T) {
dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: new(net.Resolver)}
conn, err := dialer.DialContext(context.Background(), "tcp", "1.1.1.1:853")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn")
}
}
func TestDNSDialerDialForManyIPFails(t *testing.T) {
dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: MockableResolver{
Addresses: []string{"1.1.1.1", "8.8.8.8"},
}}
conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn")
}
}
func TestDNSDialerDialForManyIPSuccess(t *testing.T) {
dialer := dialer.DNSDialer{Dialer: dialer.EOFConnDialer{}, Resolver: MockableResolver{
Addresses: []string{"1.1.1.1", "8.8.8.8"},
}}
conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853")
if err != nil {
t.Fatal("expected nil error here")
}
if conn == nil {
t.Fatal("expected non-nil conn")
}
conn.Close()
}
func TestDNSDialerDialSetsDialID(t *testing.T) {
saver := &handlers.SavingHandler{}
ctx := modelx.WithMeasurementRoot(context.Background(), &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: saver,
})
dialer := dialer.DNSDialer{Dialer: dialer.EmitterDialer{
Dialer: dialer.EOFConnDialer{},
}, Resolver: MockableResolver{
Addresses: []string{"1.1.1.1", "8.8.8.8"},
}}
conn, err := dialer.DialContext(ctx, "tcp", "dot.dns:853")
if err != nil {
t.Fatal("expected nil error here")
}
if conn == nil {
t.Fatal("expected non-nil conn")
}
conn.Close()
events := saver.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
for _, ev := range events {
if ev.Connect != nil && ev.Connect.DialID == 0 {
t.Fatal("unexpected DialID")
}
}
}
func TestReduceErrors(t *testing.T) {
t.Run("no errors", func(t *testing.T) {
result := dialer.ReduceErrors(nil)
if result != nil {
t.Fatal("wrong result")
}
})
t.Run("single error", func(t *testing.T) {
err := errors.New("mocked error")
result := dialer.ReduceErrors([]error{err})
if result != err {
t.Fatal("wrong result")
}
})
t.Run("multiple errors", func(t *testing.T) {
err1 := errors.New("mocked error #1")
err2 := errors.New("mocked error #2")
result := dialer.ReduceErrors([]error{err1, err2})
if result.Error() != "mocked error #1" {
t.Fatal("wrong result")
}
})
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
err1 := errors.New("mocked error #1")
err2 := &errorx.ErrWrapper{
Failure: "unknown_failure: antani",
}
err3 := &errorx.ErrWrapper{
Failure: errorx.FailureConnectionRefused,
}
err4 := errors.New("mocked error #3")
result := dialer.ReduceErrors([]error{err1, err2, err3, err4})
if result.Error() != errorx.FailureConnectionRefused {
t.Fatal("wrong result")
}
})
}
+103
View File
@@ -0,0 +1,103 @@
package dialer
import (
"context"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
)
// EmitterDialer is a Dialer that emits events
type EmitterDialer struct {
Dialer
}
// DialContext implements Dialer.DialContext
func (d EmitterDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
start := time.Now()
conn, err := d.Dialer.DialContext(ctx, network, address)
stop := time.Now()
root := modelx.ContextMeasurementRootOrDefault(ctx)
root.Handler.OnMeasurement(modelx.Measurement{
Connect: &modelx.ConnectEvent{
ConnID: safeConnID(network, conn),
DialID: dialid.ContextDialID(ctx),
DurationSinceBeginning: stop.Sub(root.Beginning),
Error: err,
Network: network,
RemoteAddress: address,
SyscallDuration: stop.Sub(start),
TransactionID: transactionid.ContextTransactionID(ctx),
},
})
if err != nil {
return nil, err
}
return EmitterConn{
Conn: conn,
Beginning: root.Beginning,
Handler: root.Handler,
ID: safeConnID(network, conn),
}, nil
}
// EmitterConn is a net.Conn used to emit events
type EmitterConn struct {
net.Conn
Beginning time.Time
Handler modelx.Handler
ID int64
}
// Read implements net.Conn.Read
func (c EmitterConn) Read(b []byte) (n int, err error) {
start := time.Now()
n, err = c.Conn.Read(b)
stop := time.Now()
c.Handler.OnMeasurement(modelx.Measurement{
Read: &modelx.ReadEvent{
ConnID: c.ID,
DurationSinceBeginning: stop.Sub(c.Beginning),
Error: err,
NumBytes: int64(n),
SyscallDuration: stop.Sub(start),
},
})
return
}
// Write implements net.Conn.Write
func (c EmitterConn) Write(b []byte) (n int, err error) {
start := time.Now()
n, err = c.Conn.Write(b)
stop := time.Now()
c.Handler.OnMeasurement(modelx.Measurement{
Write: &modelx.WriteEvent{
ConnID: c.ID,
DurationSinceBeginning: stop.Sub(c.Beginning),
Error: err,
NumBytes: int64(n),
SyscallDuration: stop.Sub(start),
},
})
return
}
// Close implements net.Conn.Close
func (c EmitterConn) Close() (err error) {
start := time.Now()
err = c.Conn.Close()
stop := time.Now()
c.Handler.OnMeasurement(modelx.Measurement{
Close: &modelx.CloseEvent{
ConnID: c.ID,
DurationSinceBeginning: stop.Sub(c.Beginning),
Error: err,
SyscallDuration: stop.Sub(start),
},
})
return
}
+166
View File
@@ -0,0 +1,166 @@
package dialer_test
import (
"context"
"errors"
"io"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
func TestEmitterFailure(t *testing.T) {
ctx := dialid.WithDialID(context.Background())
saver := &handlers.SavingHandler{}
ctx = modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: saver,
})
ctx = transactionid.WithTransactionID(ctx)
d := dialer.EmitterDialer{Dialer: dialer.EOFDialer{}}
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected a nil conn here")
}
events := saver.Read()
if len(events) != 1 {
t.Fatal("unexpected number of events saved")
}
if events[0].Connect == nil {
t.Fatal("expected non nil Connect")
}
conninfo := events[0].Connect
if conninfo.ConnID != 0 {
t.Fatal("unexpected ConnID value")
}
emitterCheckConnectEventCommon(t, conninfo, io.EOF)
}
func emitterCheckConnectEventCommon(
t *testing.T, conninfo *modelx.ConnectEvent, err error) {
if conninfo.DialID == 0 {
t.Fatal("unexpected DialID value")
}
if conninfo.DurationSinceBeginning == 0 {
t.Fatal("unexpected DurationSinceBeginning value")
}
if !errors.Is(conninfo.Error, err) {
t.Fatal("unexpected Error value")
}
if conninfo.Network != "tcp" {
t.Fatal("unexpected Network value")
}
if conninfo.RemoteAddress != "www.google.com:443" {
t.Fatal("unexpected Network value")
}
if conninfo.SyscallDuration == 0 {
t.Fatal("unexpected SyscallDuration value")
}
if conninfo.TransactionID == 0 {
t.Fatal("unexpected TransactionID value")
}
}
func TestEmitterSuccess(t *testing.T) {
ctx := dialid.WithDialID(context.Background())
saver := &handlers.SavingHandler{}
ctx = modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: saver,
})
ctx = transactionid.WithTransactionID(ctx)
d := dialer.EmitterDialer{Dialer: dialer.EOFConnDialer{}}
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
if err != nil {
t.Fatal("we expected no error")
}
if conn == nil {
t.Fatal("expected a non-nil conn here")
}
conn.Read(nil)
conn.Write(nil)
conn.Close()
events := saver.Read()
if len(events) != 4 {
t.Fatal("unexpected number of events saved")
}
if events[0].Connect == nil {
t.Fatal("expected non nil Connect")
}
conninfo := events[0].Connect
if conninfo.ConnID == 0 {
t.Fatal("unexpected ConnID value")
}
emitterCheckConnectEventCommon(t, conninfo, nil)
if events[1].Read == nil {
t.Fatal("expected non nil Read")
}
emitterCheckReadEvent(t, events[1].Read)
if events[2].Write == nil {
t.Fatal("expected non nil Write")
}
emitterCheckWriteEvent(t, events[2].Write)
if events[3].Close == nil {
t.Fatal("expected non nil Close")
}
emitterCheckCloseEvent(t, events[3].Close)
}
func emitterCheckReadEvent(t *testing.T, ev *modelx.ReadEvent) {
if ev.ConnID == 0 {
t.Fatal("unexpected ConnID")
}
if ev.DurationSinceBeginning == 0 {
t.Fatal("unexpected DurationSinceBeginning")
}
if !errors.Is(ev.Error, io.EOF) {
t.Fatal("unexpected Error")
}
if ev.NumBytes != 0 {
t.Fatal("unexpected NumBytes")
}
if ev.SyscallDuration == 0 {
t.Fatal("unexpected SyscallDuration")
}
}
func emitterCheckWriteEvent(t *testing.T, ev *modelx.WriteEvent) {
if ev.ConnID == 0 {
t.Fatal("unexpected ConnID")
}
if ev.DurationSinceBeginning == 0 {
t.Fatal("unexpected DurationSinceBeginning")
}
if !errors.Is(ev.Error, io.EOF) {
t.Fatal("unexpected Error")
}
if ev.NumBytes != 0 {
t.Fatal("unexpected NumBytes")
}
if ev.SyscallDuration == 0 {
t.Fatal("unexpected SyscallDuration")
}
}
func emitterCheckCloseEvent(t *testing.T, ev *modelx.CloseEvent) {
if ev.ConnID == 0 {
t.Fatal("unexpected ConnID")
}
if ev.DurationSinceBeginning == 0 {
t.Fatal("unexpected DurationSinceBeginning")
}
if !errors.Is(ev.Error, io.EOF) {
t.Fatal("unexpected Error")
}
if ev.SyscallDuration == 0 {
t.Fatal("unexpected SyscallDuration")
}
}
+80
View File
@@ -0,0 +1,80 @@
package dialer
import (
"context"
"crypto/tls"
"io"
"net"
"time"
)
type EOFDialer struct{}
func (EOFDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return nil, io.EOF
}
type EOFConnDialer struct{}
func (EOFConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return EOFConn{}, nil
}
type EOFConn struct {
net.Conn
}
func (EOFConn) Read(p []byte) (int, error) {
time.Sleep(10 * time.Microsecond)
return 0, io.EOF
}
func (EOFConn) Write(p []byte) (int, error) {
time.Sleep(10 * time.Microsecond)
return 0, io.EOF
}
func (EOFConn) Close() error {
time.Sleep(10 * time.Microsecond)
return io.EOF
}
func (EOFConn) LocalAddr() net.Addr {
return EOFAddr{}
}
func (EOFConn) RemoteAddr() net.Addr {
return EOFAddr{}
}
func (EOFConn) SetDeadline(t time.Time) error {
return nil
}
func (EOFConn) SetReadDeadline(t time.Time) error {
return nil
}
func (EOFConn) SetWriteDeadline(t time.Time) error {
return nil
}
type EOFAddr struct{}
func (EOFAddr) Network() string {
return "tcp"
}
func (EOFAddr) String() string {
return "127.0.0.1:1234"
}
type EOFTLSHandshaker struct{}
func (EOFTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
time.Sleep(10 * time.Microsecond)
return nil, tls.ConnectionState{}, io.EOF
}
@@ -0,0 +1,75 @@
package dialer
import (
"context"
"net"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
// ErrorWrapperDialer is a dialer that performs err wrapping
type ErrorWrapperDialer struct {
Dialer
}
// DialContext implements Dialer.DialContext
func (d ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
dialID := dialid.ContextDialID(ctx)
conn, err := d.Dialer.DialContext(ctx, network, address)
err = errorx.SafeErrWrapperBuilder{
// ConnID does not make any sense if we've failed and the error
// does not make any sense (and is nil) if we succeded.
DialID: dialID,
Error: err,
Operation: errorx.ConnectOperation,
}.MaybeBuild()
if err != nil {
return nil, err
}
return &ErrorWrapperConn{
Conn: conn, ConnID: safeConnID(network, conn), DialID: dialID}, nil
}
// ErrorWrapperConn is a net.Conn that performs error wrapping.
type ErrorWrapperConn struct {
net.Conn
ConnID int64
DialID int64
}
// Read implements net.Conn.Read
func (c ErrorWrapperConn) Read(b []byte) (n int, err error) {
n, err = c.Conn.Read(b)
err = errorx.SafeErrWrapperBuilder{
ConnID: c.ConnID,
DialID: c.DialID,
Error: err,
Operation: errorx.ReadOperation,
}.MaybeBuild()
return
}
// Write implements net.Conn.Write
func (c ErrorWrapperConn) Write(b []byte) (n int, err error) {
n, err = c.Conn.Write(b)
err = errorx.SafeErrWrapperBuilder{
ConnID: c.ConnID,
DialID: c.DialID,
Error: err,
Operation: errorx.WriteOperation,
}.MaybeBuild()
return
}
// Close implements net.Conn.Close
func (c ErrorWrapperConn) Close() (err error) {
err = c.Conn.Close()
err = errorx.SafeErrWrapperBuilder{
ConnID: c.ConnID,
DialID: c.DialID,
Error: err,
Operation: errorx.CloseOperation,
}.MaybeBuild()
return
}
@@ -0,0 +1,66 @@
package dialer_test
import (
"context"
"errors"
"io"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
func TestErrorWrapperFailure(t *testing.T) {
ctx := dialid.WithDialID(context.Background())
d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFDialer{}}
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
if conn != nil {
t.Fatal("expected a nil conn here")
}
errorWrapperCheckErr(t, err, errorx.ConnectOperation)
}
func errorWrapperCheckErr(t *testing.T, err error, op string) {
if !errors.Is(err, io.EOF) {
t.Fatal("expected another error here")
}
var errWrapper *errorx.ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("cannot cast to ErrWrapper")
}
if errWrapper.DialID == 0 {
t.Fatal("unexpected DialID")
}
if errWrapper.Operation != op {
t.Fatal("unexpected Operation")
}
if errWrapper.Failure != errorx.FailureEOFError {
t.Fatal("unexpected failure")
}
}
func TestErrorWrapperSuccess(t *testing.T) {
ctx := dialid.WithDialID(context.Background())
d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFConnDialer{}}
conn, err := d.DialContext(ctx, "tcp", "www.google.com")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn here")
}
count, err := conn.Read(nil)
errorWrapperCheckIOResult(t, count, err, errorx.ReadOperation)
count, err = conn.Write(nil)
errorWrapperCheckIOResult(t, count, err, errorx.WriteOperation)
err = conn.Close()
errorWrapperCheckErr(t, err, errorx.CloseOperation)
}
func errorWrapperCheckIOResult(t *testing.T, count int, err error, op string) {
if count != 0 {
t.Fatal("expected nil count here")
}
errorWrapperCheckErr(t, err, op)
}
+71
View File
@@ -0,0 +1,71 @@
package dialer
import (
"context"
"io"
"net"
"time"
)
type FakeDialer struct {
Conn net.Conn
Err error
}
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return d.Conn, d.Err
}
type FakeConn struct {
ReadError error
ReadData []byte
SetDeadlineError error
SetReadDeadlineError error
SetWriteDeadlineError error
WriteError error
}
func (c *FakeConn) Read(b []byte) (int, error) {
if len(c.ReadData) > 0 {
n := copy(b, c.ReadData)
c.ReadData = c.ReadData[n:]
return n, nil
}
if c.ReadError != nil {
return 0, c.ReadError
}
return 0, io.EOF
}
func (c *FakeConn) Write(b []byte) (n int, err error) {
if c.WriteError != nil {
return 0, c.WriteError
}
n = len(b)
return
}
func (*FakeConn) Close() (err error) {
return
}
func (*FakeConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (*FakeConn) RemoteAddr() net.Addr {
return &net.TCPAddr{}
}
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
return c.SetDeadlineError
}
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
return c.SetReadDeadlineError
}
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
return c.SetWriteDeadlineError
}
@@ -0,0 +1,57 @@
package dialer_test
import (
"context"
"net"
"net/http"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
func TestTLSDialerSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
log.SetLevel(log.DebugLevel)
dialer := dialer.TLSDialer{Dialer: new(net.Dialer),
TLSHandshaker: dialer.LoggingTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
Logger: log.Log,
},
}
txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) {
// AlpineLinux edge is still using Go 1.13. We cannot switch to
// using DialTLSContext here as we'd like to until either Alpine
// switches to Go 1.14 or we drop the MK dependency.
return dialer.DialTLSContext(context.Background(), network, address)
}}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
}
func TestDNSDialerSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
log.SetLevel(log.DebugLevel)
dialer := dialer.DNSDialer{
Dialer: dialer.LoggingDialer{
Dialer: new(net.Dialer),
Logger: log.Log,
},
Resolver: new(net.Resolver),
}
txp := &http.Transport{DialContext: dialer.DialContext}
client := &http.Client{Transport: txp}
resp, err := client.Get("http://www.google.com")
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
}
+56
View File
@@ -0,0 +1,56 @@
package dialer
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
)
// Logger is the logger assumed by this package
type Logger interface {
Debugf(format string, v ...interface{})
Debug(message string)
}
// LoggingDialer is a Dialer with logging
type LoggingDialer struct {
Dialer
Logger Logger
}
// DialContext implements Dialer.DialContext
func (d LoggingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
d.Logger.Debugf("dial %s/%s...", address, network)
start := time.Now()
conn, err := d.Dialer.DialContext(ctx, network, address)
stop := time.Now()
d.Logger.Debugf("dial %s/%s... %+v in %s", address, network, err, stop.Sub(start))
return conn, err
}
// LoggingTLSHandshaker is a TLSHandshaker with logging
type LoggingTLSHandshaker struct {
TLSHandshaker
Logger Logger
}
// Handshake implements Handshaker.Handshake
func (h LoggingTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
h.Logger.Debugf("tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos)
start := time.Now()
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
stop := time.Now()
h.Logger.Debugf(
"tls {sni=%s next=%+v}... %+v in %s {next=%s cipher=%s v=%s}", config.ServerName,
config.NextProtos, err, stop.Sub(start), state.NegotiatedProtocol,
tlsx.CipherSuiteString(state.CipherSuite), tlsx.VersionString(state.Version))
return tlsconn, state, err
}
var _ Dialer = LoggingDialer{}
var _ TLSHandshaker = LoggingTLSHandshaker{}
@@ -0,0 +1,42 @@
package dialer_test
import (
"context"
"crypto/tls"
"errors"
"io"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
func TestLoggingDialerFailure(t *testing.T) {
d := dialer.LoggingDialer{
Dialer: dialer.EOFDialer{},
Logger: log.Log,
}
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
func TestLoggingTLSHandshakerFailure(t *testing.T) {
h := dialer.LoggingTLSHandshaker{
TLSHandshaker: dialer.EOFTLSHandshaker{},
Logger: log.Log,
}
tlsconn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
ServerName: "www.google.com",
})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if tlsconn != nil {
t.Fatal("expected nil tlsconn here")
}
}
+94
View File
@@ -0,0 +1,94 @@
package dialer
import (
"context"
"errors"
"net"
"net/url"
"golang.org/x/net/proxy"
)
// ProxyDialer is a dialer that uses a proxy. If the ProxyURL is not configured, this
// dialer is a passthrough for the next Dialer in chain. Otherwise, it will internally
// create a SOCKS5 dialer that will connect to the proxy using the underlying Dialer.
//
// As a special case, you can force a proxy to be used only extemporarily. To this end,
// you can use the WithProxyURL function, to store the proxy URL in the context. This
// will take precedence over any otherwise configured proxy. The use case for this
// functionality is when you need a tunnel to contact OONI probe services.
type ProxyDialer struct {
Dialer
ProxyURL *url.URL
}
type proxyKey struct{}
// ContextProxyURL retrieves the proxy URL from the context. This is mainly used
// to force a tunnel when we fail contacting OONI probe services otherwise.
func ContextProxyURL(ctx context.Context) *url.URL {
url, _ := ctx.Value(proxyKey{}).(*url.URL)
return url
}
// WithProxyURL assigns the proxy URL to the context
func WithProxyURL(ctx context.Context, url *url.URL) context.Context {
return context.WithValue(ctx, proxyKey{}, url)
}
// DialContext implements Dialer.DialContext
func (d ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
url := ContextProxyURL(ctx) // context URL takes precendence
if url == nil {
url = d.ProxyURL
}
if url == nil {
return d.Dialer.DialContext(ctx, network, address)
}
if url.Scheme != "socks5" {
return nil, errors.New("Scheme is not socks5")
}
// the code at proxy/socks5.go never fails; see https://git.io/JfJ4g
child, _ := proxy.SOCKS5(
network, url.Host, nil, proxyDialerWrapper{Dialer: d.Dialer})
return d.dial(ctx, child, network, address)
}
func (d ProxyDialer) dial(
ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) {
connch := make(chan net.Conn)
errch := make(chan error, 1)
go func() {
conn, err := child.Dial(network, address)
if err != nil {
errch <- err
return
}
select {
case connch <- conn:
default:
conn.Close()
}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errch:
return nil, err
case conn := <-connch:
return conn, nil
}
}
// proxyDialerWrapper is required because SOCKS5 expects a Dialer.Dial type but internally
// it checks whether DialContext is available and prefers that. So, we need to use this
// structure to cast our inner Dialer the way in which SOCKS5 likes it.
//
// See https://git.io/JfJ4g.
type proxyDialerWrapper struct {
Dialer
}
func (d proxyDialerWrapper) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
@@ -0,0 +1,15 @@
package dialer
import (
"context"
"net"
"golang.org/x/net/proxy"
)
type ProxyDialerWrapper = proxyDialerWrapper
func (d ProxyDialer) DialContextWithDialer(
ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) {
return d.dial(ctx, child, network, address)
}
+152
View File
@@ -0,0 +1,152 @@
package dialer_test
import (
"context"
"errors"
"io"
"net/url"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
func TestProxyDialerDialContextNoProxyURL(t *testing.T) {
expected := errors.New("mocked error")
d := dialer.ProxyDialer{
Dialer: dialer.FakeDialer{Err: expected},
}
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal(err)
}
if conn != nil {
t.Fatal("conn is not nil")
}
}
func TestProxyDialerContextTakesPrecedence(t *testing.T) {
expected := errors.New("mocked error")
d := dialer.ProxyDialer{
Dialer: dialer.FakeDialer{Err: expected},
ProxyURL: &url.URL{Scheme: "antani"},
}
ctx := context.Background()
ctx = dialer.WithProxyURL(ctx, &url.URL{Scheme: "socks5", Host: "[::1]:443"})
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("conn is not nil")
}
}
func TestProxyDialerDialContextInvalidScheme(t *testing.T) {
d := dialer.ProxyDialer{
Dialer: dialer.FakeDialer{},
ProxyURL: &url.URL{Scheme: "antani"},
}
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
if err.Error() != "Scheme is not socks5" {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("conn is not nil")
}
}
func TestProxyDialerDialContextWithEOF(t *testing.T) {
d := dialer.ProxyDialer{
Dialer: dialer.FakeDialer{
Err: io.EOF,
},
ProxyURL: &url.URL{Scheme: "socks5"},
}
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("conn is not nil")
}
}
func TestProxyDialerDialContextWithContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // immediately fail
d := dialer.ProxyDialer{
Dialer: dialer.FakeDialer{
Err: io.EOF,
},
ProxyURL: &url.URL{Scheme: "socks5"},
}
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
if !errors.Is(err, context.Canceled) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("conn is not nil")
}
}
func TestProxyDialerDialContextWithDialerSuccess(t *testing.T) {
d := dialer.ProxyDialer{
Dialer: dialer.FakeDialer{
Conn: &dialer.FakeConn{
ReadError: io.EOF,
WriteError: io.EOF,
},
},
ProxyURL: &url.URL{Scheme: "socks5"},
}
conn, err := d.DialContextWithDialer(
context.Background(), dialer.ProxyDialerWrapper{
Dialer: d.Dialer,
}, "tcp", "www.google.com:443")
if err != nil {
t.Fatal(err)
}
conn.Close()
}
func TestProxyDialerDialContextWithDialerCanceledContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// Stop immediately. The FakeDialer sleeps for some microseconds so
// it is much more likely we immediately exit with done context. The
// arm where we receive the conn is much less likely.
cancel()
d := dialer.ProxyDialer{
Dialer: dialer.FakeDialer{
Conn: &dialer.FakeConn{
ReadError: io.EOF,
WriteError: io.EOF,
},
},
ProxyURL: &url.URL{Scheme: "socks5"},
}
conn, err := d.DialContextWithDialer(
ctx, dialer.ProxyDialerWrapper{
Dialer: d.Dialer,
}, "tcp", "www.google.com:443")
if !errors.Is(err, context.Canceled) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
func TestProxyDialerWrapper(t *testing.T) {
d := dialer.ProxyDialerWrapper{
Dialer: dialer.FakeDialer{
Err: io.EOF,
},
}
conn, err := d.Dial("tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("conn is not nil")
}
}
+125
View File
@@ -0,0 +1,125 @@
package dialer
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// SaverDialer saves events occurring during the dial
type SaverDialer struct {
Dialer
Saver *trace.Saver
}
// DialContext implements Dialer.DialContext
func (d SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
start := time.Now()
conn, err := d.Dialer.DialContext(ctx, network, address)
stop := time.Now()
d.Saver.Write(trace.Event{
Address: address,
Duration: stop.Sub(start),
Err: err,
Name: errorx.ConnectOperation,
Proto: network,
Time: stop,
})
return conn, err
}
// SaverTLSHandshaker saves events occurring during the handshake
type SaverTLSHandshaker struct {
TLSHandshaker
Saver *trace.Saver
}
// Handshake implements TLSHandshaker.Handshake
func (h SaverTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
start := time.Now()
h.Saver.Write(trace.Event{
Name: "tls_handshake_start",
NoTLSVerify: config.InsecureSkipVerify,
TLSNextProtos: config.NextProtos,
TLSServerName: config.ServerName,
Time: start,
})
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
stop := time.Now()
h.Saver.Write(trace.Event{
Duration: stop.Sub(start),
Err: err,
Name: "tls_handshake_done",
NoTLSVerify: config.InsecureSkipVerify,
TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: config.NextProtos,
TLSPeerCerts: trace.PeerCerts(state, err),
TLSServerName: config.ServerName,
TLSVersion: tlsx.VersionString(state.Version),
Time: stop,
})
return tlsconn, state, err
}
// SaverConnDialer wraps the returned connection such that we
// collect all the read/write events that occur.
type SaverConnDialer struct {
Dialer
Saver *trace.Saver
}
// DialContext implements Dialer.DialContext
func (d SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
return saverConn{saver: d.Saver, Conn: conn}, nil
}
type saverConn struct {
net.Conn
saver *trace.Saver
}
func (c saverConn) Read(p []byte) (int, error) {
start := time.Now()
count, err := c.Conn.Read(p)
stop := time.Now()
c.saver.Write(trace.Event{
Data: p[:count],
Duration: stop.Sub(start),
Err: err,
NumBytes: count,
Name: errorx.ReadOperation,
Time: stop,
})
return count, err
}
func (c saverConn) Write(p []byte) (int, error) {
start := time.Now()
count, err := c.Conn.Write(p)
stop := time.Now()
c.saver.Write(trace.Event{
Data: p[:count],
Duration: stop.Sub(start),
Err: err,
NumBytes: count,
Name: errorx.WriteOperation,
Time: stop,
})
return count, err
}
var _ Dialer = SaverDialer{}
var _ TLSHandshaker = SaverTLSHandshaker{}
var _ net.Conn = saverConn{}
+371
View File
@@ -0,0 +1,371 @@
package dialer_test
import (
"context"
"crypto/tls"
"errors"
"net"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
func TestSaverDialerFailure(t *testing.T) {
expected := errors.New("mocked error")
saver := &trace.Saver{}
dlr := dialer.SaverDialer{
Dialer: dialer.FakeDialer{
Err: expected,
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal("expected another error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected a single event here")
}
if ev[0].Address != "www.google.com:443" {
t.Fatal("unexpected Address")
}
if ev[0].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[0].Err, expected) {
t.Fatal("unexpected Err")
}
if ev[0].Name != errorx.ConnectOperation {
t.Fatal("unexpected Name")
}
if ev[0].Proto != "tcp" {
t.Fatal("unexpected Proto")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
}
func TestSaverConnDialerFailure(t *testing.T) {
expected := errors.New("mocked error")
saver := &trace.Saver{}
dlr := dialer.SaverConnDialer{
Dialer: dialer.FakeDialer{
Err: expected,
},
Saver: saver,
}
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
// This is the most common use case for collecting reads, writes
if testing.Short() {
t.Skip("skip test in short mode")
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := dialer.TLSDialer{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: dialer.SaverConnDialer{
Dialer: new(net.Dialer),
Saver: saver,
},
TLSHandshaker: dialer.SaverTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
Saver: saver,
},
}
// Implementation note: we don't close the connection here because it is
// very handy to have the last event being the end of the handshake
_, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
if err != nil {
t.Fatal(err)
}
ev := saver.Read()
if len(ev) < 4 {
// it's a bit tricky to be sure about the right number of
// events because network conditions may influence that
t.Fatal("unexpected number of events")
}
if ev[0].Name != "tls_handshake_start" {
t.Fatal("unexpected Name")
}
if ev[0].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("unexpected Time")
}
last := len(ev) - 1
for idx := 1; idx < last; idx++ {
if ev[idx].Data == nil {
t.Fatal("unexpected Data")
}
if ev[idx].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[idx].Err != nil {
t.Fatal("unexpected Err")
}
if ev[idx].NumBytes <= 0 {
t.Fatal("unexpected NumBytes")
}
switch ev[idx].Name {
case errorx.ReadOperation, errorx.WriteOperation:
default:
t.Fatal("unexpected Name")
}
if ev[idx].Time.Before(ev[idx-1].Time) {
t.Fatal("unexpected Time")
}
}
if ev[last].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[last].Err != nil {
t.Fatal("unexpected Err")
}
if ev[last].Name != "tls_handshake_done" {
t.Fatal("unexpected Name")
}
if ev[last].TLSCipherSuite == "" {
t.Fatal("unexpected TLSCipherSuite")
}
if ev[last].TLSNegotiatedProto != "h2" {
t.Fatal("unexpected TLSNegotiatedProto")
}
if !reflect.DeepEqual(ev[last].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[last].TLSPeerCerts == nil {
t.Fatal("unexpected TLSPeerCerts")
}
if ev[last].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if ev[last].TLSVersion == "" {
t.Fatal("unexpected TLSVersion")
}
if ev[last].Time.Before(ev[last-1].Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverTLSHandshakerSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
nextprotos := []string{"h2"}
saver := &trace.Saver{}
tlsdlr := dialer.TLSDialer{
Config: &tls.Config{NextProtos: nextprotos},
Dialer: new(net.Dialer),
TLSHandshaker: dialer.SaverTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
if err != nil {
t.Fatal(err)
}
conn.Close()
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("unexpected number of events")
}
if ev[0].Name != "tls_handshake_start" {
t.Fatal("unexpected Name")
}
if ev[0].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("unexpected Time")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name != "tls_handshake_done" {
t.Fatal("unexpected Name")
}
if ev[1].TLSCipherSuite == "" {
t.Fatal("unexpected TLSCipherSuite")
}
if ev[1].TLSNegotiatedProto != "h2" {
t.Fatal("unexpected TLSNegotiatedProto")
}
if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[1].TLSPeerCerts == nil {
t.Fatal("unexpected TLSPeerCerts")
}
if ev[1].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if ev[1].TLSVersion == "" {
t.Fatal("unexpected TLSVersion")
}
if ev[1].Time.Before(ev[0].Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverTLSHandshakerHostnameError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := dialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: dialer.SaverTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "wrong.host.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := dialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: dialer.SaverTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "expired.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := dialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: dialer.SaverTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
tlsdlr := dialer.TLSDialer{
Config: &tls.Config{InsecureSkipVerify: true},
Dialer: new(net.Dialer),
TLSHandshaker: dialer.SaverTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
Saver: saver,
},
}
conn, err := tlsdlr.DialTLSContext(
context.Background(), "tcp", "self-signed.badssl.com:443")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("expected non-nil conn here")
}
conn.Close()
for _, ev := range saver.Read() {
if ev.Name != "tls_handshake_done" {
continue
}
if ev.NoTLSVerify != true {
t.Fatal("expected NoTLSVerify to be true")
}
if len(ev.TLSPeerCerts) < 1 {
t.Fatal("expected at least a certificate here")
}
}
}
@@ -0,0 +1,21 @@
// +build !shaping
package dialer
import (
"context"
"net"
)
// ShapingDialer ensures we don't use too much bandwidth
// when using integration tests at GitHub. To select
// the implementation with shaping use `-tags shaping`.
type ShapingDialer struct {
Dialer
}
// DialContext implements Dialer.DialContext
func (d ShapingDialer) DialContext(
ctx context.Context, network, address string) (net.Conn, error) {
return d.Dialer.DialContext(ctx, network, address)
}
@@ -0,0 +1,40 @@
// +build shaping
package dialer
import (
"context"
"net"
"time"
)
// ShapingDialer ensures we don't use too much bandwidth
// when using integration tests at GitHub. To select
// the implementation with shaping use `-tags shaping`.
type ShapingDialer struct {
Dialer
}
// DialContext implements Dialer.DialContext
func (d ShapingDialer) DialContext(
ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
return &shapingConn{Conn: conn}, nil
}
type shapingConn struct {
net.Conn
}
func (c shapingConn) Read(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return c.Conn.Read(p)
}
func (c shapingConn) Write(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return c.Conn.Write(p)
}
@@ -0,0 +1,27 @@
package dialer_test
import (
"net"
"net/http"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
func TestGood(t *testing.T) {
txp := netx.NewHTTPTransport(netx.Config{
Dialer: dialer.ShapingDialer{
Dialer: new(net.Dialer),
},
})
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com/")
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("expected nil response here")
}
resp.Body.Close()
}
+24
View File
@@ -0,0 +1,24 @@
package dialer
import (
"context"
"net"
"time"
)
// TimeoutDialer is a Dialer that enforces a timeout
type TimeoutDialer struct {
Dialer
ConnectTimeout time.Duration // default: 30 seconds
}
// DialContext implements Dialer.DialContext
func (d TimeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
timeout := 30 * time.Second
if d.ConnectTimeout != 0 {
timeout = d.ConnectTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return d.Dialer.DialContext(ctx, network, address)
}
@@ -0,0 +1,34 @@
package dialer_test
import (
"context"
"errors"
"io"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
type SlowDialer struct{}
func (SlowDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(30 * time.Second):
return nil, io.EOF
}
}
func TestTimeoutDialer(t *testing.T) {
d := dialer.TimeoutDialer{Dialer: SlowDialer{}, ConnectTimeout: time.Second}
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn here")
}
}
+141
View File
@@ -0,0 +1,141 @@
package dialer
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
// TLSHandshaker is the generic TLS handshaker
type TLSHandshaker interface {
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}
// SystemTLSHandshaker is the system TLS handshaker.
type SystemTLSHandshaker struct{}
// Handshake implements Handshaker.Handshake
func (h SystemTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn := tls.Client(conn, config)
if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
}
// TimeoutTLSHandshaker is a TLSHandshaker with timeout
type TimeoutTLSHandshaker struct {
TLSHandshaker
HandshakeTimeout time.Duration // default: 10 second
}
// Handshake implements Handshaker.Handshake
func (h TimeoutTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
timeout := 10 * time.Second
if h.HandshakeTimeout != 0 {
timeout = h.HandshakeTimeout
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, tls.ConnectionState{}, err
}
tlsconn, connstate, err := h.TLSHandshaker.Handshake(ctx, conn, config)
conn.SetDeadline(time.Time{})
return tlsconn, connstate, err
}
// ErrorWrapperTLSHandshaker wraps the returned error to be an OONI error
type ErrorWrapperTLSHandshaker struct {
TLSHandshaker
}
// Handshake implements Handshaker.Handshake
func (h ErrorWrapperTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
err = errorx.SafeErrWrapperBuilder{
ConnID: connID,
Error: err,
Operation: errorx.TLSHandshakeOperation,
}.MaybeBuild()
return tlsconn, state, err
}
// EmitterTLSHandshaker emits events using the MeasurementRoot
type EmitterTLSHandshaker struct {
TLSHandshaker
}
// Handshake implements Handshaker.Handshake
func (h EmitterTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
root := modelx.ContextMeasurementRootOrDefault(ctx)
root.Handler.OnMeasurement(modelx.Measurement{
TLSHandshakeStart: &modelx.TLSHandshakeStartEvent{
ConnID: connID,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
SNI: config.ServerName,
},
})
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
root.Handler.OnMeasurement(modelx.Measurement{
TLSHandshakeDone: &modelx.TLSHandshakeDoneEvent{
ConnID: connID,
ConnectionState: modelx.NewTLSConnectionState(state),
Error: err,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
},
})
return tlsconn, state, err
}
// TLSDialer is the TLS dialer
type TLSDialer struct {
Config *tls.Config
Dialer Dialer
TLSHandshaker TLSHandshaker
}
// DialTLSContext is like tls.DialTLS but with the signature of net.Dialer.DialContext
func (d TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
// Implementation note: when DialTLS is not set, the code in
// net/http will perform the handshake. Otherwise, if DialTLS
// is set, we will end up here. This code is still used when
// performing non-HTTP TLS-enabled dial operations.
host, _, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
config := d.Config
if config == nil {
config = new(tls.Config)
} else {
config = config.Clone()
}
if config.ServerName == "" {
config.ServerName = host
}
tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil {
conn.Close()
return nil, err
}
return tlsconn, nil
}
+277
View File
@@ -0,0 +1,277 @@
package dialer_test
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
func TestSystemTLSHandshakerEOFError(t *testing.T) {
h := dialer.SystemTLSHandshaker{}
conn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
ServerName: "x.org",
})
if err != io.EOF {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) {
h := dialer.TimeoutTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
expected := errors.New("mocked error")
conn, _, err := h.Handshake(
context.Background(), &dialer.FakeConn{SetDeadlineError: expected},
new(tls.Config))
if !errors.Is(err, expected) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerEOFError(t *testing.T) {
h := dialer.TimeoutTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
conn, _, err := h.Handshake(
context.Background(), dialer.EOFConn{}, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) {
h := dialer.TimeoutTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
underlying := &SetDeadlineConn{}
conn, _, err := h.Handshake(
context.Background(), underlying, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
if len(underlying.deadlines) != 2 {
t.Fatal("SetDeadline not called twice")
}
if underlying.deadlines[0].Before(time.Now()) {
t.Fatal("the first SetDeadline call was incorrect")
}
if !underlying.deadlines[1].IsZero() {
t.Fatal("the second SetDeadline call was incorrect")
}
}
type SetDeadlineConn struct {
dialer.EOFConn
deadlines []time.Time
}
func (c *SetDeadlineConn) SetDeadline(t time.Time) error {
c.deadlines = append(c.deadlines, t)
return nil
}
func TestErrorWrapperTLSHandshakerFailure(t *testing.T) {
h := dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}}
conn, _, err := h.Handshake(
context.Background(), dialer.EOFConn{}, new(tls.Config))
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
var errWrapper *errorx.ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("cannot cast to ErrWrapper")
}
if errWrapper.ConnID == 0 {
t.Fatal("unexpected ConnID")
}
if errWrapper.Failure != errorx.FailureEOFError {
t.Fatal("unexpected Failure")
}
if errWrapper.Operation != errorx.TLSHandshakeOperation {
t.Fatal("unexpected Operation")
}
}
func TestEmitterTLSHandshakerFailure(t *testing.T) {
saver := &handlers.SavingHandler{}
ctx := modelx.WithMeasurementRoot(context.Background(), &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: saver,
})
h := dialer.EmitterTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}}
conn, _, err := h.Handshake(ctx, dialer.EOFConn{}, &tls.Config{
ServerName: "www.kernel.org",
})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
events := saver.Read()
if len(events) != 2 {
t.Fatal("Wrong number of events")
}
if events[0].TLSHandshakeStart == nil {
t.Fatal("missing TLSHandshakeStart event")
}
if events[0].TLSHandshakeStart.ConnID == 0 {
t.Fatal("expected nonzero ConnID")
}
if events[0].TLSHandshakeStart.DurationSinceBeginning == 0 {
t.Fatal("expected nonzero DurationSinceBeginning")
}
if events[0].TLSHandshakeStart.SNI != "www.kernel.org" {
t.Fatal("expected nonzero SNI")
}
if events[1].TLSHandshakeDone == nil {
t.Fatal("missing TLSHandshakeDone event")
}
if events[1].TLSHandshakeDone.ConnID == 0 {
t.Fatal("expected nonzero ConnID")
}
if events[1].TLSHandshakeDone.DurationSinceBeginning == 0 {
t.Fatal("expected nonzero DurationSinceBeginning")
}
}
func TestTLSDialerFailureSplitHostPort(t *testing.T) {
dialer := dialer.TLSDialer{}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com") // missing port
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}
func TestTLSDialerFailureDialing(t *testing.T) {
dialer := dialer.TLSDialer{Dialer: dialer.EOFDialer{}}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}
func TestTLSDialerFailureHandshaking(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}}
dialer := dialer.TLSDialer{
Dialer: dialer.EOFConnDialer{},
TLSHandshaker: rec,
}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
if rec.SNI != "www.google.com" {
t.Fatal("unexpected SNI value")
}
}
func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}}
dialer := dialer.TLSDialer{
Config: &tls.Config{
ServerName: "x.org",
},
Dialer: dialer.EOFConnDialer{},
TLSHandshaker: rec,
}
conn, err := dialer.DialTLSContext(
context.Background(), "tcp", "www.google.com:443")
if !errors.Is(err, io.EOF) {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("connection is not nil")
}
if rec.SNI != "x.org" {
t.Fatal("unexpected SNI value")
}
}
type RecorderTLSHandshaker struct {
dialer.TLSHandshaker
SNI string
}
func (h *RecorderTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
h.SNI = config.ServerName
return h.TLSHandshaker.Handshake(ctx, conn, config)
}
func TestDialTLSContextGood(t *testing.T) {
dialer := dialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: dialer.SystemTLSHandshaker{},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err != nil {
t.Fatal(err)
}
if conn == nil {
t.Fatal("connection is nil")
}
conn.Close()
}
func TestDialTLSContextTimeout(t *testing.T) {
dialer := dialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: dialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: dialer.TimeoutTLSHandshaker{
TLSHandshaker: dialer.SystemTLSHandshaker{},
HandshakeTimeout: 10 * time.Microsecond,
},
},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err.Error() != errorx.FailureGenericTimeoutError {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("connection is not nil")
}
}
+322
View File
@@ -0,0 +1,322 @@
// Package errorx contains error extensions
package errorx
import (
"context"
"crypto/x509"
"errors"
"fmt"
"regexp"
"strings"
)
const (
// FailureConnectionRefused means ECONNREFUSED.
FailureConnectionRefused = "connection_refused"
// FailureConnectionReset means ECONNRESET.
FailureConnectionReset = "connection_reset"
// FailureDNSBogonError means we detected bogon in DNS reply.
FailureDNSBogonError = "dns_bogon_error"
// FailureDNSNXDOMAINError means we got NXDOMAIN in DNS reply.
FailureDNSNXDOMAINError = "dns_nxdomain_error"
// FailureEOFError means we got unexpected EOF on connection.
FailureEOFError = "eof_error"
// FailureGenericTimeoutError means we got some timer has expired.
FailureGenericTimeoutError = "generic_timeout_error"
// FailureInterrupted means that the user interrupted us.
FailureInterrupted = "interrupted"
// FailureNoCompatibleQUICVersion means that the server does not support the proposed QUIC version
FailureNoCompatibleQUICVersion = "quic_incompatible_version"
// FailureSSLInvalidHostname means we got certificate is not valid for SNI.
FailureSSLInvalidHostname = "ssl_invalid_hostname"
// FailureSSLUnknownAuthority means we cannot find CA validating certificate.
FailureSSLUnknownAuthority = "ssl_unknown_authority"
// FailureSSLInvalidCertificate means certificate experired or other
// sort of errors causing it to be invalid.
FailureSSLInvalidCertificate = "ssl_invalid_certificate"
// FailureJSONParseError indicates that we couldn't parse a JSON
FailureJSONParseError = "json_parse_error"
)
const (
// ResolveOperation is the operation where we resolve a domain name
ResolveOperation = "resolve"
// ConnectOperation is the operation where we do a TCP connect
ConnectOperation = "connect"
// TLSHandshakeOperation is the TLS handshake
TLSHandshakeOperation = "tls_handshake"
// QUICHandshakeOperation is the handshake to setup a QUIC connection
QUICHandshakeOperation = "quic_handshake"
// HTTPRoundTripOperation is the HTTP round trip
HTTPRoundTripOperation = "http_round_trip"
// CloseOperation is when we close a socket
CloseOperation = "close"
// ReadOperation is when we read from a socket
ReadOperation = "read"
// WriteOperation is when we write to a socket
WriteOperation = "write"
// ReadFromOperation is when we read from an UDP socket
ReadFromOperation = "read_from"
// WriteToOperation is when we write to an UDP socket
WriteToOperation = "write_to"
// UnknownOperation is when we cannot determine the operation
UnknownOperation = "unknown"
// TopLevelOperation is used when the failure happens at top level. This
// happens for example with urlgetter with a cancelled context.
TopLevelOperation = "top_level"
)
// ErrDNSBogon indicates that we found a bogon address. This is the
// correct value with which to initialize MeasurementRoot.ErrDNSBogon
// to tell this library to return an error when a bogon is found.
var ErrDNSBogon = errors.New("dns: detected bogon address")
// ErrWrapper is our error wrapper for Go errors. The key objective of
// this structure is to properly set Failure, which is also returned by
// the Error() method, so be one of the OONI defined strings.
type ErrWrapper struct {
// ConnID is the connection ID, or zero if not known.
ConnID int64
// DialID is the dial ID, or zero if not known.
DialID int64
// Failure is the OONI failure string. The failure strings are
// loosely backward compatible with Measurement Kit.
//
// This is either one of the FailureXXX strings or any other
// string like `unknown_failure ...`. The latter represents an
// error that we have not yet mapped to a failure.
Failure string
// Operation is the operation that failed. If possible, it
// SHOULD be a _major_ operation. Major operations are:
//
// - ResolveOperation: resolving a domain name failed
// - ConnectOperation: connecting to an IP failed
// - TLSHandshakeOperation: TLS handshaking failed
// - HTTPRoundTripOperation: other errors during round trip
//
// Because a network connection doesn't necessarily know
// what is the current major operation we also have the
// following _minor_ operations:
//
// - CloseOperation: CLOSE failed
// - ReadOperation: READ failed
// - WriteOperation: WRITE failed
//
// If an ErrWrapper referring to a major operation is wrapping
// another ErrWrapper and such ErrWrapper already refers to
// a major operation, then the new ErrWrapper should use the
// child ErrWrapper major operation. Otherwise, it should use
// its own major operation. This way, the topmost wrapper is
// supposed to refer to the major operation that failed.
Operation string
// TransactionID is the transaction ID, or zero if not known.
TransactionID int64
// WrappedErr is the error that we're wrapping.
WrappedErr error
}
// Error returns a description of the error that occurred.
func (e *ErrWrapper) Error() string {
return e.Failure
}
// Unwrap allows to access the underlying error
func (e *ErrWrapper) Unwrap() error {
return e.WrappedErr
}
// SafeErrWrapperBuilder contains a builder for ErrWrapper that
// is safe, i.e., behaves correctly when the error is nil.
type SafeErrWrapperBuilder struct {
// ConnID is the connection ID, if any
ConnID int64
// DialID is the dial ID, if any
DialID int64
// Error is the error, if any
Error error
// Operation is the operation that failed
Operation string
// TransactionID is the transaction ID, if any
TransactionID int64
}
// MaybeBuild builds a new ErrWrapper, if b.Error is not nil, and returns
// a nil error value, instead, if b.Error is nil.
func (b SafeErrWrapperBuilder) MaybeBuild() (err error) {
if b.Error != nil {
err = &ErrWrapper{
ConnID: b.ConnID,
DialID: b.DialID,
Failure: toFailureString(b.Error),
Operation: toOperationString(b.Error, b.Operation),
TransactionID: b.TransactionID,
WrappedErr: b.Error,
}
}
return
}
func toFailureString(err error) string {
// The list returned here matches the values used by MK unless
// explicitly noted otherwise with a comment.
var errwrapper *ErrWrapper
if errors.As(err, &errwrapper) {
return errwrapper.Error() // we've already wrapped it
}
if errors.Is(err, ErrDNSBogon) {
return FailureDNSBogonError // not in MK
}
if errors.Is(err, context.Canceled) {
return FailureInterrupted
}
var x509HostnameError x509.HostnameError
if errors.As(err, &x509HostnameError) {
// Test case: https://wrong.host.badssl.com/
return FailureSSLInvalidHostname
}
var x509UnknownAuthorityError x509.UnknownAuthorityError
if errors.As(err, &x509UnknownAuthorityError) {
// Test case: https://self-signed.badssl.com/. This error has
// never been among the ones returned by MK.
return FailureSSLUnknownAuthority
}
var x509CertificateInvalidError x509.CertificateInvalidError
if errors.As(err, &x509CertificateInvalidError) {
// Test case: https://expired.badssl.com/
return FailureSSLInvalidCertificate
}
s := err.Error()
if strings.HasSuffix(s, "operation was canceled") {
return FailureInterrupted
}
if strings.HasSuffix(s, "EOF") {
return FailureEOFError
}
if strings.HasSuffix(s, "connection refused") {
return FailureConnectionRefused
}
if strings.HasSuffix(s, "connection reset by peer") {
return FailureConnectionReset
}
if strings.HasSuffix(s, "context deadline exceeded") {
return FailureGenericTimeoutError
}
if strings.HasSuffix(s, "transaction is timed out") {
return FailureGenericTimeoutError
}
if strings.HasSuffix(s, "i/o timeout") {
return FailureGenericTimeoutError
}
if strings.HasSuffix(s, "TLS handshake timeout") {
return FailureGenericTimeoutError
}
if strings.HasSuffix(s, "no such host") {
// This is dns_lookup_error in MK but such error is used as a
// generic "hey, the lookup failed" error. Instead, this error
// that we return here is significantly more specific.
return FailureDNSNXDOMAINError
}
// TODO(kelmenhorst): see whether it is possible to match errors
// from qtls rather than strings for TLS errors below.
//
// TODO(kelmenhorst): make sure we have tests for all errors. Also,
// how to ensure we are robust to changes in other libs?
//
// special QUIC errors
matched, err := regexp.MatchString(`.*x509: certificate is valid for.*not.*`, s)
if matched {
return FailureSSLInvalidHostname
}
if strings.HasSuffix(s, "x509: certificate signed by unknown authority") {
return FailureSSLUnknownAuthority
}
certInvalidErrors := []string{"x509: certificate is not authorized to sign other certificates", "x509: certificate has expired or is not yet valid:", "x509: a root or intermediate certificate is not authorized to sign for this name:", "x509: a root or intermediate certificate is not authorized for an extended key usage:", "x509: too many intermediates for path length constraint", "x509: certificate specifies an incompatible key usage", "x509: issuer name does not match subject from issuing certificate", "x509: issuer has name constraints but leaf doesn't have a SAN extension", "x509: issuer has name constraints but leaf contains unknown or unconstrained name:"}
for _, errstr := range certInvalidErrors {
if strings.Contains(s, errstr) {
return FailureSSLInvalidCertificate
}
}
if strings.HasPrefix(s, "No compatible QUIC version found") {
return FailureNoCompatibleQUICVersion
}
if strings.HasSuffix(s, "Handshake did not complete in time") {
return FailureGenericTimeoutError
}
if strings.HasSuffix(s, "connection_refused") {
return FailureConnectionRefused
}
if strings.Contains(s, "stateless_reset") {
return FailureConnectionReset
}
if strings.Contains(s, "deadline exceeded") {
return FailureGenericTimeoutError
}
formatted := fmt.Sprintf("unknown_failure: %s", s)
return Scrub(formatted) // scrub IP addresses in the error
}
func toOperationString(err error, operation string) string {
var errwrapper *ErrWrapper
if errors.As(err, &errwrapper) {
// Basically, as explained in ErrWrapper docs, let's
// keep the child major operation, if any.
if errwrapper.Operation == ConnectOperation {
return errwrapper.Operation
}
if errwrapper.Operation == HTTPRoundTripOperation {
return errwrapper.Operation
}
if errwrapper.Operation == ResolveOperation {
return errwrapper.Operation
}
if errwrapper.Operation == TLSHandshakeOperation {
return errwrapper.Operation
}
if errwrapper.Operation == QUICHandshakeOperation {
return errwrapper.Operation
}
if errwrapper.Operation == "quic_handshake_start" {
return QUICHandshakeOperation
}
if errwrapper.Operation == "quic_handshake_done" {
return QUICHandshakeOperation
}
// FALLTHROUGH
}
return operation
}
+247
View File
@@ -0,0 +1,247 @@
package errorx
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"syscall"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/lucas-clemente/quic-go"
"github.com/pion/stun"
)
func TestMaybeBuildFactory(t *testing.T) {
err := SafeErrWrapperBuilder{
ConnID: 1,
DialID: 10,
Error: errors.New("mocked error"),
TransactionID: 100,
}.MaybeBuild()
var target *ErrWrapper
if errors.As(err, &target) == false {
t.Fatal("not the expected error type")
}
if target.ConnID != 1 {
t.Fatal("wrong ConnID")
}
if target.DialID != 10 {
t.Fatal("wrong DialID")
}
if target.Failure != "unknown_failure: mocked error" {
t.Fatal("the failure string is wrong")
}
if target.TransactionID != 100 {
t.Fatal("the transactionID is wrong")
}
if target.WrappedErr.Error() != "mocked error" {
t.Fatal("the wrapped error is wrong")
}
}
func TestToFailureString(t *testing.T) {
t.Run("for already wrapped error", func(t *testing.T) {
err := SafeErrWrapperBuilder{Error: io.EOF}.MaybeBuild()
if toFailureString(err) != FailureEOFError {
t.Fatal("unexpected result")
}
})
t.Run("for ErrDNSBogon", func(t *testing.T) {
if toFailureString(ErrDNSBogon) != FailureDNSBogonError {
t.Fatal("unexpected result")
}
})
t.Run("for context.Canceled", func(t *testing.T) {
if toFailureString(context.Canceled) != FailureInterrupted {
t.Fatal("unexpected result")
}
})
t.Run("for x509.HostnameError", func(t *testing.T) {
var err x509.HostnameError
if toFailureString(err) != FailureSSLInvalidHostname {
t.Fatal("unexpected result")
}
})
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
var err x509.UnknownAuthorityError
if toFailureString(err) != FailureSSLUnknownAuthority {
t.Fatal("unexpected result")
}
})
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
var err x509.CertificateInvalidError
if toFailureString(err) != FailureSSLInvalidCertificate {
t.Fatal("unexpected result")
}
})
t.Run("for operation was canceled error", func(t *testing.T) {
if toFailureString(errors.New("operation was canceled")) != FailureInterrupted {
t.Fatal("unexpected result")
}
})
t.Run("for EOF", func(t *testing.T) {
if toFailureString(io.EOF) != FailureEOFError {
t.Fatal("unexpected results")
}
})
t.Run("for connection_refused", func(t *testing.T) {
if toFailureString(syscall.ECONNREFUSED) != FailureConnectionRefused {
t.Fatal("unexpected results")
}
})
t.Run("for connection_reset", func(t *testing.T) {
if toFailureString(syscall.ECONNRESET) != FailureConnectionReset {
t.Fatal("unexpected results")
}
})
t.Run("for context deadline exceeded", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1)
defer cancel()
<-ctx.Done()
if toFailureString(ctx.Err()) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for stun's transaction is timed out", func(t *testing.T) {
if toFailureString(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for i/o error", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1)
defer cancel() // fail immediately
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", "www.google.com:80")
if err == nil {
t.Fatal("expected an error here")
}
if conn != nil {
t.Fatal("expected nil connection here")
}
if toFailureString(err) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for TLS handshake timeout error", func(t *testing.T) {
err := errors.New("net/http: TLS handshake timeout")
if toFailureString(err) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for no such host", func(t *testing.T) {
if toFailureString(&net.DNSError{
Err: "no such host",
}) != FailureDNSNXDOMAINError {
t.Fatal("unexpected results")
}
})
t.Run("for errors including IPv4 address", func(t *testing.T) {
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: use of closed network connection"
out := toFailureString(input)
if out != expected {
t.Fatal(cmp.Diff(expected, out))
}
})
t.Run("for errors including IPv6 address", func(t *testing.T) {
input := errors.New("read tcp [::1]:56948->[::1]:443: use of closed network connection")
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: use of closed network connection"
out := toFailureString(input)
if out != expected {
t.Fatal(cmp.Diff(expected, out))
}
})
// QUIC failures
t.Run("for connection_refused", func(t *testing.T) {
if toFailureString(errors.New("connection_refused")) != FailureConnectionRefused {
t.Fatal("unexpected results")
}
})
t.Run("for connection_reset", func(t *testing.T) {
if toFailureString(errors.New("stateless_reset")) != FailureConnectionReset {
t.Fatal("unexpected results")
}
})
t.Run("for incompatible quic version", func(t *testing.T) {
if toFailureString(errors.New("No compatible QUIC version found")) != FailureNoCompatibleQUICVersion {
t.Fatal("unexpected results")
}
})
t.Run("for i/o error", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1)
defer cancel() // fail immediately
udpAddr := &net.UDPAddr{IP: net.ParseIP("216.58.212.164"), Port: 80, Zone: ""}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
sess, err := quic.DialEarlyContext(ctx, udpConn, udpAddr, "google.com:80", &tls.Config{}, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected nil session here")
}
if toFailureString(err) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for QUIC handshake timeout error", func(t *testing.T) {
err := errors.New("Handshake did not complete in time")
if toFailureString(err) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
}
func TestToOperationString(t *testing.T) {
t.Run("for connect", func(t *testing.T) {
// You're doing HTTP and connect fails. You want to know
// that connect failed not that HTTP failed.
err := &ErrWrapper{Operation: ConnectOperation}
if toOperationString(err, HTTPRoundTripOperation) != ConnectOperation {
t.Fatal("unexpected result")
}
})
t.Run("for http_round_trip", func(t *testing.T) {
// You're doing DoH and something fails inside HTTP. You want
// to know about the internal HTTP error, not resolve.
err := &ErrWrapper{Operation: HTTPRoundTripOperation}
if toOperationString(err, ResolveOperation) != HTTPRoundTripOperation {
t.Fatal("unexpected result")
}
})
t.Run("for resolve", func(t *testing.T) {
// You're doing HTTP and the DNS fails. You want to
// know that resolve failed.
err := &ErrWrapper{Operation: ResolveOperation}
if toOperationString(err, HTTPRoundTripOperation) != ResolveOperation {
t.Fatal("unexpected result")
}
})
t.Run("for tls_handshake", func(t *testing.T) {
// You're doing HTTP and the TLS handshake fails. You want
// to know about a TLS handshake error.
err := &ErrWrapper{Operation: TLSHandshakeOperation}
if toOperationString(err, HTTPRoundTripOperation) != TLSHandshakeOperation {
t.Fatal("unexpected result")
}
})
t.Run("for minor operation", func(t *testing.T) {
// You just noticed that TLS handshake failed and you
// have a child error telling you that read failed. Here
// you want to know about a TLS handshake error.
err := &ErrWrapper{Operation: ReadOperation}
if toOperationString(err, TLSHandshakeOperation) != TLSHandshakeOperation {
t.Fatal("unexpected result")
}
})
t.Run("for quic_handshake", func(t *testing.T) {
// You're doing HTTP and the TLS handshake fails. You want
// to know about a TLS handshake error.
err := &ErrWrapper{Operation: QUICHandshakeOperation}
if toOperationString(err, HTTPRoundTripOperation) != QUICHandshakeOperation {
t.Fatal("unexpected result")
}
})
}
+70
View File
@@ -0,0 +1,70 @@
package errorx
import "regexp"
// The code in this file is adapted from github.com/keroserene/snowflake's
// common/safelog/safelog.go implementation <https://git.io/JfO9w>.
//
// ================================================================================
// Copyright (c) 2016, Serene Han, Arlo Breault
// Copyright (c) 2019-2020, The Tor Project, Inc
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation and/or
// other materials provided with the distribution.
//
// * Neither the names of the copyright owners nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
// ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// ================================================================================
const ipv4Address = `\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}`
const ipv6Address = `([0-9a-fA-F]{0,4}:){5,7}([0-9a-fA-F]{0,4})?`
const ipv6Compressed = `([0-9a-fA-F]{0,4}:){0,5}([0-9a-fA-F]{0,4})?(::)([0-9a-fA-F]{0,4}:){0,5}([0-9a-fA-F]{0,4})?`
const ipv6Full = `(` + ipv6Address + `(` + ipv4Address + `))` +
`|(` + ipv6Compressed + `(` + ipv4Address + `))` +
`|(` + ipv6Address + `)` + `|(` + ipv6Compressed + `)`
const optionalPort = `(:\d{1,5})?`
const addressPattern = `((` + ipv4Address + `)|(\[(` + ipv6Full + `)\])|(` + ipv6Full + `))` + optionalPort
const fullAddrPattern = `(^|\s|[^\w:])` + addressPattern + `(\s|(:\s)|[^\w:]|$)`
var scrubberPatterns = []*regexp.Regexp{
regexp.MustCompile(fullAddrPattern),
}
var addressRegexp = regexp.MustCompile(addressPattern)
func scrub(b []byte) []byte {
scrubbedBytes := b
for _, pattern := range scrubberPatterns {
// this is a workaround since go does not yet support look ahead or look
// behind for regular expressions.
scrubbedBytes = pattern.ReplaceAllFunc(scrubbedBytes, func(b []byte) []byte {
return addressRegexp.ReplaceAll(b, []byte("[scrubbed]"))
})
}
return scrubbedBytes
}
// Scrub sanitizes a string containing an error such that
// any occurrence of IP endpoints is scrubbed
func Scrub(s string) string {
return string(scrub([]byte(s)))
}
@@ -0,0 +1,129 @@
package errorx
import (
"testing"
"github.com/google/go-cmp/cmp"
)
// The code in this file is adapted from github.com/keroserene/snowflake's
// common/safelog/safelog.go implementation <https://git.io/JfO9w>.
//
// ================================================================================
// Copyright (c) 2016, Serene Han, Arlo Breault
// Copyright (c) 2019-2020, The Tor Project, Inc
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation and/or
// other materials provided with the distribution.
//
// * Neither the names of the copyright owners nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
// ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// ================================================================================
//Test the log scrubber on known problematic log messages
func TestLogScrubberMessages(t *testing.T) {
for _, test := range []struct {
input, expected string
}{
{
"http: TLS handshake error from 129.97.208.23:38310: ",
"http: TLS handshake error from [scrubbed]: ",
},
{
"http2: panic serving [2620:101:f000:780:9097:75b1:519f:dbb8]:58344: interface conversion: *http2.responseWriter is not http.Hijacker: missing method Hijack",
"http2: panic serving [scrubbed]: interface conversion: *http2.responseWriter is not http.Hijacker: missing method Hijack",
},
{
//Make sure it doesn't scrub fingerprint
"a=fingerprint:sha-256 33:B6:FA:F6:94:CA:74:61:45:4A:D2:1F:2C:2F:75:8A:D9:EB:23:34:B2:30:E9:1B:2A:A6:A9:E0:44:72:CC:74",
"a=fingerprint:sha-256 33:B6:FA:F6:94:CA:74:61:45:4A:D2:1F:2C:2F:75:8A:D9:EB:23:34:B2:30:E9:1B:2A:A6:A9:E0:44:72:CC:74",
},
{
//try with enclosing parens
"(1:2:3:4:c:d:e:f) {1:2:3:4:c:d:e:f}",
"([scrubbed]) {[scrubbed]}",
},
{
//Make sure it doesn't scrub timestamps
"2019/05/08 15:37:31 starting",
"2019/05/08 15:37:31 starting",
},
} {
if Scrub(test.input) != test.expected {
t.Error(cmp.Diff(test.input, test.expected))
}
}
}
func TestLogScrubberGoodFormats(t *testing.T) {
for _, addr := range []string{
// IPv4
"1.2.3.4",
"255.255.255.255",
// IPv4 with port
"1.2.3.4:55",
"255.255.255.255:65535",
// IPv6
"1:2:3:4:c:d:e:f",
"1111:2222:3333:4444:CCCC:DDDD:EEEE:FFFF",
// IPv6 with brackets
"[1:2:3:4:c:d:e:f]",
"[1111:2222:3333:4444:CCCC:DDDD:EEEE:FFFF]",
// IPv6 with brackets and port
"[1:2:3:4:c:d:e:f]:55",
"[1111:2222:3333:4444:CCCC:DDDD:EEEE:FFFF]:65535",
// compressed IPv6
"::f",
"::d:e:f",
"1:2:3::",
"1:2:3::d:e:f",
"1:2:3:d:e:f::",
"::1:2:3:d:e:f",
"1111:2222:3333::DDDD:EEEE:FFFF",
// compressed IPv6 with brackets
"[::d:e:f]",
"[1:2:3::]",
"[1:2:3::d:e:f]",
"[1111:2222:3333::DDDD:EEEE:FFFF]",
"[1:2:3:4:5:6::8]",
"[1::7:8]",
// compressed IPv6 with brackets and port
"[1::]:58344",
"[::d:e:f]:55",
"[1:2:3::]:55",
"[1:2:3::d:e:f]:55",
"[1111:2222:3333::DDDD:EEEE:FFFF]:65535",
// IPv4-compatible and IPv4-mapped
"::255.255.255.255",
"::ffff:255.255.255.255",
"[::255.255.255.255]",
"[::ffff:255.255.255.255]",
"[::255.255.255.255]:65535",
"[::ffff:255.255.255.255]:65535",
"[::ffff:0:255.255.255.255]",
"[2001:db8:3:4::192.0.2.33]",
} {
if Scrub(addr) != "[scrubbed]" {
t.Error(cmp.Diff(addr, "[scrubbed]"))
}
}
}
+56
View File
@@ -0,0 +1,56 @@
package netx
import (
"context"
"io/ioutil"
"net"
"net/http"
"time"
)
type FakeDialer struct {
Conn net.Conn
Err error
}
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return d.Conn, d.Err
}
type FakeTransport struct {
Err error
Func func(*http.Request) (*http.Response, error)
Resp *http.Response
}
func (txp FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
time.Sleep(10 * time.Microsecond)
if txp.Func != nil {
return txp.Func(req)
}
if req.Body != nil {
ioutil.ReadAll(req.Body)
req.Body.Close()
}
if txp.Err != nil {
return nil, txp.Err
}
txp.Resp.Request = req // non thread safe but it doesn't matter
return txp.Resp, nil
}
func (txp FakeTransport) CloseIdleConnections() {}
type FakeBody struct {
Err error
}
func (fb FakeBody) Read(p []byte) (int, error) {
time.Sleep(10 * time.Microsecond)
return 0, fb.Err
}
func (fb FakeBody) Close() error {
return nil
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,95 @@
// +build ignore
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// Forked from github.com/certifi/gocertifi <https://git.io/JJjmG>.
//
// This script should not be invoked directly, rather it should be
// executed by running go generate ./... from toplevel dir.
package main
import (
"crypto/x509"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"text/template"
"time"
)
var tmpl = template.Must(template.New("").Parse(`// Code generated by go generate; DO NOT EDIT.
// {{ .Timestamp }}
// {{ .URL }}
package gocertifi
//go:generate go run generate.go "{{ .URL }}"
import "crypto/x509"
const pemcerts string = ` + "`" + `
{{ .Bundle }}
` + "`" + `
// CACerts builds an X.509 certificate pool containing the
// certificate bundle from {{ .URL }} fetch on {{ .Timestamp }}.
// Returns nil on error along with an appropriate error code.
func CACerts() (*x509.CertPool, error) {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM([]byte(pemcerts))
return pool, nil
}
`))
func main() {
if len(os.Args) != 2 || !strings.HasPrefix(os.Args[1], "https://") {
log.Fatal("usage: go run generate.go <url>")
}
url := os.Args[1]
resp, err := http.Get(url)
if err != nil {
log.Fatal(err)
}
if resp.StatusCode != 200 {
log.Fatal("expected 200, got", resp.StatusCode)
}
defer resp.Body.Close()
bundle, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(bundle) {
log.Fatalf("can't parse certificates from %s", url)
}
fp, err := os.Create("certifi.go")
if err != nil {
log.Fatal(err)
}
err = tmpl.Execute(fp, struct {
Timestamp time.Time
URL string
Bundle string
}{
Timestamp: time.Now(),
URL: url,
Bundle: string(bundle),
})
if err != nil {
log.Fatal(err)
}
if err := fp.Close(); err != nil {
log.Fatal(err)
}
}
@@ -0,0 +1,73 @@
package httptransport
import (
"io"
"net/http"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
)
// ByteCountingTransport is a RoundTripper that counts bytes.
type ByteCountingTransport struct {
RoundTripper
Counter *bytecounter.Counter
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp ByteCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Body != nil {
req.Body = byteCountingBody{
ReadCloser: req.Body, Account: txp.Counter.CountBytesSent}
}
txp.estimateRequestMetadata(req)
resp, err := txp.RoundTripper.RoundTrip(req)
if err != nil {
return nil, err
}
txp.estimateResponseMetadata(resp)
resp.Body = byteCountingBody{
ReadCloser: resp.Body, Account: txp.Counter.CountBytesReceived}
return resp, nil
}
func (txp ByteCountingTransport) estimateRequestMetadata(req *http.Request) {
txp.Counter.CountBytesSent(len(req.Method))
txp.Counter.CountBytesSent(len(req.URL.String()))
for key, values := range req.Header {
for _, value := range values {
txp.Counter.CountBytesSent(len(key))
txp.Counter.CountBytesSent(len(": "))
txp.Counter.CountBytesSent(len(value))
txp.Counter.CountBytesSent(len("\r\n"))
}
}
txp.Counter.CountBytesSent(len("\r\n"))
}
func (txp ByteCountingTransport) estimateResponseMetadata(resp *http.Response) {
txp.Counter.CountBytesReceived(len(resp.Status))
for key, values := range resp.Header {
for _, value := range values {
txp.Counter.CountBytesReceived(len(key))
txp.Counter.CountBytesReceived(len(": "))
txp.Counter.CountBytesReceived(len(value))
txp.Counter.CountBytesReceived(len("\r\n"))
}
}
txp.Counter.CountBytesReceived(len("\r\n"))
}
type byteCountingBody struct {
io.ReadCloser
Account func(int)
}
func (r byteCountingBody) Read(p []byte) (int, error) {
count, err := r.ReadCloser.Read(p)
if count > 0 {
r.Account(count)
}
return count, err
}
var _ RoundTripper = ByteCountingTransport{}
@@ -0,0 +1,128 @@
package httptransport_test
import (
"errors"
"io"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
)
func TestByteCounterFailure(t *testing.T) {
counter := bytecounter.New()
txp := httptransport.ByteCountingTransport{
Counter: counter,
RoundTripper: httptransport.FakeTransport{
Err: io.EOF,
},
}
client := &http.Client{Transport: txp}
req, err := http.NewRequest(
"POST", "https://www.google.com", strings.NewReader("AAAAAA"))
if err != nil {
t.Fatal(err)
}
req.Header.Set("User-Agent", "antani-browser/1.0.0")
resp, err := client.Do(req)
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
if counter.Sent.Load() != 68 {
t.Fatal("expected around 68 bytes sent")
}
if counter.Received.Load() != 0 {
t.Fatal("expected zero bytes received")
}
}
func TestByteCounterSuccess(t *testing.T) {
counter := bytecounter.New()
txp := httptransport.ByteCountingTransport{
Counter: counter,
RoundTripper: httptransport.FakeTransport{
Resp: &http.Response{
Body: ioutil.NopCloser(strings.NewReader("1234567")),
Header: http.Header{
"Server": []string{"antani/0.1.0"},
},
Status: "200 OK",
StatusCode: http.StatusOK,
},
},
}
client := &http.Client{Transport: txp}
req, err := http.NewRequest(
"POST", "https://www.google.com", strings.NewReader("AAAAAA"))
if err != nil {
t.Fatal(err)
}
req.Header.Set("User-Agent", "antani-browser/1.0.0")
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if string(data) != "1234567" {
t.Fatal("expected a different body here")
}
if counter.Sent.Load() != 68 {
t.Fatal("expected around 68 bytes sent")
}
if counter.Received.Load() != 37 {
t.Fatal("expected zero around 37 bytes received")
}
}
func TestByteCounterSuccessWithEOF(t *testing.T) {
counter := bytecounter.New()
txp := httptransport.ByteCountingTransport{
Counter: counter,
RoundTripper: httptransport.FakeTransport{
Resp: &http.Response{
Body: bodyReaderWithEOF{},
Header: http.Header{
"Server": []string{"antani/0.1.0"},
},
Status: "200 OK",
StatusCode: http.StatusOK,
},
},
}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatal(err)
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
if string(data) != "A" {
t.Fatal("expected a different body here")
}
}
type bodyReaderWithEOF struct{}
func (bodyReaderWithEOF) Read(p []byte) (int, error) {
if len(p) < 1 {
panic("should not happen")
}
p[0] = 'A'
return 1, io.EOF // we want code to be robust to this
}
func (bodyReaderWithEOF) Close() error {
return nil
}
@@ -0,0 +1,56 @@
package httptransport
import (
"context"
"io/ioutil"
"net"
"net/http"
"time"
)
type FakeDialer struct {
Conn net.Conn
Err error
}
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return d.Conn, d.Err
}
type FakeTransport struct {
Err error
Func func(*http.Request) (*http.Response, error)
Resp *http.Response
}
func (txp FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
time.Sleep(10 * time.Microsecond)
if txp.Func != nil {
return txp.Func(req)
}
if req.Body != nil {
ioutil.ReadAll(req.Body)
req.Body.Close()
}
if txp.Err != nil {
return nil, txp.Err
}
txp.Resp.Request = req // non thread safe but it doesn't matter
return txp.Resp, nil
}
func (txp FakeTransport) CloseIdleConnections() {}
type FakeBody struct {
Err error
}
func (fb FakeBody) Read(p []byte) (int, error) {
time.Sleep(10 * time.Microsecond)
return 0, fb.Err
}
func (fb FakeBody) Close() error {
return nil
}
@@ -0,0 +1,43 @@
package httptransport
import (
"context"
"crypto/tls"
"net/http"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
)
// QUICWrapperDialer is a QUICDialer that wraps a ContextDialer
// This is necessary because the http3 RoundTripper does not support a DialContext method.
type QUICWrapperDialer struct {
Dialer quicdialer.ContextDialer
}
// Dial implements QUICDialer.Dial
func (d QUICWrapperDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
return d.Dialer.DialContext(context.Background(), network, host, tlsCfg, cfg)
}
// HTTP3Transport is a httptransport.RoundTripper using the http3 protocol.
type HTTP3Transport struct {
http3.RoundTripper
}
// CloseIdleConnections closes all the connections opened by this transport.
func (t *HTTP3Transport) CloseIdleConnections() {
t.RoundTripper.Close()
}
// NewHTTP3Transport creates a new HTTP3Transport instance.
func NewHTTP3Transport(config Config) RoundTripper {
txp := &HTTP3Transport{}
txp.QuicConfig = &quic.Config{}
txp.TLSClientConfig = config.TLSConfig
txp.Dial = config.QUICDialer.Dial
return txp
}
var _ RoundTripper = &http.Transport{}
@@ -0,0 +1,157 @@
package httptransport_test
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net/http"
"strings"
"testing"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
"github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor"
)
type MockQUICDialer struct{}
func (d MockQUICDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
return quic.DialAddrEarly(host, tlsCfg, cfg)
}
type MockSNIQUICDialer struct {
namech chan string
}
func (d MockSNIQUICDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
d.namech <- tlsCfg.ServerName
return quic.DialAddrEarly(host, tlsCfg, cfg)
}
type MockCertQUICDialer struct {
certch chan *x509.CertPool
}
func (d MockCertQUICDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
d.certch <- tlsCfg.RootCAs
return quic.DialAddrEarly(host, tlsCfg, cfg)
}
func TestHTTP3TransportSNI(t *testing.T) {
namech := make(chan string, 1)
sni := "sni.org"
txp := httptransport.NewHTTP3Transport(httptransport.Config{
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockSNIQUICDialer{namech: namech}, TLSConfig: &tls.Config{ServerName: sni}})
req, err := http.NewRequest("GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err == nil {
t.Fatal("expected error here")
}
if resp != nil {
t.Fatal("expected nil resp here")
}
if !strings.Contains(err.Error(), "certificate is valid for www.google.com, not "+sni) {
t.Fatal("unexpected error type", err)
}
servername := <-namech
if servername != sni {
t.Fatal("unexpected server name", servername)
}
}
func TestHTTP3TransportSNINoVerify(t *testing.T) {
namech := make(chan string, 1)
sni := "sni.org"
txp := httptransport.NewHTTP3Transport(httptransport.Config{
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockSNIQUICDialer{namech: namech}, TLSConfig: &tls.Config{ServerName: sni, InsecureSkipVerify: true}})
req, err := http.NewRequest("GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %+v", err)
}
if resp == nil {
t.Fatal("unexpected nil resp")
}
servername := <-namech
if servername != sni {
t.Fatal("unexpected server name", servername)
}
}
func TestHTTP3TransportCABundle(t *testing.T) {
certch := make(chan *x509.CertPool, 1)
certpool := x509.NewCertPool()
txp := httptransport.NewHTTP3Transport(httptransport.Config{
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockCertQUICDialer{certch: certch}, TLSConfig: &tls.Config{RootCAs: certpool}})
req, err := http.NewRequest("GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err == nil {
t.Fatal("expected error here")
}
if resp != nil {
t.Fatal("expected nil resp here")
}
// since the certificate pool is empty, the unknown authority error should be thrown
if !strings.Contains(err.Error(), "certificate signed by unknown authority") {
t.Fatal("unexpected error type")
}
certs := <-certch
if certs != certpool {
t.Fatal("not the certpool we expected")
}
}
func TestUnitHTTP3TransportSuccess(t *testing.T) {
txp := httptransport.NewHTTP3Transport(httptransport.Config{
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockQUICDialer{}})
req, err := http.NewRequest("GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if resp == nil {
t.Fatal("unexpected nil response here")
}
if resp.StatusCode != 200 {
t.Fatal("HTTP statuscode should be 200 OK", resp.StatusCode)
}
}
func TestUnitHTTP3TransportFailure(t *testing.T) {
txp := httptransport.NewHTTP3Transport(httptransport.Config{
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockQUICDialer{}})
ctx, cancel := context.WithCancel(context.Background())
cancel() // so that the request immediately fails
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err == nil {
t.Fatal("expected error here")
}
// context.Canceled error occurs if the test host supports QUIC
// timeout error ("Handshake did not complete in time") occurs if the test host does not support QUIC
if !(errors.Is(err, context.Canceled) || strings.HasSuffix(err.Error(), "Handshake did not complete in time")) {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
}
@@ -0,0 +1,47 @@
// Package httptransport contains HTTP transport extensions.
package httptransport
import (
"context"
"crypto/tls"
"net"
"net/http"
"github.com/lucas-clemente/quic-go"
)
// Config contains the configuration required for constructing an HTTP transport
type Config struct {
Dialer Dialer
QUICDialer QUICDialer
TLSDialer TLSDialer
TLSConfig *tls.Config
}
// Dialer is the definition of dialer assumed by this package.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// TLSDialer is the definition of a TLS dialer assumed by this package.
type TLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
}
// QUICDialer is the definition of dialer for QUIC assumed by this package.
type QUICDialer interface {
Dial(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
}
// RoundTripper is the definition of http.RoundTripper used by this package.
type RoundTripper interface {
RoundTrip(req *http.Request) (*http.Response, error)
CloseIdleConnections()
}
// Resolver is the interface we expect from a resolver
type Resolver interface {
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
Network() string
Address() string
}
@@ -0,0 +1,50 @@
package httptransport
import "net/http"
// Logger is the logger assumed by this package
type Logger interface {
Debugf(format string, v ...interface{})
Debug(message string)
}
// LoggingTransport is a logging transport
type LoggingTransport struct {
RoundTripper
Logger Logger
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp LoggingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
host := req.Host
if host == "" {
host = req.URL.Host
}
req.Header.Set("Host", host) // anticipate what Go would do
return txp.logTrip(req)
}
func (txp LoggingTransport) logTrip(req *http.Request) (*http.Response, error) {
txp.Logger.Debugf("> %s %s", req.Method, req.URL.String())
for key, values := range req.Header {
for _, value := range values {
txp.Logger.Debugf("> %s: %s", key, value)
}
}
txp.Logger.Debug(">")
resp, err := txp.RoundTripper.RoundTrip(req)
if err != nil {
txp.Logger.Debugf("< %s", err)
return nil, err
}
txp.Logger.Debugf("< %d", resp.StatusCode)
for key, values := range resp.Header {
for _, value := range values {
txp.Logger.Debugf("< %s: %s", key, value)
}
}
txp.Logger.Debug("<")
return resp, nil
}
var _ RoundTripper = LoggingTransport{}
@@ -0,0 +1,77 @@
package httptransport_test
import (
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
)
func TestLoggingFailure(t *testing.T) {
txp := httptransport.LoggingTransport{
Logger: log.Log,
RoundTripper: httptransport.FakeTransport{
Err: io.EOF,
},
}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
}
func TestLoggingFailureWithNoHostHeader(t *testing.T) {
txp := httptransport.LoggingTransport{
Logger: log.Log,
RoundTripper: httptransport.FakeTransport{
Err: io.EOF,
},
}
req := &http.Request{
Header: http.Header{},
URL: &url.URL{
Scheme: "https",
Host: "www.google.com",
Path: "/",
},
}
resp, err := txp.RoundTrip(req)
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
}
func TestLoggingSuccess(t *testing.T) {
txp := httptransport.LoggingTransport{
Logger: log.Log,
RoundTripper: httptransport.FakeTransport{
Resp: &http.Response{
Body: ioutil.NopCloser(strings.NewReader("")),
Header: http.Header{
"Server": []string{"antani/0.1.0"},
},
StatusCode: 200,
},
},
}
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatal(err)
}
ioutil.ReadAll(resp.Body)
resp.Body.Close()
}
+158
View File
@@ -0,0 +1,158 @@
package httptransport
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/http/httptrace"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// SaverPerformanceHTTPTransport is a RoundTripper that saves
// performance events occurring during the round trip
type SaverPerformanceHTTPTransport struct {
RoundTripper
Saver *trace.Saver
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp SaverPerformanceHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
tracep := httptrace.ContextClientTrace(req.Context())
if tracep == nil {
tracep = &httptrace.ClientTrace{
WroteHeaders: func() {
txp.Saver.Write(trace.Event{Name: "http_wrote_headers", Time: time.Now()})
},
WroteRequest: func(httptrace.WroteRequestInfo) {
txp.Saver.Write(trace.Event{Name: "http_wrote_request", Time: time.Now()})
},
GotFirstResponseByte: func() {
txp.Saver.Write(trace.Event{
Name: "http_first_response_byte", Time: time.Now()})
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), tracep))
}
return txp.RoundTripper.RoundTrip(req)
}
// SaverMetadataHTTPTransport is a RoundTripper that saves
// events related to HTTP request and response metadata
type SaverMetadataHTTPTransport struct {
RoundTripper
Saver *trace.Saver
Transport string
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp SaverMetadataHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
txp.Saver.Write(trace.Event{
HTTPHeaders: req.Header,
HTTPMethod: req.Method,
HTTPURL: req.URL.String(),
Transport: txp.Transport,
Name: "http_request_metadata",
Time: time.Now(),
})
resp, err := txp.RoundTripper.RoundTrip(req)
if err != nil {
return nil, err
}
txp.Saver.Write(trace.Event{
HTTPHeaders: resp.Header,
HTTPStatusCode: resp.StatusCode,
Name: "http_response_metadata",
Time: time.Now(),
})
return resp, err
}
// SaverTransactionHTTPTransport is a RoundTripper that saves
// events related to the HTTP transaction
type SaverTransactionHTTPTransport struct {
RoundTripper
Saver *trace.Saver
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
txp.Saver.Write(trace.Event{
Name: "http_transaction_start",
Time: time.Now(),
})
resp, err := txp.RoundTripper.RoundTrip(req)
txp.Saver.Write(trace.Event{
Err: err,
Name: "http_transaction_done",
Time: time.Now(),
})
return resp, err
}
// SaverBodyHTTPTransport is a RoundTripper that saves
// body events occurring during the round trip
type SaverBodyHTTPTransport struct {
RoundTripper
Saver *trace.Saver
SnapshotSize int
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
const defaultSnapSize = 1 << 17
snapsize := defaultSnapSize
if txp.SnapshotSize != 0 {
snapsize = txp.SnapshotSize
}
if req.Body != nil {
data, err := saverSnapRead(req.Body, snapsize)
if err != nil {
return nil, err
}
req.Body = saverCompose(data, req.Body)
txp.Saver.Write(trace.Event{
DataIsTruncated: len(data) >= snapsize,
Data: data,
Name: "http_request_body_snapshot",
Time: time.Now(),
})
}
resp, err := txp.RoundTripper.RoundTrip(req)
if err != nil {
return nil, err
}
data, err := saverSnapRead(resp.Body, snapsize)
if err != nil {
resp.Body.Close()
return nil, err
}
resp.Body = saverCompose(data, resp.Body)
txp.Saver.Write(trace.Event{
DataIsTruncated: len(data) >= snapsize,
Data: data,
Name: "http_response_body_snapshot",
Time: time.Now(),
})
return resp, nil
}
func saverSnapRead(r io.ReadCloser, snapsize int) ([]byte, error) {
return ioutil.ReadAll(io.LimitReader(r, int64(snapsize)))
}
func saverCompose(data []byte, r io.ReadCloser) io.ReadCloser {
return saverReadCloser{Closer: r, Reader: io.MultiReader(bytes.NewReader(data), r)}
}
type saverReadCloser struct {
io.Closer
io.Reader
}
var _ RoundTripper = SaverPerformanceHTTPTransport{}
var _ RoundTripper = SaverMetadataHTTPTransport{}
var _ RoundTripper = SaverBodyHTTPTransport{}
var _ RoundTripper = SaverTransactionHTTPTransport{}
@@ -0,0 +1,429 @@
package httptransport_test
import (
"errors"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
func TestSaverPerformanceNoMultipleEvents(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
// register twice - do we see events twice?
txp := httptransport.SaverPerformanceHTTPTransport{
RoundTripper: http.DefaultTransport.(*http.Transport),
Saver: saver,
}
txp = httptransport.SaverPerformanceHTTPTransport{
RoundTripper: txp,
Saver: saver,
}
req, err := http.NewRequest("GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatal("not the error we expected")
}
if resp == nil {
t.Fatal("expected non nil response here")
}
ev := saver.Read()
// we should specifically see the events not attached to any
// context being submitted twice. This is fine because they are
// explicit, while the context is implicit and hence leads to
// more subtle bugs. For example, this happens when you measure
// every event and combine HTTP with DoH.
if len(ev) != 3 {
t.Fatal("expected three events")
}
expected := []string{
"http_wrote_headers", // measured with context
"http_wrote_request", // measured with context
"http_first_response_byte", // measured with context
}
for i := 0; i < len(expected); i++ {
if ev[i].Name != expected[i] {
t.Fatal("unexpected event name")
}
}
}
func TestSaverMetadataSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
txp := httptransport.SaverMetadataHTTPTransport{
RoundTripper: http.DefaultTransport.(*http.Transport),
Saver: saver,
}
req, err := http.NewRequest("GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("User-Agent", "miniooni/0.1.0-dev")
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatal("not the error we expected")
}
if resp == nil {
t.Fatal("expected non nil response here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected two events")
}
//
if ev[0].HTTPMethod != "GET" {
t.Fatal("unexpected Method")
}
if len(ev[0].HTTPHeaders) <= 0 {
t.Fatal("unexpected Headers")
}
if ev[0].HTTPURL != "https://www.google.com" {
t.Fatal("unexpected URL")
}
if ev[0].Name != "http_request_metadata" {
t.Fatal("unexpected Name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
//
if ev[1].HTTPStatusCode != 200 {
t.Fatal("unexpected StatusCode")
}
if len(ev[1].HTTPHeaders) <= 0 {
t.Fatal("unexpected Headers")
}
if ev[1].Name != "http_response_metadata" {
t.Fatal("unexpected Name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverMetadataFailure(t *testing.T) {
expected := errors.New("mocked error")
saver := &trace.Saver{}
txp := httptransport.SaverMetadataHTTPTransport{
RoundTripper: httptransport.FakeTransport{
Err: expected,
},
Saver: saver,
}
req, err := http.NewRequest("GET", "http://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Add("User-Agent", "miniooni/0.1.0-dev")
resp, err := txp.RoundTrip(req)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("expected one event")
}
if ev[0].HTTPMethod != "GET" {
t.Fatal("unexpected Method")
}
if len(ev[0].HTTPHeaders) <= 0 {
t.Fatal("unexpected Headers")
}
if ev[0].HTTPURL != "http://www.google.com" {
t.Fatal("unexpected URL")
}
if ev[0].Name != "http_request_metadata" {
t.Fatal("unexpected Name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
}
func TestSaverTransactionSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := &trace.Saver{}
txp := httptransport.SaverTransactionHTTPTransport{
RoundTripper: http.DefaultTransport.(*http.Transport),
Saver: saver,
}
req, err := http.NewRequest("GET", "https://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatal("not the error we expected")
}
if resp == nil {
t.Fatal("expected non nil response here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected two events")
}
//
if ev[0].Name != "http_transaction_start" {
t.Fatal("unexpected Name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
//
if ev[1].Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name != "http_transaction_done" {
t.Fatal("unexpected Name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverTransactionFailure(t *testing.T) {
expected := errors.New("mocked error")
saver := &trace.Saver{}
txp := httptransport.SaverTransactionHTTPTransport{
RoundTripper: httptransport.FakeTransport{
Err: expected,
},
Saver: saver,
}
req, err := http.NewRequest("GET", "http://www.google.com", nil)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected two events")
}
if ev[0].Name != "http_transaction_start" {
t.Fatal("unexpected Name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("unexpected Time")
}
if ev[1].Name != "http_transaction_done" {
t.Fatal("unexpected Name")
}
if !errors.Is(ev[1].Err, expected) {
t.Fatal("unexpected Err")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("unexpected Time")
}
}
func TestSaverBodySuccess(t *testing.T) {
saver := new(trace.Saver)
txp := httptransport.SaverBodyHTTPTransport{
RoundTripper: httptransport.FakeTransport{
Func: func(req *http.Request) (*http.Response, error) {
data, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Fatal(err)
}
if string(data) != "deadbeef" {
t.Fatal("invalid data")
}
return &http.Response{
StatusCode: 501,
Body: ioutil.NopCloser(strings.NewReader("abad1dea")),
}, nil
},
},
SnapshotSize: 4,
Saver: saver,
}
body := strings.NewReader("deadbeef")
req, err := http.NewRequest("POST", "http://x.org/y", body)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 501 {
t.Fatal("unexpected status code")
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(data) != "abad1dea" {
t.Fatal("unexpected body")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("unexpected number of events")
}
if string(ev[0].Data) != "dead" {
t.Fatal("invalid Data")
}
if ev[0].DataIsTruncated != true {
t.Fatal("invalid DataIsTruncated")
}
if ev[0].Name != "http_request_body_snapshot" {
t.Fatal("invalid Name")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("invalid Time")
}
if string(ev[1].Data) != "abad" {
t.Fatal("invalid Data")
}
if ev[1].DataIsTruncated != true {
t.Fatal("invalid DataIsTruncated")
}
if ev[1].Name != "http_response_body_snapshot" {
t.Fatal("invalid Name")
}
if ev[1].Time.Before(ev[0].Time) {
t.Fatal("invalid Time")
}
}
func TestSaverBodyRequestReadError(t *testing.T) {
saver := new(trace.Saver)
txp := httptransport.SaverBodyHTTPTransport{
RoundTripper: httptransport.FakeTransport{
Func: func(req *http.Request) (*http.Response, error) {
panic("should not be called")
},
},
SnapshotSize: 4,
Saver: saver,
}
expected := errors.New("mocked error")
body := httptransport.FakeBody{Err: expected}
req, err := http.NewRequest("POST", "http://x.org/y", body)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response")
}
ev := saver.Read()
if len(ev) != 0 {
t.Fatal("unexpected number of events")
}
}
func TestSaverBodyRoundTripError(t *testing.T) {
saver := new(trace.Saver)
expected := errors.New("mocked error")
txp := httptransport.SaverBodyHTTPTransport{
RoundTripper: httptransport.FakeTransport{
Err: expected,
},
SnapshotSize: 4,
Saver: saver,
}
body := strings.NewReader("deadbeef")
req, err := http.NewRequest("POST", "http://x.org/y", body)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("unexpected number of events")
}
if string(ev[0].Data) != "dead" {
t.Fatal("invalid Data")
}
if ev[0].DataIsTruncated != true {
t.Fatal("invalid DataIsTruncated")
}
if ev[0].Name != "http_request_body_snapshot" {
t.Fatal("invalid Name")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("invalid Time")
}
}
func TestSaverBodyResponseReadError(t *testing.T) {
saver := new(trace.Saver)
expected := errors.New("mocked error")
txp := httptransport.SaverBodyHTTPTransport{
RoundTripper: httptransport.FakeTransport{
Func: func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: httptransport.FakeBody{
Err: expected,
},
}, nil
},
},
SnapshotSize: 4,
Saver: saver,
}
body := strings.NewReader("deadbeef")
req, err := http.NewRequest("POST", "http://x.org/y", body)
if err != nil {
t.Fatal(err)
}
resp, err := txp.RoundTrip(req)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil response")
}
ev := saver.Read()
if len(ev) != 1 {
t.Fatal("unexpected number of events")
}
if string(ev[0].Data) != "dead" {
t.Fatal("invalid Data")
}
if ev[0].DataIsTruncated != true {
t.Fatal("invalid DataIsTruncated")
}
if ev[0].Name != "http_request_body_snapshot" {
t.Fatal("invalid Name")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("invalid Time")
}
}
@@ -0,0 +1,24 @@
package httptransport
import (
"net/http"
)
// NewSystemTransport creates a new "system" HTTP transport. That is a transport
// using the Go standard library with custom dialer and TLS dialer.
func NewSystemTransport(config Config) RoundTripper {
txp := http.DefaultTransport.(*http.Transport).Clone()
txp.DialContext = config.Dialer.DialContext
txp.DialTLSContext = config.TLSDialer.DialTLSContext
// Better for Cloudflare DNS and also better because we have less
// noisy events and we can better understand what happened.
txp.MaxConnsPerHost = 1
// The following (1) reduces the number of headers that Go will
// automatically send for us and (2) ensures that we always receive
// back the true headers, such as Content-Length. This change is
// functional to OONI's goal of observing the network.
txp.DisableCompression = true
return txp
}
var _ RoundTripper = &http.Transport{}
@@ -0,0 +1,19 @@
package httptransport
import "net/http"
// UserAgentTransport is a transport that ensures that we always
// set an OONI specific default User-Agent header.
type UserAgentTransport struct {
RoundTripper
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp UserAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Header.Get("User-Agent") == "" {
req.Header.Set("User-Agent", "miniooni/0.1.0-dev")
}
return txp.RoundTripper.RoundTrip(req)
}
var _ RoundTripper = UserAgentTransport{}
@@ -0,0 +1,51 @@
package httptransport_test
import (
"net/http"
"net/url"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
)
func TestUserAgentWithDefault(t *testing.T) {
txp := httptransport.UserAgentTransport{
RoundTripper: httptransport.FakeTransport{
Resp: &http.Response{StatusCode: 200},
},
}
req := &http.Request{URL: &url.URL{
Scheme: "https",
Host: "www.google.com",
Path: "/",
}}
req.Header = http.Header{}
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if resp.Request.Header.Get("User-Agent") != "miniooni/0.1.0-dev" {
t.Fatal("not the User-Agent we expected")
}
}
func TestUserAgentWithExplicitValue(t *testing.T) {
txp := httptransport.UserAgentTransport{
RoundTripper: httptransport.FakeTransport{
Resp: &http.Response{StatusCode: 200},
},
}
req := &http.Request{URL: &url.URL{
Scheme: "https",
Host: "www.google.com",
Path: "/",
}}
req.Header = http.Header{"User-Agent": []string{"antani-client/0.1.1"}}
resp, err := txp.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
if resp.Request.Header.Get("User-Agent") != "antani-client/0.1.1" {
t.Fatal("not the User-Agent we expected")
}
}
+93
View File
@@ -0,0 +1,93 @@
package netx_test
import (
"context"
"errors"
"io/ioutil"
"net/http"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
func TestSuccess(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
log.SetLevel(log.DebugLevel)
counter := bytecounter.New()
config := netx.Config{
BogonIsError: true,
ByteCounter: counter,
CacheResolutions: true,
ContextByteCounting: true,
DialSaver: &trace.Saver{},
HTTPSaver: &trace.Saver{},
Logger: log.Log,
ReadWriteSaver: &trace.Saver{},
ResolveSaver: &trace.Saver{},
TLSSaver: &trace.Saver{},
}
txp := netx.NewHTTPTransport(config)
client := &http.Client{Transport: txp}
resp, err := client.Get("https://www.google.com")
if err != nil {
t.Fatal(err)
}
if _, err = ioutil.ReadAll(resp.Body); err != nil {
t.Fatal(err)
}
if err = resp.Body.Close(); err != nil {
t.Fatal(err)
}
if counter.Sent.Load() <= 0 {
t.Fatal("no bytes sent?!")
}
if counter.Received.Load() <= 0 {
t.Fatal("no bytes received?!")
}
if ev := config.DialSaver.Read(); len(ev) <= 0 {
t.Fatal("no dial events?!")
}
if ev := config.HTTPSaver.Read(); len(ev) <= 0 {
t.Fatal("no HTTP events?!")
}
if ev := config.ReadWriteSaver.Read(); len(ev) <= 0 {
t.Fatal("no R/W events?!")
}
if ev := config.ResolveSaver.Read(); len(ev) <= 0 {
t.Fatal("no resolver events?!")
}
if ev := config.TLSSaver.Read(); len(ev) <= 0 {
t.Fatal("no TLS events?!")
}
}
func TestBogonResolutionNotBroken(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
saver := new(trace.Saver)
r := netx.NewResolver(netx.Config{
BogonIsError: true,
DNSCache: map[string][]string{
"www.google.com": {"127.0.0.1"},
},
ResolveSaver: saver,
Logger: log.Log,
})
addrs, err := r.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, errorx.ErrDNSBogon) {
t.Fatal("not the error we expected")
}
if err.Error() != errorx.FailureDNSBogonError {
t.Fatal("error not correctly wrapped")
}
if len(addrs) != 1 || addrs[0] != "127.0.0.1" {
t.Fatal("address was not returned")
}
}
+485
View File
@@ -0,0 +1,485 @@
// Package netx contains code to perform network measurements.
//
// This library contains replacements for commonly used standard library
// interfaces that facilitate seamless network measurements. By using
// such replacements, as opposed to standard library interfaces, we can:
//
// * save the timing of HTTP events (e.g. received response headers)
// * save the timing and result of every Connect, Read, Write, Close operation
// * save the timing and result of the TLS handshake (including certificates)
//
// By default, this library uses the system resolver. In addition, it
// is possible to configure alternative DNS transports and remote
// servers. We support DNS over UDP, DNS over TCP, DNS over TLS (DoT),
// and DNS over HTTPS (DoH). When using an alternative transport, we
// are also able to intercept and save DNS messages, as well as any
// other interaction with the remote server (e.g., the result of the
// TLS handshake for DoT and DoH).
//
// We described the design and implementation of the most recent version of
// this package at <https://github.com/ooni/probe-cli/v3/internal/engine/issues/359>. Such
// issue also links to a previous design document.
package netx
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"net/http"
"net/url"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/internal/runtimex"
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/gocertifi"
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
"github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// Logger is the logger assumed by this package
type Logger interface {
Debugf(format string, v ...interface{})
Debug(message string)
}
// Dialer is the definition of dialer assumed by this package.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// QUICDialer is the definition of a dialer for QUIC assumed by this package.
type QUICDialer interface {
Dial(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
}
// TLSDialer is the definition of a TLS dialer assumed by this package.
type TLSDialer interface {
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
}
// HTTPRoundTripper is the definition of http.HTTPRoundTripper used by this package.
type HTTPRoundTripper interface {
RoundTrip(req *http.Request) (*http.Response, error)
CloseIdleConnections()
}
// Resolver is the interface we expect from a resolver
type Resolver interface {
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
Network() string
Address() string
}
// Config contains configuration for creating a new transport. When any
// field of Config is nil/empty, we will use a suitable default.
//
// We use different savers for different kind of events such that the
// user of this library can choose what to save.
type Config struct {
BaseResolver Resolver // default: system resolver
BogonIsError bool // default: bogon is not error
ByteCounter *bytecounter.Counter // default: no explicit byte counting
CacheResolutions bool // default: no caching
CertPool *x509.CertPool // default: use vendored gocertifi
ContextByteCounting bool // default: no implicit byte counting
DNSCache map[string][]string // default: cache is empty
DialSaver *trace.Saver // default: not saving dials
Dialer Dialer // default: dialer.DNSDialer
FullResolver Resolver // default: base resolver + goodies
QUICDialer QUICDialer // default: quicdialer.DNSDialer
HTTP3Enabled bool // default: disabled
HTTPSaver *trace.Saver // default: not saving HTTP
Logger Logger // default: no logging
NoTLSVerify bool // default: perform TLS verify
ProxyURL *url.URL // default: no proxy
ReadWriteSaver *trace.Saver // default: not saving read/write
ResolveSaver *trace.Saver // default: not saving resolves
TLSConfig *tls.Config // default: attempt using h2
TLSDialer TLSDialer // default: dialer.TLSDialer
TLSSaver *trace.Saver // default: not saving TLS
}
type tlsHandshaker interface {
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}
// NewDefaultCertPool returns a copy of the default x509
// certificate pool. This function panics on failure.
func NewDefaultCertPool() *x509.CertPool {
pool, err := gocertifi.CACerts()
runtimex.PanicOnError(err, "gocertifi.CACerts() failed")
return pool
}
var defaultCertPool *x509.CertPool = NewDefaultCertPool()
// NewResolver creates a new resolver from the specified config
func NewResolver(config Config) Resolver {
if config.BaseResolver == nil {
config.BaseResolver = resolver.SystemResolver{}
}
var r Resolver = config.BaseResolver
if config.CacheResolutions {
r = &resolver.CacheResolver{Resolver: r}
}
if config.DNSCache != nil {
cache := &resolver.CacheResolver{Resolver: r, ReadOnly: true}
for key, values := range config.DNSCache {
cache.Set(key, values)
}
r = cache
}
if config.BogonIsError {
r = resolver.BogonResolver{Resolver: r}
}
r = resolver.ErrorWrapperResolver{Resolver: r}
if config.Logger != nil {
r = resolver.LoggingResolver{Logger: config.Logger, Resolver: r}
}
if config.ResolveSaver != nil {
r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver}
}
r = resolver.AddressResolver{Resolver: r}
return resolver.IDNAResolver{Resolver: r}
}
// NewDialer creates a new Dialer from the specified config
func NewDialer(config Config) Dialer {
if config.FullResolver == nil {
config.FullResolver = NewResolver(config)
}
var d Dialer = selfcensor.SystemDialer{}
d = dialer.TimeoutDialer{Dialer: d}
d = dialer.ErrorWrapperDialer{Dialer: d}
if config.Logger != nil {
d = dialer.LoggingDialer{Dialer: d, Logger: config.Logger}
}
if config.DialSaver != nil {
d = dialer.SaverDialer{Dialer: d, Saver: config.DialSaver}
}
if config.ReadWriteSaver != nil {
d = dialer.SaverConnDialer{Dialer: d, Saver: config.ReadWriteSaver}
}
d = dialer.DNSDialer{Resolver: config.FullResolver, Dialer: d}
d = dialer.ProxyDialer{ProxyURL: config.ProxyURL, Dialer: d}
if config.ContextByteCounting {
d = dialer.ByteCounterDialer{Dialer: d}
}
d = dialer.ShapingDialer{Dialer: d}
return d
}
// NewQUICDialer creates a new DNS Dialer for QUIC, with the resolver from the specified config
func NewQUICDialer(config Config) QUICDialer {
if config.FullResolver == nil {
config.FullResolver = NewResolver(config)
}
var d quicdialer.ContextDialer = &quicdialer.SystemDialer{Saver: config.ReadWriteSaver}
d = quicdialer.ErrorWrapperDialer{Dialer: d}
if config.TLSSaver != nil {
d = quicdialer.HandshakeSaver{Saver: config.TLSSaver, Dialer: d}
}
d = &quicdialer.DNSDialer{Resolver: config.FullResolver, Dialer: d}
var dialer QUICDialer = &httptransport.QUICWrapperDialer{Dialer: d}
return dialer
}
// NewTLSDialer creates a new TLSDialer from the specified config
func NewTLSDialer(config Config) TLSDialer {
if config.Dialer == nil {
config.Dialer = NewDialer(config)
}
var h tlsHandshaker = dialer.SystemTLSHandshaker{}
h = dialer.TimeoutTLSHandshaker{TLSHandshaker: h}
h = dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
if config.Logger != nil {
h = dialer.LoggingTLSHandshaker{Logger: config.Logger, TLSHandshaker: h}
}
if config.TLSSaver != nil {
h = dialer.SaverTLSHandshaker{TLSHandshaker: h, Saver: config.TLSSaver}
}
if config.TLSConfig == nil {
config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
}
if config.CertPool == nil {
config.CertPool = defaultCertPool
}
config.TLSConfig.RootCAs = config.CertPool
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
return dialer.TLSDialer{
Config: config.TLSConfig,
Dialer: config.Dialer,
TLSHandshaker: h,
}
}
// NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned
// HTTPRoundTripper before wrapping it into an http.Client.
func NewHTTPTransport(config Config) HTTPRoundTripper {
if config.Dialer == nil {
config.Dialer = NewDialer(config)
}
if config.TLSDialer == nil {
config.TLSDialer = NewTLSDialer(config)
}
if config.QUICDialer == nil {
config.QUICDialer = NewQUICDialer(config)
}
tInfo := allTransportsInfo[config.HTTP3Enabled]
txp := tInfo.Factory(httptransport.Config{
Dialer: config.Dialer, QUICDialer: config.QUICDialer, TLSDialer: config.TLSDialer,
TLSConfig: config.TLSConfig})
transport := tInfo.TransportName
if config.ByteCounter != nil {
txp = httptransport.ByteCountingTransport{
Counter: config.ByteCounter, RoundTripper: txp}
}
if config.Logger != nil {
txp = httptransport.LoggingTransport{Logger: config.Logger, RoundTripper: txp}
}
if config.HTTPSaver != nil {
txp = httptransport.SaverMetadataHTTPTransport{
RoundTripper: txp, Saver: config.HTTPSaver, Transport: transport}
txp = httptransport.SaverBodyHTTPTransport{
RoundTripper: txp, Saver: config.HTTPSaver}
txp = httptransport.SaverPerformanceHTTPTransport{
RoundTripper: txp, Saver: config.HTTPSaver}
txp = httptransport.SaverTransactionHTTPTransport{
RoundTripper: txp, Saver: config.HTTPSaver}
}
txp = httptransport.UserAgentTransport{RoundTripper: txp}
return txp
}
// httpTransportInfo contains the constructing function as well as the transport name
type httpTransportInfo struct {
Factory func(httptransport.Config) httptransport.RoundTripper
TransportName string
}
var allTransportsInfo = map[bool]httpTransportInfo{
false: {
Factory: httptransport.NewSystemTransport,
TransportName: "tcp",
},
true: {
Factory: httptransport.NewHTTP3Transport,
TransportName: "quic",
},
}
// DNSClient is a DNS client. It wraps a Resolver and it possibly
// also wraps an HTTP client, but only when we're using DoH.
type DNSClient struct {
Resolver
httpClient *http.Client
}
// CloseIdleConnections closes idle connections, if any.
func (c DNSClient) CloseIdleConnections() {
if c.httpClient != nil {
c.httpClient.CloseIdleConnections()
}
}
// NewDNSClient creates a new DNS client. The config argument is used to
// create the underlying Dialer and/or HTTP transport, if needed. The URL
// argument describes the kind of client that we want to make:
//
// - if the URL is `doh://powerdns`, `doh://google` or `doh://cloudflare` or the URL
// starts with `https://`, then we create a DoH client.
//
// - if the URL is `` or `system:///`, then we create a system client,
// i.e. a client using the system resolver.
//
// - if the URL starts with `udp://`, then we create a client using
// a resolver that uses the specified UDP endpoint.
//
// We return error if the URL does not parse or the URL scheme does not
// fall into one of the cases described above.
//
// If config.ResolveSaver is not nil and we're creating an underlying
// resolver where this is possible, we will also save events.
func NewDNSClient(config Config, URL string) (DNSClient, error) {
return NewDNSClientWithOverrides(config, URL, "", "", "")
}
// ErrInvalidTLSVersion indicates that you passed us a string
// that does not represent a valid TLS version.
var ErrInvalidTLSVersion = errors.New("invalid TLS version")
// ConfigureTLSVersion configures the correct TLS version into
// the specified *tls.Config or returns an error.
func ConfigureTLSVersion(config *tls.Config, version string) error {
switch version {
case "TLSv1.3":
config.MinVersion = tls.VersionTLS13
config.MaxVersion = tls.VersionTLS13
case "TLSv1.2":
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
case "TLSv1.1":
config.MinVersion = tls.VersionTLS11
config.MaxVersion = tls.VersionTLS11
case "TLSv1.0", "TLSv1":
config.MinVersion = tls.VersionTLS10
config.MaxVersion = tls.VersionTLS10
case "":
// nothing
default:
return ErrInvalidTLSVersion
}
return nil
}
// NewDNSClientWithOverrides creates a new DNS client, similar to NewDNSClient,
// with the option to override the default Hostname and SNI.
func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
TLSVersion string) (DNSClient, error) {
var c DNSClient
switch URL {
case "doh://powerdns":
URL = "https://doh.powerdns.org/"
case "doh://google":
URL = "https://dns.google/dns-query"
case "doh://cloudflare":
URL = "https://cloudflare-dns.com/dns-query"
case "":
URL = "system:///"
}
resolverURL, err := url.Parse(URL)
if err != nil {
return c, err
}
config.TLSConfig = &tls.Config{ServerName: SNIOverride}
if err := ConfigureTLSVersion(config.TLSConfig, TLSVersion); err != nil {
return c, err
}
switch resolverURL.Scheme {
case "system":
c.Resolver = resolver.SystemResolver{}
return c, nil
case "https":
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
c.httpClient = &http.Client{Transport: NewHTTPTransport(config)}
var txp resolver.RoundTripper = resolver.NewDNSOverHTTPSWithHostOverride(
c.httpClient, URL, hostOverride)
if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{
RoundTripper: txp,
Saver: config.ResolveSaver,
}
}
c.Resolver = resolver.NewSerialResolver(txp)
return c, nil
case "udp":
dialer := NewDialer(config)
endpoint, err := makeValidEndpoint(resolverURL)
if err != nil {
return c, err
}
var txp resolver.RoundTripper = resolver.NewDNSOverUDP(dialer, endpoint)
if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{
RoundTripper: txp,
Saver: config.ResolveSaver,
}
}
c.Resolver = resolver.NewSerialResolver(txp)
return c, nil
case "dot":
config.TLSConfig.NextProtos = []string{"dot"}
tlsDialer := NewTLSDialer(config)
endpoint, err := makeValidEndpoint(resolverURL)
if err != nil {
return c, err
}
var txp resolver.RoundTripper = resolver.NewDNSOverTLS(
tlsDialer.DialTLSContext, endpoint)
if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{
RoundTripper: txp,
Saver: config.ResolveSaver,
}
}
c.Resolver = resolver.NewSerialResolver(txp)
return c, nil
case "tcp":
dialer := NewDialer(config)
endpoint, err := makeValidEndpoint(resolverURL)
if err != nil {
return c, err
}
var txp resolver.RoundTripper = resolver.NewDNSOverTCP(
dialer.DialContext, endpoint)
if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{
RoundTripper: txp,
Saver: config.ResolveSaver,
}
}
c.Resolver = resolver.NewSerialResolver(txp)
return c, nil
default:
return c, errors.New("unsupported resolver scheme")
}
}
// makeValidEndpoint makes a valid endpoint for DoT and Do53 given the
// input URL representing such endpoint. Specifically, we are
// concerned with the case where the port is missing. In such a
// case, we ensure that we are using the default port 853 for DoT
// and default port 53 for TCP and UDP.
func makeValidEndpoint(URL *url.URL) (string, error) {
// Implementation note: when we're using a quoted IPv6
// address, URL.Host contains the quotes but instead the
// return value from URL.Hostname() does not.
//
// For example:
//
// - Host: [2620:fe::9]
// - Hostname(): 2620:fe::9
//
// We need to keep this in mind when trying to determine
// whether there is also a port or not.
//
// So the first step is to check whether URL.Host is already
// a whatever valid TCP/UDP endpoint and, if so, use it.
if _, _, err := net.SplitHostPort(URL.Host); err == nil {
return URL.Host, nil
}
// The second step is to assume that appending the default port
// to a host parsed by url.Parse should be giving us a valid
// endpoint. The possibilities in fact are:
//
// 1. domain w/o port
// 2. IPv4 w/o port
// 3. square bracket quoted IPv6 w/o port
// 4. other
//
// In the first three cases, appending a port leads us to a
// good endpoint. The fourth case does not.
//
// For this reason we check again whether we can split it using
// net.SplitHostPort. If we cannot, we were in case four.
host := URL.Host
if URL.Scheme == "dot" {
host += ":853"
} else {
host += ":53"
}
if _, _, err := net.SplitHostPort(host); err != nil {
return "", err
}
// Otherwise it's one of the three valid cases above.
return host, nil
}
@@ -0,0 +1,8 @@
package netx
import "crypto/x509"
// DefaultCertPool allows tests to access the default cert pool.
func DefaultCertPool() *x509.CertPool {
return defaultCertPool
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,14 @@
// +build !go1.15
package quicdialer
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
)
// ConnectionState returns the ConnectionState of a QUIC Session.
func ConnectionState(sess quic.EarlySession) tls.ConnectionState {
return tls.ConnectionState{}
}
@@ -0,0 +1,14 @@
// +build go1.15
package quicdialer
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
)
// ConnectionState returns the ConnectionState of a QUIC Session.
func ConnectionState(sess quic.EarlySession) tls.ConnectionState {
return sess.ConnectionState().ConnectionState
}
+59
View File
@@ -0,0 +1,59 @@
package quicdialer
import (
"context"
"crypto/tls"
"net"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
)
// DNSDialer is a dialer that uses the configured Resolver to resolve a
// domain name to IP addresses
type DNSDialer struct {
Dialer ContextDialer
Resolver Resolver
}
// DialContext implements ContextDialer.DialContext
func (d DNSDialer) DialContext(
ctx context.Context, network, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
onlyhost, onlyport, err := net.SplitHostPort(host)
if err != nil {
return nil, err
}
// TODO(kelmenhorst): Should this be somewhere else?
// failure if tlsCfg is nil but that should not happen
if tlsCfg.ServerName == "" {
tlsCfg.ServerName = onlyhost
}
ctx = dialid.WithDialID(ctx)
var addrs []string
addrs, err = d.LookupHost(ctx, onlyhost)
if err != nil {
return nil, err
}
var errorslist []error
for _, addr := range addrs {
target := net.JoinHostPort(addr, onlyport)
sess, err := d.Dialer.DialContext(
ctx, network, target, tlsCfg, cfg)
if err == nil {
return sess, nil
}
errorslist = append(errorslist, err)
}
// TODO(bassosimone): maybe ReduceErrors could be in netx/internal.
return nil, dialer.ReduceErrors(errorslist)
}
// LookupHost implements Resolver.LookupHost
func (d DNSDialer) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {
return []string{hostname}, nil
}
return d.Resolver.LookupHost(ctx, hostname)
}
+142
View File
@@ -0,0 +1,142 @@
package quicdialer_test
import (
"context"
"crypto/tls"
"errors"
"net"
"strconv"
"strings"
"testing"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
)
type MockableResolver struct {
Addresses []string
Err error
}
func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
return r.Addresses, r.Err
}
func TestDNSDialerSuccess(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com:443",
tlsConf, &quic.Config{})
if err != nil {
t.Fatal("unexpected error", err)
}
if sess == nil {
t.Fatal("non nil sess expected")
}
}
func TestDNSDialerNoPort(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected a nil sess here")
}
if err.Error() != "address www.google.com: missing port in address" {
t.Fatal("not the error we expected")
}
}
func TestDNSDialerLookupHostAddress(t *testing.T) {
dialer := quicdialer.DNSDialer{Resolver: MockableResolver{
Err: errors.New("mocked error"),
}}
addrs, err := dialer.LookupHost(context.Background(), "1.1.1.1")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
t.Fatal("not the result we expected")
}
}
func TestDNSDialerLookupHostFailure(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
expected := errors.New("mocked error")
dialer := quicdialer.DNSDialer{Resolver: MockableResolver{
Err: expected,
}}
sess, err := dialer.DialContext(
context.Background(), "udp", "dns.google.com:853",
tlsConf, &quic.Config{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if sess != nil {
t.Fatal("expected nil sess")
}
}
func TestDNSDialerInvalidPort(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com:0",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected nil sess")
}
if !strings.HasSuffix(err.Error(), "sendto: invalid argument") &&
!strings.HasSuffix(err.Error(), "sendto: can't assign requested address") {
t.Fatal("not the error we expected")
}
}
func TestDNSDialerInvalidPortSyntax(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com:port",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected nil sess")
}
if !errors.Is(err, strconv.ErrSyntax) {
t.Fatal("not the error we expected")
}
}
func TestDNSDialerDialEarlyFails(t *testing.T) {
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
expected := errors.New("mocked DialEarly error")
dialer := quicdialer.DNSDialer{
Resolver: new(net.Resolver), Dialer: MockDialer{Err: expected}}
sess, err := dialer.DialContext(
context.Background(), "udp", "www.google.com:443",
tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected nil sess")
}
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
}
@@ -0,0 +1,34 @@
package quicdialer
import (
"context"
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
// ErrorWrapperDialer is a dialer that performs quic err wrapping
type ErrorWrapperDialer struct {
Dialer ContextDialer
}
// DialContext implements ContextDialer.DialContext
func (d ErrorWrapperDialer) DialContext(
ctx context.Context, network string, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialID := dialid.ContextDialID(ctx)
sess, err := d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
err = errorx.SafeErrWrapperBuilder{
// ConnID does not make any sense if we've failed and the error
// does not make any sense (and is nil) if we succeded.
DialID: dialID,
Error: err,
Operation: errorx.QUICHandshakeOperation,
}.MaybeBuild()
if err != nil {
return nil, err
}
return sess, nil
}
@@ -0,0 +1,61 @@
package quicdialer_test
import (
"context"
"crypto/tls"
"errors"
"io"
"testing"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
)
func TestErrorWrapperFailure(t *testing.T) {
ctx := dialid.WithDialID(context.Background())
d := quicdialer.ErrorWrapperDialer{
Dialer: MockDialer{Sess: nil, Err: io.EOF}}
sess, err := d.DialContext(
ctx, "udp", "www.google.com:443", &tls.Config{}, &quic.Config{})
if sess != nil {
t.Fatal("expected a nil sess here")
}
errorWrapperCheckErr(t, err, errorx.QUICHandshakeOperation)
}
func errorWrapperCheckErr(t *testing.T, err error, op string) {
if !errors.Is(err, io.EOF) {
t.Fatal("expected another error here")
}
var errWrapper *errorx.ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("cannot cast to ErrWrapper")
}
if errWrapper.DialID == 0 {
t.Fatal("unexpected DialID")
}
if errWrapper.Operation != op {
t.Fatal("unexpected Operation")
}
if errWrapper.Failure != errorx.FailureEOFError {
t.Fatal("unexpected failure")
}
}
func TestErrorWrapperSuccess(t *testing.T) {
ctx := dialid.WithDialID(context.Background())
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
d := quicdialer.ErrorWrapperDialer{Dialer: quicdialer.SystemDialer{}}
sess, err := d.DialContext(ctx, "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal(err)
}
if sess == nil {
t.Fatal("expected non-nil sess here")
}
}
@@ -0,0 +1,26 @@
package quicdialer
import (
"context"
"crypto/tls"
"github.com/lucas-clemente/quic-go"
)
// ContextDialer is a dialer for QUIC using Context.
type ContextDialer interface {
// Note: assumes that tlsCfg and cfg are not nil.
DialContext(ctx context.Context, network, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
}
// Dialer dials QUIC connections.
type Dialer interface {
// Note: assumes that tlsCfg and cfg are not nil.
Dial(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
}
// Resolver is the interface we expect from a resolver.
type Resolver interface {
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
}
+62
View File
@@ -0,0 +1,62 @@
package quicdialer
import (
"context"
"crypto/tls"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// HandshakeSaver saves events occurring during the handshake
type HandshakeSaver struct {
Saver *trace.Saver
Dialer ContextDialer
}
// DialContext implements ContextDialer.DialContext
func (h HandshakeSaver) DialContext(ctx context.Context, network string,
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
start := time.Now()
// TODO(bassosimone): in the future we probably want to also save
// information about what versions we're willing to accept.
h.Saver.Write(trace.Event{
Address: host,
Name: "quic_handshake_start",
NoTLSVerify: tlsCfg.InsecureSkipVerify,
Proto: network,
TLSNextProtos: tlsCfg.NextProtos,
TLSServerName: tlsCfg.ServerName,
Time: start,
})
sess, err := h.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
stop := time.Now()
if err != nil {
h.Saver.Write(trace.Event{
Duration: stop.Sub(start),
Err: err,
Name: "quic_handshake_done",
NoTLSVerify: tlsCfg.InsecureSkipVerify,
TLSNextProtos: tlsCfg.NextProtos,
TLSServerName: tlsCfg.ServerName,
Time: stop,
})
return nil, err
}
state := ConnectionState(sess)
h.Saver.Write(trace.Event{
Duration: stop.Sub(start),
Name: "quic_handshake_done",
NoTLSVerify: tlsCfg.InsecureSkipVerify,
TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite),
TLSNegotiatedProto: state.NegotiatedProtocol,
TLSNextProtos: tlsCfg.NextProtos,
TLSPeerCerts: trace.PeerCerts(state, err),
TLSServerName: tlsCfg.ServerName,
TLSVersion: tlsx.VersionString(state.Version),
Time: stop,
})
return sess, nil
}
@@ -0,0 +1,118 @@
package quicdialer_test
import (
"context"
"crypto/tls"
"reflect"
"strings"
"testing"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
type MockDialer struct {
Dialer quicdialer.ContextDialer
Sess quic.EarlySession
Err error
}
func (d MockDialer) DialContext(ctx context.Context, network, host string,
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
if d.Dialer != nil {
return d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
}
return d.Sess, d.Err
}
func TestHandshakeSaverSuccess(t *testing.T) {
nextprotos := []string{"h3-29"}
servername := "www.google.com"
tlsConf := &tls.Config{
NextProtos: nextprotos,
ServerName: servername,
}
saver := &trace.Saver{}
dlr := quicdialer.HandshakeSaver{
Dialer: quicdialer.SystemDialer{},
Saver: saver,
}
sess, err := dlr.DialContext(context.Background(), "udp",
"216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal("unexpected error", err)
}
if sess == nil {
t.Fatal("unexpected nil sess")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("unexpected number of events")
}
if ev[0].Name != "quic_handshake_start" {
t.Fatal("unexpected Name")
}
if ev[0].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[0].Time.After(time.Now()) {
t.Fatal("unexpected Time")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Err != nil {
t.Fatal("unexpected Err", ev[1].Err)
}
if ev[1].Name != "quic_handshake_done" {
t.Fatal("unexpected Name")
}
if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) {
t.Fatal("unexpected TLSNextProtos")
}
if ev[1].TLSServerName != "www.google.com" {
t.Fatal("unexpected TLSServerName")
}
if ev[1].Time.Before(ev[0].Time) {
t.Fatal("unexpected Time")
}
}
func TestHandshakeSaverHostNameError(t *testing.T) {
nextprotos := []string{"h3-29"}
servername := "wrong.host.badssl.com"
tlsConf := &tls.Config{
NextProtos: nextprotos,
ServerName: servername,
}
saver := &trace.Saver{}
dlr := quicdialer.HandshakeSaver{
Dialer: quicdialer.SystemDialer{},
Saver: saver,
}
sess, err := dlr.DialContext(context.Background(), "udp",
"216.58.212.164:443", tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected nil sess here")
}
for _, ev := range saver.Read() {
if ev.Name != "quic_handshake_done" {
continue
}
if ev.NoTLSVerify == true {
t.Fatal("expected NoTLSVerify to be false")
}
if !strings.Contains(ev.Err.Error(),
"certificate is valid for www.google.com, not "+servername) {
t.Fatal("unexpected error", ev.Err)
}
}
}
+87
View File
@@ -0,0 +1,87 @@
package quicdialer
import (
"context"
"crypto/tls"
"errors"
"net"
"strconv"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// SystemDialer is the basic dialer for QUIC
type SystemDialer struct {
// Saver saves read/write events on the underlying UDP
// connection. (Implementation note: we need it here since
// this is the only part in the codebase that is able to
// observe the underlying UDP connection.)
Saver *trace.Saver
}
// DialContext implements ContextDialer.DialContext
func (d SystemDialer) DialContext(ctx context.Context, network string,
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
onlyhost, onlyport, err := net.SplitHostPort(host)
port, err := strconv.Atoi(onlyport)
if err != nil {
return nil, err
}
ip := net.ParseIP(onlyhost)
if ip == nil {
// TODO(kelmenhorst): write test for this error condition.
return nil, errors.New("quicdialer: invalid IP representation")
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
var pconn net.PacketConn = udpConn
if d.Saver != nil {
pconn = saverUDPConn{UDPConn: udpConn, saver: d.Saver}
}
udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""}
return quic.DialEarlyContext(ctx, pconn, udpAddr, host, tlsCfg, cfg)
}
type saverUDPConn struct {
*net.UDPConn
saver *trace.Saver
}
func (c saverUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) {
start := time.Now()
count, err := c.UDPConn.WriteTo(p, addr)
stop := time.Now()
c.saver.Write(trace.Event{
Address: addr.String(),
Data: p[:count],
Duration: stop.Sub(start),
Err: err,
NumBytes: count,
Name: errorx.WriteToOperation,
Time: stop,
})
return count, err
}
func (c saverUDPConn) ReadMsgUDP(b, oob []byte) (int, int, int, *net.UDPAddr, error) {
start := time.Now()
n, oobn, flags, addr, err := c.UDPConn.ReadMsgUDP(b, oob)
stop := time.Now()
var data []byte
if n > 0 {
data = b[:n]
}
c.saver.Write(trace.Event{
Address: addr.String(),
Data: data,
Duration: stop.Sub(start),
Err: err,
NumBytes: n,
Name: errorx.ReadFromOperation,
Time: stop,
})
return n, oobn, flags, addr, err
}
@@ -0,0 +1,75 @@
package quicdialer_test
import (
"context"
"crypto/tls"
"testing"
"github.com/lucas-clemente/quic-go"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
func TestSystemDialerInvalidIPFailure(t *testing.T) {
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
saver := &trace.Saver{}
systemdialer := quicdialer.SystemDialer{
Saver: saver,
}
sess, err := systemdialer.DialContext(context.Background(), "udp", "a.b.c.d:0", tlsConf, &quic.Config{})
if err == nil {
t.Fatal("expected an error here")
}
if sess != nil {
t.Fatal("expected nil sess here")
}
if err.Error() != "quicdialer: invalid IP representation" {
t.Fatal("expected another error here")
}
}
func TestSystemDialerSuccessWithReadWrite(t *testing.T) {
// This is the most common use case for collecting reads, writes
tlsConf := &tls.Config{
NextProtos: []string{"h3-29"},
ServerName: "www.google.com",
}
saver := &trace.Saver{}
systemdialer := quicdialer.SystemDialer{Saver: saver}
_, err := systemdialer.DialContext(context.Background(), "udp",
"216.58.212.164:443", tlsConf, &quic.Config{})
if err != nil {
t.Fatal(err)
}
ev := saver.Read()
if len(ev) < 2 {
t.Fatal("unexpected number of events")
}
last := len(ev) - 1
for idx := 1; idx < last; idx++ {
if ev[idx].Data == nil {
t.Fatal("unexpected Data")
}
if ev[idx].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[idx].Err != nil {
t.Fatal("unexpected Err")
}
if ev[idx].NumBytes <= 0 {
t.Fatal("unexpected NumBytes")
}
switch ev[idx].Name {
case errorx.ReadFromOperation, errorx.WriteToOperation:
default:
t.Fatal("unexpected Name")
}
if ev[idx].Time.Before(ev[idx-1].Time) {
t.Fatal("unexpected Time")
}
}
}
+22
View File
@@ -0,0 +1,22 @@
package resolver
import (
"context"
"net"
)
// AddressResolver is a resolver that knows how to correctly
// resolve IP addresses to themselves.
type AddressResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r AddressResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {
return []string{hostname}, nil
}
return r.Resolver.LookupHost(ctx, hostname)
}
var _ Resolver = AddressResolver{}
@@ -0,0 +1,36 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestAddressSuccess(t *testing.T) {
r := resolver.AddressResolver{}
addrs, err := r.LookupHost(context.Background(), "8.8.8.8")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
}
func TestAddressFailure(t *testing.T) {
expected := errors.New("mocked error")
r := resolver.AddressResolver{
Resolver: resolver.FakeResolver{
Err: expected,
},
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil addrs")
}
}
+71
View File
@@ -0,0 +1,71 @@
package resolver
import (
"context"
"net"
"github.com/ooni/probe-cli/v3/internal/engine/internal/runtimex"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
var privateIPBlocks []*net.IPNet
func init() {
for _, cidr := range []string{
"0.0.0.0/8", // "This" network (however, Linux...)
"10.0.0.0/8", // RFC1918
"100.64.0.0/10", // Carrier grade NAT
"127.0.0.0/8", // IPv4 loopback
"169.254.0.0/16", // RFC3927 link-local
"172.16.0.0/12", // RFC1918
"192.168.0.0/16", // RFC1918
"224.0.0.0/4", // Multicast
"::1/128", // IPv6 loopback
"fe80::/10", // IPv6 link-local
"fc00::/7", // IPv6 unique local addr
} {
_, block, err := net.ParseCIDR(cidr)
runtimex.PanicOnError(err, "net.ParseCIDR failed")
privateIPBlocks = append(privateIPBlocks, block)
}
}
func isPrivate(ip net.IP) bool {
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
for _, block := range privateIPBlocks {
if block.Contains(ip) {
return true
}
}
return false
}
// IsBogon returns whether if an IP address is bogon. Passing to this
// function a non-IP address causes it to return bogon.
func IsBogon(address string) bool {
ip := net.ParseIP(address)
return ip == nil || isPrivate(ip)
}
// BogonResolver is a bogon aware resolver. When a bogon is encountered in
// a reply, this resolver will return an error.
type BogonResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r BogonResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname)
for _, addr := range addrs {
if IsBogon(addr) == true {
// We need to return the addrs otherwise the caller cannot see/log/save
// the specific addresses that triggered our bogon filter
return addrs, errorx.ErrDNSBogon
}
}
return addrs, err
}
var _ Resolver = BogonResolver{}
@@ -0,0 +1,52 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestResolverIsBogon(t *testing.T) {
if resolver.IsBogon("antani") != true {
t.Fatal("unexpected result")
}
if resolver.IsBogon("127.0.0.1") != true {
t.Fatal("unexpected result")
}
if resolver.IsBogon("1.1.1.1") != false {
t.Fatal("unexpected result")
}
if resolver.IsBogon("10.0.1.1") != true {
t.Fatal("unexpected result")
}
}
func TestBogonAwareResolverWithBogon(t *testing.T) {
r := resolver.BogonResolver{
Resolver: resolver.NewFakeResolverWithResult([]string{"127.0.0.1"}),
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if !errors.Is(err, errorx.ErrDNSBogon) {
t.Fatal("not the error we expected")
}
if len(addrs) != 1 || addrs[0] != "127.0.0.1" {
t.Fatal("expected to see address here")
}
}
func TestBogonAwareResolverWithoutBogon(t *testing.T) {
orig := []string{"8.8.8.8"}
r := resolver.BogonResolver{
Resolver: resolver.NewFakeResolverWithResult(orig),
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != len(orig) || addrs[0] != orig[0] {
t.Fatal("not the error we expected")
}
}
+47
View File
@@ -0,0 +1,47 @@
package resolver
import (
"context"
"sync"
)
// CacheResolver is a resolver that caches successful replies.
type CacheResolver struct {
ReadOnly bool
Resolver
mu sync.Mutex
cache map[string][]string
}
// LookupHost implements Resolver.LookupHost
func (r *CacheResolver) LookupHost(
ctx context.Context, hostname string) ([]string, error) {
if entry := r.Get(hostname); entry != nil {
return entry, nil
}
entry, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil {
return nil, err
}
if r.ReadOnly == false {
r.Set(hostname, entry)
}
return entry, nil
}
// Get gets the currently configured entry for domain, or nil
func (r *CacheResolver) Get(domain string) []string {
r.mu.Lock()
defer r.mu.Unlock()
return r.cache[domain]
}
// Set allows to pre-populate the cache
func (r *CacheResolver) Set(domain string, addresses []string) {
r.mu.Lock()
if r.cache == nil {
r.cache = make(map[string][]string)
}
r.cache[domain] = addresses
r.mu.Unlock()
}
@@ -0,0 +1,76 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestCacheFailure(t *testing.T) {
expected := errors.New("mocked error")
var r resolver.Resolver = resolver.FakeResolver{
Err: expected,
}
cache := &resolver.CacheResolver{Resolver: r}
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil addrs here")
}
if cache.Get("www.google.com") != nil {
t.Fatal("expected empty cache here")
}
}
func TestCacheHitSuccess(t *testing.T) {
var r resolver.Resolver = resolver.FakeResolver{
Err: errors.New("mocked error"),
}
cache := &resolver.CacheResolver{Resolver: r}
cache.Set("dns.google.com", []string{"8.8.8.8"})
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
}
func TestCacheMissSuccess(t *testing.T) {
var r resolver.Resolver = resolver.FakeResolver{
Result: []string{"8.8.8.8"},
}
cache := &resolver.CacheResolver{Resolver: r}
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
if cache.Get("dns.google.com")[0] != "8.8.8.8" {
t.Fatal("expected full cache here")
}
}
func TestCacheReadonlySuccess(t *testing.T) {
var r resolver.Resolver = resolver.FakeResolver{
Result: []string{"8.8.8.8"},
}
cache := &resolver.CacheResolver{Resolver: r, ReadOnly: true}
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected")
}
if cache.Get("dns.google.com") != nil {
t.Fatal("expected empty cache here")
}
}
+33
View File
@@ -0,0 +1,33 @@
package resolver
import (
"context"
)
// ChainResolver is a chain resolver. The primary resolver is used first and, if that
// fails, we then attempt with the secondary resolver.
type ChainResolver struct {
Primary Resolver
Secondary Resolver
}
// LookupHost implements Resolver.LookupHost
func (c ChainResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := c.Primary.LookupHost(ctx, hostname)
if err != nil {
addrs, err = c.Secondary.LookupHost(ctx, hostname)
}
return addrs, err
}
// Network implements Resolver.Network
func (c ChainResolver) Network() string {
return "chain"
}
// Address implements Resolver.Address
func (c ChainResolver) Address() string {
return ""
}
var _ Resolver = ChainResolver{}
@@ -0,0 +1,28 @@
package resolver_test
import (
"context"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestChainLookupHost(t *testing.T) {
r := resolver.ChainResolver{
Primary: resolver.NewFakeResolverThatFails(),
Secondary: resolver.SystemResolver{},
}
if r.Address() != "" {
t.Fatal("invalid address")
}
if r.Network() != "chain" {
t.Fatal("invalid network")
}
addrs, err := r.LookupHost(context.Background(), "www.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expect non nil return value here")
}
}
+54
View File
@@ -0,0 +1,54 @@
package resolver
import (
"errors"
"github.com/miekg/dns"
)
// The Decoder decodes a DNS reply into A or AAAA entries. It will use the
// provided qtype and only look for mathing entries. It will return error if
// there are no entries for the requested qtype inside the reply.
type Decoder interface {
Decode(qtype uint16, data []byte) ([]string, error)
}
// MiekgDecoder uses github.com/miekg/dns to implement the Decoder.
type MiekgDecoder struct{}
// Decode implements Decoder.Decode.
func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
reply := new(dns.Msg)
if err := reply.Unpack(data); err != nil {
return nil, err
}
// TODO(bassosimone): map more errors to net.DNSError names
switch reply.Rcode {
case dns.RcodeSuccess:
case dns.RcodeNameError:
return nil, errors.New("ooniresolver: no such host")
default:
return nil, errors.New("ooniresolver: query failed")
}
var addrs []string
for _, answer := range reply.Answer {
switch qtype {
case dns.TypeA:
if rra, ok := answer.(*dns.A); ok {
ip := rra.A
addrs = append(addrs, ip.String())
}
case dns.TypeAAAA:
if rra, ok := answer.(*dns.AAAA); ok {
ip := rra.AAAA
addrs = append(addrs, ip.String())
}
}
}
if len(addrs) <= 0 {
return nil, errors.New("ooniresolver: no response returned")
}
return addrs, nil
}
var _ Decoder = MiekgDecoder{}
@@ -0,0 +1,113 @@
package resolver_test
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDecoderUnpackError(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, nil)
if err == nil {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderNXDOMAIN(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeNameError))
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderOtherError(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeRefused))
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderNoAddress(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderDecodeA(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "1.1.1.1" {
t.Fatal("invalid first IPv4 entry")
}
if data[1] != "8.8.8.8" {
t.Fatal("invalid second IPv4 entry")
}
}
func TestDecoderDecodeAAAA(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "::1" {
t.Fatal("invalid first IPv6 entry")
}
if data[1] != "fe80::1" {
t.Fatal("invalid second IPv6 entry")
}
}
func TestDecoderUnexpectedAReply(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected nil data here")
}
}
@@ -0,0 +1,77 @@
package resolver
import (
"bytes"
"context"
"errors"
"io/ioutil"
"net/http"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
)
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
// an HTTP/HTTPS channel provided by URL using the Do function.
type DNSOverHTTPS struct {
Do func(req *http.Request) (*http.Response, error)
URL string
HostOverride string
}
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
// specified http.Client and URL, as a convenience.
func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS {
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
}
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
// it's creating a resolver where we use the specified host.
func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS {
return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
if err != nil {
return nil, err
}
req.Host = t.HostOverride
req.Header.Set("user-agent", httpheader.UserAgent())
req.Header.Set("content-type", "application/dns-message")
var resp *http.Response
resp, err = t.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
// TODO(bassosimone): we should map the status code to a
// proper Error in the DNS context.
return nil, errors.New("doh: server returned error")
}
if resp.Header.Get("content-type") != "application/dns-message" {
return nil, errors.New("doh: invalid content-type")
}
return ioutil.ReadAll(resp.Body)
}
// RequiresPadding returns true for DoH according to RFC8467
func (t DNSOverHTTPS) RequiresPadding() bool {
return true
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverHTTPS) Network() string {
return "doh"
}
// Address returns the upstream server address.
func (t DNSOverHTTPS) Address() string {
return t.URL
}
var _ RoundTripper = DNSOverHTTPS{}
@@ -0,0 +1,165 @@
package resolver_test
import (
"bytes"
"context"
"errors"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
const invalidURL = "\t"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
expected := errors.New("mocked error")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 500,
Body: ioutil.NopCloser(strings.NewReader("")),
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverHTTPSSuccess(t *testing.T) {
body := []byte("AAA")
txp := resolver.DNSOverHTTPS{
Do: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, body) {
t.Fatal("not the response we expected")
}
}
func TestDNSOverHTTPTransportOK(t *testing.T) {
const queryURL = "https://cloudflare-dns.com/dns-query"
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
if txp.Network() != "doh" {
t.Fatal("invalid network")
}
if txp.RequiresPadding() != true {
t.Fatal("should require padding")
}
if txp.Address() != queryURL {
t.Fatal("invalid address")
}
}
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
expected := errors.New("mocked error")
var correct bool
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct user agent")
}
}
func TestDNSOverHTTPSHostOverride(t *testing.T) {
var correct bool
expected := errors.New("mocked error")
hostOverride := "test.com"
txp := resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
correct = req.Host == hostOverride
return nil, expected
},
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
}
@@ -0,0 +1,97 @@
package resolver
import (
"context"
"errors"
"io"
"math"
"net"
"time"
)
// DialContextFunc is a generic function for dialing a connection.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
// and NewDNSOverTLS to create specific instances that use plaintext
// queries or encrypted queries over TLS.
//
// As a known bug, this implementation always creates a new connection
// for each incoming query, thus increasing the response delay.
type DNSOverTCP struct {
dial DialContextFunc
address string
network string
requiresPadding bool
}
// NewDNSOverTCP creates a new DNSOverTCP transport.
func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
return DNSOverTCP{
dial: dial,
address: address,
network: "tcp",
requiresPadding: false,
}
}
// NewDNSOverTLS creates a new DNSOverTLS transport.
func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
return DNSOverTCP{
dial: dial,
address: address,
network: "dot",
requiresPadding: true,
}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
if len(query) > math.MaxUint16 {
return nil, errors.New("query too long")
}
conn, err := t.dial(ctx, "tcp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return nil, err
}
// Write request
buf := []byte{byte(len(query) >> 8)}
buf = append(buf, byte(len(query)))
buf = append(buf, query...)
if _, err = conn.Write(buf); err != nil {
return nil, err
}
// Read response
header := make([]byte, 2)
if _, err = io.ReadFull(conn, header); err != nil {
return nil, err
}
length := int(header[0])<<8 | int(header[1])
reply := make([]byte, length)
if _, err = io.ReadFull(conn, reply); err != nil {
return nil, err
}
return reply, nil
}
// RequiresPadding returns true for DoT and false for TCP
// according to RFC8467.
func (t DNSOverTCP) RequiresPadding() bool {
return t.requiresPadding
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverTCP) Network() string {
return t.network
}
// Address returns the upstream server address.
func (t DNSOverTCP) Address() string {
return t.address
}
var _ RoundTripper = DNSOverTCP{}
@@ -0,0 +1,146 @@
package resolver_test
import (
"context"
"errors"
"net"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
if err == nil {
t.Fatal("expected an error here")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Err: mocked}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
SetDeadlineError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
WriteError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
ReadError: mocked,
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(2)},
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
}
func TestDNSOverTCPTransportAllGood(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
ReadError: mocked,
ReadData: []byte{byte(0), byte(1), byte(1)},
}}
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if err != nil {
t.Fatal(err)
}
if len(reply) != 1 || reply[0] != 1 {
t.Fatal("not the response we expected")
}
}
func TestDNSOverTCPTransportOK(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "tcp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
}
func TestDNSOverTLSTransportOK(t *testing.T) {
const address = "9.9.9.9:853"
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "dot" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
}
@@ -0,0 +1,64 @@
package resolver
import (
"context"
"net"
"time"
)
// Dialer is the network dialer interface assumed by this package.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// DNSOverUDP is a DNS over UDP RoundTripper.
type DNSOverUDP struct {
dialer Dialer
address string
}
// NewDNSOverUDP creates a DNSOverUDP instance.
func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP {
return DNSOverUDP{dialer: dialer, address: address}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
// Use five seconds timeout like Bionic does. See
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, err
}
if _, err = conn.Write(query); err != nil {
return nil, err
}
reply := make([]byte, 1<<17)
var n int
n, err = conn.Read(reply)
if err != nil {
return nil, err
}
return reply[:n], nil
}
// RequiresPadding returns false for UDP according to RFC8467
func (t DNSOverUDP) RequiresPadding() bool {
return false
}
// Network returns the transport network (e.g., doh, dot)
func (t DNSOverUDP) Network() string {
return "udp"
}
// Address returns the upstream server address.
func (t DNSOverUDP) Address() string {
return t.address
}
var _ RoundTripper = DNSOverUDP{}
@@ -0,0 +1,107 @@
package resolver_test
import (
"context"
"errors"
"net"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDNSOverUDPDialFailure(t *testing.T) {
mocked := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
SetDeadlineError: mocked,
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPWriteFailure(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
WriteError: mocked,
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPReadFailure(t *testing.T) {
mocked := errors.New("mocked error")
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{
ReadError: mocked,
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
}
func TestDNSOverUDPReadSuccess(t *testing.T) {
const expected = 17
txp := resolver.NewDNSOverUDP(
resolver.FakeDialer{
Conn: &resolver.FakeConn{ReadData: make([]byte, 17)},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(data) != expected {
t.Fatal("expected non nil data")
}
}
func TestDNSOverUDPTransportOK(t *testing.T) {
const address = "9.9.9.9:53"
txp := resolver.NewDNSOverUDP(&net.Dialer{}, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "udp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
}
+89
View File
@@ -0,0 +1,89 @@
package resolver
import (
"context"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
)
// EmitterTransport is a RoundTripper that emits events when they occur.
type EmitterTransport struct {
RoundTripper
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp EmitterTransport) RoundTrip(ctx context.Context, querydata []byte) ([]byte, error) {
root := modelx.ContextMeasurementRootOrDefault(ctx)
root.Handler.OnMeasurement(modelx.Measurement{
DNSQuery: &modelx.DNSQueryEvent{
Data: querydata,
DialID: dialid.ContextDialID(ctx),
DurationSinceBeginning: time.Now().Sub(root.Beginning),
},
})
replydata, err := txp.RoundTripper.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
root.Handler.OnMeasurement(modelx.Measurement{
DNSReply: &modelx.DNSReplyEvent{
Data: replydata,
DialID: dialid.ContextDialID(ctx),
DurationSinceBeginning: time.Now().Sub(root.Beginning),
},
})
return replydata, nil
}
// EmitterResolver is a resolver that emits events
type EmitterResolver struct {
Resolver
}
// LookupHost returns the IP addresses of a host
func (r EmitterResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
var (
network string
address string
)
type queryableResolver interface {
Transport() RoundTripper
}
if qr, ok := r.Resolver.(queryableResolver); ok {
txp := qr.Transport()
network, address = txp.Network(), txp.Address()
}
dialID := dialid.ContextDialID(ctx)
txID := transactionid.ContextTransactionID(ctx)
root := modelx.ContextMeasurementRootOrDefault(ctx)
root.Handler.OnMeasurement(modelx.Measurement{
ResolveStart: &modelx.ResolveStartEvent{
DialID: dialID,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
Hostname: hostname,
TransactionID: txID,
TransportAddress: address,
TransportNetwork: network,
},
})
addrs, err := r.Resolver.LookupHost(ctx, hostname)
root.Handler.OnMeasurement(modelx.Measurement{
ResolveDone: &modelx.ResolveDoneEvent{
Addresses: addrs,
DialID: dialID,
DurationSinceBeginning: time.Now().Sub(root.Beginning),
Error: err,
Hostname: hostname,
TransactionID: txID,
TransportAddress: address,
TransportNetwork: network,
},
})
return addrs, err
}
var _ RoundTripper = EmitterTransport{}
var _ Resolver = EmitterResolver{}
@@ -0,0 +1,220 @@
package resolver_test
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"testing"
"time"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestEmitterTransportSuccess(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
}}
e := resolver.MiekgEncoder{}
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
if err != nil {
t.Fatal(err)
}
replydata, err := txp.RoundTrip(ctx, querydata)
if err != nil {
t.Fatal(err)
}
events := handler.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
if events[0].DNSQuery == nil {
t.Fatal("missing DNSQuery field")
}
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
t.Fatal("invalid query data")
}
if events[0].DNSQuery.DialID == 0 {
t.Fatal("invalid query DialID")
}
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
if events[1].DNSReply == nil {
t.Fatal("missing DNSReply field")
}
if !bytes.Equal(events[1].DNSReply.Data, replydata) {
t.Fatal("missing reply data")
}
if events[1].DNSReply.DialID != 1 {
t.Fatal("invalid query DialID")
}
if events[1].DNSReply.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
}
func TestEmitterTransportFailure(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
mocked := errors.New("mocked error")
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
Err: mocked,
}}
e := resolver.MiekgEncoder{}
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
if err != nil {
t.Fatal(err)
}
replydata, err := txp.RoundTrip(ctx, querydata)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if replydata != nil {
t.Fatal("expected nil replydata")
}
events := handler.Read()
if len(events) != 1 {
t.Fatal("unexpected number of events")
}
if events[0].DNSQuery == nil {
t.Fatal("missing DNSQuery field")
}
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
t.Fatal("invalid query data")
}
if events[0].DNSQuery.DialID == 0 {
t.Fatal("invalid query DialID")
}
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
}
func TestEmitterResolverFailure(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
ctx = transactionid.WithTransactionID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
resolver.DNSOverHTTPS{
Do: func(req *http.Request) (*http.Response, error) {
return nil, io.EOF
},
URL: "https://dns.google.com/",
},
)}
replies, err := r.LookupHost(ctx, "www.google.com")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if replies != nil {
t.Fatal("expected nil replies")
}
events := handler.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
if events[0].ResolveStart == nil {
t.Fatal("missing ResolveStart field")
}
if events[0].ResolveStart.DialID == 0 {
t.Fatal("invalid DialID")
}
if events[0].ResolveStart.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
if events[0].ResolveStart.Hostname != "www.google.com" {
t.Fatal("invalid Hostname")
}
if events[0].ResolveStart.TransactionID == 0 {
t.Fatal("invalid TransactionID")
}
if events[0].ResolveStart.TransportAddress != "https://dns.google.com/" {
t.Fatal("invalid TransportAddress")
}
if events[0].ResolveStart.TransportNetwork != "doh" {
t.Fatal("invalid TransportNetwork")
}
if events[1].ResolveDone == nil {
t.Fatal("missing ResolveDone field")
}
if events[1].ResolveDone.DialID == 0 {
t.Fatal("invalid DialID")
}
if events[1].ResolveDone.DurationSinceBeginning <= 0 {
t.Fatal("invalid duration since beginning")
}
if events[1].ResolveDone.Error != io.EOF {
t.Fatal("invalid Error")
}
if events[1].ResolveDone.Hostname != "www.google.com" {
t.Fatal("invalid Hostname")
}
if events[1].ResolveDone.TransactionID == 0 {
t.Fatal("invalid TransactionID")
}
if events[1].ResolveDone.TransportAddress != "https://dns.google.com/" {
t.Fatal("invalid TransportAddress")
}
if events[1].ResolveDone.TransportNetwork != "doh" {
t.Fatal("invalid TransportNetwork")
}
}
func TestEmitterResolverSuccess(t *testing.T) {
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
ctx = transactionid.WithTransactionID(ctx)
handler := &handlers.SavingHandler{}
root := &modelx.MeasurementRoot{
Beginning: time.Now(),
Handler: handler,
}
ctx = modelx.WithMeasurementRoot(ctx, root)
r := resolver.EmitterResolver{Resolver: resolver.NewFakeResolverWithResult(
[]string{"8.8.8.8"},
)}
replies, err := r.LookupHost(ctx, "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(replies) != 1 {
t.Fatal("expected a single replies")
}
events := handler.Read()
if len(events) != 2 {
t.Fatal("unexpected number of events")
}
if events[1].ResolveDone == nil {
t.Fatal("missing ResolveDone field")
}
if events[1].ResolveDone.Addresses[0] != "8.8.8.8" {
t.Fatal("invalid Addresses")
}
}
+52
View File
@@ -0,0 +1,52 @@
package resolver
import "github.com/miekg/dns"
// The Encoder encodes DNS queries to bytes
type Encoder interface {
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
}
// MiekgEncoder uses github.com/miekg/dns to implement the Encoder.
type MiekgEncoder struct{}
const (
// PaddingDesiredBlockSize is the size that the padded query should be multiple of
PaddingDesiredBlockSize = 128
// EDNS0MaxResponseSize is the maximum response size for EDNS0
EDNS0MaxResponseSize = 4096
// DNSSECEnabled turns on support for DNSSEC when using EDNS0
DNSSECEnabled = true
)
// Encode implements Encoder.Encode
func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
question := dns.Question{
Name: dns.Fqdn(domain),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
if padding {
query.SetEdns0(EDNS0MaxResponseSize, DNSSECEnabled)
// Clients SHOULD pad queries to the closest multiple of
// 128 octets RFC8467#section-4.1. We inflate the query
// length by the size of the option (i.e. 4 octets). The
// cast to uint is necessary to make the modulus operation
// work as intended when the desiredBlockSize is smaller
// than (query.Len()+4) ¯\_(ツ)_/¯.
remainder := (PaddingDesiredBlockSize - uint(query.Len()+4)) % PaddingDesiredBlockSize
opt := new(dns.EDNS0_PADDING)
opt.Padding = make([]byte, remainder)
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
}
return query.Pack()
}
var _ Encoder = MiekgEncoder{}
@@ -0,0 +1,99 @@
package resolver_test
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestEncoderEncodeA(t *testing.T) {
e := resolver.MiekgEncoder{}
data, err := e.Encode("x.org", dns.TypeA, false)
if err != nil {
t.Fatal(err)
}
validate(t, data, byte(dns.TypeA))
}
func TestEncoderEncodeAAAA(t *testing.T) {
e := resolver.MiekgEncoder{}
data, err := e.Encode("x.org", dns.TypeAAAA, false)
if err != nil {
t.Fatal(err)
}
validate(t, data, byte(dns.TypeA))
}
func validate(t *testing.T, data []byte, qtype byte) {
// skipping over the query ID
if data[2] != 1 {
t.Fatal("FLAGS should only have RD set")
}
if data[3] != 0 {
t.Fatal("RA|Z|Rcode should be zero")
}
if data[4] != 0 || data[5] != 1 {
t.Fatal("QCOUNT high should be one")
}
if data[6] != 0 || data[7] != 0 {
t.Fatal("ANCOUNT should be zero")
}
if data[8] != 0 || data[9] != 0 {
t.Fatal("NSCOUNT should be zero")
}
if data[10] != 0 || data[11] != 0 {
t.Fatal("ARCOUNT should be zero")
}
t.Log(data[12])
if data[12] != 1 || data[13] != byte('x') {
t.Fatal("The name does not contain 1:x")
}
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
t.Fatal("The name does not containg 3:org")
}
if data[18] != 0 {
t.Fatal("The name does not terminate where expected")
}
if data[19] != 0 && data[20] != qtype {
t.Fatal("The query is not for the expected type")
}
if data[21] != 0 && data[22] != 1 {
t.Fatal("The query is not IN")
}
}
func TestEncoderPadding(t *testing.T) {
// The purpose of this unit test is to make sure that for a wide
// array of values we obtain the right query size.
getquerylen := func(domainlen int, padding bool) int {
e := resolver.MiekgEncoder{}
data, err := e.Encode(
// This is not a valid name because it ends up being way
// longer than 255 octets. However, the library is allowing
// us to generate such name and we are not going to send
// it on the wire. Also, we check below that the query that
// we generate is long enough, so we should be good.
dns.Fqdn(strings.Repeat("x.", domainlen)),
dns.TypeA, padding,
)
if err != nil {
t.Fatal(err)
}
return len(data)
}
for domainlen := 1; domainlen <= 4000; domainlen++ {
vanillalen := getquerylen(domainlen, false)
paddedlen := getquerylen(domainlen, true)
if vanillalen < domainlen {
t.Fatal("vanillalen is smaller than domainlen")
}
if (paddedlen % resolver.PaddingDesiredBlockSize) != 0 {
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
}
if paddedlen < vanillalen {
t.Fatal("paddedlen is smaller than vanillalen")
}
}
}
@@ -0,0 +1,30 @@
package resolver
import (
"context"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
)
// ErrorWrapperResolver is a Resolver that knows about wrapping errors.
type ErrorWrapperResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r ErrorWrapperResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
dialID := dialid.ContextDialID(ctx)
txID := transactionid.ContextTransactionID(ctx)
addrs, err := r.Resolver.LookupHost(ctx, hostname)
err = errorx.SafeErrWrapperBuilder{
DialID: dialID,
Error: err,
Operation: errorx.ResolveOperation,
TransactionID: txID,
}.MaybeBuild()
return addrs, err
}
var _ Resolver = ErrorWrapperResolver{}
@@ -0,0 +1,58 @@
package resolver_test
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestErrorWrapperSuccess(t *testing.T) {
orig := []string{"8.8.8.8"}
r := resolver.ErrorWrapperResolver{
Resolver: resolver.NewFakeResolverWithResult(orig),
}
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if len(addrs) != len(orig) || addrs[0] != orig[0] {
t.Fatal("not the result we expected")
}
}
func TestErrorWrapperFailure(t *testing.T) {
r := resolver.ErrorWrapperResolver{
Resolver: resolver.NewFakeResolverThatFails(),
}
ctx := context.Background()
ctx = dialid.WithDialID(ctx)
ctx = transactionid.WithTransactionID(ctx)
addrs, err := r.LookupHost(ctx, "dns.google.com")
if addrs != nil {
t.Fatal("expected nil addr here")
}
var errWrapper *errorx.ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("cannot properly cast the returned error")
}
if errWrapper.Failure != errorx.FailureDNSNXDOMAINError {
t.Fatal("unexpected failure")
}
if errWrapper.ConnID != 0 {
t.Fatal("unexpected ConnID")
}
if errWrapper.DialID == 0 {
t.Fatal("unexpected DialID")
}
if errWrapper.TransactionID == 0 {
t.Fatal("unexpected TransactionID")
}
if errWrapper.Operation != errorx.ResolveOperation {
t.Fatal("unexpected Operation")
}
}
+142
View File
@@ -0,0 +1,142 @@
package resolver
import (
"context"
"io"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
)
type FakeDialer struct {
Conn net.Conn
Err error
}
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
time.Sleep(10 * time.Microsecond)
return d.Conn, d.Err
}
type FakeConn struct {
ReadError error
ReadData []byte
SetDeadlineError error
SetReadDeadlineError error
SetWriteDeadlineError error
WriteError error
}
func (c *FakeConn) Read(b []byte) (int, error) {
if len(c.ReadData) > 0 {
n := copy(b, c.ReadData)
c.ReadData = c.ReadData[n:]
return n, nil
}
if c.ReadError != nil {
return 0, c.ReadError
}
return 0, io.EOF
}
func (c *FakeConn) Write(b []byte) (n int, err error) {
if c.WriteError != nil {
return 0, c.WriteError
}
n = len(b)
return
}
func (*FakeConn) Close() (err error) {
return
}
func (*FakeConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (*FakeConn) RemoteAddr() net.Addr {
return &net.TCPAddr{}
}
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
return c.SetDeadlineError
}
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
return c.SetReadDeadlineError
}
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
return c.SetWriteDeadlineError
}
type FakeTransport struct {
Data []byte
Err error
}
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
return ft.Data, ft.Err
}
func (ft FakeTransport) RequiresPadding() bool {
return false
}
func (ft FakeTransport) Address() string {
return ""
}
func (ft FakeTransport) Network() string {
return "fake"
}
type FakeEncoder struct {
Data []byte
Err error
}
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
return fe.Data, fe.Err
}
type FakeResolver struct {
NumFailures *atomicx.Int64
Err error
Result []string
}
func NewFakeResolverThatFails() FakeResolver {
return FakeResolver{NumFailures: atomicx.NewInt64(), Err: errNotFound}
}
func NewFakeResolverWithResult(r []string) FakeResolver {
return FakeResolver{NumFailures: atomicx.NewInt64(), Result: r}
}
var errNotFound = &net.DNSError{
Err: "no such host",
}
func (c FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
time.Sleep(10 * time.Microsecond)
if c.Err != nil {
if c.NumFailures != nil {
c.NumFailures.Add(1)
}
return nil, c.Err
}
return c.Result, nil
}
func (c FakeResolver) Network() string {
return "fake"
}
func (c FakeResolver) Address() string {
return ""
}
var _ Resolver = FakeResolver{}
@@ -0,0 +1,76 @@
package resolver
import (
"net"
"testing"
"github.com/miekg/dns"
)
func GenReplyError(t *testing.T, code int) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetRcode(query, code)
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
func GenReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query)
for _, ip := range ips {
switch qtype {
case dns.TypeA:
reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
A: net.ParseIP(ip),
})
case dns.TypeAAAA:
reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: net.ParseIP(ip),
})
}
}
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
+34
View File
@@ -0,0 +1,34 @@
package resolver
import (
"context"
"golang.org/x/net/idna"
)
// IDNAResolver is to support resolving Internationalized Domain Names.
// See RFC3492 for more information.
type IDNAResolver struct {
Resolver
}
// LookupHost implements Resolver.LookupHost
func (r IDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
host, err := idna.ToASCII(hostname)
if err != nil {
return nil, err
}
return r.Resolver.LookupHost(ctx, host)
}
// Network implements Resolver.Network.
func (r IDNAResolver) Network() string {
return "idna"
}
// Address implements Resolver.Address.
func (r IDNAResolver) Address() string {
return ""
}
var _ Resolver = IDNAResolver{}
@@ -0,0 +1,76 @@
package resolver_test
import (
"context"
"errors"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
var ErrUnexpectedPunycode = errors.New("unexpected punycode value")
type CheckIDNAResolver struct {
Addresses []string
Error error
Expect string
}
func (resolv CheckIDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
if resolv.Error != nil {
return nil, resolv.Error
}
if hostname != resolv.Expect {
return nil, ErrUnexpectedPunycode
}
return resolv.Addresses, nil
}
func (r CheckIDNAResolver) Network() string {
return "checkidna"
}
func (r CheckIDNAResolver) Address() string {
return ""
}
func TestIDNAResolverSuccess(t *testing.T) {
expectedIPs := []string{"77.88.55.66"}
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
Addresses: expectedIPs,
Expect: "xn--d1acpjx3f.xn--p1ai",
}}
addrs, err := resolv.LookupHost(context.Background(), "яндекс.рф")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(expectedIPs, addrs); diff != "" {
t.Fatal(diff)
}
}
func TestIDNAResolverFailure(t *testing.T) {
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
Error: errors.New("we should not arrive here"),
}}
// See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/
addrs, err := resolv.LookupHost(context.Background(), "xn--0000h")
if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected no response here")
}
}
func TestIDNAResolverTransportOK(t *testing.T) {
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{}}
if resolv.Network() != "idna" {
t.Fatal("invalid network")
}
if resolv.Address() != "" {
t.Fatal("invalid address")
}
}
@@ -0,0 +1,111 @@
package resolver_test
import (
"context"
"net"
"net/http"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func init() {
log.SetLevel(log.DebugLevel)
}
func testresolverquick(t *testing.T, reso resolver.Resolver) {
if testing.Short() {
t.Skip("skip test in short mode")
}
reso = resolver.LoggingResolver{Logger: log.Log, Resolver: reso}
addrs, err := reso.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil addrs here")
}
var foundquad8 bool
for _, addr := range addrs {
// See https://github.com/ooni/probe-cli/v3/internal/engine/pull/954/checks?check_run_id=1182269025
if addr == "8.8.8.8" || addr == "2001:4860:4860::8888" {
foundquad8 = true
}
}
if !foundquad8 {
t.Fatalf("did not find 8.8.8.8 in ouput; output=%+v", addrs)
}
}
// Ensuring we can handle Internationalized Domain Names (IDNs) without issues
func testresolverquickidna(t *testing.T, reso resolver.Resolver) {
if testing.Short() {
t.Skip("skip test in short mode")
}
reso = resolver.IDNAResolver{
resolver.LoggingResolver{Logger: log.Log, Resolver: reso},
}
addrs, err := reso.LookupHost(context.Background(), "яндекс.рф")
if err != nil {
t.Fatal(err)
}
if addrs == nil {
t.Fatal("expected non-nil addrs here")
}
}
func TestNewResolverSystem(t *testing.T) {
reso := resolver.SystemResolver{}
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverUDPAddress(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverUDP(new(net.Dialer), "8.8.8.8:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverUDPDomain(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverUDP(new(net.Dialer), "dns.google.com:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverTCPAddress(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "8.8.8.8:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverTCPDomain(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "dns.google.com:53"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoTAddress(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoTDomain(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverTLS(resolver.DialTLSContext, "dns.google.com:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoH(t *testing.T) {
reso := resolver.NewSerialResolver(
resolver.NewDNSOverHTTPS(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
+30
View File
@@ -0,0 +1,30 @@
package resolver
import (
"context"
"time"
)
// Logger is the logger assumed by this package
type Logger interface {
Debugf(format string, v ...interface{})
Debug(message string)
}
// LoggingResolver is a resolver that emits events
type LoggingResolver struct {
Resolver
Logger Logger
}
// LookupHost returns the IP addresses of a host
func (r LoggingResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
r.Logger.Debugf("resolve %s...", hostname)
start := time.Now()
addrs, err := r.Resolver.LookupHost(ctx, hostname)
stop := time.Now()
r.Logger.Debugf("resolve %s... (%+v, %+v) in %s", hostname, addrs, err, stop.Sub(start))
return addrs, err
}
var _ Resolver = LoggingResolver{}
@@ -0,0 +1,23 @@
package resolver_test
import (
"context"
"testing"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestLoggingResolver(t *testing.T) {
r := resolver.LoggingResolver{
Logger: log.Log,
Resolver: resolver.NewFakeResolverThatFails(),
}
addrs, err := r.LookupHost(context.Background(), "www.google.com")
if err == nil {
t.Fatal("expected an error here")
}
if addrs != nil {
t.Fatal("expected nil addr here")
}
}
+18
View File
@@ -0,0 +1,18 @@
package resolver
import (
"context"
)
// Resolver is a DNS resolver. The *net.Resolver used by Go implements
// this interface, but other implementations are possible.
type Resolver interface {
// LookupHost resolves a hostname to a list of IP addresses.
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
// Network returns the network being used by the resolver
Network() string
// Address returns the address being used by the resolver
Address() string
}
+73
View File
@@ -0,0 +1,73 @@
package resolver
import (
"context"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
// SaverResolver is a resolver that saves events
type SaverResolver struct {
Resolver
Saver *trace.Saver
}
// LookupHost implements Resolver.LookupHost
func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
start := time.Now()
r.Saver.Write(trace.Event{
Address: r.Resolver.Address(),
Hostname: hostname,
Name: "resolve_start",
Proto: r.Resolver.Network(),
Time: start,
})
addrs, err := r.Resolver.LookupHost(ctx, hostname)
stop := time.Now()
r.Saver.Write(trace.Event{
Addresses: addrs,
Address: r.Resolver.Address(),
Duration: stop.Sub(start),
Err: err,
Hostname: hostname,
Name: "resolve_done",
Proto: r.Resolver.Network(),
Time: stop,
})
return addrs, err
}
// SaverDNSTransport is a DNS transport that saves events
type SaverDNSTransport struct {
RoundTripper
Saver *trace.Saver
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
start := time.Now()
txp.Saver.Write(trace.Event{
Address: txp.Address(),
DNSQuery: query,
Name: "dns_round_trip_start",
Proto: txp.Network(),
Time: start,
})
reply, err := txp.RoundTripper.RoundTrip(ctx, query)
stop := time.Now()
txp.Saver.Write(trace.Event{
Address: txp.Address(),
DNSQuery: query,
DNSReply: reply,
Duration: stop.Sub(start),
Err: err,
Name: "dns_round_trip_done",
Proto: txp.Network(),
Time: stop,
})
return reply, err
}
var _ Resolver = SaverResolver{}
var _ RoundTripper = SaverDNSTransport{}
+211
View File
@@ -0,0 +1,211 @@
package resolver_test
import (
"bytes"
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
)
func TestSaverResolverFailure(t *testing.T) {
expected := errors.New("no such host")
saver := &trace.Saver{}
reso := resolver.SaverResolver{
Resolver: resolver.FakeResolver{
Err: expected,
},
Saver: saver,
}
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if ev[1].Addresses != nil {
t.Fatal("unexpected Addresses")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverResolverSuccess(t *testing.T) {
expected := []string{"8.8.8.8", "8.8.4.4"}
saver := &trace.Saver{}
reso := resolver.SaverResolver{
Resolver: resolver.FakeResolver{
Result: expected,
},
Saver: saver,
}
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
if err != nil {
t.Fatal("expected nil error here")
}
if !reflect.DeepEqual(addrs, expected) {
t.Fatal("not the result we expected")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if ev[0].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[0].Name != "resolve_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !reflect.DeepEqual(ev[1].Addresses, expected) {
t.Fatal("unexpected Addresses")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Hostname != "www.google.com" {
t.Fatal("unexpected Hostname")
}
if ev[1].Name != "resolve_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverDNSTransportFailure(t *testing.T) {
expected := errors.New("no such host")
saver := &trace.Saver{}
txp := resolver.SaverDNSTransport{
RoundTripper: resolver.FakeTransport{
Err: expected,
},
Saver: saver,
}
query := []byte("abc")
reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].DNSReply != nil {
t.Fatal("unexpected DNSReply")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if !errors.Is(ev[1].Err, expected) {
t.Fatal("unexpected Err")
}
if ev[1].Name != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
func TestSaverDNSTransportSuccess(t *testing.T) {
expected := []byte("def")
saver := &trace.Saver{}
txp := resolver.SaverDNSTransport{
RoundTripper: resolver.FakeTransport{
Data: expected,
},
Saver: saver,
}
query := []byte("abc")
reply, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal("we expected nil error here")
}
if !bytes.Equal(reply, expected) {
t.Fatal("expected another reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name != "dns_round_trip_start" {
t.Fatal("unexpected name")
}
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].DNSQuery, query) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].DNSReply, expected) {
t.Fatal("unexpected DNSReply")
}
if ev[1].Duration <= 0 {
t.Fatal("unexpected Duration")
}
if ev[1].Err != nil {
t.Fatal("unexpected Err")
}
if ev[1].Name != "dns_round_trip_done" {
t.Fatal("unexpected name")
}
if !ev[1].Time.After(ev[0].Time) {
t.Fatal("the saved time is wrong")
}
}
+113
View File
@@ -0,0 +1,113 @@
package resolver
import (
"context"
"errors"
"net"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
)
// RoundTripper represents an abstract DNS transport.
type RoundTripper interface {
// RoundTrip sends a DNS query and receives the reply.
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
// RequiresPadding return true for DoH and DoT according to RFC8467
RequiresPadding() bool
// Network is the network of the round tripper (e.g. "dot")
Network() string
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
Address() string
}
// SerialResolver is a resolver that first issues an A query and then
// issues an AAAA query for the requested domain.
type SerialResolver struct {
Encoder Encoder
Decoder Decoder
NumTimeouts *atomicx.Int64
Txp RoundTripper
}
// NewSerialResolver creates a new OONI Resolver instance.
func NewSerialResolver(t RoundTripper) SerialResolver {
return SerialResolver{
Encoder: MiekgEncoder{},
Decoder: MiekgDecoder{},
NumTimeouts: atomicx.NewInt64(),
Txp: t,
}
}
// Transport returns the transport being used.
func (r SerialResolver) Transport() RoundTripper {
return r.Txp
}
// Network implements Resolver.Network
func (r SerialResolver) Network() string {
return r.Txp.Network()
}
// Address implements Resolver.Address
func (r SerialResolver) Address() string {
return r.Txp.Address()
}
// LookupHost implements Resolver.LookupHost.
func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
var addrs []string
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
if errA != nil && errAAAA != nil {
return nil, errA
}
addrs = append(addrs, addrsA...)
addrs = append(addrs, addrsAAAA...)
return addrs, nil
}
func (r SerialResolver) roundTripWithRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
var errorslist []error
for i := 0; i < 3; i++ {
replies, err := r.roundTrip(ctx, hostname, qtype)
if err == nil {
return replies, nil
}
errorslist = append(errorslist, err)
var operr *net.OpError
if errors.As(err, &operr) == false || operr.Timeout() == false {
// The first error is the one that is most likely to be caused
// by the network. Subsequent errors are more likely to be caused
// by context deadlines. So, the first error is attached to an
// operation, while subsequent errors may possibly not be. If
// so, the resulting failing operation is not correct.
break
}
r.NumTimeouts.Add(1)
}
// bugfix: we MUST return one of the errors otherwise we confuse the
// mechanism in errwrap that classifies the root cause operation, since
// it would not be able to find a child with a major operation error
return nil, errorslist[0]
}
func (r SerialResolver) roundTrip(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.Decode(qtype, replydata)
}
var _ Resolver = SerialResolver{}

Some files were not shown because too many files have changed in this diff Show More