refactor: move tls handshaker to netxlite (#400)

Part of https://github.com/ooni/probe/issues/1505
This commit is contained in:
Simone Basso
2021-06-25 11:07:26 +02:00
committed by GitHub
parent b8428b302f
commit 6b7d270bda
15 changed files with 182 additions and 172 deletions
+2 -2
View File
@@ -34,8 +34,8 @@ func newDialManager(ndt7URL string, logger model.Logger, userAgent string) dialM
}
func (mgr dialManager) dialWithTestName(ctx context.Context, testName string) (*websocket.Conn, error) {
var reso resolver.Resolver = netxlite.ResolverSystem{}
reso = netxlite.ResolverLogger{Resolver: reso, Logger: mgr.logger}
var reso resolver.Resolver = &netxlite.ResolverSystem{}
reso = &netxlite.ResolverLogger{Resolver: reso, Logger: mgr.logger}
dlr := dialer.New(&dialer.Config{
ContextByteCounting: true,
Logger: mgr.logger,
+2 -3
View File
@@ -14,6 +14,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/dialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
// Dialer performs measurements while dialing.
@@ -107,9 +108,7 @@ func newTLSDialer(d dialer.Dialer, config *tls.Config) tlsdialer.TLSDialer {
Dialer: d,
TLSHandshaker: tlsdialer.EmitterTLSHandshaker{
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
},
},
}
+1 -1
View File
@@ -159,7 +159,7 @@ func resolverWrapTransport(txp resolver.RoundTripper) resolver.EmitterResolver {
}
func newResolverSystem() resolver.EmitterResolver {
return resolverWrapResolver(netxlite.ResolverSystem{})
return resolverWrapResolver(&netxlite.ResolverSystem{})
}
func newResolverUDP(dialer resolver.Dialer, address string) resolver.EmitterResolver {
+4 -5
View File
@@ -115,7 +115,7 @@ var defaultCertPool *x509.CertPool = tlsx.NewDefaultCertPool()
// NewResolver creates a new resolver from the specified config
func NewResolver(config Config) Resolver {
if config.BaseResolver == nil {
config.BaseResolver = netxlite.ResolverSystem{}
config.BaseResolver = &netxlite.ResolverSystem{}
}
var r Resolver = config.BaseResolver
r = resolver.AddressResolver{Resolver: r}
@@ -134,7 +134,7 @@ func NewResolver(config Config) Resolver {
}
r = resolver.ErrorWrapperResolver{Resolver: r}
if config.Logger != nil {
r = netxlite.ResolverLogger{Logger: config.Logger, Resolver: r}
r = &netxlite.ResolverLogger{Logger: config.Logger, Resolver: r}
}
if config.ResolveSaver != nil {
r = resolver.SaverResolver{Resolver: r, Saver: config.ResolveSaver}
@@ -176,8 +176,7 @@ func NewTLSDialer(config Config) TLSDialer {
if config.Dialer == nil {
config.Dialer = NewDialer(config)
}
var h tlsHandshaker = tlsdialer.SystemTLSHandshaker{}
h = tlsdialer.TimeoutTLSHandshaker{TLSHandshaker: h}
var h tlsHandshaker = &netxlite.TLSHandshakerStdlib{}
h = tlsdialer.ErrorWrapperTLSHandshaker{TLSHandshaker: h}
if config.Logger != nil {
h = tlsdialer.LoggingTLSHandshaker{Logger: config.Logger, TLSHandshaker: h}
@@ -318,7 +317,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
}
switch resolverURL.Scheme {
case "system":
c.Resolver = netxlite.ResolverSystem{}
c.Resolver = &netxlite.ResolverSystem{}
return c, nil
case "https":
config.TLSConfig.NextProtos = []string{"h2", "http/1.1"}
+16 -40
View File
@@ -32,7 +32,7 @@ func TestNewResolverVanilla(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
if !ok {
t.Fatal("not the resolver we expected")
}
@@ -82,7 +82,7 @@ func TestNewResolverWithBogonFilter(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
if !ok {
t.Fatal("not the resolver we expected")
}
@@ -96,7 +96,7 @@ func TestNewResolverWithLogging(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
lr, ok := ir.Resolver.(netxlite.ResolverLogger)
lr, ok := ir.Resolver.(*netxlite.ResolverLogger)
if !ok {
t.Fatal("not the resolver we expected")
}
@@ -111,7 +111,7 @@ func TestNewResolverWithLogging(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
if !ok {
t.Fatal("not the resolver we expected")
}
@@ -141,7 +141,7 @@ func TestNewResolverWithSaver(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
if !ok {
t.Fatal("not the resolver we expected")
}
@@ -170,7 +170,7 @@ func TestNewResolverWithReadWriteCache(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
if !ok {
t.Fatal("not the resolver we expected")
}
@@ -204,7 +204,7 @@ func TestNewResolverWithPrefilledReadonlyCache(t *testing.T) {
if !ok {
t.Fatal("not the resolver we expected")
}
_, ok = ar.Resolver.(netxlite.ResolverSystem)
_, ok = ar.Resolver.(*netxlite.ResolverSystem)
if !ok {
t.Fatal("not the resolver we expected")
}
@@ -235,11 +235,7 @@ func TestNewTLSDialerVanilla(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@@ -268,11 +264,7 @@ func TestNewTLSDialerWithConfig(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@@ -311,11 +303,7 @@ func TestNewTLSDialerWithLogging(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@@ -355,11 +343,7 @@ func TestNewTLSDialerWithSaver(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@@ -392,11 +376,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndConfig(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@@ -431,11 +411,7 @@ func TestNewTLSDialerWithNoTLSVerifyAndNoConfig(t *testing.T) {
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
tth, ok := ewth.TLSHandshaker.(tlsdialer.TimeoutTLSHandshaker)
if !ok {
t.Fatal("not the TLSHandshaker we expected")
}
if _, ok := tth.TLSHandshaker.(tlsdialer.SystemTLSHandshaker); !ok {
if _, ok := ewth.TLSHandshaker.(*netxlite.TLSHandshakerStdlib); !ok {
t.Fatal("not the TLSHandshaker we expected")
}
}
@@ -472,7 +448,7 @@ func TestNewWithTLSDialer(t *testing.T) {
tlsDialer := tlsdialer.TLSDialer{
Config: new(tls.Config),
Dialer: netx.FakeDialer{Err: expected},
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
}
txp := netx.NewHTTPTransport(netx.Config{
TLSDialer: tlsDialer,
@@ -598,7 +574,7 @@ func TestNewDNSClientSystemResolver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if _, ok := dnsclient.Resolver.(netxlite.ResolverSystem); !ok {
if _, ok := dnsclient.Resolver.(*netxlite.ResolverSystem); !ok {
t.Fatal("not the resolver we expected")
}
dnsclient.CloseIdleConnections()
@@ -610,7 +586,7 @@ func TestNewDNSClientEmpty(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if _, ok := dnsclient.Resolver.(netxlite.ResolverSystem); !ok {
if _, ok := dnsclient.Resolver.(*netxlite.ResolverSystem); !ok {
t.Fatal("not the resolver we expected")
}
dnsclient.CloseIdleConnections()
+1 -1
View File
@@ -11,7 +11,7 @@ import (
func TestChainLookupHost(t *testing.T) {
r := resolver.ChainResolver{
Primary: resolver.NewFakeResolverThatFails(),
Secondary: netxlite.ResolverSystem{},
Secondary: &netxlite.ResolverSystem{},
}
if r.Address() != "" {
t.Fatal("invalid address")
@@ -19,7 +19,7 @@ func testresolverquick(t *testing.T, reso resolver.Resolver) {
if testing.Short() {
t.Skip("skip test in short mode")
}
reso = netxlite.ResolverLogger{Logger: log.Log, Resolver: reso}
reso = &netxlite.ResolverLogger{Logger: log.Log, Resolver: reso}
addrs, err := reso.LookupHost(context.Background(), "dns.google.com")
if err != nil {
t.Fatal(err)
@@ -45,7 +45,7 @@ func testresolverquickidna(t *testing.T, reso resolver.Resolver) {
t.Skip("skip test in short mode")
}
reso = resolver.IDNAResolver{
netxlite.ResolverLogger{Logger: log.Log, Resolver: reso},
&netxlite.ResolverLogger{Logger: log.Log, Resolver: reso},
}
addrs, err := reso.LookupHost(context.Background(), "яндекс.рф")
if err != nil {
@@ -57,7 +57,7 @@ func testresolverquickidna(t *testing.T, reso resolver.Resolver) {
}
func TestNewResolverSystem(t *testing.T) {
reso := netxlite.ResolverSystem{}
reso := &netxlite.ResolverSystem{}
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
@@ -8,6 +8,7 @@ import (
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestTLSDialerSuccess(t *testing.T) {
@@ -17,7 +18,7 @@ func TestTLSDialerSuccess(t *testing.T) {
log.SetLevel(log.DebugLevel)
dialer := tlsdialer.TLSDialer{Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.LoggingTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Logger: log.Log,
},
}
+7 -6
View File
@@ -12,6 +12,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
@@ -25,7 +26,7 @@ func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) {
Config: &tls.Config{NextProtos: nextprotos},
Dialer: dialer.New(&dialer.Config{ReadWriteSaver: saver}, &net.Resolver{}),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Saver: saver,
},
}
@@ -118,7 +119,7 @@ func TestSaverTLSHandshakerSuccess(t *testing.T) {
Config: &tls.Config{NextProtos: nextprotos},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Saver: saver,
},
}
@@ -183,7 +184,7 @@ func TestSaverTLSHandshakerHostnameError(t *testing.T) {
tlsdlr := tlsdialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Saver: saver,
},
}
@@ -216,7 +217,7 @@ func TestSaverTLSHandshakerInvalidCertError(t *testing.T) {
tlsdlr := tlsdialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Saver: saver,
},
}
@@ -249,7 +250,7 @@ func TestSaverTLSHandshakerAuthorityError(t *testing.T) {
tlsdlr := tlsdialer.TLSDialer{
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Saver: saver,
},
}
@@ -283,7 +284,7 @@ func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) {
Config: &tls.Config{InsecureSkipVerify: true},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SaverTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
Saver: saver,
},
}
-36
View File
@@ -22,42 +22,6 @@ type TLSHandshaker interface {
net.Conn, tls.ConnectionState, error)
}
// SystemTLSHandshaker is the system TLS handshaker.
type SystemTLSHandshaker struct{}
// Handshake implements Handshaker.Handshake
func (h SystemTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
tlsconn := tls.Client(conn, config)
if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
}
// TimeoutTLSHandshaker is a TLSHandshaker with timeout
type TimeoutTLSHandshaker struct {
TLSHandshaker
HandshakeTimeout time.Duration // default: 10 second
}
// Handshake implements Handshaker.Handshake
func (h TimeoutTLSHandshaker) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
timeout := 10 * time.Second
if h.HandshakeTimeout != 0 {
timeout = h.HandshakeTimeout
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, tls.ConnectionState{}, err
}
tlsconn, connstate, err := h.TLSHandshaker.Handshake(ctx, conn, config)
conn.SetDeadline(time.Time{})
return tlsconn, connstate, err
}
// ErrorWrapperTLSHandshaker wraps the returned error to be an OONI error
type ErrorWrapperTLSHandshaker struct {
TLSHandshaker
+7 -64
View File
@@ -13,10 +13,11 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/errorx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)
func TestSystemTLSHandshakerEOFError(t *testing.T) {
h := tlsdialer.SystemTLSHandshaker{}
h := &netxlite.TLSHandshakerStdlib{}
conn, _, err := h.Handshake(context.Background(), tlsdialer.EOFConn{}, &tls.Config{
ServerName: "x.org",
})
@@ -28,63 +29,6 @@ func TestSystemTLSHandshakerEOFError(t *testing.T) {
}
}
func TestTimeoutTLSHandshakerSetDeadlineError(t *testing.T) {
h := tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
expected := errors.New("mocked error")
conn, _, err := h.Handshake(
context.Background(), &tlsdialer.FakeConn{SetDeadlineError: expected},
new(tls.Config))
if !errors.Is(err, expected) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerEOFError(t *testing.T) {
h := tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
conn, _, err := h.Handshake(
context.Background(), tlsdialer.EOFConn{}, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
}
func TestTimeoutTLSHandshakerCallsSetDeadline(t *testing.T) {
h := tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 200 * time.Millisecond,
}
underlying := &SetDeadlineConn{}
conn, _, err := h.Handshake(
context.Background(), underlying, &tls.Config{ServerName: "x.org"})
if !errors.Is(err, io.EOF) {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
if len(underlying.deadlines) != 2 {
t.Fatal("SetDeadline not called twice")
}
if underlying.deadlines[0].Before(time.Now()) {
t.Fatal("the first SetDeadline call was incorrect")
}
if !underlying.deadlines[1].IsZero() {
t.Fatal("the second SetDeadline call was incorrect")
}
}
type SetDeadlineConn struct {
tlsdialer.EOFConn
deadlines []time.Time
@@ -179,7 +123,7 @@ func TestTLSDialerFailureDialing(t *testing.T) {
}
func TestTLSDialerFailureHandshaking(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}}
rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}}
dialer := tlsdialer.TLSDialer{
Dialer: tlsdialer.EOFConnDialer{},
TLSHandshaker: rec,
@@ -198,7 +142,7 @@ func TestTLSDialerFailureHandshaking(t *testing.T) {
}
func TestTLSDialerFailureHandshakingOverrideSNI(t *testing.T) {
rec := &RecorderTLSHandshaker{TLSHandshaker: tlsdialer.SystemTLSHandshaker{}}
rec := &RecorderTLSHandshaker{TLSHandshaker: &netxlite.TLSHandshakerStdlib{}}
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{
ServerName: "x.org",
@@ -235,7 +179,7 @@ func TestDialTLSContextGood(t *testing.T) {
dialer := tlsdialer.TLSDialer{
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
TLSHandshaker: &netxlite.TLSHandshakerStdlib{},
}
conn, err := dialer.DialTLSContext(context.Background(), "tcp", "google.com:443")
if err != nil {
@@ -252,9 +196,8 @@ func TestDialTLSContextTimeout(t *testing.T) {
Config: &tls.Config{ServerName: "google.com"},
Dialer: new(net.Dialer),
TLSHandshaker: tlsdialer.ErrorWrapperTLSHandshaker{
TLSHandshaker: tlsdialer.TimeoutTLSHandshaker{
TLSHandshaker: tlsdialer.SystemTLSHandshaker{},
HandshakeTimeout: 10 * time.Microsecond,
TLSHandshaker: &netxlite.TLSHandshakerStdlib{
Timeout: 10 * time.Microsecond,
},
},
}