refactor(netx): start moving tls-specific code inside the tlsx pkg (#363)

* refactor(netx): move cert pool code inside tlsx

* refactor(netx): move more tls code inside tlsx
This commit is contained in:
Simone Basso 2021-06-08 15:39:25 +02:00 committed by GitHub
parent 0317420398
commit c553afdbd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 120 additions and 119 deletions

View File

@ -40,7 +40,6 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsx"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
// Logger is the logger assumed by this package
@ -112,12 +111,8 @@ type tlsHandshaker interface {
}
// NewDefaultCertPool returns a copy of the default x509
// certificate pool. This function panics on failure.
func NewDefaultCertPool() *x509.CertPool {
pool, err := tlsx.CACerts()
runtimex.PanicOnError(err, "tlsx.CACerts() failed")
return pool
}
// certificate pool that we bundle from Mozilla.
var NewDefaultCertPool = tlsx.NewDefaultCertPool
var defaultCertPool *x509.CertPool = NewDefaultCertPool()
@ -316,31 +311,11 @@ func NewDNSClient(config Config, URL string) (DNSClient, error) {
// ErrInvalidTLSVersion indicates that you passed us a string
// that does not represent a valid TLS version.
var ErrInvalidTLSVersion = errors.New("invalid TLS version")
var ErrInvalidTLSVersion = tlsx.ErrInvalidTLSVersion
// ConfigureTLSVersion configures the correct TLS version into
// the specified *tls.Config or returns an error.
func ConfigureTLSVersion(config *tls.Config, version string) error {
switch version {
case "TLSv1.3":
config.MinVersion = tls.VersionTLS13
config.MaxVersion = tls.VersionTLS13
case "TLSv1.2":
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
case "TLSv1.1":
config.MinVersion = tls.VersionTLS11
config.MaxVersion = tls.VersionTLS11
case "TLSv1.0", "TLSv1":
config.MinVersion = tls.VersionTLS10
config.MaxVersion = tls.VersionTLS10
case "":
// nothing
default:
return ErrInvalidTLSVersion
}
return nil
}
var ConfigureTLSVersion = tlsx.ConfigureTLSVersion
// NewDNSClientWithOverrides creates a new DNS client, similar to NewDNSClient,
// with the option to override the default Hostname and SNI.

View File

@ -1193,70 +1193,3 @@ func TestNewDNSCLientWithInvalidTLSVersion(t *testing.T) {
t.Fatalf("not the error we expected: %+v", err)
}
}
func TestConfigureTLSVersion(t *testing.T) {
tests := []struct {
name string
version string
wantErr error
versionMin int
versionMax int
}{{
name: "with TLSv1.3",
version: "TLSv1.3",
wantErr: nil,
versionMin: tls.VersionTLS13,
versionMax: tls.VersionTLS13,
}, {
name: "with TLSv1.2",
version: "TLSv1.2",
wantErr: nil,
versionMin: tls.VersionTLS12,
versionMax: tls.VersionTLS12,
}, {
name: "with TLSv1.1",
version: "TLSv1.1",
wantErr: nil,
versionMin: tls.VersionTLS11,
versionMax: tls.VersionTLS11,
}, {
name: "with TLSv1.0",
version: "TLSv1.0",
wantErr: nil,
versionMin: tls.VersionTLS10,
versionMax: tls.VersionTLS10,
}, {
name: "with TLSv1",
version: "TLSv1",
wantErr: nil,
versionMin: tls.VersionTLS10,
versionMax: tls.VersionTLS10,
}, {
name: "with default",
version: "",
wantErr: nil,
versionMin: 0,
versionMax: 0,
}, {
name: "with invalid version",
version: "TLSv999",
wantErr: netx.ErrInvalidTLSVersion,
versionMin: 0,
versionMax: 0,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := new(tls.Config)
err := netx.ConfigureTLSVersion(conf, tt.version)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("not the error we expected: %+v", err)
}
if conf.MinVersion != uint16(tt.versionMin) {
t.Fatalf("not the min version we expected: %+v", conf.MinVersion)
}
if conf.MaxVersion != uint16(tt.versionMax) {
t.Fatalf("not the max version we expected: %+v", conf.MaxVersion)
}
})
}
}

View File

