chore: merge probe-engine into probe-cli (#201)
This is how I did it: 1. `git clone https://github.com/ooni/probe-engine internal/engine` 2. ``` (cd internal/engine && git describe --tags) v0.23.0 ``` 3. `nvim go.mod` (merging `go.mod` with `internal/engine/go.mod` 4. `rm -rf internal/.git internal/engine/go.{mod,sum}` 5. `git add internal/engine` 6. `find . -type f -name \*.go -exec sed -i 's@/ooni/probe-engine@/ooni/probe-cli/v3/internal/engine@g' {} \;` 7. `go build ./...` (passes) 8. `go test -race ./...` (temporary failure on RiseupVPN) 9. `go mod tidy` 10. this commit message Once this piece of work is done, we can build a new version of `ooniprobe` that is using `internal/engine` directly. We need to do more work to ensure all the other functionality in `probe-engine` (e.g. making mobile packages) are still WAI. Part of https://github.com/ooni/probe/issues/1335
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
)
|
||||
|
||||
// ByteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you
|
||||
// should make sure that you insert this dialer in the dialing chain.
|
||||
//
|
||||
// Bug
|
||||
//
|
||||
// This implementation cannot properly account for the bytes that are sent by
|
||||
// persistent connections, because they strick to the counters set when the
|
||||
// connection was established. This typically means we miss the bytes sent and
|
||||
// received when submitting a measurement. Such bytes are specifically not
|
||||
// see by the experiment specific byte counter.
|
||||
//
|
||||
// For this reason, this implementation may be heavily changed/removed.
|
||||
type ByteCounterDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ByteCounterDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
exp := ContextExperimentByteCounter(ctx)
|
||||
sess := ContextSessionByteCounter(ctx)
|
||||
if exp == nil && sess == nil {
|
||||
return conn, nil // no point in wrapping
|
||||
}
|
||||
return byteCounterConnWrapper{Conn: conn, exp: exp, sess: sess}, nil
|
||||
}
|
||||
|
||||
type byteCounterSessionKey struct{}
|
||||
|
||||
// ContextSessionByteCounter retrieves the session byte counter from the context
|
||||
func ContextSessionByteCounter(ctx context.Context) *bytecounter.Counter {
|
||||
counter, _ := ctx.Value(byteCounterSessionKey{}).(*bytecounter.Counter)
|
||||
return counter
|
||||
}
|
||||
|
||||
// WithSessionByteCounter assigns the session byte counter to the context
|
||||
func WithSessionByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context {
|
||||
return context.WithValue(ctx, byteCounterSessionKey{}, counter)
|
||||
}
|
||||
|
||||
type byteCounterExperimentKey struct{}
|
||||
|
||||
// ContextExperimentByteCounter retrieves the experiment byte counter from the context
|
||||
func ContextExperimentByteCounter(ctx context.Context) *bytecounter.Counter {
|
||||
counter, _ := ctx.Value(byteCounterExperimentKey{}).(*bytecounter.Counter)
|
||||
return counter
|
||||
}
|
||||
|
||||
// WithExperimentByteCounter assigns the experiment byte counter to the context
|
||||
func WithExperimentByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context {
|
||||
return context.WithValue(ctx, byteCounterExperimentKey{}, counter)
|
||||
}
|
||||
|
||||
type byteCounterConnWrapper struct {
|
||||
net.Conn
|
||||
exp *bytecounter.Counter
|
||||
sess *bytecounter.Counter
|
||||
}
|
||||
|
||||
func (c byteCounterConnWrapper) Read(p []byte) (int, error) {
|
||||
count, err := c.Conn.Read(p)
|
||||
if c.exp != nil {
|
||||
c.exp.CountBytesReceived(count)
|
||||
}
|
||||
if c.sess != nil {
|
||||
c.sess.CountBytesReceived(count)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c byteCounterConnWrapper) Write(p []byte) (int, error) {
|
||||
count, err := c.Conn.Write(p)
|
||||
if c.exp != nil {
|
||||
c.exp.CountBytesSent(count)
|
||||
}
|
||||
if c.sess != nil {
|
||||
c.sess.CountBytesSent(count)
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func dorequest(ctx context.Context, url string) error {
|
||||
txp := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defer txp.CloseIdleConnections()
|
||||
dialer := dialer.ByteCounterDialer{Dialer: new(net.Dialer)}
|
||||
txp.DialContext = dialer.DialContext
|
||||
client := &http.Client{Transport: txp}
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "http://www.google.com", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil {
|
||||
return err
|
||||
}
|
||||
return resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestByteCounterNormalUsage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
sess := bytecounter.New()
|
||||
ctx := context.Background()
|
||||
ctx = dialer.WithSessionByteCounter(ctx, sess)
|
||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
exp := bytecounter.New()
|
||||
ctx = dialer.WithExperimentByteCounter(ctx, exp)
|
||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if sess.Received.Load() <= exp.Received.Load() {
|
||||
t.Fatal("session should have received more than experiment")
|
||||
}
|
||||
if sess.Sent.Load() <= exp.Sent.Load() {
|
||||
t.Fatal("session should have sent more than experiment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterNoHandlers(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
ctx := context.Background()
|
||||
if err := dorequest(ctx, "http://www.google.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := dorequest(ctx, "http://facebook.com"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteCounterConnectFailure(t *testing.T) {
|
||||
dialer := dialer.ByteCounterDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
|
||||
)
|
||||
|
||||
// Dialer is the interface we expect from a dialer
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Resolver is the interface we expect from a resolver
|
||||
type Resolver interface {
|
||||
LookupHost(ctx context.Context, hostname string) (addrs []string, err error)
|
||||
}
|
||||
|
||||
func safeLocalAddress(conn net.Conn) (s string) {
|
||||
if conn != nil && conn.LocalAddr() != nil {
|
||||
s = conn.LocalAddr().String()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func safeConnID(network string, conn net.Conn) int64 {
|
||||
return connid.Compute(network, safeLocalAddress(conn))
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// DNSDialer is a dialer that uses the configured Resolver to resolver a
|
||||
// domain name to IP addresses, and the configured Dialer to connect.
|
||||
type DNSDialer struct {
|
||||
Dialer
|
||||
Resolver Resolver
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext.
|
||||
func (d DNSDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
onlyhost, onlyport, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = dialid.WithDialID(ctx) // important to create before lookupHost
|
||||
var addrs []string
|
||||
addrs, err = d.LookupHost(ctx, onlyhost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var errorslist []error
|
||||
for _, addr := range addrs {
|
||||
target := net.JoinHostPort(addr, onlyport)
|
||||
conn, err := d.Dialer.DialContext(ctx, network, target)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
}
|
||||
return nil, ReduceErrors(errorslist)
|
||||
}
|
||||
|
||||
// ReduceErrors finds a known error in a list of errors since it's probably most relevant
|
||||
func ReduceErrors(errorslist []error) error {
|
||||
if len(errorslist) == 0 {
|
||||
return nil
|
||||
}
|
||||
// If we have a known error, let's consider this the real error
|
||||
// since it's probably most relevant. Otherwise let's return the
|
||||
// first considering that (1) local resolvers likely will give
|
||||
// us IPv4 first and (2) also our resolver does that. So, in case
|
||||
// the user has no IPv6 connectivity, an IPv6 error is going to
|
||||
// appear later in the list of errors.
|
||||
for _, err := range errorslist {
|
||||
var wrapper *errorx.ErrWrapper
|
||||
if errors.As(err, &wrapper) && !strings.HasPrefix(
|
||||
err.Error(), "unknown_failure",
|
||||
) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO(bassosimone): handle this case in a better way
|
||||
return errorslist[0]
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost
|
||||
func (d DNSDialer) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
return d.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func TestDNSDialerNoPort(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: new(net.Resolver)}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "antani.ooni.nu")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerLookupHostAddress(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{
|
||||
Err: errors.New("mocked error"),
|
||||
}}
|
||||
addrs, err := dialer.LookupHost(context.Background(), "1.1.1.1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 1 || addrs[0] != "1.1.1.1" {
|
||||
t.Fatal("not the result we expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerLookupHostFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
dialer := dialer.DNSDialer{Dialer: new(net.Dialer), Resolver: MockableResolver{
|
||||
Err: expected,
|
||||
}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "dns.google.com:853")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
type MockableResolver struct {
|
||||
Addresses []string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||
return r.Addresses, r.Err
|
||||
}
|
||||
|
||||
func TestDNSDialerDialForSingleIPFails(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: new(net.Resolver)}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "1.1.1.1:853")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerDialForManyIPFails(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: MockableResolver{
|
||||
Addresses: []string{"1.1.1.1", "8.8.8.8"},
|
||||
}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSDialerDialForManyIPSuccess(t *testing.T) {
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EOFConnDialer{}, Resolver: MockableResolver{
|
||||
Addresses: []string{"1.1.1.1", "8.8.8.8"},
|
||||
}}
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853")
|
||||
if err != nil {
|
||||
t.Fatal("expected nil error here")
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn")
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDNSDialerDialSetsDialID(t *testing.T) {
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx := modelx.WithMeasurementRoot(context.Background(), &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
dialer := dialer.DNSDialer{Dialer: dialer.EmitterDialer{
|
||||
Dialer: dialer.EOFConnDialer{},
|
||||
}, Resolver: MockableResolver{
|
||||
Addresses: []string{"1.1.1.1", "8.8.8.8"},
|
||||
}}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "dot.dns:853")
|
||||
if err != nil {
|
||||
t.Fatal("expected nil error here")
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn")
|
||||
}
|
||||
conn.Close()
|
||||
events := saver.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
for _, ev := range events {
|
||||
if ev.Connect != nil && ev.Connect.DialID == 0 {
|
||||
t.Fatal("unexpected DialID")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReduceErrors(t *testing.T) {
|
||||
t.Run("no errors", func(t *testing.T) {
|
||||
result := dialer.ReduceErrors(nil)
|
||||
if result != nil {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("single error", func(t *testing.T) {
|
||||
err := errors.New("mocked error")
|
||||
result := dialer.ReduceErrors([]error{err})
|
||||
if result != err {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple errors", func(t *testing.T) {
|
||||
err1 := errors.New("mocked error #1")
|
||||
err2 := errors.New("mocked error #2")
|
||||
result := dialer.ReduceErrors([]error{err1, err2})
|
||||
if result.Error() != "mocked error #1" {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
||||
err1 := errors.New("mocked error #1")
|
||||
err2 := &errorx.ErrWrapper{
|
||||
Failure: "unknown_failure: antani",
|
||||
}
|
||||
err3 := &errorx.ErrWrapper{
|
||||
Failure: errorx.FailureConnectionRefused,
|
||||
}
|
||||
err4 := errors.New("mocked error #3")
|
||||
result := dialer.ReduceErrors([]error{err1, err2, err3, err4})
|
||||
if result.Error() != errorx.FailureConnectionRefused {
|
||||
t.Fatal("wrong result")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
)
|
||||
|
||||
// EmitterDialer is a Dialer that emits events
|
||||
type EmitterDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d EmitterDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
start := time.Now()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
stop := time.Now()
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
Connect: &modelx.ConnectEvent{
|
||||
ConnID: safeConnID(network, conn),
|
||||
DialID: dialid.ContextDialID(ctx),
|
||||
DurationSinceBeginning: stop.Sub(root.Beginning),
|
||||
Error: err,
|
||||
Network: network,
|
||||
RemoteAddress: address,
|
||||
SyscallDuration: stop.Sub(start),
|
||||
TransactionID: transactionid.ContextTransactionID(ctx),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return EmitterConn{
|
||||
Conn: conn,
|
||||
Beginning: root.Beginning,
|
||||
Handler: root.Handler,
|
||||
ID: safeConnID(network, conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EmitterConn is a net.Conn used to emit events
|
||||
type EmitterConn struct {
|
||||
net.Conn
|
||||
Beginning time.Time
|
||||
Handler modelx.Handler
|
||||
ID int64
|
||||
}
|
||||
|
||||
// Read implements net.Conn.Read
|
||||
func (c EmitterConn) Read(b []byte) (n int, err error) {
|
||||
start := time.Now()
|
||||
n, err = c.Conn.Read(b)
|
||||
stop := time.Now()
|
||||
c.Handler.OnMeasurement(modelx.Measurement{
|
||||
Read: &modelx.ReadEvent{
|
||||
ConnID: c.ID,
|
||||
DurationSinceBeginning: stop.Sub(c.Beginning),
|
||||
Error: err,
|
||||
NumBytes: int64(n),
|
||||
SyscallDuration: stop.Sub(start),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write
|
||||
func (c EmitterConn) Write(b []byte) (n int, err error) {
|
||||
start := time.Now()
|
||||
n, err = c.Conn.Write(b)
|
||||
stop := time.Now()
|
||||
c.Handler.OnMeasurement(modelx.Measurement{
|
||||
Write: &modelx.WriteEvent{
|
||||
ConnID: c.ID,
|
||||
DurationSinceBeginning: stop.Sub(c.Beginning),
|
||||
Error: err,
|
||||
NumBytes: int64(n),
|
||||
SyscallDuration: stop.Sub(start),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Close implements net.Conn.Close
|
||||
func (c EmitterConn) Close() (err error) {
|
||||
start := time.Now()
|
||||
err = c.Conn.Close()
|
||||
stop := time.Now()
|
||||
c.Handler.OnMeasurement(modelx.Measurement{
|
||||
Close: &modelx.CloseEvent{
|
||||
ConnID: c.ID,
|
||||
DurationSinceBeginning: stop.Sub(c.Beginning),
|
||||
Error: err,
|
||||
SyscallDuration: stop.Sub(start),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestEmitterFailure(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
d := dialer.EmitterDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
events := saver.Read()
|
||||
if len(events) != 1 {
|
||||
t.Fatal("unexpected number of events saved")
|
||||
}
|
||||
if events[0].Connect == nil {
|
||||
t.Fatal("expected non nil Connect")
|
||||
}
|
||||
conninfo := events[0].Connect
|
||||
if conninfo.ConnID != 0 {
|
||||
t.Fatal("unexpected ConnID value")
|
||||
}
|
||||
emitterCheckConnectEventCommon(t, conninfo, io.EOF)
|
||||
}
|
||||
|
||||
func emitterCheckConnectEventCommon(
|
||||
t *testing.T, conninfo *modelx.ConnectEvent, err error) {
|
||||
if conninfo.DialID == 0 {
|
||||
t.Fatal("unexpected DialID value")
|
||||
}
|
||||
if conninfo.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning value")
|
||||
}
|
||||
if !errors.Is(conninfo.Error, err) {
|
||||
t.Fatal("unexpected Error value")
|
||||
}
|
||||
if conninfo.Network != "tcp" {
|
||||
t.Fatal("unexpected Network value")
|
||||
}
|
||||
if conninfo.RemoteAddress != "www.google.com:443" {
|
||||
t.Fatal("unexpected Network value")
|
||||
}
|
||||
if conninfo.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration value")
|
||||
}
|
||||
if conninfo.TransactionID == 0 {
|
||||
t.Fatal("unexpected TransactionID value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterSuccess(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx = modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
ctx = transactionid.WithTransactionID(ctx)
|
||||
d := dialer.EmitterDialer{Dialer: dialer.EOFConnDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal("we expected no error")
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected a non-nil conn here")
|
||||
}
|
||||
conn.Read(nil)
|
||||
conn.Write(nil)
|
||||
conn.Close()
|
||||
events := saver.Read()
|
||||
if len(events) != 4 {
|
||||
t.Fatal("unexpected number of events saved")
|
||||
}
|
||||
if events[0].Connect == nil {
|
||||
t.Fatal("expected non nil Connect")
|
||||
}
|
||||
conninfo := events[0].Connect
|
||||
if conninfo.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID value")
|
||||
}
|
||||
emitterCheckConnectEventCommon(t, conninfo, nil)
|
||||
if events[1].Read == nil {
|
||||
t.Fatal("expected non nil Read")
|
||||
}
|
||||
emitterCheckReadEvent(t, events[1].Read)
|
||||
if events[2].Write == nil {
|
||||
t.Fatal("expected non nil Write")
|
||||
}
|
||||
emitterCheckWriteEvent(t, events[2].Write)
|
||||
if events[3].Close == nil {
|
||||
t.Fatal("expected non nil Close")
|
||||
}
|
||||
emitterCheckCloseEvent(t, events[3].Close)
|
||||
}
|
||||
|
||||
func emitterCheckReadEvent(t *testing.T, ev *modelx.ReadEvent) {
|
||||
if ev.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if ev.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning")
|
||||
}
|
||||
if !errors.Is(ev.Error, io.EOF) {
|
||||
t.Fatal("unexpected Error")
|
||||
}
|
||||
if ev.NumBytes != 0 {
|
||||
t.Fatal("unexpected NumBytes")
|
||||
}
|
||||
if ev.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration")
|
||||
}
|
||||
}
|
||||
|
||||
func emitterCheckWriteEvent(t *testing.T, ev *modelx.WriteEvent) {
|
||||
if ev.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if ev.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning")
|
||||
}
|
||||
if !errors.Is(ev.Error, io.EOF) {
|
||||
t.Fatal("unexpected Error")
|
||||
}
|
||||
if ev.NumBytes != 0 {
|
||||
t.Fatal("unexpected NumBytes")
|
||||
}
|
||||
if ev.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration")
|
||||
}
|
||||
}
|
||||
|
||||
func emitterCheckCloseEvent(t *testing.T, ev *modelx.CloseEvent) {
|
||||
if ev.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if ev.DurationSinceBeginning == 0 {
|
||||
t.Fatal("unexpected DurationSinceBeginning")
|
||||
}
|
||||
if !errors.Is(ev.Error, io.EOF) {
|
||||
t.Fatal("unexpected Error")
|
||||
}
|
||||
if ev.SyscallDuration == 0 {
|
||||
t.Fatal("unexpected SyscallDuration")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EOFDialer struct{}
|
||||
|
||||
func (EOFDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
type EOFConnDialer struct{}
|
||||
|
||||
func (EOFConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return EOFConn{}, nil
|
||||
}
|
||||
|
||||
type EOFConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (EOFConn) Read(p []byte) (int, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (EOFConn) Write(p []byte) (int, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (EOFConn) Close() error {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (EOFConn) LocalAddr() net.Addr {
|
||||
return EOFAddr{}
|
||||
}
|
||||
|
||||
func (EOFConn) RemoteAddr() net.Addr {
|
||||
return EOFAddr{}
|
||||
}
|
||||
|
||||
func (EOFConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (EOFConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (EOFConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type EOFAddr struct{}
|
||||
|
||||
func (EOFAddr) Network() string {
|
||||
return "tcp"
|
||||
}
|
||||
|
||||
func (EOFAddr) String() string {
|
||||
return "127.0.0.1:1234"
|
||||
}
|
||||
|
||||
type EOFTLSHandshaker struct{}
|
||||
|
||||
func (EOFTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return nil, tls.ConnectionState{}, io.EOF
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// ErrorWrapperDialer is a dialer that performs err wrapping
|
||||
type ErrorWrapperDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ErrorWrapperDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
dialID := dialid.ContextDialID(ctx)
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
// ConnID does not make any sense if we've failed and the error
|
||||
// does not make any sense (and is nil) if we succeded.
|
||||
DialID: dialID,
|
||||
Error: err,
|
||||
Operation: errorx.ConnectOperation,
|
||||
}.MaybeBuild()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ErrorWrapperConn{
|
||||
Conn: conn, ConnID: safeConnID(network, conn), DialID: dialID}, nil
|
||||
}
|
||||
|
||||
// ErrorWrapperConn is a net.Conn that performs error wrapping.
|
||||
type ErrorWrapperConn struct {
|
||||
net.Conn
|
||||
ConnID int64
|
||||
DialID int64
|
||||
}
|
||||
|
||||
// Read implements net.Conn.Read
|
||||
func (c ErrorWrapperConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Read(b)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: c.ConnID,
|
||||
DialID: c.DialID,
|
||||
Error: err,
|
||||
Operation: errorx.ReadOperation,
|
||||
}.MaybeBuild()
|
||||
return
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write
|
||||
func (c ErrorWrapperConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.Conn.Write(b)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: c.ConnID,
|
||||
DialID: c.DialID,
|
||||
Error: err,
|
||||
Operation: errorx.WriteOperation,
|
||||
}.MaybeBuild()
|
||||
return
|
||||
}
|
||||
|
||||
// Close implements net.Conn.Close
|
||||
func (c ErrorWrapperConn) Close() (err error) {
|
||||
err = c.Conn.Close()
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: c.ConnID,
|
||||
DialID: c.DialID,
|
||||
Error: err,
|
||||
Operation: errorx.CloseOperation,
|
||||
}.MaybeBuild()
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func TestErrorWrapperFailure(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if conn != nil {
|
||||
t.Fatal("expected a nil conn here")
|
||||
}
|
||||
errorWrapperCheckErr(t, err, errorx.ConnectOperation)
|
||||
}
|
||||
|
||||
func errorWrapperCheckErr(t *testing.T, err error, op string) {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected another error here")
|
||||
}
|
||||
var errWrapper *errorx.ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("cannot cast to ErrWrapper")
|
||||
}
|
||||
if errWrapper.DialID == 0 {
|
||||
t.Fatal("unexpected DialID")
|
||||
}
|
||||
if errWrapper.Operation != op {
|
||||
t.Fatal("unexpected Operation")
|
||||
}
|
||||
if errWrapper.Failure != errorx.FailureEOFError {
|
||||
t.Fatal("unexpected failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorWrapperSuccess(t *testing.T) {
|
||||
ctx := dialid.WithDialID(context.Background())
|
||||
d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFConnDialer{}}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn here")
|
||||
}
|
||||
count, err := conn.Read(nil)
|
||||
errorWrapperCheckIOResult(t, count, err, errorx.ReadOperation)
|
||||
count, err = conn.Write(nil)
|
||||
errorWrapperCheckIOResult(t, count, err, errorx.WriteOperation)
|
||||
err = conn.Close()
|
||||
errorWrapperCheckErr(t, err, errorx.CloseOperation)
|
||||
}
|
||||
|
||||
func errorWrapperCheckIOResult(t *testing.T, count int, err error, op string) {
|
||||
if count != 0 {
|
||||
t.Fatal("expected nil count here")
|
||||
}
|
||||
errorWrapperCheckErr(t, err, op)
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FakeDialer struct {
|
||||
Conn net.Conn
|
||||
Err error
|
||||
}
|
||||
|
||||
func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
time.Sleep(10 * time.Microsecond)
|
||||
return d.Conn, d.Err
|
||||
}
|
||||
|
||||
type FakeConn struct {
|
||||
ReadError error
|
||||
ReadData []byte
|
||||
SetDeadlineError error
|
||||
SetReadDeadlineError error
|
||||
SetWriteDeadlineError error
|
||||
WriteError error
|
||||
}
|
||||
|
||||
func (c *FakeConn) Read(b []byte) (int, error) {
|
||||
if len(c.ReadData) > 0 {
|
||||
n := copy(b, c.ReadData)
|
||||
c.ReadData = c.ReadData[n:]
|
||||
return n, nil
|
||||
}
|
||||
if c.ReadError != nil {
|
||||
return 0, c.ReadError
|
||||
}
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (c *FakeConn) Write(b []byte) (n int, err error) {
|
||||
if c.WriteError != nil {
|
||||
return 0, c.WriteError
|
||||
}
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) Close() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (*FakeConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (*FakeConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetDeadline(t time.Time) (err error) {
|
||||
return c.SetDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetReadDeadline(t time.Time) (err error) {
|
||||
return c.SetReadDeadlineError
|
||||
}
|
||||
|
||||
func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
|
||||
return c.SetWriteDeadlineError
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestTLSDialerSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
log.SetLevel(log.DebugLevel)
|
||||
dialer := dialer.TLSDialer{Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.LoggingTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Logger: log.Log,
|
||||
},
|
||||
}
|
||||
txp := &http.Transport{DialTLS: func(network, address string) (net.Conn, error) {
|
||||
// AlpineLinux edge is still using Go 1.13. We cannot switch to
|
||||
// using DialTLSContext here as we'd like to until either Alpine
|
||||
// switches to Go 1.14 or we drop the MK dependency.
|
||||
return dialer.DialTLSContext(context.Background(), network, address)
|
||||
}}
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
func TestDNSDialerSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
log.SetLevel(log.DebugLevel)
|
||||
dialer := dialer.DNSDialer{
|
||||
Dialer: dialer.LoggingDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
Logger: log.Log,
|
||||
},
|
||||
Resolver: new(net.Resolver),
|
||||
}
|
||||
txp := &http.Transport{DialContext: dialer.DialContext}
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("http://www.google.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
|
||||
)
|
||||
|
||||
// Logger is the logger assumed by this package
|
||||
type Logger interface {
|
||||
Debugf(format string, v ...interface{})
|
||||
Debug(message string)
|
||||
}
|
||||
|
||||
// LoggingDialer is a Dialer with logging
|
||||
type LoggingDialer struct {
|
||||
Dialer
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d LoggingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d.Logger.Debugf("dial %s/%s...", address, network)
|
||||
start := time.Now()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
stop := time.Now()
|
||||
d.Logger.Debugf("dial %s/%s... %+v in %s", address, network, err, stop.Sub(start))
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// LoggingTLSHandshaker is a TLSHandshaker with logging
|
||||
type LoggingTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h LoggingTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
h.Logger.Debugf("tls {sni=%s next=%+v}...", config.ServerName, config.NextProtos)
|
||||
start := time.Now()
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
stop := time.Now()
|
||||
h.Logger.Debugf(
|
||||
"tls {sni=%s next=%+v}... %+v in %s {next=%s cipher=%s v=%s}", config.ServerName,
|
||||
config.NextProtos, err, stop.Sub(start), state.NegotiatedProtocol,
|
||||
tlsx.CipherSuiteString(state.CipherSuite), tlsx.VersionString(state.Version))
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
var _ Dialer = LoggingDialer{}
|
||||
var _ TLSHandshaker = LoggingTLSHandshaker{}
|
||||
@@ -0,0 +1,42 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestLoggingDialerFailure(t *testing.T) {
|
||||
d := dialer.LoggingDialer{
|
||||
Dialer: dialer.EOFDialer{},
|
||||
Logger: log.Log,
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggingTLSHandshakerFailure(t *testing.T) {
|
||||
h := dialer.LoggingTLSHandshaker{
|
||||
TLSHandshaker: dialer.EOFTLSHandshaker{},
|
||||
Logger: log.Log,
|
||||
}
|
||||
tlsconn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
|
||||
ServerName: "www.google.com",
|
||||
})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if tlsconn != nil {
|
||||
t.Fatal("expected nil tlsconn here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// ProxyDialer is a dialer that uses a proxy. If the ProxyURL is not configured, this
|
||||
// dialer is a passthrough for the next Dialer in chain. Otherwise, it will internally
|
||||
// create a SOCKS5 dialer that will connect to the proxy using the underlying Dialer.
|
||||
//
|
||||
// As a special case, you can force a proxy to be used only extemporarily. To this end,
|
||||
// you can use the WithProxyURL function, to store the proxy URL in the context. This
|
||||
// will take precedence over any otherwise configured proxy. The use case for this
|
||||
// functionality is when you need a tunnel to contact OONI probe services.
|
||||
type ProxyDialer struct {
|
||||
Dialer
|
||||
ProxyURL *url.URL
|
||||
}
|
||||
|
||||
type proxyKey struct{}
|
||||
|
||||
// ContextProxyURL retrieves the proxy URL from the context. This is mainly used
|
||||
// to force a tunnel when we fail contacting OONI probe services otherwise.
|
||||
func ContextProxyURL(ctx context.Context) *url.URL {
|
||||
url, _ := ctx.Value(proxyKey{}).(*url.URL)
|
||||
return url
|
||||
}
|
||||
|
||||
// WithProxyURL assigns the proxy URL to the context
|
||||
func WithProxyURL(ctx context.Context, url *url.URL) context.Context {
|
||||
return context.WithValue(ctx, proxyKey{}, url)
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
url := ContextProxyURL(ctx) // context URL takes precendence
|
||||
if url == nil {
|
||||
url = d.ProxyURL
|
||||
}
|
||||
if url == nil {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
if url.Scheme != "socks5" {
|
||||
return nil, errors.New("Scheme is not socks5")
|
||||
}
|
||||
// the code at proxy/socks5.go never fails; see https://git.io/JfJ4g
|
||||
child, _ := proxy.SOCKS5(
|
||||
network, url.Host, nil, proxyDialerWrapper{Dialer: d.Dialer})
|
||||
return d.dial(ctx, child, network, address)
|
||||
}
|
||||
|
||||
func (d ProxyDialer) dial(
|
||||
ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) {
|
||||
connch := make(chan net.Conn)
|
||||
errch := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := child.Dial(network, address)
|
||||
if err != nil {
|
||||
errch <- err
|
||||
return
|
||||
}
|
||||
select {
|
||||
case connch <- conn:
|
||||
default:
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case err := <-errch:
|
||||
return nil, err
|
||||
case conn := <-connch:
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
// proxyDialerWrapper is required because SOCKS5 expects a Dialer.Dial type but internally
|
||||
// it checks whether DialContext is available and prefers that. So, we need to use this
|
||||
// structure to cast our inner Dialer the way in which SOCKS5 likes it.
|
||||
//
|
||||
// See https://git.io/JfJ4g.
|
||||
type proxyDialerWrapper struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
func (d proxyDialerWrapper) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type ProxyDialerWrapper = proxyDialerWrapper
|
||||
|
||||
func (d ProxyDialer) DialContextWithDialer(
|
||||
ctx context.Context, child proxy.Dialer, network, address string) (net.Conn, error) {
|
||||
return d.dial(ctx, child, network, address)
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestProxyDialerDialContextNoProxyURL(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{Err: expected},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerContextTakesPrecedence(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{Err: expected},
|
||||
ProxyURL: &url.URL{Scheme: "antani"},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ctx = dialer.WithProxyURL(ctx, &url.URL{Scheme: "socks5", Host: "[::1]:443"})
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextInvalidScheme(t *testing.T) {
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{},
|
||||
ProxyURL: &url.URL{Scheme: "antani"},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if err.Error() != "Scheme is not socks5" {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithEOF(t *testing.T) {
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: io.EOF,
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // immediately fail
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: io.EOF,
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContext(ctx, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithDialerSuccess(t *testing.T) {
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Conn: &dialer.FakeConn{
|
||||
ReadError: io.EOF,
|
||||
WriteError: io.EOF,
|
||||
},
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContextWithDialer(
|
||||
context.Background(), dialer.ProxyDialerWrapper{
|
||||
Dialer: d.Dialer,
|
||||
}, "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestProxyDialerDialContextWithDialerCanceledContext(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Stop immediately. The FakeDialer sleeps for some microseconds so
|
||||
// it is much more likely we immediately exit with done context. The
|
||||
// arm where we receive the conn is much less likely.
|
||||
cancel()
|
||||
d := dialer.ProxyDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Conn: &dialer.FakeConn{
|
||||
ReadError: io.EOF,
|
||||
WriteError: io.EOF,
|
||||
},
|
||||
},
|
||||
ProxyURL: &url.URL{Scheme: "socks5"},
|
||||
}
|
||||
conn, err := d.DialContextWithDialer(
|
||||
ctx, dialer.ProxyDialerWrapper{
|
||||
Dialer: d.Dialer,
|
||||
}, "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialerWrapper(t *testing.T) {
|
||||
d := dialer.ProxyDialerWrapper{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: io.EOF,
|
||||
},
|
||||
}
|
||||
conn, err := d.Dial("tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("conn is not nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/internal/tlsx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
// SaverDialer saves events occurring during the dial
|
||||
type SaverDialer struct {
|
||||
Dialer
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d SaverDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
start := time.Now()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
stop := time.Now()
|
||||
d.Saver.Write(trace.Event{
|
||||
Address: address,
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: errorx.ConnectOperation,
|
||||
Proto: network,
|
||||
Time: stop,
|
||||
})
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// SaverTLSHandshaker saves events occurring during the handshake
|
||||
type SaverTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// Handshake implements TLSHandshaker.Handshake
|
||||
func (h SaverTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
start := time.Now()
|
||||
h.Saver.Write(trace.Event{
|
||||
Name: "tls_handshake_start",
|
||||
NoTLSVerify: config.InsecureSkipVerify,
|
||||
TLSNextProtos: config.NextProtos,
|
||||
TLSServerName: config.ServerName,
|
||||
Time: start,
|
||||
})
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
stop := time.Now()
|
||||
h.Saver.Write(trace.Event{
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: "tls_handshake_done",
|
||||
NoTLSVerify: config.InsecureSkipVerify,
|
||||
TLSCipherSuite: tlsx.CipherSuiteString(state.CipherSuite),
|
||||
TLSNegotiatedProto: state.NegotiatedProtocol,
|
||||
TLSNextProtos: config.NextProtos,
|
||||
TLSPeerCerts: trace.PeerCerts(state, err),
|
||||
TLSServerName: config.ServerName,
|
||||
TLSVersion: tlsx.VersionString(state.Version),
|
||||
Time: stop,
|
||||
})
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
// SaverConnDialer wraps the returned connection such that we
|
||||
// collect all the read/write events that occur.
|
||||
type SaverConnDialer struct {
|
||||
Dialer
|
||||
Saver *trace.Saver
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d SaverConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return saverConn{saver: d.Saver, Conn: conn}, nil
|
||||
}
|
||||
|
||||
type saverConn struct {
|
||||
net.Conn
|
||||
saver *trace.Saver
|
||||
}
|
||||
|
||||
func (c saverConn) Read(p []byte) (int, error) {
|
||||
start := time.Now()
|
||||
count, err := c.Conn.Read(p)
|
||||
stop := time.Now()
|
||||
c.saver.Write(trace.Event{
|
||||
Data: p[:count],
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
NumBytes: count,
|
||||
Name: errorx.ReadOperation,
|
||||
Time: stop,
|
||||
})
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c saverConn) Write(p []byte) (int, error) {
|
||||
start := time.Now()
|
||||
count, err := c.Conn.Write(p)
|
||||
stop := time.Now()
|
||||
c.saver.Write(trace.Event{
|
||||
Data: p[:count],
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
NumBytes: count,
|
||||
Name: errorx.WriteOperation,
|
||||
Time: stop,
|
||||
})
|
||||
return count, err
|
||||
}
|
||||
|
||||
var _ Dialer = SaverDialer{}
|
||||
var _ TLSHandshaker = SaverTLSHandshaker{}
|
||||
var _ net.Conn = saverConn{}
|
||||
@@ -0,0 +1,371 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
)
|
||||
|
||||
func TestSaverDialerFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
saver := &trace.Saver{}
|
||||
dlr := dialer.SaverDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected another error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 1 {
|
||||
t.Fatal("expected a single event here")
|
||||
}
|
||||
if ev[0].Address != "www.google.com:443" {
|
||||
t.Fatal("unexpected Address")
|
||||
}
|
||||
if ev[0].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if !errors.Is(ev[0].Err, expected) {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[0].Name != errorx.ConnectOperation {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[0].Proto != "tcp" {
|
||||
t.Fatal("unexpected Proto")
|
||||
}
|
||||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverConnDialerFailure(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
saver := &trace.Saver{}
|
||||
dlr := dialer.SaverConnDialer{
|
||||
Dialer: dialer.FakeDialer{
|
||||
Err: expected,
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
|
||||
// This is the most common use case for collecting reads, writes
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
nextprotos := []string{"h2"}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Config: &tls.Config{NextProtos: nextprotos},
|
||||
Dialer: dialer.SaverConnDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
Saver: saver,
|
||||
},
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
// Implementation note: we don't close the connection here because it is
|
||||
// very handy to have the last event being the end of the handshake
|
||||
_, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) < 4 {
|
||||
// it's a bit tricky to be sure about the right number of
|
||||
// events because network conditions may influence that
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if ev[0].Name != "tls_handshake_start" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[0].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
last := len(ev) - 1
|
||||
for idx := 1; idx < last; idx++ {
|
||||
if ev[idx].Data == nil {
|
||||
t.Fatal("unexpected Data")
|
||||
}
|
||||
if ev[idx].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[idx].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[idx].NumBytes <= 0 {
|
||||
t.Fatal("unexpected NumBytes")
|
||||
}
|
||||
switch ev[idx].Name {
|
||||
case errorx.ReadOperation, errorx.WriteOperation:
|
||||
default:
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[idx].Time.Before(ev[idx-1].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
if ev[last].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[last].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[last].Name != "tls_handshake_done" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[last].TLSCipherSuite == "" {
|
||||
t.Fatal("unexpected TLSCipherSuite")
|
||||
}
|
||||
if ev[last].TLSNegotiatedProto != "h2" {
|
||||
t.Fatal("unexpected TLSNegotiatedProto")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[last].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[last].TLSPeerCerts == nil {
|
||||
t.Fatal("unexpected TLSPeerCerts")
|
||||
}
|
||||
if ev[last].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if ev[last].TLSVersion == "" {
|
||||
t.Fatal("unexpected TLSVersion")
|
||||
}
|
||||
if ev[last].Time.Before(ev[last-1].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerSuccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
nextprotos := []string{"h2"}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Config: &tls.Config{NextProtos: nextprotos},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Close()
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("unexpected number of events")
|
||||
}
|
||||
if ev[0].Name != "tls_handshake_start" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[0].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[0].Time.After(time.Now()) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
if ev[1].Duration <= 0 {
|
||||
t.Fatal("unexpected Duration")
|
||||
}
|
||||
if ev[1].Err != nil {
|
||||
t.Fatal("unexpected Err")
|
||||
}
|
||||
if ev[1].Name != "tls_handshake_done" {
|
||||
t.Fatal("unexpected Name")
|
||||
}
|
||||
if ev[1].TLSCipherSuite == "" {
|
||||
t.Fatal("unexpected TLSCipherSuite")
|
||||
}
|
||||
if ev[1].TLSNegotiatedProto != "h2" {
|
||||
t.Fatal("unexpected TLSNegotiatedProto")
|
||||
}
|
||||
if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) {
|
||||
t.Fatal("unexpected TLSNextProtos")
|
||||
}
|
||||
if ev[1].TLSPeerCerts == nil {
|
||||
t.Fatal("unexpected TLSPeerCerts")
|
||||
}
|
||||
if ev[1].TLSServerName != "www.google.com" {
|
||||
t.Fatal("unexpected TLSServerName")
|
||||
}
|
||||
if ev[1].TLSVersion == "" {
|
||||
t.Fatal("unexpected TLSVersion")
|
||||
}
|
||||
if ev[1].Time.Before(ev[0].Time) {
|
||||
t.Fatal("unexpected Time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerHostnameError(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "wrong.host.badssl.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify == true {
|
||||
t.Fatal("expected NoTLSVerify to be false")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "expired.badssl.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify == true {
|
||||
t.Fatal("expected NoTLSVerify to be false")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "self-signed.badssl.com:443")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify == true {
|
||||
t.Fatal("expected NoTLSVerify to be false")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skip test in short mode")
|
||||
}
|
||||
saver := &trace.Saver{}
|
||||
tlsdlr := dialer.TLSDialer{
|
||||
Config: &tls.Config{InsecureSkipVerify: true},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SaverTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
Saver: saver,
|
||||
},
|
||||
}
|
||||
conn, err := tlsdlr.DialTLSContext(
|
||||
context.Background(), "tcp", "self-signed.badssl.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("expected non-nil conn here")
|
||||
}
|
||||
conn.Close()
|
||||
for _, ev := range saver.Read() {
|
||||
if ev.Name != "tls_handshake_done" {
|
||||
continue
|
||||
}
|
||||
if ev.NoTLSVerify != true {
|
||||
t.Fatal("expected NoTLSVerify to be true")
|
||||
}
|
||||
if len(ev.TLSPeerCerts) < 1 {
|
||||
t.Fatal("expected at least a certificate here")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// +build !shaping
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ShapingDialer ensures we don't use too much bandwidth
|
||||
// when using integration tests at GitHub. To select
|
||||
// the implementation with shaping use `-tags shaping`.
|
||||
type ShapingDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ShapingDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
// +build shaping
|
||||
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShapingDialer ensures we don't use too much bandwidth
|
||||
// when using integration tests at GitHub. To select
|
||||
// the implementation with shaping use `-tags shaping`.
|
||||
type ShapingDialer struct {
|
||||
Dialer
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d ShapingDialer) DialContext(
|
||||
ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &shapingConn{Conn: conn}, nil
|
||||
}
|
||||
|
||||
type shapingConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c shapingConn) Read(p []byte) (int, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
func (c shapingConn) Write(p []byte) (int, error) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
func TestGood(t *testing.T) {
|
||||
txp := netx.NewHTTPTransport(netx.Config{
|
||||
Dialer: dialer.ShapingDialer{
|
||||
Dialer: new(net.Dialer),
|
||||
},
|
||||
})
|
||||
client := &http.Client{Transport: txp}
|
||||
resp, err := client.Get("https://www.google.com/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TimeoutDialer is a Dialer that enforces a timeout
|
||||
type TimeoutDialer struct {
|
||||
Dialer
|
||||
ConnectTimeout time.Duration // default: 30 seconds
|
||||
}
|
||||
|
||||
// DialContext implements Dialer.DialContext
|
||||
func (d TimeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
timeout := 30 * time.Second
|
||||
if d.ConnectTimeout != 0 {
|
||||
timeout = d.ConnectTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
)
|
||||
|
||||
type SlowDialer struct{}
|
||||
|
||||
func (SlowDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(30 * time.Second):
|
||||
return nil, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutDialer(t *testing.T) {
|
||||
d := dialer.TimeoutDialer{Dialer: SlowDialer{}, ConnectTimeout: time.Second}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn here")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/connid"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
// TLSHandshaker is the generic TLS handshaker
|
||||
type TLSHandshaker interface {
|
||||
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
|
||||
net.Conn, tls.ConnectionState, error)
|
||||
}
|
||||
|
||||
// SystemTLSHandshaker is the system TLS handshaker.
|
||||
type SystemTLSHandshaker struct{}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h SystemTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
tlsconn := tls.Client(conn, config)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
return tlsconn, tlsconn.ConnectionState(), nil
|
||||
}
|
||||
|
||||
// TimeoutTLSHandshaker is a TLSHandshaker with timeout
|
||||
type TimeoutTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
HandshakeTimeout time.Duration // default: 10 second
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h TimeoutTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
timeout := 10 * time.Second
|
||||
if h.HandshakeTimeout != 0 {
|
||||
timeout = h.HandshakeTimeout
|
||||
}
|
||||
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
tlsconn, connstate, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
conn.SetDeadline(time.Time{})
|
||||
return tlsconn, connstate, err
|
||||
}
|
||||
|
||||
// ErrorWrapperTLSHandshaker wraps the returned error to be an OONI error
|
||||
type ErrorWrapperTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h ErrorWrapperTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
err = errorx.SafeErrWrapperBuilder{
|
||||
ConnID: connID,
|
||||
Error: err,
|
||||
Operation: errorx.TLSHandshakeOperation,
|
||||
}.MaybeBuild()
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
// EmitterTLSHandshaker emits events using the MeasurementRoot
|
||||
type EmitterTLSHandshaker struct {
|
||||
TLSHandshaker
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h EmitterTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
connID := connid.Compute(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
root := modelx.ContextMeasurementRootOrDefault(ctx)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
TLSHandshakeStart: &modelx.TLSHandshakeStartEvent{
|
||||
ConnID: connID,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
SNI: config.ServerName,
|
||||
},
|
||||
})
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
root.Handler.OnMeasurement(modelx.Measurement{
|
||||
TLSHandshakeDone: &modelx.TLSHandshakeDoneEvent{
|
||||
ConnID: connID,
|
||||
ConnectionState: modelx.NewTLSConnectionState(state),
|
||||
Error: err,
|
||||
DurationSinceBeginning: time.Now().Sub(root.Beginning),
|
||||
},
|
||||
})
|
||||
return tlsconn, state, err
|
||||
}
|
||||
|
||||
// TLSDialer is the TLS dialer
|
||||
type TLSDialer struct {
|
||||
Config *tls.Config
|
||||
Dialer Dialer
|
||||
TLSHandshaker TLSHandshaker
|
||||
}
|
||||
|
||||
// DialTLSContext is like tls.DialTLS but with the signature of net.Dialer.DialContext
|
||||
func (d TLSDialer) DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// Implementation note: when DialTLS is not set, the code in
|
||||
// net/http will perform the handshake. Otherwise, if DialTLS
|
||||
// is set, we will end up here. This code is still used when
|
||||
// performing non-HTTP TLS-enabled dial operations.
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config := d.Config
|
||||
if config == nil {
|
||||
config = new(tls.Config)
|
||||
} else {
|
||||
config = config.Clone()
|
||||
}
|
||||
if config.ServerName == "" {
|
||||
config.ServerName = host
|
||||
}
|
||||
tlsconn, _, err := d.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return tlsconn, nil
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package dialer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
|
||||
)
|
||||
|
||||
func TestSystemTLSHandshakerEOFError(t *testing.T) {
|
||||
h := dialer.SystemTLSHandshaker{}
|
||||
conn, _, err := h.Handshake(context.Background(), dialer.EOFConn{}, &tls.Config{
|
||||
ServerName: "x.org",
|
||||
})
|
||||
if err != io.EOF {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) {
|
||||
h := dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 200 * time.Millisecond,
|
||||
}
|
||||
expected := errors.New("mocked error")
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), &dialer.FakeConn{SetDeadlineError: expected},
|
||||
new(tls.Config))
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutTLSHandshakerEOFError(t *testing.T) {
|
||||
h := dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 200 * time.Millisecond,
|
||||
}
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), dialer.EOFConn{}, &tls.Config{ServerName: "x.org"})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) {
|
||||
h := dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 200 * time.Millisecond,
|
||||
}
|
||||
underlying := &SetDeadlineConn{}
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), underlying, &tls.Config{ServerName: "x.org"})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
if len(underlying.deadlines) != 2 {
|
||||
t.Fatal("SetDeadline not called twice")
|
||||
}
|
||||
if underlying.deadlines[0].Before(time.Now()) {
|
||||
t.Fatal("the first SetDeadline call was incorrect")
|
||||
}
|
||||
if !underlying.deadlines[1].IsZero() {
|
||||
t.Fatal("the second SetDeadline call was incorrect")
|
||||
}
|
||||
}
|
||||
|
||||
type SetDeadlineConn struct {
|
||||
dialer.EOFConn
|
||||
deadlines []time.Time
|
||||
}
|
||||
|
||||
func (c *SetDeadlineConn) SetDeadline(t time.Time) error {
|
||||
c.deadlines = append(c.deadlines, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestErrorWrapperTLSHandshakerFailure(t *testing.T) {
|
||||
h := dialer.ErrorWrapperTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}}
|
||||
conn, _, err := h.Handshake(
|
||||
context.Background(), dialer.EOFConn{}, new(tls.Config))
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
var errWrapper *errorx.ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("cannot cast to ErrWrapper")
|
||||
}
|
||||
if errWrapper.ConnID == 0 {
|
||||
t.Fatal("unexpected ConnID")
|
||||
}
|
||||
if errWrapper.Failure != errorx.FailureEOFError {
|
||||
t.Fatal("unexpected Failure")
|
||||
}
|
||||
if errWrapper.Operation != errorx.TLSHandshakeOperation {
|
||||
t.Fatal("unexpected Operation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmitterTLSHandshakerFailure(t *testing.T) {
|
||||
saver := &handlers.SavingHandler{}
|
||||
ctx := modelx.WithMeasurementRoot(context.Background(), &modelx.MeasurementRoot{
|
||||
Beginning: time.Now(),
|
||||
Handler: saver,
|
||||
})
|
||||
h := dialer.EmitterTLSHandshaker{TLSHandshaker: dialer.EOFTLSHandshaker{}}
|
||||
conn, _, err := h.Handshake(ctx, dialer.EOFConn{}, &tls.Config{
|
||||
ServerName: "www.kernel.org",
|
||||
})
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
events := saver.Read()
|
||||
if len(events) != 2 {
|
||||
t.Fatal("Wrong number of events")
|
||||
}
|
||||
if events[0].TLSHandshakeStart == nil {
|
||||
t.Fatal("missing TLSHandshakeStart event")
|
||||
}
|
||||
if events[0].TLSHandshakeStart.ConnID == 0 {
|
||||
t.Fatal("expected nonzero ConnID")
|
||||
}
|
||||
if events[0].TLSHandshakeStart.DurationSinceBeginning == 0 {
|
||||
t.Fatal("expected nonzero DurationSinceBeginning")
|
||||
}
|
||||
if events[0].TLSHandshakeStart.SNI != "www.kernel.org" {
|
||||
t.Fatal("expected nonzero SNI")
|
||||
}
|
||||
if events[1].TLSHandshakeDone == nil {
|
||||
t.Fatal("missing TLSHandshakeDone event")
|
||||
}
|
||||
if events[1].TLSHandshakeDone.ConnID == 0 {
|
||||
t.Fatal("expected nonzero ConnID")
|
||||
}
|
||||
if events[1].TLSHandshakeDone.DurationSinceBeginning == 0 {
|
||||
t.Fatal("expected nonzero DurationSinceBeginning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureSplitHostPort(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com") // missing port
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureDialing(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{Dialer: dialer.EOFDialer{}}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureHandshaking(t *testing.T) {
|
||||
rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}}
|
||||
dialer := dialer.TLSDialer{
|
||||
Dialer: dialer.EOFConnDialer{},
|
||||
TLSHandshaker: rec,
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
if rec.SNI != "www.google.com" {
|
||||
t.Fatal("unexpected SNI value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) {
|
||||
rec := &RecorderTLSHandshaker{TLSHandshaker: dialer.SystemTLSHandshaker{}}
|
||||
dialer := dialer.TLSDialer{
|
||||
Config: &tls.Config{
|
||||
ServerName: "x.org",
|
||||
},
|
||||
Dialer: dialer.EOFConnDialer{},
|
||||
TLSHandshaker: rec,
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(
|
||||
context.Background(), "tcp", "www.google.com:443")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
if rec.SNI != "x.org" {
|
||||
t.Fatal("unexpected SNI value")
|
||||
}
|
||||
}
|
||||
|
||||
type RecorderTLSHandshaker struct {
|
||||
dialer.TLSHandshaker
|
||||
SNI string
|
||||
}
|
||||
|
||||
func (h *RecorderTLSHandshaker) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
h.SNI = config.ServerName
|
||||
return h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
}
|
||||
|
||||
func TestDialTLSContextGood(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{
|
||||
Config: &tls.Config{ServerName: "google.com"},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatal("connection is nil")
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestDialTLSContextTimeout(t *testing.T) {
|
||||
dialer := dialer.TLSDialer{
|
||||
Config: &tls.Config{ServerName: "google.com"},
|
||||
Dialer: new(net.Dialer),
|
||||
TLSHandshaker: dialer.ErrorWrapperTLSHandshaker{
|
||||
TLSHandshaker: dialer.TimeoutTLSHandshaker{
|
||||
TLSHandshaker: dialer.SystemTLSHandshaker{},
|
||||
HandshakeTimeout: 10 * time.Microsecond,
|
||||
},
|
||||
},
|
||||
}
|
||||
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
|
||||
if err.Error() != errorx.FailureGenericTimeoutError {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("connection is not nil")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user