mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-14 18:47:01 +02:00
* Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License. Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUS-1.1 * Fix test that expected exact offset on hcl file --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com> Co-authored-by: Sarah Thompson <sthompson@hashicorp.com> Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
323 lines
14 KiB
Go
323 lines
14 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package pki
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/json"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/acme"
|
|
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/builtin/logical/pki/dnstest"
|
|
"github.com/hashicorp/vault/helper/constants"
|
|
"github.com/hashicorp/vault/helper/timeutil"
|
|
"github.com/hashicorp/vault/vault"
|
|
"github.com/hashicorp/vault/vault/activity"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestACMEBilling is a basic test that will validate client counts created via ACME workflows.
|
|
func TestACMEBilling(t *testing.T) {
|
|
t.Parallel()
|
|
timeutil.SkipAtEndOfMonth(t)
|
|
|
|
cluster, client, _ := setupAcmeBackend(t)
|
|
defer cluster.Cleanup()
|
|
|
|
dns := dnstest.SetupResolver(t, "dadgarcorp.com")
|
|
defer dns.Cleanup()
|
|
|
|
// Enable additional mounts.
|
|
setupAcmeBackendOnClusterAtPath(t, cluster, client, "pki2")
|
|
setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns1/pki")
|
|
setupAcmeBackendOnClusterAtPath(t, cluster, client, "ns2/pki")
|
|
|
|
// Enable custom DNS resolver for testing.
|
|
for _, mount := range []string{"pki", "pki2", "ns1/pki", "ns2/pki"} {
|
|
_, err := client.Logical().Write(mount+"/config/acme", map[string]interface{}{
|
|
"dns_resolver": dns.GetLocalAddr(),
|
|
})
|
|
require.NoError(t, err, "failed to set local dns resolver address for testing on mount: "+mount)
|
|
}
|
|
|
|
// Enable client counting.
|
|
_, err := client.Logical().Write("/sys/internal/counters/config", map[string]interface{}{
|
|
"enabled": "enable",
|
|
})
|
|
require.NoError(t, err, "failed to enable client counting")
|
|
|
|
// Setup ACME clients. We refresh account keys each time for consistency.
|
|
acmeClientPKI := getAcmeClientForCluster(t, cluster, "/v1/pki/acme/", nil)
|
|
acmeClientPKI2 := getAcmeClientForCluster(t, cluster, "/v1/pki2/acme/", nil)
|
|
acmeClientPKINS1 := getAcmeClientForCluster(t, cluster, "/v1/ns1/pki/acme/", nil)
|
|
acmeClientPKINS2 := getAcmeClientForCluster(t, cluster, "/v1/ns2/pki/acme/", nil)
|
|
|
|
// Get our initial count.
|
|
expectedCount := validateClientCount(t, client, "", -1, "initial fetch")
|
|
|
|
// Unique identifier: should increase by one.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")
|
|
|
|
// Different identifier; should increase by one.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"example.dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")
|
|
|
|
// While same identifiers, used together and so thus are unique; increase by one.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"example.dadgarcorp.com", "dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "pki", expectedCount+1, "new certificate")
|
|
|
|
// Same identifiers in different order are not unique; keep the same.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKI, []string{"dadgarcorp.com", "example.dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "pki", expectedCount, "different order; same identifiers")
|
|
|
|
// Using a different mount shouldn't affect counts.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKI2, []string{"dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "", expectedCount, "different mount; same identifiers")
|
|
|
|
// But using a different identifier should.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKI2, []string{"pki2.dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "pki2", expectedCount+1, "different mount with different identifiers")
|
|
|
|
// A new identifier in a unique namespace will affect results.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKINS1, []string{"unique.dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "ns1/pki", expectedCount+1, "unique identifier in a namespace")
|
|
|
|
// But in a different namespace with the existing identifier will not.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKINS2, []string{"unique.dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier in a namespace")
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKI2, []string{"unique.dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "", expectedCount, "existing identifier outside of a namespace")
|
|
|
|
// Creating a unique identifier in a namespace with a mount with the
|
|
// same name as another namespace should increase counts as well.
|
|
doACMEForDomainWithDNS(t, dns, acmeClientPKINS2, []string{"very-unique.dadgarcorp.com"})
|
|
expectedCount = validateClientCount(t, client, "ns2/pki", expectedCount+1, "unique identifier in a different namespace")
|
|
|
|
// Check the current fragment
|
|
fragment := cluster.Cores[0].Core.ResetActivityLog()[0]
|
|
if fragment == nil {
|
|
t.Fatal("no fragment created")
|
|
}
|
|
validateAcmeClientTypes(t, fragment, expectedCount)
|
|
}
|
|
|
|
func validateAcmeClientTypes(t *testing.T, fragment *activity.LogFragment, expectedCount int64) {
|
|
t.Helper()
|
|
if int64(len(fragment.Clients)) != expectedCount {
|
|
t.Fatalf("bad number of entities, expected %v: got %v, entities are: %v", expectedCount, len(fragment.Clients), fragment.Clients)
|
|
}
|
|
|
|
for _, ac := range fragment.Clients {
|
|
if ac.ClientType != vault.ACMEActivityType {
|
|
t.Fatalf("Couldn't find expected '%v' client_type in %v", vault.ACMEActivityType, fragment.Clients)
|
|
}
|
|
}
|
|
}
|
|
|
|
func validateClientCount(t *testing.T, client *api.Client, mount string, expected int64, message string) int64 {
|
|
resp, err := client.Logical().Read("/sys/internal/counters/activity/monthly")
|
|
require.NoError(t, err, "failed to fetch client count values")
|
|
t.Logf("got client count numbers: %v", resp)
|
|
|
|
require.NotNil(t, resp)
|
|
require.NotNil(t, resp.Data)
|
|
require.Contains(t, resp.Data, "non_entity_clients")
|
|
require.Contains(t, resp.Data, "months")
|
|
|
|
rawCount := resp.Data["non_entity_clients"].(json.Number)
|
|
count, err := rawCount.Int64()
|
|
require.NoError(t, err, "failed to parse number as int64: "+rawCount.String())
|
|
|
|
if expected != -1 {
|
|
require.Equal(t, expected, count, "value of client counts did not match expectations: "+message)
|
|
}
|
|
|
|
if mount == "" {
|
|
return count
|
|
}
|
|
|
|
months := resp.Data["months"].([]interface{})
|
|
if len(months) > 1 {
|
|
t.Fatalf("running across a month boundary despite using SkipAtEndOfMonth(...); rerun test from start fully in the next month instead")
|
|
}
|
|
|
|
require.Equal(t, 1, len(months), "expected only a single month when running this test")
|
|
|
|
monthlyInfo := months[0].(map[string]interface{})
|
|
|
|
// Validate this month's aggregate counts match the overall value.
|
|
require.Contains(t, monthlyInfo, "counts", "expected monthly info to contain a count key")
|
|
monthlyCounts := monthlyInfo["counts"].(map[string]interface{})
|
|
require.Contains(t, monthlyCounts, "non_entity_clients", "expected month[0].counts to contain a non_entity_clients key")
|
|
monthlyCountNonEntityRaw := monthlyCounts["non_entity_clients"].(json.Number)
|
|
monthlyCountNonEntity, err := monthlyCountNonEntityRaw.Int64()
|
|
require.NoError(t, err, "failed to parse number as int64: "+monthlyCountNonEntityRaw.String())
|
|
require.Equal(t, count, monthlyCountNonEntity, "expected equal values for non entity client counts")
|
|
|
|
// Validate this mount's namespace is included in the namespaces list,
|
|
// if this is enterprise. Otherwise, if its OSS or we don't have a
|
|
// namespace, we default to the value root.
|
|
mountNamespace := ""
|
|
mountPath := mount + "/"
|
|
if constants.IsEnterprise && strings.Contains(mount, "/") {
|
|
pieces := strings.Split(mount, "/")
|
|
require.Equal(t, 2, len(pieces), "we do not support nested namespaces in this test")
|
|
mountNamespace = pieces[0] + "/"
|
|
mountPath = pieces[1] + "/"
|
|
}
|
|
|
|
require.Contains(t, monthlyInfo, "namespaces", "expected monthly info to contain a namespaces key")
|
|
monthlyNamespaces := monthlyInfo["namespaces"].([]interface{})
|
|
foundNamespace := false
|
|
for index, namespaceRaw := range monthlyNamespaces {
|
|
namespace := namespaceRaw.(map[string]interface{})
|
|
require.Contains(t, namespace, "namespace_path", "expected monthly.namespaces[%v] to contain a namespace_path key", index)
|
|
namespacePath := namespace["namespace_path"].(string)
|
|
|
|
if namespacePath != mountNamespace {
|
|
t.Logf("skipping non-matching namespace %v: %v != %v / %v", index, namespacePath, mountNamespace, namespace)
|
|
continue
|
|
}
|
|
|
|
foundNamespace = true
|
|
|
|
// This namespace must have a non-empty aggregate non-entity count.
|
|
require.Contains(t, namespace, "counts", "expected monthly.namespaces[%v] to contain a counts key", index)
|
|
namespaceCounts := namespace["counts"].(map[string]interface{})
|
|
require.Contains(t, namespaceCounts, "non_entity_clients", "expected namespace counts to contain a non_entity_clients key")
|
|
namespaceCountNonEntityRaw := namespaceCounts["non_entity_clients"].(json.Number)
|
|
namespaceCountNonEntity, err := namespaceCountNonEntityRaw.Int64()
|
|
require.NoError(t, err, "failed to parse number as int64: "+namespaceCountNonEntityRaw.String())
|
|
require.Greater(t, namespaceCountNonEntity, int64(0), "expected at least one non-entity client count value in the namespace")
|
|
|
|
require.Contains(t, namespace, "mounts", "expected monthly.namespaces[%v] to contain a mounts key", index)
|
|
namespaceMounts := namespace["mounts"].([]interface{})
|
|
foundMount := false
|
|
for mountIndex, mountRaw := range namespaceMounts {
|
|
mountInfo := mountRaw.(map[string]interface{})
|
|
require.Contains(t, mountInfo, "mount_path", "expected monthly.namespaces[%v].mounts[%v] to contain a mount_path key", index, mountIndex)
|
|
mountInfoPath := mountInfo["mount_path"].(string)
|
|
if mountPath != mountInfoPath {
|
|
t.Logf("skipping non-matching mount path %v in namespace %v: %v != %v / %v of %v", mountIndex, index, mountPath, mountInfoPath, mountInfo, namespace)
|
|
continue
|
|
}
|
|
|
|
foundMount = true
|
|
|
|
// This mount must also have a non-empty non-entity client count.
|
|
require.Contains(t, mountInfo, "counts", "expected monthly.namespaces[%v].mounts[%v] to contain a counts key", index, mountIndex)
|
|
mountCounts := mountInfo["counts"].(map[string]interface{})
|
|
require.Contains(t, mountCounts, "non_entity_clients", "expected mount counts to contain a non_entity_clients key")
|
|
mountCountNonEntityRaw := mountCounts["non_entity_clients"].(json.Number)
|
|
mountCountNonEntity, err := mountCountNonEntityRaw.Int64()
|
|
require.NoError(t, err, "failed to parse number as int64: "+mountCountNonEntityRaw.String())
|
|
require.Greater(t, mountCountNonEntity, int64(0), "expected at least one non-entity client count value in the mount")
|
|
}
|
|
|
|
require.True(t, foundMount, "expected to find the mount "+mountPath+" in the list of mounts for namespace, but did not")
|
|
}
|
|
|
|
require.True(t, foundNamespace, "expected to find the namespace "+mountNamespace+" in the list of namespaces, but did not")
|
|
|
|
return count
|
|
}
|
|
|
|
func doACMEForDomainWithDNS(t *testing.T, dns *dnstest.TestServer, acmeClient *acme.Client, domains []string) *x509.Certificate {
|
|
cr := &x509.CertificateRequest{
|
|
Subject: pkix.Name{CommonName: domains[0]},
|
|
DNSNames: domains,
|
|
}
|
|
|
|
return doACMEForCSRWithDNS(t, dns, acmeClient, domains, cr)
|
|
}
|
|
|
|
func doACMEForCSRWithDNS(t *testing.T, dns *dnstest.TestServer, acmeClient *acme.Client, domains []string, cr *x509.CertificateRequest) *x509.Certificate {
|
|
accountKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
require.NoError(t, err, "failed to generate account key")
|
|
acmeClient.Key = accountKey
|
|
|
|
testCtx, cancelFunc := context.WithTimeout(context.Background(), 2*time.Minute)
|
|
defer cancelFunc()
|
|
|
|
// Register the client.
|
|
_, err = acmeClient.Register(testCtx, &acme.Account{Contact: []string{"mailto:ipsans@dadgarcorp.com"}}, func(tosURL string) bool { return true })
|
|
require.NoError(t, err, "failed registering account")
|
|
|
|
// Create the Order
|
|
var orderIdentifiers []acme.AuthzID
|
|
for _, domain := range domains {
|
|
orderIdentifiers = append(orderIdentifiers, acme.AuthzID{Type: "dns", Value: domain})
|
|
}
|
|
order, err := acmeClient.AuthorizeOrder(testCtx, orderIdentifiers)
|
|
require.NoError(t, err, "failed creating ACME order")
|
|
|
|
// Fetch its authorizations.
|
|
var auths []*acme.Authorization
|
|
for _, authUrl := range order.AuthzURLs {
|
|
authorization, err := acmeClient.GetAuthorization(testCtx, authUrl)
|
|
require.NoError(t, err, "failed to lookup authorization at url: %s", authUrl)
|
|
auths = append(auths, authorization)
|
|
}
|
|
|
|
// For each dns-01 challenge, place the record in the associated DNS resolver.
|
|
var challengesToAccept []*acme.Challenge
|
|
for _, auth := range auths {
|
|
for _, challenge := range auth.Challenges {
|
|
if challenge.Status != acme.StatusPending {
|
|
t.Logf("ignoring challenge not in status pending: %v", challenge)
|
|
continue
|
|
}
|
|
|
|
if challenge.Type == "dns-01" {
|
|
challengeBody, err := acmeClient.DNS01ChallengeRecord(challenge.Token)
|
|
require.NoError(t, err, "failed generating challenge response")
|
|
|
|
dns.AddRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody)
|
|
defer dns.RemoveRecord("_acme-challenge."+auth.Identifier.Value, "TXT", challengeBody)
|
|
|
|
require.NoError(t, err, "failed setting DNS record")
|
|
|
|
challengesToAccept = append(challengesToAccept, challenge)
|
|
}
|
|
}
|
|
}
|
|
|
|
dns.PushConfig()
|
|
require.GreaterOrEqual(t, len(challengesToAccept), 1, "Need at least one challenge, got none")
|
|
|
|
// Tell the ACME server, that they can now validate those challenges.
|
|
for _, challenge := range challengesToAccept {
|
|
_, err = acmeClient.Accept(testCtx, challenge)
|
|
require.NoError(t, err, "failed to accept challenge: %v", challenge)
|
|
}
|
|
|
|
// Wait for the order/challenges to be validated.
|
|
_, err = acmeClient.WaitOrder(testCtx, order.URI)
|
|
require.NoError(t, err, "failed waiting for order to be ready")
|
|
|
|
// Create/sign the CSR and ask ACME server to sign it returning us the final certificate
|
|
csrKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
csr, err := x509.CreateCertificateRequest(rand.Reader, cr, csrKey)
|
|
require.NoError(t, err, "failed generating csr")
|
|
|
|
certs, _, err := acmeClient.CreateOrderCert(testCtx, order.FinalizeURL, csr, false)
|
|
require.NoError(t, err, "failed to get a certificate back from ACME")
|
|
|
|
acmeCert, err := x509.ParseCertificate(certs[0])
|
|
require.NoError(t, err, "failed parsing acme cert bytes")
|
|
|
|
return acmeCert
|
|
}
|