From a7dbcc282cd25d63edd599b5a78baedd3d1920f5 Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Wed, 11 Sep 2019 17:46:04 +0200 Subject: [PATCH] Consider default cert domain in certificate store Co-authored-by: Nicolas Mengin --- tls/certificate_store.go | 27 +++++++++++ tls/certificate_store_test.go | 88 ++++++++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/tls/certificate_store.go b/tls/certificate_store.go index 6ddb9407c..dcd1fd661 100644 --- a/tls/certificate_store.go +++ b/tls/certificate_store.go @@ -2,6 +2,7 @@ package tls import ( "crypto/tls" + "crypto/x509" "net" "sort" "strings" @@ -47,6 +48,11 @@ func (c CertificateStore) GetAllDomains() []string { allCerts = append(allCerts, domains) } } + + // Get Default certificate + if c.DefaultCertificate != nil { + allCerts = append(allCerts, getCertificateDomains(c.DefaultCertificate)...) + } return allCerts } @@ -115,6 +121,27 @@ func (c CertificateStore) ResetCache() { } } +func getCertificateDomains(cert *tls.Certificate) []string { + if cert == nil { + return nil + } + + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil + } + + var names []string + if len(x509Cert.Subject.CommonName) > 0 { + names = append(names, x509Cert.Subject.CommonName) + } + for _, san := range x509Cert.DNSNames { + names = append(names, san) + } + + return names +} + // MatchDomain return true if a domain match the cert domain func MatchDomain(domain string, certDomain string) bool { if domain == certDomain { diff --git a/tls/certificate_store_test.go b/tls/certificate_store_test.go index 506937010..96c570069 100644 --- a/tls/certificate_store_test.go +++ b/tls/certificate_store_test.go @@ -13,6 +13,90 @@ import ( "github.com/stretchr/testify/require" ) +func TestGetAllDomains(t *testing.T) { + testCases := []struct { + desc string + staticCert string + dynamicCert string + defaultCert string + expectedDomains []string + }{ + { + desc: "Empty Store, returns no domains", + staticCert: "", + dynamicCert: "", + defaultCert: "", + expectedDomains: nil, + }, + { + desc: "Static cert domains", + staticCert: "snitest.com", + dynamicCert: "", + defaultCert: "", + expectedDomains: []string{"snitest.com"}, + }, + { + desc: "Dynamic cert domains", + staticCert: "", + dynamicCert: "snitest.com", + defaultCert: "", + expectedDomains: []string{"snitest.com"}, + }, + { + desc: "Default cert domains", + staticCert: "", + dynamicCert: "", + defaultCert: "snitest.com", + expectedDomains: []string{"snitest.com"}, + }, + { + desc: "All domains", + staticCert: "www.snitest.com", + dynamicCert: "*.snitest.com", + defaultCert: "snitest.com", + expectedDomains: []string{"www.snitest.com", "*.snitest.com", "snitest.com"}, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + staticMap := map[string]*tls.Certificate{} + if test.staticCert != "" { + cert, err := loadTestCert(test.staticCert, false) + require.NoError(t, err) + staticMap[strings.ToLower(test.staticCert)] = cert + } + + dynamicMap := map[string]*tls.Certificate{} + if test.dynamicCert != "" { + cert, err := loadTestCert(test.dynamicCert, false) + require.NoError(t, err) + dynamicMap[strings.ToLower(test.dynamicCert)] = cert + } + + var defaultCert *tls.Certificate + if test.defaultCert != "" { + cert, err := loadTestCert(test.defaultCert, false) + require.NoError(t, err) + defaultCert = cert + } + + store := &CertificateStore{ + DynamicCerts: safe.New(dynamicMap), + StaticCerts: safe.New(staticMap), + DefaultCertificate: defaultCert, + CertCache: cache.New(1*time.Hour, 10*time.Minute), + } + + actual := store.GetAllDomains() + assert.Equal(t, test.expectedDomains, actual) + }) + } +} + func TestGetBestCertificate(t *testing.T) { testCases := []struct { desc string @@ -116,15 +200,15 @@ func TestGetBestCertificate(t *testing.T) { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() - staticMap := map[string]*tls.Certificate{} - dynamicMap := map[string]*tls.Certificate{} + staticMap := map[string]*tls.Certificate{} if test.staticCert != "" { cert, err := loadTestCert(test.staticCert, test.uppercase) require.NoError(t, err) staticMap[strings.ToLower(test.staticCert)] = cert } + dynamicMap := map[string]*tls.Certificate{} if test.dynamicCert != "" { cert, err := loadTestCert(test.dynamicCert, test.uppercase) require.NoError(t, err)