refactor: move tls handshaker to netxlite (#400)

Part of https://github.com/ooni/probe/issues/1505
This commit is contained in:
Simone Basso
2021-06-25 11:07:26 +02:00
committed by GitHub
parent b8428b302f
commit 6b7d270bda
15 changed files with 182 additions and 172 deletions
+9 -9
View File
@@ -15,25 +15,25 @@ type Resolver interface {
// ResolverSystem is the system resolver.
type ResolverSystem struct{}
var _ Resolver = ResolverSystem{}
var _ Resolver = &ResolverSystem{}
// LookupHost implements Resolver.LookupHost.
func (r ResolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
func (r *ResolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
return net.DefaultResolver.LookupHost(ctx, hostname)
}
// Network implements Resolver.Network.
func (r ResolverSystem) Network() string {
func (r *ResolverSystem) Network() string {
return "system"
}
// Address implements Resolver.Address.
func (r ResolverSystem) Address() string {
func (r *ResolverSystem) Address() string {
return ""
}
// DefaultResolver is the resolver we use by default.
var DefaultResolver = ResolverSystem{}
var DefaultResolver = &ResolverSystem{}
// ResolverLogger is a resolver that emits events
type ResolverLogger struct {
@@ -41,10 +41,10 @@ type ResolverLogger struct {
Logger Logger
}
var _ Resolver = ResolverLogger{}
var _ Resolver = &ResolverLogger{}
// LookupHost returns the IP addresses of a host
func (r ResolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
func (r *ResolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
r.Logger.Debugf("resolve %s...", hostname)
start := time.Now()
addrs, err := r.Resolver.LookupHost(ctx, hostname)
@@ -62,7 +62,7 @@ type resolverNetworker interface {
}
// Network implements Resolver.Network.
func (r ResolverLogger) Network() string {
func (r *ResolverLogger) Network() string {
if rn, ok := r.Resolver.(resolverNetworker); ok {
return rn.Network()
}
@@ -74,7 +74,7 @@ type resolverAddresser interface {
}
// Address implements Resolver.Address.
func (r ResolverLogger) Address() string {
func (r *ResolverLogger) Address() string {
if ra, ok := r.Resolver.(resolverAddresser); ok {
return ra.Address()
}
+46
View File
@@ -0,0 +1,46 @@
package netxlite
import (
"context"
"crypto/tls"
"net"
"time"
)
// TLSHandshaker is the generic TLS handshaker.
type TLSHandshaker interface {
// Handshake creates a new TLS connection from the given connection and
// the given config. This function DOES NOT take ownership of the connection
// and it's your responsibility to close it on failure.
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
net.Conn, tls.ConnectionState, error)
}
// TLSHandshakerStdlib is the stdlib's TLS handshaker.
type TLSHandshakerStdlib struct {
// Timeout is the timeout imposed on the TLS handshake. If zero
// or negative, we will use default timeout of 10 seconds.
Timeout time.Duration
}
var _ TLSHandshaker = &TLSHandshakerStdlib{}
// Handshake implements Handshaker.Handshake
func (h *TLSHandshakerStdlib) Handshake(
ctx context.Context, conn net.Conn, config *tls.Config,
) (net.Conn, tls.ConnectionState, error) {
timeout := h.Timeout
if timeout <= 0 {
timeout = 10 * time.Second
}
defer conn.SetDeadline(time.Time{})
conn.SetDeadline(time.Now().Add(timeout))
tlsconn := tls.Client(conn, config)
if err := tlsconn.Handshake(); err != nil {
return nil, tls.ConnectionState{}, err
}
return tlsconn, tlsconn.ConnectionState(), nil
}
// DefaultTLSHandshaker is the default TLS handshaker.
var DefaultTLSHandshaker = &TLSHandshakerStdlib{}
+81
View File
@@ -0,0 +1,81 @@
package netxlite
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/netxmocks"
)
func TestTLSHandshakerStdlibWithError(t *testing.T) {
var times []time.Time
h := &TLSHandshakerStdlib{}
tcpConn := &netxmocks.Conn{
MockWrite: func(b []byte) (int, error) {
return 0, io.EOF
},
MockSetDeadline: func(t time.Time) error {
times = append(times, t)
return nil
},
}
ctx := context.Background()
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
ServerName: "x.org",
})
if err != io.EOF {
t.Fatal("not the error that we expected")
}
if conn != nil {
t.Fatal("expected nil con here")
}
if len(times) != 2 {
t.Fatal("expected two time entries")
}
if !times[0].After(time.Now()) {
t.Fatal("timeout not in the future")
}
if !times[1].IsZero() {
t.Fatal("did not clear timeout on exit")
}
}
func TestTLSHandshakerStdlibSuccess(t *testing.T) {
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200)
})
srvr := httptest.NewTLSServer(handler)
defer srvr.Close()
URL, err := url.Parse(srvr.URL)
if err != nil {
t.Fatal(err)
}
conn, err := net.Dial("tcp", URL.Host)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
handshaker := &TLSHandshakerStdlib{}
ctx := context.Background()
config := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
ServerName: URL.Hostname(),
}
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
if err != nil {
t.Fatal(err)
}
defer tlsConn.Close()
if connState.Version != tls.VersionTLS13 {
t.Fatal("unexpected TLS version")
}
}