refactor(netx): start moving tls-specific code inside the tlsx pkg (#363)
* refactor(netx): move cert pool code inside tlsx * refactor(netx): move more tls code inside tlsx
This commit is contained in:
parent
0317420398
commit
c553afdbd5
|
@ -40,7 +40,6 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsdialer"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsx"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/tlsx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Logger is the logger assumed by this package
|
// Logger is the logger assumed by this package
|
||||||
|
@ -112,12 +111,8 @@ type tlsHandshaker interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultCertPool returns a copy of the default x509
|
// NewDefaultCertPool returns a copy of the default x509
|
||||||
// certificate pool. This function panics on failure.
|
// certificate pool that we bundle from Mozilla.
|
||||||
func NewDefaultCertPool() *x509.CertPool {
|
var NewDefaultCertPool = tlsx.NewDefaultCertPool
|
||||||
pool, err := tlsx.CACerts()
|
|
||||||
runtimex.PanicOnError(err, "tlsx.CACerts() failed")
|
|
||||||
return pool
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultCertPool *x509.CertPool = NewDefaultCertPool()
|
var defaultCertPool *x509.CertPool = NewDefaultCertPool()
|
||||||
|
|
||||||
|
@ -316,31 +311,11 @@ func NewDNSClient(config Config, URL string) (DNSClient, error) {
|
||||||
|
|
||||||
// ErrInvalidTLSVersion indicates that you passed us a string
|
// ErrInvalidTLSVersion indicates that you passed us a string
|
||||||
// that does not represent a valid TLS version.
|
// that does not represent a valid TLS version.
|
||||||
var ErrInvalidTLSVersion = errors.New("invalid TLS version")
|
var ErrInvalidTLSVersion = tlsx.ErrInvalidTLSVersion
|
||||||
|
|
||||||
// ConfigureTLSVersion configures the correct TLS version into
|
// ConfigureTLSVersion configures the correct TLS version into
|
||||||
// the specified *tls.Config or returns an error.
|
// the specified *tls.Config or returns an error.
|
||||||
func ConfigureTLSVersion(config *tls.Config, version string) error {
|
var ConfigureTLSVersion = tlsx.ConfigureTLSVersion
|
||||||
switch version {
|
|
||||||
case "TLSv1.3":
|
|
||||||
config.MinVersion = tls.VersionTLS13
|
|
||||||
config.MaxVersion = tls.VersionTLS13
|
|
||||||
case "TLSv1.2":
|
|
||||||
config.MinVersion = tls.VersionTLS12
|
|
||||||
config.MaxVersion = tls.VersionTLS12
|
|
||||||
case "TLSv1.1":
|
|
||||||
config.MinVersion = tls.VersionTLS11
|
|
||||||
config.MaxVersion = tls.VersionTLS11
|
|
||||||
case "TLSv1.0", "TLSv1":
|
|
||||||
config.MinVersion = tls.VersionTLS10
|
|
||||||
config.MaxVersion = tls.VersionTLS10
|
|
||||||
case "":
|
|
||||||
// nothing
|
|
||||||
default:
|
|
||||||
return ErrInvalidTLSVersion
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDNSClientWithOverrides creates a new DNS client, similar to NewDNSClient,
|
// NewDNSClientWithOverrides creates a new DNS client, similar to NewDNSClient,
|
||||||
// with the option to override the default Hostname and SNI.
|
// with the option to override the default Hostname and SNI.
|
||||||
|
|
|
@ -1193,70 +1193,3 @@ func TestNewDNSCLientWithInvalidTLSVersion(t *testing.T) {
|
||||||
t.Fatalf("not the error we expected: %+v", err)
|
t.Fatalf("not the error we expected: %+v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigureTLSVersion(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
wantErr error
|
|
||||||
versionMin int
|
|
||||||
versionMax int
|
|
||||||
}{{
|
|
||||||
name: "with TLSv1.3",
|
|
||||||
version: "TLSv1.3",
|
|
||||||
wantErr: nil,
|
|
||||||
versionMin: tls.VersionTLS13,
|
|
||||||
versionMax: tls.VersionTLS13,
|
|
||||||
}, {
|
|
||||||
name: "with TLSv1.2",
|
|
||||||
version: "TLSv1.2",
|
|
||||||
wantErr: nil,
|
|
||||||
versionMin: tls.VersionTLS12,
|
|
||||||
versionMax: tls.VersionTLS12,
|
|
||||||
}, {
|
|
||||||
name: "with TLSv1.1",
|
|
||||||
version: "TLSv1.1",
|
|
||||||
wantErr: nil,
|
|
||||||
versionMin: tls.VersionTLS11,
|
|
||||||
versionMax: tls.VersionTLS11,
|
|
||||||
}, {
|
|
||||||
name: "with TLSv1.0",
|
|
||||||
version: "TLSv1.0",
|
|
||||||
wantErr: nil,
|
|
||||||
versionMin: tls.VersionTLS10,
|
|
||||||
versionMax: tls.VersionTLS10,
|
|
||||||
}, {
|
|
||||||
name: "with TLSv1",
|
|
||||||
version: "TLSv1",
|
|
||||||
wantErr: nil,
|
|
||||||
versionMin: tls.VersionTLS10,
|
|
||||||
versionMax: tls.VersionTLS10,
|
|
||||||
}, {
|
|
||||||
name: "with default",
|
|
||||||
version: "",
|
|
||||||
wantErr: nil,
|
|
||||||
versionMin: 0,
|
|
||||||
versionMax: 0,
|
|
||||||
}, {
|
|
||||||
name: "with invalid version",
|
|
||||||
version: "TLSv999",
|
|
||||||
wantErr: netx.ErrInvalidTLSVersion,
|
|
||||||
versionMin: 0,
|
|
||||||
versionMax: 0,
|
|
||||||
}}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
conf := new(tls.Config)
|
|
||||||
err := netx.ConfigureTLSVersion(conf, tt.version)
|
|
||||||
if !errors.Is(err, tt.wantErr) {
|
|
||||||
t.Fatalf("not the error we expected: %+v", err)
|
|
||||||
}
|
|
||||||
if conf.MinVersion != uint16(tt.versionMin) {
|
|
||||||
t.Fatalf("not the min version we expected: %+v", conf.MinVersion)
|
|
||||||
}
|
|
||||||
if conf.MaxVersion != uint16(tt.versionMax) {
|
|
||||||
t.Fatalf("not the max version we expected: %+v", conf.MaxVersion)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
// Code generated by go generate; DO NOT EDIT.
|
// Code generated by go generate; DO NOT EDIT.
|
||||||
// 2021-06-08 13:03:08.852763 +0200 CEST m=+2.563240667
|
// 2021-06-08 14:36:18.474117 +0200 CEST m=+0.459588210
|
||||||
// https://curl.haxx.se/ca/cacert.pem
|
// https://curl.haxx.se/ca/cacert.pem
|
||||||
|
|
||||||
package tlsx
|
package tlsx
|
||||||
|
|
||||||
//go:generate go run generate.go "https://curl.haxx.se/ca/cacert.pem"
|
//go:generate go run generate.go "https://curl.haxx.se/ca/cacert.pem"
|
||||||
|
|
||||||
import "crypto/x509"
|
|
||||||
|
|
||||||
const pemcerts string = `
|
const pemcerts string = `
|
||||||
##
|
##
|
||||||
## Bundle of CA Root Certificates
|
## Bundle of CA Root Certificates
|
||||||
|
@ -3149,12 +3147,3 @@ CAezNIm8BZ/3Hobui3A=
|
||||||
-----END CERTIFICATE-----
|
-----END CERTIFICATE-----
|
||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
// CACerts builds an X.509 certificate pool containing the
|
|
||||||
// certificate bundle from https://curl.haxx.se/ca/cacert.pem fetch on 2021-06-08 13:03:08.852763 +0200 CEST m=+2.563240667.
|
|
||||||
// Returns nil on error along with an appropriate error code.
|
|
||||||
func CACerts() (*x509.CertPool, error) {
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
pool.AppendCertsFromPEM([]byte(pemcerts))
|
|
||||||
return pool, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -30,20 +30,9 @@ package tlsx
|
||||||
|
|
||||||
//go:generate go run generate.go "{{ .URL }}"
|
//go:generate go run generate.go "{{ .URL }}"
|
||||||
|
|
||||||
import "crypto/x509"
|
|
||||||
|
|
||||||
const pemcerts string = ` + "`" + `
|
const pemcerts string = ` + "`" + `
|
||||||
{{ .Bundle }}
|
{{ .Bundle }}
|
||||||
` + "`" + `
|
` + "`" + `
|
||||||
|
|
||||||
// CACerts builds an X.509 certificate pool containing the
|
|
||||||
// certificate bundle from {{ .URL }} fetch on {{ .Timestamp }}.
|
|
||||||
// Returns nil on error along with an appropriate error code.
|
|
||||||
func CACerts() (*x509.CertPool, error) {
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
pool.AppendCertsFromPEM([]byte(pemcerts))
|
|
||||||
return pool, nil
|
|
||||||
}
|
|
||||||
`))
|
`))
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
@ -3,6 +3,8 @@ package tlsx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -61,3 +63,41 @@ func CipherSuiteString(value uint16) string {
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("TLS_CIPHER_SUITE_UNKNOWN_%d", value)
|
return fmt.Sprintf("TLS_CIPHER_SUITE_UNKNOWN_%d", value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewDefaultCertPool returns a copy of the default x509
|
||||||
|
// certificate pool that we bundle from Mozilla.
|
||||||
|
func NewDefaultCertPool() *x509.CertPool {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
// Assumption: AppendCertsFromPEM cannot fail because we
|
||||||
|
// run this function already in the generate.go file
|
||||||
|
pool.AppendCertsFromPEM([]byte(pemcerts))
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrInvalidTLSVersion indicates that you passed us a string
|
||||||
|
// that does not represent a valid TLS version.
|
||||||
|
var ErrInvalidTLSVersion = errors.New("invalid TLS version")
|
||||||
|
|
||||||
|
// ConfigureTLSVersion configures the correct TLS version into
|
||||||
|
// the specified *tls.Config or returns an error.
|
||||||
|
func ConfigureTLSVersion(config *tls.Config, version string) error {
|
||||||
|
switch version {
|
||||||
|
case "TLSv1.3":
|
||||||
|
config.MinVersion = tls.VersionTLS13
|
||||||
|
config.MaxVersion = tls.VersionTLS13
|
||||||
|
case "TLSv1.2":
|
||||||
|
config.MinVersion = tls.VersionTLS12
|
||||||
|
config.MaxVersion = tls.VersionTLS12
|
||||||
|
case "TLSv1.1":
|
||||||
|
config.MinVersion = tls.VersionTLS11
|
||||||
|
config.MaxVersion = tls.VersionTLS11
|
||||||
|
case "TLSv1.0", "TLSv1":
|
||||||
|
config.MinVersion = tls.VersionTLS10
|
||||||
|
config.MaxVersion = tls.VersionTLS10
|
||||||
|
case "":
|
||||||
|
// nothing
|
||||||
|
default:
|
||||||
|
return ErrInvalidTLSVersion
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package tlsx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,3 +29,77 @@ func TestCipherSuite(t *testing.T) {
|
||||||
t.Fatal("not working for zero cipher suite")
|
t.Fatal("not working for zero cipher suite")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewDefaultCertPoolWorks(t *testing.T) {
|
||||||
|
pool := NewDefaultCertPool()
|
||||||
|
if pool == nil {
|
||||||
|
t.Fatal("expected non-nil value here")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigureTLSVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
version string
|
||||||
|
wantErr error
|
||||||
|
versionMin int
|
||||||
|
versionMax int
|
||||||
|
}{{
|
||||||
|
name: "with TLSv1.3",
|
||||||
|
version: "TLSv1.3",
|
||||||
|
wantErr: nil,
|
||||||
|
versionMin: tls.VersionTLS13,
|
||||||
|
versionMax: tls.VersionTLS13,
|
||||||
|
}, {
|
||||||
|
name: "with TLSv1.2",
|
||||||
|
version: "TLSv1.2",
|
||||||
|
wantErr: nil,
|
||||||
|
versionMin: tls.VersionTLS12,
|
||||||
|
versionMax: tls.VersionTLS12,
|
||||||
|
}, {
|
||||||
|
name: "with TLSv1.1",
|
||||||
|
version: "TLSv1.1",
|
||||||
|
wantErr: nil,
|
||||||
|
versionMin: tls.VersionTLS11,
|
||||||
|
versionMax: tls.VersionTLS11,
|
||||||
|
}, {
|
||||||
|
name: "with TLSv1.0",
|
||||||
|
version: "TLSv1.0",
|
||||||
|
wantErr: nil,
|
||||||
|
versionMin: tls.VersionTLS10,
|
||||||
|
versionMax: tls.VersionTLS10,
|
||||||
|
}, {
|
||||||
|
name: "with TLSv1",
|
||||||
|
version: "TLSv1",
|
||||||
|
wantErr: nil,
|
||||||
|
versionMin: tls.VersionTLS10,
|
||||||
|
versionMax: tls.VersionTLS10,
|
||||||
|
}, {
|
||||||
|
name: "with default",
|
||||||
|
version: "",
|
||||||
|
wantErr: nil,
|
||||||
|
versionMin: 0,
|
||||||
|
versionMax: 0,
|
||||||
|
}, {
|
||||||
|
name: "with invalid version",
|
||||||
|
version: "TLSv999",
|
||||||
|
wantErr: ErrInvalidTLSVersion,
|
||||||
|
versionMin: 0,
|
||||||
|
versionMax: 0,
|
||||||
|
}}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
conf := new(tls.Config)
|
||||||
|
err := ConfigureTLSVersion(conf, tt.version)
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("not the error we expected: %+v", err)
|
||||||
|
}
|
||||||
|
if conf.MinVersion != uint16(tt.versionMin) {
|
||||||
|
t.Fatalf("not the min version we expected: %+v", conf.MinVersion)
|
||||||
|
}
|
||||||
|
if conf.MaxVersion != uint16(tt.versionMax) {
|
||||||
|
t.Fatalf("not the max version we expected: %+v", conf.MaxVersion)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user