refactor: move tls handshaker to netxlite (#400)
Part of https://github.com/ooni/probe/issues/1505
This commit is contained in:
@@ -15,25 +15,25 @@ type Resolver interface {
|
||||
// ResolverSystem is the system resolver.
|
||||
type ResolverSystem struct{}
|
||||
|
||||
var _ Resolver = ResolverSystem{}
|
||||
var _ Resolver = &ResolverSystem{}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r ResolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
func (r *ResolverSystem) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
return net.DefaultResolver.LookupHost(ctx, hostname)
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network.
|
||||
func (r ResolverSystem) Network() string {
|
||||
func (r *ResolverSystem) Network() string {
|
||||
return "system"
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address.
|
||||
func (r ResolverSystem) Address() string {
|
||||
func (r *ResolverSystem) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// DefaultResolver is the resolver we use by default.
|
||||
var DefaultResolver = ResolverSystem{}
|
||||
var DefaultResolver = &ResolverSystem{}
|
||||
|
||||
// ResolverLogger is a resolver that emits events
|
||||
type ResolverLogger struct {
|
||||
@@ -41,10 +41,10 @@ type ResolverLogger struct {
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
var _ Resolver = ResolverLogger{}
|
||||
var _ Resolver = &ResolverLogger{}
|
||||
|
||||
// LookupHost returns the IP addresses of a host
|
||||
func (r ResolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
func (r *ResolverLogger) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
r.Logger.Debugf("resolve %s...", hostname)
|
||||
start := time.Now()
|
||||
addrs, err := r.Resolver.LookupHost(ctx, hostname)
|
||||
@@ -62,7 +62,7 @@ type resolverNetworker interface {
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network.
|
||||
func (r ResolverLogger) Network() string {
|
||||
func (r *ResolverLogger) Network() string {
|
||||
if rn, ok := r.Resolver.(resolverNetworker); ok {
|
||||
return rn.Network()
|
||||
}
|
||||
@@ -74,7 +74,7 @@ type resolverAddresser interface {
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address.
|
||||
func (r ResolverLogger) Address() string {
|
||||
func (r *ResolverLogger) Address() string {
|
||||
if ra, ok := r.Resolver.(resolverAddresser); ok {
|
||||
return ra.Address()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package netxlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TLSHandshaker is the generic TLS handshaker.
|
||||
type TLSHandshaker interface {
|
||||
// Handshake creates a new TLS connection from the given connection and
|
||||
// the given config. This function DOES NOT take ownership of the connection
|
||||
// and it's your responsibility to close it on failure.
|
||||
Handshake(ctx context.Context, conn net.Conn, config *tls.Config) (
|
||||
net.Conn, tls.ConnectionState, error)
|
||||
}
|
||||
|
||||
// TLSHandshakerStdlib is the stdlib's TLS handshaker.
|
||||
type TLSHandshakerStdlib struct {
|
||||
// Timeout is the timeout imposed on the TLS handshake. If zero
|
||||
// or negative, we will use default timeout of 10 seconds.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
var _ TLSHandshaker = &TLSHandshakerStdlib{}
|
||||
|
||||
// Handshake implements Handshaker.Handshake
|
||||
func (h *TLSHandshakerStdlib) Handshake(
|
||||
ctx context.Context, conn net.Conn, config *tls.Config,
|
||||
) (net.Conn, tls.ConnectionState, error) {
|
||||
timeout := h.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
defer conn.SetDeadline(time.Time{})
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
tlsconn := tls.Client(conn, config)
|
||||
if err := tlsconn.Handshake(); err != nil {
|
||||
return nil, tls.ConnectionState{}, err
|
||||
}
|
||||
return tlsconn, tlsconn.ConnectionState(), nil
|
||||
}
|
||||
|
||||
// DefaultTLSHandshaker is the default TLS handshaker.
|
||||
var DefaultTLSHandshaker = &TLSHandshakerStdlib{}
|
||||
@@ -0,0 +1,81 @@
|
||||
package netxlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/netxmocks"
|
||||
)
|
||||
|
||||
func TestTLSHandshakerStdlibWithError(t *testing.T) {
|
||||
var times []time.Time
|
||||
h := &TLSHandshakerStdlib{}
|
||||
tcpConn := &netxmocks.Conn{
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return 0, io.EOF
|
||||
},
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
times = append(times, t)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn, _, err := h.Handshake(ctx, tcpConn, &tls.Config{
|
||||
ServerName: "x.org",
|
||||
})
|
||||
if err != io.EOF {
|
||||
t.Fatal("not the error that we expected")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil con here")
|
||||
}
|
||||
if len(times) != 2 {
|
||||
t.Fatal("expected two time entries")
|
||||
}
|
||||
if !times[0].After(time.Now()) {
|
||||
t.Fatal("timeout not in the future")
|
||||
}
|
||||
if !times[1].IsZero() {
|
||||
t.Fatal("did not clear timeout on exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSHandshakerStdlibSuccess(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
})
|
||||
srvr := httptest.NewTLSServer(handler)
|
||||
defer srvr.Close()
|
||||
URL, err := url.Parse(srvr.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := net.Dial("tcp", URL.Host)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
handshaker := &TLSHandshakerStdlib{}
|
||||
ctx := context.Background()
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
ServerName: URL.Hostname(),
|
||||
}
|
||||
tlsConn, connState, err := handshaker.Handshake(ctx, conn, config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tlsConn.Close()
|
||||
if connState.Version != tls.VersionTLS13 {
|
||||
t.Fatal("unexpected TLS version")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user