@ -1,13 +1,11 @@
// Code generated by go generate; DO NOT EDIT.
// 2021-06-08 13:03:08.852763 +0200 CEST m=+2.563240667
// 2021-06-08 14:36:18.474117 +0200 CEST m=+0.459588210
// https://curl.haxx.se/ca/cacert.pem
package tlsx
//go:generate go run generate.go "https://curl.haxx.se/ca/cacert.pem"
import "crypto/x509"
const pemcerts string = `
##
## Bundle of CA Root Certificates
@ -3149,12 +3147,3 @@ CAezNIm8BZ/3Hobui3A=
-----END CERTIFICATE-----
`
// CACerts builds an X.509 certificate pool containing the
// certificate bundle from https://curl.haxx.se/ca/cacert.pem fetch on 2021-06-08 13:03:08.852763 +0200 CEST m=+2.563240667.
// Returns nil on error along with an appropriate error code.
func CACerts() (*x509.CertPool, error) {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM([]byte(pemcerts))
return pool, nil
}

View File

@ -30,20 +30,9 @@ package tlsx
//go:generate go run generate.go "{{ .URL }}"
import "crypto/x509"
const pemcerts string = ` + "`" + `
{{ .Bundle }}
` + "`" + `
// CACerts builds an X.509 certificate pool containing the
// certificate bundle from {{ .URL }} fetch on {{ .Timestamp }}.
// Returns nil on error along with an appropriate error code.
func CACerts() (*x509.CertPool, error) {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM([]byte(pemcerts))
return pool, nil
}
`))
func main() {

View File

@ -3,6 +3,8 @@ package tlsx
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
)
@ -61,3 +63,41 @@ func CipherSuiteString(value uint16) string {
}
return fmt.Sprintf("TLS_CIPHER_SUITE_UNKNOWN_%d", value)
}
// NewDefaultCertPool returns a copy of the default x509
// certificate pool that we bundle from Mozilla.
func NewDefaultCertPool() *x509.CertPool {
pool := x509.NewCertPool()
// Assumption: AppendCertsFromPEM cannot fail because we
// run this function already in the generate.go file
pool.AppendCertsFromPEM([]byte(pemcerts))
return pool
}
// ErrInvalidTLSVersion indicates that you passed us a string
// that does not represent a valid TLS version.
var ErrInvalidTLSVersion = errors.New("invalid TLS version")
// ConfigureTLSVersion configures the correct TLS version into
// the specified *tls.Config or returns an error.
func ConfigureTLSVersion(config *tls.Config, version string) error {
switch version {
case "TLSv1.3":
config.MinVersion = tls.VersionTLS13
config.MaxVersion = tls.VersionTLS13
case "TLSv1.2":
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
case "TLSv1.1":
config.MinVersion = tls.VersionTLS11
config.MaxVersion = tls.VersionTLS11
case "TLSv1.0", "TLSv1":
config.MinVersion = tls.VersionTLS10
config.MaxVersion = tls.VersionTLS10
case "":
// nothing
default:
return ErrInvalidTLSVersion
}
return nil
}

View File

@ -2,6 +2,7 @@ package tlsx
import (
"crypto/tls"
"errors"
"testing"
)
@ -28,3 +29,77 @@ func TestCipherSuite(t *testing.T) {
t.Fatal("not working for zero cipher suite")
}
}
func TestNewDefaultCertPoolWorks(t *testing.T) {
pool := NewDefaultCertPool()
if pool == nil {
t.Fatal("expected non-nil value here")
}
}
func TestConfigureTLSVersion(t *testing.T) {
tests := []struct {
name string
version string
wantErr error
versionMin int
versionMax int
}{{
name: "with TLSv1.3",
version: "TLSv1.3",
wantErr: nil,
versionMin: tls.VersionTLS13,
versionMax: tls.VersionTLS13,
}, {
name: "with TLSv1.2",
version: "TLSv1.2",
wantErr: nil,
versionMin: tls.VersionTLS12,
versionMax: tls.VersionTLS12,
}, {
name: "with TLSv1.1",
version: "TLSv1.1",
wantErr: nil,
versionMin: tls.VersionTLS11,
versionMax: tls.VersionTLS11,
}, {
name: "with TLSv1.0",
version: "TLSv1.0",
wantErr: nil,
versionMin: tls.VersionTLS10,
versionMax: tls.VersionTLS10,
}, {
name: "with TLSv1",
version: "TLSv1",
wantErr: nil,
versionMin: tls.VersionTLS10,
versionMax: tls.VersionTLS10,
}, {
name: "with default",
version: "",
wantErr: nil,
versionMin: 0,
versionMax: 0,
}, {
name: "with invalid version",
version: "TLSv999",
wantErr: ErrInvalidTLSVersion,
versionMin: 0,
versionMax: 0,
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := new(tls.Config)
err := ConfigureTLSVersion(conf, tt.version)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("not the error we expected: %+v", err)
}
if conf.MinVersion != uint16(tt.versionMin) {
t.Fatalf("not the min version we expected: %+v", conf.MinVersion)
}
if conf.MaxVersion != uint16(tt.versionMax) {
t.Fatalf("not the max version we expected: %+v", conf.MaxVersion)
}
})
}
}