chore: merge probe-engine into probe-cli (#201)
This is how I did it: 1. `git clone https://github.com/ooni/probe-engine internal/engine` 2. ``` (cd internal/engine && git describe --tags) v0.23.0 ``` 3. `nvim go.mod` (merging `go.mod` with `internal/engine/go.mod` 4. `rm -rf internal/.git internal/engine/go.{mod,sum}` 5. `git add internal/engine` 6. `find . -type f -name \*.go -exec sed -i 's@/ooni/probe-engine@/ooni/probe-cli/v3/internal/engine@g' {} \;` 7. `go build ./...` (passes) 8. `go test -race ./...` (temporary failure on RiseupVPN) 9. `go mod tidy` 10. this commit message Once this piece of work is done, we can build a new version of `ooniprobe` that is using `internal/engine` directly. We need to do more work to ensure all the other functionality in `probe-engine` (e.g. making mobile packages) are still WAI. Part of https://github.com/ooni/probe/issues/1335
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
# Package github.com/ooni/probe-engine/netx
|
||||
|
||||
OONI extensions to the `net` and `net/http` packages. This code is
|
||||
used by `ooni/probe-engine` as a low level library to collect
|
||||
network, DNS, and HTTP events occurring during OONI measurements.
|
||||
|
||||
This library contains replacements for commonly used standard library
|
||||
interfaces that facilitate seamless network measurements. By using
|
||||
such replacements, as opposed to standard library interfaces, we can:
|
||||
|
||||
* save the timing of HTTP events (e.g. received response headers)
|
||||
* save the timing and result of every Connect, Read, Write, Close operation
|
||||
* save the timing and result of the TLS handshake (including certificates)
|
||||
|
||||
By default, this library uses the system resolver. In addition, it
|
||||
is possible to configure alternative DNS transports and remote
|
||||
servers. We support DNS over UDP, DNS over TCP, DNS over TLS (DoT),
|
||||
and DNS over HTTPS (DoH). When using an alternative transport, we
|
||||
are also able to intercept and save DNS messages, as well as any
|
||||
other interaction with the remote server (e.g., the result of the
|
||||
TLS handshake for DoT and DoH).
|
||||
|
||||
This package is a fork of [github.com/ooni/netx](https://github.com/ooni/netx).
|
||||
@@ -0,0 +1,580 @@
|
||||
// Package archival contains data formats used for archival.
|
||||
//
|
||||
// See https://github.com/ooni/spec.
|
||||
package archival
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// ExtSpec describes a data format extension
|
||||
type ExtSpec struct {
|
||||
Name string // extension name
|
||||
V int64 // extension version
|
||||
}
|
||||
|
||||
// AddTo adds the current ExtSpec to the specified measurement
|
||||
func (spec ExtSpec) AddTo(m *model.Measurement) {
|
||||
if m.Extensions == nil {
|
||||
m.Extensions = make(map[string]int64)
|
||||
}
|
||||
m.Extensions[spec.Name] = spec.V
|
||||
}
|
||||
|
||||
var (
|
||||
// ExtDNS is the version of df-002-dnst.md
|
||||
ExtDNS = ExtSpec{Name: "dnst", V: 0}
|
||||
|
||||
// ExtNetevents is the version of df-008-netevents.md
|
||||
ExtNetevents = ExtSpec{Name: "netevents", V: 0}
|
||||
|
||||
// ExtHTTP is the version of df-001-httpt.md
|
||||
ExtHTTP = ExtSpec{Name: "httpt", V: 0}
|
||||
|
||||
// ExtTCPConnect is the version of df-005-tcpconnect.md
|
||||
ExtTCPConnect = ExtSpec{Name: "tcpconnect", V: 0}
|
||||
|
||||
// ExtTLSHandshake is the version of df-006-tlshandshake.md
|
||||
ExtTLSHandshake = ExtSpec{Name: "tlshandshake", V: 0}
|
||||
|
||||
// ExtTunnel is the version of df-009-tunnel.md
|
||||
ExtTunnel = ExtSpec{Name: "tunnel", V: 0}
|
||||
)
|
||||
|
||||
// TCPConnectStatus contains the TCP connect status.
|
||||
//
|
||||
// The Blocked field breaks the separation between measurement and analysis
|
||||
// we have been enforcing for quite some time now. It is a legacy from the
|
||||
// Web Connectivity experiment and it should be here because of that.
|
||||
type TCPConnectStatus struct {
|
||||
Blocked *bool `json:"blocked,omitempty"` // Web Connectivity only
|
||||
Failure *string `json:"failure"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// TCPConnectEntry contains one of the entries that are part
|
||||
// of the "tcp_connect" key of a OONI report.
|
||||
type TCPConnectEntry struct {
|
||||
ConnID int64 `json:"conn_id,omitempty"`
|
||||
DialID int64 `json:"dial_id,omitempty"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
Status TCPConnectStatus `json:"status"`
|
||||
T float64 `json:"t"`
|
||||
TransactionID int64 `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewTCPConnectList creates a new TCPConnectList
|
||||
func NewTCPConnectList(begin time.Time, events []trace.Event) []TCPConnectEntry {
|
||||
var out []TCPConnectEntry
|
||||
for _, event := range events {
|
||||
if event.Name != errorx.ConnectOperation {
|
||||
continue
|
||||
}
|
||||
if event.Proto != "tcp" {
|
||||
continue
|
||||
}
|
||||
// We assume Go is passing us legit data structures
|
||||
ip, sport, _ := net.SplitHostPort(event.Address)
|
||||
iport, _ := strconv.Atoi(sport)
|
||||
out = append(out, TCPConnectEntry{
|
||||
IP: ip,
|
||||
Port: iport,
|
||||
Status: TCPConnectStatus{
|
||||
Failure: NewFailure(event.Err),
|
||||
Success: event.Err == nil,
|
||||
},
|
||||
T: event.Time.Sub(begin).Seconds(),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// NewFailure creates a failure nullable string from the given error
|
||||
func NewFailure(err error) *string {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// The following code guarantees that the error is always wrapped even
|
||||
// when we could not actually hit our code that does the wrapping. A case
|
||||
// in which this happen is with context deadline for HTTP.
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
Error: err,
|
||||
Operation: errorx.TopLevelOperation,
|
||||
}.MaybeBuild()
|
||||
errWrapper := err.(*errorx.ErrWrapper)
|
||||
s := errWrapper.Failure
|
||||
if s == "" {
|
||||
s = "unknown_failure: errWrapper.Failure is empty"
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// NewFailedOperation creates a failed operation string from the given error.
|
||||
func NewFailedOperation(err error) *string {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
errWrapper *errorx.ErrWrapper
|
||||
s = errorx.UnknownOperation
|
||||
)
|
||||
if errors.As(err, &errWrapper) && errWrapper.Operation != "" {
|
||||
s = errWrapper.Operation
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// HTTPTor contains Tor information
|
||||
type HTTPTor struct {
|
||||
ExitIP *string `json:"exit_ip"`
|
||||
ExitName *string `json:"exit_name"`
|
||||
IsTor bool `json:"is_tor"`
|
||||
}
|
||||
|
||||
// MaybeBinaryValue is a possibly binary string. We use this helper class
|
||||
// to define a custom JSON encoder that allows us to choose the proper
|
||||
// representation depending on whether the Value field is valid UTF-8 or not.
|
||||
type MaybeBinaryValue struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
// MarshalJSON marshals a string-like to JSON following the OONI spec that
|
||||
// says that UTF-8 content is represened as string and non-UTF-8 content is
|
||||
// instead represented using `{"format":"base64","data":"..."}`.
|
||||
func (hb MaybeBinaryValue) MarshalJSON() ([]byte, error) {
|
||||
if utf8.ValidString(hb.Value) {
|
||||
return json.Marshal(hb.Value)
|
||||
}
|
||||
er := make(map[string]string)
|
||||
er["format"] = "base64"
|
||||
er["data"] = base64.StdEncoding.EncodeToString([]byte(hb.Value))
|
||||
return json.Marshal(er)
|
||||
}
|
||||
|
||||
// UnmarshalJSON is the opposite of MarshalJSON.
|
||||
func (hb *MaybeBinaryValue) UnmarshalJSON(d []byte) error {
|
||||
if err := json.Unmarshal(d, &hb.Value); err == nil {
|
||||
return nil
|
||||
}
|
||||
er := make(map[string]string)
|
||||
if err := json.Unmarshal(d, &er); err != nil {
|
||||
return err
|
||||
}
|
||||
if v, ok := er["format"]; !ok || v != "base64" {
|
||||
return errors.New("missing or invalid format field")
|
||||
}
|
||||
if _, ok := er["data"]; !ok {
|
||||
return errors.New("missing data field")
|
||||
}
|
||||
b64, err := base64.StdEncoding.DecodeString(er["data"])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hb.Value = string(b64)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HTTPBody is an HTTP body. As an implementation note, this type must be
|
||||
// an alias for the MaybeBinaryValue type, otherwise the specific serialisation
|
||||
// mechanism implemented by MaybeBinaryValue is not working.
|
||||
type HTTPBody = MaybeBinaryValue
|
||||
|
||||
// HTTPHeader is a single HTTP header.
|
||||
type HTTPHeader struct {
|
||||
Key string
|
||||
Value MaybeBinaryValue
|
||||
}
|
||||
|
||||
// MarshalJSON marshals a single HTTP header to a tuple where the first
|
||||
// element is a string and the second element is maybe-binary data.
|
||||
func (hh HTTPHeader) MarshalJSON() ([]byte, error) {
|
||||
if utf8.ValidString(hh.Value.Value) {
|
||||
return json.Marshal([]string{hh.Key, hh.Value.Value})
|
||||
}
|
||||
value := make(map[string]string)
|
||||
value["format"] = "base64"
|
||||
value["data"] = base64.StdEncoding.EncodeToString([]byte(hh.Value.Value))
|
||||
return json.Marshal([]interface{}{hh.Key, value})
|
||||
}
|
||||
|
||||
// UnmarshalJSON is the opposite of MarshalJSON.
|
||||
func (hh *HTTPHeader) UnmarshalJSON(d []byte) error {
|
||||
var pair []interface{}
|
||||
if err := json.Unmarshal(d, &pair); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(pair) != 2 {
|
||||
return errors.New("unexpected pair length")
|
||||
}
|
||||
key, ok := pair[0].(string)
|
||||
if !ok {
|
||||
return errors.New("the key is not a string")
|
||||
}
|
||||
value, ok := pair[1].(string)
|
||||
if !ok {
|
||||
mapvalue, ok := pair[1].(map[string]interface{})
|
||||
if !ok {
|
||||
return errors.New("the value is neither a string nor a map[string]interface{}")
|
||||
}
|
||||
if _, ok := mapvalue["format"]; !ok {
|
||||
return errors.New("missing format")
|
||||
}
|
||||
if v, ok := mapvalue["format"].(string); !ok || v != "base64" {
|
||||
return errors.New("invalid format")
|
||||
}
|
||||
if _, ok := mapvalue["data"]; !ok {
|
||||
return errors.New("missing data field")
|
||||
}
|
||||
v, ok := mapvalue["data"].(string)
|
||||
if !ok {
|
||||
return errors.New("the data field is not a string")
|
||||
}
|
||||
b64, err := base64.StdEncoding.DecodeString(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value = string(b64)
|
||||
}
|
||||
hh.Key, hh.Value = key, MaybeBinaryValue{Value: value}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HTTPRequest contains an HTTP request.
|
||||
//
|
||||
// Headers are a map in Web Connectivity data format but
|
||||
// we have added support for a list since January 2020.
|
||||
type HTTPRequest struct {
|
||||
Body HTTPBody `json:"body"`
|
||||
BodyIsTruncated bool `json:"body_is_truncated"`
|
||||
HeadersList []HTTPHeader `json:"headers_list"`
|
||||
Headers map[string]MaybeBinaryValue `json:"headers"`
|
||||
Method string `json:"method"`
|
||||
Tor HTTPTor `json:"tor"`
|
||||
Transport string `json:"x_transport"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// HTTPResponse contains an HTTP response.
|
||||
//
|
||||
// Headers are a map in Web Connectivity data format but
|
||||
// we have added support for a list since January 2020.
|
||||
type HTTPResponse struct {
|
||||
Body HTTPBody `json:"body"`
|
||||
BodyIsTruncated bool `json:"body_is_truncated"`
|
||||
Code int64 `json:"code"`
|
||||
HeadersList []HTTPHeader `json:"headers_list"`
|
||||
Headers map[string]MaybeBinaryValue `json:"headers"`
|
||||
|
||||
// The following fields are not serialised but are useful to simplify
|
||||
// analysing the measurements in telegram, whatsapp, etc.
|
||||
Locations []string `json:"-"`
|
||||
}
|
||||
|
||||
// RequestEntry is one of the entries that are part of
|
||||
// the "requests" key of a OONI report.
|
||||
type RequestEntry struct {
|
||||
Failure *string `json:"failure"`
|
||||
Request HTTPRequest `json:"request"`
|
||||
Response HTTPResponse `json:"response"`
|
||||
T float64 `json:"t"`
|
||||
TransactionID int64 `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
func addheaders(
|
||||
source http.Header,
|
||||
destList *[]HTTPHeader,
|
||||
destMap *map[string]MaybeBinaryValue,
|
||||
) {
|
||||
for key, values := range source {
|
||||
for index, value := range values {
|
||||
value := MaybeBinaryValue{Value: value}
|
||||
// With the map representation we can only represent a single
|
||||
// value for every key. Hence the list representation.
|
||||
if index == 0 {
|
||||
(*destMap)[key] = value
|
||||
}
|
||||
*destList = append(*destList, HTTPHeader{
|
||||
Key: key,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
}
|
||||
sort.Slice(*destList, func(i, j int) bool {
|
||||
return (*destList)[i].Key < (*destList)[j].Key
|
||||
})
|
||||
}
|
||||
|
||||
// NewRequestList returns the list for "requests"
|
||||
func NewRequestList(begin time.Time, events []trace.Event) []RequestEntry {
|
||||
// OONI wants the last request to appear first
|
||||
var out []RequestEntry
|
||||
tmp := newRequestList(begin, events)
|
||||
for i := len(tmp) - 1; i >= 0; i-- {
|
||||
out = append(out, tmp[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func newRequestList(begin time.Time, events []trace.Event) []RequestEntry {
|
||||
var (
|
||||
out []RequestEntry
|
||||
entry RequestEntry
|
||||
)
|
||||
for _, ev := range events {
|
||||
switch ev.Name {
|
||||
case "http_transaction_start":
|
||||
entry = RequestEntry{}
|
||||
entry.T = ev.Time.Sub(begin).Seconds()
|
||||
case "http_request_body_snapshot":
|
||||
entry.Request.Body.Value = string(ev.Data)
|
||||
entry.Request.BodyIsTruncated = ev.DataIsTruncated
|
||||
case "http_request_metadata":
|
||||
entry.Request.Headers = make(map[string]MaybeBinaryValue)
|
||||
addheaders(
|
||||
ev.HTTPHeaders, &entry.Request.HeadersList, &entry.Request.Headers)
|
||||
entry.Request.Method = ev.HTTPMethod
|
||||
entry.Request.URL = ev.HTTPURL
|
||||
entry.Request.Transport = ev.Transport
|
||||
case "http_response_metadata":
|
||||
entry.Response.Headers = make(map[string]MaybeBinaryValue)
|
||||
addheaders(
|
||||
ev.HTTPHeaders, &entry.Response.HeadersList, &entry.Response.Headers)
|
||||
entry.Response.Code = int64(ev.HTTPStatusCode)
|
||||
entry.Response.Locations = ev.HTTPHeaders.Values("Location")
|
||||
case "http_response_body_snapshot":
|
||||
entry.Response.Body.Value = string(ev.Data)
|
||||
entry.Response.BodyIsTruncated = ev.DataIsTruncated
|
||||
case "http_transaction_done":
|
||||
entry.Failure = NewFailure(ev.Err)
|
||||
out = append(out, entry)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DNSAnswerEntry is the answer to a DNS query
|
||||
type DNSAnswerEntry struct {
|
||||
ASN int64 `json:"asn,omitempty"`
|
||||
ASOrgName string `json:"as_org_name,omitempty"`
|
||||
AnswerType string `json:"answer_type"`
|
||||
Hostname string `json:"hostname,omitempty"`
|
||||
IPv4 string `json:"ipv4,omitempty"`
|
||||
IPv6 string `json:"ipv6,omitempty"`
|
||||
TTL *uint32 `json:"ttl"`
|
||||
}
|
||||
|
||||
// DNSQueryEntry is a DNS query with possibly an answer
|
||||
type DNSQueryEntry struct {
|
||||
Answers []DNSAnswerEntry `json:"answers"`
|
||||
DialID int64 `json:"dial_id,omitempty"`
|
||||
Engine string `json:"engine"`
|
||||
Failure *string `json:"failure"`
|
||||
Hostname string `json:"hostname"`
|
||||
QueryType string `json:"query_type"`
|
||||
ResolverHostname *string `json:"resolver_hostname"`
|
||||
ResolverPort *string `json:"resolver_port"`
|
||||
ResolverAddress string `json:"resolver_address"`
|
||||
T float64 `json:"t"`
|
||||
TransactionID int64 `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
type dnsQueryType string
|
||||
|
||||
// NewDNSQueriesList returns a list of DNS queries.
|
||||
func NewDNSQueriesList(begin time.Time, events []trace.Event, dbpath string) []DNSQueryEntry {
|
||||
// TODO(bassosimone): add support for CNAME lookups.
|
||||
var out []DNSQueryEntry
|
||||
for _, ev := range events {
|
||||
if ev.Name != "resolve_done" {
|
||||
continue
|
||||
}
|
||||
for _, qtype := range []dnsQueryType{"A", "AAAA"} {
|
||||
entry := qtype.makequeryentry(begin, ev)
|
||||
for _, addr := range ev.Addresses {
|
||||
if qtype.ipoftype(addr) {
|
||||
entry.Answers = append(
|
||||
entry.Answers, qtype.makeanswerentry(addr, dbpath))
|
||||
}
|
||||
}
|
||||
if len(entry.Answers) <= 0 && ev.Err == nil {
|
||||
// This allows us to skip cases where the server does not have
|
||||
// an IPv6 address but has an IPv4 address. Instead, when we
|
||||
// receive an error, we want to track its existence. The main
|
||||
// issue here is that we are cheating, because we are creating
|
||||
// entries representing queries, but we don't know what the
|
||||
// resolver actually did, especially the system resolver. So,
|
||||
// this output is just our best guess.
|
||||
continue
|
||||
}
|
||||
out = append(out, entry)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (qtype dnsQueryType) ipoftype(addr string) bool {
|
||||
switch qtype {
|
||||
case "A":
|
||||
return strings.Contains(addr, ":") == false
|
||||
case "AAAA":
|
||||
return strings.Contains(addr, ":") == true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (qtype dnsQueryType) makeanswerentry(addr string, dbpath string) DNSAnswerEntry {
|
||||
answer := DNSAnswerEntry{AnswerType: string(qtype)}
|
||||
asn, org, _ := geolocate.LookupASN(dbpath, addr)
|
||||
answer.ASN = int64(asn)
|
||||
answer.ASOrgName = org
|
||||
switch qtype {
|
||||
case "A":
|
||||
answer.IPv4 = addr
|
||||
case "AAAA":
|
||||
answer.IPv6 = addr
|
||||
}
|
||||
return answer
|
||||
}
|
||||
|
||||
func (qtype dnsQueryType) makequeryentry(begin time.Time, ev trace.Event) DNSQueryEntry {
|
||||
return DNSQueryEntry{
|
||||
Engine: ev.Proto,
|
||||
Failure: NewFailure(ev.Err),
|
||||
Hostname: ev.Hostname,
|
||||
QueryType: string(qtype),
|
||||
ResolverAddress: ev.Address,
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// NetworkEvent is a network event.
|
||||
type NetworkEvent struct {
|
||||
Address string `json:"address,omitempty"`
|
||||
ConnID int64 `json:"conn_id,omitempty"`
|
||||
DialID int64 `json:"dial_id,omitempty"`
|
||||
Failure *string `json:"failure"`
|
||||
NumBytes int64 `json:"num_bytes,omitempty"`
|
||||
Operation string `json:"operation"`
|
||||
Proto string `json:"proto,omitempty"`
|
||||
T float64 `json:"t"`
|
||||
TransactionID int64 `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewNetworkEventsList returns a list of DNS queries.
|
||||
func NewNetworkEventsList(begin time.Time, events []trace.Event) []NetworkEvent {
|
||||
var out []NetworkEvent
|
||||
for _, ev := range events {
|
||||
if ev.Name == errorx.ConnectOperation {
|
||||
out = append(out, NetworkEvent{
|
||||
Address: ev.Address,
|
||||
Failure: NewFailure(ev.Err),
|
||||
Operation: ev.Name,
|
||||
Proto: ev.Proto,
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if ev.Name == errorx.ReadOperation {
|
||||
out = append(out, NetworkEvent{
|
||||
Failure: NewFailure(ev.Err),
|
||||
Operation: ev.Name,
|
||||
NumBytes: int64(ev.NumBytes),
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if ev.Name == errorx.WriteOperation {
|
||||
out = append(out, NetworkEvent{
|
||||
Failure: NewFailure(ev.Err),
|
||||
Operation: ev.Name,
|
||||
NumBytes: int64(ev.NumBytes),
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if ev.Name == errorx.ReadFromOperation {
|
||||
out = append(out, NetworkEvent{
|
||||
Address: ev.Address,
|
||||
Failure: NewFailure(ev.Err),
|
||||
Operation: ev.Name,
|
||||
NumBytes: int64(ev.NumBytes),
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if ev.Name == errorx.WriteToOperation {
|
||||
out = append(out, NetworkEvent{
|
||||
Address: ev.Address,
|
||||
Failure: NewFailure(ev.Err),
|
||||
Operation: ev.Name,
|
||||
NumBytes: int64(ev.NumBytes),
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
out = append(out, NetworkEvent{
|
||||
Failure: NewFailure(ev.Err),
|
||||
Operation: ev.Name,
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TLSHandshake contains TLS handshake data
|
||||
type TLSHandshake struct {
|
||||
CipherSuite string `json:"cipher_suite"`
|
||||
ConnID int64 `json:"conn_id,omitempty"`
|
||||
Failure *string `json:"failure"`
|
||||
NegotiatedProtocol string `json:"negotiated_protocol"`
|
||||
NoTLSVerify bool `json:"no_tls_verify"`
|
||||
PeerCertificates []MaybeBinaryValue `json:"peer_certificates"`
|
||||
ServerName string `json:"server_name"`
|
||||
T float64 `json:"t"`
|
||||
TLSVersion string `json:"tls_version"`
|
||||
TransactionID int64 `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
// NewTLSHandshakesList creates a new TLSHandshakesList
|
||||
func NewTLSHandshakesList(begin time.Time, events []trace.Event) []TLSHandshake {
|
||||
var out []TLSHandshake
|
||||
for _, ev := range events {
|
||||
if !strings.Contains(ev.Name, "_handshake_done") {
|
||||
continue
|
||||
}
|
||||
out = append(out, TLSHandshake{
|
||||
CipherSuite: ev.TLSCipherSuite,
|
||||
Failure: NewFailure(ev.Err),
|
||||
NegotiatedProtocol: ev.TLSNegotiatedProto,
|
||||
NoTLSVerify: ev.NoTLSVerify,
|
||||
PeerCertificates: makePeerCerts(ev.TLSPeerCerts),
|
||||
ServerName: ev.TLSServerName,
|
||||
T: ev.Time.Sub(begin).Seconds(),
|
||||
TLSVersion: ev.TLSVersion,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func makePeerCerts(in []*x509.Certificate) (out []MaybeBinaryValue) {
|
||||
for _, e := range in {
|
||||
out = append(out, MaybeBinaryValue{Value: string(e.Raw)})
|
||||
}
|
||||
return
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,8 @@
|
||||
package archival
|
||||
|
||||
// DNSQueryType allows to access dnsQueryType from unit tests
|
||||
type DNSQueryType = dnsQueryType
|
||||
|
||||
func (qtype dnsQueryType) IPOfType(addr string) bool {
|
||||
return qtype.ipoftype(addr)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package bytecounter
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
|
||||
// Counter counts bytes sent and received.
|
||||
type Counter struct {
|
||||
Received *atomicx.Int64
|
||||
Sent *atomicx.Int64
|
||||
}
|
||||
|
||||
// New creates a new Counter.
|
||||
func New() *Counter {
|
||||
return &Counter{Received: atomicx.NewInt64(), Sent: atomicx.NewInt64()}
|
||||
}
|
||||
|
||||
// CountBytesSent adds count to the bytes sent counter.
|
||||
func (c *Counter) CountBytesSent(count int) {
|
||||
c.Sent.Add(int64(count))
|
||||
}
|
||||
|
||||
// CountKibiBytesSent adds 1024*count to the bytes sent counter.
|
||||
func (c *Counter) CountKibiBytesSent(count float64) {
|
||||
c.Sent.Add(int64(1024 * count))
|
||||
}
|
||||
|
||||
// BytesSent returns the bytes sent so far.
|
||||
func (c *Counter) BytesSent() int64 {
|
||||
return c.Sent.Load()
|
||||
}
|
||||
|
||||
// KibiBytesSent returns the KiB sent so far.
|
||||
func (c *Counter) KibiBytesSent() float64 {
|
||||
return float64(c.BytesSent()) / 1024
|
||||
}
|
||||
|
||||
// CountBytesReceived adds count to the bytes received counter.
|
||||
func (c *Counter) CountBytesReceived(count int) {
|
||||
c.Received.Add(int64(count))
|
||||
}
|
||||
|
||||
// CountKibiBytesReceived adds 1024*count to the bytes received counter.
|
||||
func (c *Counter) CountKibiBytesReceived(count float64) {
|
||||
c.Received.Add(int64(1024 * count))
|
||||
}
|
||||
|
||||
// BytesReceived returns the bytes received so far.
|
||||
func (c *Counter) BytesReceived() int64 {
|
||||
return c.Received.Load()
|
||||
}
|
||||
|
||||
// KibiBytesReceived returns the KiB received so far.
|
||||
func (c *Counter) KibiBytesReceived() float64 {
|
||||
return float64(c.BytesReceived()) / 1024
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package bytecounter_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
)
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
counter := bytecounter.New()
|
||||
counter.CountBytesReceived(16384)
|
||||
counter.CountKibiBytesReceived(10)
|
||||
counter.CountBytesSent(2048)
|
||||
counter.CountKibiBytesSent(10)
|
||||
if counter.BytesSent() != 12288 {
|
||||
t.Fatal("invalid bytes sent")
|
||||
}
|
||||
if counter.BytesReceived() != 26624 {
|
||||
t.Fatal("invalid bytes received")
|
||||
}
|
||||
if v := counter.KibiBytesSent(); v < 11.9 || v > 12.1 {
|
||||
t.Fatal("invalid kibibytes sent")
|
||||
}
|
||||
if v := counter.KibiBytesReceived(); v < 25.9 || v > 26.1 {
|
||||
t.Fatal("invalid kibibytes received")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
)
|
||||
|
||||
// ByteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you
|
||||
// should make sure that you insert this dialer in the dialing chain.
|
||||
//
|
||||
// Bug
|
||||
//
|
||||
// This implementation cannot properly account for the bytes that are sent by
|
||||
// persistent connections, because they strick to the counters set when the
|
||||
// connection was established. This typically means we miss the bytes sent and
|
||||
// received when submitting a measurement. Such bytes are specifically not
|
||||
// see by the experiment specific byte counter.
|
||||
//
|
||||
// For this reason, this implementation may be heavily changed/removed.
|
||||
type ByteCounterDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ByteCounterDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exp := ContextExperimentByteCounter(ctx)
|
||||
sess := ContextSessionByteCounter(ctx)
|
||||
if exp == nil && sess == nil {
|
||||
return conn, nil // no point in wrapping
|
||||
}
|
||||
return byteCounterConnWrapper{Conn: conn, exp: exp, sess: sess}, nil
|
||||
}
|
||||
|
||||
type byteCounterSessionKey struct{}
|
||||
|
||||
// ContextSessionByteCounter retrieves the session byte counter from the context
|
||||
func ContextSessionByteCounter(ctx context.Context) *bytecounter.Counter {
|
||||
counter, _ := ctx.Value(byteCounterSessionKey{}).(*bytecounter.Counter)
|
||||
return counter
|
||||
}
|
||||
|
||||
// WithSessionByteCounter assigns the session byte counter to the context
|
||||
func WithSessionByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context {
|
||||
return context.WithValue(ctx, byteCounterSessionKey{}, counter)
|
||||
}
|
||||
|
||||
type byteCounterExperimentKey struct{}
|
||||
|
||||
// ContextExperimentByteCounter retrieves the experiment byte counter from the context
|
||||
func ContextExperimentByteCounter(ctx context.Context) *bytecounter.Counter {
|
||||
counter, _ := ctx.Value(byteCounterExperimentKey{}).(*bytecounter.Counter)
|
||||
return counter
|
||||
}
|
||||
|
||||
// WithExperimentByteCounter assigns the experiment byte counter to the context
|
||||
func WithExperimentByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context {
|
||||
return context.WithValue(ctx, byteCounterExperimentKey{}, counter)
|
||||
}
|
||||
|
||||
type byteCounterConnWrapper struct {
|
||||
net.Conn
|
||||
exp *bytecounter.Counter
|
||||
sess *bytecounter.Counter
|
||||
}
|
||||
|
||||
func (c byteCounterConnWrapper) Read(p []byte) (int, error) {
|
||||
count, err := c.Conn.Read(p)
|
||||
if c.exp != nil {
|
||||
c.exp.CountBytesReceived(count)
|
||||
}
|
||||
if c.sess != nil {
|
||||
c.sess.CountBytesReceived(count)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c byteCounterConnWrapper) Write(p []byte) (int, error) {
|
||||
count, err := c.Conn.Write(p)
|
||||
if c.exp != nil {
|
||||
c.exp.CountBytesSent(count)
|
||||
}
|
||||
if c.sess != nil {
|
||||
c.sess.CountBytesSent(count)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func dorequest(ctx context.Context, url string) error {
|
||||
txp := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defer txp.CloseIdleConnections()
|
||||
dialer := dialer.ByteCounterDialer{Dialer: new(net.Dialer)}
|
||||
txp.DialContext = dialer.DialContext
|
||||
client := &http.Client{Transport: txp}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
return resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestByteCounterNormalUsage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
sess := bytecounter.New()
|
||||
ctx := context.Background()
|
||||
ctx = dialer.WithSessionByteCounter(ctx, sess)
|
||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
exp := bytecounter.New()
|
||||
ctx = dialer.WithExperimentByteCounter(ctx, exp)
|
||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess.Received.Load() <= exp.Received.Load() {
|
||||
t.Fatal("session should have received more than experiment")
|
||||
}
|
||||
if sess.Sent.Load() <= exp.Sent.Load() {
|
||||
t.Fatal("session should have sent more than experiment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterNoHandlers(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
ctx := context.Background()
|
||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterConnectFailure(t *testing.T) {
|
||||
dialer := dialer.ByteCounterDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
|
||||
)
|
||||
|
||||
// Dialer is the interface we expect from a dialer
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Resolver is the interface we expect from a resolver
|
||||
type Resolver interface {
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
}
|
||||
|
||||
func safeLocalAddress(conn net.Conn) (s string) {
|
||||
if conn != nil && conn.LocalAddr() != nil {
|
||||
s = conn.LocalAddr().String()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func safeConnID(network string, conn net.Conn) int64 {
|
||||
return connid.Compute(network, safeLocalAddress(conn))
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// DNSDialer is a dialer that uses the configured Resolver to resolver a
|
||||
// domain name to IP addresses, and the configured Dialer to connect.
|
||||
type DNSDialer struct {
|
||||
Dialer
|
||||
Resolver Resolver
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext.
|
||||
func (d DNSDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
onlyhost, onlyport, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = dialid.WithDialID(ctx) // important to create before lookupHost
|
||||
var addrs []string
|
||||
addrs, err = d.LookupHost(ctx, onlyhost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var errorslist []error
|
||||
for _, addr := range addrs {
|
||||
target := net.JoinHostPort(addr, onlyport)
|
||||
conn, err := d.Dialer.DialContext(ctx, network, target)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
}
|
||||
return nil, ReduceErrors(errorslist)
|
||||
}
|
||||
|
||||
// ReduceErrors finds a known error in a list of errors since it's probably most relevant
|
||||
func ReduceErrors(errorslist []error) error {
|
||||
if len(errorslist) == 0 {
|
||||
return nil
|
||||
}
|
||||
// If we have a known error, let's consider this the real error
|
||||
// since it's probably most relevant. Otherwise let's return the
|
||||
// first considering that (1) local resolvers likely will give
|
||||
// us IPv4 first and (2) also our resolver does that. So, in case
|
||||
// the user has no IPv6 connectivity, an IPv6 error is going to
|
||||
// appear later in the list of errors.
|
||||
for _, err := range errorslist {
|
||||
var wrapper *errorx.ErrWrapper
|
||||
if errors.As(err, &wrapper) && !strings.HasPrefix(
|
||||
err.Error(), "unknown_failure",
|
||||
) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO(bassosimone): handle this case in a better way
|
||||
return errorslist[0]
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (d DNSDialer) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
return d.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func TestDNSDialerNoPort(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: new(net.Resolver)}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "antani.ooni.nu")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerLookupHostAddress(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{
|
||||
Err: errors.New("mocked error"),
|
||||
}}
|
||||
addrs, err := dialer.LookupHost(context.Background(), "1.1.1.1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerLookupHostFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{
|
||||
Err: expected,
|
||||
}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "dns.google.com:853")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
type MockableResolver struct {
|
||||
Addresses []string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||
return r.Addresses, r.Err
|
||||
}
|
||||
|
||||
func TestDNSDialerDialForSingleIPFails(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: new(net.Resolver)}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "1.1.1.1:853")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerDialForManyIPFails(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: MockableResolver{
|
||||
Addresses: []string{"1.1.1.1", "8.8.8.8"},
|
||||
}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerDialForManyIPSuccess(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EOFConnDialer{}, Resolver: MockableResolver{
|
||||
Addresses: []string{"1.1.1.1", "8.8.8.8"},
|
||||
}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853")
|
||||
if err != nil {
|
||||
t.Fatal("expected nil error here")
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn")
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDNSDialerDialSetsDialID(t *testing.T) {
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx := modelx.WithMeasurementRoot(context.Background(), &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EmitterDialer{
|
||||
Dialer: dialer.EOFConnDialer{},
|
||||
}, Resolver: MockableResolver{
|
||||
Addresses: []string{"1.1.1.1", "8.8.8.8"},
|
||||
}}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "dot.dns:853")
|
||||
if err != nil {
|
||||
t.Fatal("expected nil error here")
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn")
|
||||
}
|
||||
conn.Close()
|
||||
events := saver.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
for _, ev := range events {
|
||||
if ev.Connect != nil && ev.Connect.DialID == 0 {
|
||||
t.Fatal("unexpected DialID")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReduceErrors(t *testing.T) {
|
||||
t.Run("no errors", func(t *testing.T) {
|
||||
result := dialer.ReduceErrors(nil)
|
||||
if result != nil {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single error", func(t *testing.T) {
|
||||
err := errors.New("mocked error")
|
||||
result := dialer.ReduceErrors([]error{err})
|
||||
if result != err {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple errors", func(t *testing.T) {
|
||||
err1 := errors.New("mocked error #1")
|
||||
err2 := errors.New("mocked error #2")
|
||||
result := dialer.ReduceErrors([]error{err1, err2})
|
||||
if result.Error() != "mocked error #1" {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
||||
err1 := errors.New("mocked error #1")
|
||||
err2 := &errorx.ErrWrapper{
|
||||
Failure: "unknown_failure: antani",
|
||||
}
|
||||
err3 := &errorx.ErrWrapper{
|
||||
Failure: errorx.FailureConnectionRefused,
|
||||
}
|
||||
err4 := errors.New("mocked error #3")
|
||||
result := dialer.ReduceErrors([]error{err1, err2, err3, err4})
|
||||
if result.Error() != errorx.FailureConnectionRefused {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
)
|
||||
|
||||
// EmitterDialer is a Dialer that emits events
|
||||
type EmitterDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d EmitterDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
start := time.Now()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
stop := time.Now()
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
Connect: &modelx.ConnectEvent{
|
||||
ConnID: safeConnID(network, conn),
|
||||
DialID: dialid.ContextDialID(ctx),
|
||||
DurationSinceBeginning: stop.Sub(root.Beginning),
|
||||
Error: err,
|
||||
Network: network,
|
||||
RemoteAddress: address,
|
||||
SyscallDuration: stop.Sub(start),
|
||||
TransactionID: transactionid.ContextTransactionID(ctx),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return EmitterConn{
|
||||
Conn: conn,
|
||||
Beginning: root.Beginning,
|
||||
Handler: root.Handler,
|
||||
ID: safeConnID(network, conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EmitterConn is a net.Conn used to emit events
|
||||
type EmitterConn struct {
|
||||
net.Conn
|
||||
Beginning time.Time
|
||||
Handler modelx.Handler
|
||||
ID int64
|
||||
}
|
||||
|
||||
// Read implements net.Conn.Read
|
||||
func (c EmitterConn) Read(b []byte) (n int, err error) {
|
||||
start := time.Now()
|
||||
n, err = c.Conn.Read(b)
|
||||
stop := time.Now()
|
||||
c.Handler.OnMeasurement(modelx.Measurement{
|
||||
Read: &modelx.ReadEvent{
|
||||
ConnID: c.ID,
|
||||
DurationSinceBeginning: stop.Sub(c.Beginning),
|
||||
Error: err,
|
||||
NumBytes: int64(n),
|
||||
SyscallDuration: stop.Sub(start),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write
|
||||
func (c EmitterConn) Write(b []byte) (n int, err error) {
|
||||
start := time.Now()
|
||||
n, err = c.Conn.Write(b)
|
||||
stop := time.Now()
|
||||
c.Handler.OnMeasurement(modelx.Measurement{
|
||||
Write: &modelx.WriteEvent{
|
||||
ConnID: c.ID,
|
||||
DurationSinceBeginning: stop.Sub(c.Beginning),
|
||||
Error: err,
|
||||
NumBytes: int64(n),
|
||||
SyscallDuration: stop.Sub(start),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Close implements net.Conn.Close
|
||||
func (c EmitterConn) Close() (err error) {
|
||||
start := time.Now()
|
||||
err = c.Conn.Close()
|
||||
stop := time.Now()
|
||||
c.Handler.OnMeasurement(modelx.Measurement{
|
||||
Close: &modelx.CloseEvent{
|
||||
ConnID: c.ID,
|
||||
DurationSinceBeginning: stop.Sub(c.Beginning),
|
||||
Error: err,
|
||||
SyscallDuration: stop.Sub(start),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestEmitterFailure(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
d := dialer.EmitterDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
events := saver.Read()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("unexpected number of events saved")
|
||||
}
|
||||
if events[0].Connect == nil {
|
||||
t.Fatal("expected non nil Connect")
|
||||
}
|
||||
conninfo := events[0].Connect
|
||||
if conninfo.ConnID != 0 {
|
||||
t.Fatal("unexpected ConnID value")
|
||||
}
|
||||
emitterCheckConnectEventCommon(t, conninfo, io.EOF)
|
||||
}
|
||||
|
||||
func emitterCheckConnectEventCommon(
|
||||
t *testing.T, conninfo *modelx.ConnectEvent, err error) {
|
||||
if conninfo.DialID == 0 {
|
||||
t.Fatal("unexpected DialID value")
|
||||
}
|
||||
if conninfo.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning value")
|
||||
}
|
||||
if !errors.Is(conninfo.Error, err) {
|
||||
t.Fatal("unexpected Error value")
|
||||
}
|
||||
if conninfo.Network != "tcp" {
|
||||
t.Fatal("unexpected Network value")
|
||||
}
|
||||
if conninfo.RemoteAddress != "www.google.com:443" {
|
||||
t.Fatal("unexpected Network value")
|
||||
}
|
||||
if conninfo.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration value")
|
||||
}
|
||||
if conninfo.TransactionID == 0 {
|
||||
t.Fatal("unexpected TransactionID value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterSuccess(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
d := dialer.EmitterDialer{Dialer: dialer.EOFConnDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal("we expected no error")
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected a non-nil conn here")
|
||||
}
|
||||
conn.Read(nil)
|
||||
conn.Write(nil)
|
||||
conn.Close()
|
||||
events := saver.Read()
|
||||
if len(events) != 4 {
|
||||
t.Fatal("unexpected number of events saved")
|
||||
}
|
||||
if events[0].Connect == nil {
|
||||
t.Fatal("expected non nil Connect")
|
||||
}
|
||||
conninfo := events[0].Connect
|
||||
if conninfo.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID value")
|
||||
}
|
||||
emitterCheckConnectEventCommon(t, conninfo, nil)
|
||||
if events[1].Read == nil {
|
||||
t.Fatal("expected non nil Read")
|
||||
}
|
||||
emitterCheckReadEvent(t, events[1].Read)
|
||||
if events[2].Write == nil {
|
||||
t.Fatal("expected non nil Write")
|
||||
}
|
||||
emitterCheckWriteEvent(t, events[2].Write)
|
||||
if events[3].Close == nil {
|
||||
t.Fatal("expected non nil Close")
|
||||
}
|
||||
emitterCheckCloseEvent(t, events[3].Close)
|
||||
}
|
||||
|
||||
func emitterCheckReadEvent(t *testing.T, ev *modelx.ReadEvent) {
|
||||
if ev.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if ev.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning")
|
||||
}
|
||||
if !errors.Is(ev.Error, io.EOF) {
|
||||
t.Fatal("unexpected Error")
|
||||
}
|
||||
if ev.NumBytes != 0 {
|
||||
t.Fatal("unexpected NumBytes")
|
||||
}
|
||||
if ev.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration")
|
||||
}
|
||||
}
|
||||
|
||||
func emitterCheckWriteEvent(t *testing.T, ev *modelx.WriteEvent) {
|
||||
if ev.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if ev.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning")
|
||||
}
|
||||
if !errors.Is(ev.Error, io.EOF) {
|
||||
t.Fatal("unexpected Error")
|
||||
}
|
||||
if ev.NumBytes != 0 {
|
||||
t.Fatal("unexpected NumBytes")
|
||||
}
|
||||
if ev.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration")
|
||||
}
|
||||
}
|
||||
|
||||
func emitterCheckCloseEvent(t *testing.T, ev *modelx.CloseEvent) {
|
||||
if ev.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if ev.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning")
|
||||
}
|
||||
if !errors.Is(ev.Error, io.EOF) {
|
||||
t.Fatal("unexpected Error")
|
||||
}
|
||||
if ev.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EOFDialer struct{}
|
||||
|
||||
func (EOFDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
type EOFConnDialer struct{}
|
||||
|
||||
func (EOFConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return EOFConn{}, nil
|
||||
}
|
||||
|
||||
type EOFConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (EOFConn) Read(p []byte) (int, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (EOFConn) Write(p []byte) (int, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (EOFConn) Close() error {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (EOFConn) LocalAddr() net.Addr {
|
||||
return EOFAddr{}
|
||||
}
|
||||
|
||||
func (EOFConn) RemoteAddr() net.Addr {
|
||||
return EOFAddr{}
|
||||
}
|
||||
|
||||
func (EOFConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (EOFConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (EOFConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type EOFAddr struct{}
|
||||
|
||||
func (EOFAddr) Network() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (EOFAddr) String() string {
|
||||
return "127.0.0.1:1234"
|
||||
}
|
||||
|
||||
type EOFTLSHandshaker struct{}
|
||||
|
||||
func (EOFTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return nil, tls.ConnectionState{}, io.EOF
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// ErrorWrapperDialer is a dialer that performs err wrapping
|
||||
type ErrorWrapperDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
dialID := dialid.ContextDialID(ctx)
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
// ConnID does not make any sense if we've failed and the error
|
||||
// does not make any sense (and is nil) if we succeded.
|
||||
DialID: dialID,
|
||||
Error: err,
|
||||
Operation: errorx.ConnectOperation,
|
||||
}.MaybeBuild()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ErrorWrapperConn{
|
||||
Conn: conn, ConnID: safeConnID(network, conn), DialID: dialID}, nil
|
||||
}
|
||||
|
||||
// ErrorWrapperConn is a net.Conn that performs error wrapping.
|
||||
type ErrorWrapperConn struct {
|
||||
net.Conn
|
||||
ConnID int64
|
||||
DialID int64
|
||||
}
|
||||
|
||||
// Read implements net.Conn.Read
|
||||
func (c ErrorWrapperConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: c.ConnID,
|
||||
DialID: c.DialID,
|
||||
Error: err,
|
||||
Operation: errorx.ReadOperation,
|
||||
}.MaybeBuild()
|
||||
return
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write
|
||||
func (c ErrorWrapperConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: c.ConnID,
|
||||
DialID: c.DialID,
|
||||
Error: err,
|
||||
Operation: errorx.WriteOperation,
|
||||
}.MaybeBuild()
|
||||
return
|
||||
}
|
||||
|
||||
// Close implements net.Conn.Close
|
||||
func (c ErrorWrapperConn) Close() (err error) {
|
||||
err = c.Conn.Close()
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: c.ConnID,
|
||||
DialID: c.DialID,
|
||||
Error: err,
|
||||
Operation: errorx.CloseOperation,
|
||||
}.MaybeBuild()
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func TestErrorWrapperFailure(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
errorWrapperCheckErr(t, err, errorx.ConnectOperation)
|
||||
}
|
||||
|
||||
func errorWrapperCheckErr(t *testing.T, err error, op string) {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected another error here")
|
||||
}
|
||||
var errWrapper *errorx.ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("cannot cast to ErrWrapper")
|
||||
}
|
||||
if errWrapper.DialID == 0 {
|
||||
t.Fatal("unexpected DialID")
|
||||
}
|
||||
if errWrapper.Operation != op {
|
||||
t.Fatal("unexpected Operation")
|
||||
}
|
||||
if errWrapper.Failure != errorx.FailureEOFError {
|
||||
t.Fatal("unexpected failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorWrapperSuccess(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFConnDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn here")
|
||||
}
|
||||
count, err := conn.Read(nil)
|
||||
errorWrapperCheckIOResult(t, count, err, errorx.ReadOperation)
|
||||
count, err = conn.Write(nil)
|
||||
errorWrapperCheckIOResult(t, count, err, errorx.WriteOperation)
|
||||
err = conn.Close()
|
||||
errorWrapperCheckErr(t, err, errorx.CloseOperation)
|
||||
}
|
||||
|
||||
func errorWrapperCheckIOResult(t *testing.T, count int, err error, op string) {
|
||||
if count != 0 {
|
||||
t.Fatal("expected nil count here")
|
||||
}
|
||||
errorWrapperCheckErr(t, err, op)
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FakeDialer struct {
|
||||
Conn net.Conn
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return d.Conn, d.Err
|
||||
}
|
||||
|
||||
type FakeConn struct {
|
||||
ReadError error
|
||||
ReadData []byte
|
||||
SetDeadlineError error
|
||||
SetReadDeadlineError error
|
||||
SetWriteDeadlineError error
|
||||
WriteError error
|
||||
}
|
||||
|
||||
func (c *FakeConn) Read(b []byte) (int, error) {
|
||||
if len(c.ReadData) > 0 {
|
||||
n := copy(b, c.ReadData)
|
||||
c.ReadData = c.ReadData[n:]
|
||||
return n, nil
|
||||
}
|
||||
if c.ReadError != nil {
|
||||
return 0, c.ReadError
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *FakeConn) Write(b []byte) (n int, err error) {
|
||||
if c.WriteError != nil {
|
||||
return 0, c.WriteError
|
||||
}
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) Close() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (*FakeConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
|
||||
return c.SetDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
|
||||
return c.SetReadDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
|
||||
return c.SetWriteDeadlineError
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestTLSDialerSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
log.SetLevel(log.DebugLevel)
|
||||
dialer := dialer.TLSDialer{Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.LoggingTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Logger: log.Log,
|
||||
},
|
||||
}
|
||||
txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) {
|
||||
// AlpineLinux edge is still using Go 1.13. We cannot switch to
|
||||
// using DialTLSContext here as we'd like to until either Alpine
|
||||
// switches to Go 1.14 or we drop the MK dependency.
|
||||
return dialer.DialTLSContext(context.Background(), network, address)
|
||||
}}
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestDNSDialerSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
log.SetLevel(log.DebugLevel)
|
||||
dialer := dialer.DNSDialer{
|
||||
Dialer: dialer.LoggingDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
Logger: log.Log,
|
||||
},
|
||||
Resolver: new(net.Resolver),
|
||||
}
|
||||
txp := &http.Transport{DialContext: dialer.DialContext}
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("http://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
|
||||
)
|
||||
|
||||
// Logger is the logger assumed by this package
|
||||
type Logger interface {
|
||||
Debugf(format string, v ...interface{})
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
// LoggingDialer is a Dialer with logging
|
||||
type LoggingDialer struct {
|
||||
Dialer
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d LoggingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d.Logger.Debugf("dial %s/%s...", address, network)
|
||||
start := time.Now()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
stop := time.Now()
|
||||
d.Logger.Debugf("dial %s/%s... %+v in %s", address, network, err, stop.Sub(start))
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// LoggingTLSHandshaker is a TLSHandshaker with logging
|
||||
type LoggingTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h LoggingTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
h.Logger.Debugf("tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos)
|
||||
start := time.Now()
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
stop := time.Now()
|
||||
h.Logger.Debugf(
|
||||
"tls {sni=%s next=%+v}... %+v in %s {next=%s cipher=%s v=%s}", config.ServerName,
|
||||
config.NextProtos, err, stop.Sub(start), state.NegotiatedProtocol,
|
||||
tlsx.CipherSuiteString(state.CipherSuite), tlsx.VersionString(state.Version))
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
var _ Dialer = LoggingDialer{}
|
||||
var _ TLSHandshaker = LoggingTLSHandshaker{}
|
||||
@@ -0,0 +1,42 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestLoggingDialerFailure(t *testing.T) {
|
||||
d := dialer.LoggingDialer{
|
||||
Dialer: dialer.EOFDialer{},
|
||||
Logger: log.Log,
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggingTLSHandshakerFailure(t *testing.T) {
|
||||
h := dialer.LoggingTLSHandshaker{
|
||||
TLSHandshaker: dialer.EOFTLSHandshaker{},
|
||||
Logger: log.Log,
|
||||
}
|
||||
tlsconn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
|
||||
ServerName: "www.google.com",
|
||||
})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if tlsconn != nil {
|
||||
t.Fatal("expected nil tlsconn here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// ProxyDialer is a dialer that uses a proxy. If the ProxyURL is not configured, this
|
||||
// dialer is a passthrough for the next Dialer in chain. Otherwise, it will internally
|
||||
// create a SOCKS5 dialer that will connect to the proxy using the underlying Dialer.
|
||||
//
|
||||
// As a special case, you can force a proxy to be used only extemporarily. To this end,
|
||||
// you can use the WithProxyURL function, to store the proxy URL in the context. This
|
||||
// will take precedence over any otherwise configured proxy. The use case for this
|
||||
// functionality is when you need a tunnel to contact OONI probe services.
|
||||
type ProxyDialer struct {
|
||||
Dialer
|
||||
ProxyURL *url.URL
|
||||
}
|
||||
|
||||
type proxyKey struct{}
|
||||
|
||||
// ContextProxyURL retrieves the proxy URL from the context. This is mainly used
|
||||
// to force a tunnel when we fail contacting OONI probe services otherwise.
|
||||
func ContextProxyURL(ctx context.Context) *url.URL {
|
||||
url, _ := ctx.Value(proxyKey{}).(*url.URL)
|
||||
return url
|
||||
}
|
||||
|
||||
// WithProxyURL assigns the proxy URL to the context
|
||||
func WithProxyURL(ctx context.Context, url *url.URL) context.Context {
|
||||
return context.WithValue(ctx, proxyKey{}, url)
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
url := ContextProxyURL(ctx) // context URL takes precendence
|
||||
if url == nil {
|
||||
url = d.ProxyURL
|
||||
}
|
||||
if url == nil {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
if url.Scheme != "socks5" {
|
||||
return nil, errors.New("Scheme is not socks5")
|
||||
}
|
||||
// the code at proxy/socks5.go never fails; see https://git.io/JfJ4g
|
||||
child, _ := proxy.SOCKS5(
|
||||
network, url.Host, nil, proxyDialerWrapper{Dialer: d.Dialer})
|
||||
return d.dial(ctx, child, network, address)
|
||||
}
|
||||
|
||||
func (d ProxyDialer) dial(
|
||||
ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) {
|
||||
connch := make(chan net.Conn)
|
||||
errch := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := child.Dial(network, address)
|
||||
if err != nil {
|
||||
errch <- err
|
||||
return
|
||||
}
|
||||
select {
|
||||
case connch <- conn:
|
||||
default:
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case err := <-errch:
|
||||
return nil, err
|
||||
case conn := <-connch:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// proxyDialerWrapper is required because SOCKS5 expects a Dialer.Dial type but internally
|
||||
// it checks whether DialContext is available and prefers that. So, we need to use this
|
||||
// structure to cast our inner Dialer the way in which SOCKS5 likes it.
|
||||
//
|
||||
// See https://git.io/JfJ4g.
|
||||
type proxyDialerWrapper struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
func (d proxyDialerWrapper) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type ProxyDialerWrapper = proxyDialerWrapper
|
||||
|
||||
func (d ProxyDialer) DialContextWithDialer(
|
||||
ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) {
|
||||
return d.dial(ctx, child, network, address)
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestProxyDialerDialContextNoProxyURL(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{Err: expected},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerContextTakesPrecedence(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{Err: expected},
|
||||
ProxyURL: &url.URL{Scheme: "antani"},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx = dialer.WithProxyURL(ctx, &url.URL{Scheme: "socks5", Host: "[::1]:443"})
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextInvalidScheme(t *testing.T) {
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{},
|
||||
ProxyURL: &url.URL{Scheme: "antani"},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if err.Error() != "Scheme is not socks5" {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithEOF(t *testing.T) {
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: io.EOF,
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // immediately fail
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: io.EOF,
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithDialerSuccess(t *testing.T) {
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Conn: &dialer.FakeConn{
|
||||
ReadError: io.EOF,
|
||||
WriteError: io.EOF,
|
||||
},
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContextWithDialer(
|
||||
context.Background(), dialer.ProxyDialerWrapper{
|
||||
Dialer: d.Dialer,
|
||||
}, "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithDialerCanceledContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Stop immediately. The FakeDialer sleeps for some microseconds so
|
||||
// it is much more likely we immediately exit with done context. The
|
||||
// arm where we receive the conn is much less likely.
|
||||
cancel()
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Conn: &dialer.FakeConn{
|
||||
ReadError: io.EOF,
|
||||
WriteError: io.EOF,
|
||||
},
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContextWithDialer(
|
||||
ctx, dialer.ProxyDialerWrapper{
|
||||
Dialer: d.Dialer,
|
||||
}, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerWrapper(t *testing.T) {
|
||||
d := dialer.ProxyDialerWrapper{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: io.EOF,
|
||||
},
|
||||
}
|
||||
conn, err := d.Dial("tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// SaverDialer saves events occurring during the dial
|
||||
type SaverDialer struct {
|
||||
Dialer
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
start := time.Now()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
stop := time.Now()
|
||||
d.Saver.Write(trace.Event{
|
||||
Address: address,
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: errorx.ConnectOperation,
|
||||
Proto: network,
|
||||
Time: stop,
|
||||
})
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// SaverTLSHandshaker saves events occurring during the handshake
|
||||
type SaverTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// Handshake implements TLSHandshaker.Handshake
|
||||
func (h SaverTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
start := time.Now()
|
||||
h.Saver.Write(trace.Event{
|
||||
Name: "tls_handshake_start",
|
||||
NoTLSVerify: config.InsecureSkipVerify,
|
||||
TLSNextProtos: config.NextProtos,
|
||||
TLSServerName: config.ServerName,
|
||||
Time: start,
|
||||
})
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
stop := time.Now()
|
||||
h.Saver.Write(trace.Event{
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: "tls_handshake_done",
|
||||
NoTLSVerify: config.InsecureSkipVerify,
|
||||
TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite),
|
||||
TLSNegotiatedProto: state.NegotiatedProtocol,
|
||||
TLSNextProtos: config.NextProtos,
|
||||
TLSPeerCerts: trace.PeerCerts(state, err),
|
||||
TLSServerName: config.ServerName,
|
||||
TLSVersion: tlsx.VersionString(state.Version),
|
||||
Time: stop,
|
||||
})
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
// SaverConnDialer wraps the returned connection such that we
|
||||
// collect all the read/write events that occur.
|
||||
type SaverConnDialer struct {
|
||||
Dialer
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return saverConn{saver: d.Saver, Conn: conn}, nil
|
||||
}
|
||||
|
||||
type saverConn struct {
|
||||
net.Conn
|
||||
saver *trace.Saver
|
||||
}
|
||||
|
||||
func (c saverConn) Read(p []byte) (int, error) {
|
||||
start := time.Now()
|
||||
count, err := c.Conn.Read(p)
|
||||
stop := time.Now()
|
||||
c.saver.Write(trace.Event{
|
||||
Data: p[:count],
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
NumBytes: count,
|
||||
Name: errorx.ReadOperation,
|
||||
Time: stop,
|
||||
})
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c saverConn) Write(p []byte) (int, error) {
|
||||
start := time.Now()
|
||||
count, err := c.Conn.Write(p)
|
||||
stop := time.Now()
|
||||
c.saver.Write(trace.Event{
|
||||
Data: p[:count],
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
NumBytes: count,
|
||||
Name: errorx.WriteOperation,
|
||||
Time: stop,
|
||||
})
|
||||
return count, err
|
||||
}
|
||||
|
||||
var _ Dialer = SaverDialer{}
|
||||
var _ TLSHandshaker = SaverTLSHandshaker{}
|
||||
var _ net.Conn = saverConn{}
|
||||
@@ -0,0 +1,371 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
func TestSaverDialerFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
saver := &trace.Saver{}
|
||||
dlr := dialer.SaverDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected another error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 1 {
|
||||
t.Fatal("expected a single event here")
|
||||
}
|
||||
if ev[0].Address != "www.google.com:443" {
|
||||
t.Fatal("unexpected Address")
|
||||
}
|
||||
if ev[0].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if !errors.Is(ev[0].Err, expected) {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[0].Name != errorx.ConnectOperation {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[0].Proto != "tcp" {
|
||||
t.Fatal("unexpected Proto")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverConnDialerFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
saver := &trace.Saver{}
|
||||
dlr := dialer.SaverConnDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
|
||||
// This is the most common use case for collecting reads, writes
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
nextprotos := []string{"h2"}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Config: &tls.Config{NextProtos: nextprotos},
|
||||
Dialer: dialer.SaverConnDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
Saver: saver,
|
||||
},
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
// Implementation note: we don't close the connection here because it is
|
||||
// very handy to have the last event being the end of the handshake
|
||||
_, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) < 4 {
|
||||
// it's a bit tricky to be sure about the right number of
|
||||
// events because network conditions may influence that
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if ev[0].Name != "tls_handshake_start" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[0].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
last := len(ev) - 1
|
||||
for idx := 1; idx < last; idx++ {
|
||||
if ev[idx].Data == nil {
|
||||
t.Fatal("unexpected Data")
|
||||
}
|
||||
if ev[idx].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[idx].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[idx].NumBytes <= 0 {
|
||||
t.Fatal("unexpected NumBytes")
|
||||
}
|
||||
switch ev[idx].Name {
|
||||
case errorx.ReadOperation, errorx.WriteOperation:
|
||||
default:
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[idx].Time.Before(ev[idx-1].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
if ev[last].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[last].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[last].Name != "tls_handshake_done" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[last].TLSCipherSuite == "" {
|
||||
t.Fatal("unexpected TLSCipherSuite")
|
||||
}
|
||||
if ev[last].TLSNegotiatedProto != "h2" {
|
||||
t.Fatal("unexpected TLSNegotiatedProto")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[last].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[last].TLSPeerCerts == nil {
|
||||
t.Fatal("unexpected TLSPeerCerts")
|
||||
}
|
||||
if ev[last].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if ev[last].TLSVersion == "" {
|
||||
t.Fatal("unexpected TLSVersion")
|
||||
}
|
||||
if ev[last].Time.Before(ev[last-1].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
nextprotos := []string{"h2"}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Config: &tls.Config{NextProtos: nextprotos},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if ev[0].Name != "tls_handshake_start" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[0].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Name != "tls_handshake_done" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[1].TLSCipherSuite == "" {
|
||||
t.Fatal("unexpected TLSCipherSuite")
|
||||
}
|
||||
if ev[1].TLSNegotiatedProto != "h2" {
|
||||
t.Fatal("unexpected TLSNegotiatedProto")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[1].TLSPeerCerts == nil {
|
||||
t.Fatal("unexpected TLSPeerCerts")
|
||||
}
|
||||
if ev[1].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if ev[1].TLSVersion == "" {
|
||||
t.Fatal("unexpected TLSVersion")
|
||||
}
|
||||
if ev[1].Time.Before(ev[0].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerHostnameError(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "wrong.host.badssl.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify == true {
|
||||
t.Fatal("expected NoTLSVerify to be false")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "expired.badssl.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify == true {
|
||||
t.Fatal("expected NoTLSVerify to be false")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "self-signed.badssl.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify == true {
|
||||
t.Fatal("expected NoTLSVerify to be false")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Config: &tls.Config{InsecureSkipVerify: true},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "self-signed.badssl.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn here")
|
||||
}
|
||||
conn.Close()
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify != true {
|
||||
t.Fatal("expected NoTLSVerify to be true")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// +build !shaping
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ShapingDialer ensures we don't use too much bandwidth
|
||||
// when using integration tests at GitHub. To select
|
||||
// the implementation with shaping use `-tags shaping`.
|
||||
type ShapingDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ShapingDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
// +build shaping
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShapingDialer ensures we don't use too much bandwidth
|
||||
// when using integration tests at GitHub. To select
|
||||
// the implementation with shaping use `-tags shaping`.
|
||||
type ShapingDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ShapingDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shapingConn{Conn: conn}, nil
|
||||
}
|
||||
|
||||
type shapingConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c shapingConn) Read(p []byte) (int, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
func (c shapingConn) Write(p []byte) (int, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
txp := netx.NewHTTPTransport(netx.Config{
|
||||
Dialer: dialer.ShapingDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
},
|
||||
})
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TimeoutDialer is a Dialer that enforces a timeout
|
||||
type TimeoutDialer struct {
|
||||
Dialer
|
||||
ConnectTimeout time.Duration // default: 30 seconds
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d TimeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
timeout := 30 * time.Second
|
||||
if d.ConnectTimeout != 0 {
|
||||
timeout = d.ConnectTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
type SlowDialer struct{}
|
||||
|
||||
func (SlowDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(30 * time.Second):
|
||||
return nil, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutDialer(t *testing.T) {
|
||||
d := dialer.TimeoutDialer{Dialer: SlowDialer{}, ConnectTimeout: time.Second}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// TLSHandshaker is the generic TLS handshaker
|
||||
type TLSHandshaker interface {
|
||||
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
|
||||
net.Conn, tls.ConnectionState, error)
|
||||
}
|
||||
|
||||
// SystemTLSHandshaker is the system TLS handshaker.
|
||||
type SystemTLSHandshaker struct{}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h SystemTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
tlsconn := tls.Client(conn, config)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
return tlsconn, tlsconn.ConnectionState(), nil
|
||||
}
|
||||
|
||||
// TimeoutTLSHandshaker is a TLSHandshaker with timeout
|
||||
type TimeoutTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
HandshakeTimeout time.Duration // default: 10 second
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h TimeoutTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
timeout := 10 * time.Second
|
||||
if h.HandshakeTimeout != 0 {
|
||||
timeout = h.HandshakeTimeout
|
||||
}
|
||||
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
tlsconn, connstate, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
conn.SetDeadline(time.Time{})
|
||||
return tlsconn, connstate, err
|
||||
}
|
||||
|
||||
// ErrorWrapperTLSHandshaker wraps the returned error to be an OONI error
|
||||
type ErrorWrapperTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h ErrorWrapperTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: connID,
|
||||
Error: err,
|
||||
Operation: errorx.TLSHandshakeOperation,
|
||||
}.MaybeBuild()
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
// EmitterTLSHandshaker emits events using the MeasurementRoot
|
||||
type EmitterTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h EmitterTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
TLSHandshakeStart: &modelx.TLSHandshakeStartEvent{
|
||||
ConnID: connID,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
SNI: config.ServerName,
|
||||
},
|
||||
})
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
TLSHandshakeDone: &modelx.TLSHandshakeDoneEvent{
|
||||
ConnID: connID,
|
||||
ConnectionState: modelx.NewTLSConnectionState(state),
|
||||
Error: err,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
},
|
||||
})
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
// TLSDialer is the TLS dialer
|
||||
type TLSDialer struct {
|
||||
Config *tls.Config
|
||||
Dialer Dialer
|
||||
TLSHandshaker TLSHandshaker
|
||||
}
|
||||
|
||||
// DialTLSContext is like tls.DialTLS but with the signature of net.Dialer.DialContext
|
||||
func (d TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// Implementation note: when DialTLS is not set, the code in
|
||||
// net/http will perform the handshake. Otherwise, if DialTLS
|
||||
// is set, we will end up here. This code is still used when
|
||||
// performing non-HTTP TLS-enabled dial operations.
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config := d.Config
|
||||
if config == nil {
|
||||
config = new(tls.Config)
|
||||
} else {
|
||||
config = config.Clone()
|
||||
}
|
||||
if config.ServerName == "" {
|
||||
config.ServerName = host
|
||||
}
|
||||
tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return tlsconn, nil
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func TestSystemTLSHandshakerEOFError(t *testing.T) {
|
||||
h := dialer.SystemTLSHandshaker{}
|
||||
conn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
|
||||
ServerName: "x.org",
|
||||
})
|
||||
if err != io.EOF {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) {
|
||||
h := dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 200 * time.Millisecond,
|
||||
}
|
||||
expected := errors.New("mocked error")
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), &dialer.FakeConn{SetDeadlineError: expected},
|
||||
new(tls.Config))
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutTLSHandshakerEOFError(t *testing.T) {
|
||||
h := dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 200 * time.Millisecond,
|
||||
}
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), dialer.EOFConn{}, &tls.Config{ServerName: "x.org"})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) {
|
||||
h := dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 200 * time.Millisecond,
|
||||
}
|
||||
underlying := &SetDeadlineConn{}
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), underlying, &tls.Config{ServerName: "x.org"})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
if len(underlying.deadlines) != 2 {
|
||||
t.Fatal("SetDeadline not called twice")
|
||||
}
|
||||
if underlying.deadlines[0].Before(time.Now()) {
|
||||
t.Fatal("the first SetDeadline call was incorrect")
|
||||
}
|
||||
if !underlying.deadlines[1].IsZero() {
|
||||
t.Fatal("the second SetDeadline call was incorrect")
|
||||
}
|
||||
}
|
||||
|
||||
type SetDeadlineConn struct {
|
||||
dialer.EOFConn
|
||||
deadlines []time.Time
|
||||
}
|
||||
|
||||
func (c *SetDeadlineConn) SetDeadline(t time.Time) error {
|
||||
c.deadlines = append(c.deadlines, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestErrorWrapperTLSHandshakerFailure(t *testing.T) {
|
||||
h := dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}}
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), dialer.EOFConn{}, new(tls.Config))
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
var errWrapper *errorx.ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("cannot cast to ErrWrapper")
|
||||
}
|
||||
if errWrapper.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if errWrapper.Failure != errorx.FailureEOFError {
|
||||
t.Fatal("unexpected Failure")
|
||||
}
|
||||
if errWrapper.Operation != errorx.TLSHandshakeOperation {
|
||||
t.Fatal("unexpected Operation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterTLSHandshakerFailure(t *testing.T) {
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx := modelx.WithMeasurementRoot(context.Background(), &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
h := dialer.EmitterTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}}
|
||||
conn, _, err := h.Handshake(ctx, dialer.EOFConn{}, &tls.Config{
|
||||
ServerName: "www.kernel.org",
|
||||
})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
events := saver.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("Wrong number of events")
|
||||
}
|
||||
if events[0].TLSHandshakeStart == nil {
|
||||
t.Fatal("missing TLSHandshakeStart event")
|
||||
}
|
||||
if events[0].TLSHandshakeStart.ConnID == 0 {
|
||||
t.Fatal("expected nonzero ConnID")
|
||||
}
|
||||
if events[0].TLSHandshakeStart.DurationSinceBeginning == 0 {
|
||||
t.Fatal("expected nonzero DurationSinceBeginning")
|
||||
}
|
||||
if events[0].TLSHandshakeStart.SNI != "www.kernel.org" {
|
||||
t.Fatal("expected nonzero SNI")
|
||||
}
|
||||
if events[1].TLSHandshakeDone == nil {
|
||||
t.Fatal("missing TLSHandshakeDone event")
|
||||
}
|
||||
if events[1].TLSHandshakeDone.ConnID == 0 {
|
||||
t.Fatal("expected nonzero ConnID")
|
||||
}
|
||||
if events[1].TLSHandshakeDone.DurationSinceBeginning == 0 {
|
||||
t.Fatal("expected nonzero DurationSinceBeginning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureSplitHostPort(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com") // missing port
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureDialing(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureHandshaking(t *testing.T) {
|
||||
rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}}
|
||||
dialer := dialer.TLSDialer{
|
||||
Dialer: dialer.EOFConnDialer{},
|
||||
TLSHandshaker: rec,
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
if rec.SNI != "www.google.com" {
|
||||
t.Fatal("unexpected SNI value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) {
|
||||
rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}}
|
||||
dialer := dialer.TLSDialer{
|
||||
Config: &tls.Config{
|
||||
ServerName: "x.org",
|
||||
},
|
||||
Dialer: dialer.EOFConnDialer{},
|
||||
TLSHandshaker: rec,
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
if rec.SNI != "x.org" {
|
||||
t.Fatal("unexpected SNI value")
|
||||
}
|
||||
}
|
||||
|
||||
type RecorderTLSHandshaker struct {
|
||||
dialer.TLSHandshaker
|
||||
SNI string
|
||||
}
|
||||
|
||||
func (h *RecorderTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
h.SNI = config.ServerName
|
||||
return h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
}
|
||||
|
||||
func TestDialTLSContextGood(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{
|
||||
Config: &tls.Config{ServerName: "google.com"},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("connection is nil")
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialTLSContextTimeout(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{
|
||||
Config: &tls.Config{ServerName: "google.com"},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.ErrorWrapperTLSHandshaker{
|
||||
TLSHandshaker: dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 10 * time.Microsecond,
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
|
||||
if err.Error() != errorx.FailureGenericTimeoutError {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,322 @@
|
||||
// Package errorx contains error extensions
|
||||
package errorx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// FailureConnectionRefused means ECONNREFUSED.
|
||||
FailureConnectionRefused = "connection_refused"
|
||||
|
||||
// FailureConnectionReset means ECONNRESET.
|
||||
FailureConnectionReset = "connection_reset"
|
||||
|
||||
// FailureDNSBogonError means we detected bogon in DNS reply.
|
||||
FailureDNSBogonError = "dns_bogon_error"
|
||||
|
||||
// FailureDNSNXDOMAINError means we got NXDOMAIN in DNS reply.
|
||||
FailureDNSNXDOMAINError = "dns_nxdomain_error"
|
||||
|
||||
// FailureEOFError means we got unexpected EOF on connection.
|
||||
FailureEOFError = "eof_error"
|
||||
|
||||
// FailureGenericTimeoutError means we got some timer has expired.
|
||||
FailureGenericTimeoutError = "generic_timeout_error"
|
||||
|
||||
// FailureInterrupted means that the user interrupted us.
|
||||
FailureInterrupted = "interrupted"
|
||||
|
||||
// FailureNoCompatibleQUICVersion means that the server does not support the proposed QUIC version
|
||||
FailureNoCompatibleQUICVersion = "quic_incompatible_version"
|
||||
|
||||
// FailureSSLInvalidHostname means we got certificate is not valid for SNI.
|
||||
FailureSSLInvalidHostname = "ssl_invalid_hostname"
|
||||
|
||||
// FailureSSLUnknownAuthority means we cannot find CA validating certificate.
|
||||
FailureSSLUnknownAuthority = "ssl_unknown_authority"
|
||||
|
||||
// FailureSSLInvalidCertificate means certificate experired or other
|
||||
// sort of errors causing it to be invalid.
|
||||
FailureSSLInvalidCertificate = "ssl_invalid_certificate"
|
||||
|
||||
// FailureJSONParseError indicates that we couldn't parse a JSON
|
||||
FailureJSONParseError = "json_parse_error"
|
||||
)
|
||||
|
||||
const (
|
||||
// ResolveOperation is the operation where we resolve a domain name
|
||||
ResolveOperation = "resolve"
|
||||
|
||||
// ConnectOperation is the operation where we do a TCP connect
|
||||
ConnectOperation = "connect"
|
||||
|
||||
// TLSHandshakeOperation is the TLS handshake
|
||||
TLSHandshakeOperation = "tls_handshake"
|
||||
|
||||
// QUICHandshakeOperation is the handshake to setup a QUIC connection
|
||||
QUICHandshakeOperation = "quic_handshake"
|
||||
|
||||
// HTTPRoundTripOperation is the HTTP round trip
|
||||
HTTPRoundTripOperation = "http_round_trip"
|
||||
|
||||
// CloseOperation is when we close a socket
|
||||
CloseOperation = "close"
|
||||
|
||||
// ReadOperation is when we read from a socket
|
||||
ReadOperation = "read"
|
||||
|
||||
// WriteOperation is when we write to a socket
|
||||
WriteOperation = "write"
|
||||
|
||||
// ReadFromOperation is when we read from an UDP socket
|
||||
ReadFromOperation = "read_from"
|
||||
|
||||
// WriteToOperation is when we write to an UDP socket
|
||||
WriteToOperation = "write_to"
|
||||
|
||||
// UnknownOperation is when we cannot determine the operation
|
||||
UnknownOperation = "unknown"
|
||||
|
||||
// TopLevelOperation is used when the failure happens at top level. This
|
||||
// happens for example with urlgetter with a cancelled context.
|
||||
TopLevelOperation = "top_level"
|
||||
)
|
||||
|
||||
// ErrDNSBogon indicates that we found a bogon address. This is the
|
||||
// correct value with which to initialize MeasurementRoot.ErrDNSBogon
|
||||
// to tell this library to return an error when a bogon is found.
|
||||
var ErrDNSBogon = errors.New("dns: detected bogon address")
|
||||
|
||||
// ErrWrapper is our error wrapper for Go errors. The key objective of
|
||||
// this structure is to properly set Failure, which is also returned by
|
||||
// the Error() method, so be one of the OONI defined strings.
|
||||
type ErrWrapper struct {
|
||||
// ConnID is the connection ID, or zero if not known.
|
||||
ConnID int64
|
||||
|
||||
// DialID is the dial ID, or zero if not known.
|
||||
DialID int64
|
||||
|
||||
// Failure is the OONI failure string. The failure strings are
|
||||
// loosely backward compatible with Measurement Kit.
|
||||
//
|
||||
// This is either one of the FailureXXX strings or any other
|
||||
// string like `unknown_failure ...`. The latter represents an
|
||||
// error that we have not yet mapped to a failure.
|
||||
Failure string
|
||||
|
||||
// Operation is the operation that failed. If possible, it
|
||||
// SHOULD be a _major_ operation. Major operations are:
|
||||
//
|
||||
// - ResolveOperation: resolving a domain name failed
|
||||
// - ConnectOperation: connecting to an IP failed
|
||||
// - TLSHandshakeOperation: TLS handshaking failed
|
||||
// - HTTPRoundTripOperation: other errors during round trip
|
||||
//
|
||||
// Because a network connection doesn't necessarily know
|
||||
// what is the current major operation we also have the
|
||||
// following _minor_ operations:
|
||||
//
|
||||
// - CloseOperation: CLOSE failed
|
||||
// - ReadOperation: READ failed
|
||||
// - WriteOperation: WRITE failed
|
||||
//
|
||||
// If an ErrWrapper referring to a major operation is wrapping
|
||||
// another ErrWrapper and such ErrWrapper already refers to
|
||||
// a major operation, then the new ErrWrapper should use the
|
||||
// child ErrWrapper major operation. Otherwise, it should use
|
||||
// its own major operation. This way, the topmost wrapper is
|
||||
// supposed to refer to the major operation that failed.
|
||||
Operation string
|
||||
|
||||
// TransactionID is the transaction ID, or zero if not known.
|
||||
TransactionID int64
|
||||
|
||||
// WrappedErr is the error that we're wrapping.
|
||||
WrappedErr error
|
||||
}
|
||||
|
||||
// Error returns a description of the error that occurred.
|
||||
func (e *ErrWrapper) Error() string {
|
||||
return e.Failure
|
||||
}
|
||||
|
||||
// Unwrap allows to access the underlying error
|
||||
func (e *ErrWrapper) Unwrap() error {
|
||||
return e.WrappedErr
|
||||
}
|
||||
|
||||
// SafeErrWrapperBuilder contains a builder for ErrWrapper that
|
||||
// is safe, i.e., behaves correctly when the error is nil.
|
||||
type SafeErrWrapperBuilder struct {
|
||||
// ConnID is the connection ID, if any
|
||||
ConnID int64
|
||||
|
||||
// DialID is the dial ID, if any
|
||||
DialID int64
|
||||
|
||||
// Error is the error, if any
|
||||
Error error
|
||||
|
||||
// Operation is the operation that failed
|
||||
Operation string
|
||||
|
||||
// TransactionID is the transaction ID, if any
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
// MaybeBuild builds a new ErrWrapper, if b.Error is not nil, and returns
|
||||
// a nil error value, instead, if b.Error is nil.
|
||||
func (b SafeErrWrapperBuilder) MaybeBuild() (err error) {
|
||||
if b.Error != nil {
|
||||
err = &ErrWrapper{
|
||||
ConnID: b.ConnID,
|
||||
DialID: b.DialID,
|
||||
Failure: toFailureString(b.Error),
|
||||
Operation: toOperationString(b.Error, b.Operation),
|
||||
TransactionID: b.TransactionID,
|
||||
WrappedErr: b.Error,
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func toFailureString(err error) string {
|
||||
// The list returned here matches the values used by MK unless
|
||||
// explicitly noted otherwise with a comment.
|
||||
|
||||
var errwrapper *ErrWrapper
|
||||
if errors.As(err, &errwrapper) {
|
||||
return errwrapper.Error() // we've already wrapped it
|
||||
}
|
||||
|
||||
if errors.Is(err, ErrDNSBogon) {
|
||||
return FailureDNSBogonError // not in MK
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return FailureInterrupted
|
||||
}
|
||||
var x509HostnameError x509.HostnameError
|
||||
if errors.As(err, &x509HostnameError) {
|
||||
// Test case: https://wrong.host.badssl.com/
|
||||
return FailureSSLInvalidHostname
|
||||
}
|
||||
var x509UnknownAuthorityError x509.UnknownAuthorityError
|
||||
if errors.As(err, &x509UnknownAuthorityError) {
|
||||
// Test case: https://self-signed.badssl.com/. This error has
|
||||
// never been among the ones returned by MK.
|
||||
return FailureSSLUnknownAuthority
|
||||
}
|
||||
var x509CertificateInvalidError x509.CertificateInvalidError
|
||||
if errors.As(err, &x509CertificateInvalidError) {
|
||||
// Test case: https://expired.badssl.com/
|
||||
return FailureSSLInvalidCertificate
|
||||
}
|
||||
|
||||
s := err.Error()
|
||||
if strings.HasSuffix(s, "operation was canceled") {
|
||||
return FailureInterrupted
|
||||
}
|
||||
if strings.HasSuffix(s, "EOF") {
|
||||
return FailureEOFError
|
||||
}
|
||||
if strings.HasSuffix(s, "connection refused") {
|
||||
return FailureConnectionRefused
|
||||
}
|
||||
if strings.HasSuffix(s, "connection reset by peer") {
|
||||
return FailureConnectionReset
|
||||
}
|
||||
if strings.HasSuffix(s, "context deadline exceeded") {
|
||||
return FailureGenericTimeoutError
|
||||
}
|
||||
if strings.HasSuffix(s, "transaction is timed out") {
|
||||
return FailureGenericTimeoutError
|
||||
}
|
||||
if strings.HasSuffix(s, "i/o timeout") {
|
||||
return FailureGenericTimeoutError
|
||||
}
|
||||
if strings.HasSuffix(s, "TLS handshake timeout") {
|
||||
return FailureGenericTimeoutError
|
||||
}
|
||||
if strings.HasSuffix(s, "no such host") {
|
||||
// This is dns_lookup_error in MK but such error is used as a
|
||||
// generic "hey, the lookup failed" error. Instead, this error
|
||||
// that we return here is significantly more specific.
|
||||
return FailureDNSNXDOMAINError
|
||||
}
|
||||
|
||||
// TODO(kelmenhorst): see whether it is possible to match errors
|
||||
// from qtls rather than strings for TLS errors below.
|
||||
//
|
||||
// TODO(kelmenhorst): make sure we have tests for all errors. Also,
|
||||
// how to ensure we are robust to changes in other libs?
|
||||
//
|
||||
// special QUIC errors
|
||||
matched, err := regexp.MatchString(`.*x509: certificate is valid for.*not.*`, s)
|
||||
if matched {
|
||||
return FailureSSLInvalidHostname
|
||||
}
|
||||
if strings.HasSuffix(s, "x509: certificate signed by unknown authority") {
|
||||
return FailureSSLUnknownAuthority
|
||||
}
|
||||
certInvalidErrors := []string{"x509: certificate is not authorized to sign other certificates", "x509: certificate has expired or is not yet valid:", "x509: a root or intermediate certificate is not authorized to sign for this name:", "x509: a root or intermediate certificate is not authorized for an extended key usage:", "x509: too many intermediates for path length constraint", "x509: certificate specifies an incompatible key usage", "x509: issuer name does not match subject from issuing certificate", "x509: issuer has name constraints but leaf doesn't have a SAN extension", "x509: issuer has name constraints but leaf contains unknown or unconstrained name:"}
|
||||
for _, errstr := range certInvalidErrors {
|
||||
if strings.Contains(s, errstr) {
|
||||
return FailureSSLInvalidCertificate
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(s, "No compatible QUIC version found") {
|
||||
return FailureNoCompatibleQUICVersion
|
||||
}
|
||||
if strings.HasSuffix(s, "Handshake did not complete in time") {
|
||||
return FailureGenericTimeoutError
|
||||
}
|
||||
if strings.HasSuffix(s, "connection_refused") {
|
||||
return FailureConnectionRefused
|
||||
}
|
||||
if strings.Contains(s, "stateless_reset") {
|
||||
return FailureConnectionReset
|
||||
}
|
||||
if strings.Contains(s, "deadline exceeded") {
|
||||
return FailureGenericTimeoutError
|
||||
}
|
||||
formatted := fmt.Sprintf("unknown_failure: %s", s)
|
||||
return Scrub(formatted) // scrub IP addresses in the error
|
||||
}
|
||||
|
||||
func toOperationString(err error, operation string) string {
|
||||
var errwrapper *ErrWrapper
|
||||
if errors.As(err, &errwrapper) {
|
||||
// Basically, as explained in ErrWrapper docs, let's
|
||||
// keep the child major operation, if any.
|
||||
if errwrapper.Operation == ConnectOperation {
|
||||
return errwrapper.Operation
|
||||
}
|
||||
if errwrapper.Operation == HTTPRoundTripOperation {
|
||||
return errwrapper.Operation
|
||||
}
|
||||
if errwrapper.Operation == ResolveOperation {
|
||||
return errwrapper.Operation
|
||||
}
|
||||
if errwrapper.Operation == TLSHandshakeOperation {
|
||||
return errwrapper.Operation
|
||||
}
|
||||
if errwrapper.Operation == QUICHandshakeOperation {
|
||||
return errwrapper.Operation
|
||||
}
|
||||
if errwrapper.Operation == "quic_handshake_start" {
|
||||
return QUICHandshakeOperation
|
||||
}
|
||||
if errwrapper.Operation == "quic_handshake_done" {
|
||||
return QUICHandshakeOperation
|
||||
}
|
||||
// FALLTHROUGH
|
||||
}
|
||||
return operation
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package errorx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/pion/stun"
|
||||
)
|
||||
|
||||
func TestMaybeBuildFactory(t *testing.T) {
|
||||
err := SafeErrWrapperBuilder{
|
||||
ConnID: 1,
|
||||
DialID: 10,
|
||||
Error: errors.New("mocked error"),
|
||||
TransactionID: 100,
|
||||
}.MaybeBuild()
|
||||
var target *ErrWrapper
|
||||
if errors.As(err, &target) == false {
|
||||
t.Fatal("not the expected error type")
|
||||
}
|
||||
if target.ConnID != 1 {
|
||||
t.Fatal("wrong ConnID")
|
||||
}
|
||||
if target.DialID != 10 {
|
||||
t.Fatal("wrong DialID")
|
||||
}
|
||||
if target.Failure != "unknown_failure: mocked error" {
|
||||
t.Fatal("the failure string is wrong")
|
||||
}
|
||||
if target.TransactionID != 100 {
|
||||
t.Fatal("the transactionID is wrong")
|
||||
}
|
||||
if target.WrappedErr.Error() != "mocked error" {
|
||||
t.Fatal("the wrapped error is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToFailureString(t *testing.T) {
|
||||
t.Run("for already wrapped error", func(t *testing.T) {
|
||||
err := SafeErrWrapperBuilder{Error: io.EOF}.MaybeBuild()
|
||||
if toFailureString(err) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for ErrDNSBogon", func(t *testing.T) {
|
||||
if toFailureString(ErrDNSBogon) != FailureDNSBogonError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for context.Canceled", func(t *testing.T) {
|
||||
if toFailureString(context.Canceled) != FailureInterrupted {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for x509.HostnameError", func(t *testing.T) {
|
||||
var err x509.HostnameError
|
||||
if toFailureString(err) != FailureSSLInvalidHostname {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
|
||||
var err x509.UnknownAuthorityError
|
||||
if toFailureString(err) != FailureSSLUnknownAuthority {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
|
||||
var err x509.CertificateInvalidError
|
||||
if toFailureString(err) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for operation was canceled error", func(t *testing.T) {
|
||||
if toFailureString(errors.New("operation was canceled")) != FailureInterrupted {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for EOF", func(t *testing.T) {
|
||||
if toFailureString(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for connection_refused", func(t *testing.T) {
|
||||
if toFailureString(syscall.ECONNREFUSED) != FailureConnectionRefused {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for connection_reset", func(t *testing.T) {
|
||||
if toFailureString(syscall.ECONNRESET) != FailureConnectionReset {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for context deadline exceeded", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1)
|
||||
defer cancel()
|
||||
<-ctx.Done()
|
||||
if toFailureString(ctx.Err()) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for stun's transaction is timed out", func(t *testing.T) {
|
||||
if toFailureString(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for i/o error", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1)
|
||||
defer cancel() // fail immediately
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", "www.google.com:80")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil connection here")
|
||||
}
|
||||
if toFailureString(err) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for TLS handshake timeout error", func(t *testing.T) {
|
||||
err := errors.New("net/http: TLS handshake timeout")
|
||||
if toFailureString(err) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for no such host", func(t *testing.T) {
|
||||
if toFailureString(&net.DNSError{
|
||||
Err: "no such host",
|
||||
}) != FailureDNSNXDOMAINError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for errors including IPv4 address", func(t *testing.T) {
|
||||
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
|
||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: use of closed network connection"
|
||||
out := toFailureString(input)
|
||||
if out != expected {
|
||||
t.Fatal(cmp.Diff(expected, out))
|
||||
}
|
||||
})
|
||||
t.Run("for errors including IPv6 address", func(t *testing.T) {
|
||||
input := errors.New("read tcp [::1]:56948->[::1]:443: use of closed network connection")
|
||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: use of closed network connection"
|
||||
out := toFailureString(input)
|
||||
if out != expected {
|
||||
t.Fatal(cmp.Diff(expected, out))
|
||||
}
|
||||
})
|
||||
// QUIC failures
|
||||
t.Run("for connection_refused", func(t *testing.T) {
|
||||
if toFailureString(errors.New("connection_refused")) != FailureConnectionRefused {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for connection_reset", func(t *testing.T) {
|
||||
if toFailureString(errors.New("stateless_reset")) != FailureConnectionReset {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for incompatible quic version", func(t *testing.T) {
|
||||
if toFailureString(errors.New("No compatible QUIC version found")) != FailureNoCompatibleQUICVersion {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for i/o error", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1)
|
||||
defer cancel() // fail immediately
|
||||
udpAddr := &net.UDPAddr{IP: net.ParseIP("216.58.212.164"), Port: 80, Zone: ""}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
sess, err := quic.DialEarlyContext(ctx, udpConn, udpAddr, "google.com:80", &tls.Config{}, &quic.Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil session here")
|
||||
}
|
||||
if toFailureString(err) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
t.Run("for QUIC handshake timeout error", func(t *testing.T) {
|
||||
err := errors.New("Handshake did not complete in time")
|
||||
if toFailureString(err) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestToOperationString(t *testing.T) {
|
||||
t.Run("for connect", func(t *testing.T) {
|
||||
// You're doing HTTP and connect fails. You want to know
|
||||
// that connect failed not that HTTP failed.
|
||||
err := &ErrWrapper{Operation: ConnectOperation}
|
||||
if toOperationString(err, HTTPRoundTripOperation) != ConnectOperation {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for http_round_trip", func(t *testing.T) {
|
||||
// You're doing DoH and something fails inside HTTP. You want
|
||||
// to know about the internal HTTP error, not resolve.
|
||||
err := &ErrWrapper{Operation: HTTPRoundTripOperation}
|
||||
if toOperationString(err, ResolveOperation) != HTTPRoundTripOperation {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for resolve", func(t *testing.T) {
|
||||
// You're doing HTTP and the DNS fails. You want to
|
||||
// know that resolve failed.
|
||||
err := &ErrWrapper{Operation: ResolveOperation}
|
||||
if toOperationString(err, HTTPRoundTripOperation) != ResolveOperation {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for tls_handshake", func(t *testing.T) {
|
||||
// You're doing HTTP and the TLS handshake fails. You want
|
||||
// to know about a TLS handshake error.
|
||||
err := &ErrWrapper{Operation: TLSHandshakeOperation}
|
||||
if toOperationString(err, HTTPRoundTripOperation) != TLSHandshakeOperation {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for minor operation", func(t *testing.T) {
|
||||
// You just noticed that TLS handshake failed and you
|
||||
// have a child error telling you that read failed. Here
|
||||
// you want to know about a TLS handshake error.
|
||||
err := &ErrWrapper{Operation: ReadOperation}
|
||||
if toOperationString(err, TLSHandshakeOperation) != TLSHandshakeOperation {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
t.Run("for quic_handshake", func(t *testing.T) {
|
||||
// You're doing HTTP and the TLS handshake fails. You want
|
||||
// to know about a TLS handshake error.
|
||||
err := &ErrWrapper{Operation: QUICHandshakeOperation}
|
||||
if toOperationString(err, HTTPRoundTripOperation) != QUICHandshakeOperation {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package errorx
|
||||
|
||||
import "regexp"
|
||||
|
||||
// The code in this file is adapted from github.com/keroserene/snowflake's
|
||||
// common/safelog/safelog.go implementation <https://git.io/JfO9w>.
|
||||
//
|
||||
// ================================================================================
|
||||
// Copyright (c) 2016, Serene Han, Arlo Breault
|
||||
// Copyright (c) 2019-2020, The Tor Project, Inc
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright notice, this
|
||||
// list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistributions in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation and/or
|
||||
// other materials provided with the distribution.
|
||||
//
|
||||
// * Neither the names of the copyright owners nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from this
|
||||
// software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
// ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// ================================================================================
|
||||
|
||||
const ipv4Address = `\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}`
|
||||
const ipv6Address = `([0-9a-fA-F]{0,4}:){5,7}([0-9a-fA-F]{0,4})?`
|
||||
const ipv6Compressed = `([0-9a-fA-F]{0,4}:){0,5}([0-9a-fA-F]{0,4})?(::)([0-9a-fA-F]{0,4}:){0,5}([0-9a-fA-F]{0,4})?`
|
||||
const ipv6Full = `(` + ipv6Address + `(` + ipv4Address + `))` +
|
||||
`|(` + ipv6Compressed + `(` + ipv4Address + `))` +
|
||||
`|(` + ipv6Address + `)` + `|(` + ipv6Compressed + `)`
|
||||
const optionalPort = `(:\d{1,5})?`
|
||||
const addressPattern = `((` + ipv4Address + `)|(\[(` + ipv6Full + `)\])|(` + ipv6Full + `))` + optionalPort
|
||||
const fullAddrPattern = `(^|\s|[^\w:])` + addressPattern + `(\s|(:\s)|[^\w:]|$)`
|
||||
|
||||
var scrubberPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(fullAddrPattern),
|
||||
}
|
||||
|
||||
var addressRegexp = regexp.MustCompile(addressPattern)
|
||||
|
||||
func scrub(b []byte) []byte {
|
||||
scrubbedBytes := b
|
||||
for _, pattern := range scrubberPatterns {
|
||||
// this is a workaround since go does not yet support look ahead or look
|
||||
// behind for regular expressions.
|
||||
scrubbedBytes = pattern.ReplaceAllFunc(scrubbedBytes, func(b []byte) []byte {
|
||||
return addressRegexp.ReplaceAll(b, []byte("[scrubbed]"))
|
||||
})
|
||||
}
|
||||
return scrubbedBytes
|
||||
}
|
||||
|
||||
// Scrub sanitizes a string containing an error such that
|
||||
// any occurrence of IP endpoints is scrubbed
|
||||
func Scrub(s string) string {
|
||||
return string(scrub([]byte(s)))
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package errorx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
// The code in this file is adapted from github.com/keroserene/snowflake's
|
||||
// common/safelog/safelog.go implementation <https://git.io/JfO9w>.
|
||||
//
|
||||
// ================================================================================
|
||||
// Copyright (c) 2016, Serene Han, Arlo Breault
|
||||
// Copyright (c) 2019-2020, The Tor Project, Inc
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright notice, this
|
||||
// list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistributions in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation and/or
|
||||
// other materials provided with the distribution.
|
||||
//
|
||||
// * Neither the names of the copyright owners nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from this
|
||||
// software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
// ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// ================================================================================
|
||||
|
||||
//Test the log scrubber on known problematic log messages
|
||||
func TestLogScrubberMessages(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
input, expected string
|
||||
}{
|
||||
{
|
||||
"http: TLS handshake error from 129.97.208.23:38310: ",
|
||||
"http: TLS handshake error from [scrubbed]: ",
|
||||
},
|
||||
{
|
||||
"http2: panic serving [2620:101:f000:780:9097:75b1:519f:dbb8]:58344: interface conversion: *http2.responseWriter is not http.Hijacker: missing method Hijack",
|
||||
"http2: panic serving [scrubbed]: interface conversion: *http2.responseWriter is not http.Hijacker: missing method Hijack",
|
||||
},
|
||||
{
|
||||
//Make sure it doesn't scrub fingerprint
|
||||
"a=fingerprint:sha-256 33:B6:FA:F6:94:CA:74:61:45:4A:D2:1F:2C:2F:75:8A:D9:EB:23:34:B2:30:E9:1B:2A:A6:A9:E0:44:72:CC:74",
|
||||
"a=fingerprint:sha-256 33:B6:FA:F6:94:CA:74:61:45:4A:D2:1F:2C:2F:75:8A:D9:EB:23:34:B2:30:E9:1B:2A:A6:A9:E0:44:72:CC:74",
|
||||
},
|
||||
{
|
||||
//try with enclosing parens
|
||||
"(1:2:3:4:c:d:e:f) {1:2:3:4:c:d:e:f}",
|
||||
"([scrubbed]) {[scrubbed]}",
|
||||
},
|
||||
{
|
||||
//Make sure it doesn't scrub timestamps
|
||||
"2019/05/08 15:37:31 starting",
|
||||
"2019/05/08 15:37:31 starting",
|
||||
},
|
||||
} {
|
||||
if Scrub(test.input) != test.expected {
|
||||
t.Error(cmp.Diff(test.input, test.expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogScrubberGoodFormats(t *testing.T) {
|
||||
for _, addr := range []string{
|
||||
// IPv4
|
||||
"1.2.3.4",
|
||||
"255.255.255.255",
|
||||
// IPv4 with port
|
||||
"1.2.3.4:55",
|
||||
"255.255.255.255:65535",
|
||||
// IPv6
|
||||
"1:2:3:4:c:d:e:f",
|
||||
"1111:2222:3333:4444:CCCC:DDDD:EEEE:FFFF",
|
||||
// IPv6 with brackets
|
||||
"[1:2:3:4:c:d:e:f]",
|
||||
"[1111:2222:3333:4444:CCCC:DDDD:EEEE:FFFF]",
|
||||
// IPv6 with brackets and port
|
||||
"[1:2:3:4:c:d:e:f]:55",
|
||||
"[1111:2222:3333:4444:CCCC:DDDD:EEEE:FFFF]:65535",
|
||||
// compressed IPv6
|
||||
"::f",
|
||||
"::d:e:f",
|
||||
"1:2:3::",
|
||||
"1:2:3::d:e:f",
|
||||
"1:2:3:d:e:f::",
|
||||
"::1:2:3:d:e:f",
|
||||
"1111:2222:3333::DDDD:EEEE:FFFF",
|
||||
// compressed IPv6 with brackets
|
||||
"[::d:e:f]",
|
||||
"[1:2:3::]",
|
||||
"[1:2:3::d:e:f]",
|
||||
"[1111:2222:3333::DDDD:EEEE:FFFF]",
|
||||
"[1:2:3:4:5:6::8]",
|
||||
"[1::7:8]",
|
||||
// compressed IPv6 with brackets and port
|
||||
"[1::]:58344",
|
||||
"[::d:e:f]:55",
|
||||
"[1:2:3::]:55",
|
||||
"[1:2:3::d:e:f]:55",
|
||||
"[1111:2222:3333::DDDD:EEEE:FFFF]:65535",
|
||||
// IPv4-compatible and IPv4-mapped
|
||||
"::255.255.255.255",
|
||||
"::ffff:255.255.255.255",
|
||||
"[::255.255.255.255]",
|
||||
"[::ffff:255.255.255.255]",
|
||||
"[::255.255.255.255]:65535",
|
||||
"[::ffff:255.255.255.255]:65535",
|
||||
"[::ffff:0:255.255.255.255]",
|
||||
"[2001:db8:3:4::192.0.2.33]",
|
||||
} {
|
||||
if Scrub(addr) != "[scrubbed]" {
|
||||
t.Error(cmp.Diff(addr, "[scrubbed]"))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package netx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FakeDialer struct {
|
||||
Conn net.Conn
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return d.Conn, d.Err
|
||||
}
|
||||
|
||||
type FakeTransport struct {
|
||||
Err error
|
||||
Func func(*http.Request) (*http.Response, error)
|
||||
Resp *http.Response
|
||||
}
|
||||
|
||||
func (txp FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
if txp.Func != nil {
|
||||
return txp.Func(req)
|
||||
}
|
||||
if req.Body != nil {
|
||||
ioutil.ReadAll(req.Body)
|
||||
req.Body.Close()
|
||||
}
|
||||
if txp.Err != nil {
|
||||
return nil, txp.Err
|
||||
}
|
||||
txp.Resp.Request = req // non thread safe but it doesn't matter
|
||||
return txp.Resp, nil
|
||||
}
|
||||
|
||||
func (txp FakeTransport) CloseIdleConnections() {}
|
||||
|
||||
type FakeBody struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (fb FakeBody) Read(p []byte) (int, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return 0, fb.Err
|
||||
}
|
||||
|
||||
func (fb FakeBody) Close() error {
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,95 @@
|
||||
// +build ignore
|
||||
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||
//
|
||||
// Forked from github.com/certifi/gocertifi <https://git.io/JJjmG>.
|
||||
//
|
||||
// This script should not be invoked directly, rather it should be
|
||||
// executed by running go generate ./... from toplevel dir.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
)
|
||||
|
||||
var tmpl = template.Must(template.New("").Parse(`// Code generated by go generate; DO NOT EDIT.
|
||||
// {{ .Timestamp }}
|
||||
// {{ .URL }}
|
||||
|
||||
package gocertifi
|
||||
|
||||
//go:generate go run generate.go "{{ .URL }}"
|
||||
|
||||
import "crypto/x509"
|
||||
|
||||
const pemcerts string = ` + "`" + `
|
||||
{{ .Bundle }}
|
||||
` + "`" + `
|
||||
|
||||
// CACerts builds an X.509 certificate pool containing the
|
||||
// certificate bundle from {{ .URL }} fetch on {{ .Timestamp }}.
|
||||
// Returns nil on error along with an appropriate error code.
|
||||
func CACerts() (*x509.CertPool, error) {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AppendCertsFromPEM([]byte(pemcerts))
|
||||
return pool, nil
|
||||
}
|
||||
`))
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 2 || !strings.HasPrefix(os.Args[1], "https://") {
|
||||
log.Fatal("usage: go run generate.go <url>")
|
||||
}
|
||||
url := os.Args[1]
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
log.Fatal("expected 200, got", resp.StatusCode)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bundle, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(bundle) {
|
||||
log.Fatalf("can't parse certificates from %s", url)
|
||||
}
|
||||
|
||||
fp, err := os.Create("certifi.go")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = tmpl.Execute(fp, struct {
|
||||
Timestamp time.Time
|
||||
URL string
|
||||
Bundle string
|
||||
}{
|
||||
Timestamp: time.Now(),
|
||||
URL: url,
|
||||
Bundle: string(bundle),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := fp.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package httptransport
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
)
|
||||
|
||||
// ByteCountingTransport is a RoundTripper that counts bytes.
|
||||
type ByteCountingTransport struct {
|
||||
RoundTripper
|
||||
Counter *bytecounter.Counter
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp ByteCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil {
|
||||
req.Body = byteCountingBody{
|
||||
ReadCloser: req.Body, Account: txp.Counter.CountBytesSent}
|
||||
}
|
||||
txp.estimateRequestMetadata(req)
|
||||
resp, err := txp.RoundTripper.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
txp.estimateResponseMetadata(resp)
|
||||
resp.Body = byteCountingBody{
|
||||
ReadCloser: resp.Body, Account: txp.Counter.CountBytesReceived}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (txp ByteCountingTransport) estimateRequestMetadata(req *http.Request) {
|
||||
txp.Counter.CountBytesSent(len(req.Method))
|
||||
txp.Counter.CountBytesSent(len(req.URL.String()))
|
||||
for key, values := range req.Header {
|
||||
for _, value := range values {
|
||||
txp.Counter.CountBytesSent(len(key))
|
||||
txp.Counter.CountBytesSent(len(": "))
|
||||
txp.Counter.CountBytesSent(len(value))
|
||||
txp.Counter.CountBytesSent(len("\r\n"))
|
||||
}
|
||||
}
|
||||
txp.Counter.CountBytesSent(len("\r\n"))
|
||||
}
|
||||
|
||||
func (txp ByteCountingTransport) estimateResponseMetadata(resp *http.Response) {
|
||||
txp.Counter.CountBytesReceived(len(resp.Status))
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
txp.Counter.CountBytesReceived(len(key))
|
||||
txp.Counter.CountBytesReceived(len(": "))
|
||||
txp.Counter.CountBytesReceived(len(value))
|
||||
txp.Counter.CountBytesReceived(len("\r\n"))
|
||||
}
|
||||
}
|
||||
txp.Counter.CountBytesReceived(len("\r\n"))
|
||||
}
|
||||
|
||||
type byteCountingBody struct {
|
||||
io.ReadCloser
|
||||
Account func(int)
|
||||
}
|
||||
|
||||
func (r byteCountingBody) Read(p []byte) (int, error) {
|
||||
count, err := r.ReadCloser.Read(p)
|
||||
if count > 0 {
|
||||
r.Account(count)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
var _ RoundTripper = ByteCountingTransport{}
|
||||
@@ -0,0 +1,128 @@
|
||||
package httptransport_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
|
||||
)
|
||||
|
||||
func TestByteCounterFailure(t *testing.T) {
|
||||
counter := bytecounter.New()
|
||||
txp := httptransport.ByteCountingTransport{
|
||||
Counter: counter,
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Err: io.EOF,
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: txp}
|
||||
req, err := http.NewRequest(
|
||||
"POST", "https://www.google.com", strings.NewReader("AAAAAA"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "antani-browser/1.0.0")
|
||||
resp, err := client.Do(req)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
if counter.Sent.Load() != 68 {
|
||||
t.Fatal("expected around 68 bytes sent")
|
||||
}
|
||||
if counter.Received.Load() != 0 {
|
||||
t.Fatal("expected zero bytes received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterSuccess(t *testing.T) {
|
||||
counter := bytecounter.New()
|
||||
txp := httptransport.ByteCountingTransport{
|
||||
Counter: counter,
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Resp: &http.Response{
|
||||
Body: ioutil.NopCloser(strings.NewReader("1234567")),
|
||||
Header: http.Header{
|
||||
"Server": []string{"antani/0.1.0"},
|
||||
},
|
||||
Status: "200 OK",
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: txp}
|
||||
req, err := http.NewRequest(
|
||||
"POST", "https://www.google.com", strings.NewReader("AAAAAA"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("User-Agent", "antani-browser/1.0.0")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if string(data) != "1234567" {
|
||||
t.Fatal("expected a different body here")
|
||||
}
|
||||
if counter.Sent.Load() != 68 {
|
||||
t.Fatal("expected around 68 bytes sent")
|
||||
}
|
||||
if counter.Received.Load() != 37 {
|
||||
t.Fatal("expected zero around 37 bytes received")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterSuccessWithEOF(t *testing.T) {
|
||||
counter := bytecounter.New()
|
||||
txp := httptransport.ByteCountingTransport{
|
||||
Counter: counter,
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Resp: &http.Response{
|
||||
Body: bodyReaderWithEOF{},
|
||||
Header: http.Header{
|
||||
"Server": []string{"antani/0.1.0"},
|
||||
},
|
||||
Status: "200 OK",
|
||||
StatusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if string(data) != "A" {
|
||||
t.Fatal("expected a different body here")
|
||||
}
|
||||
}
|
||||
|
||||
type bodyReaderWithEOF struct{}
|
||||
|
||||
func (bodyReaderWithEOF) Read(p []byte) (int, error) {
|
||||
if len(p) < 1 {
|
||||
panic("should not happen")
|
||||
}
|
||||
p[0] = 'A'
|
||||
return 1, io.EOF // we want code to be robust to this
|
||||
}
|
||||
func (bodyReaderWithEOF) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package httptransport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FakeDialer struct {
|
||||
Conn net.Conn
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return d.Conn, d.Err
|
||||
}
|
||||
|
||||
type FakeTransport struct {
|
||||
Err error
|
||||
Func func(*http.Request) (*http.Response, error)
|
||||
Resp *http.Response
|
||||
}
|
||||
|
||||
func (txp FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
if txp.Func != nil {
|
||||
return txp.Func(req)
|
||||
}
|
||||
if req.Body != nil {
|
||||
ioutil.ReadAll(req.Body)
|
||||
req.Body.Close()
|
||||
}
|
||||
if txp.Err != nil {
|
||||
return nil, txp.Err
|
||||
}
|
||||
txp.Resp.Request = req // non thread safe but it doesn't matter
|
||||
return txp.Resp, nil
|
||||
}
|
||||
|
||||
func (txp FakeTransport) CloseIdleConnections() {}
|
||||
|
||||
type FakeBody struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (fb FakeBody) Read(p []byte) (int, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return 0, fb.Err
|
||||
}
|
||||
|
||||
func (fb FakeBody) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package httptransport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/http3"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
|
||||
)
|
||||
|
||||
// QUICWrapperDialer is a QUICDialer that wraps a ContextDialer
|
||||
// This is necessary because the http3 RoundTripper does not support a DialContext method.
|
||||
type QUICWrapperDialer struct {
|
||||
Dialer quicdialer.ContextDialer
|
||||
}
|
||||
|
||||
// Dial implements QUICDialer.Dial
|
||||
func (d QUICWrapperDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
return d.Dialer.DialContext(context.Background(), network, host, tlsCfg, cfg)
|
||||
}
|
||||
|
||||
// HTTP3Transport is a httptransport.RoundTripper using the http3 protocol.
|
||||
type HTTP3Transport struct {
|
||||
http3.RoundTripper
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes all the connections opened by this transport.
|
||||
func (t *HTTP3Transport) CloseIdleConnections() {
|
||||
t.RoundTripper.Close()
|
||||
}
|
||||
|
||||
// NewHTTP3Transport creates a new HTTP3Transport instance.
|
||||
func NewHTTP3Transport(config Config) RoundTripper {
|
||||
txp := &HTTP3Transport{}
|
||||
txp.QuicConfig = &quic.Config{}
|
||||
txp.TLSClientConfig = config.TLSConfig
|
||||
txp.Dial = config.QUICDialer.Dial
|
||||
return txp
|
||||
}
|
||||
|
||||
var _ RoundTripper = &http.Transport{}
|
||||
@@ -0,0 +1,157 @@
|
||||
package httptransport_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor"
|
||||
)
|
||||
|
||||
type MockQUICDialer struct{}
|
||||
|
||||
func (d MockQUICDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
return quic.DialAddrEarly(host, tlsCfg, cfg)
|
||||
}
|
||||
|
||||
type MockSNIQUICDialer struct {
|
||||
namech chan string
|
||||
}
|
||||
|
||||
func (d MockSNIQUICDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
d.namech <- tlsCfg.ServerName
|
||||
return quic.DialAddrEarly(host, tlsCfg, cfg)
|
||||
}
|
||||
|
||||
type MockCertQUICDialer struct {
|
||||
certch chan *x509.CertPool
|
||||
}
|
||||
|
||||
func (d MockCertQUICDialer) Dial(network, host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
d.certch <- tlsCfg.RootCAs
|
||||
return quic.DialAddrEarly(host, tlsCfg, cfg)
|
||||
}
|
||||
|
||||
func TestHTTP3TransportSNI(t *testing.T) {
|
||||
namech := make(chan string, 1)
|
||||
sni := "sni.org"
|
||||
txp := httptransport.NewHTTP3Transport(httptransport.Config{
|
||||
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockSNIQUICDialer{namech: namech}, TLSConfig: &tls.Config{ServerName: sni}})
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error here")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "certificate is valid for www.google.com, not "+sni) {
|
||||
t.Fatal("unexpected error type", err)
|
||||
}
|
||||
servername := <-namech
|
||||
if servername != sni {
|
||||
t.Fatal("unexpected server name", servername)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP3TransportSNINoVerify(t *testing.T) {
|
||||
namech := make(chan string, 1)
|
||||
sni := "sni.org"
|
||||
txp := httptransport.NewHTTP3Transport(httptransport.Config{
|
||||
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockSNIQUICDialer{namech: namech}, TLSConfig: &tls.Config{ServerName: sni, InsecureSkipVerify: true}})
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %+v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("unexpected nil resp")
|
||||
}
|
||||
servername := <-namech
|
||||
if servername != sni {
|
||||
t.Fatal("unexpected server name", servername)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP3TransportCABundle(t *testing.T) {
|
||||
certch := make(chan *x509.CertPool, 1)
|
||||
certpool := x509.NewCertPool()
|
||||
txp := httptransport.NewHTTP3Transport(httptransport.Config{
|
||||
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockCertQUICDialer{certch: certch}, TLSConfig: &tls.Config{RootCAs: certpool}})
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error here")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
// since the certificate pool is empty, the unknown authority error should be thrown
|
||||
if !strings.Contains(err.Error(), "certificate signed by unknown authority") {
|
||||
t.Fatal("unexpected error type")
|
||||
}
|
||||
certs := <-certch
|
||||
if certs != certpool {
|
||||
t.Fatal("not the certpool we expected")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestUnitHTTP3TransportSuccess(t *testing.T) {
|
||||
txp := httptransport.NewHTTP3Transport(httptransport.Config{
|
||||
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockQUICDialer{}})
|
||||
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("unexpected nil response here")
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatal("HTTP statuscode should be 200 OK", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnitHTTP3TransportFailure(t *testing.T) {
|
||||
txp := httptransport.NewHTTP3Transport(httptransport.Config{
|
||||
Dialer: selfcensor.SystemDialer{}, QUICDialer: MockQUICDialer{}})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // so that the request immediately fails
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error here")
|
||||
}
|
||||
// context.Canceled error occurs if the test host supports QUIC
|
||||
// timeout error ("Handshake did not complete in time") occurs if the test host does not support QUIC
|
||||
if !(errors.Is(err, context.Canceled) || strings.HasSuffix(err.Error(), "Handshake did not complete in time")) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
// Package httptransport contains HTTP transport extensions.
|
||||
package httptransport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// Config contains the configuration required for constructing an HTTP transport
|
||||
type Config struct {
|
||||
Dialer Dialer
|
||||
QUICDialer QUICDialer
|
||||
TLSDialer TLSDialer
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// Dialer is the definition of dialer assumed by this package.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// TLSDialer is the definition of a TLS dialer assumed by this package.
|
||||
type TLSDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// QUICDialer is the definition of dialer for QUIC assumed by this package.
|
||||
type QUICDialer interface {
|
||||
Dial(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||
}
|
||||
|
||||
// RoundTripper is the definition of http.RoundTripper used by this package.
|
||||
type RoundTripper interface {
|
||||
RoundTrip(req *http.Request) (*http.Response, error)
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// Resolver is the interface we expect from a resolver
|
||||
type Resolver interface {
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
Network() string
|
||||
Address() string
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package httptransport
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Logger is the logger assumed by this package
|
||||
type Logger interface {
|
||||
Debugf(format string, v ...interface{})
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
// LoggingTransport is a logging transport
|
||||
type LoggingTransport struct {
|
||||
RoundTripper
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp LoggingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
req.Header.Set("Host", host) // anticipate what Go would do
|
||||
return txp.logTrip(req)
|
||||
}
|
||||
|
||||
func (txp LoggingTransport) logTrip(req *http.Request) (*http.Response, error) {
|
||||
txp.Logger.Debugf("> %s %s", req.Method, req.URL.String())
|
||||
for key, values := range req.Header {
|
||||
for _, value := range values {
|
||||
txp.Logger.Debugf("> %s: %s", key, value)
|
||||
}
|
||||
}
|
||||
txp.Logger.Debug(">")
|
||||
resp, err := txp.RoundTripper.RoundTrip(req)
|
||||
if err != nil {
|
||||
txp.Logger.Debugf("< %s", err)
|
||||
return nil, err
|
||||
}
|
||||
txp.Logger.Debugf("< %d", resp.StatusCode)
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
txp.Logger.Debugf("< %s: %s", key, value)
|
||||
}
|
||||
}
|
||||
txp.Logger.Debug("<")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var _ RoundTripper = LoggingTransport{}
|
||||
@@ -0,0 +1,77 @@
|
||||
package httptransport_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
|
||||
)
|
||||
|
||||
func TestLoggingFailure(t *testing.T) {
|
||||
txp := httptransport.LoggingTransport{
|
||||
Logger: log.Log,
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Err: io.EOF,
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggingFailureWithNoHostHeader(t *testing.T) {
|
||||
txp := httptransport.LoggingTransport{
|
||||
Logger: log.Log,
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Err: io.EOF,
|
||||
},
|
||||
}
|
||||
req := &http.Request{
|
||||
Header: http.Header{},
|
||||
URL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.google.com",
|
||||
Path: "/",
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggingSuccess(t *testing.T) {
|
||||
txp := httptransport.LoggingTransport{
|
||||
Logger: log.Log,
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Resp: &http.Response{
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
Header: http.Header{
|
||||
"Server": []string{"antani/0.1.0"},
|
||||
},
|
||||
StatusCode: 200,
|
||||
},
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package httptransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// SaverPerformanceHTTPTransport is a RoundTripper that saves
|
||||
// performance events occurring during the round trip
|
||||
type SaverPerformanceHTTPTransport struct {
|
||||
RoundTripper
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp SaverPerformanceHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
tracep := httptrace.ContextClientTrace(req.Context())
|
||||
if tracep == nil {
|
||||
tracep = &httptrace.ClientTrace{
|
||||
WroteHeaders: func() {
|
||||
txp.Saver.Write(trace.Event{Name: "http_wrote_headers", Time: time.Now()})
|
||||
},
|
||||
WroteRequest: func(httptrace.WroteRequestInfo) {
|
||||
txp.Saver.Write(trace.Event{Name: "http_wrote_request", Time: time.Now()})
|
||||
},
|
||||
GotFirstResponseByte: func() {
|
||||
txp.Saver.Write(trace.Event{
|
||||
Name: "http_first_response_byte", Time: time.Now()})
|
||||
},
|
||||
}
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), tracep))
|
||||
}
|
||||
return txp.RoundTripper.RoundTrip(req)
|
||||
}
|
||||
|
||||
// SaverMetadataHTTPTransport is a RoundTripper that saves
|
||||
// events related to HTTP request and response metadata
|
||||
type SaverMetadataHTTPTransport struct {
|
||||
RoundTripper
|
||||
Saver *trace.Saver
|
||||
Transport string
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp SaverMetadataHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
txp.Saver.Write(trace.Event{
|
||||
HTTPHeaders: req.Header,
|
||||
HTTPMethod: req.Method,
|
||||
HTTPURL: req.URL.String(),
|
||||
Transport: txp.Transport,
|
||||
Name: "http_request_metadata",
|
||||
Time: time.Now(),
|
||||
})
|
||||
resp, err := txp.RoundTripper.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
txp.Saver.Write(trace.Event{
|
||||
HTTPHeaders: resp.Header,
|
||||
HTTPStatusCode: resp.StatusCode,
|
||||
Name: "http_response_metadata",
|
||||
Time: time.Now(),
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// SaverTransactionHTTPTransport is a RoundTripper that saves
|
||||
// events related to the HTTP transaction
|
||||
type SaverTransactionHTTPTransport struct {
|
||||
RoundTripper
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp SaverTransactionHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
txp.Saver.Write(trace.Event{
|
||||
Name: "http_transaction_start",
|
||||
Time: time.Now(),
|
||||
})
|
||||
resp, err := txp.RoundTripper.RoundTrip(req)
|
||||
txp.Saver.Write(trace.Event{
|
||||
Err: err,
|
||||
Name: "http_transaction_done",
|
||||
Time: time.Now(),
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// SaverBodyHTTPTransport is a RoundTripper that saves
|
||||
// body events occurring during the round trip
|
||||
type SaverBodyHTTPTransport struct {
|
||||
RoundTripper
|
||||
Saver *trace.Saver
|
||||
SnapshotSize int
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp SaverBodyHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
const defaultSnapSize = 1 << 17
|
||||
snapsize := defaultSnapSize
|
||||
if txp.SnapshotSize != 0 {
|
||||
snapsize = txp.SnapshotSize
|
||||
}
|
||||
if req.Body != nil {
|
||||
data, err := saverSnapRead(req.Body, snapsize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Body = saverCompose(data, req.Body)
|
||||
txp.Saver.Write(trace.Event{
|
||||
DataIsTruncated: len(data) >= snapsize,
|
||||
Data: data,
|
||||
Name: "http_request_body_snapshot",
|
||||
Time: time.Now(),
|
||||
})
|
||||
}
|
||||
resp, err := txp.RoundTripper.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := saverSnapRead(resp.Body, snapsize)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return nil, err
|
||||
}
|
||||
resp.Body = saverCompose(data, resp.Body)
|
||||
txp.Saver.Write(trace.Event{
|
||||
DataIsTruncated: len(data) >= snapsize,
|
||||
Data: data,
|
||||
Name: "http_response_body_snapshot",
|
||||
Time: time.Now(),
|
||||
})
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func saverSnapRead(r io.ReadCloser, snapsize int) ([]byte, error) {
|
||||
return ioutil.ReadAll(io.LimitReader(r, int64(snapsize)))
|
||||
}
|
||||
|
||||
func saverCompose(data []byte, r io.ReadCloser) io.ReadCloser {
|
||||
return saverReadCloser{Closer: r, Reader: io.MultiReader(bytes.NewReader(data), r)}
|
||||
}
|
||||
|
||||
type saverReadCloser struct {
|
||||
io.Closer
|
||||
io.Reader
|
||||
}
|
||||
|
||||
var _ RoundTripper = SaverPerformanceHTTPTransport{}
|
||||
var _ RoundTripper = SaverMetadataHTTPTransport{}
|
||||
var _ RoundTripper = SaverBodyHTTPTransport{}
|
||||
var _ RoundTripper = SaverTransactionHTTPTransport{}
|
||||
@@ -0,0 +1,429 @@
|
||||
package httptransport_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
func TestSaverPerformanceNoMultipleEvents(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
// register twice - do we see events twice?
|
||||
txp := httptransport.SaverPerformanceHTTPTransport{
|
||||
RoundTripper: http.DefaultTransport.(*http.Transport),
|
||||
Saver: saver,
|
||||
}
|
||||
txp = httptransport.SaverPerformanceHTTPTransport{
|
||||
RoundTripper: txp,
|
||||
Saver: saver,
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non nil response here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
// we should specifically see the events not attached to any
|
||||
// context being submitted twice. This is fine because they are
|
||||
// explicit, while the context is implicit and hence leads to
|
||||
// more subtle bugs. For example, this happens when you measure
|
||||
// every event and combine HTTP with DoH.
|
||||
if len(ev) != 3 {
|
||||
t.Fatal("expected three events")
|
||||
}
|
||||
expected := []string{
|
||||
"http_wrote_headers", // measured with context
|
||||
"http_wrote_request", // measured with context
|
||||
"http_first_response_byte", // measured with context
|
||||
}
|
||||
for i := 0; i < len(expected); i++ {
|
||||
if ev[i].Name != expected[i] {
|
||||
t.Fatal("unexpected event name")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverMetadataSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
txp := httptransport.SaverMetadataHTTPTransport{
|
||||
RoundTripper: http.DefaultTransport.(*http.Transport),
|
||||
Saver: saver,
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Add("User-Agent", "miniooni/0.1.0-dev")
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non nil response here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected two events")
|
||||
}
|
||||
//
|
||||
if ev[0].HTTPMethod != "GET" {
|
||||
t.Fatal("unexpected Method")
|
||||
}
|
||||
if len(ev[0].HTTPHeaders) <= 0 {
|
||||
t.Fatal("unexpected Headers")
|
||||
}
|
||||
if ev[0].HTTPURL != "https://www.google.com" {
|
||||
t.Fatal("unexpected URL")
|
||||
}
|
||||
if ev[0].Name != "http_request_metadata" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
//
|
||||
if ev[1].HTTPStatusCode != 200 {
|
||||
t.Fatal("unexpected StatusCode")
|
||||
}
|
||||
if len(ev[1].HTTPHeaders) <= 0 {
|
||||
t.Fatal("unexpected Headers")
|
||||
}
|
||||
if ev[1].Name != "http_response_metadata" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverMetadataFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
saver := &trace.Saver{}
|
||||
txp := httptransport.SaverMetadataHTTPTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
req, err := http.NewRequest("GET", "http://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Add("User-Agent", "miniooni/0.1.0-dev")
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 1 {
|
||||
t.Fatal("expected one event")
|
||||
}
|
||||
if ev[0].HTTPMethod != "GET" {
|
||||
t.Fatal("unexpected Method")
|
||||
}
|
||||
if len(ev[0].HTTPHeaders) <= 0 {
|
||||
t.Fatal("unexpected Headers")
|
||||
}
|
||||
if ev[0].HTTPURL != "http://www.google.com" {
|
||||
t.Fatal("unexpected URL")
|
||||
}
|
||||
if ev[0].Name != "http_request_metadata" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTransactionSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
txp := httptransport.SaverTransactionHTTPTransport{
|
||||
RoundTripper: http.DefaultTransport.(*http.Transport),
|
||||
Saver: saver,
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non nil response here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected two events")
|
||||
}
|
||||
//
|
||||
if ev[0].Name != "http_transaction_start" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
//
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Name != "http_transaction_done" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTransactionFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
saver := &trace.Saver{}
|
||||
txp := httptransport.SaverTransactionHTTPTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
req, err := http.NewRequest("GET", "http://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected two events")
|
||||
}
|
||||
if ev[0].Name != "http_transaction_start" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
if ev[1].Name != "http_transaction_done" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !errors.Is(ev[1].Err, expected) {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverBodySuccess(t *testing.T) {
|
||||
saver := new(trace.Saver)
|
||||
txp := httptransport.SaverBodyHTTPTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Func: func(req *http.Request) (*http.Response, error) {
|
||||
data, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "deadbeef" {
|
||||
t.Fatal("invalid data")
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: 501,
|
||||
Body: ioutil.NopCloser(strings.NewReader("abad1dea")),
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
SnapshotSize: 4,
|
||||
Saver: saver,
|
||||
}
|
||||
body := strings.NewReader("deadbeef")
|
||||
req, err := http.NewRequest("POST", "http://x.org/y", body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != 501 {
|
||||
t.Fatal("unexpected status code")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "abad1dea" {
|
||||
t.Fatal("unexpected body")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if string(ev[0].Data) != "dead" {
|
||||
t.Fatal("invalid Data")
|
||||
}
|
||||
if ev[0].DataIsTruncated != true {
|
||||
t.Fatal("invalid DataIsTruncated")
|
||||
}
|
||||
if ev[0].Name != "http_request_body_snapshot" {
|
||||
t.Fatal("invalid Name")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("invalid Time")
|
||||
}
|
||||
if string(ev[1].Data) != "abad" {
|
||||
t.Fatal("invalid Data")
|
||||
}
|
||||
if ev[1].DataIsTruncated != true {
|
||||
t.Fatal("invalid DataIsTruncated")
|
||||
}
|
||||
if ev[1].Name != "http_response_body_snapshot" {
|
||||
t.Fatal("invalid Name")
|
||||
}
|
||||
if ev[1].Time.Before(ev[0].Time) {
|
||||
t.Fatal("invalid Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverBodyRequestReadError(t *testing.T) {
|
||||
saver := new(trace.Saver)
|
||||
txp := httptransport.SaverBodyHTTPTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Func: func(req *http.Request) (*http.Response, error) {
|
||||
panic("should not be called")
|
||||
},
|
||||
},
|
||||
SnapshotSize: 4,
|
||||
Saver: saver,
|
||||
}
|
||||
expected := errors.New("mocked error")
|
||||
body := httptransport.FakeBody{Err: expected}
|
||||
req, err := http.NewRequest("POST", "http://x.org/y", body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 0 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverBodyRoundTripError(t *testing.T) {
|
||||
saver := new(trace.Saver)
|
||||
expected := errors.New("mocked error")
|
||||
txp := httptransport.SaverBodyHTTPTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Err: expected,
|
||||
},
|
||||
SnapshotSize: 4,
|
||||
Saver: saver,
|
||||
}
|
||||
body := strings.NewReader("deadbeef")
|
||||
req, err := http.NewRequest("POST", "http://x.org/y", body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 1 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if string(ev[0].Data) != "dead" {
|
||||
t.Fatal("invalid Data")
|
||||
}
|
||||
if ev[0].DataIsTruncated != true {
|
||||
t.Fatal("invalid DataIsTruncated")
|
||||
}
|
||||
if ev[0].Name != "http_request_body_snapshot" {
|
||||
t.Fatal("invalid Name")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("invalid Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverBodyResponseReadError(t *testing.T) {
|
||||
saver := new(trace.Saver)
|
||||
expected := errors.New("mocked error")
|
||||
txp := httptransport.SaverBodyHTTPTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Func: func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: httptransport.FakeBody{
|
||||
Err: expected,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
SnapshotSize: 4,
|
||||
Saver: saver,
|
||||
}
|
||||
body := strings.NewReader("deadbeef")
|
||||
req, err := http.NewRequest("POST", "http://x.org/y", body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 1 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if string(ev[0].Data) != "dead" {
|
||||
t.Fatal("invalid Data")
|
||||
}
|
||||
if ev[0].DataIsTruncated != true {
|
||||
t.Fatal("invalid DataIsTruncated")
|
||||
}
|
||||
if ev[0].Name != "http_request_body_snapshot" {
|
||||
t.Fatal("invalid Name")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("invalid Time")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package httptransport
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// NewSystemTransport creates a new "system" HTTP transport. That is a transport
|
||||
// using the Go standard library with custom dialer and TLS dialer.
|
||||
func NewSystemTransport(config Config) RoundTripper {
|
||||
txp := http.DefaultTransport.(*http.Transport).Clone()
|
||||
txp.DialContext = config.Dialer.DialContext
|
||||
txp.DialTLSContext = config.TLSDialer.DialTLSContext
|
||||
// Better for Cloudflare DNS and also better because we have less
|
||||
// noisy events and we can better understand what happened.
|
||||
txp.MaxConnsPerHost = 1
|
||||
// The following (1) reduces the number of headers that Go will
|
||||
// automatically send for us and (2) ensures that we always receive
|
||||
// back the true headers, such as Content-Length. This change is
|
||||
// functional to OONI's goal of observing the network.
|
||||
txp.DisableCompression = true
|
||||
return txp
|
||||
}
|
||||
|
||||
var _ RoundTripper = &http.Transport{}
|
||||
@@ -0,0 +1,19 @@
|
||||
package httptransport
|
||||
|
||||
import "net/http"
|
||||
|
||||
// UserAgentTransport is a transport that ensures that we always
|
||||
// set an OONI specific default User-Agent header.
|
||||
type UserAgentTransport struct {
|
||||
RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp UserAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("User-Agent") == "" {
|
||||
req.Header.Set("User-Agent", "miniooni/0.1.0-dev")
|
||||
}
|
||||
return txp.RoundTripper.RoundTrip(req)
|
||||
}
|
||||
|
||||
var _ RoundTripper = UserAgentTransport{}
|
||||
@@ -0,0 +1,51 @@
|
||||
package httptransport_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
|
||||
)
|
||||
|
||||
func TestUserAgentWithDefault(t *testing.T) {
|
||||
txp := httptransport.UserAgentTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Resp: &http.Response{StatusCode: 200},
|
||||
},
|
||||
}
|
||||
req := &http.Request{URL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.google.com",
|
||||
Path: "/",
|
||||
}}
|
||||
req.Header = http.Header{}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Request.Header.Get("User-Agent") != "miniooni/0.1.0-dev" {
|
||||
t.Fatal("not the User-Agent we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserAgentWithExplicitValue(t *testing.T) {
|
||||
txp := httptransport.UserAgentTransport{
|
||||
RoundTripper: httptransport.FakeTransport{
|
||||
Resp: &http.Response{StatusCode: 200},
|
||||
},
|
||||
}
|
||||
req := &http.Request{URL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "www.google.com",
|
||||
Path: "/",
|
||||
}}
|
||||
req.Header = http.Header{"User-Agent": []string{"antani-client/0.1.1"}}
|
||||
resp, err := txp.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Request.Header.Get("User-Agent") != "antani-client/0.1.1" {
|
||||
t.Fatal("not the User-Agent we expected")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package netx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
func TestSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
log.SetLevel(log.DebugLevel)
|
||||
counter := bytecounter.New()
|
||||
config := netx.Config{
|
||||
BogonIsError: true,
|
||||
ByteCounter: counter,
|
||||
CacheResolutions: true,
|
||||
ContextByteCounting: true,
|
||||
DialSaver: &trace.Saver{},
|
||||
HTTPSaver: &trace.Saver{},
|
||||
Logger: log.Log,
|
||||
ReadWriteSaver: &trace.Saver{},
|
||||
ResolveSaver: &trace.Saver{},
|
||||
TLSSaver: &trace.Saver{},
|
||||
}
|
||||
txp := netx.NewHTTPTransport(config)
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err = ioutil.ReadAll(resp.Body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if counter.Sent.Load() <= 0 {
|
||||
t.Fatal("no bytes sent?!")
|
||||
}
|
||||
if counter.Received.Load() <= 0 {
|
||||
t.Fatal("no bytes received?!")
|
||||
}
|
||||
if ev := config.DialSaver.Read(); len(ev) <= 0 {
|
||||
t.Fatal("no dial events?!")
|
||||
}
|
||||
if ev := config.HTTPSaver.Read(); len(ev) <= 0 {
|
||||
t.Fatal("no HTTP events?!")
|
||||
}
|
||||
if ev := config.ReadWriteSaver.Read(); len(ev) <= 0 {
|
||||
t.Fatal("no R/W events?!")
|
||||
}
|
||||
if ev := config.ResolveSaver.Read(); len(ev) <= 0 {
|
||||
t.Fatal("no resolver events?!")
|
||||
}
|
||||
if ev := config.TLSSaver.Read(); len(ev) <= 0 {
|
||||
t.Fatal("no TLS events?!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBogonResolutionNotBroken(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := new(trace.Saver)
|
||||
r := netx.NewResolver(netx.Config{
|
||||
BogonIsError: true,
|
||||
DNSCache: map[string][]string{
|
||||
"www.google.com": {"127.0.0.1"},
|
||||
},
|
||||
ResolveSaver: saver,
|
||||
Logger: log.Log,
|
||||
})
|
||||
addrs, err := r.LookupHost(context.Background(), "www.google.com")
|
||||
if !errors.Is(err, errorx.ErrDNSBogon) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if err.Error() != errorx.FailureDNSBogonError {
|
||||
t.Fatal("error not correctly wrapped")
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "127.0.0.1" {
|
||||
t.Fatal("address was not returned")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
// Package netx contains code to perform network measurements.
|
||||
//
|
||||
// This library contains replacements for commonly used standard library
|
||||
// interfaces that facilitate seamless network measurements. By using
|
||||
// such replacements, as opposed to standard library interfaces, we can:
|
||||
//
|
||||
// * save the timing of HTTP events (e.g. received response headers)
|
||||
// * save the timing and result of every Connect, Read, Write, Close operation
|
||||
// * save the timing and result of the TLS handshake (including certificates)
|
||||
//
|
||||
// By default, this library uses the system resolver. In addition, it
|
||||
// is possible to configure alternative DNS transports and remote
|
||||
// servers. We support DNS over UDP, DNS over TCP, DNS over TLS (DoT),
|
||||
// and DNS over HTTPS (DoH). When using an alternative transport, we
|
||||
// are also able to intercept and save DNS messages, as well as any
|
||||
// other interaction with the remote server (e.g., the result of the
|
||||
// TLS handshake for DoT and DoH).
|
||||
//
|
||||
// We described the design and implementation of the most recent version of
|
||||
// this package at <https://github.com/ooni/probe-cli/v3/internal/engine/issues/359>. Such
|
||||
// issue also links to a previous design document.
|
||||
package netx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/runtimex"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/gocertifi"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/httptransport"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/selfcensor"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// Logger is the logger assumed by this package
|
||||
type Logger interface {
|
||||
Debugf(format string, v ...interface{})
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
// Dialer is the definition of dialer assumed by this package.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// QUICDialer is the definition of a dialer for QUIC assumed by this package.
|
||||
type QUICDialer interface {
|
||||
Dial(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||
}
|
||||
|
||||
// TLSDialer is the definition of a TLS dialer assumed by this package.
|
||||
type TLSDialer interface {
|
||||
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// HTTPRoundTripper is the definition of http.HTTPRoundTripper used by this package.
|
||||
type HTTPRoundTripper interface {
|
||||
RoundTrip(req *http.Request) (*http.Response, error)
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// Resolver is the interface we expect from a resolver
|
||||
type Resolver interface {
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
Network() string
|
||||
Address() string
|
||||
}
|
||||
|
||||
// Config contains configuration for creating a new transport. When any
|
||||
// field of Config is nil/empty, we will use a suitable default.
|
||||
//
|
||||
// We use different savers for different kind of events such that the
|
||||
// user of this library can choose what to save.
|
||||
type Config struct {
|
||||
BaseResolver Resolver // default: system resolver
|
||||
BogonIsError bool // default: bogon is not error
|
||||
ByteCounter *bytecounter.Counter // default: no explicit byte counting
|
||||
CacheResolutions bool // default: no caching
|
||||
CertPool *x509.CertPool // default: use vendored gocertifi
|
||||
ContextByteCounting bool // default: no implicit byte counting
|
||||
DNSCache map[string][]string // default: cache is empty
|
||||
DialSaver *trace.Saver // default: not saving dials
|
||||
Dialer Dialer // default: dialer.DNSDialer
|
||||
FullResolver Resolver // default: base resolver + goodies
|
||||
QUICDialer QUICDialer // default: quicdialer.DNSDialer
|
||||
HTTP3Enabled bool // default: disabled
|
||||
HTTPSaver *trace.Saver // default: not saving HTTP
|
||||
Logger Logger // default: no logging
|
||||
NoTLSVerify bool // default: perform TLS verify
|
||||
ProxyURL *url.URL // default: no proxy
|
||||
ReadWriteSaver *trace.Saver // default: not saving read/write
|
||||
ResolveSaver *trace.Saver // default: not saving resolves
|
||||
TLSConfig *tls.Config // default: attempt using h2
|
||||
TLSDialer TLSDialer // default: dialer.TLSDialer
|
||||
TLSSaver *trace.Saver // default: not saving TLS
|
||||
}
|
||||
|
||||
type tlsHandshaker interface {
|
||||
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
|
||||
net.Conn, tls.ConnectionState, error)
|
||||
}
|
||||
|
||||
// NewDefaultCertPool returns a copy of the default x509
|
||||
// certificate pool. This function panics on failure.
|
||||
func NewDefaultCertPool() *x509.CertPool {
|
||||
pool, err := gocertifi.CACerts()
|
||||
runtimex.PanicOnError(err, "gocertifi.CACerts() failed")
|
||||
return pool
|
||||
}
|
||||
|
||||
var defaultCertPool *x509.CertPool = NewDefaultCertPool()
|
||||
|
||||
// NewResolver creates a new resolver from the specified config
|
||||
func NewResolver(config Config) Resolver {
|
||||
if config.BaseResolver == nil {
|
||||
config.BaseResolver = resolver.SystemResolver{}
|
||||
}
|
||||
var r Resolver = config.BaseResolver
|
||||
if config.CacheResolutions {
|
||||
r = &resolver.CacheResolver{Resolver: r}
|
||||
}
|
||||
if config.DNSCache != nil {
|
||||
cache := &resolver.CacheResolver{Resolver: r, ReadOnly: true}
|
||||
for key, values := range config.DNSCache {
|
||||
cache.Set(key, values)
|
||||
}
|
||||
r = cache
|
||||
}
|
||||
if config.BogonIsError {
|
||||
r = resolver.BogonResolver{Resolver: r}
|
||||
}
|
||||
r = resolver.ErrorWrapperResolver{Resolver: r}
|
||||
if config.Logger != nil {
|
||||
r = resolver.LoggingResolver{Logger: config.Logger, Resolver: r}
|
||||
}
|
||||
if config.ResolveSaver != nil {
|
||||
r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver}
|
||||
}
|
||||
r = resolver.AddressResolver{Resolver: r}
|
||||
return resolver.IDNAResolver{Resolver: r}
|
||||
}
|
||||
|
||||
// NewDialer creates a new Dialer from the specified config
|
||||
func NewDialer(config Config) Dialer {
|
||||
if config.FullResolver == nil {
|
||||
config.FullResolver = NewResolver(config)
|
||||
}
|
||||
var d Dialer = selfcensor.SystemDialer{}
|
||||
d = dialer.TimeoutDialer{Dialer: d}
|
||||
d = dialer.ErrorWrapperDialer{Dialer: d}
|
||||
if config.Logger != nil {
|
||||
d = dialer.LoggingDialer{Dialer: d, Logger: config.Logger}
|
||||
}
|
||||
if config.DialSaver != nil {
|
||||
d = dialer.SaverDialer{Dialer: d, Saver: config.DialSaver}
|
||||
}
|
||||
if config.ReadWriteSaver != nil {
|
||||
d = dialer.SaverConnDialer{Dialer: d, Saver: config.ReadWriteSaver}
|
||||
}
|
||||
d = dialer.DNSDialer{Resolver: config.FullResolver, Dialer: d}
|
||||
d = dialer.ProxyDialer{ProxyURL: config.ProxyURL, Dialer: d}
|
||||
if config.ContextByteCounting {
|
||||
d = dialer.ByteCounterDialer{Dialer: d}
|
||||
}
|
||||
d = dialer.ShapingDialer{Dialer: d}
|
||||
return d
|
||||
}
|
||||
|
||||
// NewQUICDialer creates a new DNS Dialer for QUIC, with the resolver from the specified config
|
||||
func NewQUICDialer(config Config) QUICDialer {
|
||||
if config.FullResolver == nil {
|
||||
config.FullResolver = NewResolver(config)
|
||||
}
|
||||
var d quicdialer.ContextDialer = &quicdialer.SystemDialer{Saver: config.ReadWriteSaver}
|
||||
d = quicdialer.ErrorWrapperDialer{Dialer: d}
|
||||
if config.TLSSaver != nil {
|
||||
d = quicdialer.HandshakeSaver{Saver: config.TLSSaver, Dialer: d}
|
||||
}
|
||||
d = &quicdialer.DNSDialer{Resolver: config.FullResolver, Dialer: d}
|
||||
var dialer QUICDialer = &httptransport.QUICWrapperDialer{Dialer: d}
|
||||
return dialer
|
||||
}
|
||||
|
||||
// NewTLSDialer creates a new TLSDialer from the specified config
|
||||
func NewTLSDialer(config Config) TLSDialer {
|
||||
if config.Dialer == nil {
|
||||
config.Dialer = NewDialer(config)
|
||||
}
|
||||
var h tlsHandshaker = dialer.SystemTLSHandshaker{}
|
||||
h = dialer.TimeoutTLSHandshaker{TLSHandshaker: h}
|
||||
h = dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
|
||||
if config.Logger != nil {
|
||||
h = dialer.LoggingTLSHandshaker{Logger: config.Logger, TLSHandshaker: h}
|
||||
}
|
||||
if config.TLSSaver != nil {
|
||||
h = dialer.SaverTLSHandshaker{TLSHandshaker: h, Saver: config.TLSSaver}
|
||||
}
|
||||
if config.TLSConfig == nil {
|
||||
config.TLSConfig = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
|
||||
}
|
||||
if config.CertPool == nil {
|
||||
config.CertPool = defaultCertPool
|
||||
}
|
||||
config.TLSConfig.RootCAs = config.CertPool
|
||||
config.TLSConfig.InsecureSkipVerify = config.NoTLSVerify
|
||||
return dialer.TLSDialer{
|
||||
Config: config.TLSConfig,
|
||||
Dialer: config.Dialer,
|
||||
TLSHandshaker: h,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTPRoundTripper. You can further extend the returned
|
||||
// HTTPRoundTripper before wrapping it into an http.Client.
|
||||
func NewHTTPTransport(config Config) HTTPRoundTripper {
|
||||
if config.Dialer == nil {
|
||||
config.Dialer = NewDialer(config)
|
||||
}
|
||||
if config.TLSDialer == nil {
|
||||
config.TLSDialer = NewTLSDialer(config)
|
||||
}
|
||||
if config.QUICDialer == nil {
|
||||
config.QUICDialer = NewQUICDialer(config)
|
||||
}
|
||||
|
||||
tInfo := allTransportsInfo[config.HTTP3Enabled]
|
||||
txp := tInfo.Factory(httptransport.Config{
|
||||
Dialer: config.Dialer, QUICDialer: config.QUICDialer, TLSDialer: config.TLSDialer,
|
||||
TLSConfig: config.TLSConfig})
|
||||
transport := tInfo.TransportName
|
||||
|
||||
if config.ByteCounter != nil {
|
||||
txp = httptransport.ByteCountingTransport{
|
||||
Counter: config.ByteCounter, RoundTripper: txp}
|
||||
}
|
||||
if config.Logger != nil {
|
||||
txp = httptransport.LoggingTransport{Logger: config.Logger, RoundTripper: txp}
|
||||
}
|
||||
if config.HTTPSaver != nil {
|
||||
txp = httptransport.SaverMetadataHTTPTransport{
|
||||
RoundTripper: txp, Saver: config.HTTPSaver, Transport: transport}
|
||||
txp = httptransport.SaverBodyHTTPTransport{
|
||||
RoundTripper: txp, Saver: config.HTTPSaver}
|
||||
txp = httptransport.SaverPerformanceHTTPTransport{
|
||||
RoundTripper: txp, Saver: config.HTTPSaver}
|
||||
txp = httptransport.SaverTransactionHTTPTransport{
|
||||
RoundTripper: txp, Saver: config.HTTPSaver}
|
||||
}
|
||||
txp = httptransport.UserAgentTransport{RoundTripper: txp}
|
||||
return txp
|
||||
}
|
||||
|
||||
// httpTransportInfo contains the constructing function as well as the transport name
|
||||
type httpTransportInfo struct {
|
||||
Factory func(httptransport.Config) httptransport.RoundTripper
|
||||
TransportName string
|
||||
}
|
||||
|
||||
var allTransportsInfo = map[bool]httpTransportInfo{
|
||||
false: {
|
||||
Factory: httptransport.NewSystemTransport,
|
||||
TransportName: "tcp",
|
||||
},
|
||||
true: {
|
||||
Factory: httptransport.NewHTTP3Transport,
|
||||
TransportName: "quic",
|
||||
},
|
||||
}
|
||||
|
||||
// DNSClient is a DNS client. It wraps a Resolver and it possibly
|
||||
// also wraps an HTTP client, but only when we're using DoH.
|
||||
type DNSClient struct {
|
||||
Resolver
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections, if any.
|
||||
func (c DNSClient) CloseIdleConnections() {
|
||||
if c.httpClient != nil {
|
||||
c.httpClient.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// NewDNSClient creates a new DNS client. The config argument is used to
|
||||
// create the underlying Dialer and/or HTTP transport, if needed. The URL
|
||||
// argument describes the kind of client that we want to make:
|
||||
//
|
||||
// - if the URL is `doh://powerdns`, `doh://google` or `doh://cloudflare` or the URL
|
||||
// starts with `https://`, then we create a DoH client.
|
||||
//
|
||||
// - if the URL is `` or `system:///`, then we create a system client,
|
||||
// i.e. a client using the system resolver.
|
||||
//
|
||||
// - if the URL starts with `udp://`, then we create a client using
|
||||
// a resolver that uses the specified UDP endpoint.
|
||||
//
|
||||
// We return error if the URL does not parse or the URL scheme does not
|
||||
// fall into one of the cases described above.
|
||||
//
|
||||
// If config.ResolveSaver is not nil and we're creating an underlying
|
||||
// resolver where this is possible, we will also save events.
|
||||
func NewDNSClient(config Config, URL string) (DNSClient, error) {
|
||||
return NewDNSClientWithOverrides(config, URL, "", "", "")
|
||||
}
|
||||
|
||||
// ErrInvalidTLSVersion indicates that you passed us a string
|
||||
// that does not represent a valid TLS version.
|
||||
var ErrInvalidTLSVersion = errors.New("invalid TLS version")
|
||||
|
||||
// ConfigureTLSVersion configures the correct TLS version into
|
||||
// the specified *tls.Config or returns an error.
|
||||
func ConfigureTLSVersion(config *tls.Config, version string) error {
|
||||
switch version {
|
||||
case "TLSv1.3":
|
||||
config.MinVersion = tls.VersionTLS13
|
||||
config.MaxVersion = tls.VersionTLS13
|
||||
case "TLSv1.2":
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS12
|
||||
case "TLSv1.1":
|
||||
config.MinVersion = tls.VersionTLS11
|
||||
config.MaxVersion = tls.VersionTLS11
|
||||
case "TLSv1.0", "TLSv1":
|
||||
config.MinVersion = tls.VersionTLS10
|
||||
config.MaxVersion = tls.VersionTLS10
|
||||
case "":
|
||||
// nothing
|
||||
default:
|
||||
return ErrInvalidTLSVersion
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewDNSClientWithOverrides creates a new DNS client, similar to NewDNSClient,
|
||||
// with the option to override the default Hostname and SNI.
|
||||
func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
||||
TLSVersion string) (DNSClient, error) {
|
||||
var c DNSClient
|
||||
switch URL {
|
||||
case "doh://powerdns":
|
||||
URL = "https://doh.powerdns.org/"
|
||||
case "doh://google":
|
||||
URL = "https://dns.google/dns-query"
|
||||
case "doh://cloudflare":
|
||||
URL = "https://cloudflare-dns.com/dns-query"
|
||||
case "":
|
||||
URL = "system:///"
|
||||
}
|
||||
resolverURL, err := url.Parse(URL)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
config.TLSConfig = &tls.Config{ServerName: SNIOverride}
|
||||
if err := ConfigureTLSVersion(config.TLSConfig, TLSVersion); err != nil {
|
||||
return c, err
|
||||
}
|
||||
switch resolverURL.Scheme {
|
||||
case "system":
|
||||
c.Resolver = resolver.SystemResolver{}
|
||||
return c, nil
|
||||
case "https":
|
||||
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||
c.httpClient = &http.Client{Transport: NewHTTPTransport(config)}
|
||||
var txp resolver.RoundTripper = resolver.NewDNSOverHTTPSWithHostOverride(
|
||||
c.httpClient, URL, hostOverride)
|
||||
if config.ResolveSaver != nil {
|
||||
txp = resolver.SaverDNSTransport{
|
||||
RoundTripper: txp,
|
||||
Saver: config.ResolveSaver,
|
||||
}
|
||||
}
|
||||
c.Resolver = resolver.NewSerialResolver(txp)
|
||||
return c, nil
|
||||
case "udp":
|
||||
dialer := NewDialer(config)
|
||||
endpoint, err := makeValidEndpoint(resolverURL)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
var txp resolver.RoundTripper = resolver.NewDNSOverUDP(dialer, endpoint)
|
||||
if config.ResolveSaver != nil {
|
||||
txp = resolver.SaverDNSTransport{
|
||||
RoundTripper: txp,
|
||||
Saver: config.ResolveSaver,
|
||||
}
|
||||
}
|
||||
c.Resolver = resolver.NewSerialResolver(txp)
|
||||
return c, nil
|
||||
case "dot":
|
||||
config.TLSConfig.NextProtos = []string{"dot"}
|
||||
tlsDialer := NewTLSDialer(config)
|
||||
endpoint, err := makeValidEndpoint(resolverURL)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
var txp resolver.RoundTripper = resolver.NewDNSOverTLS(
|
||||
tlsDialer.DialTLSContext, endpoint)
|
||||
if config.ResolveSaver != nil {
|
||||
txp = resolver.SaverDNSTransport{
|
||||
RoundTripper: txp,
|
||||
Saver: config.ResolveSaver,
|
||||
}
|
||||
}
|
||||
c.Resolver = resolver.NewSerialResolver(txp)
|
||||
return c, nil
|
||||
case "tcp":
|
||||
dialer := NewDialer(config)
|
||||
endpoint, err := makeValidEndpoint(resolverURL)
|
||||
if err != nil {
|
||||
return c, err
|
||||
}
|
||||
var txp resolver.RoundTripper = resolver.NewDNSOverTCP(
|
||||
dialer.DialContext, endpoint)
|
||||
if config.ResolveSaver != nil {
|
||||
txp = resolver.SaverDNSTransport{
|
||||
RoundTripper: txp,
|
||||
Saver: config.ResolveSaver,
|
||||
}
|
||||
}
|
||||
c.Resolver = resolver.NewSerialResolver(txp)
|
||||
return c, nil
|
||||
default:
|
||||
return c, errors.New("unsupported resolver scheme")
|
||||
}
|
||||
}
|
||||
|
||||
// makeValidEndpoint makes a valid endpoint for DoT and Do53 given the
|
||||
// input URL representing such endpoint. Specifically, we are
|
||||
// concerned with the case where the port is missing. In such a
|
||||
// case, we ensure that we are using the default port 853 for DoT
|
||||
// and default port 53 for TCP and UDP.
|
||||
func makeValidEndpoint(URL *url.URL) (string, error) {
|
||||
// Implementation note: when we're using a quoted IPv6
|
||||
// address, URL.Host contains the quotes but instead the
|
||||
// return value from URL.Hostname() does not.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// - Host: [2620:fe::9]
|
||||
// - Hostname(): 2620:fe::9
|
||||
//
|
||||
// We need to keep this in mind when trying to determine
|
||||
// whether there is also a port or not.
|
||||
//
|
||||
// So the first step is to check whether URL.Host is already
|
||||
// a whatever valid TCP/UDP endpoint and, if so, use it.
|
||||
if _, _, err := net.SplitHostPort(URL.Host); err == nil {
|
||||
return URL.Host, nil
|
||||
}
|
||||
// The second step is to assume that appending the default port
|
||||
// to a host parsed by url.Parse should be giving us a valid
|
||||
// endpoint. The possibilities in fact are:
|
||||
//
|
||||
// 1. domain w/o port
|
||||
// 2. IPv4 w/o port
|
||||
// 3. square bracket quoted IPv6 w/o port
|
||||
// 4. other
|
||||
//
|
||||
// In the first three cases, appending a port leads us to a
|
||||
// good endpoint. The fourth case does not.
|
||||
//
|
||||
// For this reason we check again whether we can split it using
|
||||
// net.SplitHostPort. If we cannot, we were in case four.
|
||||
host := URL.Host
|
||||
if URL.Scheme == "dot" {
|
||||
host += ":853"
|
||||
} else {
|
||||
host += ":53"
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(host); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Otherwise it's one of the three valid cases above.
|
||||
return host, nil
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package netx
|
||||
|
||||
import "crypto/x509"
|
||||
|
||||
// DefaultCertPool allows tests to access the default cert pool.
|
||||
func DefaultCertPool() *x509.CertPool {
|
||||
return defaultCertPool
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,14 @@
|
||||
// +build !go1.15
|
||||
|
||||
package quicdialer
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// ConnectionState returns the ConnectionState of a QUIC Session.
|
||||
func ConnectionState(sess quic.EarlySession) tls.ConnectionState {
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
// +build go1.15
|
||||
|
||||
package quicdialer
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// ConnectionState returns the ConnectionState of a QUIC Session.
|
||||
func ConnectionState(sess quic.EarlySession) tls.ConnectionState {
|
||||
return sess.ConnectionState().ConnectionState
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package quicdialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
// DNSDialer is a dialer that uses the configured Resolver to resolve a
|
||||
// domain name to IP addresses
|
||||
type DNSDialer struct {
|
||||
Dialer ContextDialer
|
||||
Resolver Resolver
|
||||
}
|
||||
|
||||
// DialContext implements ContextDialer.DialContext
|
||||
func (d DNSDialer) DialContext(
|
||||
ctx context.Context, network, host string,
|
||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
onlyhost, onlyport, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(kelmenhorst): Should this be somewhere else?
|
||||
// failure if tlsCfg is nil but that should not happen
|
||||
if tlsCfg.ServerName == "" {
|
||||
tlsCfg.ServerName = onlyhost
|
||||
}
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
var addrs []string
|
||||
addrs, err = d.LookupHost(ctx, onlyhost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var errorslist []error
|
||||
for _, addr := range addrs {
|
||||
target := net.JoinHostPort(addr, onlyport)
|
||||
sess, err := d.Dialer.DialContext(
|
||||
ctx, network, target, tlsCfg, cfg)
|
||||
if err == nil {
|
||||
return sess, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
}
|
||||
// TODO(bassosimone): maybe ReduceErrors could be in netx/internal.
|
||||
return nil, dialer.ReduceErrors(errorslist)
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (d DNSDialer) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
return d.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package quicdialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
|
||||
)
|
||||
|
||||
type MockableResolver struct {
|
||||
Addresses []string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||
return r.Addresses, r.Err
|
||||
}
|
||||
|
||||
func TestDNSDialerSuccess(t *testing.T) {
|
||||
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
|
||||
dialer := quicdialer.DNSDialer{
|
||||
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
|
||||
sess, err := dialer.DialContext(
|
||||
context.Background(), "udp", "www.google.com:443",
|
||||
tlsConf, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if sess == nil {
|
||||
t.Fatal("non nil sess expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerNoPort(t *testing.T) {
|
||||
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
|
||||
dialer := quicdialer.DNSDialer{
|
||||
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
|
||||
sess, err := dialer.DialContext(
|
||||
context.Background(), "udp", "www.google.com",
|
||||
tlsConf, &quic.Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected a nil sess here")
|
||||
}
|
||||
if err.Error() != "address www.google.com: missing port in address" {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerLookupHostAddress(t *testing.T) {
|
||||
dialer := quicdialer.DNSDialer{Resolver: MockableResolver{
|
||||
Err: errors.New("mocked error"),
|
||||
}}
|
||||
addrs, err := dialer.LookupHost(context.Background(), "1.1.1.1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerLookupHostFailure(t *testing.T) {
|
||||
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
|
||||
expected := errors.New("mocked error")
|
||||
dialer := quicdialer.DNSDialer{Resolver: MockableResolver{
|
||||
Err: expected,
|
||||
}}
|
||||
sess, err := dialer.DialContext(
|
||||
context.Background(), "udp", "dns.google.com:853",
|
||||
tlsConf, &quic.Config{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil sess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerInvalidPort(t *testing.T) {
|
||||
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
|
||||
dialer := quicdialer.DNSDialer{
|
||||
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
|
||||
sess, err := dialer.DialContext(
|
||||
context.Background(), "udp", "www.google.com:0",
|
||||
tlsConf, &quic.Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil sess")
|
||||
}
|
||||
if !strings.HasSuffix(err.Error(), "sendto: invalid argument") &&
|
||||
!strings.HasSuffix(err.Error(), "sendto: can't assign requested address") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerInvalidPortSyntax(t *testing.T) {
|
||||
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
|
||||
dialer := quicdialer.DNSDialer{
|
||||
Resolver: new(net.Resolver), Dialer: quicdialer.SystemDialer{}}
|
||||
sess, err := dialer.DialContext(
|
||||
context.Background(), "udp", "www.google.com:port",
|
||||
tlsConf, &quic.Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil sess")
|
||||
}
|
||||
if !errors.Is(err, strconv.ErrSyntax) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerDialEarlyFails(t *testing.T) {
|
||||
tlsConf := &tls.Config{NextProtos: []string{"h3-29"}}
|
||||
expected := errors.New("mocked DialEarly error")
|
||||
dialer := quicdialer.DNSDialer{
|
||||
Resolver: new(net.Resolver), Dialer: MockDialer{Err: expected}}
|
||||
sess, err := dialer.DialContext(
|
||||
context.Background(), "udp", "www.google.com:443",
|
||||
tlsConf, &quic.Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil sess")
|
||||
}
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package quicdialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// ErrorWrapperDialer is a dialer that performs quic err wrapping
|
||||
type ErrorWrapperDialer struct {
|
||||
Dialer ContextDialer
|
||||
}
|
||||
|
||||
// DialContext implements ContextDialer.DialContext
|
||||
func (d ErrorWrapperDialer) DialContext(
|
||||
ctx context.Context, network string, host string,
|
||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
dialID := dialid.ContextDialID(ctx)
|
||||
sess, err := d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
// ConnID does not make any sense if we've failed and the error
|
||||
// does not make any sense (and is nil) if we succeded.
|
||||
DialID: dialID,
|
||||
Error: err,
|
||||
Operation: errorx.QUICHandshakeOperation,
|
||||
}.MaybeBuild()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package quicdialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
|
||||
)
|
||||
|
||||
func TestErrorWrapperFailure(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
d := quicdialer.ErrorWrapperDialer{
|
||||
Dialer: MockDialer{Sess: nil, Err: io.EOF}}
|
||||
sess, err := d.DialContext(
|
||||
ctx, "udp", "www.google.com:443", &tls.Config{}, &quic.Config{})
|
||||
if sess != nil {
|
||||
t.Fatal("expected a nil sess here")
|
||||
}
|
||||
errorWrapperCheckErr(t, err, errorx.QUICHandshakeOperation)
|
||||
}
|
||||
|
||||
func errorWrapperCheckErr(t *testing.T, err error, op string) {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected another error here")
|
||||
}
|
||||
var errWrapper *errorx.ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("cannot cast to ErrWrapper")
|
||||
}
|
||||
if errWrapper.DialID == 0 {
|
||||
t.Fatal("unexpected DialID")
|
||||
}
|
||||
if errWrapper.Operation != op {
|
||||
t.Fatal("unexpected Operation")
|
||||
}
|
||||
if errWrapper.Failure != errorx.FailureEOFError {
|
||||
t.Fatal("unexpected failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorWrapperSuccess(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
tlsConf := &tls.Config{
|
||||
NextProtos: []string{"h3-29"},
|
||||
ServerName: "www.google.com",
|
||||
}
|
||||
d := quicdialer.ErrorWrapperDialer{Dialer: quicdialer.SystemDialer{}}
|
||||
sess, err := d.DialContext(ctx, "udp", "216.58.212.164:443", tlsConf, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess == nil {
|
||||
t.Fatal("expected non-nil sess here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package quicdialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// ContextDialer is a dialer for QUIC using Context.
|
||||
type ContextDialer interface {
|
||||
// Note: assumes that tlsCfg and cfg are not nil.
|
||||
DialContext(ctx context.Context, network, host string,
|
||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||
}
|
||||
|
||||
// Dialer dials QUIC connections.
|
||||
type Dialer interface {
|
||||
// Note: assumes that tlsCfg and cfg are not nil.
|
||||
Dial(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
|
||||
}
|
||||
|
||||
// Resolver is the interface we expect from a resolver.
|
||||
type Resolver interface {
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package quicdialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// HandshakeSaver saves events occurring during the handshake
|
||||
type HandshakeSaver struct {
|
||||
Saver *trace.Saver
|
||||
Dialer ContextDialer
|
||||
}
|
||||
|
||||
// DialContext implements ContextDialer.DialContext
|
||||
func (h HandshakeSaver) DialContext(ctx context.Context, network string,
|
||||
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
start := time.Now()
|
||||
// TODO(bassosimone): in the future we probably want to also save
|
||||
// information about what versions we're willing to accept.
|
||||
h.Saver.Write(trace.Event{
|
||||
Address: host,
|
||||
Name: "quic_handshake_start",
|
||||
NoTLSVerify: tlsCfg.InsecureSkipVerify,
|
||||
Proto: network,
|
||||
TLSNextProtos: tlsCfg.NextProtos,
|
||||
TLSServerName: tlsCfg.ServerName,
|
||||
Time: start,
|
||||
})
|
||||
sess, err := h.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
||||
stop := time.Now()
|
||||
if err != nil {
|
||||
h.Saver.Write(trace.Event{
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: "quic_handshake_done",
|
||||
NoTLSVerify: tlsCfg.InsecureSkipVerify,
|
||||
TLSNextProtos: tlsCfg.NextProtos,
|
||||
TLSServerName: tlsCfg.ServerName,
|
||||
Time: stop,
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
state := ConnectionState(sess)
|
||||
h.Saver.Write(trace.Event{
|
||||
Duration: stop.Sub(start),
|
||||
Name: "quic_handshake_done",
|
||||
NoTLSVerify: tlsCfg.InsecureSkipVerify,
|
||||
TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite),
|
||||
TLSNegotiatedProto: state.NegotiatedProtocol,
|
||||
TLSNextProtos: tlsCfg.NextProtos,
|
||||
TLSPeerCerts: trace.PeerCerts(state, err),
|
||||
TLSServerName: tlsCfg.ServerName,
|
||||
TLSVersion: tlsx.VersionString(state.Version),
|
||||
Time: stop,
|
||||
})
|
||||
return sess, nil
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package quicdialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
type MockDialer struct {
|
||||
Dialer quicdialer.ContextDialer
|
||||
Sess quic.EarlySession
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d MockDialer) DialContext(ctx context.Context, network, host string,
|
||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
if d.Dialer != nil {
|
||||
return d.Dialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
||||
}
|
||||
return d.Sess, d.Err
|
||||
}
|
||||
|
||||
func TestHandshakeSaverSuccess(t *testing.T) {
|
||||
nextprotos := []string{"h3-29"}
|
||||
servername := "www.google.com"
|
||||
tlsConf := &tls.Config{
|
||||
NextProtos: nextprotos,
|
||||
ServerName: servername,
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
dlr := quicdialer.HandshakeSaver{
|
||||
Dialer: quicdialer.SystemDialer{},
|
||||
Saver: saver,
|
||||
}
|
||||
sess, err := dlr.DialContext(context.Background(), "udp",
|
||||
"216.58.212.164:443", tlsConf, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if sess == nil {
|
||||
t.Fatal("unexpected nil sess")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if ev[0].Name != "quic_handshake_start" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[0].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err", ev[1].Err)
|
||||
}
|
||||
if ev[1].Name != "quic_handshake_done" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[1].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if ev[1].Time.Before(ev[0].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshakeSaverHostNameError(t *testing.T) {
|
||||
nextprotos := []string{"h3-29"}
|
||||
servername := "wrong.host.badssl.com"
|
||||
tlsConf := &tls.Config{
|
||||
NextProtos: nextprotos,
|
||||
ServerName: servername,
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
dlr := quicdialer.HandshakeSaver{
|
||||
Dialer: quicdialer.SystemDialer{},
|
||||
Saver: saver,
|
||||
}
|
||||
sess, err := dlr.DialContext(context.Background(), "udp",
|
||||
"216.58.212.164:443", tlsConf, &quic.Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil sess here")
|
||||
}
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "quic_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify == true {
|
||||
t.Fatal("expected NoTLSVerify to be false")
|
||||
}
|
||||
if !strings.Contains(ev.Err.Error(),
|
||||
"certificate is valid for www.google.com, not "+servername) {
|
||||
t.Fatal("unexpected error", ev.Err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package quicdialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// SystemDialer is the basic dialer for QUIC
|
||||
type SystemDialer struct {
|
||||
// Saver saves read/write events on the underlying UDP
|
||||
// connection. (Implementation note: we need it here since
|
||||
// this is the only part in the codebase that is able to
|
||||
// observe the underlying UDP connection.)
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// DialContext implements ContextDialer.DialContext
|
||||
func (d SystemDialer) DialContext(ctx context.Context, network string,
|
||||
host string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
|
||||
onlyhost, onlyport, err := net.SplitHostPort(host)
|
||||
port, err := strconv.Atoi(onlyport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip := net.ParseIP(onlyhost)
|
||||
if ip == nil {
|
||||
// TODO(kelmenhorst): write test for this error condition.
|
||||
return nil, errors.New("quicdialer: invalid IP representation")
|
||||
}
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
var pconn net.PacketConn = udpConn
|
||||
if d.Saver != nil {
|
||||
pconn = saverUDPConn{UDPConn: udpConn, saver: d.Saver}
|
||||
}
|
||||
udpAddr := &net.UDPAddr{IP: ip, Port: port, Zone: ""}
|
||||
return quic.DialEarlyContext(ctx, pconn, udpAddr, host, tlsCfg, cfg)
|
||||
|
||||
}
|
||||
|
||||
type saverUDPConn struct {
|
||||
*net.UDPConn
|
||||
saver *trace.Saver
|
||||
}
|
||||
|
||||
func (c saverUDPConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
start := time.Now()
|
||||
count, err := c.UDPConn.WriteTo(p, addr)
|
||||
stop := time.Now()
|
||||
c.saver.Write(trace.Event{
|
||||
Address: addr.String(),
|
||||
Data: p[:count],
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
NumBytes: count,
|
||||
Name: errorx.WriteToOperation,
|
||||
Time: stop,
|
||||
})
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c saverUDPConn) ReadMsgUDP(b, oob []byte) (int, int, int, *net.UDPAddr, error) {
|
||||
start := time.Now()
|
||||
n, oobn, flags, addr, err := c.UDPConn.ReadMsgUDP(b, oob)
|
||||
stop := time.Now()
|
||||
var data []byte
|
||||
if n > 0 {
|
||||
data = b[:n]
|
||||
}
|
||||
c.saver.Write(trace.Event{
|
||||
Address: addr.String(),
|
||||
Data: data,
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
NumBytes: n,
|
||||
Name: errorx.ReadFromOperation,
|
||||
Time: stop,
|
||||
})
|
||||
return n, oobn, flags, addr, err
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package quicdialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/quicdialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
func TestSystemDialerInvalidIPFailure(t *testing.T) {
|
||||
tlsConf := &tls.Config{
|
||||
NextProtos: []string{"h3-29"},
|
||||
ServerName: "www.google.com",
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
systemdialer := quicdialer.SystemDialer{
|
||||
Saver: saver,
|
||||
}
|
||||
sess, err := systemdialer.DialContext(context.Background(), "udp", "a.b.c.d:0", tlsConf, &quic.Config{})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if sess != nil {
|
||||
t.Fatal("expected nil sess here")
|
||||
}
|
||||
if err.Error() != "quicdialer: invalid IP representation" {
|
||||
t.Fatal("expected another error here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemDialerSuccessWithReadWrite(t *testing.T) {
|
||||
// This is the most common use case for collecting reads, writes
|
||||
tlsConf := &tls.Config{
|
||||
NextProtos: []string{"h3-29"},
|
||||
ServerName: "www.google.com",
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
systemdialer := quicdialer.SystemDialer{Saver: saver}
|
||||
_, err := systemdialer.DialContext(context.Background(), "udp",
|
||||
"216.58.212.164:443", tlsConf, &quic.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) < 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
last := len(ev) - 1
|
||||
for idx := 1; idx < last; idx++ {
|
||||
if ev[idx].Data == nil {
|
||||
t.Fatal("unexpected Data")
|
||||
}
|
||||
if ev[idx].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[idx].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[idx].NumBytes <= 0 {
|
||||
t.Fatal("unexpected NumBytes")
|
||||
}
|
||||
switch ev[idx].Name {
|
||||
case errorx.ReadFromOperation, errorx.WriteToOperation:
|
||||
default:
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[idx].Time.Before(ev[idx-1].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// AddressResolver is a resolver that knows how to correctly
|
||||
// resolve IP addresses to themselves.
|
||||
type AddressResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r AddressResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
return r.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
var _ Resolver = AddressResolver{}
|
||||
@@ -0,0 +1,36 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestAddressSuccess(t *testing.T) {
|
||||
r := resolver.AddressResolver{}
|
||||
addrs, err := r.LookupHost(context.Background(), "8.8.8.8")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddressFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := resolver.AddressResolver{
|
||||
Resolver: resolver.FakeResolver{
|
||||
Err: expected,
|
||||
},
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/runtimex"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
var privateIPBlocks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"0.0.0.0/8", // "This" network (however, Linux...)
|
||||
"10.0.0.0/8", // RFC1918
|
||||
"100.64.0.0/10", // Carrier grade NAT
|
||||
"127.0.0.0/8", // IPv4 loopback
|
||||
"169.254.0.0/16", // RFC3927 link-local
|
||||
"172.16.0.0/12", // RFC1918
|
||||
"192.168.0.0/16", // RFC1918
|
||||
"224.0.0.0/4", // Multicast
|
||||
"::1/128", // IPv6 loopback
|
||||
"fe80::/10", // IPv6 link-local
|
||||
"fc00::/7", // IPv6 unique local addr
|
||||
} {
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
runtimex.PanicOnError(err, "net.ParseCIDR failed")
|
||||
privateIPBlocks = append(privateIPBlocks, block)
|
||||
}
|
||||
}
|
||||
|
||||
func isPrivate(ip net.IP) bool {
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return true
|
||||
}
|
||||
for _, block := range privateIPBlocks {
|
||||
if block.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsBogon returns whether if an IP address is bogon. Passing to this
|
||||
// function a non-IP address causes it to return bogon.
|
||||
func IsBogon(address string) bool {
|
||||
ip := net.ParseIP(address)
|
||||
return ip == nil || isPrivate(ip)
|
||||
}
|
||||
|
||||
// BogonResolver is a bogon aware resolver. When a bogon is encountered in
|
||||
// a reply, this resolver will return an error.
|
||||
type BogonResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r BogonResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
for _, addr := range addrs {
|
||||
if IsBogon(addr) == true {
|
||||
// We need to return the addrs otherwise the caller cannot see/log/save
|
||||
// the specific addresses that triggered our bogon filter
|
||||
return addrs, errorx.ErrDNSBogon
|
||||
}
|
||||
}
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ Resolver = BogonResolver{}
|
||||
@@ -0,0 +1,52 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestResolverIsBogon(t *testing.T) {
|
||||
if resolver.IsBogon("antani") != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if resolver.IsBogon("127.0.0.1") != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if resolver.IsBogon("1.1.1.1") != false {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if resolver.IsBogon("10.0.1.1") != true {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBogonAwareResolverWithBogon(t *testing.T) {
|
||||
r := resolver.BogonResolver{
|
||||
Resolver: resolver.NewFakeResolverWithResult([]string{"127.0.0.1"}),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if !errors.Is(err, errorx.ErrDNSBogon) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "127.0.0.1" {
|
||||
t.Fatal("expected to see address here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBogonAwareResolverWithoutBogon(t *testing.T) {
|
||||
orig := []string{"8.8.8.8"}
|
||||
r := resolver.BogonResolver{
|
||||
Resolver: resolver.NewFakeResolverWithResult(orig),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != len(orig) || addrs[0] != orig[0] {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CacheResolver is a resolver that caches successful replies.
|
||||
type CacheResolver struct {
|
||||
ReadOnly bool
|
||||
Resolver
|
||||
mu sync.Mutex
|
||||
cache map[string][]string
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r *CacheResolver) LookupHost(
|
||||
ctx context.Context, hostname string) ([]string, error) {
|
||||
if entry := r.Get(hostname); entry != nil {
|
||||
return entry, nil
|
||||
}
|
||||
entry, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.ReadOnly == false {
|
||||
r.Set(hostname, entry)
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// Get gets the currently configured entry for domain, or nil
|
||||
func (r *CacheResolver) Get(domain string) []string {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.cache[domain]
|
||||
}
|
||||
|
||||
// Set allows to pre-populate the cache
|
||||
func (r *CacheResolver) Set(domain string, addresses []string) {
|
||||
r.mu.Lock()
|
||||
if r.cache == nil {
|
||||
r.cache = make(map[string][]string)
|
||||
}
|
||||
r.cache[domain] = addresses
|
||||
r.mu.Unlock()
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestCacheFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Err: expected,
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "www.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addrs here")
|
||||
}
|
||||
if cache.Get("www.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheHitSuccess(t *testing.T) {
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Err: errors.New("mocked error"),
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r}
|
||||
cache.Set("dns.google.com", []string{"8.8.8.8"})
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMissSuccess(t *testing.T) {
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Result: []string{"8.8.8.8"},
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.Get("dns.google.com")[0] != "8.8.8.8" {
|
||||
t.Fatal("expected full cache here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheReadonlySuccess(t *testing.T) {
|
||||
var r resolver.Resolver = resolver.FakeResolver{
|
||||
Result: []string{"8.8.8.8"},
|
||||
}
|
||||
cache := &resolver.CacheResolver{Resolver: r, ReadOnly: true}
|
||||
addrs, err := cache.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
if cache.Get("dns.google.com") != nil {
|
||||
t.Fatal("expected empty cache here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// ChainResolver is a chain resolver. The primary resolver is used first and, if that
|
||||
// fails, we then attempt with the secondary resolver.
|
||||
type ChainResolver struct {
|
||||
Primary Resolver
|
||||
Secondary Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (c ChainResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := c.Primary.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
addrs, err = c.Secondary.LookupHost(ctx, hostname)
|
||||
}
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (c ChainResolver) Network() string {
|
||||
return "chain"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (c ChainResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ Resolver = ChainResolver{}
|
||||
@@ -0,0 +1,28 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestChainLookupHost(t *testing.T) {
|
||||
r := resolver.ChainResolver{
|
||||
Primary: resolver.NewFakeResolverThatFails(),
|
||||
Secondary: resolver.SystemResolver{},
|
||||
}
|
||||
if r.Address() != "" {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
if r.Network() != "chain" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expect non nil return value here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// The Decoder decodes a DNS reply into A or AAAA entries. It will use the
|
||||
// provided qtype and only look for mathing entries. It will return error if
|
||||
// there are no entries for the requested qtype inside the reply.
|
||||
type Decoder interface {
|
||||
Decode(qtype uint16, data []byte) ([]string, error)
|
||||
}
|
||||
|
||||
// MiekgDecoder uses github.com/miekg/dns to implement the Decoder.
|
||||
type MiekgDecoder struct{}
|
||||
|
||||
// Decode implements Decoder.Decode.
|
||||
func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
|
||||
reply := new(dns.Msg)
|
||||
if err := reply.Unpack(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bassosimone): map more errors to net.DNSError names
|
||||
switch reply.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
case dns.RcodeNameError:
|
||||
return nil, errors.New("ooniresolver: no such host")
|
||||
default:
|
||||
return nil, errors.New("ooniresolver: query failed")
|
||||
}
|
||||
var addrs []string
|
||||
for _, answer := range reply.Answer {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
if rra, ok := answer.(*dns.A); ok {
|
||||
ip := rra.A
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
if rra, ok := answer.(*dns.AAAA); ok {
|
||||
ip := rra.AAAA
|
||||
addrs = append(addrs, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(addrs) <= 0 {
|
||||
return nil, errors.New("ooniresolver: no response returned")
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
var _ Decoder = MiekgDecoder{}
|
||||
@@ -0,0 +1,113 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDecoderUnpackError(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNXDOMAIN(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeNameError))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderOtherError(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeRefused))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderNoAddress(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeA(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "1.1.1.1" {
|
||||
t.Fatal("invalid first IPv4 entry")
|
||||
}
|
||||
if data[1] != "8.8.8.8" {
|
||||
t.Fatal("invalid second IPv4 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecodeAAAA(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "::1" {
|
||||
t.Fatal("invalid first IPv6 entry")
|
||||
}
|
||||
if data[1] != "fe80::1" {
|
||||
t.Fatal("invalid second IPv6 entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAReply(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
|
||||
)
|
||||
|
||||
// DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over
|
||||
// an HTTP/HTTPS channel provided by URL using the Do function.
|
||||
type DNSOverHTTPS struct {
|
||||
Do func(req *http.Request) (*http.Response, error)
|
||||
URL string
|
||||
HostOverride string
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPS creates a new DNSOverHTTP instance from the
|
||||
// specified http.Client and URL, as a convenience.
|
||||
func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS {
|
||||
return NewDNSOverHTTPSWithHostOverride(client, URL, "")
|
||||
}
|
||||
|
||||
// NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that
|
||||
// it's creating a resolver where we use the specified host.
|
||||
func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS {
|
||||
return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = t.HostOverride
|
||||
req.Header.Set("user-agent", httpheader.UserAgent())
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
var resp *http.Response
|
||||
resp, err = t.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
// TODO(bassosimone): we should map the status code to a
|
||||
// proper Error in the DNS context.
|
||||
return nil, errors.New("doh: server returned error")
|
||||
}
|
||||
if resp.Header.Get("content-type") != "application/dns-message" {
|
||||
return nil, errors.New("doh: invalid content-type")
|
||||
}
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoH according to RFC8467
|
||||
func (t DNSOverHTTPS) RequiresPadding() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverHTTPS) Network() string {
|
||||
return "doh"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverHTTPS) Address() string {
|
||||
return t.URL
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverHTTPS{}
|
||||
@@ -0,0 +1,165 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/httpheader"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
|
||||
const invalidURL = "\t"
|
||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return nil, expected
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 500,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: server returned error" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSMissingContentType(t *testing.T) {
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSSuccess(t *testing.T) {
|
||||
body := []byte("AAA")
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/dns-message"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, body) {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPTransportOK(t *testing.T) {
|
||||
const queryURL = "https://cloudflare-dns.com/dns-query"
|
||||
txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
|
||||
if txp.Network() != "doh" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("should require padding")
|
||||
}
|
||||
if txp.Address() != queryURL {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
var correct bool
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
|
||||
return nil, expected
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct user agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverHTTPSHostOverride(t *testing.T) {
|
||||
var correct bool
|
||||
expected := errors.New("mocked error")
|
||||
|
||||
hostOverride := "test.com"
|
||||
txp := resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
correct = req.Host == hostOverride
|
||||
return nil, expected
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct host override")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DialContextFunc is a generic function for dialing a connection.
|
||||
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
|
||||
// and NewDNSOverTLS to create specific instances that use plaintext
|
||||
// queries or encrypted queries over TLS.
|
||||
//
|
||||
// As a known bug, this implementation always creates a new connection
|
||||
// for each incoming query, thus increasing the response delay.
|
||||
type DNSOverTCP struct {
|
||||
dial DialContextFunc
|
||||
address string
|
||||
network string
|
||||
requiresPadding bool
|
||||
}
|
||||
|
||||
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
||||
func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP {
|
||||
return DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "tcp",
|
||||
requiresPadding: false,
|
||||
}
|
||||
}
|
||||
|
||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP {
|
||||
return DNSOverTCP{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "dot",
|
||||
requiresPadding: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
if len(query) > math.MaxUint16 {
|
||||
return nil, errors.New("query too long")
|
||||
}
|
||||
conn, err := t.dial(ctx, "tcp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Write request
|
||||
buf := []byte{byte(len(query) >> 8)}
|
||||
buf = append(buf, byte(len(query)))
|
||||
buf = append(buf, query...)
|
||||
if _, err = conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
header := make([]byte, 2)
|
||||
if _, err = io.ReadFull(conn, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := int(header[0])<<8 | int(header[1])
|
||||
reply := make([]byte, length)
|
||||
if _, err = io.ReadFull(conn, reply); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoT and false for TCP
|
||||
// according to RFC8467.
|
||||
func (t DNSOverTCP) RequiresPadding() bool {
|
||||
return t.requiresPadding
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverTCP) Network() string {
|
||||
return t.network
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverTCP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverTCP{}
|
||||
@@ -0,0 +1,146 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportDialFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Err: mocked}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
WriteError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(2)},
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportAllGood(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
ReadData: []byte{byte(0), byte(1), byte(1)},
|
||||
}}
|
||||
txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 || reply[0] != 1 {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTCPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "tcp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverTLSTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:853"
|
||||
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address)
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "dot" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialer is the network dialer interface assumed by this package.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// DNSOverUDP is a DNS over UDP RoundTripper.
|
||||
type DNSOverUDP struct {
|
||||
dialer Dialer
|
||||
address string
|
||||
}
|
||||
|
||||
// NewDNSOverUDP creates a DNSOverUDP instance.
|
||||
func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP {
|
||||
return DNSOverUDP{dialer: dialer, address: address}
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip.
|
||||
func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
// Use five seconds timeout like Bionic does. See
|
||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err = conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply := make([]byte, 1<<17)
|
||||
var n int
|
||||
n, err = conn.Read(reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply[:n], nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467
|
||||
func (t DNSOverUDP) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Network returns the transport network (e.g., doh, dot)
|
||||
func (t DNSOverUDP) Network() string {
|
||||
return "udp"
|
||||
}
|
||||
|
||||
// Address returns the upstream server address.
|
||||
func (t DNSOverUDP) Address() string {
|
||||
return t.address
|
||||
}
|
||||
|
||||
var _ RoundTripper = DNSOverUDP{}
|
||||
@@ -0,0 +1,107 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDNSOverUDPDialFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPSetDeadlineError(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
SetDeadlineError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPWriteFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
WriteError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadFailure(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{
|
||||
ReadError: mocked,
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPReadSuccess(t *testing.T) {
|
||||
const expected = 17
|
||||
txp := resolver.NewDNSOverUDP(
|
||||
resolver.FakeDialer{
|
||||
Conn: &resolver.FakeConn{ReadData: make([]byte, 17)},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != expected {
|
||||
t.Fatal("expected non nil data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSOverUDPTransportOK(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := resolver.NewDNSOverUDP(&net.Dialer{}, address)
|
||||
if txp.RequiresPadding() != false {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
if txp.Network() != "udp" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
if txp.Address() != address {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
)
|
||||
|
||||
// EmitterTransport is a RoundTripper that emits events when they occur.
|
||||
type EmitterTransport struct {
|
||||
RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp EmitterTransport) RoundTrip(ctx context.Context, querydata []byte) ([]byte, error) {
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
DNSQuery: &modelx.DNSQueryEvent{
|
||||
Data: querydata,
|
||||
DialID: dialid.ContextDialID(ctx),
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
},
|
||||
})
|
||||
replydata, err := txp.RoundTripper.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
DNSReply: &modelx.DNSReplyEvent{
|
||||
Data: replydata,
|
||||
DialID: dialid.ContextDialID(ctx),
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
},
|
||||
})
|
||||
return replydata, nil
|
||||
}
|
||||
|
||||
// EmitterResolver is a resolver that emits events
|
||||
type EmitterResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost returns the IP addresses of a host
|
||||
func (r EmitterResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var (
|
||||
network string
|
||||
address string
|
||||
)
|
||||
type queryableResolver interface {
|
||||
Transport() RoundTripper
|
||||
}
|
||||
if qr, ok := r.Resolver.(queryableResolver); ok {
|
||||
txp := qr.Transport()
|
||||
network, address = txp.Network(), txp.Address()
|
||||
}
|
||||
dialID := dialid.ContextDialID(ctx)
|
||||
txID := transactionid.ContextTransactionID(ctx)
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
ResolveStart: &modelx.ResolveStartEvent{
|
||||
DialID: dialID,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Hostname: hostname,
|
||||
TransactionID: txID,
|
||||
TransportAddress: address,
|
||||
TransportNetwork: network,
|
||||
},
|
||||
})
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
ResolveDone: &modelx.ResolveDoneEvent{
|
||||
Addresses: addrs,
|
||||
DialID: dialID,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Error: err,
|
||||
Hostname: hostname,
|
||||
TransactionID: txID,
|
||||
TransportAddress: address,
|
||||
TransportNetwork: network,
|
||||
},
|
||||
})
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ RoundTripper = EmitterTransport{}
|
||||
var _ Resolver = EmitterResolver{}
|
||||
@@ -0,0 +1,220 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestEmitterTransportSuccess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
|
||||
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"),
|
||||
}}
|
||||
e := resolver.MiekgEncoder{}
|
||||
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
replydata, err := txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[0].DNSQuery == nil {
|
||||
t.Fatal("missing DNSQuery field")
|
||||
}
|
||||
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
|
||||
t.Fatal("invalid query data")
|
||||
}
|
||||
if events[0].DNSQuery.DialID == 0 {
|
||||
t.Fatal("invalid query DialID")
|
||||
}
|
||||
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
if events[1].DNSReply == nil {
|
||||
t.Fatal("missing DNSReply field")
|
||||
}
|
||||
if !bytes.Equal(events[1].DNSReply.Data, replydata) {
|
||||
t.Fatal("missing reply data")
|
||||
}
|
||||
if events[1].DNSReply.DialID != 1 {
|
||||
t.Fatal("invalid query DialID")
|
||||
}
|
||||
if events[1].DNSReply.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterTransportFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
mocked := errors.New("mocked error")
|
||||
txp := resolver.EmitterTransport{RoundTripper: resolver.FakeTransport{
|
||||
Err: mocked,
|
||||
}}
|
||||
e := resolver.MiekgEncoder{}
|
||||
querydata, err := e.Encode("www.google.com", dns.TypeAAAA, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
replydata, err := txp.RoundTrip(ctx, querydata)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if replydata != nil {
|
||||
t.Fatal("expected nil replydata")
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[0].DNSQuery == nil {
|
||||
t.Fatal("missing DNSQuery field")
|
||||
}
|
||||
if !bytes.Equal(events[0].DNSQuery.Data, querydata) {
|
||||
t.Fatal("invalid query data")
|
||||
}
|
||||
if events[0].DNSQuery.DialID == 0 {
|
||||
t.Fatal("invalid query DialID")
|
||||
}
|
||||
if events[0].DNSQuery.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterResolverFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver(
|
||||
resolver.DNSOverHTTPS{
|
||||
Do: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, io.EOF
|
||||
},
|
||||
URL: "https://dns.google.com/",
|
||||
},
|
||||
)}
|
||||
replies, err := r.LookupHost(ctx, "www.google.com")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if replies != nil {
|
||||
t.Fatal("expected nil replies")
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[0].ResolveStart == nil {
|
||||
t.Fatal("missing ResolveStart field")
|
||||
}
|
||||
if events[0].ResolveStart.DialID == 0 {
|
||||
t.Fatal("invalid DialID")
|
||||
}
|
||||
if events[0].ResolveStart.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
if events[0].ResolveStart.Hostname != "www.google.com" {
|
||||
t.Fatal("invalid Hostname")
|
||||
}
|
||||
if events[0].ResolveStart.TransactionID == 0 {
|
||||
t.Fatal("invalid TransactionID")
|
||||
}
|
||||
if events[0].ResolveStart.TransportAddress != "https://dns.google.com/" {
|
||||
t.Fatal("invalid TransportAddress")
|
||||
}
|
||||
if events[0].ResolveStart.TransportNetwork != "doh" {
|
||||
t.Fatal("invalid TransportNetwork")
|
||||
}
|
||||
if events[1].ResolveDone == nil {
|
||||
t.Fatal("missing ResolveDone field")
|
||||
}
|
||||
if events[1].ResolveDone.DialID == 0 {
|
||||
t.Fatal("invalid DialID")
|
||||
}
|
||||
if events[1].ResolveDone.DurationSinceBeginning <= 0 {
|
||||
t.Fatal("invalid duration since beginning")
|
||||
}
|
||||
if events[1].ResolveDone.Error != io.EOF {
|
||||
t.Fatal("invalid Error")
|
||||
}
|
||||
if events[1].ResolveDone.Hostname != "www.google.com" {
|
||||
t.Fatal("invalid Hostname")
|
||||
}
|
||||
if events[1].ResolveDone.TransactionID == 0 {
|
||||
t.Fatal("invalid TransactionID")
|
||||
}
|
||||
if events[1].ResolveDone.TransportAddress != "https://dns.google.com/" {
|
||||
t.Fatal("invalid TransportAddress")
|
||||
}
|
||||
if events[1].ResolveDone.TransportNetwork != "doh" {
|
||||
t.Fatal("invalid TransportNetwork")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterResolverSuccess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
handler := &handlers.SavingHandler{}
|
||||
root := &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, root)
|
||||
r := resolver.EmitterResolver{Resolver: resolver.NewFakeResolverWithResult(
|
||||
[]string{"8.8.8.8"},
|
||||
)}
|
||||
replies, err := r.LookupHost(ctx, "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(replies) != 1 {
|
||||
t.Fatal("expected a single replies")
|
||||
}
|
||||
events := handler.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if events[1].ResolveDone == nil {
|
||||
t.Fatal("missing ResolveDone field")
|
||||
}
|
||||
if events[1].ResolveDone.Addresses[0] != "8.8.8.8" {
|
||||
t.Fatal("invalid Addresses")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package resolver
|
||||
|
||||
import "github.com/miekg/dns"
|
||||
|
||||
// The Encoder encodes DNS queries to bytes
|
||||
type Encoder interface {
|
||||
Encode(domain string, qtype uint16, padding bool) ([]byte, error)
|
||||
}
|
||||
|
||||
// MiekgEncoder uses github.com/miekg/dns to implement the Encoder.
|
||||
type MiekgEncoder struct{}
|
||||
|
||||
const (
|
||||
// PaddingDesiredBlockSize is the size that the padded query should be multiple of
|
||||
PaddingDesiredBlockSize = 128
|
||||
|
||||
// EDNS0MaxResponseSize is the maximum response size for EDNS0
|
||||
EDNS0MaxResponseSize = 4096
|
||||
|
||||
// DNSSECEnabled turns on support for DNSSEC when using EDNS0
|
||||
DNSSECEnabled = true
|
||||
)
|
||||
|
||||
// Encode implements Encoder.Encode
|
||||
func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
if padding {
|
||||
query.SetEdns0(EDNS0MaxResponseSize, DNSSECEnabled)
|
||||
// Clients SHOULD pad queries to the closest multiple of
|
||||
// 128 octets RFC8467#section-4.1. We inflate the query
|
||||
// length by the size of the option (i.e. 4 octets). The
|
||||
// cast to uint is necessary to make the modulus operation
|
||||
// work as intended when the desiredBlockSize is smaller
|
||||
// than (query.Len()+4) ¯\_(ツ)_/¯.
|
||||
remainder := (PaddingDesiredBlockSize - uint(query.Len()+4)) % PaddingDesiredBlockSize
|
||||
opt := new(dns.EDNS0_PADDING)
|
||||
opt.Padding = make([]byte, remainder)
|
||||
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
||||
}
|
||||
return query.Pack()
|
||||
}
|
||||
|
||||
var _ Encoder = MiekgEncoder{}
|
||||
@@ -0,0 +1,99 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestEncoderEncodeA(t *testing.T) {
|
||||
e := resolver.MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func TestEncoderEncodeAAAA(t *testing.T) {
|
||||
e := resolver.MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
validate(t, data, byte(dns.TypeA))
|
||||
}
|
||||
|
||||
func validate(t *testing.T, data []byte, qtype byte) {
|
||||
// skipping over the query ID
|
||||
if data[2] != 1 {
|
||||
t.Fatal("FLAGS should only have RD set")
|
||||
}
|
||||
if data[3] != 0 {
|
||||
t.Fatal("RA|Z|Rcode should be zero")
|
||||
}
|
||||
if data[4] != 0 || data[5] != 1 {
|
||||
t.Fatal("QCOUNT high should be one")
|
||||
}
|
||||
if data[6] != 0 || data[7] != 0 {
|
||||
t.Fatal("ANCOUNT should be zero")
|
||||
}
|
||||
if data[8] != 0 || data[9] != 0 {
|
||||
t.Fatal("NSCOUNT should be zero")
|
||||
}
|
||||
if data[10] != 0 || data[11] != 0 {
|
||||
t.Fatal("ARCOUNT should be zero")
|
||||
}
|
||||
t.Log(data[12])
|
||||
if data[12] != 1 || data[13] != byte('x') {
|
||||
t.Fatal("The name does not contain 1:x")
|
||||
}
|
||||
if data[14] != 3 || data[15] != byte('o') || data[16] != byte('r') || data[17] != byte('g') {
|
||||
t.Fatal("The name does not containg 3:org")
|
||||
}
|
||||
if data[18] != 0 {
|
||||
t.Fatal("The name does not terminate where expected")
|
||||
}
|
||||
if data[19] != 0 && data[20] != qtype {
|
||||
t.Fatal("The query is not for the expected type")
|
||||
}
|
||||
if data[21] != 0 && data[22] != 1 {
|
||||
t.Fatal("The query is not IN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderPadding(t *testing.T) {
|
||||
// The purpose of this unit test is to make sure that for a wide
|
||||
// array of values we obtain the right query size.
|
||||
getquerylen := func(domainlen int, padding bool) int {
|
||||
e := resolver.MiekgEncoder{}
|
||||
data, err := e.Encode(
|
||||
// This is not a valid name because it ends up being way
|
||||
// longer than 255 octets. However, the library is allowing
|
||||
// us to generate such name and we are not going to send
|
||||
// it on the wire. Also, we check below that the query that
|
||||
// we generate is long enough, so we should be good.
|
||||
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
||||
dns.TypeA, padding,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return len(data)
|
||||
}
|
||||
for domainlen := 1; domainlen <= 4000; domainlen++ {
|
||||
vanillalen := getquerylen(domainlen, false)
|
||||
paddedlen := getquerylen(domainlen, true)
|
||||
if vanillalen < domainlen {
|
||||
t.Fatal("vanillalen is smaller than domainlen")
|
||||
}
|
||||
if (paddedlen % resolver.PaddingDesiredBlockSize) != 0 {
|
||||
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
|
||||
}
|
||||
if paddedlen < vanillalen {
|
||||
t.Fatal("paddedlen is smaller than vanillalen")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// ErrorWrapperResolver is a Resolver that knows about wrapping errors.
|
||||
type ErrorWrapperResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r ErrorWrapperResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
dialID := dialid.ContextDialID(ctx)
|
||||
txID := transactionid.ContextTransactionID(ctx)
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
DialID: dialID,
|
||||
Error: err,
|
||||
Operation: errorx.ResolveOperation,
|
||||
TransactionID: txID,
|
||||
}.MaybeBuild()
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ Resolver = ErrorWrapperResolver{}
|
||||
@@ -0,0 +1,58 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestErrorWrapperSuccess(t *testing.T) {
|
||||
orig := []string{"8.8.8.8"}
|
||||
r := resolver.ErrorWrapperResolver{
|
||||
Resolver: resolver.NewFakeResolverWithResult(orig),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != len(orig) || addrs[0] != orig[0] {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorWrapperFailure(t *testing.T) {
|
||||
r := resolver.ErrorWrapperResolver{
|
||||
Resolver: resolver.NewFakeResolverThatFails(),
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx = dialid.WithDialID(ctx)
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
addrs, err := r.LookupHost(ctx, "dns.google.com")
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addr here")
|
||||
}
|
||||
var errWrapper *errorx.ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("cannot properly cast the returned error")
|
||||
}
|
||||
if errWrapper.Failure != errorx.FailureDNSNXDOMAINError {
|
||||
t.Fatal("unexpected failure")
|
||||
}
|
||||
if errWrapper.ConnID != 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if errWrapper.DialID == 0 {
|
||||
t.Fatal("unexpected DialID")
|
||||
}
|
||||
if errWrapper.TransactionID == 0 {
|
||||
t.Fatal("unexpected TransactionID")
|
||||
}
|
||||
if errWrapper.Operation != errorx.ResolveOperation {
|
||||
t.Fatal("unexpected Operation")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
)
|
||||
|
||||
type FakeDialer struct {
|
||||
Conn net.Conn
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return d.Conn, d.Err
|
||||
}
|
||||
|
||||
type FakeConn struct {
|
||||
ReadError error
|
||||
ReadData []byte
|
||||
SetDeadlineError error
|
||||
SetReadDeadlineError error
|
||||
SetWriteDeadlineError error
|
||||
WriteError error
|
||||
}
|
||||
|
||||
func (c *FakeConn) Read(b []byte) (int, error) {
|
||||
if len(c.ReadData) > 0 {
|
||||
n := copy(b, c.ReadData)
|
||||
c.ReadData = c.ReadData[n:]
|
||||
return n, nil
|
||||
}
|
||||
if c.ReadError != nil {
|
||||
return 0, c.ReadError
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *FakeConn) Write(b []byte) (n int, err error) {
|
||||
if c.WriteError != nil {
|
||||
return 0, c.WriteError
|
||||
}
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) Close() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (*FakeConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
|
||||
return c.SetDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
|
||||
return c.SetReadDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
|
||||
return c.SetWriteDeadlineError
|
||||
}
|
||||
|
||||
type FakeTransport struct {
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
return ft.Data, ft.Err
|
||||
}
|
||||
|
||||
func (ft FakeTransport) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (ft FakeTransport) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ft FakeTransport) Network() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
type FakeEncoder struct {
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return fe.Data, fe.Err
|
||||
}
|
||||
|
||||
type FakeResolver struct {
|
||||
NumFailures *atomicx.Int64
|
||||
Err error
|
||||
Result []string
|
||||
}
|
||||
|
||||
func NewFakeResolverThatFails() FakeResolver {
|
||||
return FakeResolver{NumFailures: atomicx.NewInt64(), Err: errNotFound}
|
||||
}
|
||||
|
||||
func NewFakeResolverWithResult(r []string) FakeResolver {
|
||||
return FakeResolver{NumFailures: atomicx.NewInt64(), Result: r}
|
||||
}
|
||||
|
||||
var errNotFound = &net.DNSError{
|
||||
Err: "no such host",
|
||||
}
|
||||
|
||||
func (c FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
if c.Err != nil {
|
||||
if c.NumFailures != nil {
|
||||
c.NumFailures.Add(1)
|
||||
}
|
||||
return nil, c.Err
|
||||
}
|
||||
return c.Result, nil
|
||||
}
|
||||
|
||||
func (c FakeResolver) Network() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (c FakeResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ Resolver = FakeResolver{}
|
||||
@@ -0,0 +1,76 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func GenReplyError(t *testing.T, code int) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetRcode(query, code)
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func GenReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Qtype: qtype,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetReply(query)
|
||||
for _, ip := range ips {
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
reply.Answer = append(reply.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
A: net.ParseIP(ip),
|
||||
})
|
||||
case dns.TypeAAAA:
|
||||
reply.Answer = append(reply.Answer, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
AAAA: net.ParseIP(ip),
|
||||
})
|
||||
}
|
||||
}
|
||||
data, err := reply.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return data
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// IDNAResolver is to support resolving Internationalized Domain Names.
|
||||
// See RFC3492 for more information.
|
||||
type IDNAResolver struct {
|
||||
Resolver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r IDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
host, err := idna.ToASCII(hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Resolver.LookupHost(ctx, host)
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network.
|
||||
func (r IDNAResolver) Network() string {
|
||||
return "idna"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address.
|
||||
func (r IDNAResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ Resolver = IDNAResolver{}
|
||||
@@ -0,0 +1,76 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
var ErrUnexpectedPunycode = errors.New("unexpected punycode value")
|
||||
|
||||
type CheckIDNAResolver struct {
|
||||
Addresses []string
|
||||
Error error
|
||||
Expect string
|
||||
}
|
||||
|
||||
func (resolv CheckIDNAResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if resolv.Error != nil {
|
||||
return nil, resolv.Error
|
||||
}
|
||||
if hostname != resolv.Expect {
|
||||
return nil, ErrUnexpectedPunycode
|
||||
}
|
||||
return resolv.Addresses, nil
|
||||
}
|
||||
|
||||
func (r CheckIDNAResolver) Network() string {
|
||||
return "checkidna"
|
||||
}
|
||||
|
||||
func (r CheckIDNAResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func TestIDNAResolverSuccess(t *testing.T) {
|
||||
expectedIPs := []string{"77.88.55.66"}
|
||||
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
|
||||
Addresses: expectedIPs,
|
||||
Expect: "xn--d1acpjx3f.xn--p1ai",
|
||||
}}
|
||||
addrs, err := resolv.LookupHost(context.Background(), "яндекс.рф")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expectedIPs, addrs); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDNAResolverFailure(t *testing.T) {
|
||||
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{
|
||||
Error: errors.New("we should not arrive here"),
|
||||
}}
|
||||
// See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/
|
||||
addrs, err := resolv.LookupHost(context.Background(), "xn--0000h")
|
||||
if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDNAResolverTransportOK(t *testing.T) {
|
||||
resolv := resolver.IDNAResolver{Resolver: CheckIDNAResolver{}}
|
||||
if resolv.Network() != "idna" {
|
||||
t.Fatal("invalid network")
|
||||
}
|
||||
if resolv.Address() != "" {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
func testresolverquick(t *testing.T, reso resolver.Resolver) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
reso = resolver.LoggingResolver{Logger: log.Log, Resolver: reso}
|
||||
addrs, err := reso.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil addrs here")
|
||||
}
|
||||
var foundquad8 bool
|
||||
for _, addr := range addrs {
|
||||
// See https://github.com/ooni/probe-cli/v3/internal/engine/pull/954/checks?check_run_id=1182269025
|
||||
if addr == "8.8.8.8" || addr == "2001:4860:4860::8888" {
|
||||
foundquad8 = true
|
||||
}
|
||||
}
|
||||
if !foundquad8 {
|
||||
t.Fatalf("did not find 8.8.8.8 in ouput; output=%+v", addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensuring we can handle Internationalized Domain Names (IDNs) without issues
|
||||
func testresolverquickidna(t *testing.T, reso resolver.Resolver) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
reso = resolver.IDNAResolver{
|
||||
resolver.LoggingResolver{Logger: log.Log, Resolver: reso},
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "яндекс.рф")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil addrs here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolverSystem(t *testing.T) {
|
||||
reso := resolver.SystemResolver{}
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverUDPAddress(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverUDP(new(net.Dialer), "8.8.8.8:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverUDPDomain(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverUDP(new(net.Dialer), "dns.google.com:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverTCPAddress(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "8.8.8.8:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverTCPDomain(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTCP(new(net.Dialer).DialContext, "dns.google.com:53"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverDoTAddress(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverDoTDomain(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverTLS(resolver.DialTLSContext, "dns.google.com:853"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverDoH(t *testing.T) {
|
||||
reso := resolver.NewSerialResolver(
|
||||
resolver.NewDNSOverHTTPS(http.DefaultClient, "https://cloudflare-dns.com/dns-query"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger is the logger assumed by this package
|
||||
type Logger interface {
|
||||
Debugf(format string, v ...interface{})
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
// LoggingResolver is a resolver that emits events
|
||||
type LoggingResolver struct {
|
||||
Resolver
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// LookupHost returns the IP addresses of a host
|
||||
func (r LoggingResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
r.Logger.Debugf("resolve %s...", hostname)
|
||||
start := time.Now()
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
stop := time.Now()
|
||||
r.Logger.Debugf("resolve %s... (%+v, %+v) in %s", hostname, addrs, err, stop.Sub(start))
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
var _ Resolver = LoggingResolver{}
|
||||
@@ -0,0 +1,23 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestLoggingResolver(t *testing.T) {
|
||||
r := resolver.LoggingResolver{
|
||||
Logger: log.Log,
|
||||
Resolver: resolver.NewFakeResolverThatFails(),
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.google.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil addr here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Resolver is a DNS resolver. The *net.Resolver used by Go implements
|
||||
// this interface, but other implementations are possible.
|
||||
type Resolver interface {
|
||||
// LookupHost resolves a hostname to a list of IP addresses.
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
|
||||
// Network returns the network being used by the resolver
|
||||
Network() string
|
||||
|
||||
// Address returns the address being used by the resolver
|
||||
Address() string
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// SaverResolver is a resolver that saves events
|
||||
type SaverResolver struct {
|
||||
Resolver
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (r SaverResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
start := time.Now()
|
||||
r.Saver.Write(trace.Event{
|
||||
Address: r.Resolver.Address(),
|
||||
Hostname: hostname,
|
||||
Name: "resolve_start",
|
||||
Proto: r.Resolver.Network(),
|
||||
Time: start,
|
||||
})
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
stop := time.Now()
|
||||
r.Saver.Write(trace.Event{
|
||||
Addresses: addrs,
|
||||
Address: r.Resolver.Address(),
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Hostname: hostname,
|
||||
Name: "resolve_done",
|
||||
Proto: r.Resolver.Network(),
|
||||
Time: stop,
|
||||
})
|
||||
return addrs, err
|
||||
}
|
||||
|
||||
// SaverDNSTransport is a DNS transport that saves events
|
||||
type SaverDNSTransport struct {
|
||||
RoundTripper
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
start := time.Now()
|
||||
txp.Saver.Write(trace.Event{
|
||||
Address: txp.Address(),
|
||||
DNSQuery: query,
|
||||
Name: "dns_round_trip_start",
|
||||
Proto: txp.Network(),
|
||||
Time: start,
|
||||
})
|
||||
reply, err := txp.RoundTripper.RoundTrip(ctx, query)
|
||||
stop := time.Now()
|
||||
txp.Saver.Write(trace.Event{
|
||||
Address: txp.Address(),
|
||||
DNSQuery: query,
|
||||
DNSReply: reply,
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: "dns_round_trip_done",
|
||||
Proto: txp.Network(),
|
||||
Time: stop,
|
||||
})
|
||||
return reply, err
|
||||
}
|
||||
|
||||
var _ Resolver = SaverResolver{}
|
||||
var _ RoundTripper = SaverDNSTransport{}
|
||||
@@ -0,0 +1,211 @@
|
||||
package resolver_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
func TestSaverResolverFailure(t *testing.T) {
|
||||
expected := errors.New("no such host")
|
||||
saver := &trace.Saver{}
|
||||
reso := resolver.SaverResolver{
|
||||
Resolver: resolver.FakeResolver{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if ev[0].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[0].Name != "resolve_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if ev[1].Addresses != nil {
|
||||
t.Fatal("unexpected Addresses")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if !errors.Is(ev[1].Err, expected) {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[1].Name != "resolve_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverResolverSuccess(t *testing.T) {
|
||||
expected := []string{"8.8.8.8", "8.8.4.4"}
|
||||
saver := &trace.Saver{}
|
||||
reso := resolver.SaverResolver{
|
||||
Resolver: resolver.FakeResolver{
|
||||
Result: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
addrs, err := reso.LookupHost(context.Background(), "www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal("expected nil error here")
|
||||
}
|
||||
if !reflect.DeepEqual(addrs, expected) {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if ev[0].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[0].Name != "resolve_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[1].Addresses, expected) {
|
||||
t.Fatal("unexpected Addresses")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Hostname != "www.google.com" {
|
||||
t.Fatal("unexpected Hostname")
|
||||
}
|
||||
if ev[1].Name != "resolve_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverDNSTransportFailure(t *testing.T) {
|
||||
expected := errors.New("no such host")
|
||||
saver := &trace.Saver{}
|
||||
txp := resolver.SaverDNSTransport{
|
||||
RoundTripper: resolver.FakeTransport{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
query := []byte("abc")
|
||||
reply, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[0].Name != "dns_round_trip_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[1].DNSReply != nil {
|
||||
t.Fatal("unexpected DNSReply")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if !errors.Is(ev[1].Err, expected) {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Name != "dns_round_trip_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverDNSTransportSuccess(t *testing.T) {
|
||||
expected := []byte("def")
|
||||
saver := &trace.Saver{}
|
||||
txp := resolver.SaverDNSTransport{
|
||||
RoundTripper: resolver.FakeTransport{
|
||||
Data: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
query := []byte("abc")
|
||||
reply, err := txp.RoundTrip(context.Background(), query)
|
||||
if err != nil {
|
||||
t.Fatal("we expected nil error here")
|
||||
}
|
||||
if !bytes.Equal(reply, expected) {
|
||||
t.Fatal("expected another reply here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[0].Name != "dns_round_trip_start" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSReply, expected) {
|
||||
t.Fatal("unexpected DNSReply")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Name != "dns_round_trip_done" {
|
||||
t.Fatal("unexpected name")
|
||||
}
|
||||
if !ev[1].Time.After(ev[0].Time) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
)
|
||||
|
||||
// RoundTripper represents an abstract DNS transport.
|
||||
type RoundTripper interface {
|
||||
// RoundTrip sends a DNS query and receives the reply.
|
||||
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
|
||||
// RequiresPadding return true for DoH and DoT according to RFC8467
|
||||
RequiresPadding() bool
|
||||
|
||||
// Network is the network of the round tripper (e.g. "dot")
|
||||
Network() string
|
||||
|
||||
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
|
||||
Address() string
|
||||
}
|
||||
|
||||
// SerialResolver is a resolver that first issues an A query and then
|
||||
// issues an AAAA query for the requested domain.
|
||||
type SerialResolver struct {
|
||||
Encoder Encoder
|
||||
Decoder Decoder
|
||||
NumTimeouts *atomicx.Int64
|
||||
Txp RoundTripper
|
||||
}
|
||||
|
||||
// NewSerialResolver creates a new OONI Resolver instance.
|
||||
func NewSerialResolver(t RoundTripper) SerialResolver {
|
||||
return SerialResolver{
|
||||
Encoder: MiekgEncoder{},
|
||||
Decoder: MiekgDecoder{},
|
||||
NumTimeouts: atomicx.NewInt64(),
|
||||
Txp: t,
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the transport being used.
|
||||
func (r SerialResolver) Transport() RoundTripper {
|
||||
return r.Txp
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (r SerialResolver) Network() string {
|
||||
return r.Txp.Network()
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (r SerialResolver) Address() string {
|
||||
return r.Txp.Address()
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var addrs []string
|
||||
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
|
||||
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
|
||||
if errA != nil && errAAAA != nil {
|
||||
return nil, errA
|
||||
}
|
||||
addrs = append(addrs, addrsA...)
|
||||
addrs = append(addrs, addrsAAAA...)
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (r SerialResolver) roundTripWithRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
var errorslist []error
|
||||
for i := 0; i < 3; i++ {
|
||||
replies, err := r.roundTrip(ctx, hostname, qtype)
|
||||
if err == nil {
|
||||
return replies, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
var operr *net.OpError
|
||||
if errors.As(err, &operr) == false || operr.Timeout() == false {
|
||||
// The first error is the one that is most likely to be caused
|
||||
// by the network. Subsequent errors are more likely to be caused
|
||||
// by context deadlines. So, the first error is attached to an
|
||||
// operation, while subsequent errors may possibly not be. If
|
||||
// so, the resulting failing operation is not correct.
|
||||
break
|
||||
}
|
||||
r.NumTimeouts.Add(1)
|
||||
}
|
||||
// bugfix: we MUST return one of the errors otherwise we confuse the
|
||||
// mechanism in errwrap that classifies the root cause operation, since
|
||||
// it would not be able to find a child with a major operation error
|
||||
return nil, errorslist[0]
|
||||
}
|
||||
|
||||
func (r SerialResolver) roundTrip(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.Decode(qtype, replydata)
|
||||
}
|
||||
|
||||
var _ Resolver = SerialResolver{}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user