chore: merge probe-engine into probe-cli (#201)
This is how I did it: 1. `git clone https://github.com/ooni/probe-engine internal/engine` 2. ``` (cd internal/engine && git describe --tags) v0.23.0 ``` 3. `nvim go.mod` (merging `go.mod` with `internal/engine/go.mod` 4. `rm -rf internal/.git internal/engine/go.{mod,sum}` 5. `git add internal/engine` 6. `find . -type f -name \*.go -exec sed -i 's@/ooni/probe-engine@/ooni/probe-cli/v3/internal/engine@g' {} \;` 7. `go build ./...` (passes) 8. `go test -race ./...` (temporary failure on RiseupVPN) 9. `go mod tidy` 10. this commit message Once this piece of work is done, we can build a new version of `ooniprobe` that is using `internal/engine` directly. We need to do more work to ensure all the other functionality in `probe-engine` (e.g. making mobile packages) are still WAI. Part of https://github.com/ooni/probe/issues/1335
This commit is contained in:
@@ -0,0 +1,400 @@
|
||||
# OONI Network Extensions
|
||||
|
||||
| Author | Simone Basso |
|
||||
|--------------|--------------|
|
||||
| Last-Updated | 2020-04-02 |
|
||||
| Status | approved |
|
||||
|
||||
## Introduction
|
||||
|
||||
OONI experiments send and/or receive network traffic to
|
||||
determine if there is blocking. We want the implementation
|
||||
of OONI experiments to be as simple as possible. We also
|
||||
_want to attribute errors to the major network or protocol
|
||||
operation that caused them_.
|
||||
|
||||
At the same time, _we want an experiment to collect as much
|
||||
low-level data as possible_. For example, we want to know
|
||||
whether and when the TLS handshake completed; what certificates
|
||||
were provided by the server; what TLS version was selected;
|
||||
and so forth. These bits of information are very useful
|
||||
to analyze a measurement and better classify it.
|
||||
|
||||
We also want to _automatically or manually run follow-up
|
||||
measurements where we change some configuration properties
|
||||
and repeat the measurement_. For example, we may want to
|
||||
configure DNS over HTTPS (DoH) and then attempt to
|
||||
fetch again an URL. Or we may want to detect whether
|
||||
there is SNI bases blocking. This package allows us to
|
||||
do that in other parts of probe-engine.
|
||||
|
||||
## Rationale
|
||||
|
||||
As we observed [ooni/probe-engine#13](
|
||||
https://github.com/ooni/probe-engine/issues/13), every
|
||||
experiment consists of two separate phases:
|
||||
|
||||
1. measurement gathering
|
||||
|
||||
2. measurement analysis
|
||||
|
||||
During measurement gathering, we perform specific actions
|
||||
that cause network data to be sent and/or received. During
|
||||
measurement analysis, we process the measurement on the
|
||||
device. For some experiments (e.g., Web Connectivity), this
|
||||
second phase also entails contacting OONI backend services
|
||||
that provide data useful to complete the analysis.
|
||||
|
||||
This package implements measurement gathering. The analysis
|
||||
is performed by other packages in probe-engine. The core
|
||||
design idea is to provide OONI-measurements-aware replacements
|
||||
for Go standard library interfaces, e.g., the
|
||||
`http.RoundTripper`. On top of that, we'll create all the
|
||||
required interfaces to achive the measurement goals mentioned above.
|
||||
|
||||
We are of course writing test templates in `probe-engine`
|
||||
anyway, because we need additional abstraction, but we can
|
||||
take advantage of the fact that the API exposed by this package
|
||||
is stable by definition, because it mimics the stdlib. Also,
|
||||
for many experiments we can collect information pertaining
|
||||
to TCP, DNS, TLS, and HTTP with a single call to `netx`.
|
||||
|
||||
This code used to live at `github.com/ooni/netx`. On 2020-03-02
|
||||
we merged github.com/ooni/netx@4f8d645bce6466bb into `probe-engine`
|
||||
because it was more practical and enabled easier refactoring.
|
||||
|
||||
## Definitions
|
||||
|
||||
Consistently with Go's terminology, we define
|
||||
_HTTP round trip_ the process where we get a request
|
||||
to send; we find a suitable connection for sending
|
||||
it, or we create one; we send headers and
|
||||
possibly body; and we receive response headers.
|
||||
|
||||
We also define _HTTP transaction_ the process starting
|
||||
with an HTTP round trip and terminating by reading
|
||||
the full response body.
|
||||
|
||||
We define _netx replacement_ a Go struct of interface that
|
||||
has the same interface of a Go standard library object
|
||||
but additionally performs measurements.
|
||||
|
||||
## Enhanced error handling
|
||||
|
||||
This library MUST wrap `error` such that:
|
||||
|
||||
1. we can classify all errors we care about; and
|
||||
|
||||
2. we can map them to major operations.
|
||||
|
||||
The `github.com/ooni/netx/modelx` MUST contain a wrapper for
|
||||
Go `error` named `ErrWrapper` that is at least like:
|
||||
|
||||
```Go
|
||||
type ErrWrapper struct {
|
||||
Failure string // error classification
|
||||
Operation string // operation that caused error
|
||||
WrappedErr error // the original error
|
||||
}
|
||||
|
||||
func (e *ErrWrapper) Error() string {
|
||||
return e.Failure
|
||||
}
|
||||
```
|
||||
|
||||
Where `Failure` is one of the errors we care about, i.e.:
|
||||
|
||||
- `connection_refused`: ECONNREFUSED
|
||||
- `connection_reset`: ECONNRESET
|
||||
- `dns_bogon_error`: detected bogon in DNS reply
|
||||
- `dns_nxdomain_error`: NXDOMAIN in DNS reply
|
||||
- `eof_error`: unexpected EOF on connection
|
||||
- `generic_timeout_error`: some timer has expired
|
||||
- `ssl_invalid_hostname`: certificate not valid for SNI
|
||||
- `ssl_unknown_autority`: cannot find CA validating certificate
|
||||
- `ssl_invalid_certificate`: e.g. certificate expired
|
||||
- `unknown_failure <string>`: any other error
|
||||
|
||||
Note that we care about bogons in DNS replies because they are
|
||||
often used to censor specific websites.
|
||||
|
||||
And where `Operation` is one of:
|
||||
|
||||
- `resolve`: domain name resolution
|
||||
- `connect`: TCP connect
|
||||
- `tls_handshake`: TLS handshake
|
||||
- `http_round_trip`: reading/writing HTTP
|
||||
|
||||
The code in this library MUST wrap returned errors such
|
||||
that we can cast back to `ErrWrapper` during the analysis
|
||||
phase, using Go 1.13 `errors` library as follows:
|
||||
|
||||
```Go
|
||||
var wrapper *modelx.ErrWrapper
|
||||
if errors.As(err, &wrapper) == true {
|
||||
// Do something with the error
|
||||
}
|
||||
```
|
||||
|
||||
## Netx replacements
|
||||
|
||||
We want to provide netx replacements for the following
|
||||
interfaces in the Go standard library:
|
||||
|
||||
1. `http.RoundTripper`
|
||||
|
||||
2. `http.Client`
|
||||
|
||||
3. `net.Dialer`
|
||||
|
||||
4. `net.Resolver`
|
||||
|
||||
Accordingly, we'll define the following interfaces in
|
||||
the `github.com/ooni/probe-engine/netx/modelx` package:
|
||||
|
||||
```Go
|
||||
type DNSResolver interface {
|
||||
LookupHost(ctx context.Context, hostname string) ([]string, error)
|
||||
}
|
||||
|
||||
type Dialer interface {
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type TLSDialer interface {
|
||||
DialTLS(network, address string) (net.Conn, error)
|
||||
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
```
|
||||
|
||||
We won't need an interface for `http.RoundTripper`
|
||||
because it is already an interface, so we'll just use it.
|
||||
|
||||
Our replacements will implement these interfaces.
|
||||
|
||||
Using an API compatible with Go's standard libary makes
|
||||
it possible to use, say, our `net.Dialer` replacement with
|
||||
other libraries. Both `http.Transport` and
|
||||
`gorilla/websocket`'s `websocket.Dialer` have
|
||||
functions like `Dial` and `DialContext` that can be
|
||||
overriden. By overriding such function pointers,
|
||||
we could use our replacements instead of the standard
|
||||
libary, thus we could collect measurements while
|
||||
using third party code to implement specific protocols.
|
||||
|
||||
Also, using interfaces allows us to combine code
|
||||
quite easily. For example, a resolver that detects
|
||||
bogons is easily implemented as a wrapper around
|
||||
another resolve that performs the real resolution.
|
||||
|
||||
## Dispatching events
|
||||
|
||||
The `github.com/ooni/netx/modelx` package will define
|
||||
an handler for low level events as:
|
||||
|
||||
```Go
|
||||
type Handler interface {
|
||||
OnMeasurement(Measurement)
|
||||
}
|
||||
```
|
||||
|
||||
We will provide a mechanism to bind a specific
|
||||
handler to a `context.Context` such that the handler
|
||||
will receive all the measurements caused by code
|
||||
using such context. This mechanism is like:
|
||||
|
||||
```Go
|
||||
type MeasurementRoot struct {
|
||||
Beginning time.Time // the "zero" time
|
||||
Handler Handler // the handler to use
|
||||
}
|
||||
```
|
||||
|
||||
You will be able to assign a `MeasurementRoot` to
|
||||
a context by using the following function:
|
||||
|
||||
```Go
|
||||
func WithMeasurementRoot(
|
||||
ctx context.Context, root *MeasurementRoot) context.Context
|
||||
```
|
||||
|
||||
which will return a clone of the original context
|
||||
that uses the `MeasurementRoot`. Pass this context to
|
||||
any method of our replacements to get measurements.
|
||||
|
||||
Given such context, or a subcontext, you can get
|
||||
back the original `MeasurementRoot` using:
|
||||
|
||||
```Go
|
||||
func ContextMeasurementRoot(ctx context.Context) *MeasurementRoot
|
||||
```
|
||||
|
||||
which will return the context `MeasurementRoot` or
|
||||
`nil` if none is set into the context. This is how our
|
||||
internal code gets access to the `MeasurementRoot`.
|
||||
|
||||
## Constructing and configuring replacements
|
||||
|
||||
The `github.com/ooni/probe-engine/netx` package MUST provide an API such
|
||||
that you can construct and configure a `net.Resolver` replacement
|
||||
as follows:
|
||||
|
||||
```Go
|
||||
r, err := netx.NewResolverWithoutHandler(dnsNetwork, dnsAddress)
|
||||
if err != nil {
|
||||
log.Fatal("cannot configure specifc resolver")
|
||||
}
|
||||
var resolver modelx.DNSResolver = r
|
||||
// now use resolver ...
|
||||
```
|
||||
|
||||
where `DNSNetwork` and `DNSAddress` configure the type
|
||||
of the resolver as follows:
|
||||
|
||||
- when `DNSNetwork` is `""` or `"system"`, `DNSAddress` does
|
||||
not matter and we use the system resolver
|
||||
|
||||
- when `DNSNetwork` is `"udp"`, `DNSAddress` is the address
|
||||
or domain name, with optional port, of the DNS server
|
||||
(e.g., `8.8.8.8:53`)
|
||||
|
||||
- when `DNSNetwork` is `"tcp"`, `DNSAddress` is the address
|
||||
or domain name, with optional port, of the DNS server
|
||||
(e.g., `8.8.8.8:53`)
|
||||
|
||||
- when `DNSNetwork` is `"dot"`, `DNSAddress` is the address
|
||||
or domain name, with optional port, of the DNS server
|
||||
(e.g., `8.8.8.8:853`)
|
||||
|
||||
- when `DNSNetwork` is `"doh"`, `DNSAddress` is the URL
|
||||
of the DNS server (e.g. `https://cloudflare-dns.com/dns-query`)
|
||||
|
||||
When the resolve is not the system one, we'll also be able
|
||||
to emit events when performing resolution. Otherwise, we'll
|
||||
just emit the `DNSResolveDone` event defined below.
|
||||
|
||||
Any resolver returned by this function may be configured to return the
|
||||
`dns_bogon_error` if any `LookupHost` lookup returns a bogon IP.
|
||||
|
||||
The package will also contain this function:
|
||||
|
||||
```Go
|
||||
func ChainResolvers(
|
||||
primary, secondary modelx.DNSResolver) modelx.DNSResolver
|
||||
```
|
||||
|
||||
where you can create a new resolver where `secondary` will be
|
||||
invoked whenever `primary` fails. This functionality allows
|
||||
us to be more resilient and bypass automatically certain types
|
||||
of censorship, e.g., a resolver returning a bogon.
|
||||
|
||||
The `github.com/ooni/probe-engine/netx` package MUST also provide an API such
|
||||
that you can construct and configure a `net.Dialer` replacement
|
||||
as follows:
|
||||
|
||||
```Go
|
||||
d := netx.NewDialerWithoutHandler()
|
||||
d.SetResolver(resolver)
|
||||
d.ForceSpecificSNI("www.kernel.org")
|
||||
d.SetCABundle("/etc/ssl/cert.pem")
|
||||
d.ForceSkipVerify()
|
||||
var dialer modelx.Dialer = d
|
||||
// now use dialer
|
||||
```
|
||||
|
||||
where `SetResolver` allows you to change the resolver,
|
||||
`ForceSpecificSNI` forces the TLS dials to use such SNI
|
||||
instead of using the provided domain, `SetCABundle`
|
||||
allows to set a specific CA bundle, and `ForceSkipVerify`
|
||||
allows to disable certificate verification. All these funcs
|
||||
MUST NOT be invoked once you're using the dialer.
|
||||
|
||||
The `github.com/ooni/probe-engine/netx` package MUST contain
|
||||
code so that we can do:
|
||||
|
||||
```Go
|
||||
t := netx.NewHTTPTransportWithProxyFunc(
|
||||
http.ProxyFromEnvironment,
|
||||
)
|
||||
t.SetResolver(resolver)
|
||||
t.ForceSpecificSNI("www.kernel.org")
|
||||
t.SetCABundle("/etc/ssl/cert.pem")
|
||||
t.ForceSkipVerify()
|
||||
var transport http.RoundTripper = t
|
||||
// now use transport
|
||||
```
|
||||
|
||||
where the functions have the same semantics as the
|
||||
namesake functions described before and the same caveats.
|
||||
|
||||
We also have syntactic sugar on top of that and legacy
|
||||
methods, but this fully describes the design.
|
||||
|
||||
## Structure of events
|
||||
|
||||
The `github.com/ooni/probe-engine/netx/modelx` will contain the
|
||||
definition of low-level events. We are interested in
|
||||
knowing the following:
|
||||
|
||||
1. the timing and result of each I/O operation.
|
||||
|
||||
2. the timing of HTTP events occurring during the
|
||||
lifecycle of an HTTP request.
|
||||
|
||||
3. the timing and result of the TLS handshake including
|
||||
the negotiated TLS version and other details such as
|
||||
what certificates the server has provided.
|
||||
|
||||
4. DNS events, e.g. queries and replies, generated
|
||||
as part of using DoT and DoH.
|
||||
|
||||
We will represent time as a `time.Duration` since the
|
||||
beginning configured either in the context or when
|
||||
constructing an object. The `modelx` package will also
|
||||
define the `Measurement` event as follows:
|
||||
|
||||
```Go
|
||||
type Measurement struct {
|
||||
Connect *ConnectEvent
|
||||
HTTPConnectionReady *HTTPConnectionReadyEvent
|
||||
HTTPRoundTripDone *HTTPRoundTripDoneEvent
|
||||
ResolveDone *ResolveDoneEvent
|
||||
TLSHandshakeDone *TLSHandshakeDoneEvent
|
||||
}
|
||||
```
|
||||
|
||||
The events above MUST always be present, but more
|
||||
events will likely be available. The structure
|
||||
will contain a pointer for every event that
|
||||
we support. The events processing code will check
|
||||
what pointer or pointers are not `nil` to known
|
||||
which event or events have occurred.
|
||||
|
||||
To simplify joining events together the following holds:
|
||||
|
||||
1. when we're establishing a new connection there is a nonzero
|
||||
`DialID` shared by `Connect` and `ResolveDone`
|
||||
|
||||
2. a new connection has a nonzero `ConnID` that is emitted
|
||||
as part of a successful `Connect` event
|
||||
|
||||
3. during an HTTP transaction there is a nonzero `TransactionID`
|
||||
shared by `HTTPConnectionReady` and `HTTPRoundTripDone`
|
||||
|
||||
4. if the TLS handshake is invoked by HTTP code it will have a
|
||||
nonzero `TrasactionID` otherwise a nonzero `ConnID`
|
||||
|
||||
5. the `HTTPConnectionReady` will also see the `ConnID`
|
||||
|
||||
6. when a transaction starts dialing, it will pass its
|
||||
`TransactionID` to `ResolveDone` and `Connect`
|
||||
|
||||
7. when we're dialing a connection for DoH, we pass the `DialID`
|
||||
to the `HTTPConnectionReady` event as well
|
||||
|
||||
Because of the following rules, it should always be possible
|
||||
to bind together events. Also, we define more events than the
|
||||
above, but they are ancillary to the above events. Also, the
|
||||
main reason why `HTTPConnectionReady` is here is because it is
|
||||
the event allowing to bind `ConnID` and `TransactionID`.
|
||||
@@ -0,0 +1,31 @@
|
||||
// Package connid contains code to generate the connectionID
|
||||
package connid
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Compute computes the connectionID from the local socket address. The zero
|
||||
// value is conventionally returned to mean "unknown".
|
||||
func Compute(network, address string) int64 {
|
||||
_, portstring, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
portnum, err := strconv.Atoi(portstring)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
if portnum < 0 || portnum > 65535 {
|
||||
return 0
|
||||
}
|
||||
result := int64(portnum)
|
||||
if strings.Contains(network, "udp") {
|
||||
result *= -1
|
||||
} else if !strings.Contains(network, "tcp") {
|
||||
result = 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package connid
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTCP(t *testing.T) {
|
||||
num := Compute("tcp", "1.2.3.4:6789")
|
||||
if num != 6789 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCP4(t *testing.T) {
|
||||
num := Compute("tcp4", "130.192.91.211:34566")
|
||||
if num != 34566 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCP6(t *testing.T) {
|
||||
num := Compute("tcp4", "[::1]:4444")
|
||||
if num != 4444 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP(t *testing.T) {
|
||||
num := Compute("udp", "1.2.3.4:6789")
|
||||
if num != -6789 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP4(t *testing.T) {
|
||||
num := Compute("udp4", "130.192.91.211:34566")
|
||||
if num != -34566 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDP6(t *testing.T) {
|
||||
num := Compute("udp6", "[::1]:4444")
|
||||
if num != -4444 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidAddress(t *testing.T) {
|
||||
num := Compute("udp6", "[::1]")
|
||||
if num != 0 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidPort(t *testing.T) {
|
||||
num := Compute("udp6", "[::1]:antani")
|
||||
if num != 0 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegativePort(t *testing.T) {
|
||||
num := Compute("udp6", "[::1]:-1")
|
||||
if num != 0 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargePort(t *testing.T) {
|
||||
num := Compute("udp6", "[::1]:65536")
|
||||
if num != 0 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidNetwork(t *testing.T) {
|
||||
num := Compute("unix", "[::1]:65531")
|
||||
if num != 0 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
// Package netx contains OONI's net extensions.
|
||||
package netx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
// Dialer performs measurements while dialing.
|
||||
type Dialer struct {
|
||||
Beginning time.Time
|
||||
Handler modelx.Handler
|
||||
Resolver modelx.DNSResolver
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
func newDialer(beginning time.Time, handler modelx.Handler) *Dialer {
|
||||
return &Dialer{
|
||||
Beginning: beginning,
|
||||
Handler: handler,
|
||||
Resolver: newResolverSystem(),
|
||||
TLSConfig: new(tls.Config),
|
||||
}
|
||||
}
|
||||
|
||||
// NewDialer creates a new Dialer instance.
|
||||
func NewDialer() *Dialer {
|
||||
return newDialer(time.Now(), handlers.NoHandler)
|
||||
}
|
||||
|
||||
// Dial creates a TCP or UDP connection. See net.Dial docs.
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func maybeWithMeasurementRoot(
|
||||
ctx context.Context, beginning time.Time, handler modelx.Handler,
|
||||
) context.Context {
|
||||
if modelx.ContextMeasurementRoot(ctx) != nil {
|
||||
return ctx
|
||||
}
|
||||
return modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{
|
||||
Beginning: beginning,
|
||||
Handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// newDNSDialer creates a new DNS dialer using the following chain:
|
||||
//
|
||||
// - DNSDialer (topmost)
|
||||
// - EmitterDialer
|
||||
// - ErrorWrapperDialer
|
||||
// - TimeoutDialer
|
||||
// - ByteCountingDialer
|
||||
// - net.Dialer
|
||||
//
|
||||
// If you have others needs, manually build the chain you need.
|
||||
func newDNSDialer(resolver dialer.Resolver) dialer.DNSDialer {
|
||||
return dialer.DNSDialer{
|
||||
Dialer: dialer.EmitterDialer{
|
||||
Dialer: dialer.ErrorWrapperDialer{
|
||||
Dialer: dialer.TimeoutDialer{
|
||||
Dialer: dialer.ByteCounterDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Resolver: resolver,
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext is like Dial but the context allows to interrupt a
|
||||
// pending connection attempt at any time.
|
||||
func (d *Dialer) DialContext(
|
||||
ctx context.Context, network, address string,
|
||||
) (conn net.Conn, err error) {
|
||||
ctx = maybeWithMeasurementRoot(ctx, d.Beginning, d.Handler)
|
||||
return newDNSDialer(d.Resolver).DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialTLS is like Dial, but creates TLS connections.
|
||||
func (d *Dialer) DialTLS(network, address string) (net.Conn, error) {
|
||||
return d.DialTLSContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
// newTLSDialer creates a new TLSDialer using:
|
||||
//
|
||||
// - EmitterTLSHandshaker (topmost)
|
||||
// - ErrorWrapperTLSHandshaker
|
||||
// - TimeoutTLSHandshaker
|
||||
// - SystemTLSHandshaker
|
||||
//
|
||||
// If you have others needs, manually build the chain you need.
|
||||
func newTLSDialer(d dialer.Dialer, config *tls.Config) dialer.TLSDialer {
|
||||
return dialer.TLSDialer{
|
||||
Config: config,
|
||||
Dialer: d,
|
||||
TLSHandshaker: dialer.EmitterTLSHandshaker{
|
||||
TLSHandshaker: dialer.ErrorWrapperTLSHandshaker{
|
||||
TLSHandshaker: dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DialTLSContext is like DialTLS, but with context
|
||||
func (d *Dialer) DialTLSContext(
|
||||
ctx context.Context, network, address string,
|
||||
) (net.Conn, error) {
|
||||
ctx = maybeWithMeasurementRoot(ctx, d.Beginning, d.Handler)
|
||||
return newTLSDialer(
|
||||
newDNSDialer(d.Resolver),
|
||||
d.TLSConfig,
|
||||
).DialTLSContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// SetCABundle configures the dialer to use a specific CA bundle. This
|
||||
// function is not goroutine safe. Make sure you call it before starting
|
||||
// to use this specific dialer.
|
||||
func (d *Dialer) SetCABundle(path string) error {
|
||||
cert, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if pool.AppendCertsFromPEM(cert) == false {
|
||||
return errors.New("AppendCertsFromPEM failed")
|
||||
}
|
||||
d.TLSConfig.RootCAs = pool
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceSpecificSNI forces using a specific SNI.
|
||||
func (d *Dialer) ForceSpecificSNI(sni string) error {
|
||||
d.TLSConfig.ServerName = sni
|
||||
return nil
|
||||
}
|
||||
|
||||
// ForceSkipVerify forces to skip certificate verification
|
||||
func (d *Dialer) ForceSkipVerify() error {
|
||||
d.TLSConfig.InsecureSkipVerify = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureDNS configures the DNS resolver. The network argument
|
||||
// selects the type of resolver. The address argument indicates the
|
||||
// resolver address and depends on the network.
|
||||
//
|
||||
// This functionality is not goroutine safe. You should only change
|
||||
// the DNS settings before starting to use the Dialer.
|
||||
//
|
||||
// The following is a list of all the possible network values:
|
||||
//
|
||||
// - "": behaves exactly like "system"
|
||||
//
|
||||
// - "system": this indicates that Go should use the system resolver
|
||||
// and prevents us from seeing any DNS packet. The value of the
|
||||
// address parameter is ignored when using "system". If you do
|
||||
// not ConfigureDNS, this is the default resolver used.
|
||||
//
|
||||
// - "udp": indicates that we should send queries using UDP. In this
|
||||
// case the address is a host, port UDP endpoint.
|
||||
//
|
||||
// - "tcp": like "udp" but we use TCP.
|
||||
//
|
||||
// - "dot": we use DNS over TLS (DoT). In this case the address is
|
||||
// the domain name of the DoT server.
|
||||
//
|
||||
// - "doh": we use DNS over HTTPS (DoH). In this case the address is
|
||||
// the URL of the DoH server.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// d.ConfigureDNS("system", "")
|
||||
// d.ConfigureDNS("udp", "8.8.8.8:53")
|
||||
// d.ConfigureDNS("tcp", "8.8.8.8:53")
|
||||
// d.ConfigureDNS("dot", "dns.quad9.net")
|
||||
// d.ConfigureDNS("doh", "https://cloudflare-dns.com/dns-query")
|
||||
func (d *Dialer) ConfigureDNS(network, address string) error {
|
||||
r, err := newResolver(d.Beginning, d.Handler, network, address)
|
||||
if err == nil {
|
||||
d.Resolver = r
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SetResolver is a more flexible way of configuring a resolver
|
||||
// that should perhaps be used instead of ConfigureDNS.
|
||||
func (d *Dialer) SetResolver(r modelx.DNSResolver) {
|
||||
d.Resolver = r
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package netx_test
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx"
|
||||
)
|
||||
|
||||
func TestDialerDial(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
conn, err := dialer.Dial("tcp", "www.google.com:80")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialerDialWithCustomResolver(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
resolver, err := netx.NewResolver("tcp", "1.1.1.1:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dialer.SetResolver(resolver)
|
||||
conn, err := dialer.Dial("tcp", "www.google.com:80")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialerDialWithConfigureDNS(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
err := dialer.ConfigureDNS("tcp", "1.1.1.1:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialer.Dial("tcp", "www.google.com:80")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialerDialTLS(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
conn, err := dialer.DialTLS("tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialerDialTLSForceSkipVerify(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
dialer.ForceSkipVerify()
|
||||
conn, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialerSetCABundleNonexisting(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
err := dialer.SetCABundle("testdata/cacert-nonexistent.pem")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialerSetCABundleInvalid(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
err := dialer.SetCABundle("testdata/cacert-invalid.pem")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialerSetCABundleWAI(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
err := dialer.SetCABundle("testdata/cacert.pem")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := dialer.DialTLS("tcp", "www.google.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
var target x509.UnknownAuthorityError
|
||||
if errors.As(err, &target) == false {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialerForceSpecificSNI(t *testing.T) {
|
||||
dialer := netx.NewDialer()
|
||||
err := dialer.ForceSpecificSNI("www.facebook.com")
|
||||
conn, err := dialer.DialTLS("tcp", "www.google.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
var target x509.HostnameError
|
||||
if errors.As(err, &target) == false {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil connection here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package dialid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
)
|
||||
|
||||
type contextkey struct{}
|
||||
|
||||
var id = atomicx.NewInt64()
|
||||
|
||||
// WithDialID returns a copy of ctx with DialID
|
||||
func WithDialID(ctx context.Context) context.Context {
|
||||
return context.WithValue(
|
||||
ctx, contextkey{}, id.Add(1),
|
||||
)
|
||||
}
|
||||
|
||||
// ContextDialID returns the DialID of the context, or zero
|
||||
func ContextDialID(ctx context.Context) int64 {
|
||||
id, _ := ctx.Value(contextkey{}).(int64)
|
||||
return id
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package dialid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
id := ContextDialID(ctx)
|
||||
if id != 0 {
|
||||
t.Fatal("unexpected ID for empty context")
|
||||
}
|
||||
ctx = WithDialID(ctx)
|
||||
id = ContextDialID(ctx)
|
||||
if id != 1 {
|
||||
t.Fatal("expected ID equal to 1")
|
||||
}
|
||||
ctx = WithDialID(ctx)
|
||||
id = ContextDialID(ctx)
|
||||
if id != 2 {
|
||||
t.Fatal("expected ID equal to 2")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Package handlers contains default modelx.Handler handlers.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/runtimex"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
)
|
||||
|
||||
type stdoutHandler struct{}
|
||||
|
||||
func (stdoutHandler) OnMeasurement(m modelx.Measurement) {
|
||||
data, err := json.Marshal(m)
|
||||
runtimex.PanicOnError(err, "unexpected json.Marshal failure")
|
||||
fmt.Printf("%s\n", string(data))
|
||||
}
|
||||
|
||||
// StdoutHandler is a Handler that logs on stdout.
|
||||
var StdoutHandler stdoutHandler
|
||||
|
||||
type noHandler struct{}
|
||||
|
||||
func (noHandler) OnMeasurement(m modelx.Measurement) {
|
||||
}
|
||||
|
||||
// NoHandler is a Handler that does not print anything
|
||||
var NoHandler noHandler
|
||||
|
||||
// SavingHandler saves the events it receives.
|
||||
type SavingHandler struct {
|
||||
mu sync.Mutex
|
||||
v []modelx.Measurement
|
||||
}
|
||||
|
||||
// OnMeasurement implements modelx.Handler.OnMeasurement
|
||||
func (sh *SavingHandler) OnMeasurement(ev modelx.Measurement) {
|
||||
sh.mu.Lock()
|
||||
sh.v = append(sh.v, ev)
|
||||
sh.mu.Unlock()
|
||||
}
|
||||
|
||||
// Read extracts the saved events
|
||||
func (sh *SavingHandler) Read() []modelx.Measurement {
|
||||
sh.mu.Lock()
|
||||
v := sh.v
|
||||
sh.v = nil
|
||||
sh.mu.Unlock()
|
||||
return v
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
)
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
handlers.NoHandler.OnMeasurement(modelx.Measurement{})
|
||||
handlers.StdoutHandler.OnMeasurement(modelx.Measurement{})
|
||||
saver := handlers.SavingHandler{}
|
||||
saver.OnMeasurement(modelx.Measurement{})
|
||||
events := saver.Read()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("invalid number of events")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package netx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/oldhttptransport"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// HTTPTransport performs single HTTP transactions and emits
|
||||
// measurement events as they happen.
|
||||
type HTTPTransport struct {
|
||||
Beginning time.Time
|
||||
Dialer *Dialer
|
||||
Handler modelx.Handler
|
||||
Transport *http.Transport
|
||||
roundTripper http.RoundTripper
|
||||
}
|
||||
|
||||
func newHTTPTransport(
|
||||
beginning time.Time,
|
||||
handler modelx.Handler,
|
||||
dialer *Dialer,
|
||||
disableKeepAlives bool,
|
||||
proxyFunc func(*http.Request) (*url.URL, error),
|
||||
) *HTTPTransport {
|
||||
baseTransport := &http.Transport{
|
||||
// The following values are copied from Go 1.12 docs and match
|
||||
// what should be used by the default transport
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
Proxy: proxyFunc,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
DisableKeepAlives: disableKeepAlives,
|
||||
}
|
||||
ooniTransport := oldhttptransport.New(baseTransport)
|
||||
// Configure h2 and make sure that the custom TLSConfig we use for dialing
|
||||
// is actually compatible with upgrading to h2. (This mainly means we
|
||||
// need to make sure we include "h2" in the NextProtos array.) Because
|
||||
// http2.ConfigureTransport only returns error when we have already
|
||||
// configured http2, it is safe to ignore the return value.
|
||||
http2.ConfigureTransport(baseTransport)
|
||||
// Since we're not going to use our dialer for TLS, the main purpose of
|
||||
// the following line is to make sure ForseSpecificSNI has impact on the
|
||||
// config we are going to use when doing TLS. The code is as such since
|
||||
// we used to force net/http through using dialer.DialTLS.
|
||||
dialer.TLSConfig = baseTransport.TLSClientConfig
|
||||
// Arrange the configuration such that we always use `dialer` for dialing
|
||||
// cleartext connections. The net/http code will dial TLS connections.
|
||||
baseTransport.DialContext = dialer.DialContext
|
||||
// Better for Cloudflare DNS and also better because we have less
|
||||
// noisy events and we can better understand what happened.
|
||||
baseTransport.MaxConnsPerHost = 1
|
||||
// The following (1) reduces the number of headers that Go will
|
||||
// automatically send for us and (2) ensures that we always receive
|
||||
// back the true headers, such as Content-Length. This change is
|
||||
// functional to OONI's goal of observing the network.
|
||||
baseTransport.DisableCompression = true
|
||||
return &HTTPTransport{
|
||||
Beginning: beginning,
|
||||
Dialer: dialer,
|
||||
Handler: handler,
|
||||
Transport: baseTransport,
|
||||
roundTripper: ooniTransport,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction, returning
|
||||
// a Response for the provided Request.
|
||||
func (t *HTTPTransport) RoundTrip(
|
||||
req *http.Request,
|
||||
) (resp *http.Response, err error) {
|
||||
ctx := maybeWithMeasurementRoot(req.Context(), t.Beginning, t.Handler)
|
||||
req = req.WithContext(ctx)
|
||||
resp, err = t.roundTripper.RoundTrip(req)
|
||||
// For safety wrap the error as modelx.HTTPRoundTripOperation but this
|
||||
// will only be used if the error chain does not contain any
|
||||
// other major operation failure. See errorx.ErrWrapper.
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
Error: err,
|
||||
Operation: errorx.HTTPRoundTripOperation,
|
||||
}.MaybeBuild()
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes the idle connections.
|
||||
func (t *HTTPTransport) CloseIdleConnections() {
|
||||
// Adapted from net/http code
|
||||
type closeIdler interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if tr, ok := t.roundTripper.(closeIdler); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPTransportWithProxyFunc creates a transport without any
|
||||
// handler attached using the specified proxy func.
|
||||
func NewHTTPTransportWithProxyFunc(
|
||||
proxyFunc func(*http.Request) (*url.URL, error),
|
||||
) *HTTPTransport {
|
||||
return newHTTPTransport(time.Now(), handlers.NoHandler, NewDialer(), false, proxyFunc)
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new HTTP transport.
|
||||
func NewHTTPTransport() *HTTPTransport {
|
||||
return NewHTTPTransportWithProxyFunc(http.ProxyFromEnvironment)
|
||||
}
|
||||
|
||||
// ConfigureDNS is exactly like netx.Dialer.ConfigureDNS.
|
||||
func (t *HTTPTransport) ConfigureDNS(network, address string) error {
|
||||
return t.Dialer.ConfigureDNS(network, address)
|
||||
}
|
||||
|
||||
// SetResolver is exactly like netx.Dialer.SetResolver.
|
||||
func (t *HTTPTransport) SetResolver(r modelx.DNSResolver) {
|
||||
t.Dialer.SetResolver(r)
|
||||
}
|
||||
|
||||
// SetCABundle internally calls netx.Dialer.SetCABundle and
|
||||
// therefore it has the same caveats and limitations.
|
||||
func (t *HTTPTransport) SetCABundle(path string) error {
|
||||
return t.Dialer.SetCABundle(path)
|
||||
}
|
||||
|
||||
// ForceSpecificSNI forces using a specific SNI.
|
||||
func (t *HTTPTransport) ForceSpecificSNI(sni string) error {
|
||||
return t.Dialer.ForceSpecificSNI(sni)
|
||||
}
|
||||
|
||||
// ForceSkipVerify forces to skip certificate verification
|
||||
func (t *HTTPTransport) ForceSkipVerify() error {
|
||||
return t.Dialer.ForceSkipVerify()
|
||||
}
|
||||
|
||||
// HTTPClient is a replacement for http.HTTPClient.
|
||||
type HTTPClient struct {
|
||||
// HTTPClient is the underlying client. Pass this client to existing code
|
||||
// that expects an *http.HTTPClient. For this reason we can't embed it.
|
||||
HTTPClient *http.Client
|
||||
|
||||
// Transport is the transport configured by NewClient to be used
|
||||
// by the HTTPClient field.
|
||||
Transport *HTTPTransport
|
||||
}
|
||||
|
||||
// NewHTTPClientWithProxyFunc creates a new client using the
|
||||
// specified proxyFunc for handling proxying.
|
||||
func NewHTTPClientWithProxyFunc(
|
||||
proxyFunc func(*http.Request) (*url.URL, error),
|
||||
) *HTTPClient {
|
||||
transport := NewHTTPTransportWithProxyFunc(proxyFunc)
|
||||
return &HTTPClient{
|
||||
HTTPClient: &http.Client{Transport: transport},
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHTTPClient creates a new client instance.
|
||||
func NewHTTPClient() *HTTPClient {
|
||||
return NewHTTPClientWithProxyFunc(http.ProxyFromEnvironment)
|
||||
}
|
||||
|
||||
// NewHTTPClientWithoutProxy creates a new client instance that
|
||||
// does not use any kind of proxy.
|
||||
func NewHTTPClientWithoutProxy() *HTTPClient {
|
||||
return NewHTTPClientWithProxyFunc(nil)
|
||||
}
|
||||
|
||||
// ConfigureDNS internally calls netx.Dialer.ConfigureDNS and
|
||||
// therefore it has the same caveats and limitations.
|
||||
func (c *HTTPClient) ConfigureDNS(network, address string) error {
|
||||
return c.Transport.ConfigureDNS(network, address)
|
||||
}
|
||||
|
||||
// SetResolver internally calls netx.Dialer.SetResolver
|
||||
func (c *HTTPClient) SetResolver(r modelx.DNSResolver) {
|
||||
c.Transport.SetResolver(r)
|
||||
}
|
||||
|
||||
// SetCABundle internally calls netx.Dialer.SetCABundle and
|
||||
// therefore it has the same caveats and limitations.
|
||||
func (c *HTTPClient) SetCABundle(path string) error {
|
||||
return c.Transport.SetCABundle(path)
|
||||
}
|
||||
|
||||
// ForceSpecificSNI forces using a specific SNI.
|
||||
func (c *HTTPClient) ForceSpecificSNI(sni string) error {
|
||||
return c.Transport.ForceSpecificSNI(sni)
|
||||
}
|
||||
|
||||
// ForceSkipVerify forces to skip certificate verification
|
||||
func (c *HTTPClient) ForceSkipVerify() error {
|
||||
return c.Transport.ForceSkipVerify()
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes the idle connections.
|
||||
func (c *HTTPClient) CloseIdleConnections() {
|
||||
c.Transport.CloseIdleConnections()
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package netx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func dowithclient(t *testing.T, client *netx.HTTPClient) {
|
||||
defer client.CloseIdleConnections()
|
||||
resp, err := client.HTTPClient.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClient(t *testing.T) {
|
||||
client := netx.NewHTTPClient()
|
||||
dowithclient(t, client)
|
||||
}
|
||||
|
||||
func TestHTTPClientAndTransport(t *testing.T) {
|
||||
client := netx.NewHTTPClient()
|
||||
client.Transport = netx.NewHTTPTransport()
|
||||
dowithclient(t, client)
|
||||
}
|
||||
|
||||
func TestHTTPClientConfigureDNS(t *testing.T) {
|
||||
client := netx.NewHTTPClientWithoutProxy()
|
||||
err := client.ConfigureDNS("udp", "1.1.1.1:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dowithclient(t, client)
|
||||
}
|
||||
|
||||
func TestHTTPClientSetResolver(t *testing.T) {
|
||||
client := netx.NewHTTPClientWithoutProxy()
|
||||
client.SetResolver(new(net.Resolver))
|
||||
dowithclient(t, client)
|
||||
}
|
||||
|
||||
func TestHTTPClientSetCABundle(t *testing.T) {
|
||||
client := netx.NewHTTPClientWithoutProxy()
|
||||
err := client.SetCABundle("testdata/cacert.pem")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := client.HTTPClient.Get("https://www.google.com")
|
||||
var target x509.UnknownAuthorityError
|
||||
if errors.As(err, &target) == false {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientForceSpecificSNI(t *testing.T) {
|
||||
client := netx.NewHTTPClientWithoutProxy()
|
||||
err := client.ForceSpecificSNI("www.facebook.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := client.HTTPClient.Get("https://www.google.com")
|
||||
var target x509.HostnameError
|
||||
if errors.As(err, &target) == false {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPClientForceSkipVerify(t *testing.T) {
|
||||
client := netx.NewHTTPClientWithoutProxy()
|
||||
client.ForceSkipVerify()
|
||||
resp, err := client.HTTPClient.Get("https://self-signed.badssl.com/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non nil response here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPNewClientProxy(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(451)
|
||||
}))
|
||||
defer server.Close()
|
||||
client := netx.NewHTTPClientWithoutProxy()
|
||||
httpProxyTestMain(t, client.HTTPClient, 200)
|
||||
client = netx.NewHTTPClientWithProxyFunc(func(req *http.Request) (*url.URL, error) {
|
||||
return url.Parse(server.URL)
|
||||
})
|
||||
httpProxyTestMain(t, client.HTTPClient, 451)
|
||||
}
|
||||
|
||||
const httpProxyTestsURL = "http://explorer.ooni.org"
|
||||
|
||||
func httpProxyTestMain(t *testing.T, client *http.Client, expect int) {
|
||||
req, err := http.NewRequest("GET", httpProxyTestsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != expect {
|
||||
t.Fatal("unexpected status code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportTimeout(t *testing.T) {
|
||||
client := &http.Client{Transport: netx.NewHTTPTransport()}
|
||||
req, err := http.NewRequest("GET", "https://www.google.com", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if !strings.HasSuffix(err.Error(), errorx.FailureGenericTimeoutError) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPTransportFailure(t *testing.T) {
|
||||
client := &http.Client{Transport: netx.NewHTTPTransport()}
|
||||
// This fails the request because we attempt to speak cleartext HTTP with
|
||||
// a server that instead is expecting TLS.
|
||||
resp, err := client.Get("http://www.google.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil response here")
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
@@ -0,0 +1,699 @@
|
||||
// Package modelx contains the data modelx.
|
||||
package modelx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Measurement contains zero or more events. Do not assume that at any
|
||||
// time a Measurement will only contain a single event. When a Measurement
|
||||
// contains an event, the corresponding pointer is non nil.
|
||||
//
|
||||
// All events contain a time measurement, `DurationSinceBeginning`, that
|
||||
// uses a monotonic clock and is relative to a preconfigured "zero".
|
||||
type Measurement struct {
|
||||
// DNS events
|
||||
//
|
||||
// These are all identifed by a DialID. A ResolveEvent optionally has
|
||||
// a reference to the TransactionID that started the dial, if any.
|
||||
ResolveStart *ResolveStartEvent `json:",omitempty"`
|
||||
DNSQuery *DNSQueryEvent `json:",omitempty"`
|
||||
DNSReply *DNSReplyEvent `json:",omitempty"`
|
||||
ResolveDone *ResolveDoneEvent `json:",omitempty"`
|
||||
|
||||
// Syscalls
|
||||
//
|
||||
// These are all identified by a ConnID. A ConnectEvent has a reference
|
||||
// to the DialID that caused this connection to be attempted.
|
||||
//
|
||||
// Because they are syscalls, we don't split them in start/done pairs
|
||||
// but we record the amount of time in which we were blocked.
|
||||
Connect *ConnectEvent `json:",omitempty"`
|
||||
Read *ReadEvent `json:",omitempty"`
|
||||
Write *WriteEvent `json:",omitempty"`
|
||||
Close *CloseEvent `json:",omitempty"`
|
||||
|
||||
// TLS events
|
||||
//
|
||||
// Identified by either ConnID or TransactionID. In the former case
|
||||
// the TLS handshake is managed by net code, in the latter case it is
|
||||
// instead managed by Golang's HTTP engine. It should not happen to
|
||||
// have both ConnID and TransactionID different from zero.
|
||||
TLSHandshakeStart *TLSHandshakeStartEvent `json:",omitempty"`
|
||||
TLSHandshakeDone *TLSHandshakeDoneEvent `json:",omitempty"`
|
||||
|
||||
// HTTP roundtrip events
|
||||
//
|
||||
// A round trip starts when we need a connection to send a request
|
||||
// and ends when we've got the response headers or an error.
|
||||
//
|
||||
// The identifer here is TransactionID, where the transaction is
|
||||
// like the round trip except that it terminates when we've finished
|
||||
// reading the whole response body.
|
||||
HTTPRoundTripStart *HTTPRoundTripStartEvent `json:",omitempty"`
|
||||
HTTPConnectionReady *HTTPConnectionReadyEvent `json:",omitempty"`
|
||||
HTTPRequestHeader *HTTPRequestHeaderEvent `json:",omitempty"`
|
||||
HTTPRequestHeadersDone *HTTPRequestHeadersDoneEvent `json:",omitempty"`
|
||||
HTTPRequestDone *HTTPRequestDoneEvent `json:",omitempty"`
|
||||
HTTPResponseStart *HTTPResponseStartEvent `json:",omitempty"`
|
||||
HTTPRoundTripDone *HTTPRoundTripDoneEvent `json:",omitempty"`
|
||||
|
||||
// HTTP body events
|
||||
//
|
||||
// They are identified by the TransactionID. You are not going to see
|
||||
// these events if you don't fully read response bodies. But that's
|
||||
// something you are supposed to do, so you should be fine.
|
||||
HTTPResponseBodyPart *HTTPResponseBodyPartEvent `json:",omitempty"`
|
||||
HTTPResponseDone *HTTPResponseDoneEvent `json:",omitempty"`
|
||||
|
||||
// Extension events.
|
||||
//
|
||||
// The purpose of these events is to give us some flexibility to
|
||||
// experiment with message formats before blessing something as
|
||||
// part of the official API of the library. The intent however is
|
||||
// to avoid keeping something as an extension for a long time.
|
||||
Extension *ExtensionEvent `json:",omitempty"`
|
||||
}
|
||||
|
||||
// CloseEvent is emitted when the CLOSE syscall returns.
|
||||
type CloseEvent struct {
|
||||
// ConnID is the identifier of this connection.
|
||||
ConnID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is the error returned by CLOSE.
|
||||
Error error
|
||||
|
||||
// SyscallDuration is the number of nanoseconds we were
|
||||
// blocked waiting for the syscall to return.
|
||||
SyscallDuration time.Duration
|
||||
}
|
||||
|
||||
// ConnectEvent is emitted when the CONNECT syscall returns.
|
||||
type ConnectEvent struct {
|
||||
// ConnID is the identifier of this connection.
|
||||
ConnID int64
|
||||
|
||||
// DialID is the identifier of the dial operation as
|
||||
// part of which we called CONNECT.
|
||||
DialID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is the error returned by CONNECT.
|
||||
Error error
|
||||
|
||||
// Network is the network we're dialing for, e.g. "tcp"
|
||||
Network string
|
||||
|
||||
// RemoteAddress is the remote IP address we're dialing for
|
||||
RemoteAddress string
|
||||
|
||||
// SyscallDuration is the number of nanoseconds we were
|
||||
// blocked waiting for the syscall to return.
|
||||
SyscallDuration time.Duration
|
||||
|
||||
// TransactionID is the ID of the HTTP transaction that caused the
|
||||
// current dial to run, or zero if there's no such transaction.
|
||||
TransactionID int64 `json:",omitempty"`
|
||||
}
|
||||
|
||||
// DNSQueryEvent is emitted when we send a DNS query.
|
||||
type DNSQueryEvent struct {
|
||||
// Data is the raw data we're sending to the server.
|
||||
Data []byte
|
||||
|
||||
// DialID is the identifier of the dial operation as
|
||||
// part of which we're sending this query.
|
||||
DialID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Msg is the parsed message we're sending to the server.
|
||||
Msg *dns.Msg `json:"-"`
|
||||
}
|
||||
|
||||
// DNSReplyEvent is emitted when we receive byte that are
|
||||
// successfully parsed into a DNS reply.
|
||||
type DNSReplyEvent struct {
|
||||
// Data is the raw data we've received and parsed.
|
||||
Data []byte
|
||||
|
||||
// DialID is the identifier of the dial operation as
|
||||
// part of which we've received this query.
|
||||
DialID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Msg is the received parsed message.
|
||||
Msg *dns.Msg `json:"-"`
|
||||
}
|
||||
|
||||
// ExtensionEvent is emitted by a netx extension.
|
||||
type ExtensionEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Key is the unique identifier of the event. A good rule of
|
||||
// thumb is to use `${packageName}.${messageType}`.
|
||||
Key string
|
||||
|
||||
// Severity of the emitted message ("WARN", "INFO", "DEBUG")
|
||||
Severity string
|
||||
|
||||
// TransactionID is the identifier of this transaction, provided
|
||||
// that we have an active one, otherwise is zero.
|
||||
TransactionID int64
|
||||
|
||||
// Value is the extension dependent message. This message
|
||||
// has the only requirement of being JSON serializable.
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// HTTPRoundTripStartEvent is emitted when the HTTP transport
|
||||
// starts the HTTP "round trip". That is, when the transport
|
||||
// receives from the HTTP client a request to sent. The round
|
||||
// trip terminates when we receive headers. What we call the
|
||||
// "transaction" here starts with this event and does not finish
|
||||
// until we have also finished receiving the response body.
|
||||
type HTTPRoundTripStartEvent struct {
|
||||
// DialID is the identifier of the dial operation that
|
||||
// caused this round trip to start. Typically, this occures
|
||||
// when doing DoH. If zero, means that this round trip has
|
||||
// not been started by any dial operation.
|
||||
DialID int64 `json:",omitempty"`
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Method is the request method
|
||||
Method string
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
|
||||
// URL is the request URL
|
||||
URL string
|
||||
}
|
||||
|
||||
// HTTPConnectionReadyEvent is emitted when the HTTP transport has got
|
||||
// a connection which is ready for sending the request.
|
||||
type HTTPConnectionReadyEvent struct {
|
||||
// ConnID is the identifier of the connection that is ready. Knowing
|
||||
// this ID allows you to bind HTTP events to net events.
|
||||
ConnID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
// HTTPRequestHeaderEvent is emitted when we have written a header,
|
||||
// where written typically means just "buffered".
|
||||
type HTTPRequestHeaderEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Key is the header key
|
||||
Key string
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
|
||||
// Value is the value/values of this header.
|
||||
Value []string
|
||||
}
|
||||
|
||||
// HTTPRequestHeadersDoneEvent is emitted when we have written, or more
|
||||
// correctly, "buffered" all headers.
|
||||
type HTTPRequestHeadersDoneEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Headers contain the original request headers. This is included
|
||||
// here to make this event actionable without needing to join it with
|
||||
// other events, i.e., to simplify logging.
|
||||
Headers http.Header
|
||||
|
||||
// Method is the original request method. This is here
|
||||
// for the same reason of Headers.
|
||||
Method string
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
|
||||
// URL is the original request URL. This is here
|
||||
// for the same reason of Headers. We use an object
|
||||
// rather than a string, because here you want to
|
||||
// use specific subfields directly for logging.
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
// HTTPRequestDoneEvent is emitted when we have sent the request
|
||||
// body or there has been any failure in sending the request.
|
||||
type HTTPRequestDoneEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is non nil if we could not write the request headers or
|
||||
// some specific part of the body. When this step of writing
|
||||
// the request fails, of course the whole transaction will fail
|
||||
// as well. This error however tells you that the issue was
|
||||
// when sending the request, not when receiving the response.
|
||||
Error error
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
// HTTPResponseStartEvent is emitted when we receive the byte from
|
||||
// the response on the wire.
|
||||
type HTTPResponseStartEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
const defaultBodySnapSize int64 = 1 << 20
|
||||
|
||||
// ComputeBodySnapSize computes the body snap size. If snapSize is negative
|
||||
// we return MaxInt64. If it's zero we return the default snap size. Otherwise
|
||||
// the value of snapSize is returned.
|
||||
func ComputeBodySnapSize(snapSize int64) int64 {
|
||||
if snapSize < 0 {
|
||||
snapSize = math.MaxInt64
|
||||
} else if snapSize == 0 {
|
||||
snapSize = defaultBodySnapSize
|
||||
}
|
||||
return snapSize
|
||||
}
|
||||
|
||||
// HTTPRoundTripDoneEvent is emitted at the end of the round trip. Either
|
||||
// we have an error, or a valid HTTP response. An error could be caused
|
||||
// either by not being able to send the request or not being able to receive
|
||||
// the response. Note that here errors are network/TLS/dialing errors or
|
||||
// protocol violation errors. No status code will cause errors here.
|
||||
type HTTPRoundTripDoneEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is the overall result of the round trip. If non-nil, checking
|
||||
// also the result of HTTPResponseDone helps to disambiguate whether the
|
||||
// error was in sending the request or receiving the response.
|
||||
Error error
|
||||
|
||||
// RequestBodySnap contains a snap of the request body. We'll
|
||||
// not read more than SnapSize bytes of the body. Because typically
|
||||
// you control the request bodies that you send, perhaps think
|
||||
// about saving them using other means.
|
||||
RequestBodySnap []byte
|
||||
|
||||
// RequestHeaders contain the original request headers. This is
|
||||
// included here to make this event actionable without needing to
|
||||
// join it with other events, as it's too important.
|
||||
RequestHeaders http.Header
|
||||
|
||||
// RequestMethod is the original request method. This is here
|
||||
// for the same reason of RequestHeaders.
|
||||
RequestMethod string
|
||||
|
||||
// RequestURL is the original request URL. This is here
|
||||
// for the same reason of RequestHeaders.
|
||||
RequestURL string
|
||||
|
||||
// ResponseBodySnap is like RequestBodySnap but for the response. You
|
||||
// can still save the whole body by just reading it, if this
|
||||
// is something that you need to do. We're using the snaps here
|
||||
// mainly to log small stuff like DoH and redirects.
|
||||
ResponseBodySnap []byte
|
||||
|
||||
// ResponseHeaders contains the response headers if error is nil.
|
||||
ResponseHeaders http.Header
|
||||
|
||||
// ResponseProto contains the response protocol
|
||||
ResponseProto string
|
||||
|
||||
// ResponseStatusCode contains the HTTP status code if error is nil.
|
||||
ResponseStatusCode int64
|
||||
|
||||
// MaxBodySnapSize is the maximum size of the bodies snapshot.
|
||||
MaxBodySnapSize int64
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
// HTTPResponseBodyPartEvent is emitted after we have received
|
||||
// a part of the response body, or an error reading it. Note that
|
||||
// bytes read here does not necessarily match bytes returned by
|
||||
// ReadEvent because of (1) transparent gzip decompression by Go,
|
||||
// (2) HTTP overhead (headers and chunked body), (3) TLS. This
|
||||
// is the reason why we also want to record the error here rather
|
||||
// than just recording the error in ReadEvent.
|
||||
//
|
||||
// Note that you are not going to see this event if you do not
|
||||
// drain the response body, which you're supposed to do, tho.
|
||||
type HTTPResponseBodyPartEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error indicates whether we could not read a part of the body
|
||||
Error error
|
||||
|
||||
// Data is a reference to the body we've just read.
|
||||
Data []byte
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
// HTTPResponseDoneEvent is emitted after we have received the body,
|
||||
// when the response body is being closed.
|
||||
//
|
||||
// Note that you are not going to see this event if you do not
|
||||
// drain the response body, which you're supposed to do, tho.
|
||||
type HTTPResponseDoneEvent struct {
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// TransactionID is the identifier of this transaction
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
// ReadEvent is emitted when the READ/RECV syscall returns.
|
||||
type ReadEvent struct {
|
||||
// ConnID is the identifier of this connection.
|
||||
ConnID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is the error returned by READ/RECV.
|
||||
Error error
|
||||
|
||||
// NumBytes is the number of bytes received, which may in
|
||||
// principle also be nonzero on error.
|
||||
NumBytes int64
|
||||
|
||||
// SyscallDuration is the number of nanoseconds we were
|
||||
// blocked waiting for the syscall to return.
|
||||
SyscallDuration time.Duration
|
||||
}
|
||||
|
||||
// ResolveStartEvent is emitted when we start resolving a domain name.
|
||||
type ResolveStartEvent struct {
|
||||
// DialID is the identifier of the dial operation as
|
||||
// part of which we're resolving this domain.
|
||||
DialID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Hostname is the domain name to resolve.
|
||||
Hostname string
|
||||
|
||||
// TransactionID is the ID of the HTTP transaction that caused the
|
||||
// current dial to run, or zero if there's no such transaction.
|
||||
TransactionID int64 `json:",omitempty"`
|
||||
|
||||
// TransportNetwork is the network used by the DNS transport, which
|
||||
// can be one of "doh", "dot", "tcp", "udp", or "system".
|
||||
TransportNetwork string
|
||||
|
||||
// TransportAddress is the address used by the DNS transport, which
|
||||
// is of course relative to the TransportNetwork.
|
||||
TransportAddress string
|
||||
}
|
||||
|
||||
// ResolveDoneEvent is emitted when we know the IP addresses of a
|
||||
// specific domain name, or the resolution failed.
|
||||
type ResolveDoneEvent struct {
|
||||
// Addresses is the list of returned addresses (empty on error).
|
||||
Addresses []string
|
||||
|
||||
// ContainsBogons indicates whether Addresses contains one
|
||||
// or more IP addresses that classify as bogons.
|
||||
ContainsBogons bool
|
||||
|
||||
// DialID is the identifier of the dial operation as
|
||||
// part of which we're resolving this domain.
|
||||
DialID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is the result of the dial operation.
|
||||
Error error
|
||||
|
||||
// Hostname is the domain name to resolve.
|
||||
Hostname string
|
||||
|
||||
// TransactionID is the ID of the HTTP transaction that caused the
|
||||
// current dial to run, or zero if there's no such transaction.
|
||||
TransactionID int64 `json:",omitempty"`
|
||||
|
||||
// TransportNetwork is the network used by the DNS transport, which
|
||||
// can be one of "doh", "dot", "tcp", "udp", or "system".
|
||||
TransportNetwork string
|
||||
|
||||
// TransportAddress is the address used by the DNS transport, which
|
||||
// is of course relative to the TransportNetwork.
|
||||
TransportAddress string
|
||||
}
|
||||
|
||||
// X509Certificate is an x.509 certificate.
|
||||
type X509Certificate struct {
|
||||
// Data contains the certificate bytes in DER format.
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// TLSConnectionState contains the TLS connection state.
|
||||
type TLSConnectionState struct {
|
||||
CipherSuite uint16
|
||||
NegotiatedProtocol string
|
||||
PeerCertificates []X509Certificate
|
||||
Version uint16
|
||||
}
|
||||
|
||||
// NewTLSConnectionState creates a new TLSConnectionState.
|
||||
func NewTLSConnectionState(s tls.ConnectionState) TLSConnectionState {
|
||||
return TLSConnectionState{
|
||||
CipherSuite: s.CipherSuite,
|
||||
NegotiatedProtocol: s.NegotiatedProtocol,
|
||||
PeerCertificates: SimplifyCerts(s.PeerCertificates),
|
||||
Version: s.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// SimplifyCerts simplifies a certificate chain for archival
|
||||
func SimplifyCerts(in []*x509.Certificate) (out []X509Certificate) {
|
||||
for _, cert := range in {
|
||||
out = append(out, X509Certificate{
|
||||
Data: cert.Raw,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TLSHandshakeStartEvent is emitted when the TLS handshake starts.
|
||||
type TLSHandshakeStartEvent struct {
|
||||
// ConnID is the ID of the connection that started the TLS
|
||||
// handshake, or zero if we don't know it. Typically, it is
|
||||
// zero for connections managed by the HTTP transport, for
|
||||
// which we know instead the TransactionID.
|
||||
ConnID int64 `json:",omitempty"`
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// SNI is the SNI used when we force a specific SNI.
|
||||
SNI string
|
||||
|
||||
// TransactionID is the ID of the transaction that started
|
||||
// this TLS handshake, or zero if we don't know it. Typically,
|
||||
// it is zero for explicit dials, and it's nonzero instead
|
||||
// when a connection is managed by HTTP code.
|
||||
TransactionID int64 `json:",omitempty"`
|
||||
}
|
||||
|
||||
// TLSHandshakeDoneEvent is emitted when conn.Handshake returns.
|
||||
type TLSHandshakeDoneEvent struct {
|
||||
// ConnectionState is the TLS connection state. Depending on the
|
||||
// error type, some fields may have little meaning.
|
||||
ConnectionState TLSConnectionState
|
||||
|
||||
// ConnID is the ID of the connection that started the TLS
|
||||
// handshake, or zero if we don't know it. Typically, it is
|
||||
// zero for connections managed by the HTTP transport, for
|
||||
// which we know instead the TransactionID.
|
||||
ConnID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is the result of the TLS handshake.
|
||||
Error error
|
||||
|
||||
// TransactionID is the ID of the transaction that started
|
||||
// this TLS handshake, or zero if we don't know it. Typically,
|
||||
// it is zero for explicit dials, and it's nonzero instead
|
||||
// when a connection is managed by HTTP code.
|
||||
TransactionID int64
|
||||
}
|
||||
|
||||
// WriteEvent is emitted when the WRITE/SEND syscall returns.
|
||||
type WriteEvent struct {
|
||||
// ConnID is the identifier of this connection.
|
||||
ConnID int64
|
||||
|
||||
// DurationSinceBeginning is the number of nanoseconds since
|
||||
// the time configured as the "zero" time.
|
||||
DurationSinceBeginning time.Duration
|
||||
|
||||
// Error is the error returned by WRITE/SEND.
|
||||
Error error
|
||||
|
||||
// NumBytes is the number of bytes sent, which may in
|
||||
// principle also be nonzero on error.
|
||||
NumBytes int64
|
||||
|
||||
// SyscallDuration is the number of nanoseconds we were
|
||||
// blocked waiting for the syscall to return.
|
||||
SyscallDuration time.Duration
|
||||
}
|
||||
|
||||
// Handler handles measurement events.
|
||||
type Handler interface {
|
||||
// OnMeasurement is called when an event occurs. There will be no
|
||||
// events after the code that is using the modified Dialer, Transport,
|
||||
// or Client is returned. OnMeasurement may be called by background
|
||||
// goroutines and OnMeasurement calls may happen concurrently.
|
||||
OnMeasurement(Measurement)
|
||||
}
|
||||
|
||||
// DNSResolver is a DNS resolver. The *net.Resolver used by Go implements
|
||||
// this interface, but other implementations are possible.
|
||||
type DNSResolver interface {
|
||||
// LookupHost resolves a hostname to a list of IP addresses.
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
}
|
||||
|
||||
// Dialer is a dialer for network connections.
|
||||
type Dialer interface {
|
||||
// Dial dials a new connection
|
||||
Dial(network, address string) (net.Conn, error)
|
||||
|
||||
// DialContext is like Dial but with context
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// TLSDialer is a dialer for TLS connections.
|
||||
type TLSDialer interface {
|
||||
// DialTLS dials a new TLS connection
|
||||
DialTLS(network, address string) (net.Conn, error)
|
||||
|
||||
// DialTLSContext is like DialTLS but with context
|
||||
DialTLSContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// MeasurementRoot is the measurement root.
|
||||
//
|
||||
// If you attach this to a context, we'll use it rather than using
|
||||
// the beginning and hndler configured with resolvers, dialers, HTTP
|
||||
// clients, and HTTP transports. By attaching a measurement root to
|
||||
// a context, you can naturally split events by HTTP round trip.
|
||||
type MeasurementRoot struct {
|
||||
// Beginning is the "zero" used to compute the elapsed time.
|
||||
Beginning time.Time
|
||||
|
||||
// Handler is the handler that will handle events.
|
||||
Handler Handler
|
||||
|
||||
// MaxBodySnapSize is the maximum size after which we'll stop
|
||||
// reading request and response bodies. They will of course
|
||||
// be fully transmitted, but we'll save only MaxBodySnapSize
|
||||
// bytes as part of the event stream. If this value is negative,
|
||||
// we use math.MaxInt64. If the value is zero, we use a
|
||||
// reasonable large value. Otherwise, we'll use this value.
|
||||
MaxBodySnapSize int64
|
||||
}
|
||||
|
||||
type measurementRootContextKey struct{}
|
||||
|
||||
type dummyHandler struct{}
|
||||
|
||||
func (*dummyHandler) OnMeasurement(Measurement) {}
|
||||
|
||||
// ContextMeasurementRoot returns the MeasurementRoot configured in the
|
||||
// provided context, or a nil pointer, if not set.
|
||||
func ContextMeasurementRoot(ctx context.Context) *MeasurementRoot {
|
||||
root, _ := ctx.Value(measurementRootContextKey{}).(*MeasurementRoot)
|
||||
return root
|
||||
}
|
||||
|
||||
// ContextMeasurementRootOrDefault returns the MeasurementRoot configured in
|
||||
// the provided context, or a working, dummy, MeasurementRoot otherwise.
|
||||
func ContextMeasurementRootOrDefault(ctx context.Context) *MeasurementRoot {
|
||||
root := ContextMeasurementRoot(ctx)
|
||||
if root == nil {
|
||||
root = &MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: &dummyHandler{},
|
||||
}
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
// WithMeasurementRoot returns a copy of the context with the
|
||||
// configured MeasurementRoot set. Panics if the provided root
|
||||
// is a nil pointer, like httptrace.WithClientTrace.
|
||||
//
|
||||
// Merging more than one root is not supported. Setting again
|
||||
// the root is just going to replace the original root.
|
||||
func WithMeasurementRoot(
|
||||
ctx context.Context, root *MeasurementRoot,
|
||||
) context.Context {
|
||||
if root == nil {
|
||||
panic("nil measurement root")
|
||||
}
|
||||
return context.WithValue(
|
||||
ctx, measurementRootContextKey{}, root,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package modelx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func TestNewTLSConnectionState(t *testing.T) {
|
||||
conn, err := tls.Dial("tcp", "www.google.com:443", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
state := NewTLSConnectionState(conn.ConnectionState())
|
||||
if len(state.PeerCertificates) < 1 {
|
||||
t.Fatal("too few certificates")
|
||||
}
|
||||
if state.Version < tls.VersionSSL30 || state.Version > 0x0304 /*tls.VersionTLS13*/ {
|
||||
t.Fatal("unexpected TLS version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeasurementRoot(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
if ContextMeasurementRoot(ctx) != nil {
|
||||
t.Fatal("unexpected value for ContextMeasurementRoot")
|
||||
}
|
||||
if ContextMeasurementRootOrDefault(ctx) == nil {
|
||||
t.Fatal("unexpected value ContextMeasurementRootOrDefault")
|
||||
}
|
||||
handler := &dummyHandler{}
|
||||
root := &MeasurementRoot{
|
||||
Handler: handler,
|
||||
Beginning: time.Time{},
|
||||
}
|
||||
ctx = WithMeasurementRoot(ctx, root)
|
||||
v := ContextMeasurementRoot(ctx)
|
||||
if v != root {
|
||||
t.Fatal("unexpected ContextMeasurementRoot value")
|
||||
}
|
||||
v = ContextMeasurementRootOrDefault(ctx)
|
||||
if v != root {
|
||||
t.Fatal("unexpected ContextMeasurementRoot value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeasurementRootWithMeasurementRootPanic(t *testing.T) {
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("expected panic")
|
||||
}
|
||||
}()
|
||||
ctx := context.Background()
|
||||
ctx = WithMeasurementRoot(ctx, nil)
|
||||
}
|
||||
|
||||
func TestErrWrapperPublicAPI(t *testing.T) {
|
||||
child := errors.New("mocked error")
|
||||
wrapper := &errorx.ErrWrapper{
|
||||
Failure: "moobar",
|
||||
WrappedErr: child,
|
||||
}
|
||||
if wrapper.Error() != "moobar" {
|
||||
t.Fatal("The Error() method is misbehaving")
|
||||
}
|
||||
if wrapper.Unwrap() != child {
|
||||
t.Fatal("The Unwrap() method is misbehaving")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeBodySnapSize(t *testing.T) {
|
||||
if ComputeBodySnapSize(-1) != math.MaxInt64 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if ComputeBodySnapSize(0) != defaultBodySnapSize {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
if ComputeBodySnapSize(127) != 127 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
)
|
||||
|
||||
// BodyTracer performs single HTTP transactions and emits
|
||||
// measurement events as they happen.
|
||||
type BodyTracer struct {
|
||||
Transport http.RoundTripper
|
||||
}
|
||||
|
||||
// NewBodyTracer creates a new Transport.
|
||||
func NewBodyTracer(roundTripper http.RoundTripper) *BodyTracer {
|
||||
return &BodyTracer{Transport: roundTripper}
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction, returning
|
||||
// a Response for the provided Request.
|
||||
func (t *BodyTracer) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
||||
resp, err = t.Transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// "The http Client and Transport guarantee that Body is always
|
||||
// non-nil, even on responses without a body or responses with
|
||||
// a zero-length body." (from the docs)
|
||||
resp.Body = &bodyWrapper{
|
||||
ReadCloser: resp.Body,
|
||||
root: modelx.ContextMeasurementRootOrDefault(req.Context()),
|
||||
tid: transactionid.ContextTransactionID(req.Context()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes the idle connections.
|
||||
func (t *BodyTracer) CloseIdleConnections() {
|
||||
// Adapted from net/http code
|
||||
type closeIdler interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if tr, ok := t.Transport.(closeIdler); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
type bodyWrapper struct {
|
||||
io.ReadCloser
|
||||
root *modelx.MeasurementRoot
|
||||
tid int64
|
||||
}
|
||||
|
||||
func (bw *bodyWrapper) Read(b []byte) (n int, err error) {
|
||||
n, err = bw.ReadCloser.Read(b)
|
||||
bw.root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPResponseBodyPart: &modelx.HTTPResponseBodyPartEvent{
|
||||
// "Read reads up to len(p) bytes into p. It returns the number of
|
||||
// bytes read (0 <= n <= len(p)) and any error encountered."
|
||||
Data: b[:n],
|
||||
Error: err,
|
||||
DurationSinceBeginning: time.Now().Sub(bw.root.Beginning),
|
||||
TransactionID: bw.tid,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (bw *bodyWrapper) Close() (err error) {
|
||||
err = bw.ReadCloser.Close()
|
||||
bw.root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPResponseDone: &modelx.HTTPResponseDoneEvent{
|
||||
DurationSinceBeginning: time.Now().Sub(bw.root.Beginning),
|
||||
TransactionID: bw.tid,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBodyTracerSuccess(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: NewBodyTracer(http.DefaultTransport),
|
||||
}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func TestBodyTracerFailure(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: NewBodyTracer(http.DefaultTransport),
|
||||
}
|
||||
// This fails the request because we attempt to speak cleartext HTTP with
|
||||
// a server that instead is expecting TLS.
|
||||
resp, err := client.Get("http://www.google.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil response here")
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Package oldhttptransport contains HTTP transport extensions. Here we
|
||||
// define a http.Transport that emits events.
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Transport performs single HTTP transactions and emits
|
||||
// measurement events as they happen.
|
||||
type Transport struct {
|
||||
roundTripper http.RoundTripper
|
||||
}
|
||||
|
||||
// New creates a new Transport.
|
||||
func New(roundTripper http.RoundTripper) *Transport {
|
||||
return &Transport{
|
||||
roundTripper: NewTransactioner(NewBodyTracer(
|
||||
NewTraceTripper(roundTripper))),
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction, returning
|
||||
// a Response for the provided Request.
|
||||
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
||||
// Make sure we're not sending Go's default User-Agent
|
||||
// if the user has configured no user agent
|
||||
if req.Header.Get("User-Agent") == "" {
|
||||
req.Header["User-Agent"] = nil
|
||||
}
|
||||
return t.roundTripper.RoundTrip(req)
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes the idle connections.
|
||||
func (t *Transport) CloseIdleConnections() {
|
||||
// Adapted from net/http code
|
||||
type closeIdler interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if tr, ok := t.roundTripper.(closeIdler); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: New(http.DefaultTransport),
|
||||
}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func TestFailure(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: New(http.DefaultTransport),
|
||||
}
|
||||
// This fails the request because we attempt to speak cleartext HTTP with
|
||||
// a server that instead is expecting TLS.
|
||||
resp, err := client.Get("http://www.google.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil response here")
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// TraceTripper performs single HTTP transactions.
|
||||
type TraceTripper struct {
|
||||
readAllErrs *atomicx.Int64
|
||||
readAll func(r io.Reader) ([]byte, error)
|
||||
roundTripper http.RoundTripper
|
||||
}
|
||||
|
||||
// NewTraceTripper creates a new Transport.
|
||||
func NewTraceTripper(roundTripper http.RoundTripper) *TraceTripper {
|
||||
return &TraceTripper{
|
||||
readAllErrs: atomicx.NewInt64(),
|
||||
readAll: ioutil.ReadAll,
|
||||
roundTripper: roundTripper,
|
||||
}
|
||||
}
|
||||
|
||||
type readCloseWrapper struct {
|
||||
closer io.Closer
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newReadCloseWrapper(
|
||||
reader io.Reader, closer io.ReadCloser,
|
||||
) *readCloseWrapper {
|
||||
return &readCloseWrapper{
|
||||
closer: closer,
|
||||
reader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *readCloseWrapper) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *readCloseWrapper) Close() error {
|
||||
return c.closer.Close()
|
||||
}
|
||||
|
||||
func readSnap(
|
||||
source *io.ReadCloser, limit int64,
|
||||
readAll func(r io.Reader) ([]byte, error),
|
||||
) (data []byte, err error) {
|
||||
data, err = readAll(io.LimitReader(*source, limit))
|
||||
if err == nil {
|
||||
*source = newReadCloseWrapper(
|
||||
io.MultiReader(bytes.NewReader(data), *source),
|
||||
*source,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction, returning
|
||||
// a Response for the provided Request.
|
||||
func (t *TraceTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
root := modelx.ContextMeasurementRootOrDefault(req.Context())
|
||||
|
||||
tid := transactionid.ContextTransactionID(req.Context())
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPRoundTripStart: &modelx.HTTPRoundTripStartEvent{
|
||||
DialID: dialid.ContextDialID(req.Context()),
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Method: req.Method,
|
||||
TransactionID: tid,
|
||||
URL: req.URL.String(),
|
||||
},
|
||||
})
|
||||
|
||||
var (
|
||||
err error
|
||||
majorOp = errorx.HTTPRoundTripOperation
|
||||
majorOpMu sync.Mutex
|
||||
requestBody []byte
|
||||
requestHeaders = http.Header{}
|
||||
requestHeadersMu sync.Mutex
|
||||
snapSize = modelx.ComputeBodySnapSize(root.MaxBodySnapSize)
|
||||
)
|
||||
|
||||
// Save a snapshot of the request body
|
||||
if req.Body != nil {
|
||||
requestBody, err = readSnap(&req.Body, snapSize, t.readAll)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare a tracer for delivering events
|
||||
tracer := &httptrace.ClientTrace{
|
||||
TLSHandshakeStart: func() {
|
||||
majorOpMu.Lock()
|
||||
majorOp = errorx.TLSHandshakeOperation
|
||||
majorOpMu.Unlock()
|
||||
// Event emitted by net/http when DialTLS is not
|
||||
// configured in the http.Transport
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
TLSHandshakeStart: &modelx.TLSHandshakeStartEvent{
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
TransactionID: tid,
|
||||
},
|
||||
})
|
||||
},
|
||||
TLSHandshakeDone: func(state tls.ConnectionState, err error) {
|
||||
// Wrapping the error even if we're not returning it because it may
|
||||
// less confusing to users to see the wrapped name
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
Error: err,
|
||||
Operation: errorx.TLSHandshakeOperation,
|
||||
TransactionID: tid,
|
||||
}.MaybeBuild()
|
||||
durationSinceBeginning := time.Now().Sub(root.Beginning)
|
||||
// Event emitted by net/http when DialTLS is not
|
||||
// configured in the http.Transport
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
TLSHandshakeDone: &modelx.TLSHandshakeDoneEvent{
|
||||
ConnectionState: modelx.NewTLSConnectionState(state),
|
||||
Error: err,
|
||||
DurationSinceBeginning: durationSinceBeginning,
|
||||
TransactionID: tid,
|
||||
},
|
||||
})
|
||||
},
|
||||
GotConn: func(info httptrace.GotConnInfo) {
|
||||
majorOpMu.Lock()
|
||||
majorOp = errorx.HTTPRoundTripOperation
|
||||
majorOpMu.Unlock()
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPConnectionReady: &modelx.HTTPConnectionReadyEvent{
|
||||
ConnID: connid.Compute(
|
||||
info.Conn.LocalAddr().Network(),
|
||||
info.Conn.LocalAddr().String(),
|
||||
),
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
TransactionID: tid,
|
||||
},
|
||||
})
|
||||
},
|
||||
WroteHeaderField: func(key string, values []string) {
|
||||
requestHeadersMu.Lock()
|
||||
// Important: do not set directly into the headers map using
|
||||
// the [] operator because net/http expects to be able to
|
||||
// perform normalization of header names!
|
||||
for _, value := range values {
|
||||
requestHeaders.Add(key, value)
|
||||
}
|
||||
requestHeadersMu.Unlock()
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPRequestHeader: &modelx.HTTPRequestHeaderEvent{
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Key: key,
|
||||
TransactionID: tid,
|
||||
Value: values,
|
||||
},
|
||||
})
|
||||
},
|
||||
WroteHeaders: func() {
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPRequestHeadersDone: &modelx.HTTPRequestHeadersDoneEvent{
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Headers: requestHeaders, // [*]
|
||||
Method: req.Method, // [*]
|
||||
TransactionID: tid,
|
||||
URL: req.URL, // [*]
|
||||
},
|
||||
})
|
||||
},
|
||||
WroteRequest: func(info httptrace.WroteRequestInfo) {
|
||||
// Wrapping the error even if we're not returning it because it may
|
||||
// less confusing to users to see the wrapped name
|
||||
err := errorx.SafeErrWrapperBuilder{
|
||||
Error: info.Err,
|
||||
Operation: errorx.HTTPRoundTripOperation,
|
||||
TransactionID: tid,
|
||||
}.MaybeBuild()
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPRequestDone: &modelx.HTTPRequestDoneEvent{
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Error: err,
|
||||
TransactionID: tid,
|
||||
},
|
||||
})
|
||||
},
|
||||
GotFirstResponseByte: func() {
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPResponseStart: &modelx.HTTPResponseStartEvent{
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
TransactionID: tid,
|
||||
},
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// If we don't have already a tracer this is a toplevel request, so just
|
||||
// set the tracer. Otherwise, we're doing DoH. We cannot set anothert trace
|
||||
// because they'd be merged. Instead, replace the existing trace content
|
||||
// with the new trace and then remember to reset it.
|
||||
origtracer := httptrace.ContextClientTrace(req.Context())
|
||||
if origtracer != nil {
|
||||
bkp := *origtracer
|
||||
*origtracer = *tracer
|
||||
defer func() {
|
||||
*origtracer = bkp
|
||||
}()
|
||||
} else {
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), tracer))
|
||||
}
|
||||
|
||||
resp, err := t.roundTripper.RoundTrip(req)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
Error: err,
|
||||
Operation: majorOp,
|
||||
TransactionID: tid,
|
||||
}.MaybeBuild()
|
||||
// [*] Require less event joining work by providing info that
|
||||
// makes this event alone actionable for OONI
|
||||
event := &modelx.HTTPRoundTripDoneEvent{
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
Error: err,
|
||||
RequestBodySnap: requestBody,
|
||||
RequestHeaders: requestHeaders, // [*]
|
||||
RequestMethod: req.Method, // [*]
|
||||
RequestURL: req.URL.String(), // [*]
|
||||
MaxBodySnapSize: snapSize,
|
||||
TransactionID: tid,
|
||||
}
|
||||
if resp != nil {
|
||||
event.ResponseHeaders = resp.Header
|
||||
event.ResponseStatusCode = int64(resp.StatusCode)
|
||||
event.ResponseProto = resp.Proto
|
||||
// Save a snapshot of the response body
|
||||
var data []byte
|
||||
data, err = readSnap(&resp.Body, snapSize, t.readAll)
|
||||
if err != nil {
|
||||
t.readAllErrs.Add(1)
|
||||
resp = nil // this is how net/http likes it
|
||||
} else {
|
||||
event.ResponseBodySnap = data
|
||||
}
|
||||
}
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
HTTPRoundTripDone: event,
|
||||
})
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes the idle connections.
|
||||
func (t *TraceTripper) CloseIdleConnections() {
|
||||
// Adapted from net/http code
|
||||
type closeIdler interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if tr, ok := t.roundTripper.(closeIdler); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
)
|
||||
|
||||
func TestTraceTripperSuccess(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: NewTraceTripper(http.DefaultTransport),
|
||||
}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
type roundTripHandler struct {
|
||||
roundTrips []*modelx.HTTPRoundTripDoneEvent
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (h *roundTripHandler) OnMeasurement(m modelx.Measurement) {
|
||||
if m.HTTPRoundTripDone != nil {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.roundTrips = append(h.roundTrips, m.HTTPRoundTripDone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraceTripperReadAllFailure(t *testing.T) {
|
||||
transport := NewTraceTripper(http.DefaultTransport)
|
||||
transport.readAll = func(r io.Reader) ([]byte, error) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err := client.Get("https://google.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
if transport.readAllErrs.Load() <= 0 {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func TestTraceTripperFailure(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: NewTraceTripper(http.DefaultTransport),
|
||||
}
|
||||
// This fails the request because we attempt to speak cleartext HTTP with
|
||||
// a server that instead is expecting TLS.
|
||||
resp, err := client.Get("http://www.google.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil response here")
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func TestTraceTripperWithClientTrace(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: NewTraceTripper(http.DefaultTransport),
|
||||
}
|
||||
req, err := http.NewRequest("GET", "https://www.kernel.org/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req = req.WithContext(
|
||||
httptrace.WithClientTrace(req.Context(), new(httptrace.ClientTrace)),
|
||||
)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected a good response here")
|
||||
}
|
||||
resp.Body.Close()
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func TestTraceTripperWithCorrectSnaps(t *testing.T) {
|
||||
// Prepare a DNS query for dns.google.com A, for which we
|
||||
// know the answer in terms of well know IP addresses
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = dns.Question{
|
||||
Name: dns.Fqdn("dns.google.com"),
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
queryData, err := query.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Prepare a new transport with limited snapshot size and
|
||||
// use such transport to configure an ordinary client
|
||||
transport := NewTraceTripper(http.DefaultTransport)
|
||||
const snapSize = 15
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
// Prepare a new request for Cloudflare DNS, register
|
||||
// a handler, issue the request, fetch the response.
|
||||
req, err := http.NewRequest(
|
||||
"POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
handler := &roundTripHandler{}
|
||||
ctx := modelx.WithMeasurementRoot(
|
||||
context.Background(), &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
MaxBodySnapSize: snapSize,
|
||||
},
|
||||
)
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatal("HTTP request failed")
|
||||
}
|
||||
|
||||
// Read the whole response body, parse it as valid DNS
|
||||
// reply and verify we obtained what we expected
|
||||
replyData, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
reply := new(dns.Msg)
|
||||
err = reply.Unpack(replyData)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reply.Rcode != 0 {
|
||||
t.Fatal("unexpected Rcode")
|
||||
}
|
||||
if len(reply.Answer) < 1 {
|
||||
t.Fatal("no answers?!")
|
||||
}
|
||||
found8888, found8844, foundother := false, false, false
|
||||
for _, answer := range reply.Answer {
|
||||
if rra, ok := answer.(*dns.A); ok {
|
||||
ip := rra.A.String()
|
||||
if ip == "8.8.8.8" {
|
||||
found8888 = true
|
||||
} else if ip == "8.8.4.4" {
|
||||
found8844 = true
|
||||
} else {
|
||||
foundother = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found8888 || !found8844 || foundother {
|
||||
t.Fatal("unexpected reply")
|
||||
}
|
||||
|
||||
// Finally, make sure we have captured the correct
|
||||
// snapshots for the request and response bodies
|
||||
if len(handler.roundTrips) != 1 {
|
||||
t.Fatal("more round trips than expected")
|
||||
}
|
||||
roundTrip := handler.roundTrips[0]
|
||||
if len(roundTrip.RequestBodySnap) != snapSize {
|
||||
t.Fatal("unexpected request body snap length")
|
||||
}
|
||||
if len(roundTrip.ResponseBodySnap) != snapSize {
|
||||
t.Fatal("unexpected response body snap length")
|
||||
}
|
||||
if !bytes.Equal(roundTrip.RequestBodySnap, queryData[:snapSize]) {
|
||||
t.Fatal("the request body snap is wrong")
|
||||
}
|
||||
if !bytes.Equal(roundTrip.ResponseBodySnap, replyData[:snapSize]) {
|
||||
t.Fatal("the response body snap is wrong")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTraceTripperWithReadAllFailingForBody(t *testing.T) {
|
||||
// Prepare a DNS query for dns.google.com A, for which we
|
||||
// know the answer in terms of well know IP addresses
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = dns.Question{
|
||||
Name: dns.Fqdn("dns.google.com"),
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
queryData, err := query.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Prepare a new transport with limited snapshot size and
|
||||
// use such transport to configure an ordinary client
|
||||
transport := NewTraceTripper(http.DefaultTransport)
|
||||
errorMocked := errors.New("mocked error")
|
||||
transport.readAll = func(r io.Reader) ([]byte, error) {
|
||||
return nil, errorMocked
|
||||
}
|
||||
const snapSize = 15
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
// Prepare a new request for Cloudflare DNS, register
|
||||
// a handler, issue the request, fetch the response.
|
||||
req, err := http.NewRequest(
|
||||
"POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
handler := &roundTripHandler{}
|
||||
ctx := modelx.WithMeasurementRoot(
|
||||
context.Background(), &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: handler,
|
||||
MaxBodySnapSize: snapSize,
|
||||
},
|
||||
)
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if !errors.Is(err, errorMocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
|
||||
// Finally, make sure we got something that makes sense
|
||||
if len(handler.roundTrips) != 0 {
|
||||
t.Fatal("more round trips than expected")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
)
|
||||
|
||||
// Transactioner performs single HTTP transactions.
|
||||
type Transactioner struct {
|
||||
roundTripper http.RoundTripper
|
||||
}
|
||||
|
||||
// NewTransactioner creates a new Transport.
|
||||
func NewTransactioner(roundTripper http.RoundTripper) *Transactioner {
|
||||
return &Transactioner{
|
||||
roundTripper: roundTripper,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip executes a single HTTP transaction, returning
|
||||
// a Response for the provided Request.
|
||||
func (t *Transactioner) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.roundTripper.RoundTrip(req.WithContext(
|
||||
transactionid.WithTransactionID(req.Context()),
|
||||
))
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes the idle connections.
|
||||
func (t *Transactioner) CloseIdleConnections() {
|
||||
// Adapted from net/http code
|
||||
type closeIdler interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if tr, ok := t.roundTripper.(closeIdler); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package oldhttptransport
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
)
|
||||
|
||||
type transactionerCheckTransactionID struct {
|
||||
roundTripper http.RoundTripper
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (t *transactionerCheckTransactionID) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
ctx := req.Context()
|
||||
if id := transactionid.ContextTransactionID(ctx); id == 0 {
|
||||
t.t.Fatal("transaction ID not set")
|
||||
}
|
||||
return t.roundTripper.RoundTrip(req)
|
||||
}
|
||||
|
||||
func TestTransactionerSuccess(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: NewTransactioner(&transactionerCheckTransactionID{
|
||||
roundTripper: http.DefaultTransport,
|
||||
t: t,
|
||||
}),
|
||||
}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func TestTransactionerFailure(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: NewTransactioner(http.DefaultTransport),
|
||||
}
|
||||
// This fails the request because we attempt to speak cleartext HTTP with
|
||||
// a server that instead is expecting TLS.
|
||||
resp, err := client.Get("http://www.google.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected a nil response here")
|
||||
}
|
||||
client.CloseIdleConnections()
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package netx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
var (
|
||||
dohClientHandle *http.Client
|
||||
dohClientOnce sync.Once
|
||||
)
|
||||
|
||||
func newHTTPClientForDoH(beginning time.Time, handler modelx.Handler) *http.Client {
|
||||
if handler == handlers.NoHandler {
|
||||
// A bit of extra complexity for a good reason: if the user is not
|
||||
// interested into setting a default handler, then it is fine to
|
||||
// always return the same *http.Client for DoH. This means that we
|
||||
// don't need to care about closing the connections used by this
|
||||
// *http.Client, therefore we don't leak resources because we fail
|
||||
// to close the idle connections.
|
||||
dohClientOnce.Do(func() {
|
||||
transport := newHTTPTransport(
|
||||
time.Now(),
|
||||
handlers.NoHandler,
|
||||
newDialer(time.Now(), handler),
|
||||
false, // DisableKeepAlives
|
||||
http.ProxyFromEnvironment,
|
||||
)
|
||||
dohClientHandle = &http.Client{Transport: transport}
|
||||
})
|
||||
return dohClientHandle
|
||||
}
|
||||
// Otherwise, if the user wants to have a default handler, we
|
||||
// return a transport that does not leak connections.
|
||||
transport := newHTTPTransport(
|
||||
beginning,
|
||||
handler,
|
||||
newDialer(beginning, handler),
|
||||
true, // DisableKeepAlives
|
||||
http.ProxyFromEnvironment,
|
||||
)
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
|
||||
func withPort(address, port string) string {
|
||||
// Handle the case where port was not specified. We have written in
|
||||
// a bunch of places that we can just pass a domain in this case and
|
||||
// so we need to gracefully ensure this is still possible.
|
||||
_, _, err := net.SplitHostPort(address)
|
||||
if err != nil && strings.Contains(err.Error(), "missing port in address") {
|
||||
address = net.JoinHostPort(address, port)
|
||||
}
|
||||
return address
|
||||
}
|
||||
|
||||
type resolverWrapper struct {
|
||||
beginning time.Time
|
||||
handler modelx.Handler
|
||||
resolver modelx.DNSResolver
|
||||
}
|
||||
|
||||
func newResolverWrapper(
|
||||
beginning time.Time, handler modelx.Handler,
|
||||
resolver modelx.DNSResolver,
|
||||
) *resolverWrapper {
|
||||
return &resolverWrapper{
|
||||
beginning: beginning,
|
||||
handler: handler,
|
||||
resolver: resolver,
|
||||
}
|
||||
}
|
||||
|
||||
// LookupHost returns the IP addresses of a host
|
||||
func (r *resolverWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
ctx = maybeWithMeasurementRoot(ctx, r.beginning, r.handler)
|
||||
return r.resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
func newResolver(
|
||||
beginning time.Time, handler modelx.Handler, network, address string,
|
||||
) (modelx.DNSResolver, error) {
|
||||
// Implementation note: system need to be dealt with
|
||||
// separately because it doesn't have any transport.
|
||||
if network == "system" || network == "" {
|
||||
return newResolverWrapper(
|
||||
beginning, handler, newResolverSystem()), nil
|
||||
}
|
||||
if network == "doh" {
|
||||
return newResolverWrapper(beginning, handler, newResolverHTTPS(
|
||||
newHTTPClientForDoH(beginning, handler), address,
|
||||
)), nil
|
||||
}
|
||||
if network == "dot" {
|
||||
// We need a child dialer here to avoid an endless loop where the
|
||||
// dialer will ask us to resolve, we'll tell the dialer to dial, it
|
||||
// will ask us to resolve, ...
|
||||
return newResolverWrapper(beginning, handler, newResolverTLS(
|
||||
newDialer(beginning, handler).DialTLSContext, withPort(address, "853"),
|
||||
)), nil
|
||||
}
|
||||
if network == "tcp" {
|
||||
// Same rationale as above: avoid possible endless loop
|
||||
return newResolverWrapper(beginning, handler, newResolverTCP(
|
||||
newDialer(beginning, handler).DialContext, withPort(address, "53"),
|
||||
)), nil
|
||||
}
|
||||
if network == "udp" {
|
||||
// Same rationale as above: avoid possible endless loop
|
||||
return newResolverWrapper(beginning, handler, newResolverUDP(
|
||||
newDialer(beginning, handler), withPort(address, "53"),
|
||||
)), nil
|
||||
}
|
||||
return nil, errors.New("resolver.New: unsupported network value")
|
||||
}
|
||||
|
||||
// NewResolver creates a standalone Resolver
|
||||
func NewResolver(network, address string) (modelx.DNSResolver, error) {
|
||||
return newResolver(time.Now(), handlers.NoHandler, network, address)
|
||||
}
|
||||
|
||||
type chainWrapperResolver struct {
|
||||
modelx.DNSResolver
|
||||
}
|
||||
|
||||
func (r chainWrapperResolver) Network() string {
|
||||
return "chain"
|
||||
}
|
||||
|
||||
func (r chainWrapperResolver) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ChainResolvers chains a primary and a secondary resolver such that
|
||||
// we can fallback to the secondary if primary is broken.
|
||||
func ChainResolvers(primary, secondary modelx.DNSResolver) modelx.DNSResolver {
|
||||
return resolver.ChainResolver{
|
||||
Primary: chainWrapperResolver{DNSResolver: primary},
|
||||
Secondary: chainWrapperResolver{DNSResolver: secondary},
|
||||
}
|
||||
}
|
||||
|
||||
func resolverWrapResolver(r resolver.Resolver) resolver.EmitterResolver {
|
||||
return resolver.EmitterResolver{Resolver: resolver.ErrorWrapperResolver{Resolver: r}}
|
||||
}
|
||||
|
||||
func resolverWrapTransport(txp resolver.RoundTripper) resolver.EmitterResolver {
|
||||
return resolverWrapResolver(resolver.NewSerialResolver(
|
||||
resolver.EmitterTransport{RoundTripper: txp}))
|
||||
}
|
||||
|
||||
func newResolverSystem() resolver.EmitterResolver {
|
||||
return resolverWrapResolver(resolver.SystemResolver{})
|
||||
}
|
||||
|
||||
func newResolverUDP(dialer resolver.Dialer, address string) resolver.EmitterResolver {
|
||||
return resolverWrapTransport(resolver.NewDNSOverUDP(dialer, address))
|
||||
}
|
||||
|
||||
func newResolverTCP(dial resolver.DialContextFunc, address string) resolver.EmitterResolver {
|
||||
return resolverWrapTransport(resolver.NewDNSOverTCP(dial, address))
|
||||
}
|
||||
|
||||
func newResolverTLS(dial resolver.DialContextFunc, address string) resolver.EmitterResolver {
|
||||
return resolverWrapTransport(resolver.NewDNSOverTLS(dial, address))
|
||||
}
|
||||
|
||||
func newResolverHTTPS(client *http.Client, address string) resolver.EmitterResolver {
|
||||
return resolverWrapTransport(resolver.NewDNSOverHTTPS(client, address))
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package netx
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
)
|
||||
|
||||
func NewHTTPClientForDoH(beginning time.Time, handler modelx.Handler) *http.Client {
|
||||
return newHTTPClientForDoH(beginning, handler)
|
||||
}
|
||||
|
||||
type ChainWrapperResolver = chainWrapperResolver
|
||||
@@ -0,0 +1,168 @@
|
||||
package netx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
)
|
||||
|
||||
func testresolverquick(t *testing.T, network, address string) {
|
||||
resolver, err := netx.NewResolver(network, address)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resolver == nil {
|
||||
t.Fatal("expected non-nil resolver here")
|
||||
}
|
||||
addrs, err := resolver.LookupHost(context.Background(), "dns.google.com")
|
||||
if err != nil {
|
||||
t.Fatalf("legacy/netx/resolver_test.go: %+v with %s/%s", err, network, address)
|
||||
}
|
||||
if addrs == nil {
|
||||
t.Fatal("expected non-nil addrs here")
|
||||
}
|
||||
var foundquad8 bool
|
||||
for _, addr := range addrs {
|
||||
// See https://github.com/ooni/probe-cli/v3/internal/engine/pull/954/checks?check_run_id=1182269025
|
||||
if addr == "8.8.8.8" || addr == "2001:4860:4860::8888" {
|
||||
foundquad8 = true
|
||||
}
|
||||
}
|
||||
if !foundquad8 {
|
||||
t.Fatalf("did not find 8.8.8.8 in ouput; output=%+v", addrs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolverUDPAddress(t *testing.T) {
|
||||
testresolverquick(t, "udp", "8.8.8.8:53")
|
||||
}
|
||||
|
||||
func TestNewResolverUDPAddressNoPort(t *testing.T) {
|
||||
testresolverquick(t, "udp", "8.8.8.8")
|
||||
}
|
||||
|
||||
func TestNewResolverUDPDomain(t *testing.T) {
|
||||
testresolverquick(t, "udp", "dns.google.com:53")
|
||||
}
|
||||
|
||||
func TestNewResolverUDPDomainNoPort(t *testing.T) {
|
||||
testresolverquick(t, "udp", "dns.google.com")
|
||||
}
|
||||
|
||||
func TestNewResolverSystem(t *testing.T) {
|
||||
testresolverquick(t, "system", "")
|
||||
}
|
||||
|
||||
func TestNewResolverTCPAddress(t *testing.T) {
|
||||
testresolverquick(t, "tcp", "8.8.8.8:53")
|
||||
}
|
||||
|
||||
func TestNewResolverTCPAddressNoPort(t *testing.T) {
|
||||
testresolverquick(t, "tcp", "8.8.8.8")
|
||||
}
|
||||
|
||||
func TestNewResolverTCPDomain(t *testing.T) {
|
||||
testresolverquick(t, "tcp", "dns.google.com:53")
|
||||
}
|
||||
|
||||
func TestNewResolverTCPDomainNoPort(t *testing.T) {
|
||||
testresolverquick(t, "tcp", "dns.google.com")
|
||||
}
|
||||
|
||||
func TestNewResolverDoTAddress(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("this test is not reliable in GitHub actions")
|
||||
}
|
||||
testresolverquick(t, "dot", "9.9.9.9:853")
|
||||
}
|
||||
|
||||
func TestNewResolverDoTAddressNoPort(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("this test is not reliable in GitHub actions")
|
||||
}
|
||||
testresolverquick(t, "dot", "9.9.9.9")
|
||||
}
|
||||
|
||||
func TestNewResolverDoTDomain(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("this test is not reliable in GitHub actions")
|
||||
}
|
||||
testresolverquick(t, "dot", "dns.quad9.net:853")
|
||||
}
|
||||
|
||||
func TestNewResolverDoTDomainNoPort(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("this test is not reliable in GitHub actions")
|
||||
}
|
||||
testresolverquick(t, "dot", "dns.quad9.net")
|
||||
}
|
||||
|
||||
func TestNewResolverDoH(t *testing.T) {
|
||||
testresolverquick(t, "doh", "https://cloudflare-dns.com/dns-query")
|
||||
}
|
||||
|
||||
func TestNewResolverInvalid(t *testing.T) {
|
||||
resolver, err := netx.NewResolver(
|
||||
"antani", "https://cloudflare-dns.com/dns-query",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if resolver != nil {
|
||||
t.Fatal("expected a nil resolver here")
|
||||
}
|
||||
}
|
||||
|
||||
type failingResolver struct{}
|
||||
|
||||
func (failingResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
func TestChainResolvers(t *testing.T) {
|
||||
fallback, err := netx.NewResolver("udp", "1.1.1.1:53")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dialer := netx.NewDialer()
|
||||
resolver := netx.ChainResolvers(failingResolver{}, fallback)
|
||||
dialer.SetResolver(resolver)
|
||||
conn, err := dialer.Dial("tcp", "www.google.com:80")
|
||||
if err != nil {
|
||||
t.Fatal(err) // we don't expect error because good resolver is first
|
||||
}
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
func TestNewHTTPClientForDoH(t *testing.T) {
|
||||
first := netx.NewHTTPClientForDoH(
|
||||
time.Now(), handlers.NoHandler,
|
||||
)
|
||||
second := netx.NewHTTPClientForDoH(
|
||||
time.Now(), handlers.NoHandler,
|
||||
)
|
||||
if first != second {
|
||||
t.Fatal("expected to see same client here")
|
||||
}
|
||||
third := netx.NewHTTPClientForDoH(
|
||||
time.Now(), handlers.StdoutHandler,
|
||||
)
|
||||
if first == third {
|
||||
t.Fatal("expected to see different client here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainWrapperResolver(t *testing.T) {
|
||||
r := netx.ChainWrapperResolver{}
|
||||
if r.Address() != "" {
|
||||
t.Fatal("invalid Address")
|
||||
}
|
||||
if r.Network() != "chain" {
|
||||
t.Fatal("invalid Network")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
#
|
||||
# The following is a truncated CA bundle for integration testing. This
|
||||
# will give us confidence that we fail if the file is wrong.
|
||||
#
|
||||
|
||||
emSign ECC Root CA - C3
|
||||
=======================
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAbGgAwIBAgIKe3G2gla4EnycqDAKBggqhkjOPQQDAzBaMQswCQYDVQQGEwJVUzETMBEG
|
||||
A1UECxMKZW1TaWduIFBLSTEUMBIGA1UEChMLZU11ZGhyYSBJbmMxIDAeBgNVBAMTF2VtU2lnbiBF
|
||||
Q0MgUm9vdCBDQSAtIEMzMB4XDTE4MDIxODE4MzAwMFoXDTQzMDIxODE4MzAwMFowWjELMAkGA1UE
|
||||
BhMCVVMxEzARBgNVBAsTCmVtU2lnbiBQS0kxFDASBgNVBAoTC2VNdWRocmEgSW5jMSAwHgYDVQQD
|
||||
ExdlbVNpZ24gRUNDIFJvb3QgQ0EgLSBDMzB2MBAGByqGSM49AgEGBSuBBAAiA2IABP2lYa57JhAd
|
||||
+54
@@ -0,0 +1,54 @@
|
||||
#
|
||||
# The following is a minimal, valid CA bundle. We do not include
|
||||
# however the certificates required to validate www.google.com
|
||||
# and we check in tests that we cannot connect to it and successfully
|
||||
# complete a TLS handshake. This gives us confidence that we can
|
||||
# actually override the CA bundle path.
|
||||
#
|
||||
|
||||
emSign ECC Root CA - C3
|
||||
=======================
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAbGgAwIBAgIKe3G2gla4EnycqDAKBggqhkjOPQQDAzBaMQswCQYDVQQGEwJVUzETMBEG
|
||||
A1UECxMKZW1TaWduIFBLSTEUMBIGA1UEChMLZU11ZGhyYSBJbmMxIDAeBgNVBAMTF2VtU2lnbiBF
|
||||
Q0MgUm9vdCBDQSAtIEMzMB4XDTE4MDIxODE4MzAwMFoXDTQzMDIxODE4MzAwMFowWjELMAkGA1UE
|
||||
BhMCVVMxEzARBgNVBAsTCmVtU2lnbiBQS0kxFDASBgNVBAoTC2VNdWRocmEgSW5jMSAwHgYDVQQD
|
||||
ExdlbVNpZ24gRUNDIFJvb3QgQ0EgLSBDMzB2MBAGByqGSM49AgEGBSuBBAAiA2IABP2lYa57JhAd
|
||||
6bciMK4G9IGzsUJxlTm801Ljr6/58pc1kjZGDoeVjbk5Wum739D+yAdBPLtVb4OjavtisIGJAnB9
|
||||
SMVK4+kiVCJNk7tCDK93nCOmfddhEc5lx/h//vXyqaNCMEAwHQYDVR0OBBYEFPtaSNCAIEDyqOkA
|
||||
B2kZd6fmw/TPMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMDA2gA
|
||||
MGUCMQC02C8Cif22TGK6Q04ThHK1rt0c3ta13FaPWEBaLd4gTCKDypOofu4SQMfWh0/434UCMBwU
|
||||
ZOR8loMRnLDRWmFLpg9J0wD8ofzkpf9/rdcw0Md3f76BB1UwUCAU9Vc4CqgxUQ==
|
||||
-----END CERTIFICATE-----
|
||||
|
||||
Hongkong Post Root CA 3
|
||||
=======================
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFzzCCA7egAwIBAgIUCBZfikyl7ADJk0DfxMauI7gcWqQwDQYJKoZIhvcNAQELBQAwbzELMAkG
|
||||
A1UEBhMCSEsxEjAQBgNVBAgTCUhvbmcgS29uZzESMBAGA1UEBxMJSG9uZyBLb25nMRYwFAYDVQQK
|
||||
Ew1Ib25na29uZyBQb3N0MSAwHgYDVQQDExdIb25na29uZyBQb3N0IFJvb3QgQ0EgMzAeFw0xNzA2
|
||||
MDMwMjI5NDZaFw00MjA2MDMwMjI5NDZaMG8xCzAJBgNVBAYTAkhLMRIwEAYDVQQIEwlIb25nIEtv
|
||||
bmcxEjAQBgNVBAcTCUhvbmcgS29uZzEWMBQGA1UEChMNSG9uZ2tvbmcgUG9zdDEgMB4GA1UEAxMX
|
||||
SG9uZ2tvbmcgUG9zdCBSb290IENBIDMwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCz
|
||||
iNfqzg8gTr7m1gNt7ln8wlffKWihgw4+aMdoWJwcYEuJQwy51BWy7sFOdem1p+/l6TWZ5Mwc50tf
|
||||
jTMwIDNT2aa71T4Tjukfh0mtUC1Qyhi+AViiE3CWu4mIVoBc+L0sPOFMV4i707mV78vH9toxdCim
|
||||
5lSJ9UExyuUmGs2C4HDaOym71QP1mbpV9WTRYA6ziUm4ii8F0oRFKHyPaFASePwLtVPLwpgchKOe
|
||||
sL4jpNrcyCse2m5FHomY2vkALgbpDDtw1VAliJnLzXNg99X/NWfFobxeq81KuEXryGgeDQ0URhLj
|
||||
0mRiikKYvLTGCAj4/ahMZJx2Ab0vqWwzD9g/KLg8aQFChn5pwckGyuV6RmXpwtZQQS4/t+TtbNe/
|
||||
JgERohYpSms0BpDsE9K2+2p20jzt8NYt3eEV7KObLyzJPivkaTv/ciWxNoZbx39ri1UbSsUgYT2u
|
||||
y1DhCDq+sI9jQVMwCFk8mB13umOResoQUGC/8Ne8lYePl8X+l2oBlKN8W4UdKjk60FSh0Tlxnf0h
|
||||
+bV78OLgAo9uliQlLKAeLKjEiafv7ZkGL7YKTE/bosw3Gq9HhS2KX8Q0NEwA/RiTZxPRN+ZItIsG
|
||||
xVd7GYYKecsAyVKvQv83j+GjHno9UKtjBucVtT+2RTeUN7F+8kjDf8V1/peNRY8apxpyKBpADwID
|
||||
AQABo2MwYTAPBgNVHRMBAf8EBTADAQH/MA4GA1UdDwEB/wQEAwIBBjAfBgNVHSMEGDAWgBQXnc0e
|
||||
i9Y5K3DTXNSguB+wAPzFYTAdBgNVHQ4EFgQUF53NHovWOStw01zUoLgfsAD8xWEwDQYJKoZIhvcN
|
||||
AQELBQADggIBAFbVe27mIgHSQpsY1Q7XZiNc4/6gx5LS6ZStS6LG7BJ8dNVI0lkUmcDrudHr9Egw
|
||||
W62nV3OZqdPlt9EuWSRY3GguLmLYauRwCy0gUCCkMpXRAJi70/33MvJJrsZ64Ee+bs7Lo3I6LWld
|
||||
y8joRTnU+kLBEUx3XZL7av9YROXrgZ6voJmtvqkBZss4HTzfQx/0TW60uhdG/H39h4F5ag0zD/ov
|
||||
+BS5gLNdTaqX4fnkGMX41TiMJjz98iji7lpJiCzfeT2OnpA8vUFKOt1b9pq0zj8lMH8yfaIDlNDc
|
||||
eqFS3m6TjRgm/VWsvY+b0s+v54Ysyx8Jb6NvqYTUc79NoXQbTiNg8swOqn+knEwlqLJmOzj/2ZQw
|
||||
9nKEvmhVEA/GcywWaZMH/rFF7buiVWqw2rVKAiUnhde3t4ZEFolsgCs+l6mc1X5VTMbeRRAc6uk7
|
||||
nwNT7u56AQIWeNTowr5GdogTPyK7SBIdUgC0An4hGh6cJfTzPV4e0hz5sy229zdcxsshTrD3mUcY
|
||||
hcErulWuBurQB7Lcq9CClnXO0lD+mefPL5/ndtFhKvshuzHQqp9HpLIiyhY6UFfEW0NnxWViA0kB
|
||||
60PZ2Pierc+xYw5F9KBaLJstxabArahH9CdMOA0uG0k7UvToiIMrVCjU8jVStDKDYmlkDJGcn5fq
|
||||
dBb9HxEGmpv0
|
||||
-----END CERTIFICATE-----
|
||||
@@ -0,0 +1,25 @@
|
||||
// Package transactionid contains code to share the transactionID
|
||||
package transactionid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/atomicx"
|
||||
)
|
||||
|
||||
type contextkey struct{}
|
||||
|
||||
var id = atomicx.NewInt64()
|
||||
|
||||
// WithTransactionID returns a copy of ctx with TransactionID
|
||||
func WithTransactionID(ctx context.Context) context.Context {
|
||||
return context.WithValue(
|
||||
ctx, contextkey{}, id.Add(1),
|
||||
)
|
||||
}
|
||||
|
||||
// ContextTransactionID returns the TransactionID of the context, or zero
|
||||
func ContextTransactionID(ctx context.Context) int64 {
|
||||
id, _ := ctx.Value(contextkey{}).(int64)
|
||||
return id
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package transactionid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
id := ContextTransactionID(ctx)
|
||||
if id != 0 {
|
||||
t.Fatal("unexpected ID for empty context")
|
||||
}
|
||||
ctx = WithTransactionID(ctx)
|
||||
id = ContextTransactionID(ctx)
|
||||
if id != 1 {
|
||||
t.Fatal("expected ID equal to 1")
|
||||
}
|
||||
ctx = WithTransactionID(ctx)
|
||||
id = ContextTransactionID(ctx)
|
||||
if id != 2 {
|
||||
t.Fatal("expected ID equal to 2")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user