feat: tlsping and tcpping using step-by-step (#815)

## Checklist

- [x] I have read the [contribution guidelines](https://github.com/ooni/probe-cli/blob/master/CONTRIBUTING.md)
- [x] reference issue for this pull request: https://github.com/ooni/probe/issues/2158
- [x] if you changed anything related how experiments work and you need to reflect these changes in the ooni/spec repository, please link to the related ooni/spec pull request: https://github.com/ooni/spec/pull/250

## Description

This diff refactors the codebase to reimplement tlsping and tcpping
to use the step-by-step measurements style.

See docs/design/dd-003-step-by-step.md for more information on the
step-by-step measurement style.
This commit is contained in:
Simone Basso
2022-07-01 12:22:22 +02:00
committed by GitHub
parent 5371c7f486
commit 5ebdeb56ca
48 changed files with 2825 additions and 299 deletions
+1 -1
View File
@@ -47,7 +47,7 @@ func (r *bogonResolver) LookupHost(ctx context.Context, hostname string) ([]stri
for _, addr := range addrs {
if IsBogon(addr) {
// wrap ErrDNSBogon as documented
return nil, newErrWrapper(classifyResolverError, ResolveOperation, ErrDNSBogon)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, ErrDNSBogon)
}
}
return addrs, nil
+12 -12
View File
@@ -15,8 +15,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/scrubber"
)
// classifyGenericError is maps an error occurred during an operation
// to an OONI failure string. This specific classifier is the most
// ClassifyGenericError maps an error occurred during an operation to
// an OONI failure string. This specific classifier is the most
// generic one. You usually use it when mapping I/O errors. You should
// check whether there is a specific classifier for more specific
// operations (e.g., DNS resolution, TLS handshake).
@@ -38,7 +38,7 @@ import (
// If everything else fails, this classifier returns a string
// like "unknown_failure: XXX" where XXX has been scrubbed
// so to remove any network endpoints from the original error string.
func classifyGenericError(err error) string {
func ClassifyGenericError(err error) string {
// The list returned here matches the values used by MK unless
// explicitly noted otherwise with a comment.
@@ -139,7 +139,7 @@ const (
quicTLSUnrecognizedName = 112
)
// classifyQUICHandshakeError maps errors during a QUIC
// ClassifyQUICHandshakeError maps errors during a QUIC
// handshake to OONI failure strings.
//
// If the input error is an *ErrWrapper we don't perform
@@ -147,7 +147,7 @@ const (
//
// If this classifier fails, it calls ClassifyGenericError
// and returns to the caller its return value.
func classifyQUICHandshakeError(err error) string {
func ClassifyQUICHandshakeError(err error) string {
// Robustness: handle the case where we're passed a wrapped error.
var errwrapper *ErrWrapper
@@ -207,7 +207,7 @@ func classifyQUICHandshakeError(err error) string {
}
}
}
return classifyGenericError(err)
return ClassifyGenericError(err)
}
// quicIsCertificateError tells us whether a specific TLS alert error
@@ -277,7 +277,7 @@ var (
// anything as explained in getaddrinfo_linux.go.
var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
// classifyResolverError maps DNS resolution errors to
// ClassifyResolverError maps DNS resolution errors to
// OONI failure strings.
//
// If the input error is an *ErrWrapper we don't perform
@@ -285,7 +285,7 @@ var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
//
// If this classifier fails, it calls ClassifyGenericError and
// returns to the caller its return value.
func classifyResolverError(err error) string {
func ClassifyResolverError(err error) string {
// Robustness: handle the case where we're passed a wrapped error.
var errwrapper *ErrWrapper
@@ -310,10 +310,10 @@ func classifyResolverError(err error) string {
if errors.Is(err, ErrAndroidDNSCacheNoData) {
return FailureAndroidDNSCacheNoData
}
return classifyGenericError(err)
return ClassifyGenericError(err)
}
// classifyTLSHandshakeError maps an error occurred during the TLS
// ClassifyTLSHandshakeError maps an error occurred during the TLS
// handshake to an OONI failure string.
//
// If the input error is an *ErrWrapper we don't perform
@@ -321,7 +321,7 @@ func classifyResolverError(err error) string {
//
// If this classifier fails, it calls ClassifyGenericError and
// returns to the caller its return value.
func classifyTLSHandshakeError(err error) string {
func ClassifyTLSHandshakeError(err error) string {
// Robustness: handle the case where we're passed a wrapped error.
var errwrapper *ErrWrapper
@@ -345,5 +345,5 @@ func classifyTLSHandshakeError(err error) string {
// Test case: https://expired.badssl.com/
return FailureSSLInvalidCertificate
}
return classifyGenericError(err)
return ClassifyGenericError(err)
}
+44 -44
View File
@@ -18,13 +18,13 @@ func TestClassifyGenericError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyGenericError(err) != FailureEOFError {
if ClassifyGenericError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for a system call error", func(t *testing.T) {
if classifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
if ClassifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
t.Fatal("unexpected results")
}
})
@@ -35,63 +35,63 @@ func TestClassifyGenericError(t *testing.T) {
// is just an implementation detail.
t.Run("for operation was canceled", func(t *testing.T) {
if classifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
if ClassifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
t.Fatal("unexpected result")
}
})
t.Run("for EOF", func(t *testing.T) {
if classifyGenericError(io.EOF) != FailureEOFError {
if ClassifyGenericError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})
t.Run("for context deadline exceeded", func(t *testing.T) {
if classifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
if ClassifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for stun's transaction is timed out", func(t *testing.T) {
if classifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
if ClassifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for i/o timeout", func(t *testing.T) {
if classifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
if ClassifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for TLS handshake timeout", func(t *testing.T) {
err := errors.New("net/http: TLS handshake timeout")
if classifyGenericError(err) != FailureGenericTimeoutError {
if ClassifyGenericError(err) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for no such host", func(t *testing.T) {
if classifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
if ClassifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
t.Fatal("unexpected results")
}
})
t.Run("for dns server misbehaving", func(t *testing.T) {
if classifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
if ClassifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
t.Fatal("unexpected results")
}
})
t.Run("for no answer from DNS server", func(t *testing.T) {
if classifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
if ClassifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
t.Fatal("unexpected results")
}
})
t.Run("for use of closed network connection", func(t *testing.T) {
err := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
if classifyGenericError(err) != FailureConnectionAlreadyClosed {
if ClassifyGenericError(err) != FailureConnectionAlreadyClosed {
t.Fatal("unexpected results")
}
})
@@ -99,7 +99,7 @@ func TestClassifyGenericError(t *testing.T) {
// Now we're back in ClassifyGenericError
t.Run("for context.Canceled", func(t *testing.T) {
if classifyGenericError(context.Canceled) != FailureInterrupted {
if ClassifyGenericError(context.Canceled) != FailureInterrupted {
t.Fatal("unexpected result")
}
})
@@ -108,7 +108,7 @@ func TestClassifyGenericError(t *testing.T) {
t.Run("with an IPv4 address", func(t *testing.T) {
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: some error")
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
out := classifyGenericError(input)
out := ClassifyGenericError(input)
if out != expected {
t.Fatal(cmp.Diff(expected, out))
}
@@ -117,7 +117,7 @@ func TestClassifyGenericError(t *testing.T) {
t.Run("with an IPv6 address", func(t *testing.T) {
input := errors.New("read tcp [::1]:56948->[::1]:443: some error")
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
out := classifyGenericError(input)
out := ClassifyGenericError(input)
if out != expected {
t.Fatal(cmp.Diff(expected, out))
}
@@ -131,100 +131,100 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyQUICHandshakeError(err) != FailureEOFError {
if ClassifyQUICHandshakeError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for incompatible quic version", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
if ClassifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
t.Fatal("unexpected results")
}
})
t.Run("for stateless reset", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
if ClassifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
t.Fatal("unexpected results")
}
})
t.Run("for handshake timeout", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
if ClassifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for idle timeout", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
if ClassifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
t.Fatal("unexpected results")
}
})
t.Run("for connection refused", func(t *testing.T) {
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
t.Fatal("unexpected results")
}
})
t.Run("for bad certificate", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertBadCertificate
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for unsupported certificate", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertUnsupportedCertificate
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for certificate expired", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertCertificateExpired
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for certificate revoked", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertCertificateRevoked
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for certificate unknown", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertCertificateUnknown
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
t.Fatal("unexpected results")
}
})
t.Run("for decrypt error", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertDecryptError
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
t.Fatal("unexpected results")
}
})
t.Run("for handshake failure", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertHandshakeFailure
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
t.Fatal("unexpected results")
}
})
t.Run("for unknown CA", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSAlertUnknownCA
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
t.Fatal("unexpected results")
}
})
t.Run("for unrecognized hostname", func(t *testing.T) {
var err quic.TransportErrorCode = quicTLSUnrecognizedName
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
t.Fatal("unexpected results")
}
})
@@ -234,13 +234,13 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
ErrorCode: quic.InternalError,
ErrorMessage: FailureHostUnreachable,
}
if classifyQUICHandshakeError(err) != FailureHostUnreachable {
if ClassifyQUICHandshakeError(err) != FailureHostUnreachable {
t.Fatal("unexpected results")
}
})
t.Run("for another kind of error", func(t *testing.T) {
if classifyQUICHandshakeError(io.EOF) != FailureEOFError {
if ClassifyQUICHandshakeError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})
@@ -252,43 +252,43 @@ func TestClassifyResolverError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyResolverError(err) != FailureEOFError {
if ClassifyResolverError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for ErrDNSBogon", func(t *testing.T) {
if classifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
if ClassifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
t.Fatal("unexpected result")
}
})
t.Run("for refused", func(t *testing.T) {
if classifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
if ClassifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
t.Fatal("unexpected result")
}
})
t.Run("for servfail", func(t *testing.T) {
if classifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
if ClassifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
t.Fatal("unexpected result")
}
})
t.Run("for dns reply with wrong queryID", func(t *testing.T) {
if classifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
if ClassifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
t.Fatal("unexpected result")
}
})
t.Run("for EAI_NODATA returned by Android's getaddrinfo", func(t *testing.T) {
if classifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
if ClassifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
t.Fatal("unexpected result")
}
})
t.Run("for another kind of error", func(t *testing.T) {
if classifyResolverError(io.EOF) != FailureEOFError {
if ClassifyResolverError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})
@@ -300,34 +300,34 @@ func TestClassifyTLSHandshakeError(t *testing.T) {
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
err := &ErrWrapper{Failure: FailureEOFError}
if classifyTLSHandshakeError(err) != FailureEOFError {
if ClassifyTLSHandshakeError(err) != FailureEOFError {
t.Fatal("did not classify existing ErrWrapper correctly")
}
})
t.Run("for x509.HostnameError", func(t *testing.T) {
var err x509.HostnameError
if classifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
t.Fatal("unexpected result")
}
})
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
var err x509.UnknownAuthorityError
if classifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
if ClassifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
t.Fatal("unexpected result")
}
})
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
var err x509.CertificateInvalidError
if classifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
t.Fatal("unexpected result")
}
})
t.Run("for another kind of error", func(t *testing.T) {
if classifyTLSHandshakeError(io.EOF) != FailureEOFError {
if ClassifyTLSHandshakeError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result")
}
})
+37 -11
View File
@@ -125,7 +125,7 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver,
outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors
}
return &dialerLogger{
Dialer: &dialerResolver{
Dialer: &dialerResolverWithTracing{
Dialer: &dialerLogger{
Dialer: outDialer,
DebugLogger: logger,
@@ -171,15 +171,24 @@ func (d *DialerSystem) CloseIdleConnections() {
// nothing to do here
}
// dialerResolver combines dialing with domain name resolution.
type dialerResolver struct {
// dialerResolverWithTracing combines dialing with domain name resolution and
// implements hooks to trace TCP (or UDP) connect operations.
type dialerResolverWithTracing struct {
Dialer model.Dialer
Resolver model.Resolver
}
var _ model.Dialer = &dialerResolver{}
var _ model.Dialer = &dialerResolverWithTracing{}
func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
// DialContext implements model.Dialer.DialContext. Specifically this
// method performs the following operations:
//
// 1. resolve the domain inside the address using a resolver;
//
// 2. cycle through the available IP addresses and try to dial each of them;
//
// 3. trace the TCP (or UDP) connect and allow wrapping the returned conn.
func (d *dialerResolverWithTracing) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
// QUIRK: this routine and the related routines in quirks.go cannot
// be changed easily until we use events tracing to measure.
//
@@ -194,10 +203,27 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
}
addrs = quirkSortIPAddrs(addrs)
var errorslist []error
trace := ContextTraceOrDefault(ctx)
for _, addr := range addrs {
target := net.JoinHostPort(addr, onlyport)
started := trace.TimeNow()
conn, err := d.Dialer.DialContext(ctx, network, target)
finished := trace.TimeNow()
// TODO(bassosimone): to make the code robust to future refactoring we have
// moved error wrapping inside this type. This change opens up the possibility
// of simplifying the dialing chain by removing dialerErrWrapper. We'll be
// able to implement this refactoring once netx is gone. We cannot complete
// this refactoring _before_ because WrapDialer inserts extra wrappers
// provided by netx in the dialers chain _before_ this dialer and the dialers
// that netx insert assume that they wrap a dialer with error wrapping.
//
// Because error wrapping should be idempotent, it should not be a problem
// to have two error wrapping dialers in the chain except that, of course, it
// would be less efficient than just having a single wrapper.
err = MaybeNewErrWrapper(ClassifyGenericError, ConnectOperation, err)
trace.OnConnectDone(started, network, onlyhost, target, err, finished)
if err == nil {
conn = &dialerErrWrapperConn{conn}
return conn, nil
}
errorslist = append(errorslist, err)
@@ -206,14 +232,14 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
}
// lookupHost ensures we correctly handle IP addresses.
func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
func (d *dialerResolverWithTracing) lookupHost(ctx context.Context, hostname string) ([]string, error) {
if net.ParseIP(hostname) != nil {
return []string{hostname}, nil
}
return d.Resolver.LookupHost(ctx, hostname)
}
func (d *dialerResolver) CloseIdleConnections() {
func (d *dialerResolverWithTracing) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
d.Resolver.CloseIdleConnections()
}
@@ -303,7 +329,7 @@ var _ model.Dialer = &dialerErrWrapper{}
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, newErrWrapper(classifyGenericError, ConnectOperation, err)
return nil, NewErrWrapper(ClassifyGenericError, ConnectOperation, err)
}
return &dialerErrWrapperConn{Conn: conn}, nil
}
@@ -322,7 +348,7 @@ var _ net.Conn = &dialerErrWrapperConn{}
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b)
if err != nil {
return 0, newErrWrapper(classifyGenericError, ReadOperation, err)
return 0, NewErrWrapper(ClassifyGenericError, ReadOperation, err)
}
return count, nil
}
@@ -330,7 +356,7 @@ func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
count, err := c.Conn.Write(b)
if err != nil {
return 0, newErrWrapper(classifyGenericError, WriteOperation, err)
return 0, NewErrWrapper(ClassifyGenericError, WriteOperation, err)
}
return count, nil
}
@@ -338,7 +364,7 @@ func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
func (c *dialerErrWrapperConn) Close() error {
err := c.Conn.Close()
if err != nil {
return newErrWrapper(classifyGenericError, CloseOperation, err)
return NewErrWrapper(ClassifyGenericError, CloseOperation, err)
}
return nil
}
+130 -22
View File
@@ -13,6 +13,7 @@ import (
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestNewDialerWithStdlibResolver(t *testing.T) {
@@ -22,7 +23,7 @@ func TestNewDialerWithStdlibResolver(t *testing.T) {
t.Fatal("invalid logger")
}
// typecheck the resolver
reso := logger.Dialer.(*dialerResolver)
reso := logger.Dialer.(*dialerResolverWithTracing)
typecheckForSystemResolver(t, reso.Resolver, model.DiscardLogger)
// typecheck the dialer
logger = reso.Dialer.(*dialerLogger)
@@ -64,7 +65,7 @@ func TestNewDialer(t *testing.T) {
if logger.DebugLogger != log.Log {
t.Fatal("invalid logger")
}
reso := logger.Dialer.(*dialerResolver)
reso := logger.Dialer.(*dialerResolverWithTracing)
if _, okay := reso.Resolver.(*NullResolver); !okay {
t.Fatal("invalid Resolver type")
}
@@ -136,10 +137,10 @@ func TestDialerSystem(t *testing.T) {
})
}
func TestDialerResolver(t *testing.T) {
func TestDialerResolverWithTracing(t *testing.T) {
t.Run("DialContext", func(t *testing.T) {
t.Run("fails without a port", func(t *testing.T) {
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &DialerSystem{},
Resolver: NewUnwrappedStdlibResolver(),
}
@@ -154,7 +155,7 @@ func TestDialerResolver(t *testing.T) {
})
t.Run("handles dialing error correctly for single IP address", func(t *testing.T) {
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
@@ -166,13 +167,26 @@ func TestDialerResolver(t *testing.T) {
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("the error has not been wrapped")
}
if errWrapper.Failure != FailureEOFError {
t.Fatal("invalid wrapped error's failure")
}
if errWrapper.Operation != ConnectOperation {
t.Fatal("invalid wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("handles dialing error correctly for many IP addresses", func(t *testing.T) {
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
@@ -188,6 +202,19 @@ func TestDialerResolver(t *testing.T) {
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("the error has not been wrapped")
}
if errWrapper.Failure != FailureEOFError {
t.Fatal("invalid wrapped error's failure")
}
if errWrapper.Operation != ConnectOperation {
t.Fatal("invalid wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
@@ -199,7 +226,7 @@ func TestDialerResolver(t *testing.T) {
return nil
},
}
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return expectedConn, nil
@@ -215,7 +242,10 @@ func TestDialerResolver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if conn != expectedConn {
// Ensure that the dialer returns a connection that is already wrapping errors,
// which is a new behavior since https://github.com/ooni/probe-cli/pull/815
errWrapperConn := conn.(*dialerErrWrapperConn)
if errWrapperConn.Conn != expectedConn {
t.Fatal("unexpected conn")
}
conn.Close()
@@ -225,7 +255,7 @@ func TestDialerResolver(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@@ -257,7 +287,7 @@ func TestDialerResolver(t *testing.T) {
// TODO(https://github.com/ooni/probe/issues/1779)
mu := &sync.Mutex{}
var attempts []string
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@@ -298,14 +328,14 @@ func TestDialerResolver(t *testing.T) {
mu := &sync.Mutex{}
errorsList := []error{
errors.New("a mocked error"),
newErrWrapper(
classifyGenericError,
NewErrWrapper(
ClassifyGenericError,
CloseOperation,
io.EOF,
),
}
var errorIdx int
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@@ -337,17 +367,18 @@ func TestDialerResolver(t *testing.T) {
t.Run("though ignores the unknown failures", func(t *testing.T) {
// This test is fundamental to the following
// TODO(https://github.com/ooni/probe/issues/1779)
expectedErr := errors.New("a mocked error")
mu := &sync.Mutex{}
errorsList := []error{
errors.New("a mocked error"),
newErrWrapper(
classifyGenericError,
expectedErr,
NewErrWrapper(
ClassifyGenericError,
CloseOperation,
errors.New("antani"),
errors.New("antani"), // this is an unknown failure and we should not return it
),
}
var errorIdx int
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
// It should not happen to have parallel dials with
@@ -368,18 +399,95 @@ func TestDialerResolver(t *testing.T) {
},
}
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
if err == nil || err.Error() != "a mocked error" {
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("error has not been wrapped")
}
if errWrapper.Failure != "unknown_failure: a mocked error" {
t.Fatal("unexpected wrapped error's failure")
}
if errWrapper.Operation != ConnectOperation {
t.Fatal("unexpected wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, expectedErr) {
t.Fatal("unexpected wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("uses a context-injected custom trace", func(t *testing.T) {
var (
called bool
domainOK bool
networkOK bool
remoteAddrOK bool
startTimeOK bool
finishTimeOK bool
wrappedErr bool
)
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
var ew *ErrWrapper
called = true
domainOK = (domain == "1.1.1.1")
networkOK = (network == "tcp")
remoteAddrOK = (remoteAddr == "1.1.1.1:853")
startTimeOK = (started.Sub(zeroTime) == 0)
finishTimeOK = (finished.Sub(zeroTime) == time.Second)
wrappedErr = errors.As(err, &ew) && ew.Failure == FailureEOFError
},
}
ctx := ContextWithTrace(context.Background(), tx)
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
return nil, io.EOF
},
},
Resolver: &NullResolver{},
}
conn, err := d.DialContext(ctx, "tcp", "1.1.1.1:853")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected")
}
if conn != nil {
t.Fatal("expected nil conn")
}
if !called {
t.Fatal("not called")
}
if !domainOK {
t.Fatal("domain was not okay")
}
if !networkOK {
t.Fatal("network was not okay")
}
if !remoteAddrOK {
t.Fatal("remoteAddr was not okay")
}
if !startTimeOK {
t.Fatal("start time was not okay")
}
if !finishTimeOK {
t.Fatal("finish time was not okay")
}
if !wrappedErr {
t.Fatal("not wrapped")
}
})
})
t.Run("lookupHost", func(t *testing.T) {
t.Run("handles addresses correctly", func(t *testing.T) {
dialer := &dialerResolver{
dialer := &dialerResolverWithTracing{
Dialer: &DialerSystem{},
Resolver: &NullResolver{},
}
@@ -393,7 +501,7 @@ func TestDialerResolver(t *testing.T) {
})
t.Run("fails correctly on lookup error", func(t *testing.T) {
dialer := &dialerResolver{
dialer := &dialerResolverWithTracing{
Dialer: &DialerSystem{},
Resolver: &NullResolver{},
}
@@ -413,7 +521,7 @@ func TestDialerResolver(t *testing.T) {
calledDialer bool
calledResolver bool
)
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockCloseIdleConnections: func() {
calledDialer = true
+1 -1
View File
@@ -38,7 +38,7 @@ func (t *dnsTransportErrWrapper) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
resp, err := t.DNSTransport.RoundTrip(ctx, query)
if err != nil {
return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, DNSRoundTripOperation, err)
}
return resp, nil
}
+4
View File
@@ -34,6 +34,10 @@
//
// We want to have reasonable watchdog timeouts for each operation.
//
// We also want lightweight support for tracing network events. To this end, we
// use context.WithValue and context.Value to inject, and retrieve, a model.Trace
// implementation OPTIONALLY configured by the user.
//
// See also the design document at docs/design/dd-003-step-by-step.md,
// which provides an overview of netxlite's main concerns.
//
+16 -3
View File
@@ -71,7 +71,7 @@ func (e *ErrWrapper) MarshalJSON() ([]byte, error) {
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
type classifier func(err error) string
// newErrWrapper creates a new ErrWrapper using the given
// NewErrWrapper creates a new ErrWrapper using the given
// classifier, operation name, and underlying error.
//
// This function panics if classifier is nil, or operation
@@ -81,7 +81,7 @@ type classifier func(err error) string
// error wrapper will use the same classification string and
// will determine whether to keep the major operation as documented
// in the ErrWrapper.Operation documentation.
func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
func NewErrWrapper(c classifier, op string, err error) *ErrWrapper {
var wrapper *ErrWrapper
if errors.As(err, &wrapper) {
return &ErrWrapper{
@@ -106,6 +106,19 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
}
}
// TODO(https://github.com/ooni/probe/issues/2163): we can really
// simplify the error wrapping situation here by just dropping
// NewErrWrapper and always using MaybeNewErrWrapper.
// MaybeNewErrWrapper is like NewErrWrapper except that this
// function won't panic if passed a nil error.
func MaybeNewErrWrapper(c classifier, op string, err error) error {
if err != nil {
return NewErrWrapper(c, op, err)
}
return nil
}
// NewTopLevelGenericErrWrapper wraps an error occurring at top
// level using a generic classifier as classifier. This is the
// function you should call when you suspect a given error hasn't
@@ -115,7 +128,7 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
// error wrapper will use the same classification string and
// failed operation of the original error.
func NewTopLevelGenericErrWrapper(err error) *ErrWrapper {
return newErrWrapper(classifyGenericError, TopLevelOperation, err)
return NewErrWrapper(ClassifyGenericError, TopLevelOperation, err)
}
func classifyOperation(ew *ErrWrapper, operation string) string {
+34 -6
View File
@@ -53,7 +53,7 @@ func TestNewErrWrapper(t *testing.T) {
recovered.Add(1)
}
}()
newErrWrapper(nil, CloseOperation, io.EOF)
NewErrWrapper(nil, CloseOperation, io.EOF)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
@@ -68,7 +68,7 @@ func TestNewErrWrapper(t *testing.T) {
recovered.Add(1)
}
}()
newErrWrapper(classifyGenericError, "", io.EOF)
NewErrWrapper(ClassifyGenericError, "", io.EOF)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
@@ -83,7 +83,7 @@ func TestNewErrWrapper(t *testing.T) {
recovered.Add(1)
}
}()
newErrWrapper(classifyGenericError, CloseOperation, nil)
NewErrWrapper(ClassifyGenericError, CloseOperation, nil)
}()
if recovered.Load() != 1 {
t.Fatal("did not panic")
@@ -91,7 +91,7 @@ func TestNewErrWrapper(t *testing.T) {
})
t.Run("otherwise, works as intended", func(t *testing.T) {
ew := newErrWrapper(classifyGenericError, CloseOperation, io.EOF)
ew := NewErrWrapper(ClassifyGenericError, CloseOperation, io.EOF)
if ew.Failure != FailureEOFError {
t.Fatal("unexpected failure")
}
@@ -104,10 +104,10 @@ func TestNewErrWrapper(t *testing.T) {
})
t.Run("when the underlying error is already a wrapped error", func(t *testing.T) {
ew := newErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
ew := NewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
var err1 error = ew
err2 := fmt.Errorf("cannot read: %w", err1)
ew2 := newErrWrapper(classifyGenericError, HTTPRoundTripOperation, err2)
ew2 := NewErrWrapper(ClassifyGenericError, HTTPRoundTripOperation, err2)
if ew2.Failure != ew.Failure {
t.Fatal("not the same failure")
}
@@ -117,6 +117,34 @@ func TestNewErrWrapper(t *testing.T) {
if ew2.WrappedErr != err2 {
t.Fatal("invalid underlying error")
}
// Make sure we can still use errors.Is with two layers of wrapping
if !errors.Is(ew2, ECONNRESET) {
t.Fatal("we cannot use errors.Is to retrieve the real syscall error")
}
})
}
func TestMaybeNewErrWrapper(t *testing.T) {
// TODO(https://github.com/ooni/probe/issues/2163): we can really
// simplify the error wrapping situation here by just dropping
// NewErrWrapper and always using MaybeNewErrWrapper.
t.Run("when we pass a nil error to this function", func(t *testing.T) {
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, nil)
if err != nil {
t.Fatal("unexpected output", err)
}
})
t.Run("when we pass a non-nil error to this function", func(t *testing.T) {
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
if !errors.Is(err, ECONNRESET) {
t.Fatal("unexpected output", err)
}
var ew *ErrWrapper
if !errors.As(err, &ew) {
t.Fatal("not an instance of ErrWrapper")
}
})
}
+4 -4
View File
@@ -31,7 +31,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
dialer := txpCc.Dialer
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger)
dialerReso := dialerLog.Dialer.(*dialerResolver)
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
if dialerReso.Resolver != resolver {
t.Fatal("invalid resolver")
}
@@ -52,7 +52,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer)
dialerLog := dialerProxy.Dialer.(*dialerLogger)
dialerReso := dialerLog.Dialer.(*dialerResolver)
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
if dialerReso.Resolver != resolver {
t.Fatal("invalid resolver")
}
@@ -269,7 +269,7 @@ func TestNewHTTPTransport(t *testing.T) {
t.Run("works as intended with failing dialer", func(t *testing.T) {
called := &atomicx.Int64{}
expected := errors.New("mocked error")
d := &dialerResolver{
d := &dialerResolverWithTracing{
Dialer: &mocks.Dialer{
MockDialContext: func(ctx context.Context,
network, address string) (net.Conn, error) {
@@ -612,7 +612,7 @@ func TestNewHTTPClientWithResolver(t *testing.T) {
txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser)
dialer := txpCc.Dialer.(*httpDialerWithReadTimeout)
dialerLogger := dialer.Dialer.(*dialerLogger)
dialerReso := dialerLogger.Dialer.(*dialerResolver)
dialerReso := dialerLogger.Dialer.(*dialerResolverWithTracing)
if dialerReso.Resolver != reso {
t.Fatal("invalid resolver")
}
+2 -2
View File
@@ -41,7 +41,7 @@ func TestReadAllContext(t *testing.T) {
//
// Note: Returning a wrapped error to ensure we address
// https://github.com/ooni/probe/issues/1965
return len(b), newErrWrapper(classifyGenericError,
return len(b), NewErrWrapper(ClassifyGenericError,
ReadOperation, io.EOF)
},
}
@@ -171,7 +171,7 @@ func TestCopyContext(t *testing.T) {
//
// Note: Returning a wrapped error to ensure we address
// https://github.com/ooni/probe/issues/1965
return len(b), newErrWrapper(classifyGenericError,
return len(b), NewErrWrapper(ClassifyGenericError,
ReadOperation, io.EOF)
},
}
+6 -6
View File
@@ -380,7 +380,7 @@ var _ model.QUICListener = &quicListenerErrWrapper{}
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
pconn, err := qls.QUICListener.Listen(addr)
if err != nil {
return nil, newErrWrapper(classifyGenericError, QUICListenOperation, err)
return nil, NewErrWrapper(ClassifyGenericError, QUICListenOperation, err)
}
return &quicErrWrapperUDPLikeConn{pconn}, nil
}
@@ -397,7 +397,7 @@ var _ model.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
count, err := c.UDPLikeConn.WriteTo(p, addr)
if err != nil {
return 0, newErrWrapper(classifyGenericError, WriteToOperation, err)
return 0, NewErrWrapper(ClassifyGenericError, WriteToOperation, err)
}
return count, nil
}
@@ -406,7 +406,7 @@ func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, addr, err := c.UDPLikeConn.ReadFrom(b)
if err != nil {
return 0, nil, newErrWrapper(classifyGenericError, ReadFromOperation, err)
return 0, nil, NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
}
return n, addr, nil
}
@@ -415,7 +415,7 @@ func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
func (c *quicErrWrapperUDPLikeConn) Close() error {
err := c.UDPLikeConn.Close()
if err != nil {
return newErrWrapper(classifyGenericError, ReadFromOperation, err)
return NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
}
return nil
}
@@ -433,8 +433,8 @@ func (d *quicDialerErrWrapper) DialContext(
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
qconn, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
if err != nil {
return nil, newErrWrapper(
classifyQUICHandshakeError, QUICHandshakeOperation, err)
return nil, NewErrWrapper(
ClassifyQUICHandshakeError, QUICHandshakeOperation, err)
}
return qconn, nil
}
+4 -4
View File
@@ -34,13 +34,13 @@ func TestQuirkReduceErrors(t *testing.T) {
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
err1 := errors.New("mocked error #1")
err2 := newErrWrapper(
classifyGenericError,
err2 := NewErrWrapper(
ClassifyGenericError,
CloseOperation,
errors.New("antani"),
)
err3 := newErrWrapper(
classifyGenericError,
err3 := NewErrWrapper(
ClassifyGenericError,
CloseOperation,
ECONNREFUSED,
)
+3 -3
View File
@@ -387,7 +387,7 @@ var _ model.Resolver = &resolverErrWrapper{}
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
addrs, err := r.Resolver.LookupHost(ctx, hostname)
if err != nil {
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
}
return addrs, nil
}
@@ -396,7 +396,7 @@ func (r *resolverErrWrapper) LookupHTTPS(
ctx context.Context, domain string) (*model.HTTPSSvc, error) {
out, err := r.Resolver.LookupHTTPS(ctx, domain)
if err != nil {
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
}
return out, nil
}
@@ -417,7 +417,7 @@ func (r *resolverErrWrapper) LookupNS(
ctx context.Context, domain string) ([]*net.NS, error) {
out, err := r.Resolver.LookupNS(ctx, domain)
if err != nil {
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
}
return out, nil
}
+33 -26
View File
@@ -18,9 +18,6 @@ import (
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
// TODO(bassosimone): check whether there's now equivalent functionality
// inside the standard library allowing us to map numbers to names.
var (
tlsVersionString = map[uint16]string{
tls.VersionTLS10: "TLSv1",
@@ -85,6 +82,13 @@ func TLSVersionString(value uint16) string {
// the value to a cipher suite name, we return `TLS_CIPHER_SUITE_UNKNOWN_ddd`
// where `ddd` is the numeric value passed to this function.
func TLSCipherSuiteString(value uint16) string {
// TODO(https://github.com/ooni/probe/issues/2166): the standard library has a
// function for mapping a cipher suite to a string, but the value returned in case of
// missing cipher suite is different from the one we would return
// here. We could consider simplifying this code anyway because
// in most, if not all, cases we have a valid cipher suite and we
// just need to make sure what the spec says we should do when
// passed an unknown cipher suite.
if str, found := tlsCipherSuiteString[value]; found {
return str
}
@@ -158,15 +162,15 @@ func NewTLSHandshakerStdlib(logger model.DebugLogger) model.TLSHandshaker {
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
func newTLSHandshaker(th model.TLSHandshaker, logger model.DebugLogger) model.TLSHandshaker {
return &tlsHandshakerLogger{
TLSHandshaker: &tlsHandshakerErrWrapper{
TLSHandshaker: th,
},
DebugLogger: logger,
TLSHandshaker: th,
DebugLogger: logger,
}
}
// tlsHandshakerConfigurable is a configurable TLS handshaker that
// uses by default the standard library's TLS implementation.
//
// This type also implements error wrapping and events tracing.
type tlsHandshakerConfigurable struct {
// NewConn is the OPTIONAL factory for creating a new connection. If
// this factory is not set, we'll use the stdlib.
@@ -183,9 +187,20 @@ var _ model.TLSHandshaker = &tlsHandshakerConfigurable{}
// value into a private variable to enable for unit testing.
var defaultCertPool = NewDefaultCertPool()
// tlsMaybeConnectionState returns the connection state if error is nil
// and otherwise just returns an empty state to the caller.
func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState {
if err != nil {
return tls.ConnectionState{}
}
return conn.ConnectionState()
}
// Handshake implements Handshaker.Handshake. This function will
// configure the code to use the built-in Mozilla CA if the config
// field contains a nil RootCAs field.
//
// This function will also emit TLS-handshake-related tracing events.
func (h *tlsHandshakerConfigurable) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
@@ -203,10 +218,19 @@ func (h *tlsHandshakerConfigurable) Handshake(
if err != nil {
return nil, tls.ConnectionState{}, err
}
if err := tlsconn.HandshakeContext(ctx); err != nil {
remoteAddr := conn.RemoteAddr().String()
trace := ContextTraceOrDefault(ctx)
started := trace.TimeNow()
trace.OnTLSHandshakeStart(started, remoteAddr, config)
err = tlsconn.HandshakeContext(ctx)
err = MaybeNewErrWrapper(ClassifyTLSHandshakeError, TLSHandshakeOperation, err)
finished := trace.TimeNow()
state := tlsMaybeConnectionState(tlsconn, err)
trace.OnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
if err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
return tlsconn, state, nil
}
// newConn creates a new TLSConn.
@@ -352,23 +376,6 @@ func (d *tlsDialerSingleUseAdapter) CloseIdleConnections() {
d.Dialer.CloseIdleConnections()
}
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
type tlsHandshakerErrWrapper struct {
TLSHandshaker model.TLSHandshaker
}
// Handshake implements TLSHandshaker.Handshake
func (h *tlsHandshakerErrWrapper) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
if err != nil {
return nil, tls.ConnectionState{}, newErrWrapper(
classifyTLSHandshakeError, TLSHandshakeOperation, err)
}
return tlsconn, state, nil
}
// ErrNoTLSDialer is the type of error returned by "null" TLS dialers
// when you attempt to dial with them.
var ErrNoTLSDialer = errors.New("no configured TLS dialer")
+301 -54
View File
@@ -16,7 +16,10 @@ import (
"github.com/apex/log"
"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestVersionString(t *testing.T) {
@@ -123,8 +126,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
if logger.DebugLogger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn != nil {
t.Fatal("expected nil NewConn")
}
@@ -132,7 +134,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
func TestTLSHandshakerConfigurable(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("with error", func(t *testing.T) {
t.Run("with handshake I/O error", func(t *testing.T) {
var times []time.Time
h := &tlsHandshakerConfigurable{}
tcpConn := &mocks.Conn{
@@ -143,14 +145,37 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
times = append(times, t)
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "1.1.1.1:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
ctx := context.Background()
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
conn, state, err := h.Handshake(ctx, tcpConn, &tls.Config{
ServerName: "x.org",
})
if err != io.EOF {
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
var errWrapper *ErrWrapper
if !errors.As(err, &errWrapper) {
t.Fatal("the error has not been wrapped")
}
if errWrapper.Failure != FailureEOFError {
t.Fatal("invalid wrapped error's failure")
}
if errWrapper.Operation != TLSHandshakeOperation {
t.Fatal("invalid wrapped error's operation")
}
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
t.Fatal("invalid wrapped error's underlying error")
}
if conn != nil {
t.Fatal("expected nil con here")
}
@@ -163,6 +188,9 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
if !times[1].IsZero() {
t.Fatal("did not clear timeout on exit")
}
if !reflect.ValueOf(state).IsZero() {
t.Fatal("the returned connection state is not a zero value")
}
})
t.Run("with success", func(t *testing.T) {
@@ -217,6 +245,16 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
MockSetDeadline: func(t time.Time) error {
return nil
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return "1.1.1.1:443"
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
if !errors.Is(err, expected) {
@@ -236,7 +274,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
}
})
t.Run("we cannot create a new conn", func(t *testing.T) {
t.Run("h.newConn fails", func(t *testing.T) {
expected := errors.New("mocked error")
handshaker := &tlsHandshakerConfigurable{
NewConn: func(conn net.Conn, config *tls.Config) (TLSConn, error) {
@@ -261,6 +299,222 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
t.Fatal("expected nil tlsConn here")
}
})
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
var (
expectedSNI = "dns.google"
goodStartStartTime bool
goodStartInsecureSkipVerify bool
goodDoneInsecureSkipVerify bool
goodStartServerName bool
goodDoneServerName bool
goodDoneStartTime bool
goodDoneDoneTime bool
goodStartRemoteAddr bool
goodDoneRemoteAddr bool
goodDoneError bool
goodConnectionState bool
startCalled bool
doneCalled bool
)
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
defer server.Close()
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
startCalled = true
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodStartServerName = (config.ServerName == expectedSNI)
goodStartStartTime = (now.Sub(zeroTime) == 0)
goodStartRemoteAddr = (remoteAddr == server.Endpoint())
},
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
doneCalled = true
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodDoneServerName = (config.ServerName == expectedSNI)
goodDoneStartTime = (started.Sub(zeroTime) == 0)
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
goodDoneRemoteAddr = (remoteAddr == server.Endpoint())
goodDoneError = (err == nil)
goodConnectionState = (!reflect.ValueOf(state).IsZero())
},
}
ctx := ContextWithTrace(context.Background(), tx)
tcpConn, err := net.Dial("tcp", server.Endpoint())
if err != nil {
t.Fatal(err)
}
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: expectedSNI,
}
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if err != nil {
t.Fatal(err)
}
tlsConn.Close()
if reflect.ValueOf(connState).IsZero() {
t.Fatal("expected nonzero connState")
}
if !startCalled {
t.Fatal("start not called")
}
if !doneCalled {
t.Fatal("done not called")
}
if !goodStartInsecureSkipVerify {
t.Fatal("invalid start-event's InsecureSkipVerify")
}
if !goodDoneInsecureSkipVerify {
t.Fatal("invalid done-event's InsecureSkipVerify")
}
if !goodStartServerName {
t.Fatal("invalid start-event's ServerName")
}
if !goodDoneServerName {
t.Fatal("invalid done-event's ServerName")
}
if !goodStartStartTime {
t.Fatal("invalid start-event's start time")
}
if !goodDoneStartTime {
t.Fatal("invalid done-event's start time")
}
if !goodDoneDoneTime {
t.Fatal("invalid done-event's done time")
}
if !goodStartRemoteAddr {
t.Fatal("invalid start-event's remoteAddr")
}
if !goodDoneRemoteAddr {
t.Fatal("invalid done-event's remoteAddr")
}
if !goodDoneError {
t.Fatal("invalid done-event's error")
}
if !goodConnectionState {
t.Fatal("invalid done-event's connState")
}
})
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
var (
expectedEndpoint = "8.8.8.8:443"
expectedSNI = "dns.google"
goodStartStartTime bool
goodStartInsecureSkipVerify bool
goodDoneInsecureSkipVerify bool
goodStartServerName bool
goodDoneServerName bool
goodDoneStartTime bool
goodDoneDoneTime bool
goodStartRemoteAddr bool
goodDoneRemoteAddr bool
goodDoneError bool
goodConnectionState bool
startCalled bool
doneCalled bool
)
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
startCalled = true
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodStartServerName = (config.ServerName == expectedSNI)
goodStartStartTime = (now.Sub(zeroTime) == 0)
goodStartRemoteAddr = (remoteAddr == expectedEndpoint)
},
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
doneCalled = true
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
goodDoneServerName = (config.ServerName == expectedSNI)
goodDoneStartTime = (started.Sub(zeroTime) == 0)
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
goodDoneRemoteAddr = (remoteAddr == expectedEndpoint)
var ew *ErrWrapper
goodDoneError = (errors.As(err, &ew) && ew.Error() == FailureEOFError)
goodConnectionState = (reflect.ValueOf(state).IsZero())
},
}
ctx := ContextWithTrace(context.Background(), tx)
tcpConn := &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockString: func() string {
return expectedEndpoint
},
MockNetwork: func() string {
return "tcp"
},
}
},
}
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: expectedSNI,
}
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
if !errors.Is(err, io.EOF) {
t.Fatal("unexpected err", err)
}
if tlsConn != nil {
t.Fatal("expected nil tlsConn")
}
if !reflect.ValueOf(connState).IsZero() {
t.Fatal("expected zero connState")
}
if !startCalled {
t.Fatal("start not called")
}
if !doneCalled {
t.Fatal("done not called")
}
if !goodStartInsecureSkipVerify {
t.Fatal("invalid start-event's InsecureSkipVerify")
}
if !goodDoneInsecureSkipVerify {
t.Fatal("invalid done-event's InsecureSkipVerify")
}
if !goodStartServerName {
t.Fatal("invalid start-event's ServerName")
}
if !goodDoneServerName {
t.Fatal("invalid done-event's ServerName")
}
if !goodStartStartTime {
t.Fatal("invalid start-event's start time")
}
if !goodDoneStartTime {
t.Fatal("invalid done-event's start time")
}
if !goodDoneDoneTime {
t.Fatal("invalid done-event's done time")
}
if !goodStartRemoteAddr {
t.Fatal("invalid start-event's remoteAddr")
}
if !goodDoneRemoteAddr {
t.Fatal("invalid done-event's remoteAddr")
}
if !goodDoneError {
t.Fatal("invalid done-event's error")
}
if !goodConnectionState {
t.Fatal("invalid done-event's connState")
}
})
})
}
@@ -413,6 +667,15 @@ func TestTLSDialer(t *testing.T) {
return nil
}, MockSetDeadline: func(t time.Time) error {
return nil
}, MockRemoteAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "1.1.1.1:443"
},
MockString: func() string {
return "tcp"
},
}
}}, nil
}},
TLSHandshaker: &tlsHandshakerConfigurable{},
@@ -532,54 +795,6 @@ func TestNewSingleUseTLSDialer(t *testing.T) {
}
}
func TestTLSHandshakerErrWrapper(t *testing.T) {
t.Run("Handshake", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
expectedConn := &mocks.TLSConn{}
expectedState := tls.ConnectionState{
Version: tls.VersionTLS12,
}
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return expectedConn, expectedState, nil
},
},
}
ctx := context.Background()
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err != nil {
t.Fatal(err)
}
if expectedState.Version != state.Version {
t.Fatal("unexpected state")
}
if expectedConn != conn {
t.Fatal("unexpected conn")
}
})
t.Run("on failure", func(t *testing.T) {
expectedErr := io.EOF
th := &tlsHandshakerErrWrapper{
TLSHandshaker: &mocks.TLSHandshaker{
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
return nil, tls.ConnectionState{}, expectedErr
},
},
}
ctx := context.Background()
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
if err == nil || err.Error() != FailureEOFError {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("unexpected conn")
}
})
})
}
func TestNewNullTLSDialer(t *testing.T) {
dialer := NewNullTLSDialer()
conn, err := dialer.DialTLSContext(context.Background(), "", "")
@@ -618,3 +833,35 @@ func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
}
})
}
func TestMaybeConnectionState(t *testing.T) {
t.Run("with an error", func(t *testing.T) {
returned := tls.ConnectionState{
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
}
conn := &mocks.TLSConn{
MockConnectionState: func() tls.ConnectionState {
return returned
},
}
state := tlsMaybeConnectionState(conn, errors.New("mocked error"))
if !reflect.ValueOf(state).IsZero() {
t.Fatal("expected to see a zero connection state")
}
})
t.Run("without an error", func(t *testing.T) {
returned := tls.ConnectionState{
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
}
conn := &mocks.TLSConn{
MockConnectionState: func() tls.ConnectionState {
return returned
},
}
state := tlsMaybeConnectionState(conn, nil)
if reflect.ValueOf(state).IsZero() {
t.Fatal("expected to see a nonzero connection state")
}
})
}
+67
View File
@@ -0,0 +1,67 @@
package netxlite
//
// Context-based tracing
//
import (
"context"
"crypto/tls"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
// traceKey is the private type used to set/retrieve the context's trace.
type traceKey struct{}
// ContextTraceOrDefault retrieves the trace bound to the context or returns
// a default implementation of the trace in case no tracing was configured.
func ContextTraceOrDefault(ctx context.Context) model.Trace {
t, _ := ctx.Value(traceKey{}).(model.Trace)
return traceOrDefault(t)
}
// ContextWithTrace returns a new context that binds to the given trace. If the
// given trace is nil, this function will call panic.
func ContextWithTrace(ctx context.Context, trace model.Trace) context.Context {
runtimex.PanicIfTrue(trace == nil, "netxlite.WithTrace passed a nil trace")
return context.WithValue(ctx, traceKey{}, trace)
}
// traceOrDefault takes in input a trace and returns in output the
// given trace, if not nil, or a default trace implementation.
func traceOrDefault(trace model.Trace) model.Trace {
if trace != nil {
return trace
}
return &traceDefault{}
}
// traceDefault is a default model.Trace implementation where each method is a no-op.
type traceDefault struct{}
var _ model.Trace = &traceDefault{}
// TimeNow implements model.Trace.TimeNow
func (*traceDefault) TimeNow() time.Time {
return time.Now()
}
// OnConnectDone implements model.Trace.OnConnectDone.
func (*traceDefault) OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
// nothing
}
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
func (*traceDefault) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
// nothing
}
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
func (*traceDefault) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
state tls.ConnectionState, err error, finished time.Time) {
// nothing
}
+40
View File
@@ -0,0 +1,40 @@
package netxlite
import (
"context"
"testing"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestContextTraceOrDefault(t *testing.T) {
t.Run("without a configured trace we get a default", func(t *testing.T) {
ctx := context.Background()
tx := ContextTraceOrDefault(ctx)
_ = tx.(*traceDefault) // panic if cannot cast
})
t.Run("with a configured trace we get the expected trace", func(t *testing.T) {
realTrace := &mocks.Trace{}
ctx := ContextWithTrace(context.Background(), realTrace)
tx := ContextTraceOrDefault(ctx)
if tx != realTrace {
t.Fatal("not the trace we expected")
}
})
}
func TestContextWithTrace(t *testing.T) {
t.Run("panics if passed a nil trace", func(t *testing.T) {
var called bool
func() {
defer func() {
called = (recover() != nil)
}()
_ = ContextWithTrace(context.Background(), nil)
}()
if !called {
t.Fatal("not called")
}
})
}
+1 -2
View File
@@ -19,8 +19,7 @@ func TestNewTLSHandshakerUTLS(t *testing.T) {
if logger.DebugLogger != log.Log {
t.Fatal("invalid logger")
}
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
if configurable.NewConn == nil {
t.Fatal("expected non-nil NewConn")
}