mirror of
https://github.com/traefik/traefik.git
synced 2025-08-16 03:27:14 +02:00
Consider default cert domain in certificate store
Co-authored-by: Nicolas Mengin <nmengin.pro@gmail.com>
This commit is contained in:
parent
f4f62e7fb3
commit
a7dbcc282c
@ -2,6 +2,7 @@ package tls
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@ -47,6 +48,11 @@ func (c CertificateStore) GetAllDomains() []string {
|
|||||||
allCerts = append(allCerts, domains)
|
allCerts = append(allCerts, domains)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get Default certificate
|
||||||
|
if c.DefaultCertificate != nil {
|
||||||
|
allCerts = append(allCerts, getCertificateDomains(c.DefaultCertificate)...)
|
||||||
|
}
|
||||||
return allCerts
|
return allCerts
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,6 +121,27 @@ func (c CertificateStore) ResetCache() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCertificateDomains(cert *tls.Certificate) []string {
|
||||||
|
if cert == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
if len(x509Cert.Subject.CommonName) > 0 {
|
||||||
|
names = append(names, x509Cert.Subject.CommonName)
|
||||||
|
}
|
||||||
|
for _, san := range x509Cert.DNSNames {
|
||||||
|
names = append(names, san)
|
||||||
|
}
|
||||||
|
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
// MatchDomain return true if a domain match the cert domain
|
// MatchDomain return true if a domain match the cert domain
|
||||||
func MatchDomain(domain string, certDomain string) bool {
|
func MatchDomain(domain string, certDomain string) bool {
|
||||||
if domain == certDomain {
|
if domain == certDomain {
|
||||||
|
@ -13,6 +13,90 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestGetAllDomains(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
desc string
|
||||||
|
staticCert string
|
||||||
|
dynamicCert string
|
||||||
|
defaultCert string
|
||||||
|
expectedDomains []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "Empty Store, returns no domains",
|
||||||
|
staticCert: "",
|
||||||
|
dynamicCert: "",
|
||||||
|
defaultCert: "",
|
||||||
|
expectedDomains: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Static cert domains",
|
||||||
|
staticCert: "snitest.com",
|
||||||
|
dynamicCert: "",
|
||||||
|
defaultCert: "",
|
||||||
|
expectedDomains: []string{"snitest.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Dynamic cert domains",
|
||||||
|
staticCert: "",
|
||||||
|
dynamicCert: "snitest.com",
|
||||||
|
defaultCert: "",
|
||||||
|
expectedDomains: []string{"snitest.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "Default cert domains",
|
||||||
|
staticCert: "",
|
||||||
|
dynamicCert: "",
|
||||||
|
defaultCert: "snitest.com",
|
||||||
|
expectedDomains: []string{"snitest.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "All domains",
|
||||||
|
staticCert: "www.snitest.com",
|
||||||
|
dynamicCert: "*.snitest.com",
|
||||||
|
defaultCert: "snitest.com",
|
||||||
|
expectedDomains: []string{"www.snitest.com", "*.snitest.com", "snitest.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range testCases {
|
||||||
|
test := test
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
staticMap := map[string]*tls.Certificate{}
|
||||||
|
if test.staticCert != "" {
|
||||||
|
cert, err := loadTestCert(test.staticCert, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
staticMap[strings.ToLower(test.staticCert)] = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
dynamicMap := map[string]*tls.Certificate{}
|
||||||
|
if test.dynamicCert != "" {
|
||||||
|
cert, err := loadTestCert(test.dynamicCert, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
dynamicMap[strings.ToLower(test.dynamicCert)] = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultCert *tls.Certificate
|
||||||
|
if test.defaultCert != "" {
|
||||||
|
cert, err := loadTestCert(test.defaultCert, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defaultCert = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
store := &CertificateStore{
|
||||||
|
DynamicCerts: safe.New(dynamicMap),
|
||||||
|
StaticCerts: safe.New(staticMap),
|
||||||
|
DefaultCertificate: defaultCert,
|
||||||
|
CertCache: cache.New(1*time.Hour, 10*time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := store.GetAllDomains()
|
||||||
|
assert.Equal(t, test.expectedDomains, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetBestCertificate(t *testing.T) {
|
func TestGetBestCertificate(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
desc string
|
desc string
|
||||||
@ -116,15 +200,15 @@ func TestGetBestCertificate(t *testing.T) {
|
|||||||
test := test
|
test := test
|
||||||
t.Run(test.desc, func(t *testing.T) {
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
staticMap := map[string]*tls.Certificate{}
|
|
||||||
dynamicMap := map[string]*tls.Certificate{}
|
|
||||||
|
|
||||||
|
staticMap := map[string]*tls.Certificate{}
|
||||||
if test.staticCert != "" {
|
if test.staticCert != "" {
|
||||||
cert, err := loadTestCert(test.staticCert, test.uppercase)
|
cert, err := loadTestCert(test.staticCert, test.uppercase)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
staticMap[strings.ToLower(test.staticCert)] = cert
|
staticMap[strings.ToLower(test.staticCert)] = cert
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dynamicMap := map[string]*tls.Certificate{}
|
||||||
if test.dynamicCert != "" {
|
if test.dynamicCert != "" {
|
||||||
cert, err := loadTestCert(test.dynamicCert, test.uppercase)
|
cert, err := loadTestCert(test.dynamicCert, test.uppercase)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
Loading…
Reference in New Issue
Block a user