mirror of
https://github.com/traefik/traefik.git
synced 2025-08-14 02:27:09 +02:00
Consider default cert domain in certificate store
Co-authored-by: Nicolas Mengin <nmengin.pro@gmail.com>
This commit is contained in:
parent
f4f62e7fb3
commit
a7dbcc282c
@ -2,6 +2,7 @@ package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
@ -47,6 +48,11 @@ func (c CertificateStore) GetAllDomains() []string {
|
||||
allCerts = append(allCerts, domains)
|
||||
}
|
||||
}
|
||||
|
||||
// Get Default certificate
|
||||
if c.DefaultCertificate != nil {
|
||||
allCerts = append(allCerts, getCertificateDomains(c.DefaultCertificate)...)
|
||||
}
|
||||
return allCerts
|
||||
}
|
||||
|
||||
@ -115,6 +121,27 @@ func (c CertificateStore) ResetCache() {
|
||||
}
|
||||
}
|
||||
|
||||
func getCertificateDomains(cert *tls.Certificate) []string {
|
||||
if cert == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var names []string
|
||||
if len(x509Cert.Subject.CommonName) > 0 {
|
||||
names = append(names, x509Cert.Subject.CommonName)
|
||||
}
|
||||
for _, san := range x509Cert.DNSNames {
|
||||
names = append(names, san)
|
||||
}
|
||||
|
||||
return names
|
||||
}
|
||||
|
||||
// MatchDomain return true if a domain match the cert domain
|
||||
func MatchDomain(domain string, certDomain string) bool {
|
||||
if domain == certDomain {
|
||||
|
@ -13,6 +13,90 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetAllDomains(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
staticCert string
|
||||
dynamicCert string
|
||||
defaultCert string
|
||||
expectedDomains []string
|
||||
}{
|
||||
{
|
||||
desc: "Empty Store, returns no domains",
|
||||
staticCert: "",
|
||||
dynamicCert: "",
|
||||
defaultCert: "",
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "Static cert domains",
|
||||
staticCert: "snitest.com",
|
||||
dynamicCert: "",
|
||||
defaultCert: "",
|
||||
expectedDomains: []string{"snitest.com"},
|
||||
},
|
||||
{
|
||||
desc: "Dynamic cert domains",
|
||||
staticCert: "",
|
||||
dynamicCert: "snitest.com",
|
||||
defaultCert: "",
|
||||
expectedDomains: []string{"snitest.com"},
|
||||
},
|
||||
{
|
||||
desc: "Default cert domains",
|
||||
staticCert: "",
|
||||
dynamicCert: "",
|
||||
defaultCert: "snitest.com",
|
||||
expectedDomains: []string{"snitest.com"},
|
||||
},
|
||||
{
|
||||
desc: "All domains",
|
||||
staticCert: "www.snitest.com",
|
||||
dynamicCert: "*.snitest.com",
|
||||
defaultCert: "snitest.com",
|
||||
expectedDomains: []string{"www.snitest.com", "*.snitest.com", "snitest.com"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
staticMap := map[string]*tls.Certificate{}
|
||||
if test.staticCert != "" {
|
||||
cert, err := loadTestCert(test.staticCert, false)
|
||||
require.NoError(t, err)
|
||||
staticMap[strings.ToLower(test.staticCert)] = cert
|
||||
}
|
||||
|
||||
dynamicMap := map[string]*tls.Certificate{}
|
||||
if test.dynamicCert != "" {
|
||||
cert, err := loadTestCert(test.dynamicCert, false)
|
||||
require.NoError(t, err)
|
||||
dynamicMap[strings.ToLower(test.dynamicCert)] = cert
|
||||
}
|
||||
|
||||
var defaultCert *tls.Certificate
|
||||
if test.defaultCert != "" {
|
||||
cert, err := loadTestCert(test.defaultCert, false)
|
||||
require.NoError(t, err)
|
||||
defaultCert = cert
|
||||
}
|
||||
|
||||
store := &CertificateStore{
|
||||
DynamicCerts: safe.New(dynamicMap),
|
||||
StaticCerts: safe.New(staticMap),
|
||||
DefaultCertificate: defaultCert,
|
||||
CertCache: cache.New(1*time.Hour, 10*time.Minute),
|
||||
}
|
||||
|
||||
actual := store.GetAllDomains()
|
||||
assert.Equal(t, test.expectedDomains, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBestCertificate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
@ -116,15 +200,15 @@ func TestGetBestCertificate(t *testing.T) {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
staticMap := map[string]*tls.Certificate{}
|
||||
dynamicMap := map[string]*tls.Certificate{}
|
||||
|
||||
staticMap := map[string]*tls.Certificate{}
|
||||
if test.staticCert != "" {
|
||||
cert, err := loadTestCert(test.staticCert, test.uppercase)
|
||||
require.NoError(t, err)
|
||||
staticMap[strings.ToLower(test.staticCert)] = cert
|
||||
}
|
||||
|
||||
dynamicMap := map[string]*tls.Certificate{}
|
||||
if test.dynamicCert != "" {
|
||||
cert, err := loadTestCert(test.dynamicCert, test.uppercase)
|
||||
require.NoError(t, err)
|
||||
|
Loading…
Reference in New Issue
Block a user