feat: tlsping and tcpping using step-by-step (#815)
## Checklist - [x] I have read the [contribution guidelines](https://github.com/ooni/probe-cli/blob/master/CONTRIBUTING.md) - [x] reference issue for this pull request: https://github.com/ooni/probe/issues/2158 - [x] if you changed anything related how experiments work and you need to reflect these changes in the ooni/spec repository, please link to the related ooni/spec pull request: https://github.com/ooni/spec/pull/250 ## Description This diff refactors the codebase to reimplement tlsping and tcpping to use the step-by-step measurements style. See docs/design/dd-003-step-by-step.md for more information on the step-by-step measurement style.
This commit is contained in:
@@ -47,7 +47,7 @@ func (r *bogonResolver) LookupHost(ctx context.Context, hostname string) ([]stri
|
||||
for _, addr := range addrs {
|
||||
if IsBogon(addr) {
|
||||
// wrap ErrDNSBogon as documented
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, ErrDNSBogon)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, ErrDNSBogon)
|
||||
}
|
||||
}
|
||||
return addrs, nil
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/scrubber"
|
||||
)
|
||||
|
||||
// classifyGenericError is maps an error occurred during an operation
|
||||
// to an OONI failure string. This specific classifier is the most
|
||||
// ClassifyGenericError maps an error occurred during an operation to
|
||||
// an OONI failure string. This specific classifier is the most
|
||||
// generic one. You usually use it when mapping I/O errors. You should
|
||||
// check whether there is a specific classifier for more specific
|
||||
// operations (e.g., DNS resolution, TLS handshake).
|
||||
@@ -38,7 +38,7 @@ import (
|
||||
// If everything else fails, this classifier returns a string
|
||||
// like "unknown_failure: XXX" where XXX has been scrubbed
|
||||
// so to remove any network endpoints from the original error string.
|
||||
func classifyGenericError(err error) string {
|
||||
func ClassifyGenericError(err error) string {
|
||||
// The list returned here matches the values used by MK unless
|
||||
// explicitly noted otherwise with a comment.
|
||||
|
||||
@@ -139,7 +139,7 @@ const (
|
||||
quicTLSUnrecognizedName = 112
|
||||
)
|
||||
|
||||
// classifyQUICHandshakeError maps errors during a QUIC
|
||||
// ClassifyQUICHandshakeError maps errors during a QUIC
|
||||
// handshake to OONI failure strings.
|
||||
//
|
||||
// If the input error is an *ErrWrapper we don't perform
|
||||
@@ -147,7 +147,7 @@ const (
|
||||
//
|
||||
// If this classifier fails, it calls ClassifyGenericError
|
||||
// and returns to the caller its return value.
|
||||
func classifyQUICHandshakeError(err error) string {
|
||||
func ClassifyQUICHandshakeError(err error) string {
|
||||
|
||||
// Robustness: handle the case where we're passed a wrapped error.
|
||||
var errwrapper *ErrWrapper
|
||||
@@ -207,7 +207,7 @@ func classifyQUICHandshakeError(err error) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
return classifyGenericError(err)
|
||||
return ClassifyGenericError(err)
|
||||
}
|
||||
|
||||
// quicIsCertificateError tells us whether a specific TLS alert error
|
||||
@@ -277,7 +277,7 @@ var (
|
||||
// anything as explained in getaddrinfo_linux.go.
|
||||
var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
|
||||
|
||||
// classifyResolverError maps DNS resolution errors to
|
||||
// ClassifyResolverError maps DNS resolution errors to
|
||||
// OONI failure strings.
|
||||
//
|
||||
// If the input error is an *ErrWrapper we don't perform
|
||||
@@ -285,7 +285,7 @@ var ErrAndroidDNSCacheNoData = errors.New(FailureAndroidDNSCacheNoData)
|
||||
//
|
||||
// If this classifier fails, it calls ClassifyGenericError and
|
||||
// returns to the caller its return value.
|
||||
func classifyResolverError(err error) string {
|
||||
func ClassifyResolverError(err error) string {
|
||||
|
||||
// Robustness: handle the case where we're passed a wrapped error.
|
||||
var errwrapper *ErrWrapper
|
||||
@@ -310,10 +310,10 @@ func classifyResolverError(err error) string {
|
||||
if errors.Is(err, ErrAndroidDNSCacheNoData) {
|
||||
return FailureAndroidDNSCacheNoData
|
||||
}
|
||||
return classifyGenericError(err)
|
||||
return ClassifyGenericError(err)
|
||||
}
|
||||
|
||||
// classifyTLSHandshakeError maps an error occurred during the TLS
|
||||
// ClassifyTLSHandshakeError maps an error occurred during the TLS
|
||||
// handshake to an OONI failure string.
|
||||
//
|
||||
// If the input error is an *ErrWrapper we don't perform
|
||||
@@ -321,7 +321,7 @@ func classifyResolverError(err error) string {
|
||||
//
|
||||
// If this classifier fails, it calls ClassifyGenericError and
|
||||
// returns to the caller its return value.
|
||||
func classifyTLSHandshakeError(err error) string {
|
||||
func ClassifyTLSHandshakeError(err error) string {
|
||||
|
||||
// Robustness: handle the case where we're passed a wrapped error.
|
||||
var errwrapper *ErrWrapper
|
||||
@@ -345,5 +345,5 @@ func classifyTLSHandshakeError(err error) string {
|
||||
// Test case: https://expired.badssl.com/
|
||||
return FailureSSLInvalidCertificate
|
||||
}
|
||||
return classifyGenericError(err)
|
||||
return ClassifyGenericError(err)
|
||||
}
|
||||
|
||||
@@ -18,13 +18,13 @@ func TestClassifyGenericError(t *testing.T) {
|
||||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyGenericError(err) != FailureEOFError {
|
||||
if ClassifyGenericError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for a system call error", func(t *testing.T) {
|
||||
if classifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
|
||||
if ClassifyGenericError(EWOULDBLOCK) != FailureOperationWouldBlock {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
@@ -35,63 +35,63 @@ func TestClassifyGenericError(t *testing.T) {
|
||||
// is just an implementation detail.
|
||||
|
||||
t.Run("for operation was canceled", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
|
||||
if ClassifyGenericError(errors.New("operation was canceled")) != FailureInterrupted {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for EOF", func(t *testing.T) {
|
||||
if classifyGenericError(io.EOF) != FailureEOFError {
|
||||
if ClassifyGenericError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for context deadline exceeded", func(t *testing.T) {
|
||||
if classifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(context.DeadlineExceeded) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for stun's transaction is timed out", func(t *testing.T) {
|
||||
if classifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(stun.ErrTransactionTimeOut) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for i/o timeout", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(errors.New("i/o timeout")) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for TLS handshake timeout", func(t *testing.T) {
|
||||
err := errors.New("net/http: TLS handshake timeout")
|
||||
if classifyGenericError(err) != FailureGenericTimeoutError {
|
||||
if ClassifyGenericError(err) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for no such host", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
|
||||
if ClassifyGenericError(errors.New("no such host")) != FailureDNSNXDOMAINError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for dns server misbehaving", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
|
||||
if ClassifyGenericError(errors.New("dns server misbehaving")) != FailureDNSServerMisbehaving {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for no answer from DNS server", func(t *testing.T) {
|
||||
if classifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
|
||||
if ClassifyGenericError(errors.New("no answer from DNS server")) != FailureDNSNoAnswer {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for use of closed network connection", func(t *testing.T) {
|
||||
err := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: use of closed network connection")
|
||||
if classifyGenericError(err) != FailureConnectionAlreadyClosed {
|
||||
if ClassifyGenericError(err) != FailureConnectionAlreadyClosed {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
@@ -99,7 +99,7 @@ func TestClassifyGenericError(t *testing.T) {
|
||||
// Now we're back in ClassifyGenericError
|
||||
|
||||
t.Run("for context.Canceled", func(t *testing.T) {
|
||||
if classifyGenericError(context.Canceled) != FailureInterrupted {
|
||||
if ClassifyGenericError(context.Canceled) != FailureInterrupted {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
@@ -108,7 +108,7 @@ func TestClassifyGenericError(t *testing.T) {
|
||||
t.Run("with an IPv4 address", func(t *testing.T) {
|
||||
input := errors.New("read tcp 10.0.2.15:56948->93.184.216.34:443: some error")
|
||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
||||
out := classifyGenericError(input)
|
||||
out := ClassifyGenericError(input)
|
||||
if out != expected {
|
||||
t.Fatal(cmp.Diff(expected, out))
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func TestClassifyGenericError(t *testing.T) {
|
||||
t.Run("with an IPv6 address", func(t *testing.T) {
|
||||
input := errors.New("read tcp [::1]:56948->[::1]:443: some error")
|
||||
expected := "unknown_failure: read tcp [scrubbed]->[scrubbed]: some error"
|
||||
out := classifyGenericError(input)
|
||||
out := ClassifyGenericError(input)
|
||||
if out != expected {
|
||||
t.Fatal(cmp.Diff(expected, out))
|
||||
}
|
||||
@@ -131,100 +131,100 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
|
||||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyQUICHandshakeError(err) != FailureEOFError {
|
||||
if ClassifyQUICHandshakeError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for incompatible quic version", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
|
||||
if ClassifyQUICHandshakeError(&quic.VersionNegotiationError{}) != FailureQUICIncompatibleVersion {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for stateless reset", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
|
||||
if ClassifyQUICHandshakeError(&quic.StatelessResetError{}) != FailureConnectionReset {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for handshake timeout", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
|
||||
if ClassifyQUICHandshakeError(&quic.HandshakeTimeoutError{}) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for idle timeout", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
|
||||
if ClassifyQUICHandshakeError(&quic.IdleTimeoutError{}) != FailureGenericTimeoutError {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for connection refused", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: quic.ConnectionRefused}) != FailureConnectionRefused {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for bad certificate", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertBadCertificate
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for unsupported certificate", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertUnsupportedCertificate
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for certificate expired", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertCertificateExpired
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for certificate revoked", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertCertificateRevoked
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for certificate unknown", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertCertificateUnknown
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for decrypt error", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertDecryptError
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for handshake failure", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertHandshakeFailure
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLFailedHandshake {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for unknown CA", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSAlertUnknownCA
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLUnknownAuthority {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for unrecognized hostname", func(t *testing.T) {
|
||||
var err quic.TransportErrorCode = quicTLSUnrecognizedName
|
||||
if classifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
|
||||
if ClassifyQUICHandshakeError(&quic.TransportError{ErrorCode: err}) != FailureSSLInvalidHostname {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
@@ -234,13 +234,13 @@ func TestClassifyQUICHandshakeError(t *testing.T) {
|
||||
ErrorCode: quic.InternalError,
|
||||
ErrorMessage: FailureHostUnreachable,
|
||||
}
|
||||
if classifyQUICHandshakeError(err) != FailureHostUnreachable {
|
||||
if ClassifyQUICHandshakeError(err) != FailureHostUnreachable {
|
||||
t.Fatal("unexpected results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for another kind of error", func(t *testing.T) {
|
||||
if classifyQUICHandshakeError(io.EOF) != FailureEOFError {
|
||||
if ClassifyQUICHandshakeError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
@@ -252,43 +252,43 @@ func TestClassifyResolverError(t *testing.T) {
|
||||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyResolverError(err) != FailureEOFError {
|
||||
if ClassifyResolverError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for ErrDNSBogon", func(t *testing.T) {
|
||||
if classifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
|
||||
if ClassifyResolverError(ErrDNSBogon) != FailureDNSBogonError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for refused", func(t *testing.T) {
|
||||
if classifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
|
||||
if ClassifyResolverError(ErrOODNSRefused) != FailureDNSRefusedError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for servfail", func(t *testing.T) {
|
||||
if classifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
|
||||
if ClassifyResolverError(ErrOODNSServfail) != FailureDNSServfailError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for dns reply with wrong queryID", func(t *testing.T) {
|
||||
if classifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
|
||||
if ClassifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for EAI_NODATA returned by Android's getaddrinfo", func(t *testing.T) {
|
||||
if classifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
|
||||
if ClassifyResolverError(ErrAndroidDNSCacheNoData) != FailureAndroidDNSCacheNoData {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for another kind of error", func(t *testing.T) {
|
||||
if classifyResolverError(io.EOF) != FailureEOFError {
|
||||
if ClassifyResolverError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
@@ -300,34 +300,34 @@ func TestClassifyTLSHandshakeError(t *testing.T) {
|
||||
|
||||
t.Run("for input being already an ErrWrapper", func(t *testing.T) {
|
||||
err := &ErrWrapper{Failure: FailureEOFError}
|
||||
if classifyTLSHandshakeError(err) != FailureEOFError {
|
||||
if ClassifyTLSHandshakeError(err) != FailureEOFError {
|
||||
t.Fatal("did not classify existing ErrWrapper correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for x509.HostnameError", func(t *testing.T) {
|
||||
var err x509.HostnameError
|
||||
if classifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
|
||||
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidHostname {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for x509.UnknownAuthorityError", func(t *testing.T) {
|
||||
var err x509.UnknownAuthorityError
|
||||
if classifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
|
||||
if ClassifyTLSHandshakeError(err) != FailureSSLUnknownAuthority {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for x509.CertificateInvalidError", func(t *testing.T) {
|
||||
var err x509.CertificateInvalidError
|
||||
if classifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
|
||||
if ClassifyTLSHandshakeError(err) != FailureSSLInvalidCertificate {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for another kind of error", func(t *testing.T) {
|
||||
if classifyTLSHandshakeError(io.EOF) != FailureEOFError {
|
||||
if ClassifyTLSHandshakeError(io.EOF) != FailureEOFError {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
+37
-11
@@ -125,7 +125,7 @@ func WrapDialer(logger model.DebugLogger, resolver model.Resolver,
|
||||
outDialer = wrapper.WrapDialer(outDialer) // extend with user-supplied constructors
|
||||
}
|
||||
return &dialerLogger{
|
||||
Dialer: &dialerResolver{
|
||||
Dialer: &dialerResolverWithTracing{
|
||||
Dialer: &dialerLogger{
|
||||
Dialer: outDialer,
|
||||
DebugLogger: logger,
|
||||
@@ -171,15 +171,24 @@ func (d *DialerSystem) CloseIdleConnections() {
|
||||
// nothing to do here
|
||||
}
|
||||
|
||||
// dialerResolver combines dialing with domain name resolution.
|
||||
type dialerResolver struct {
|
||||
// dialerResolverWithTracing combines dialing with domain name resolution and
|
||||
// implements hooks to trace TCP (or UDP) connect operations.
|
||||
type dialerResolverWithTracing struct {
|
||||
Dialer model.Dialer
|
||||
Resolver model.Resolver
|
||||
}
|
||||
|
||||
var _ model.Dialer = &dialerResolver{}
|
||||
var _ model.Dialer = &dialerResolverWithTracing{}
|
||||
|
||||
func (d *dialerResolver) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// DialContext implements model.Dialer.DialContext. Specifically this
|
||||
// method performs the following operations:
|
||||
//
|
||||
// 1. resolve the domain inside the address using a resolver;
|
||||
//
|
||||
// 2. cycle through the available IP addresses and try to dial each of them;
|
||||
//
|
||||
// 3. trace the TCP (or UDP) connect and allow wrapping the returned conn.
|
||||
func (d *dialerResolverWithTracing) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// QUIRK: this routine and the related routines in quirks.go cannot
|
||||
// be changed easily until we use events tracing to measure.
|
||||
//
|
||||
@@ -194,10 +203,27 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
|
||||
}
|
||||
addrs = quirkSortIPAddrs(addrs)
|
||||
var errorslist []error
|
||||
trace := ContextTraceOrDefault(ctx)
|
||||
for _, addr := range addrs {
|
||||
target := net.JoinHostPort(addr, onlyport)
|
||||
started := trace.TimeNow()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, target)
|
||||
finished := trace.TimeNow()
|
||||
// TODO(bassosimone): to make the code robust to future refactoring we have
|
||||
// moved error wrapping inside this type. This change opens up the possibility
|
||||
// of simplifying the dialing chain by removing dialerErrWrapper. We'll be
|
||||
// able to implement this refactoring once netx is gone. We cannot complete
|
||||
// this refactoring _before_ because WrapDialer inserts extra wrappers
|
||||
// provided by netx in the dialers chain _before_ this dialer and the dialers
|
||||
// that netx insert assume that they wrap a dialer with error wrapping.
|
||||
//
|
||||
// Because error wrapping should be idempotent, it should not be a problem
|
||||
// to have two error wrapping dialers in the chain except that, of course, it
|
||||
// would be less efficient than just having a single wrapper.
|
||||
err = MaybeNewErrWrapper(ClassifyGenericError, ConnectOperation, err)
|
||||
trace.OnConnectDone(started, network, onlyhost, target, err, finished)
|
||||
if err == nil {
|
||||
conn = &dialerErrWrapperConn{conn}
|
||||
return conn, nil
|
||||
}
|
||||
errorslist = append(errorslist, err)
|
||||
@@ -206,14 +232,14 @@ func (d *dialerResolver) DialContext(ctx context.Context, network, address strin
|
||||
}
|
||||
|
||||
// lookupHost ensures we correctly handle IP addresses.
|
||||
func (d *dialerResolver) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
func (d *dialerResolverWithTracing) lookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
return d.Resolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
func (d *dialerResolver) CloseIdleConnections() {
|
||||
func (d *dialerResolverWithTracing) CloseIdleConnections() {
|
||||
d.Dialer.CloseIdleConnections()
|
||||
d.Resolver.CloseIdleConnections()
|
||||
}
|
||||
@@ -303,7 +329,7 @@ var _ model.Dialer = &dialerErrWrapper{}
|
||||
func (d *dialerErrWrapper) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyGenericError, ConnectOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyGenericError, ConnectOperation, err)
|
||||
}
|
||||
return &dialerErrWrapperConn{Conn: conn}, nil
|
||||
}
|
||||
@@ -322,7 +348,7 @@ var _ net.Conn = &dialerErrWrapperConn{}
|
||||
func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
||||
count, err := c.Conn.Read(b)
|
||||
if err != nil {
|
||||
return 0, newErrWrapper(classifyGenericError, ReadOperation, err)
|
||||
return 0, NewErrWrapper(ClassifyGenericError, ReadOperation, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@@ -330,7 +356,7 @@ func (c *dialerErrWrapperConn) Read(b []byte) (int, error) {
|
||||
func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
||||
count, err := c.Conn.Write(b)
|
||||
if err != nil {
|
||||
return 0, newErrWrapper(classifyGenericError, WriteOperation, err)
|
||||
return 0, NewErrWrapper(ClassifyGenericError, WriteOperation, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@@ -338,7 +364,7 @@ func (c *dialerErrWrapperConn) Write(b []byte) (int, error) {
|
||||
func (c *dialerErrWrapperConn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
if err != nil {
|
||||
return newErrWrapper(classifyGenericError, CloseOperation, err)
|
||||
return NewErrWrapper(ClassifyGenericError, CloseOperation, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestNewDialerWithStdlibResolver(t *testing.T) {
|
||||
@@ -22,7 +23,7 @@ func TestNewDialerWithStdlibResolver(t *testing.T) {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
// typecheck the resolver
|
||||
reso := logger.Dialer.(*dialerResolver)
|
||||
reso := logger.Dialer.(*dialerResolverWithTracing)
|
||||
typecheckForSystemResolver(t, reso.Resolver, model.DiscardLogger)
|
||||
// typecheck the dialer
|
||||
logger = reso.Dialer.(*dialerLogger)
|
||||
@@ -64,7 +65,7 @@ func TestNewDialer(t *testing.T) {
|
||||
if logger.DebugLogger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
reso := logger.Dialer.(*dialerResolver)
|
||||
reso := logger.Dialer.(*dialerResolverWithTracing)
|
||||
if _, okay := reso.Resolver.(*NullResolver); !okay {
|
||||
t.Fatal("invalid Resolver type")
|
||||
}
|
||||
@@ -136,10 +137,10 @@ func TestDialerSystem(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialerResolver(t *testing.T) {
|
||||
func TestDialerResolverWithTracing(t *testing.T) {
|
||||
t.Run("DialContext", func(t *testing.T) {
|
||||
t.Run("fails without a port", func(t *testing.T) {
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &DialerSystem{},
|
||||
Resolver: NewUnwrappedStdlibResolver(),
|
||||
}
|
||||
@@ -154,7 +155,7 @@ func TestDialerResolver(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("handles dialing error correctly for single IP address", func(t *testing.T) {
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
@@ -166,13 +167,26 @@ func TestDialerResolver(t *testing.T) {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("the error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != FailureEOFError {
|
||||
t.Fatal("invalid wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != ConnectOperation {
|
||||
t.Fatal("invalid wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||
t.Fatal("invalid wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles dialing error correctly for many IP addresses", func(t *testing.T) {
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
@@ -188,6 +202,19 @@ func TestDialerResolver(t *testing.T) {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("the error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != FailureEOFError {
|
||||
t.Fatal("invalid wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != ConnectOperation {
|
||||
t.Fatal("invalid wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||
t.Fatal("invalid wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
@@ -199,7 +226,7 @@ func TestDialerResolver(t *testing.T) {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return expectedConn, nil
|
||||
@@ -215,7 +242,10 @@ func TestDialerResolver(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conn != expectedConn {
|
||||
// Ensure that the dialer returns a connection that is already wrapping errors,
|
||||
// which is a new behavior since https://github.com/ooni/probe-cli/pull/815
|
||||
errWrapperConn := conn.(*dialerErrWrapperConn)
|
||||
if errWrapperConn.Conn != expectedConn {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
conn.Close()
|
||||
@@ -225,7 +255,7 @@ func TestDialerResolver(t *testing.T) {
|
||||
// This test is fundamental to the following
|
||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||
mu := &sync.Mutex{}
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
@@ -257,7 +287,7 @@ func TestDialerResolver(t *testing.T) {
|
||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||
mu := &sync.Mutex{}
|
||||
var attempts []string
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
@@ -298,14 +328,14 @@ func TestDialerResolver(t *testing.T) {
|
||||
mu := &sync.Mutex{}
|
||||
errorsList := []error{
|
||||
errors.New("a mocked error"),
|
||||
newErrWrapper(
|
||||
classifyGenericError,
|
||||
NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
io.EOF,
|
||||
),
|
||||
}
|
||||
var errorIdx int
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
@@ -337,17 +367,18 @@ func TestDialerResolver(t *testing.T) {
|
||||
t.Run("though ignores the unknown failures", func(t *testing.T) {
|
||||
// This test is fundamental to the following
|
||||
// TODO(https://github.com/ooni/probe/issues/1779)
|
||||
expectedErr := errors.New("a mocked error")
|
||||
mu := &sync.Mutex{}
|
||||
errorsList := []error{
|
||||
errors.New("a mocked error"),
|
||||
newErrWrapper(
|
||||
classifyGenericError,
|
||||
expectedErr,
|
||||
NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
errors.New("antani"),
|
||||
errors.New("antani"), // this is an unknown failure and we should not return it
|
||||
),
|
||||
}
|
||||
var errorIdx int
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
// It should not happen to have parallel dials with
|
||||
@@ -368,18 +399,95 @@ func TestDialerResolver(t *testing.T) {
|
||||
},
|
||||
}
|
||||
conn, err := d.DialContext(context.Background(), "tcp", "dot.dns:853")
|
||||
if err == nil || err.Error() != "a mocked error" {
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != "unknown_failure: a mocked error" {
|
||||
t.Fatal("unexpected wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != ConnectOperation {
|
||||
t.Fatal("unexpected wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, expectedErr) {
|
||||
t.Fatal("unexpected wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses a context-injected custom trace", func(t *testing.T) {
|
||||
var (
|
||||
called bool
|
||||
domainOK bool
|
||||
networkOK bool
|
||||
remoteAddrOK bool
|
||||
startTimeOK bool
|
||||
finishTimeOK bool
|
||||
wrappedErr bool
|
||||
)
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
var ew *ErrWrapper
|
||||
called = true
|
||||
domainOK = (domain == "1.1.1.1")
|
||||
networkOK = (network == "tcp")
|
||||
remoteAddrOK = (remoteAddr == "1.1.1.1:853")
|
||||
startTimeOK = (started.Sub(zeroTime) == 0)
|
||||
finishTimeOK = (finished.Sub(zeroTime) == time.Second)
|
||||
wrappedErr = errors.As(err, &ew) && ew.Failure == FailureEOFError
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||
return nil, io.EOF
|
||||
},
|
||||
},
|
||||
Resolver: &NullResolver{},
|
||||
}
|
||||
conn, err := d.DialContext(ctx, "tcp", "1.1.1.1:853")
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
if !domainOK {
|
||||
t.Fatal("domain was not okay")
|
||||
}
|
||||
if !networkOK {
|
||||
t.Fatal("network was not okay")
|
||||
}
|
||||
if !remoteAddrOK {
|
||||
t.Fatal("remoteAddr was not okay")
|
||||
}
|
||||
if !startTimeOK {
|
||||
t.Fatal("start time was not okay")
|
||||
}
|
||||
if !finishTimeOK {
|
||||
t.Fatal("finish time was not okay")
|
||||
}
|
||||
if !wrappedErr {
|
||||
t.Fatal("not wrapped")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("lookupHost", func(t *testing.T) {
|
||||
t.Run("handles addresses correctly", func(t *testing.T) {
|
||||
dialer := &dialerResolver{
|
||||
dialer := &dialerResolverWithTracing{
|
||||
Dialer: &DialerSystem{},
|
||||
Resolver: &NullResolver{},
|
||||
}
|
||||
@@ -393,7 +501,7 @@ func TestDialerResolver(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("fails correctly on lookup error", func(t *testing.T) {
|
||||
dialer := &dialerResolver{
|
||||
dialer := &dialerResolverWithTracing{
|
||||
Dialer: &DialerSystem{},
|
||||
Resolver: &NullResolver{},
|
||||
}
|
||||
@@ -413,7 +521,7 @@ func TestDialerResolver(t *testing.T) {
|
||||
calledDialer bool
|
||||
calledResolver bool
|
||||
)
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockCloseIdleConnections: func() {
|
||||
calledDialer = true
|
||||
|
||||
@@ -38,7 +38,7 @@ func (t *dnsTransportErrWrapper) RoundTrip(
|
||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
resp, err := t.DNSTransport.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, DNSRoundTripOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, DNSRoundTripOperation, err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -34,6 +34,10 @@
|
||||
//
|
||||
// We want to have reasonable watchdog timeouts for each operation.
|
||||
//
|
||||
// We also want lightweight support for tracing network events. To this end, we
|
||||
// use context.WithValue and context.Value to inject, and retrieve, a model.Trace
|
||||
// implementation OPTIONALLY configured by the user.
|
||||
//
|
||||
// See also the design document at docs/design/dd-003-step-by-step.md,
|
||||
// which provides an overview of netxlite's main concerns.
|
||||
//
|
||||
|
||||
@@ -71,7 +71,7 @@ func (e *ErrWrapper) MarshalJSON() ([]byte, error) {
|
||||
// https://github.com/ooni/spec/blob/master/data-formats/df-007-errors.md.
|
||||
type classifier func(err error) string
|
||||
|
||||
// newErrWrapper creates a new ErrWrapper using the given
|
||||
// NewErrWrapper creates a new ErrWrapper using the given
|
||||
// classifier, operation name, and underlying error.
|
||||
//
|
||||
// This function panics if classifier is nil, or operation
|
||||
@@ -81,7 +81,7 @@ type classifier func(err error) string
|
||||
// error wrapper will use the same classification string and
|
||||
// will determine whether to keep the major operation as documented
|
||||
// in the ErrWrapper.Operation documentation.
|
||||
func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||
func NewErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||
var wrapper *ErrWrapper
|
||||
if errors.As(err, &wrapper) {
|
||||
return &ErrWrapper{
|
||||
@@ -106,6 +106,19 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(https://github.com/ooni/probe/issues/2163): we can really
|
||||
// simplify the error wrapping situation here by just dropping
|
||||
// NewErrWrapper and always using MaybeNewErrWrapper.
|
||||
|
||||
// MaybeNewErrWrapper is like NewErrWrapper except that this
|
||||
// function won't panic if passed a nil error.
|
||||
func MaybeNewErrWrapper(c classifier, op string, err error) error {
|
||||
if err != nil {
|
||||
return NewErrWrapper(c, op, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewTopLevelGenericErrWrapper wraps an error occurring at top
|
||||
// level using a generic classifier as classifier. This is the
|
||||
// function you should call when you suspect a given error hasn't
|
||||
@@ -115,7 +128,7 @@ func newErrWrapper(c classifier, op string, err error) *ErrWrapper {
|
||||
// error wrapper will use the same classification string and
|
||||
// failed operation of the original error.
|
||||
func NewTopLevelGenericErrWrapper(err error) *ErrWrapper {
|
||||
return newErrWrapper(classifyGenericError, TopLevelOperation, err)
|
||||
return NewErrWrapper(ClassifyGenericError, TopLevelOperation, err)
|
||||
}
|
||||
|
||||
func classifyOperation(ew *ErrWrapper, operation string) string {
|
||||
|
||||
@@ -53,7 +53,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||
recovered.Add(1)
|
||||
}
|
||||
}()
|
||||
newErrWrapper(nil, CloseOperation, io.EOF)
|
||||
NewErrWrapper(nil, CloseOperation, io.EOF)
|
||||
}()
|
||||
if recovered.Load() != 1 {
|
||||
t.Fatal("did not panic")
|
||||
@@ -68,7 +68,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||
recovered.Add(1)
|
||||
}
|
||||
}()
|
||||
newErrWrapper(classifyGenericError, "", io.EOF)
|
||||
NewErrWrapper(ClassifyGenericError, "", io.EOF)
|
||||
}()
|
||||
if recovered.Load() != 1 {
|
||||
t.Fatal("did not panic")
|
||||
@@ -83,7 +83,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||
recovered.Add(1)
|
||||
}
|
||||
}()
|
||||
newErrWrapper(classifyGenericError, CloseOperation, nil)
|
||||
NewErrWrapper(ClassifyGenericError, CloseOperation, nil)
|
||||
}()
|
||||
if recovered.Load() != 1 {
|
||||
t.Fatal("did not panic")
|
||||
@@ -91,7 +91,7 @@ func TestNewErrWrapper(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("otherwise, works as intended", func(t *testing.T) {
|
||||
ew := newErrWrapper(classifyGenericError, CloseOperation, io.EOF)
|
||||
ew := NewErrWrapper(ClassifyGenericError, CloseOperation, io.EOF)
|
||||
if ew.Failure != FailureEOFError {
|
||||
t.Fatal("unexpected failure")
|
||||
}
|
||||
@@ -104,10 +104,10 @@ func TestNewErrWrapper(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("when the underlying error is already a wrapped error", func(t *testing.T) {
|
||||
ew := newErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||
ew := NewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||
var err1 error = ew
|
||||
err2 := fmt.Errorf("cannot read: %w", err1)
|
||||
ew2 := newErrWrapper(classifyGenericError, HTTPRoundTripOperation, err2)
|
||||
ew2 := NewErrWrapper(ClassifyGenericError, HTTPRoundTripOperation, err2)
|
||||
if ew2.Failure != ew.Failure {
|
||||
t.Fatal("not the same failure")
|
||||
}
|
||||
@@ -117,6 +117,34 @@ func TestNewErrWrapper(t *testing.T) {
|
||||
if ew2.WrappedErr != err2 {
|
||||
t.Fatal("invalid underlying error")
|
||||
}
|
||||
// Make sure we can still use errors.Is with two layers of wrapping
|
||||
if !errors.Is(ew2, ECONNRESET) {
|
||||
t.Fatal("we cannot use errors.Is to retrieve the real syscall error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaybeNewErrWrapper(t *testing.T) {
|
||||
// TODO(https://github.com/ooni/probe/issues/2163): we can really
|
||||
// simplify the error wrapping situation here by just dropping
|
||||
// NewErrWrapper and always using MaybeNewErrWrapper.
|
||||
|
||||
t.Run("when we pass a nil error to this function", func(t *testing.T) {
|
||||
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, nil)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected output", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when we pass a non-nil error to this function", func(t *testing.T) {
|
||||
err := MaybeNewErrWrapper(classifySyscallError, ReadOperation, ECONNRESET)
|
||||
if !errors.Is(err, ECONNRESET) {
|
||||
t.Fatal("unexpected output", err)
|
||||
}
|
||||
var ew *ErrWrapper
|
||||
if !errors.As(err, &ew) {
|
||||
t.Fatal("not an instance of ErrWrapper")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
|
||||
dialer := txpCc.Dialer
|
||||
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
||||
dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolver)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
|
||||
if dialerReso.Resolver != resolver {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) {
|
||||
dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout)
|
||||
dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer)
|
||||
dialerLog := dialerProxy.Dialer.(*dialerLogger)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolver)
|
||||
dialerReso := dialerLog.Dialer.(*dialerResolverWithTracing)
|
||||
if dialerReso.Resolver != resolver {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
@@ -269,7 +269,7 @@ func TestNewHTTPTransport(t *testing.T) {
|
||||
t.Run("works as intended with failing dialer", func(t *testing.T) {
|
||||
called := &atomicx.Int64{}
|
||||
expected := errors.New("mocked error")
|
||||
d := &dialerResolver{
|
||||
d := &dialerResolverWithTracing{
|
||||
Dialer: &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context,
|
||||
network, address string) (net.Conn, error) {
|
||||
@@ -612,7 +612,7 @@ func TestNewHTTPClientWithResolver(t *testing.T) {
|
||||
txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser)
|
||||
dialer := txpCc.Dialer.(*httpDialerWithReadTimeout)
|
||||
dialerLogger := dialer.Dialer.(*dialerLogger)
|
||||
dialerReso := dialerLogger.Dialer.(*dialerResolver)
|
||||
dialerReso := dialerLogger.Dialer.(*dialerResolverWithTracing)
|
||||
if dialerReso.Resolver != reso {
|
||||
t.Fatal("invalid resolver")
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestReadAllContext(t *testing.T) {
|
||||
//
|
||||
// Note: Returning a wrapped error to ensure we address
|
||||
// https://github.com/ooni/probe/issues/1965
|
||||
return len(b), newErrWrapper(classifyGenericError,
|
||||
return len(b), NewErrWrapper(ClassifyGenericError,
|
||||
ReadOperation, io.EOF)
|
||||
},
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func TestCopyContext(t *testing.T) {
|
||||
//
|
||||
// Note: Returning a wrapped error to ensure we address
|
||||
// https://github.com/ooni/probe/issues/1965
|
||||
return len(b), newErrWrapper(classifyGenericError,
|
||||
return len(b), NewErrWrapper(ClassifyGenericError,
|
||||
ReadOperation, io.EOF)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -380,7 +380,7 @@ var _ model.QUICListener = &quicListenerErrWrapper{}
|
||||
func (qls *quicListenerErrWrapper) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
|
||||
pconn, err := qls.QUICListener.Listen(addr)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyGenericError, QUICListenOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyGenericError, QUICListenOperation, err)
|
||||
}
|
||||
return &quicErrWrapperUDPLikeConn{pconn}, nil
|
||||
}
|
||||
@@ -397,7 +397,7 @@ var _ model.UDPLikeConn = &quicErrWrapperUDPLikeConn{}
|
||||
func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
count, err := c.UDPLikeConn.WriteTo(p, addr)
|
||||
if err != nil {
|
||||
return 0, newErrWrapper(classifyGenericError, WriteToOperation, err)
|
||||
return 0, NewErrWrapper(ClassifyGenericError, WriteToOperation, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
@@ -406,7 +406,7 @@ func (c *quicErrWrapperUDPLikeConn) WriteTo(p []byte, addr net.Addr) (int, error
|
||||
func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
n, addr, err := c.UDPLikeConn.ReadFrom(b)
|
||||
if err != nil {
|
||||
return 0, nil, newErrWrapper(classifyGenericError, ReadFromOperation, err)
|
||||
return 0, nil, NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
|
||||
}
|
||||
return n, addr, nil
|
||||
}
|
||||
@@ -415,7 +415,7 @@ func (c *quicErrWrapperUDPLikeConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
func (c *quicErrWrapperUDPLikeConn) Close() error {
|
||||
err := c.UDPLikeConn.Close()
|
||||
if err != nil {
|
||||
return newErrWrapper(classifyGenericError, ReadFromOperation, err)
|
||||
return NewErrWrapper(ClassifyGenericError, ReadFromOperation, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -433,8 +433,8 @@ func (d *quicDialerErrWrapper) DialContext(
|
||||
tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
qconn, err := d.QUICDialer.DialContext(ctx, network, host, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(
|
||||
classifyQUICHandshakeError, QUICHandshakeOperation, err)
|
||||
return nil, NewErrWrapper(
|
||||
ClassifyQUICHandshakeError, QUICHandshakeOperation, err)
|
||||
}
|
||||
return qconn, nil
|
||||
}
|
||||
|
||||
@@ -34,13 +34,13 @@ func TestQuirkReduceErrors(t *testing.T) {
|
||||
|
||||
t.Run("multiple errors with meaningful ones", func(t *testing.T) {
|
||||
err1 := errors.New("mocked error #1")
|
||||
err2 := newErrWrapper(
|
||||
classifyGenericError,
|
||||
err2 := NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
errors.New("antani"),
|
||||
)
|
||||
err3 := newErrWrapper(
|
||||
classifyGenericError,
|
||||
err3 := NewErrWrapper(
|
||||
ClassifyGenericError,
|
||||
CloseOperation,
|
||||
ECONNREFUSED,
|
||||
)
|
||||
|
||||
@@ -387,7 +387,7 @@ var _ model.Resolver = &resolverErrWrapper{}
|
||||
func (r *resolverErrWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
@@ -396,7 +396,7 @@ func (r *resolverErrWrapper) LookupHTTPS(
|
||||
ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||
out, err := r.Resolver.LookupHTTPS(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -417,7 +417,7 @@ func (r *resolverErrWrapper) LookupNS(
|
||||
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
out, err := r.Resolver.LookupNS(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
||||
return nil, NewErrWrapper(ClassifyResolverError, ResolveOperation, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
+33
-26
@@ -18,9 +18,6 @@ import (
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
// TODO(bassosimone): check whether there's now equivalent functionality
|
||||
// inside the standard library allowing us to map numbers to names.
|
||||
|
||||
var (
|
||||
tlsVersionString = map[uint16]string{
|
||||
tls.VersionTLS10: "TLSv1",
|
||||
@@ -85,6 +82,13 @@ func TLSVersionString(value uint16) string {
|
||||
// the value to a cipher suite name, we return `TLS_CIPHER_SUITE_UNKNOWN_ddd`
|
||||
// where `ddd` is the numeric value passed to this function.
|
||||
func TLSCipherSuiteString(value uint16) string {
|
||||
// TODO(https://github.com/ooni/probe/issues/2166): the standard library has a
|
||||
// function for mapping a cipher suite to a string, but the value returned in case of
|
||||
// missing cipher suite is different from the one we would return
|
||||
// here. We could consider simplifying this code anyway because
|
||||
// in most, if not all, cases we have a valid cipher suite and we
|
||||
// just need to make sure what the spec says we should do when
|
||||
// passed an unknown cipher suite.
|
||||
if str, found := tlsCipherSuiteString[value]; found {
|
||||
return str
|
||||
}
|
||||
@@ -158,15 +162,15 @@ func NewTLSHandshakerStdlib(logger model.DebugLogger) model.TLSHandshaker {
|
||||
// newTLSHandshaker is the common factory for creating a new TLSHandshaker
|
||||
func newTLSHandshaker(th model.TLSHandshaker, logger model.DebugLogger) model.TLSHandshaker {
|
||||
return &tlsHandshakerLogger{
|
||||
TLSHandshaker: &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: th,
|
||||
},
|
||||
DebugLogger: logger,
|
||||
TLSHandshaker: th,
|
||||
DebugLogger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// tlsHandshakerConfigurable is a configurable TLS handshaker that
|
||||
// uses by default the standard library's TLS implementation.
|
||||
//
|
||||
// This type also implements error wrapping and events tracing.
|
||||
type tlsHandshakerConfigurable struct {
|
||||
// NewConn is the OPTIONAL factory for creating a new connection. If
|
||||
// this factory is not set, we'll use the stdlib.
|
||||
@@ -183,9 +187,20 @@ var _ model.TLSHandshaker = &tlsHandshakerConfigurable{}
|
||||
// value into a private variable to enable for unit testing.
|
||||
var defaultCertPool = NewDefaultCertPool()
|
||||
|
||||
// tlsMaybeConnectionState returns the connection state if error is nil
|
||||
// and otherwise just returns an empty state to the caller.
|
||||
func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState {
|
||||
if err != nil {
|
||||
return tls.ConnectionState{}
|
||||
}
|
||||
return conn.ConnectionState()
|
||||
}
|
||||
|
||||
// Handshake implements Handshaker.Handshake. This function will
|
||||
// configure the code to use the built-in Mozilla CA if the config
|
||||
// field contains a nil RootCAs field.
|
||||
//
|
||||
// This function will also emit TLS-handshake-related tracing events.
|
||||
func (h *tlsHandshakerConfigurable) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
@@ -203,10 +218,19 @@ func (h *tlsHandshakerConfigurable) Handshake(
|
||||
if err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
if err := tlsconn.HandshakeContext(ctx); err != nil {
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
trace := ContextTraceOrDefault(ctx)
|
||||
started := trace.TimeNow()
|
||||
trace.OnTLSHandshakeStart(started, remoteAddr, config)
|
||||
err = tlsconn.HandshakeContext(ctx)
|
||||
err = MaybeNewErrWrapper(ClassifyTLSHandshakeError, TLSHandshakeOperation, err)
|
||||
finished := trace.TimeNow()
|
||||
state := tlsMaybeConnectionState(tlsconn, err)
|
||||
trace.OnTLSHandshakeDone(started, remoteAddr, config, state, err, finished)
|
||||
if err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
return tlsconn, tlsconn.ConnectionState(), nil
|
||||
return tlsconn, state, nil
|
||||
}
|
||||
|
||||
// newConn creates a new TLSConn.
|
||||
@@ -352,23 +376,6 @@ func (d *tlsDialerSingleUseAdapter) CloseIdleConnections() {
|
||||
d.Dialer.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// tlsHandshakerErrWrapper wraps the returned error to be an OONI error
|
||||
type tlsHandshakerErrWrapper struct {
|
||||
TLSHandshaker model.TLSHandshaker
|
||||
}
|
||||
|
||||
// Handshake implements TLSHandshaker.Handshake
|
||||
func (h *tlsHandshakerErrWrapper) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
tlsconn, state, err := h.TLSHandshaker.Handshake(ctx, conn, config)
|
||||
if err != nil {
|
||||
return nil, tls.ConnectionState{}, newErrWrapper(
|
||||
classifyTLSHandshakeError, TLSHandshakeOperation, err)
|
||||
}
|
||||
return tlsconn, state, nil
|
||||
}
|
||||
|
||||
// ErrNoTLSDialer is the type of error returned by "null" TLS dialers
|
||||
// when you attempt to dial with them.
|
||||
var ErrNoTLSDialer = errors.New("no configured TLS dialer")
|
||||
|
||||
+301
-54
@@ -16,7 +16,10 @@ import (
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestVersionString(t *testing.T) {
|
||||
@@ -123,8 +126,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||
if logger.DebugLogger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if configurable.NewConn != nil {
|
||||
t.Fatal("expected nil NewConn")
|
||||
}
|
||||
@@ -132,7 +134,7 @@ func TestNewTLSHandshakerStdlib(t *testing.T) {
|
||||
|
||||
func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
t.Run("Handshake", func(t *testing.T) {
|
||||
t.Run("with error", func(t *testing.T) {
|
||||
t.Run("with handshake I/O error", func(t *testing.T) {
|
||||
var times []time.Time
|
||||
h := &tlsHandshakerConfigurable{}
|
||||
tcpConn := &mocks.Conn{
|
||||
@@ -143,14 +145,37 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
times = append(times, t)
|
||||
return nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
||||
conn, state, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
||||
ServerName: "x.org",
|
||||
})
|
||||
if err != io.EOF {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
var errWrapper *ErrWrapper
|
||||
if !errors.As(err, &errWrapper) {
|
||||
t.Fatal("the error has not been wrapped")
|
||||
}
|
||||
if errWrapper.Failure != FailureEOFError {
|
||||
t.Fatal("invalid wrapped error's failure")
|
||||
}
|
||||
if errWrapper.Operation != TLSHandshakeOperation {
|
||||
t.Fatal("invalid wrapped error's operation")
|
||||
}
|
||||
if !errors.Is(errWrapper.WrappedErr, io.EOF) {
|
||||
t.Fatal("invalid wrapped error's underlying error")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
@@ -163,6 +188,9 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
if !times[1].IsZero() {
|
||||
t.Fatal("did not clear timeout on exit")
|
||||
}
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("the returned connection state is not a zero value")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with success", func(t *testing.T) {
|
||||
@@ -217,6 +245,16 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockString: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||
if !errors.Is(err, expected) {
|
||||
@@ -236,7 +274,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("we cannot create a new conn", func(t *testing.T) {
|
||||
t.Run("h.newConn fails", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
handshaker := &tlsHandshakerConfigurable{
|
||||
NewConn: func(conn net.Conn, config *tls.Config) (TLSConn, error) {
|
||||
@@ -261,6 +299,222 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
|
||||
t.Fatal("expected nil tlsConn here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
|
||||
var (
|
||||
expectedSNI = "dns.google"
|
||||
goodStartStartTime bool
|
||||
goodStartInsecureSkipVerify bool
|
||||
goodDoneInsecureSkipVerify bool
|
||||
goodStartServerName bool
|
||||
goodDoneServerName bool
|
||||
goodDoneStartTime bool
|
||||
goodDoneDoneTime bool
|
||||
goodStartRemoteAddr bool
|
||||
goodDoneRemoteAddr bool
|
||||
goodDoneError bool
|
||||
goodConnectionState bool
|
||||
startCalled bool
|
||||
doneCalled bool
|
||||
)
|
||||
server := filtering.NewTLSServer(filtering.TLSActionBlockText)
|
||||
defer server.Close()
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
startCalled = true
|
||||
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodStartServerName = (config.ServerName == expectedSNI)
|
||||
goodStartStartTime = (now.Sub(zeroTime) == 0)
|
||||
goodStartRemoteAddr = (remoteAddr == server.Endpoint())
|
||||
},
|
||||
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||
doneCalled = true
|
||||
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodDoneServerName = (config.ServerName == expectedSNI)
|
||||
goodDoneStartTime = (started.Sub(zeroTime) == 0)
|
||||
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
|
||||
goodDoneRemoteAddr = (remoteAddr == server.Endpoint())
|
||||
goodDoneError = (err == nil)
|
||||
goodConnectionState = (!reflect.ValueOf(state).IsZero())
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
tcpConn, err := net.Dial("tcp", server.Endpoint())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: expectedSNI,
|
||||
}
|
||||
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tlsConn.Close()
|
||||
if reflect.ValueOf(connState).IsZero() {
|
||||
t.Fatal("expected nonzero connState")
|
||||
}
|
||||
if !startCalled {
|
||||
t.Fatal("start not called")
|
||||
}
|
||||
if !doneCalled {
|
||||
t.Fatal("done not called")
|
||||
}
|
||||
if !goodStartInsecureSkipVerify {
|
||||
t.Fatal("invalid start-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodDoneInsecureSkipVerify {
|
||||
t.Fatal("invalid done-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodStartServerName {
|
||||
t.Fatal("invalid start-event's ServerName")
|
||||
}
|
||||
if !goodDoneServerName {
|
||||
t.Fatal("invalid done-event's ServerName")
|
||||
}
|
||||
if !goodStartStartTime {
|
||||
t.Fatal("invalid start-event's start time")
|
||||
}
|
||||
if !goodDoneStartTime {
|
||||
t.Fatal("invalid done-event's start time")
|
||||
}
|
||||
if !goodDoneDoneTime {
|
||||
t.Fatal("invalid done-event's done time")
|
||||
}
|
||||
if !goodStartRemoteAddr {
|
||||
t.Fatal("invalid start-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneRemoteAddr {
|
||||
t.Fatal("invalid done-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneError {
|
||||
t.Fatal("invalid done-event's error")
|
||||
}
|
||||
if !goodConnectionState {
|
||||
t.Fatal("invalid done-event's connState")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
|
||||
var (
|
||||
expectedEndpoint = "8.8.8.8:443"
|
||||
expectedSNI = "dns.google"
|
||||
goodStartStartTime bool
|
||||
goodStartInsecureSkipVerify bool
|
||||
goodDoneInsecureSkipVerify bool
|
||||
goodStartServerName bool
|
||||
goodDoneServerName bool
|
||||
goodDoneStartTime bool
|
||||
goodDoneDoneTime bool
|
||||
goodStartRemoteAddr bool
|
||||
goodDoneRemoteAddr bool
|
||||
goodDoneError bool
|
||||
goodConnectionState bool
|
||||
startCalled bool
|
||||
doneCalled bool
|
||||
)
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnTLSHandshakeStart: func(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
startCalled = true
|
||||
goodStartInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodStartServerName = (config.ServerName == expectedSNI)
|
||||
goodStartStartTime = (now.Sub(zeroTime) == 0)
|
||||
goodStartRemoteAddr = (remoteAddr == expectedEndpoint)
|
||||
},
|
||||
MockOnTLSHandshakeDone: func(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) {
|
||||
doneCalled = true
|
||||
goodDoneInsecureSkipVerify = (config.InsecureSkipVerify == true)
|
||||
goodDoneServerName = (config.ServerName == expectedSNI)
|
||||
goodDoneStartTime = (started.Sub(zeroTime) == 0)
|
||||
goodDoneDoneTime = (finished.Sub(zeroTime) == time.Second)
|
||||
goodDoneRemoteAddr = (remoteAddr == expectedEndpoint)
|
||||
var ew *ErrWrapper
|
||||
goodDoneError = (errors.As(err, &ew) && ew.Error() == FailureEOFError)
|
||||
goodConnectionState = (reflect.ValueOf(state).IsZero())
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
tcpConn := &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockString: func() string {
|
||||
return expectedEndpoint
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
thx := NewTLSHandshakerStdlib(model.DiscardLogger)
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: expectedSNI,
|
||||
}
|
||||
tlsConn, connState, err := thx.Handshake(ctx, tcpConn, tlsConfig)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if tlsConn != nil {
|
||||
t.Fatal("expected nil tlsConn")
|
||||
}
|
||||
if !reflect.ValueOf(connState).IsZero() {
|
||||
t.Fatal("expected zero connState")
|
||||
}
|
||||
if !startCalled {
|
||||
t.Fatal("start not called")
|
||||
}
|
||||
if !doneCalled {
|
||||
t.Fatal("done not called")
|
||||
}
|
||||
if !goodStartInsecureSkipVerify {
|
||||
t.Fatal("invalid start-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodDoneInsecureSkipVerify {
|
||||
t.Fatal("invalid done-event's InsecureSkipVerify")
|
||||
}
|
||||
if !goodStartServerName {
|
||||
t.Fatal("invalid start-event's ServerName")
|
||||
}
|
||||
if !goodDoneServerName {
|
||||
t.Fatal("invalid done-event's ServerName")
|
||||
}
|
||||
if !goodStartStartTime {
|
||||
t.Fatal("invalid start-event's start time")
|
||||
}
|
||||
if !goodDoneStartTime {
|
||||
t.Fatal("invalid done-event's start time")
|
||||
}
|
||||
if !goodDoneDoneTime {
|
||||
t.Fatal("invalid done-event's done time")
|
||||
}
|
||||
if !goodStartRemoteAddr {
|
||||
t.Fatal("invalid start-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneRemoteAddr {
|
||||
t.Fatal("invalid done-event's remoteAddr")
|
||||
}
|
||||
if !goodDoneError {
|
||||
t.Fatal("invalid done-event's error")
|
||||
}
|
||||
if !goodConnectionState {
|
||||
t.Fatal("invalid done-event's connState")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -413,6 +667,15 @@ func TestTLSDialer(t *testing.T) {
|
||||
return nil
|
||||
}, MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
}, MockRemoteAddr: func() net.Addr {
|
||||
return &mocks.Addr{
|
||||
MockNetwork: func() string {
|
||||
return "1.1.1.1:443"
|
||||
},
|
||||
MockString: func() string {
|
||||
return "tcp"
|
||||
},
|
||||
}
|
||||
}}, nil
|
||||
}},
|
||||
TLSHandshaker: &tlsHandshakerConfigurable{},
|
||||
@@ -532,54 +795,6 @@ func TestNewSingleUseTLSDialer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerErrWrapper(t *testing.T) {
|
||||
t.Run("Handshake", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expectedConn := &mocks.TLSConn{}
|
||||
expectedState := tls.ConnectionState{
|
||||
Version: tls.VersionTLS12,
|
||||
}
|
||||
th := &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return expectedConn, expectedState, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, state, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if expectedState.Version != state.Version {
|
||||
t.Fatal("unexpected state")
|
||||
}
|
||||
if expectedConn != conn {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expectedErr := io.EOF
|
||||
th := &tlsHandshakerErrWrapper{
|
||||
TLSHandshaker: &mocks.TLSHandshaker{
|
||||
MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
|
||||
return nil, tls.ConnectionState{}, expectedErr
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, _, err := th.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
|
||||
if err == nil || err.Error() != FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("unexpected conn")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewNullTLSDialer(t *testing.T) {
|
||||
dialer := NewNullTLSDialer()
|
||||
conn, err := dialer.DialTLSContext(context.Background(), "", "")
|
||||
@@ -618,3 +833,35 @@ func TestClonedTLSConfigOrNewEmptyConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMaybeConnectionState(t *testing.T) {
|
||||
t.Run("with an error", func(t *testing.T) {
|
||||
returned := tls.ConnectionState{
|
||||
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||||
}
|
||||
conn := &mocks.TLSConn{
|
||||
MockConnectionState: func() tls.ConnectionState {
|
||||
return returned
|
||||
},
|
||||
}
|
||||
state := tlsMaybeConnectionState(conn, errors.New("mocked error"))
|
||||
if !reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("expected to see a zero connection state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without an error", func(t *testing.T) {
|
||||
returned := tls.ConnectionState{
|
||||
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
|
||||
}
|
||||
conn := &mocks.TLSConn{
|
||||
MockConnectionState: func() tls.ConnectionState {
|
||||
return returned
|
||||
},
|
||||
}
|
||||
state := tlsMaybeConnectionState(conn, nil)
|
||||
if reflect.ValueOf(state).IsZero() {
|
||||
t.Fatal("expected to see a nonzero connection state")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
package netxlite
|
||||
|
||||
//
|
||||
// Context-based tracing
|
||||
//
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
// traceKey is the private type used to set/retrieve the context's trace.
|
||||
type traceKey struct{}
|
||||
|
||||
// ContextTraceOrDefault retrieves the trace bound to the context or returns
|
||||
// a default implementation of the trace in case no tracing was configured.
|
||||
func ContextTraceOrDefault(ctx context.Context) model.Trace {
|
||||
t, _ := ctx.Value(traceKey{}).(model.Trace)
|
||||
return traceOrDefault(t)
|
||||
}
|
||||
|
||||
// ContextWithTrace returns a new context that binds to the given trace. If the
|
||||
// given trace is nil, this function will call panic.
|
||||
func ContextWithTrace(ctx context.Context, trace model.Trace) context.Context {
|
||||
runtimex.PanicIfTrue(trace == nil, "netxlite.WithTrace passed a nil trace")
|
||||
return context.WithValue(ctx, traceKey{}, trace)
|
||||
}
|
||||
|
||||
// traceOrDefault takes in input a trace and returns in output the
|
||||
// given trace, if not nil, or a default trace implementation.
|
||||
func traceOrDefault(trace model.Trace) model.Trace {
|
||||
if trace != nil {
|
||||
return trace
|
||||
}
|
||||
return &traceDefault{}
|
||||
}
|
||||
|
||||
// traceDefault is a default model.Trace implementation where each method is a no-op.
|
||||
type traceDefault struct{}
|
||||
|
||||
var _ model.Trace = &traceDefault{}
|
||||
|
||||
// TimeNow implements model.Trace.TimeNow
|
||||
func (*traceDefault) TimeNow() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
// OnConnectDone implements model.Trace.OnConnectDone.
|
||||
func (*traceDefault) OnConnectDone(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
// nothing
|
||||
}
|
||||
|
||||
// OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart.
|
||||
func (*traceDefault) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) {
|
||||
// nothing
|
||||
}
|
||||
|
||||
// OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone.
|
||||
func (*traceDefault) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config,
|
||||
state tls.ConnectionState, err error, finished time.Time) {
|
||||
// nothing
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package netxlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestContextTraceOrDefault(t *testing.T) {
|
||||
t.Run("without a configured trace we get a default", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := ContextTraceOrDefault(ctx)
|
||||
_ = tx.(*traceDefault) // panic if cannot cast
|
||||
})
|
||||
|
||||
t.Run("with a configured trace we get the expected trace", func(t *testing.T) {
|
||||
realTrace := &mocks.Trace{}
|
||||
ctx := ContextWithTrace(context.Background(), realTrace)
|
||||
tx := ContextTraceOrDefault(ctx)
|
||||
if tx != realTrace {
|
||||
t.Fatal("not the trace we expected")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextWithTrace(t *testing.T) {
|
||||
t.Run("panics if passed a nil trace", func(t *testing.T) {
|
||||
var called bool
|
||||
func() {
|
||||
defer func() {
|
||||
called = (recover() != nil)
|
||||
}()
|
||||
_ = ContextWithTrace(context.Background(), nil)
|
||||
}()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -19,8 +19,7 @@ func TestNewTLSHandshakerUTLS(t *testing.T) {
|
||||
if logger.DebugLogger != log.Log {
|
||||
t.Fatal("invalid logger")
|
||||
}
|
||||
errWrapper := logger.TLSHandshaker.(*tlsHandshakerErrWrapper)
|
||||
configurable := errWrapper.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
configurable := logger.TLSHandshaker.(*tlsHandshakerConfigurable)
|
||||
if configurable.NewConn == nil {
|
||||
t.Fatal("expected non-nil NewConn")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user