Add expiration warning to certificate checking for diagnose [VAULT-1018] (#11850)

* add expiration warning to certificate checking for diagnose

* Update serviceregistration/consul/consul_service_registration.go

Co-authored-by: swayne275 <swayne275@gmail.com>

* review comments

Co-authored-by: swayne275 <swayne275@gmail.com>
This commit is contained in:
Hridoy Roy 2021-06-15 09:53:29 -07:00 committed by GitHub
parent a036b3a4d1
commit cea68aaa68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 227 additions and 117 deletions

View File

@ -255,7 +255,7 @@ func (c *OperatorDiagnoseCommand) offlineDiagnostics(ctx context.Context) error
if config.Storage != nil && config.Storage.Type == storageTypeConsul { if config.Storage != nil && config.Storage.Type == storageTypeConsul {
diagnose.Test(ctx, "test-storage-tls-consul", func(ctx context.Context) error { diagnose.Test(ctx, "test-storage-tls-consul", func(ctx context.Context) error {
err = physconsul.SetupSecureTLS(api.DefaultConfig(), config.Storage.Config, server.logger, true) err = physconsul.SetupSecureTLS(ctx, api.DefaultConfig(), config.Storage.Config, server.logger, true)
if err != nil { if err != nil {
return err return err
} }
@ -323,7 +323,7 @@ func (c *OperatorDiagnoseCommand) offlineDiagnostics(ctx context.Context) error
diagnose.Test(ctx, "test-serviceregistration-tls-consul", func(ctx context.Context) error { diagnose.Test(ctx, "test-serviceregistration-tls-consul", func(ctx context.Context) error {
// SetupSecureTLS for service discovery uses the same cert and key to set up physical // SetupSecureTLS for service discovery uses the same cert and key to set up physical
// storage. See the consul package in physical for details. // storage. See the consul package in physical for details.
err = srconsul.SetupSecureTLS(api.DefaultConfig(), srConfig, server.logger, true) err = srconsul.SetupSecureTLS(ctx, api.DefaultConfig(), srConfig, server.logger, true)
if err != nil { if err != nil {
return err return err
} }
@ -424,7 +424,7 @@ SEALFAIL:
}) })
if config.HAStorage != nil && config.HAStorage.Type == storageTypeConsul { if config.HAStorage != nil && config.HAStorage.Type == storageTypeConsul {
diagnose.Test(ctx, "test-ha-storage-tls-consul", func(ctx context.Context) error { diagnose.Test(ctx, "test-ha-storage-tls-consul", func(ctx context.Context) error {
err = physconsul.SetupSecureTLS(api.DefaultConfig(), config.HAStorage.Config, server.logger, true) err = physconsul.SetupSecureTLS(ctx, api.DefaultConfig(), config.HAStorage.Config, server.logger, true)
if err != nil { if err != nil {
return err return err
} }
@ -493,36 +493,26 @@ SEALFAIL:
defer c.cleanupGuard.Do(listenerCloseFunc) defer c.cleanupGuard.Do(listenerCloseFunc)
diagnose.Test(ctx, "check-listener-tls", func(ctx context.Context) error { listenerTLSContext, listenerTLSSpan := diagnose.StartSpan(ctx, "check-listener-tls")
sanitizedListeners := make([]listenerutil.Listener, 0, len(config.Listeners)) sanitizedListeners := make([]listenerutil.Listener, 0, len(config.Listeners))
for _, ln := range lns { for _, ln := range lns {
if ln.Config.TLSDisable { if ln.Config.TLSDisable {
diagnose.Warn(ctx, "TLS is disabled in a Listener config stanza.") diagnose.Warn(listenerTLSContext, "TLS is disabled in a Listener config stanza.")
continue continue
}
if ln.Config.TLSDisableClientCerts {
diagnose.Warn(ctx, "TLS for a listener is turned on without requiring client certs.")
}
// Check ciphersuite and load ca/cert/key files
// TODO: TLSConfig returns a reloadFunc and a TLSConfig. We can use this to
// perform an active probe.
_, _, err := listenerutil.TLSConfig(ln.Config, make(map[string]string), c.UI)
if err != nil {
return err
}
sanitizedListeners = append(sanitizedListeners, listenerutil.Listener{
Listener: ln.Listener,
Config: ln.Config,
})
} }
err = diagnose.ListenerChecks(sanitizedListeners) if ln.Config.TLSDisableClientCerts {
if err != nil { diagnose.Warn(listenerTLSContext, "TLS for a listener is turned on without requiring client certs.")
return err
} }
return nil
}) sanitizedListeners = append(sanitizedListeners, listenerutil.Listener{
Listener: ln.Listener,
Config: ln.Config,
})
}
diagnose.ListenerChecks(listenerTLSContext, sanitizedListeners)
listenerTLSSpan.End()
return nil return nil
}) })

View File

@ -226,7 +226,10 @@ func TestOperatorDiagnoseCommand_Run(t *testing.T) {
{ {
Name: "test-storage-tls-consul", Name: "test-storage-tls-consul",
Status: diagnose.ErrorStatus, Status: diagnose.ErrorStatus,
Message: "expired", Message: "certificate has expired or is not yet valid",
Warnings: []string{
"expired or near expiry",
},
}, },
{ {
Name: "test-consul-direct-access-storage", Name: "test-consul-direct-access-storage",
@ -281,7 +284,10 @@ func TestOperatorDiagnoseCommand_Run(t *testing.T) {
{ {
Name: "test-ha-storage-tls-consul", Name: "test-ha-storage-tls-consul",
Status: diagnose.ErrorStatus, Status: diagnose.ErrorStatus,
Message: "x509: certificate has expired or is not yet valid", Message: "certificate has expired or is not yet valid",
Warnings: []string{
"expired or near expiry",
},
}, },
}, },
}, },
@ -304,7 +310,10 @@ func TestOperatorDiagnoseCommand_Run(t *testing.T) {
{ {
Name: "test-serviceregistration-tls-consul", Name: "test-serviceregistration-tls-consul",
Status: diagnose.ErrorStatus, Status: diagnose.ErrorStatus,
Message: "failed to verify certificate: x509: certificate has expired or is not yet valid", Message: "certificate has expired or is not yet valid",
Warnings: []string{
"expired or near expiry",
},
}, },
{ {
Name: "test-consul-direct-access-service-discovery", Name: "test-consul-direct-access-service-discovery",

View File

@ -129,7 +129,7 @@ func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backe
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore // Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
SetupSecureTLS(consulConf, conf, logger, false) SetupSecureTLS(context.Background(), consulConf, conf, logger, false)
consulConf.HttpClient = &http.Client{Transport: consulConf.Transport} consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
client, err := api.NewClient(consulConf) client, err := api.NewClient(consulConf)
@ -151,7 +151,7 @@ func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backe
return c, nil return c, nil
} }
func SetupSecureTLS(consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error { func SetupSecureTLS(ctx context.Context, consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error {
if addr, ok := conf["address"]; ok { if addr, ok := conf["address"]; ok {
consulConf.Address = addr consulConf.Address = addr
if logger.IsDebug() { if logger.IsDebug() {
@ -189,13 +189,16 @@ func SetupSecureTLS(consulConf *api.Config, conf map[string]string, logger log.L
certPath, okCert := conf["tls_cert_file"] certPath, okCert := conf["tls_cert_file"]
keyPath, okKey := conf["tls_key_file"] keyPath, okKey := conf["tls_key_file"]
if okCert && okKey { if okCert && okKey {
err := diagnose.TLSFileChecks(certPath, keyPath) warnings, err := diagnose.TLSFileChecks(certPath, keyPath)
for _, warning := range warnings {
diagnose.Warn(ctx, warning)
}
if err != nil { if err != nil {
return err return err
} }
} else { return nil
return fmt.Errorf("key or cert path: %s, %s, cannot be loaded from consul config file", certPath, keyPath)
} }
return fmt.Errorf("key or cert path: %s, %s, cannot be loaded from consul config file", certPath, keyPath)
} }
// Use the parsed Address instead of the raw conf['address'] // Use the parsed Address instead of the raw conf['address']

View File

@ -1,6 +1,7 @@
package consul package consul
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
@ -146,7 +147,7 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.
// Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore // Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore
consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount
SetupSecureTLS(consulConf, conf, logger, false) SetupSecureTLS(context.Background(), consulConf, conf, logger, false)
consulConf.HttpClient = &http.Client{Transport: consulConf.Transport} consulConf.HttpClient = &http.Client{Transport: consulConf.Transport}
client, err := api.NewClient(consulConf) client, err := api.NewClient(consulConf)
@ -178,7 +179,7 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.
return c, nil return c, nil
} }
func SetupSecureTLS(consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error { func SetupSecureTLS(ctx context.Context, consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error {
if addr, ok := conf["address"]; ok { if addr, ok := conf["address"]; ok {
consulConf.Address = addr consulConf.Address = addr
if logger.IsDebug() { if logger.IsDebug() {
@ -216,13 +217,16 @@ func SetupSecureTLS(consulConf *api.Config, conf map[string]string, logger log.L
certPath, okCert := conf["tls_cert_file"] certPath, okCert := conf["tls_cert_file"]
keyPath, okKey := conf["tls_key_file"] keyPath, okKey := conf["tls_key_file"]
if okCert && okKey { if okCert && okKey {
err := diagnose.TLSFileChecks(certPath, keyPath) warnings, err := diagnose.TLSFileChecks(certPath, keyPath)
for _, warning := range warnings {
diagnose.Warn(ctx, warning)
}
if err != nil { if err != nil {
return err return err
} }
} else { return nil
return fmt.Errorf("key or cert path: %s, %s, cannot be loaded from consul config file", certPath, keyPath)
} }
return fmt.Errorf("key or cert path: %s, %s, cannot be loaded from consul config file", certPath, keyPath)
} }
// Use the parsed Address instead of the raw conf['address'] // Use the parsed Address instead of the raw conf['address']

View File

@ -2,11 +2,13 @@ package diagnose
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"time"
"github.com/hashicorp/vault/internalshared/listenerutil" "github.com/hashicorp/vault/internalshared/listenerutil"
"github.com/hashicorp/vault/sdk/helper/tlsutil" "github.com/hashicorp/vault/sdk/helper/tlsutil"
@ -15,9 +17,18 @@ import (
const minVersionError = "'tls_min_version' value %q not supported, please specify one of [tls10,tls11,tls12,tls13]" const minVersionError = "'tls_min_version' value %q not supported, please specify one of [tls10,tls11,tls12,tls13]"
const maxVersionError = "'tls_max_version' value %q not supported, please specify one of [tls10,tls11,tls12,tls13]" const maxVersionError = "'tls_max_version' value %q not supported, please specify one of [tls10,tls11,tls12,tls13]"
func ListenerChecks(listeners []listenerutil.Listener) error { // ListenerChecks diagnoses warnings and the first encountered error for the listener
// configuration stanzas.
func ListenerChecks(ctx context.Context, listeners []listenerutil.Listener) ([]string, []error) {
// These aggregated warnings and errors are returned purely for testing purposes.
// The errors and warnings will report in this function itself.
var listenerWarnings []string
var listenerErrors []error
for _, listener := range listeners { for _, listener := range listeners {
l := listener.Config l := listener.Config
listenerID := l.Address
// Perform the TLS version check for listeners. // Perform the TLS version check for listeners.
if l.TLSMinVersion == "" { if l.TLSMinVersion == "" {
@ -28,64 +39,119 @@ func ListenerChecks(listeners []listenerutil.Listener) error {
} }
_, ok := tlsutil.TLSLookup[l.TLSMinVersion] _, ok := tlsutil.TLSLookup[l.TLSMinVersion]
if !ok { if !ok {
return fmt.Errorf(minVersionError, l.TLSMinVersion) err := fmt.Errorf("listener at address: %s has error %s: ", listenerID, fmt.Sprintf(minVersionError, l.TLSMinVersion))
listenerErrors = append(listenerErrors, err)
Error(ctx, err)
} }
_, ok = tlsutil.TLSLookup[l.TLSMaxVersion] _, ok = tlsutil.TLSLookup[l.TLSMaxVersion]
if !ok { if !ok {
return fmt.Errorf(maxVersionError, l.TLSMaxVersion) err := fmt.Errorf("listener at address: %s has error %s: ", listenerID, fmt.Sprintf(maxVersionError, l.TLSMaxVersion))
listenerErrors = append(listenerErrors, err)
Error(ctx, err)
} }
// Perform checks on the TLS Cryptographic Information. // Perform checks on the TLS Cryptographic Information.
if err := TLSFileChecks(l.TLSCertFile, l.TLSKeyFile); err != nil { warnings, err := TLSFileChecks(l.TLSCertFile, l.TLSKeyFile)
return err for _, warning := range warnings {
warning = listenerID + ": " + warning
listenerWarnings = append(listenerWarnings, warning)
Warn(ctx, warning)
} }
if err != nil {
errMsg := listenerID + ": " + err.Error()
listenerErrors = append(listenerErrors, fmt.Errorf(errMsg))
Error(ctx, fmt.Errorf(errMsg))
}
// TODO: Use listenerutil.TLSConfig to warn on incorrect protocol specified
// Alternatively, use tlsutil.SetupTLSConfig.
} }
return nil return listenerWarnings, listenerErrors
} }
// TLSFileChecks contains manual error checks against the TLS configuration // TLSFileChecks returns an error and warnings after checking TLS information
func TLSFileChecks(certFilePath, keyFilePath string) error { func TLSFileChecks(certpath, keypath string) ([]string, error) {
// Parse TLS Certs from the certpath
leafCerts, interCerts, rootCerts, err := ParseTLSInformation(certpath)
if err != nil {
return nil, err
}
// Check for TLS Warnings
warnings, err := TLSFileWarningChecks(leafCerts, interCerts, rootCerts)
if err != nil {
return warnings, err
}
// Check for TLS Errors
if err = TLSErrorChecks(leafCerts, interCerts, rootCerts); err != nil {
return warnings, err
}
// Utilize the native TLS Loading mechanism to ensure we have missed no errors
_, err = tls.LoadX509KeyPair(certpath, keypath)
return warnings, err
}
// ParseTLSInformation parses certficate information and returns it from a cert path.
func ParseTLSInformation(certFilePath string) ([]*x509.Certificate, []*x509.Certificate, []*x509.Certificate, error) {
leafCerts := []*x509.Certificate{}
interCerts := []*x509.Certificate{}
rootCerts := []*x509.Certificate{}
data, err := ioutil.ReadFile(certFilePath) data, err := ioutil.ReadFile(certFilePath)
if err != nil { if err != nil {
return fmt.Errorf("failed to read tls_client_ca_file: %w", err) return leafCerts, interCerts, rootCerts, fmt.Errorf("failed to read certificate file: %w", err)
} }
certBlocks := []*pem.Block{} certBlocks := []*pem.Block{}
leafCerts := []*x509.Certificate{}
rootPool := x509.NewCertPool()
interPool := x509.NewCertPool()
rst := []byte(data) rst := []byte(data)
for len(rst) != 0 { for len(rst) != 0 {
block, rest := pem.Decode(rst) block, rest := pem.Decode(rst)
if block == nil { if block == nil {
return fmt.Errorf("could not decode cert") return leafCerts, interCerts, rootCerts, fmt.Errorf("could not decode cert")
} }
certBlocks = append(certBlocks, block) certBlocks = append(certBlocks, block)
rst = rest rst = rest
} }
if len(certBlocks) == 0 { if len(certBlocks) == 0 {
return fmt.Errorf("no certificates found in cert file") return leafCerts, interCerts, rootCerts, fmt.Errorf("no certificates found in cert file")
} }
for _, certBlock := range certBlocks { for _, certBlock := range certBlocks {
cert, err := x509.ParseCertificate(certBlock.Bytes) cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil { if err != nil {
return fmt.Errorf("A pem block does not parse to a certificate: %w", err) return leafCerts, interCerts, rootCerts, fmt.Errorf("A pem block does not parse to a certificate: %w", err)
} }
// Detect if the certificate is a root, leaf, or intermediate // Detect if the certificate is a root, leaf, or intermediate
if cert.IsCA && bytes.Equal(cert.RawIssuer, cert.RawSubject) { if cert.IsCA && bytes.Equal(cert.RawIssuer, cert.RawSubject) {
// It's a root // It's a root
rootPool.AddCert(cert) rootCerts = append(rootCerts, cert)
} else if cert.IsCA { } else if cert.IsCA {
// It's not a root but it's a CA, so it's an inter // It's not a root but it's a CA, so it's an inter
interPool.AddCert(cert) interCerts = append(interCerts, cert)
} else { } else {
// It's gotta be a leaf // It's gotta be a leaf
leafCerts = append(leafCerts, cert) leafCerts = append(leafCerts, cert)
} }
} }
return leafCerts, interCerts, rootCerts, nil
}
// TLSErrorChecks contains manual error checks against the TLS configuration
func TLSErrorChecks(leafCerts, interCerts, rootCerts []*x509.Certificate) error {
// First, create root pools and interPools from the root and inter certs lists
rootPool := x509.NewCertPool()
interPool := x509.NewCertPool()
for _, root := range rootCerts {
rootPool.AddCert(root)
}
for _, inter := range interCerts {
interPool.AddCert(inter)
}
// Make sure there's only one leaf. If there are multiple, it's a bad pem file. // Make sure there's only one leaf. If there are multiple, it's a bad pem file.
if len(leafCerts) != 1 { if len(leafCerts) != 1 {
@ -102,23 +168,47 @@ func TLSFileChecks(certFilePath, keyFilePath string) error {
// Verify checks that certificate isn't expired, is of correct usage type, and has an appropriate // Verify checks that certificate isn't expired, is of correct usage type, and has an appropriate
// chain. // chain.
_, err = leafCerts[0].Verify(x509.VerifyOptions{ _, err := leafCerts[0].Verify(x509.VerifyOptions{
Roots: rootPool, Roots: rootPool,
Intermediates: interPool, Intermediates: interPool,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to verify certificate: %w", err) return fmt.Errorf("failed to verify primary provided leaf certificate: %w", err)
}
// After verify passes, we need to check the values on the certificate itself.
// This is a separate check beyond the certificate expiry and chain checks.
_, err = tls.LoadX509KeyPair(certFilePath, keyFilePath)
if err != nil {
return err
} }
return nil return nil
} }
// TLSFileWarningChecks returns warnings based on the leaf certificates, intermediate certificates,
// and root certificates provided.
func TLSFileWarningChecks(leafCerts, interCerts, rootCerts []*x509.Certificate) ([]string, error) {
var warnings []string
for _, c := range leafCerts {
if NearExpiration(c) {
warnings = append(warnings, fmt.Sprintf("leaf certificate %d is expired or near expiry", c.SerialNumber))
}
}
for _, c := range interCerts {
if NearExpiration(c) {
warnings = append(warnings, fmt.Sprintf("intermediate certificate %d is expired or near expiry", c.SerialNumber))
}
}
for _, c := range rootCerts {
if NearExpiration(c) {
warnings = append(warnings, fmt.Sprintf("root certificate %d is expired or near expiry", c.SerialNumber))
}
}
return warnings, nil
}
// NearExpiration returns a true if a certficate will expire in a week and false otherwise
func NearExpiration(c *x509.Certificate) bool {
oneWeekFromNow := time.Now().Add(7 * 24 * time.Hour)
if oneWeekFromNow.After(c.NotAfter) {
return true
}
return false
}

View File

@ -1,6 +1,7 @@
package diagnose package diagnose
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -26,9 +27,13 @@ func TestTLSValidCert(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) warnings, errs := ListenerChecks(context.Background(), listeners)
if err != nil { if errs != nil {
t.Fatalf(err.Error()) // The test failed -- we can just return one of the errors
t.Fatalf(errs[0].Error())
}
if warnings != nil {
t.Fatalf("warnings returned from good listener")
} }
} }
@ -48,12 +53,15 @@ func TestTLSFakeCert(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if !strings.Contains(err.Error(), "could not decode cert") { if len(errs) != 1 {
t.Fatalf("Bad error message: %s", err) t.Fatalf("more than one error returned: %+v", errs)
}
if !strings.Contains(errs[0].Error(), "could not decode cert") {
t.Fatalf("Bad error message: %s", errs[0])
} }
} }
@ -76,12 +84,12 @@ func TestTLSTrailingData(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if !strings.Contains(err.Error(), "asn1: syntax error: trailing data") { if !strings.Contains(errs[0].Error(), "asn1: syntax error: trailing data") {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
} }
} }
@ -102,12 +110,18 @@ func TestTLSExpiredCert(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) warnings, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if !strings.Contains(err.Error(), "certificate has expired or is not yet valid") { if !strings.Contains(errs[0].Error(), "certificate has expired or is not yet valid") {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
}
if warnings == nil || len(warnings) != 1 {
t.Fatalf("TLS Config check on fake certificate should warn")
}
if !strings.Contains(warnings[0], "expired or near expiry") {
t.Fatalf("Bad warning: %s", errs[0])
} }
} }
@ -128,12 +142,12 @@ func TestTLSMismatchedCryptographicInfo(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if err.Error() != "tls: private key type does not match public key type" { if !strings.Contains(errs[0].Error(), "tls: private key type does not match public key type") {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
} }
listeners = []listenerutil.Listener{ listeners = []listenerutil.Listener{
@ -151,12 +165,12 @@ func TestTLSMismatchedCryptographicInfo(t *testing.T) {
}, },
}, },
} }
err = ListenerChecks(listeners) _, errs = ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if err.Error() != "tls: private key type does not match public key type" { if !strings.Contains(errs[0].Error(), "tls: private key type does not match public key type") {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
} }
} }
@ -177,12 +191,12 @@ func TestTLSMultiKeys(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if !strings.Contains(err.Error(), "pem block does not parse to a certificate") { if !strings.Contains(errs[0].Error(), "pem block does not parse to a certificate") {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
} }
} }
@ -202,12 +216,12 @@ func TestTLSMultiCerts(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if !strings.Contains(err.Error(), "found a certificate rather than a key in the PEM for the private key") { if !strings.Contains(errs[0].Error(), "found a certificate rather than a key in the PEM for the private key") {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
} }
} }
@ -229,12 +243,12 @@ func TestTLSInvalidRoot(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if err.Error() != "failed to verify certificate: x509: certificate signed by unknown authority" { if !strings.Contains(errs[0].Error(), "certificate signed by unknown authority") {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
} }
} }
@ -256,9 +270,9 @@ func TestTLSNoRoot(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err != nil { if errs != nil {
t.Fatalf("Server certificate without root certificate is insecure, but still valid.") t.Fatalf("server certificate without root certificate is insecure, but still valid")
} }
} }
@ -280,12 +294,12 @@ func TestTLSInvalidMinVersion(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if err.Error() != fmt.Errorf(minVersionError, "0").Error() { if !strings.Contains(errs[0].Error(), fmt.Errorf(minVersionError, "0").Error()) {
t.Fatalf("Bad error message: %s", err) t.Fatalf("Bad error message: %s", errs[0])
} }
} }
@ -307,11 +321,11 @@ func TestTLSInvalidMaxVersion(t *testing.T) {
}, },
}, },
} }
err := ListenerChecks(listeners) _, errs := ListenerChecks(context.Background(), listeners)
if err == nil { if errs == nil || len(errs) != 1 {
t.Fatalf("TLS Config check on fake certificate should fail") t.Fatalf("TLS Config check on fake certificate should fail")
} }
if err.Error() != fmt.Errorf(maxVersionError, "0").Error() { if !strings.Contains(errs[0].Error(), fmt.Errorf(maxVersionError, "0").Error()) {
t.Errorf("Bad error message: %w", err) t.Fatalf("Bad error message: %s", errs[0])
} }
} }