vault/builtin/logical/pki/acme_state.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

683 lines
19 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package pki
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"path"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/go-secure-stdlib/nonceutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
const (
// How many bytes are in a token. Per RFC 8555 Section
// 8.3. HTTP Challenge and Section 11.3 Token Entropy:
//
// > token (required, string): A random value that uniquely identifies
// > the challenge. This value MUST have at least 128 bits of entropy.
tokenBytes = 128 / 8
// Path Prefixes
acmePathPrefix = "acme/"
acmeAccountPrefix = acmePathPrefix + "accounts/"
acmeThumbprintPrefix = acmePathPrefix + "account-thumbprints/"
acmeValidationPrefix = acmePathPrefix + "validations/"
acmeEabPrefix = acmePathPrefix + "eab/"
)
type acmeState struct {
nonces nonceutil.NonceService
validator *ACMEChallengeEngine
configDirty *atomic.Bool
_config sync.RWMutex
config acmeConfigEntry
}
type acmeThumbprint struct {
Kid string `json:"kid"`
Thumbprint string `json:"-"`
}
func NewACMEState() *acmeState {
state := &acmeState{
nonces: nonceutil.NewNonceService(),
validator: NewACMEChallengeEngine(),
configDirty: new(atomic.Bool),
}
// Config hasn't been loaded yet; mark dirty.
state.configDirty.Store(true)
return state
}
func (a *acmeState) Initialize(b *backend, sc *storageContext) error {
// Initialize the nonce service.
if err := a.nonces.Initialize(); err != nil {
return fmt.Errorf("failed to initialize the ACME nonce service: %w", err)
}
// Load the ACME config.
_, err := a.getConfigWithUpdate(sc)
if err != nil {
return fmt.Errorf("error initializing ACME engine: %w", err)
}
// Kick off our ACME challenge validation engine.
go a.validator.Run(b, a, sc)
// All good.
return nil
}
func (a *acmeState) markConfigDirty() {
a.configDirty.Store(true)
}
func (a *acmeState) reloadConfigIfRequired(sc *storageContext) error {
if !a.configDirty.Load() {
return nil
}
a._config.Lock()
defer a._config.Unlock()
if !a.configDirty.Load() {
// Someone beat us to grabbing the above write lock and already
// updated the config.
return nil
}
config, err := sc.getAcmeConfig()
if err != nil {
return fmt.Errorf("failed reading ACME config: %w", err)
}
a.config = *config
a.configDirty.Store(false)
return nil
}
func (a *acmeState) getConfigWithUpdate(sc *storageContext) (*acmeConfigEntry, error) {
if err := a.reloadConfigIfRequired(sc); err != nil {
return nil, err
}
a._config.RLock()
defer a._config.RUnlock()
configCopy := a.config
return &configCopy, nil
}
func (a *acmeState) getConfigWithForcedUpdate(sc *storageContext) (*acmeConfigEntry, error) {
a.markConfigDirty()
return a.getConfigWithUpdate(sc)
}
func (a *acmeState) writeConfig(sc *storageContext, config *acmeConfigEntry) (*acmeConfigEntry, error) {
a._config.Lock()
defer a._config.Unlock()
if err := sc.setAcmeConfig(config); err != nil {
a.markConfigDirty()
return nil, fmt.Errorf("failed writing ACME config: %w", err)
}
if config != nil {
a.config = *config
} else {
a.config = defaultAcmeConfig
}
return config, nil
}
func generateRandomBase64(srcBytes int) (string, error) {
data := make([]byte, 21)
if _, err := io.ReadFull(rand.Reader, data); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(data), nil
}
func (a *acmeState) GetNonce() (string, time.Time, error) {
return a.nonces.Get()
}
func (a *acmeState) RedeemNonce(nonce string) bool {
return a.nonces.Redeem(nonce)
}
func (a *acmeState) DoTidyNonces() {
a.nonces.Tidy()
}
type ACMEAccountStatus string
func (aas ACMEAccountStatus) String() string {
return string(aas)
}
const (
AccountStatusValid ACMEAccountStatus = "valid"
AccountStatusDeactivated ACMEAccountStatus = "deactivated"
AccountStatusRevoked ACMEAccountStatus = "revoked"
)
type acmeAccount struct {
KeyId string `json:"-"`
Status ACMEAccountStatus `json:"status"`
Contact []string `json:"contact"`
TermsOfServiceAgreed bool `json:"terms-of-service-agreed"`
Jwk []byte `json:"jwk"`
AcmeDirectory string `json:"acme-directory"`
AccountCreatedDate time.Time `json:"account-created-date"`
MaxCertExpiry time.Time `json:"account-max-cert-expiry"`
AccountRevokedDate time.Time `json:"account-revoked-date"`
Eab *eabType `json:"eab"`
}
type acmeOrder struct {
OrderId string `json:"-"`
AccountId string `json:"account-id"`
Status ACMEOrderStatusType `json:"status"`
Expires time.Time `json:"expires"`
Identifiers []*ACMEIdentifier `json:"identifiers"`
AuthorizationIds []string `json:"authorization-ids"`
CertificateSerialNumber string `json:"cert-serial-number"`
CertificateExpiry time.Time `json:"cert-expiry"`
// The actual issuer UUID that issued the certificate, blank if an order exists but no certificate was issued.
IssuerId issuerID `json:"issuer-id"`
}
func (o acmeOrder) getIdentifierDNSValues() []string {
var identifiers []string
for _, value := range o.Identifiers {
if value.Type == ACMEDNSIdentifier {
// Here, because of wildcard processing, we need to use the
// original value provided by the caller rather than the
// post-modification (trimmed '*.' prefix) value.
identifiers = append(identifiers, value.OriginalValue)
}
}
return identifiers
}
func (o acmeOrder) getIdentifierIPValues() []net.IP {
var identifiers []net.IP
for _, value := range o.Identifiers {
if value.Type == ACMEIPIdentifier {
identifiers = append(identifiers, net.ParseIP(value.Value))
}
}
return identifiers
}
func (a *acmeState) CreateAccount(ac *acmeContext, c *jwsCtx, contact []string, termsOfServiceAgreed bool, eab *eabType) (*acmeAccount, error) {
// Write out the thumbprint value/entry out first, if we get an error mid-way through
// this is easier to recover from. The new kid with the same existing public key
// will rewrite the thumbprint entry. This goes in hand with LoadAccountByKey that
// will return a nil, nil value if the referenced kid in a loaded thumbprint does not
// exist. This effectively makes this self-healing IF the end-user re-attempts the
// account creation with the same public key.
thumbprint, err := c.GetKeyThumbprint()
if err != nil {
return nil, fmt.Errorf("failed generating thumbprint: %w", err)
}
thumbPrint := &acmeThumbprint{
Kid: c.Kid,
Thumbprint: thumbprint,
}
thumbPrintEntry, err := logical.StorageEntryJSON(acmeThumbprintPrefix+thumbprint, thumbPrint)
if err != nil {
return nil, fmt.Errorf("error generating account thumbprint entry: %w", err)
}
if err = ac.sc.Storage.Put(ac.sc.Context, thumbPrintEntry); err != nil {
return nil, fmt.Errorf("error writing account thumbprint entry: %w", err)
}
// Now write out the main value that the thumbprint points too.
acct := &acmeAccount{
KeyId: c.Kid,
Contact: contact,
TermsOfServiceAgreed: termsOfServiceAgreed,
Jwk: c.Jwk,
Status: AccountStatusValid,
AcmeDirectory: ac.acmeDirectory,
AccountCreatedDate: time.Now(),
Eab: eab,
}
json, err := logical.StorageEntryJSON(acmeAccountPrefix+c.Kid, acct)
if err != nil {
return nil, fmt.Errorf("error creating account entry: %w", err)
}
if err := ac.sc.Storage.Put(ac.sc.Context, json); err != nil {
return nil, fmt.Errorf("error writing account entry: %w", err)
}
return acct, nil
}
func (a *acmeState) UpdateAccount(sc *storageContext, acct *acmeAccount) error {
json, err := logical.StorageEntryJSON(acmeAccountPrefix+acct.KeyId, acct)
if err != nil {
return fmt.Errorf("error creating account entry: %w", err)
}
if err := sc.Storage.Put(sc.Context, json); err != nil {
return fmt.Errorf("error writing account entry: %w", err)
}
return nil
}
// LoadAccount will load the account object based on the passed in keyId field value
// otherwise will return an error if the account does not exist.
func (a *acmeState) LoadAccount(ac *acmeContext, keyId string) (*acmeAccount, error) {
entry, err := ac.sc.Storage.Get(ac.sc.Context, acmeAccountPrefix+keyId)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
if entry == nil {
return nil, fmt.Errorf("account not found: %w", ErrAccountDoesNotExist)
}
var acct acmeAccount
err = entry.DecodeJSON(&acct)
if err != nil {
return nil, fmt.Errorf("error decoding account: %w", err)
}
if acct.AcmeDirectory != ac.acmeDirectory {
return nil, fmt.Errorf("%w: account part of different ACME directory path", ErrMalformed)
}
acct.KeyId = keyId
return &acct, nil
}
// LoadAccountByKey will attempt to load the account based on a key thumbprint. If the thumbprint
// or kid is unknown a nil, nil will be returned.
func (a *acmeState) LoadAccountByKey(ac *acmeContext, keyThumbprint string) (*acmeAccount, error) {
thumbprintEntry, err := ac.sc.Storage.Get(ac.sc.Context, acmeThumbprintPrefix+keyThumbprint)
if err != nil {
return nil, fmt.Errorf("failed loading acme thumbprintEntry for key: %w", err)
}
if thumbprintEntry == nil {
return nil, nil
}
var thumbprint acmeThumbprint
err = thumbprintEntry.DecodeJSON(&thumbprint)
if err != nil {
return nil, fmt.Errorf("failed decoding thumbprint entry: %s: %w", keyThumbprint, err)
}
if len(thumbprint.Kid) == 0 {
return nil, fmt.Errorf("empty kid within thumbprint entry: %s", keyThumbprint)
}
acct, err := a.LoadAccount(ac, thumbprint.Kid)
if err != nil {
// If we fail to lookup the account that the thumbprint entry references, assume a bad
// write previously occurred in which we managed to write out the thumbprint but failed
// writing out the main account information.
if errors.Is(err, ErrAccountDoesNotExist) {
return nil, nil
}
return nil, err
}
return acct, nil
}
func (a *acmeState) LoadJWK(ac *acmeContext, keyId string) ([]byte, error) {
key, err := a.LoadAccount(ac, keyId)
if err != nil {
return nil, err
}
if len(key.Jwk) == 0 {
return nil, fmt.Errorf("malformed key entry lacks JWK")
}
return key.Jwk, nil
}
func (a *acmeState) LoadAuthorization(ac *acmeContext, userCtx *jwsCtx, authId string) (*ACMEAuthorization, error) {
if authId == "" {
return nil, fmt.Errorf("malformed authorization identifier")
}
authorizationPath := getAuthorizationPath(userCtx.Kid, authId)
authz, err := loadAuthorizationAtPath(ac.sc, authorizationPath)
if err != nil {
return nil, err
}
if userCtx.Kid != authz.AccountId {
return nil, ErrUnauthorized
}
return authz, nil
}
func loadAuthorizationAtPath(sc *storageContext, authorizationPath string) (*ACMEAuthorization, error) {
entry, err := sc.Storage.Get(sc.Context, authorizationPath)
if err != nil {
return nil, fmt.Errorf("error loading authorization: %w", err)
}
if entry == nil {
return nil, fmt.Errorf("authorization does not exist: %w", ErrMalformed)
}
var authz ACMEAuthorization
err = entry.DecodeJSON(&authz)
if err != nil {
return nil, fmt.Errorf("error decoding authorization: %w", err)
}
return &authz, nil
}
func (a *acmeState) SaveAuthorization(ac *acmeContext, authz *ACMEAuthorization) error {
path := getAuthorizationPath(authz.AccountId, authz.Id)
return saveAuthorizationAtPath(ac.sc, path, authz)
}
func saveAuthorizationAtPath(sc *storageContext, path string, authz *ACMEAuthorization) error {
if authz.Id == "" {
return fmt.Errorf("invalid authorization, missing id")
}
if authz.AccountId == "" {
return fmt.Errorf("invalid authorization, missing account id")
}
json, err := logical.StorageEntryJSON(path, authz)
if err != nil {
return fmt.Errorf("error creating authorization entry: %w", err)
}
if err = sc.Storage.Put(sc.Context, json); err != nil {
return fmt.Errorf("error writing authorization entry: %w", err)
}
return nil
}
func (a *acmeState) ParseRequestParams(ac *acmeContext, req *logical.Request, data *framework.FieldData) (*jwsCtx, map[string]interface{}, error) {
var c jwsCtx
var m map[string]interface{}
// Parse the key out.
rawJWKBase64, ok := data.GetOk("protected")
if !ok {
return nil, nil, fmt.Errorf("missing required field 'protected': %w", ErrMalformed)
}
jwkBase64 := rawJWKBase64.(string)
jwkBytes, err := base64.RawURLEncoding.DecodeString(jwkBase64)
if err != nil {
return nil, nil, fmt.Errorf("failed to base64 parse 'protected': %s: %w", err, ErrMalformed)
}
if err = c.UnmarshalOuterJwsJson(a, ac, jwkBytes); err != nil {
return nil, nil, fmt.Errorf("failed to json unmarshal 'protected': %w", err)
}
// Since we already parsed the header to verify the JWS context, we
// should read and redeem the nonce here too, to avoid doing any extra
// work if it is invalid.
if !a.RedeemNonce(c.Nonce) {
return nil, nil, fmt.Errorf("invalid or reused nonce: %w", ErrBadNonce)
}
// If the path is incorrect, reject the request.
//
// See RFC 8555 Section 6.4. Request URL Integrity:
//
// > As noted in Section 6.2, all ACME request objects carry a "url"
// > header parameter in their protected header. ... On receiving such
// > an object in an HTTP request, the server MUST compare the "url"
// > header parameter to the request URL. If the two do not match,
// > then the server MUST reject the request as unauthorized.
if len(c.Url) == 0 {
return nil, nil, fmt.Errorf("missing required parameter 'url' in 'protected': %w", ErrMalformed)
}
if ac.clusterUrl.JoinPath(req.Path).String() != c.Url {
return nil, nil, fmt.Errorf("invalid value for 'url' in 'protected': got '%v' expected '%v': %w", c.Url, ac.clusterUrl.JoinPath(req.Path).String(), ErrUnauthorized)
}
rawPayloadBase64, ok := data.GetOk("payload")
if !ok {
return nil, nil, fmt.Errorf("missing required field 'payload': %w", ErrMalformed)
}
payloadBase64 := rawPayloadBase64.(string)
rawSignatureBase64, ok := data.GetOk("signature")
if !ok {
return nil, nil, fmt.Errorf("missing required field 'signature': %w", ErrMalformed)
}
signatureBase64 := rawSignatureBase64.(string)
// go-jose only seems to support compact signature encodings.
compactSig := fmt.Sprintf("%v.%v.%v", jwkBase64, payloadBase64, signatureBase64)
m, err = c.VerifyJWS(compactSig)
if err != nil {
return nil, nil, fmt.Errorf("failed to verify signature: %w", err)
}
return &c, m, nil
}
func (a *acmeState) LoadOrder(ac *acmeContext, userCtx *jwsCtx, orderId string) (*acmeOrder, error) {
path := getOrderPath(userCtx.Kid, orderId)
entry, err := ac.sc.Storage.Get(ac.sc.Context, path)
if err != nil {
return nil, fmt.Errorf("error loading order: %w", err)
}
if entry == nil {
return nil, fmt.Errorf("order does not exist: %w", ErrMalformed)
}
var order acmeOrder
err = entry.DecodeJSON(&order)
if err != nil {
return nil, fmt.Errorf("error decoding order: %w", err)
}
if userCtx.Kid != order.AccountId {
return nil, ErrUnauthorized
}
order.OrderId = orderId
return &order, nil
}
func (a *acmeState) SaveOrder(ac *acmeContext, order *acmeOrder) error {
if order.OrderId == "" {
return fmt.Errorf("invalid order, missing order id")
}
if order.AccountId == "" {
return fmt.Errorf("invalid order, missing account id")
}
path := getOrderPath(order.AccountId, order.OrderId)
json, err := logical.StorageEntryJSON(path, order)
if err != nil {
return fmt.Errorf("error serializing order entry: %w", err)
}
if err = ac.sc.Storage.Put(ac.sc.Context, json); err != nil {
return fmt.Errorf("error writing order entry: %w", err)
}
return nil
}
func (a *acmeState) ListOrderIds(sc *storageContext, accountId string) ([]string, error) {
accountOrderPrefixPath := acmeAccountPrefix + accountId + "/orders/"
rawOrderIds, err := sc.Storage.List(sc.Context, accountOrderPrefixPath)
if err != nil {
return nil, fmt.Errorf("failed listing order ids for account %s: %w", accountId, err)
}
orderIds := []string{}
for _, order := range rawOrderIds {
if strings.HasSuffix(order, "/") {
// skip any folders we might have for some reason
continue
}
orderIds = append(orderIds, order)
}
return orderIds, nil
}
type acmeCertEntry struct {
Serial string `json:"-"`
Account string `json:"-"`
Order string `json:"order"`
}
func (a *acmeState) TrackIssuedCert(ac *acmeContext, accountId string, serial string, orderId string) error {
path := getAcmeSerialToAccountTrackerPath(accountId, serial)
entry := acmeCertEntry{
Order: orderId,
}
json, err := logical.StorageEntryJSON(path, &entry)
if err != nil {
return fmt.Errorf("error serializing acme cert entry: %w", err)
}
if err = ac.sc.Storage.Put(ac.sc.Context, json); err != nil {
return fmt.Errorf("error writing acme cert entry: %w", err)
}
return nil
}
func (a *acmeState) GetIssuedCert(ac *acmeContext, accountId string, serial string) (*acmeCertEntry, error) {
path := acmeAccountPrefix + accountId + "/certs/" + normalizeSerial(serial)
entry, err := ac.sc.Storage.Get(ac.sc.Context, path)
if err != nil {
return nil, fmt.Errorf("error loading acme cert entry: %w", err)
}
if entry == nil {
return nil, fmt.Errorf("no certificate with this serial was issued for this account")
}
var cert acmeCertEntry
err = entry.DecodeJSON(&cert)
if err != nil {
return nil, fmt.Errorf("error decoding acme cert entry: %w", err)
}
cert.Serial = denormalizeSerial(serial)
cert.Account = accountId
return &cert, nil
}
func (a *acmeState) SaveEab(sc *storageContext, eab *eabType) error {
json, err := logical.StorageEntryJSON(path.Join(acmeEabPrefix, eab.KeyID), eab)
if err != nil {
return err
}
return sc.Storage.Put(sc.Context, json)
}
func (a *acmeState) LoadEab(sc *storageContext, eabKid string) (*eabType, error) {
rawEntry, err := sc.Storage.Get(sc.Context, path.Join(acmeEabPrefix, eabKid))
if err != nil {
return nil, err
}
if rawEntry == nil {
return nil, fmt.Errorf("%w: no eab found for kid %s", ErrStorageItemNotFound, eabKid)
}
var eab eabType
err = rawEntry.DecodeJSON(&eab)
if err != nil {
return nil, err
}
eab.KeyID = eabKid
return &eab, nil
}
func (a *acmeState) DeleteEab(sc *storageContext, eabKid string) (bool, error) {
rawEntry, err := sc.Storage.Get(sc.Context, path.Join(acmeEabPrefix, eabKid))
if err != nil {
return false, err
}
if rawEntry == nil {
return false, nil
}
err = sc.Storage.Delete(sc.Context, path.Join(acmeEabPrefix, eabKid))
if err != nil {
return false, err
}
return true, nil
}
func (a *acmeState) ListEabIds(sc *storageContext) ([]string, error) {
entries, err := sc.Storage.List(sc.Context, acmeEabPrefix)
if err != nil {
return nil, err
}
var ids []string
for _, entry := range entries {
if strings.HasSuffix(entry, "/") {
continue
}
ids = append(ids, entry)
}
return ids, nil
}
func getAcmeSerialToAccountTrackerPath(accountId string, serial string) string {
return acmeAccountPrefix + accountId + "/certs/" + normalizeSerial(serial)
}
func getAuthorizationPath(accountId string, authId string) string {
return acmeAccountPrefix + accountId + "/authorizations/" + authId
}
func getOrderPath(accountId string, orderId string) string {
return acmeAccountPrefix + accountId + "/orders/" + orderId
}
func getACMEToken() (string, error) {
return generateRandomBase64(tokenBytes)
}