Consider default cert domain in certificate store

Co-authored-by: Nicolas Mengin <nmengin.pro@gmail.com>
This commit is contained in:
Nicholas Wiersma 2019-09-11 17:46:04 +02:00 committed by Traefiker Bot
parent f4f62e7fb3
commit a7dbcc282c
2 changed files with 113 additions and 2 deletions

View File

@ -2,6 +2,7 @@ package tls
import (
"crypto/tls"
"crypto/x509"
"net"
"sort"
"strings"
@ -47,6 +48,11 @@ func (c CertificateStore) GetAllDomains() []string {
allCerts = append(allCerts, domains)
}
}
// Get Default certificate
if c.DefaultCertificate != nil {
allCerts = append(allCerts, getCertificateDomains(c.DefaultCertificate)...)
}
return allCerts
}
@ -115,6 +121,27 @@ func (c CertificateStore) ResetCache() {
}
}
func getCertificateDomains(cert *tls.Certificate) []string {
if cert == nil {
return nil
}
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil
}
var names []string
if len(x509Cert.Subject.CommonName) > 0 {
names = append(names, x509Cert.Subject.CommonName)
}
for _, san := range x509Cert.DNSNames {
names = append(names, san)
}
return names
}
// MatchDomain return true if a domain match the cert domain
func MatchDomain(domain string, certDomain string) bool {
if domain == certDomain {

View File

@ -13,6 +13,90 @@ import (
"github.com/stretchr/testify/require"
)
func TestGetAllDomains(t *testing.T) {
testCases := []struct {
desc string
staticCert string
dynamicCert string
defaultCert string
expectedDomains []string
}{
{
desc: "Empty Store, returns no domains",
staticCert: "",
dynamicCert: "",
defaultCert: "",
expectedDomains: nil,
},
{
desc: "Static cert domains",
staticCert: "snitest.com",
dynamicCert: "",
defaultCert: "",
expectedDomains: []string{"snitest.com"},
},
{
desc: "Dynamic cert domains",
staticCert: "",
dynamicCert: "snitest.com",
defaultCert: "",
expectedDomains: []string{"snitest.com"},
},
{
desc: "Default cert domains",
staticCert: "",
dynamicCert: "",
defaultCert: "snitest.com",
expectedDomains: []string{"snitest.com"},
},
{
desc: "All domains",
staticCert: "www.snitest.com",
dynamicCert: "*.snitest.com",
defaultCert: "snitest.com",
expectedDomains: []string{"www.snitest.com", "*.snitest.com", "snitest.com"},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
staticMap := map[string]*tls.Certificate{}
if test.staticCert != "" {
cert, err := loadTestCert(test.staticCert, false)
require.NoError(t, err)
staticMap[strings.ToLower(test.staticCert)] = cert
}
dynamicMap := map[string]*tls.Certificate{}
if test.dynamicCert != "" {
cert, err := loadTestCert(test.dynamicCert, false)
require.NoError(t, err)
dynamicMap[strings.ToLower(test.dynamicCert)] = cert
}
var defaultCert *tls.Certificate
if test.defaultCert != "" {
cert, err := loadTestCert(test.defaultCert, false)
require.NoError(t, err)
defaultCert = cert
}
store := &CertificateStore{
DynamicCerts: safe.New(dynamicMap),
StaticCerts: safe.New(staticMap),
DefaultCertificate: defaultCert,
CertCache: cache.New(1*time.Hour, 10*time.Minute),
}
actual := store.GetAllDomains()
assert.Equal(t, test.expectedDomains, actual)
})
}
}
func TestGetBestCertificate(t *testing.T) {
testCases := []struct {
desc string
@ -116,15 +200,15 @@ func TestGetBestCertificate(t *testing.T) {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
staticMap := map[string]*tls.Certificate{}
dynamicMap := map[string]*tls.Certificate{}
staticMap := map[string]*tls.Certificate{}
if test.staticCert != "" {
cert, err := loadTestCert(test.staticCert, test.uppercase)
require.NoError(t, err)
staticMap[strings.ToLower(test.staticCert)] = cert
}
dynamicMap := map[string]*tls.Certificate{}
if test.dynamicCert != "" {
cert, err := loadTestCert(test.dynamicCert, test.uppercase)
require.NoError(t, err)