vault/sdk/helper/keysutil/policy.go
Alexander Scheel daf72aa427
Fix transit import/export of hmac-only keys (#20864)
* Fix export of HMAC typed keys

When initially implemented, exporting HMAC keys resulted in returning
the unused, internal HMACKey value rather than the main Key value that
is used for HMAC operations.

This is a breaking change.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Consistently handle HMAC keys in keysutil

When generating HMAC-typed keys, set HMACKey = Key consistently, to
allow users of HMAC-typed keys to use them backwards compatibly.

Notably, this could discard the (unused) HMACKey field set today.

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Add test proving export of HMAC keys work

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

* Add changelog entry

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>

---------

Signed-off-by: Alexander Scheel <alex.scheel@hashicorp.com>
2023-05-31 18:04:08 +00:00

2396 lines
68 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package keysutil
import (
"bytes"
"context"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"hash"
"io"
"math/big"
"path"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/hkdf"
"github.com/hashicorp/errwrap"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/helper/jsonutil"
"github.com/hashicorp/vault/sdk/helper/kdf"
"github.com/hashicorp/vault/sdk/logical"
"github.com/google/tink/go/kwp/subtle"
)
// Careful with iota; don't put anything before it in this const block because
// we need the default of zero to be the old-style KDF
const (
Kdf_hmac_sha256_counter = iota // built-in helper
Kdf_hkdf_sha256 // golang.org/x/crypto/hkdf
HmacMinKeySize = 256 / 8
HmacMaxKeySize = 4096 / 8
)
// Or this one...we need the default of zero to be the original AES256-GCM96
const (
KeyType_AES256_GCM96 = iota
KeyType_ECDSA_P256
KeyType_ED25519
KeyType_RSA2048
KeyType_RSA4096
KeyType_ChaCha20_Poly1305
KeyType_ECDSA_P384
KeyType_ECDSA_P521
KeyType_AES128_GCM96
KeyType_RSA3072
KeyType_MANAGED_KEY
KeyType_HMAC
)
const (
// ErrTooOld is returned whtn the ciphertext or signatures's key version is
// too old.
ErrTooOld = "ciphertext or signature version is disallowed by policy (too old)"
// DefaultVersionTemplate is used when no version template is provided.
DefaultVersionTemplate = "vault:v{{version}}:"
)
type AEADFactory interface {
GetAEAD(iv []byte) (cipher.AEAD, error)
}
type AssociatedDataFactory interface {
GetAssociatedData() ([]byte, error)
}
type ManagedKeyFactory interface {
GetManagedKeyParameters() ManagedKeyParameters
}
type RestoreInfo struct {
Time time.Time `json:"time"`
Version int `json:"version"`
}
type BackupInfo struct {
Time time.Time `json:"time"`
Version int `json:"version"`
}
type SigningOptions struct {
HashAlgorithm HashType
Marshaling MarshalingType
SaltLength int
SigAlgorithm string
ManagedKeyParams ManagedKeyParameters
}
type SigningResult struct {
Signature string
PublicKey []byte
}
type ecdsaSignature struct {
R, S *big.Int
}
type KeyType int
func (kt KeyType) EncryptionSupported() bool {
switch kt {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
return true
}
return false
}
func (kt KeyType) DecryptionSupported() bool {
switch kt {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
return true
}
return false
}
func (kt KeyType) SigningSupported() bool {
switch kt {
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_MANAGED_KEY:
return true
}
return false
}
func (kt KeyType) HashSignatureInput() bool {
switch kt {
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
return true
}
return false
}
func (kt KeyType) DerivationSupported() bool {
switch kt {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_ED25519:
return true
}
return false
}
func (kt KeyType) AssociatedDataSupported() bool {
switch kt {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_MANAGED_KEY:
return true
}
return false
}
func (kt KeyType) ImportPublicKeySupported() bool {
switch kt {
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519:
return true
}
return false
}
func (kt KeyType) String() string {
switch kt {
case KeyType_AES128_GCM96:
return "aes128-gcm96"
case KeyType_AES256_GCM96:
return "aes256-gcm96"
case KeyType_ChaCha20_Poly1305:
return "chacha20-poly1305"
case KeyType_ECDSA_P256:
return "ecdsa-p256"
case KeyType_ECDSA_P384:
return "ecdsa-p384"
case KeyType_ECDSA_P521:
return "ecdsa-p521"
case KeyType_ED25519:
return "ed25519"
case KeyType_RSA2048:
return "rsa-2048"
case KeyType_RSA3072:
return "rsa-3072"
case KeyType_RSA4096:
return "rsa-4096"
case KeyType_HMAC:
return "hmac"
case KeyType_MANAGED_KEY:
return "managed_key"
}
return "[unknown]"
}
type KeyData struct {
Policy *Policy `json:"policy"`
ArchivedKeys *archivedKeys `json:"archived_keys"`
}
// KeyEntry stores the key and metadata
type KeyEntry struct {
// AES or some other kind that is a pure byte slice like ED25519
Key []byte `json:"key"`
// Key used for HMAC functions
HMACKey []byte `json:"hmac_key"`
// Time of creation
CreationTime time.Time `json:"time"`
EC_X *big.Int `json:"ec_x"`
EC_Y *big.Int `json:"ec_y"`
EC_D *big.Int `json:"ec_d"`
RSAKey *rsa.PrivateKey `json:"rsa_key"`
RSAPublicKey *rsa.PublicKey `json:"rsa_public_key"`
// The public key in an appropriate format for the type of key
FormattedPublicKey string `json:"public_key"`
// If convergent is enabled, the version (falling back to what's in the
// policy)
ConvergentVersion int `json:"convergent_version"`
// This is deprecated (but still filled) in favor of the value above which
// is more precise
DeprecatedCreationTime int64 `json:"creation_time"`
ManagedKeyUUID string `json:"managed_key_id,omitempty"`
}
func (ke *KeyEntry) IsPrivateKeyMissing() bool {
if ke.RSAKey != nil || ke.EC_D != nil || len(ke.Key) != 0 {
return false
}
return true
}
// deprecatedKeyEntryMap is used to allow JSON marshal/unmarshal
type deprecatedKeyEntryMap map[int]KeyEntry
// MarshalJSON implements JSON marshaling
func (kem deprecatedKeyEntryMap) MarshalJSON() ([]byte, error) {
intermediate := map[string]KeyEntry{}
for k, v := range kem {
intermediate[strconv.Itoa(k)] = v
}
return json.Marshal(&intermediate)
}
// MarshalJSON implements JSON unmarshalling
func (kem deprecatedKeyEntryMap) UnmarshalJSON(data []byte) error {
intermediate := map[string]KeyEntry{}
if err := jsonutil.DecodeJSON(data, &intermediate); err != nil {
return err
}
for k, v := range intermediate {
keyval, err := strconv.Atoi(k)
if err != nil {
return err
}
kem[keyval] = v
}
return nil
}
// keyEntryMap is used to allow JSON marshal/unmarshal
type keyEntryMap map[string]KeyEntry
// PolicyConfig is used to create a new policy
type PolicyConfig struct {
// The name of the policy
Name string `json:"name"`
// The type of key
Type KeyType
// Derived keys MUST provide a context and the master underlying key is
// never used.
Derived bool
KDF int
ConvergentEncryption bool
// Whether the key is exportable
Exportable bool
// Whether the key is allowed to be deleted
DeletionAllowed bool
// AllowPlaintextBackup allows taking backup of the policy in plaintext
AllowPlaintextBackup bool
// VersionTemplate is used to prefix the ciphertext with information about
// the key version. It must inclide {{version}} and a delimiter between the
// version prefix and the ciphertext.
VersionTemplate string
// StoragePrefix is used to add a prefix when storing and retrieving the
// policy object.
StoragePrefix string
}
// NewPolicy takes a policy config and returns a Policy with those settings.
func NewPolicy(config PolicyConfig) *Policy {
return &Policy{
l: new(sync.RWMutex),
Name: config.Name,
Type: config.Type,
Derived: config.Derived,
KDF: config.KDF,
ConvergentEncryption: config.ConvergentEncryption,
ConvergentVersion: -1,
Exportable: config.Exportable,
DeletionAllowed: config.DeletionAllowed,
AllowPlaintextBackup: config.AllowPlaintextBackup,
VersionTemplate: config.VersionTemplate,
StoragePrefix: config.StoragePrefix,
}
}
// LoadPolicy will load a policy from the provided storage path and set the
// necessary un-exported variables. It is particularly useful when accessing a
// policy without the lock manager.
func LoadPolicy(ctx context.Context, s logical.Storage, path string) (*Policy, error) {
raw, err := s.Get(ctx, path)
if err != nil {
return nil, err
}
if raw == nil {
return nil, nil
}
var policy Policy
err = jsonutil.DecodeJSON(raw.Value, &policy)
if err != nil {
return nil, err
}
// Migrate RSA private keys to include their private counterpart. This lets
// us reference RSAPublicKey whenever we need to, without necessarily
// needing the private key handy, synchronizing the behavior with EC and
// Ed25519 key pairs.
switch policy.Type {
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
for _, entry := range policy.Keys {
if entry.RSAPublicKey == nil && entry.RSAKey != nil {
entry.RSAPublicKey = entry.RSAKey.Public().(*rsa.PublicKey)
}
}
}
policy.l = new(sync.RWMutex)
return &policy, nil
}
// Policy is the struct used to store metadata
type Policy struct {
// This is a pointer on purpose: if we are running with cache disabled we
// need to actually swap in the lock manager's lock for this policy with
// the local lock.
l *sync.RWMutex
// writeLocked allows us to implement Lock() and Unlock()
writeLocked bool
// Stores whether it's been deleted. This acts as a guard for operations
// that may write data, e.g. if one request rotates and that request is
// served after a delete.
deleted uint32
Name string `json:"name"`
Key []byte `json:"key,omitempty"` // DEPRECATED
KeySize int `json:"key_size,omitempty"` // For algorithms with variable key sizes
Keys keyEntryMap `json:"keys"`
// Derived keys MUST provide a context and the master underlying key is
// never used. If convergent encryption is true, the context will be used
// as the nonce as well.
Derived bool `json:"derived"`
KDF int `json:"kdf"`
ConvergentEncryption bool `json:"convergent_encryption"`
// Whether the key is exportable
Exportable bool `json:"exportable"`
// The minimum version of the key allowed to be used for decryption
MinDecryptionVersion int `json:"min_decryption_version"`
// The minimum version of the key allowed to be used for encryption
MinEncryptionVersion int `json:"min_encryption_version"`
// The latest key version in this policy
LatestVersion int `json:"latest_version"`
// The latest key version in the archive. We never delete these, so this is
// a max.
ArchiveVersion int `json:"archive_version"`
// ArchiveMinVersion is the minimum version of the key in the archive.
ArchiveMinVersion int `json:"archive_min_version"`
// MinAvailableVersion is the minimum version of the key present. All key
// versions before this would have been deleted.
MinAvailableVersion int `json:"min_available_version"`
// Whether the key is allowed to be deleted
DeletionAllowed bool `json:"deletion_allowed"`
// The version of the convergent nonce to use
ConvergentVersion int `json:"convergent_version"`
// The type of key
Type KeyType `json:"type"`
// BackupInfo indicates the information about the backup action taken on
// this policy
BackupInfo *BackupInfo `json:"backup_info"`
// RestoreInfo indicates the information about the restore action taken on
// this policy
RestoreInfo *RestoreInfo `json:"restore_info"`
// AllowPlaintextBackup allows taking backup of the policy in plaintext
AllowPlaintextBackup bool `json:"allow_plaintext_backup"`
// VersionTemplate is used to prefix the ciphertext with information about
// the key version. It must inclide {{version}} and a delimiter between the
// version prefix and the ciphertext.
VersionTemplate string `json:"version_template"`
// StoragePrefix is used to add a prefix when storing and retrieving the
// policy object.
StoragePrefix string `json:"storage_prefix"`
// AutoRotatePeriod defines how frequently the key should automatically
// rotate. Setting this to zero disables automatic rotation for the key.
AutoRotatePeriod time.Duration `json:"auto_rotate_period"`
// versionPrefixCache stores caches of version prefix strings and the split
// version template.
versionPrefixCache sync.Map
// Imported indicates whether the key was generated by Vault or imported
// from an external source
Imported bool
// AllowImportedKeyRotation indicates whether an imported key may be rotated by Vault
AllowImportedKeyRotation bool
}
func (p *Policy) Lock(exclusive bool) {
if exclusive {
p.l.Lock()
p.writeLocked = true
} else {
p.l.RLock()
}
}
func (p *Policy) Unlock() {
if p.writeLocked {
p.writeLocked = false
p.l.Unlock()
} else {
p.l.RUnlock()
}
}
// ArchivedKeys stores old keys. This is used to keep the key loading time sane
// when there are huge numbers of rotations.
type archivedKeys struct {
Keys []KeyEntry `json:"keys"`
}
func (p *Policy) LoadArchive(ctx context.Context, storage logical.Storage) (*archivedKeys, error) {
archive := &archivedKeys{}
raw, err := storage.Get(ctx, path.Join(p.StoragePrefix, "archive", p.Name))
if err != nil {
return nil, err
}
if raw == nil {
archive.Keys = make([]KeyEntry, 0)
return archive, nil
}
if err := jsonutil.DecodeJSON(raw.Value, archive); err != nil {
return nil, err
}
return archive, nil
}
func (p *Policy) storeArchive(ctx context.Context, storage logical.Storage, archive *archivedKeys) error {
// Encode the policy
buf, err := json.Marshal(archive)
if err != nil {
return err
}
// Write the policy into storage
err = storage.Put(ctx, &logical.StorageEntry{
Key: path.Join(p.StoragePrefix, "archive", p.Name),
Value: buf,
})
if err != nil {
return err
}
return nil
}
// handleArchiving manages the movement of keys to and from the policy archive.
// This should *ONLY* be called from Persist() since it assumes that the policy
// will be persisted afterwards.
func (p *Policy) handleArchiving(ctx context.Context, storage logical.Storage) error {
// We need to move keys that are no longer accessible to archivedKeys, and keys
// that now need to be accessible back here.
//
// For safety, because there isn't really a good reason to, we never delete
// keys from the archive even when we move them back.
// Check if we have the latest minimum version in the current set of keys
_, keysContainsMinimum := p.Keys[strconv.Itoa(p.MinDecryptionVersion)]
// Sanity checks
switch {
case p.MinDecryptionVersion < 1:
return fmt.Errorf("minimum decryption version of %d is less than 1", p.MinDecryptionVersion)
case p.LatestVersion < 1:
return fmt.Errorf("latest version of %d is less than 1", p.LatestVersion)
case !keysContainsMinimum && p.ArchiveVersion != p.LatestVersion:
return fmt.Errorf("need to move keys from archive but archive version not up-to-date")
case p.ArchiveVersion > p.LatestVersion:
return fmt.Errorf("archive version of %d is greater than the latest version %d",
p.ArchiveVersion, p.LatestVersion)
case p.MinEncryptionVersion > 0 && p.MinEncryptionVersion < p.MinDecryptionVersion:
return fmt.Errorf("minimum decryption version of %d is greater than minimum encryption version %d",
p.MinDecryptionVersion, p.MinEncryptionVersion)
case p.MinDecryptionVersion > p.LatestVersion:
return fmt.Errorf("minimum decryption version of %d is greater than the latest version %d",
p.MinDecryptionVersion, p.LatestVersion)
}
archive, err := p.LoadArchive(ctx, storage)
if err != nil {
return err
}
if !keysContainsMinimum {
// Need to move keys *from* archive
for i := p.MinDecryptionVersion; i <= p.LatestVersion; i++ {
p.Keys[strconv.Itoa(i)] = archive.Keys[i-p.MinAvailableVersion]
}
return nil
}
// Need to move keys *to* archive
// We need a size that is equivalent to the latest version (number of keys)
// but adding one since slice numbering starts at 0 and we're indexing by
// key version
if len(archive.Keys)+p.MinAvailableVersion < p.LatestVersion+1 {
// Increase the size of the archive slice
newKeys := make([]KeyEntry, p.LatestVersion-p.MinAvailableVersion+1)
copy(newKeys, archive.Keys)
archive.Keys = newKeys
}
// We are storing all keys in the archive, so we ensure that it is up to
// date up to p.LatestVersion
for i := p.ArchiveVersion + 1; i <= p.LatestVersion; i++ {
archive.Keys[i-p.MinAvailableVersion] = p.Keys[strconv.Itoa(i)]
p.ArchiveVersion = i
}
// Trim the keys if required
if p.ArchiveMinVersion < p.MinAvailableVersion {
archive.Keys = archive.Keys[p.MinAvailableVersion-p.ArchiveMinVersion:]
p.ArchiveMinVersion = p.MinAvailableVersion
}
err = p.storeArchive(ctx, storage, archive)
if err != nil {
return err
}
// Perform deletion afterwards so that if there is an error saving we
// haven't messed with the current policy
for i := p.LatestVersion - len(p.Keys) + 1; i < p.MinDecryptionVersion; i++ {
delete(p.Keys, strconv.Itoa(i))
}
return nil
}
func (p *Policy) Persist(ctx context.Context, storage logical.Storage) (retErr error) {
if atomic.LoadUint32(&p.deleted) == 1 {
return errors.New("key has been deleted, not persisting")
}
// Other functions will take care of restoring other values; this is just
// responsible for archiving and keys since the archive function can modify
// keys. At the moment one of the other functions calling persist will also
// roll back keys, but better safe than sorry and this doesn't happen
// enough to worry about the speed tradeoff.
priorArchiveVersion := p.ArchiveVersion
var priorKeys keyEntryMap
if p.Keys != nil {
priorKeys = keyEntryMap{}
for k, v := range p.Keys {
priorKeys[k] = v
}
}
defer func() {
if retErr != nil {
p.ArchiveVersion = priorArchiveVersion
p.Keys = priorKeys
}
}()
err := p.handleArchiving(ctx, storage)
if err != nil {
return err
}
// Encode the policy
buf, err := p.Serialize()
if err != nil {
return err
}
// Write the policy into storage
err = storage.Put(ctx, &logical.StorageEntry{
Key: path.Join(p.StoragePrefix, "policy", p.Name),
Value: buf,
})
if err != nil {
return err
}
return nil
}
func (p *Policy) Serialize() ([]byte, error) {
return json.Marshal(p)
}
func (p *Policy) NeedsUpgrade() bool {
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
return true
}
// With archiving, past assumptions about the length of the keys map are no
// longer valid
if p.LatestVersion == 0 && len(p.Keys) != 0 {
return true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if p.MinDecryptionVersion == 0 {
return true
}
// On first load after an upgrade, copy keys to the archive
if p.ArchiveVersion == 0 {
return true
}
// Need to write the version if zero; for version 3 on we set this to -1 to
// ignore it since we store this information in each key entry
if p.ConvergentEncryption && p.ConvergentVersion == 0 {
return true
}
if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
return true
}
return false
}
func (p *Policy) Upgrade(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
priorKey := p.Key
priorLatestVersion := p.LatestVersion
priorMinDecryptionVersion := p.MinDecryptionVersion
priorConvergentVersion := p.ConvergentVersion
var priorKeys keyEntryMap
if p.Keys != nil {
priorKeys = keyEntryMap{}
for k, v := range p.Keys {
priorKeys[k] = v
}
}
defer func() {
if retErr != nil {
p.Key = priorKey
p.LatestVersion = priorLatestVersion
p.MinDecryptionVersion = priorMinDecryptionVersion
p.ConvergentVersion = priorConvergentVersion
p.Keys = priorKeys
}
}()
persistNeeded := false
// Ensure we've moved from Key -> Keys
if p.Key != nil && len(p.Key) > 0 {
p.MigrateKeyToKeysMap()
persistNeeded = true
}
// With archiving, past assumptions about the length of the keys map are no
// longer valid
if p.LatestVersion == 0 && len(p.Keys) != 0 {
p.LatestVersion = len(p.Keys)
persistNeeded = true
}
// We disallow setting the version to 0, since they start at 1 since moving
// to rotate-able keys, so update if it's set to 0
if p.MinDecryptionVersion == 0 {
p.MinDecryptionVersion = 1
persistNeeded = true
}
// On first load after an upgrade, copy keys to the archive
if p.ArchiveVersion == 0 {
persistNeeded = true
}
if p.ConvergentEncryption && p.ConvergentVersion == 0 {
p.ConvergentVersion = 1
persistNeeded = true
}
if p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey == nil || len(p.Keys[strconv.Itoa(p.LatestVersion)].HMACKey) == 0 {
entry := p.Keys[strconv.Itoa(p.LatestVersion)]
hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
if err != nil {
return err
}
entry.HMACKey = hmacKey
p.Keys[strconv.Itoa(p.LatestVersion)] = entry
persistNeeded = true
if p.Type == KeyType_HMAC {
entry.HMACKey = entry.Key
}
}
if persistNeeded {
err := p.Persist(ctx, storage)
if err != nil {
return err
}
}
return nil
}
// GetKey is used to derive the encryption key that should be used depending
// on the policy. If derivation is disabled the raw key is used and no context
// is required, otherwise the KDF mode is used with the context to derive the
// proper key.
func (p *Policy) GetKey(context []byte, ver, numBytes int) ([]byte, error) {
// Fast-path non-derived keys
if !p.Derived {
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return nil, err
}
return keyEntry.Key, nil
}
return p.DeriveKey(context, nil, ver, numBytes)
}
// DeriveKey is used to derive a symmetric key given a context and salt. This does not
// check the policies Derived flag, but just implements the derivation logic. GetKey
// is responsible for switching on the policy config.
func (p *Policy) DeriveKey(context, salt []byte, ver int, numBytes int) ([]byte, error) {
if !p.Type.DerivationSupported() {
return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)}
}
if p.Keys == nil || p.LatestVersion == 0 {
return nil, errutil.InternalError{Err: "unable to access the key; no key versions found"}
}
if ver <= 0 || ver > p.LatestVersion {
return nil, errutil.UserError{Err: "invalid key version"}
}
// Ensure a context is provided
if len(context) == 0 {
return nil, errutil.UserError{Err: "missing 'context' for key derivation; the key was created using a derived key, which means additional, per-request information must be included in order to perform operations with the key"}
}
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return nil, err
}
switch p.KDF {
case Kdf_hmac_sha256_counter:
prf := kdf.HMACSHA256PRF
prfLen := kdf.HMACSHA256PRFLen
return kdf.CounterMode(prf, prfLen, keyEntry.Key, append(context, salt...), 256)
case Kdf_hkdf_sha256:
reader := hkdf.New(sha256.New, keyEntry.Key, salt, context)
derBytes := bytes.NewBuffer(nil)
derBytes.Grow(numBytes)
limReader := &io.LimitedReader{
R: reader,
N: int64(numBytes),
}
switch p.Type {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
n, err := derBytes.ReadFrom(limReader)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error reading returned derived bytes: %v", err)}
}
if n != int64(numBytes) {
return nil, errutil.InternalError{Err: fmt.Sprintf("unable to read enough derived bytes, needed %d, got %d", numBytes, n)}
}
return derBytes.Bytes(), nil
case KeyType_ED25519:
// We use the limited reader containing the derived bytes as the
// "random" input to the generation function
_, pri, err := ed25519.GenerateKey(limReader)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error generating derived key: %v", err)}
}
return pri, nil
default:
return nil, errutil.InternalError{Err: "unsupported key type for derivation"}
}
default:
return nil, errutil.InternalError{Err: "unsupported key derivation mode"}
}
}
func (p *Policy) safeGetKeyEntry(ver int) (KeyEntry, error) {
keyVerStr := strconv.Itoa(ver)
keyEntry, ok := p.Keys[keyVerStr]
if !ok {
return keyEntry, errutil.UserError{Err: "no such key version"}
}
return keyEntry, nil
}
func (p *Policy) convergentVersion(ver int) int {
if !p.ConvergentEncryption {
return 0
}
convergentVersion := p.ConvergentVersion
if convergentVersion == 0 {
// For some reason, not upgraded yet
convergentVersion = 1
}
currKey := p.Keys[strconv.Itoa(ver)]
if currKey.ConvergentVersion != 0 {
convergentVersion = currKey.ConvergentVersion
}
return convergentVersion
}
func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, error) {
return p.EncryptWithFactory(ver, context, nonce, value, nil)
}
func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) {
return p.DecryptWithFactory(context, nonce, value, nil)
}
func (p *Policy) DecryptWithFactory(context, nonce []byte, value string, factories ...interface{}) (string, error) {
if !p.Type.DecryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message decryption not supported for key type %v", p.Type)}
}
tplParts, err := p.getTemplateParts()
if err != nil {
return "", err
}
// Verify the prefix
if !strings.HasPrefix(value, tplParts[0]) {
return "", errutil.UserError{Err: "invalid ciphertext: no prefix"}
}
splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, tplParts[0]), tplParts[1], 2)
if len(splitVerCiphertext) != 2 {
return "", errutil.UserError{Err: "invalid ciphertext: wrong number of fields"}
}
ver, err := strconv.Atoi(splitVerCiphertext[0])
if err != nil {
return "", errutil.UserError{Err: "invalid ciphertext: version number could not be decoded"}
}
if ver == 0 {
// Compatibility mode with initial implementation, where keys start at
// zero
ver = 1
}
if ver > p.LatestVersion {
return "", errutil.UserError{Err: "invalid ciphertext: version is too new"}
}
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
return "", errutil.UserError{Err: ErrTooOld}
}
convergentVersion := p.convergentVersion(ver)
if convergentVersion == 1 && (nonce == nil || len(nonce) == 0) {
return "", errutil.UserError{Err: "invalid convergent nonce supplied"}
}
// Decode the base64
decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1])
if err != nil {
return "", errutil.UserError{Err: "invalid ciphertext: could not decode base64"}
}
var plain []byte
switch p.Type {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
numBytes := 32
if p.Type == KeyType_AES128_GCM96 {
numBytes = 16
}
encKey, err := p.GetKey(context, ver, numBytes)
if err != nil {
return "", err
}
if len(encKey) != numBytes {
return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
}
symopts := SymmetricOpts{
Convergent: p.ConvergentEncryption,
ConvergentVersion: p.ConvergentVersion,
}
for index, rawFactory := range factories {
if rawFactory == nil {
continue
}
switch factory := rawFactory.(type) {
case AEADFactory:
symopts.AEADFactory = factory
case AssociatedDataFactory:
symopts.AdditionalData, err = factory.GetAssociatedData()
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
}
case ManagedKeyFactory:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
}
}
plain, err = p.SymmetricDecryptRaw(encKey, decoded, symopts)
if err != nil {
return "", err
}
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return "", err
}
key := keyEntry.RSAKey
if key == nil {
return "", errutil.InternalError{Err: fmt.Sprintf("cannot decrypt ciphertext, key version does not have a private counterpart")}
}
plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil)
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA decrypt the ciphertext: %v", err)}
}
case KeyType_MANAGED_KEY:
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return "", err
}
var aad []byte
var managedKeyFactory ManagedKeyFactory
for _, f := range factories {
switch factory := f.(type) {
case AssociatedDataFactory:
aad, err = factory.GetAssociatedData()
if err != nil {
return "", err
}
case ManagedKeyFactory:
managedKeyFactory = factory
}
}
if managedKeyFactory == nil {
return "", errors.New("key type is managed_key, but managed key parameters were not provided")
}
plain, err = p.decryptWithManagedKey(managedKeyFactory.GetManagedKeyParameters(), keyEntry, decoded, nonce, aad)
if err != nil {
return "", err
}
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
return base64.StdEncoding.EncodeToString(plain), nil
}
func (p *Policy) HMACKey(version int) ([]byte, error) {
switch {
case version < 0:
return nil, fmt.Errorf("key version does not exist (cannot be negative)")
case version > p.LatestVersion:
return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion)
}
keyEntry, err := p.safeGetKeyEntry(version)
if err != nil {
return nil, err
}
if p.Type == KeyType_HMAC {
return keyEntry.Key, nil
}
if keyEntry.HMACKey == nil {
return nil, fmt.Errorf("no HMAC key exists for that key version")
}
return keyEntry.HMACKey, nil
}
func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) {
return p.SignWithOptions(ver, context, input, &SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: rsa.PSSSaltLengthAuto,
SigAlgorithm: sigAlgorithm,
})
}
func (p *Policy) minRSAPSSSaltLength() int {
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=247
return rsa.PSSSaltLengthEqualsHash
}
func (p *Policy) maxRSAPSSSaltLength(keyBitLen int, hash crypto.Hash) int {
// https://cs.opensource.google/go/go/+/refs/tags/go1.19:src/crypto/rsa/pss.go;l=288
return (keyBitLen-1+7)/8 - 2 - hash.Size()
}
func (p *Policy) validRSAPSSSaltLength(keyBitLen int, hash crypto.Hash, saltLength int) bool {
return p.minRSAPSSSaltLength() <= saltLength && saltLength <= p.maxRSAPSSSaltLength(keyBitLen, hash)
}
func (p *Policy) SignWithOptions(ver int, context, input []byte, options *SigningOptions) (*SigningResult, error) {
if !p.Type.SigningSupported() {
return nil, fmt.Errorf("message signing not supported for key type %v", p.Type)
}
switch {
case ver == 0:
ver = p.LatestVersion
case ver < 0:
return nil, errutil.UserError{Err: "requested version for signing is negative"}
case ver > p.LatestVersion:
return nil, errutil.UserError{Err: "requested version for signing is higher than the latest key version"}
case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
return nil, errutil.UserError{Err: "requested version for signing is less than the minimum encryption key version"}
}
var sig []byte
var pubKey []byte
var err error
keyParams, err := p.safeGetKeyEntry(ver)
if err != nil {
return nil, err
}
// Before signing, check if key has its private part, if not return error
if keyParams.IsPrivateKeyMissing() {
return nil, errutil.UserError{Err: "requested version for signing does not contain a private part"}
}
hashAlgorithm := options.HashAlgorithm
marshaling := options.Marshaling
saltLength := options.SaltLength
sigAlgorithm := options.SigAlgorithm
switch p.Type {
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
var curveBits int
var curve elliptic.Curve
switch p.Type {
case KeyType_ECDSA_P384:
curveBits = 384
curve = elliptic.P384()
case KeyType_ECDSA_P521:
curveBits = 521
curve = elliptic.P521()
default:
curveBits = 256
curve = elliptic.P256()
}
key := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
Curve: curve,
X: keyParams.EC_X,
Y: keyParams.EC_Y,
},
D: keyParams.EC_D,
}
r, s, err := ecdsa.Sign(rand.Reader, key, input)
if err != nil {
return nil, err
}
switch marshaling {
case MarshalingTypeASN1:
// This is used by openssl and X.509
sig, err = asn1.Marshal(ecdsaSignature{
R: r,
S: s,
})
if err != nil {
return nil, err
}
case MarshalingTypeJWS:
// This is used by JWS
// First we have to get the length of the curve in bytes. Although
// we only support 256 now, we'll do this in an agnostic way so we
// can reuse this marshaling if we support e.g. 521. Getting the
// number of bytes without rounding up would be 65.125 so we need
// to add one in that case.
keyLen := curveBits / 8
if curveBits%8 > 0 {
keyLen++
}
// Now create the output array
sig = make([]byte, keyLen*2)
rb := r.Bytes()
sb := s.Bytes()
copy(sig[keyLen-len(rb):], rb)
copy(sig[2*keyLen-len(sb):], sb)
default:
return nil, errutil.UserError{Err: "requested marshaling type is invalid"}
}
case KeyType_ED25519:
var key ed25519.PrivateKey
if p.Derived {
// Derive the key that should be used
var err error
key, err = p.GetKey(context, ver, 32)
if err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
}
pubKey = key.Public().(ed25519.PublicKey)
} else {
key = ed25519.PrivateKey(keyParams.Key)
}
// Per docs, do not pre-hash ed25519; it does two passes and performs
// its own hashing
sig, err = key.Sign(rand.Reader, input, crypto.Hash(0))
if err != nil {
return nil, err
}
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
key := keyParams.RSAKey
algo, ok := CryptoHashMap[hashAlgorithm]
if !ok {
return nil, errutil.InternalError{Err: "unsupported hash algorithm"}
}
if sigAlgorithm == "" {
sigAlgorithm = "pss"
}
switch sigAlgorithm {
case "pss":
if !p.validRSAPSSSaltLength(key.N.BitLen(), algo, saltLength) {
return nil, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
}
sig, err = rsa.SignPSS(rand.Reader, key, algo, input, &rsa.PSSOptions{SaltLength: saltLength})
if err != nil {
return nil, err
}
case "pkcs1v15":
sig, err = rsa.SignPKCS1v15(rand.Reader, key, algo, input)
if err != nil {
return nil, err
}
default:
return nil, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
}
case KeyType_MANAGED_KEY:
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return nil, err
}
sig, err = p.signWithManagedKey(options, keyEntry, input)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unsupported key type %v", p.Type)
}
// Convert to base64
var encoded string
switch marshaling {
case MarshalingTypeASN1:
encoded = base64.StdEncoding.EncodeToString(sig)
case MarshalingTypeJWS:
encoded = base64.RawURLEncoding.EncodeToString(sig)
}
res := &SigningResult{
Signature: p.getVersionPrefix(ver) + encoded,
PublicKey: pubKey,
}
return res, nil
}
func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType, sig string) (bool, error) {
return p.VerifySignatureWithOptions(context, input, sig, &SigningOptions{
HashAlgorithm: hashAlgorithm,
Marshaling: marshaling,
SaltLength: rsa.PSSSaltLengthAuto,
SigAlgorithm: sigAlgorithm,
})
}
func (p *Policy) VerifySignatureWithOptions(context, input []byte, sig string, options *SigningOptions) (bool, error) {
if !p.Type.SigningSupported() {
return false, errutil.UserError{Err: fmt.Sprintf("message verification not supported for key type %v", p.Type)}
}
tplParts, err := p.getTemplateParts()
if err != nil {
return false, err
}
// Verify the prefix
if !strings.HasPrefix(sig, tplParts[0]) {
return false, errutil.UserError{Err: "invalid signature: no prefix"}
}
splitVerSig := strings.SplitN(strings.TrimPrefix(sig, tplParts[0]), tplParts[1], 2)
if len(splitVerSig) != 2 {
return false, errutil.UserError{Err: "invalid signature: wrong number of fields"}
}
ver, err := strconv.Atoi(splitVerSig[0])
if err != nil {
return false, errutil.UserError{Err: "invalid signature: version number could not be decoded"}
}
if ver > p.LatestVersion {
return false, errutil.UserError{Err: "invalid signature: version is too new"}
}
if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion {
return false, errutil.UserError{Err: ErrTooOld}
}
hashAlgorithm := options.HashAlgorithm
marshaling := options.Marshaling
saltLength := options.SaltLength
sigAlgorithm := options.SigAlgorithm
var sigBytes []byte
switch marshaling {
case MarshalingTypeASN1:
sigBytes, err = base64.StdEncoding.DecodeString(splitVerSig[1])
case MarshalingTypeJWS:
sigBytes, err = base64.RawURLEncoding.DecodeString(splitVerSig[1])
default:
return false, errutil.UserError{Err: "requested marshaling type is invalid"}
}
if err != nil {
return false, errutil.UserError{Err: "invalid base64 signature value"}
}
switch p.Type {
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
var curve elliptic.Curve
switch p.Type {
case KeyType_ECDSA_P384:
curve = elliptic.P384()
case KeyType_ECDSA_P521:
curve = elliptic.P521()
default:
curve = elliptic.P256()
}
var ecdsaSig ecdsaSignature
switch marshaling {
case MarshalingTypeASN1:
rest, err := asn1.Unmarshal(sigBytes, &ecdsaSig)
if err != nil {
return false, errutil.UserError{Err: "supplied signature is invalid"}
}
if rest != nil && len(rest) != 0 {
return false, errutil.UserError{Err: "supplied signature contains extra data"}
}
case MarshalingTypeJWS:
paramLen := len(sigBytes) / 2
rb := sigBytes[:paramLen]
sb := sigBytes[paramLen:]
ecdsaSig.R = new(big.Int)
ecdsaSig.R.SetBytes(rb)
ecdsaSig.S = new(big.Int)
ecdsaSig.S.SetBytes(sb)
}
keyParams, err := p.safeGetKeyEntry(ver)
if err != nil {
return false, err
}
key := &ecdsa.PublicKey{
Curve: curve,
X: keyParams.EC_X,
Y: keyParams.EC_Y,
}
return ecdsa.Verify(key, input, ecdsaSig.R, ecdsaSig.S), nil
case KeyType_ED25519:
var pub ed25519.PublicKey
if p.Derived {
// Derive the key that should be used
key, err := p.GetKey(context, ver, 32)
if err != nil {
return false, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)}
}
pub = ed25519.PrivateKey(key).Public().(ed25519.PublicKey)
} else {
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return false, err
}
raw, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
if err != nil {
return false, err
}
pub = ed25519.PublicKey(raw)
}
return ed25519.Verify(pub, input, sigBytes), nil
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return false, err
}
algo, ok := CryptoHashMap[hashAlgorithm]
if !ok {
return false, errutil.InternalError{Err: "unsupported hash algorithm"}
}
if sigAlgorithm == "" {
sigAlgorithm = "pss"
}
switch sigAlgorithm {
case "pss":
publicKey := keyEntry.RSAPublicKey
if !keyEntry.IsPrivateKeyMissing() {
publicKey = &keyEntry.RSAKey.PublicKey
}
if !p.validRSAPSSSaltLength(publicKey.N.BitLen(), algo, saltLength) {
return false, errutil.UserError{Err: fmt.Sprintf("requested salt length %d is invalid", saltLength)}
}
err = rsa.VerifyPSS(publicKey, algo, input, sigBytes, &rsa.PSSOptions{SaltLength: saltLength})
case "pkcs1v15":
publicKey := keyEntry.RSAPublicKey
if !keyEntry.IsPrivateKeyMissing() {
publicKey = &keyEntry.RSAKey.PublicKey
}
err = rsa.VerifyPKCS1v15(publicKey, algo, input, sigBytes)
default:
return false, errutil.InternalError{Err: fmt.Sprintf("unsupported rsa signature algorithm %s", sigAlgorithm)}
}
return err == nil, nil
case KeyType_MANAGED_KEY:
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return false, err
}
return p.verifyWithManagedKey(options, keyEntry, input, sigBytes)
default:
return false, errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
}
func (p *Policy) Import(ctx context.Context, storage logical.Storage, key []byte, randReader io.Reader) error {
return p.ImportPublicOrPrivate(ctx, storage, key, true, randReader)
}
func (p *Policy) ImportPublicOrPrivate(ctx context.Context, storage logical.Storage, key []byte, isPrivateKey bool, randReader io.Reader) error {
now := time.Now()
entry := KeyEntry{
CreationTime: now,
DeprecatedCreationTime: now.Unix(),
}
// Before we insert this entry, check if the latest version is incomplete
// and this entry matches the current version; if so, return without
// updating to the next version.
if p.LatestVersion > 0 {
latestKey := p.Keys[strconv.Itoa(p.LatestVersion)]
if latestKey.IsPrivateKeyMissing() && isPrivateKey {
if err := p.ImportPrivateKeyForVersion(ctx, storage, p.LatestVersion, key); err == nil {
return nil
}
}
}
if p.Type != KeyType_HMAC {
hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
if err != nil {
return err
}
entry.HMACKey = hmacKey
}
if p.Type == KeyType_ED25519 && p.Derived && !isPrivateKey {
return fmt.Errorf("unable to import only public key for derived Ed25519 key: imported key should not be an Ed25519 key pair but is instead an HKDF key")
}
if (p.Type == KeyType_AES128_GCM96 && len(key) != 16) ||
((p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305) && len(key) != 32) ||
(p.Type == KeyType_HMAC && (len(key) < HmacMinKeySize || len(key) > HmacMaxKeySize)) {
return fmt.Errorf("invalid key size %d bytes for key type %s", len(key), p.Type)
}
if p.Type == KeyType_AES128_GCM96 || p.Type == KeyType_AES256_GCM96 || p.Type == KeyType_ChaCha20_Poly1305 || p.Type == KeyType_HMAC {
entry.Key = key
if p.Type == KeyType_HMAC {
p.KeySize = len(key)
entry.HMACKey = key
}
} else {
var parsedKey any
var err error
if isPrivateKey {
parsedKey, err = x509.ParsePKCS8PrivateKey(key)
if err != nil {
if strings.Contains(err.Error(), "unknown elliptic curve") {
var edErr error
parsedKey, edErr = ParsePKCS8Ed25519PrivateKey(key)
if edErr != nil {
return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err)
}
// Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded!
} else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) {
var rsaErr error
parsedKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key)
if rsaErr != nil {
return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err)
}
// Parsing as RSA-PSS in PKCS8 succeeded!
} else {
return fmt.Errorf("error parsing asymmetric key: %s", err)
}
}
} else {
pemBlock, _ := pem.Decode(key)
parsedKey, err = x509.ParsePKIXPublicKey(pemBlock.Bytes)
if err != nil {
return fmt.Errorf("error parsing public key: %s", err)
}
}
err = entry.parseFromKey(p.Type, parsedKey)
if err != nil {
return err
}
}
p.LatestVersion += 1
if p.Keys == nil {
// This is an initial key rotation when generating a new policy. We
// don't need to call migrate here because if we've called getPolicy to
// get the policy in the first place it will have been run.
p.Keys = keyEntryMap{}
}
p.Keys[strconv.Itoa(p.LatestVersion)] = entry
// This ensures that with new key creations min decryption version is set
// to 1 rather than the int default of 0, since keys start at 1 (either
// fresh or after migration to the key map)
if p.MinDecryptionVersion == 0 {
p.MinDecryptionVersion = 1
}
return p.Persist(ctx, storage)
}
// Rotate rotates the policy and persists it to storage.
// If the rotation partially fails, the policy state will be restored.
func (p *Policy) Rotate(ctx context.Context, storage logical.Storage, randReader io.Reader) (retErr error) {
priorLatestVersion := p.LatestVersion
priorMinDecryptionVersion := p.MinDecryptionVersion
var priorKeys keyEntryMap
if p.Imported && !p.AllowImportedKeyRotation {
return fmt.Errorf("imported key %s does not allow rotation within Vault", p.Name)
}
if p.Keys != nil {
priorKeys = keyEntryMap{}
for k, v := range p.Keys {
priorKeys[k] = v
}
}
defer func() {
if retErr != nil {
p.LatestVersion = priorLatestVersion
p.MinDecryptionVersion = priorMinDecryptionVersion
p.Keys = priorKeys
}
}()
if err := p.RotateInMemory(randReader); err != nil {
return err
}
p.Imported = false
return p.Persist(ctx, storage)
}
// RotateInMemory rotates the policy but does not persist it to storage.
func (p *Policy) RotateInMemory(randReader io.Reader) (retErr error) {
now := time.Now()
entry := KeyEntry{
CreationTime: now,
DeprecatedCreationTime: now.Unix(),
}
hmacKey, err := uuid.GenerateRandomBytesWithReader(32, randReader)
if err != nil {
return err
}
entry.HMACKey = hmacKey
switch p.Type {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_HMAC:
// Default to 256 bit key
numBytes := 32
if p.Type == KeyType_AES128_GCM96 {
numBytes = 16
} else if p.Type == KeyType_HMAC {
numBytes = p.KeySize
if numBytes < HmacMinKeySize || numBytes > HmacMaxKeySize {
return fmt.Errorf("invalid key size for HMAC key, must be between %d and %d bytes", HmacMinKeySize, HmacMaxKeySize)
}
}
newKey, err := uuid.GenerateRandomBytesWithReader(numBytes, randReader)
if err != nil {
return err
}
entry.Key = newKey
if p.Type == KeyType_HMAC {
// To avoid causing problems, ensure HMACKey = Key.
entry.HMACKey = newKey
}
case KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521:
var curve elliptic.Curve
switch p.Type {
case KeyType_ECDSA_P384:
curve = elliptic.P384()
case KeyType_ECDSA_P521:
curve = elliptic.P521()
default:
curve = elliptic.P256()
}
privKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return err
}
entry.EC_D = privKey.D
entry.EC_X = privKey.X
entry.EC_Y = privKey.Y
derBytes, err := x509.MarshalPKIXPublicKey(privKey.Public())
if err != nil {
return errwrap.Wrapf("error marshaling public key: {{err}}", err)
}
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
pemBytes := pem.EncodeToMemory(pemBlock)
if pemBytes == nil || len(pemBytes) == 0 {
return fmt.Errorf("error PEM-encoding public key")
}
entry.FormattedPublicKey = string(pemBytes)
case KeyType_ED25519:
// Go uses a 64-byte private key for Ed25519 keys (private+public, each
// 32-bytes long). When we do Key derivation, we still generate a 32-byte
// random value (and compute the corresponding Ed25519 public key), but
// use this entire 64-byte key as if it was an HKDF key. The corresponding
// underlying public key is never returned (which is probably good, because
// doing so would leak half of our HKDF key...), but means we cannot import
// derived-enabled Ed25519 public key components.
pub, pri, err := ed25519.GenerateKey(randReader)
if err != nil {
return err
}
entry.Key = pri
entry.FormattedPublicKey = base64.StdEncoding.EncodeToString(pub)
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
bitSize := 2048
if p.Type == KeyType_RSA3072 {
bitSize = 3072
}
if p.Type == KeyType_RSA4096 {
bitSize = 4096
}
entry.RSAKey, err = rsa.GenerateKey(randReader, bitSize)
if err != nil {
return err
}
}
if p.ConvergentEncryption {
if p.ConvergentVersion == -1 || p.ConvergentVersion > 1 {
entry.ConvergentVersion = currentConvergentVersion
}
}
p.LatestVersion += 1
if p.Keys == nil {
// This is an initial key rotation when generating a new policy. We
// don't need to call migrate here because if we've called getPolicy to
// get the policy in the first place it will have been run.
p.Keys = keyEntryMap{}
}
p.Keys[strconv.Itoa(p.LatestVersion)] = entry
// This ensures that with new key creations min decryption version is set
// to 1 rather than the int default of 0, since keys start at 1 (either
// fresh or after migration to the key map)
if p.MinDecryptionVersion == 0 {
p.MinDecryptionVersion = 1
}
return nil
}
func (p *Policy) MigrateKeyToKeysMap() {
now := time.Now()
p.Keys = keyEntryMap{
"1": KeyEntry{
Key: p.Key,
CreationTime: now,
DeprecatedCreationTime: now.Unix(),
},
}
p.Key = nil
}
// Backup should be called with an exclusive lock held on the policy
func (p *Policy) Backup(ctx context.Context, storage logical.Storage) (out string, retErr error) {
if !p.Exportable {
return "", fmt.Errorf("exporting is disallowed on the policy")
}
if !p.AllowPlaintextBackup {
return "", fmt.Errorf("plaintext backup is disallowed on the policy")
}
priorBackupInfo := p.BackupInfo
defer func() {
if retErr != nil {
p.BackupInfo = priorBackupInfo
}
}()
// Create a record of this backup operation in the policy
p.BackupInfo = &BackupInfo{
Time: time.Now(),
Version: p.LatestVersion,
}
err := p.Persist(ctx, storage)
if err != nil {
return "", errwrap.Wrapf("failed to persist policy with backup info: {{err}}", err)
}
// Load the archive only after persisting the policy as the archive can get
// adjusted while persisting the policy
archivedKeys, err := p.LoadArchive(ctx, storage)
if err != nil {
return "", err
}
keyData := &KeyData{
Policy: p,
ArchivedKeys: archivedKeys,
}
encodedBackup, err := jsonutil.EncodeJSON(keyData)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encodedBackup), nil
}
func (p *Policy) getTemplateParts() ([]string, error) {
partsRaw, ok := p.versionPrefixCache.Load("template-parts")
if ok {
return partsRaw.([]string), nil
}
template := p.VersionTemplate
if template == "" {
template = DefaultVersionTemplate
}
tplParts := strings.Split(template, "{{version}}")
if len(tplParts) != 2 {
return nil, errutil.InternalError{Err: "error parsing version template"}
}
p.versionPrefixCache.Store("template-parts", tplParts)
return tplParts, nil
}
func (p *Policy) getVersionPrefix(ver int) string {
prefixRaw, ok := p.versionPrefixCache.Load(ver)
if ok {
return prefixRaw.(string)
}
template := p.VersionTemplate
if template == "" {
template = DefaultVersionTemplate
}
prefix := strings.ReplaceAll(template, "{{version}}", strconv.Itoa(ver))
p.versionPrefixCache.Store(ver, prefix)
return prefix
}
// SymmetricOpts are the arguments to symmetric operations that are "optional", e.g.
// not always used. This improves the aesthetics of calls to those functions.
type SymmetricOpts struct {
// Whether to use convergent encryption
Convergent bool
// The version of the convergent encryption scheme
ConvergentVersion int
// The nonce, if not randomly generated
Nonce []byte
// Additional data to include in AEAD authentication
AdditionalData []byte
// The HMAC key, for generating IVs in convergent encryption
HMACKey []byte
// Allows an external provider of the AEAD, for e.g. managed keys
AEADFactory AEADFactory
}
// Symmetrically encrypt a plaintext given the convergence configuration and appropriate keys
func (p *Policy) SymmetricEncryptRaw(ver int, encKey, plaintext []byte, opts SymmetricOpts) ([]byte, error) {
var aead cipher.AEAD
var err error
nonce := opts.Nonce
switch p.Type {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
// Setup the cipher
aesCipher, err := aes.NewCipher(encKey)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
aead = gcm
case KeyType_ChaCha20_Poly1305:
cha, err := chacha20poly1305.New(encKey)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
aead = cha
case KeyType_MANAGED_KEY:
if opts.Convergent || len(opts.Nonce) != 0 {
return nil, errutil.UserError{Err: "cannot use convergent encryption or provide a nonce to managed-key backed encryption"}
}
if opts.AEADFactory == nil {
return nil, errors.New("expected AEAD factory from managed key, none provided")
}
aead, err = opts.AEADFactory.GetAEAD(nonce)
if err != nil {
return nil, err
}
}
if opts.Convergent {
convergentVersion := p.convergentVersion(ver)
switch convergentVersion {
case 1:
if len(opts.Nonce) != aead.NonceSize() {
return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())}
}
case 2, 3:
if len(opts.HMACKey) == 0 {
return nil, errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")}
}
nonceHmac := hmac.New(sha256.New, opts.HMACKey)
nonceHmac.Write(plaintext)
nonceSum := nonceHmac.Sum(nil)
nonce = nonceSum[:aead.NonceSize()]
default:
return nil, errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)}
}
} else if len(nonce) == 0 {
// Compute random nonce
nonce, err = uuid.GenerateRandomBytes(aead.NonceSize())
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
} else if len(nonce) != aead.NonceSize() {
return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long but given %d bytes", aead.NonceSize(), len(nonce))}
}
// Encrypt and tag with AEAD
ciphertext := aead.Seal(nil, nonce, plaintext, opts.AdditionalData)
// Place the encrypted data after the nonce
if !opts.Convergent || p.convergentVersion(ver) > 1 {
ciphertext = append(nonce, ciphertext...)
}
return ciphertext, nil
}
// Symmetrically decrypt a ciphertext given the convergence configuration and appropriate keys
func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOpts) ([]byte, error) {
var aead cipher.AEAD
var err error
var nonce []byte
switch p.Type {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96:
// Setup the cipher
aesCipher, err := aes.NewCipher(encKey)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
// Setup the GCM AEAD
gcm, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
aead = gcm
case KeyType_ChaCha20_Poly1305:
cha, err := chacha20poly1305.New(encKey)
if err != nil {
return nil, errutil.InternalError{Err: err.Error()}
}
aead = cha
case KeyType_MANAGED_KEY:
aead, err = opts.AEADFactory.GetAEAD(nonce)
if err != nil {
return nil, err
}
}
if len(ciphertext) < aead.NonceSize() {
return nil, errutil.UserError{Err: "invalid ciphertext length"}
}
// Extract the nonce and ciphertext
var trueCT []byte
if opts.Convergent && opts.ConvergentVersion == 1 {
trueCT = ciphertext
} else {
nonce = ciphertext[:aead.NonceSize()]
trueCT = ciphertext[aead.NonceSize():]
}
// Verify and Decrypt
plain, err := aead.Open(nil, nonce, trueCT, opts.AdditionalData)
if err != nil {
return nil, errutil.UserError{Err: err.Error()}
}
return plain, nil
}
func (p *Policy) EncryptWithFactory(ver int, context []byte, nonce []byte, value string, factories ...interface{}) (string, error) {
if !p.Type.EncryptionSupported() {
return "", errutil.UserError{Err: fmt.Sprintf("message encryption not supported for key type %v", p.Type)}
}
// Decode the plaintext value
plaintext, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return "", errutil.UserError{Err: err.Error()}
}
switch {
case ver == 0:
ver = p.LatestVersion
case ver < 0:
return "", errutil.UserError{Err: "requested version for encryption is negative"}
case ver > p.LatestVersion:
return "", errutil.UserError{Err: "requested version for encryption is higher than the latest key version"}
case ver < p.MinEncryptionVersion:
return "", errutil.UserError{Err: "requested version for encryption is less than the minimum encryption key version"}
}
var ciphertext []byte
switch p.Type {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305:
hmacKey := context
var encKey []byte
var deriveHMAC bool
encBytes := 32
hmacBytes := 0
if p.convergentVersion(ver) > 2 {
deriveHMAC = true
hmacBytes = 32
}
if p.Type == KeyType_AES128_GCM96 {
encBytes = 16
}
key, err := p.GetKey(context, ver, encBytes+hmacBytes)
if err != nil {
return "", err
}
if len(key) < encBytes+hmacBytes {
return "", errutil.InternalError{Err: "could not derive key, length too small"}
}
encKey = key[:encBytes]
if len(encKey) != encBytes {
return "", errutil.InternalError{Err: "could not derive enc key, length not correct"}
}
if deriveHMAC {
hmacKey = key[encBytes:]
if len(hmacKey) != hmacBytes {
return "", errutil.InternalError{Err: "could not derive hmac key, length not correct"}
}
}
symopts := SymmetricOpts{
Convergent: p.ConvergentEncryption,
HMACKey: hmacKey,
Nonce: nonce,
}
for index, rawFactory := range factories {
if rawFactory == nil {
continue
}
switch factory := rawFactory.(type) {
case AEADFactory:
symopts.AEADFactory = factory
case AssociatedDataFactory:
symopts.AdditionalData, err = factory.GetAssociatedData()
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("unable to get associated_data/additional_data from factory[%d]: %v", index, err)}
}
case ManagedKeyFactory:
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unknown type of factory[%d]: %T", index, rawFactory)}
}
}
ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, symopts)
if err != nil {
return "", err
}
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096:
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return "", err
}
var publicKey *rsa.PublicKey
if keyEntry.RSAKey != nil {
publicKey = &keyEntry.RSAKey.PublicKey
} else {
publicKey = keyEntry.RSAPublicKey
}
ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, publicKey, plaintext, nil)
if err != nil {
return "", errutil.InternalError{Err: fmt.Sprintf("failed to RSA encrypt the plaintext: %v", err)}
}
case KeyType_MANAGED_KEY:
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return "", err
}
var aad []byte
var managedKeyFactory ManagedKeyFactory
for _, f := range factories {
switch factory := f.(type) {
case AssociatedDataFactory:
aad, err = factory.GetAssociatedData()
if err != nil {
return "", nil
}
case ManagedKeyFactory:
managedKeyFactory = factory
}
}
if managedKeyFactory == nil {
return "", errors.New("key type is managed_key, but managed key parameters were not provided")
}
ciphertext, err = p.encryptWithManagedKey(managedKeyFactory.GetManagedKeyParameters(), keyEntry, plaintext, nonce, aad)
if err != nil {
return "", err
}
default:
return "", errutil.InternalError{Err: fmt.Sprintf("unsupported key type %v", p.Type)}
}
// Convert to base64
encoded := base64.StdEncoding.EncodeToString(ciphertext)
// Prepend some information
encoded = p.getVersionPrefix(ver) + encoded
return encoded, nil
}
func (p *Policy) KeyVersionCanBeUpdated(keyVersion int, isPrivateKey bool) error {
keyEntry, err := p.safeGetKeyEntry(keyVersion)
if err != nil {
return err
}
if !p.Type.ImportPublicKeySupported() {
return errors.New("provided type does not support importing key versions")
}
isPrivateKeyMissing := keyEntry.IsPrivateKeyMissing()
if isPrivateKeyMissing && !isPrivateKey {
return errors.New("cannot add a public key to a key version that already has a public key set")
}
if !isPrivateKeyMissing {
return errors.New("private key imported, key version cannot be updated")
}
return nil
}
func (p *Policy) ImportPrivateKeyForVersion(ctx context.Context, storage logical.Storage, keyVersion int, key []byte) error {
keyEntry, err := p.safeGetKeyEntry(keyVersion)
if err != nil {
return err
}
// Parse key
parsedPrivateKey, err := x509.ParsePKCS8PrivateKey(key)
if err != nil {
if strings.Contains(err.Error(), "unknown elliptic curve") {
var edErr error
parsedPrivateKey, edErr = ParsePKCS8Ed25519PrivateKey(key)
if edErr != nil {
return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an ed25519 private key: %s\n - original error: %v", edErr, err)
}
// Parsing as Ed25519-in-PKCS8-ECPrivateKey succeeded!
} else if strings.Contains(err.Error(), oidSignatureRSAPSS.String()) {
var rsaErr error
parsedPrivateKey, rsaErr = ParsePKCS8RSAPSSPrivateKey(key)
if rsaErr != nil {
return fmt.Errorf("error parsing asymmetric key:\n - assuming contents are an RSA/PSS private key: %v\n - original error: %w", rsaErr, err)
}
// Parsing as RSA-PSS in PKCS8 succeeded!
} else {
return fmt.Errorf("error parsing asymmetric key: %s", err)
}
}
switch parsedPrivateKey.(type) {
case *ecdsa.PrivateKey:
ecdsaKey := parsedPrivateKey.(*ecdsa.PrivateKey)
pemBlock, _ := pem.Decode([]byte(keyEntry.FormattedPublicKey))
publicKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes)
if err != nil || publicKey == nil {
return fmt.Errorf("failed to parse key entry public key: %v", err)
}
if !publicKey.(*ecdsa.PublicKey).Equal(&ecdsaKey.PublicKey) {
return fmt.Errorf("cannot import key, key pair does not match")
}
case *rsa.PrivateKey:
rsaKey := parsedPrivateKey.(*rsa.PrivateKey)
if !rsaKey.PublicKey.Equal(keyEntry.RSAPublicKey) {
return fmt.Errorf("cannot import key, key pair does not match")
}
case ed25519.PrivateKey:
ed25519Key := parsedPrivateKey.(ed25519.PrivateKey)
publicKey, err := base64.StdEncoding.DecodeString(keyEntry.FormattedPublicKey)
if err != nil {
return fmt.Errorf("failed to parse key entry public key: %v", err)
}
if !ed25519.PublicKey(publicKey).Equal(ed25519Key.Public()) {
return fmt.Errorf("cannot import key, key pair does not match")
}
}
err = keyEntry.parseFromKey(p.Type, parsedPrivateKey)
if err != nil {
return err
}
p.Keys[strconv.Itoa(keyVersion)] = keyEntry
return p.Persist(ctx, storage)
}
func (ke *KeyEntry) parseFromKey(PolKeyType KeyType, parsedKey any) error {
switch parsedKey.(type) {
case *ecdsa.PrivateKey, *ecdsa.PublicKey:
if PolKeyType != KeyType_ECDSA_P256 && PolKeyType != KeyType_ECDSA_P384 && PolKeyType != KeyType_ECDSA_P521 {
return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
}
curve := elliptic.P256()
if PolKeyType == KeyType_ECDSA_P384 {
curve = elliptic.P384()
} else if PolKeyType == KeyType_ECDSA_P521 {
curve = elliptic.P521()
}
var derBytes []byte
var err error
ecdsaKey, ok := parsedKey.(*ecdsa.PrivateKey)
if ok {
if ecdsaKey.Curve != curve {
return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
}
ke.EC_D = ecdsaKey.D
ke.EC_X = ecdsaKey.X
ke.EC_Y = ecdsaKey.Y
derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey.Public())
if err != nil {
return errwrap.Wrapf("error marshaling public key: {{err}}", err)
}
} else {
ecdsaKey := parsedKey.(*ecdsa.PublicKey)
if ecdsaKey.Curve != curve {
return fmt.Errorf("invalid curve: expected %s, got %s", curve.Params().Name, ecdsaKey.Curve.Params().Name)
}
ke.EC_X = ecdsaKey.X
ke.EC_Y = ecdsaKey.Y
derBytes, err = x509.MarshalPKIXPublicKey(ecdsaKey)
if err != nil {
return errwrap.Wrapf("error marshaling public key: {{err}}", err)
}
}
pemBlock := &pem.Block{
Type: "PUBLIC KEY",
Bytes: derBytes,
}
pemBytes := pem.EncodeToMemory(pemBlock)
if pemBytes == nil || len(pemBytes) == 0 {
return fmt.Errorf("error PEM-encoding public key")
}
ke.FormattedPublicKey = string(pemBytes)
case ed25519.PrivateKey, ed25519.PublicKey:
if PolKeyType != KeyType_ED25519 {
return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
}
privateKey, ok := parsedKey.(ed25519.PrivateKey)
if ok {
ke.Key = privateKey
publicKey := privateKey.Public().(ed25519.PublicKey)
ke.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
} else {
publicKey := parsedKey.(ed25519.PublicKey)
ke.FormattedPublicKey = base64.StdEncoding.EncodeToString(publicKey)
}
case *rsa.PrivateKey, *rsa.PublicKey:
if PolKeyType != KeyType_RSA2048 && PolKeyType != KeyType_RSA3072 && PolKeyType != KeyType_RSA4096 {
return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
}
keyBytes := 256
if PolKeyType == KeyType_RSA3072 {
keyBytes = 384
} else if PolKeyType == KeyType_RSA4096 {
keyBytes = 512
}
rsaKey, ok := parsedKey.(*rsa.PrivateKey)
if ok {
if rsaKey.Size() != keyBytes {
return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
}
ke.RSAKey = rsaKey
ke.RSAPublicKey = rsaKey.Public().(*rsa.PublicKey)
} else {
rsaKey := parsedKey.(*rsa.PublicKey)
if rsaKey.Size() != keyBytes {
return fmt.Errorf("invalid key size: expected %d bytes, got %d bytes", keyBytes, rsaKey.Size())
}
ke.RSAPublicKey = rsaKey
}
default:
return fmt.Errorf("invalid key type: expected %s, got %T", PolKeyType, parsedKey)
}
return nil
}
func (p *Policy) WrapKey(ver int, targetKey interface{}, targetKeyType KeyType, hash hash.Hash) (string, error) {
if !p.Type.SigningSupported() {
return "", fmt.Errorf("message signing not supported for key type %v", p.Type)
}
switch {
case ver == 0:
ver = p.LatestVersion
case ver < 0:
return "", errutil.UserError{Err: "requested version for key wrapping is negative"}
case ver > p.LatestVersion:
return "", errutil.UserError{Err: "requested version for key wrapping is higher than the latest key version"}
case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
return "", errutil.UserError{Err: "requested version for key wrapping is less than the minimum encryption key version"}
}
keyEntry, err := p.safeGetKeyEntry(ver)
if err != nil {
return "", err
}
return keyEntry.WrapKey(targetKey, targetKeyType, hash)
}
func (ke *KeyEntry) WrapKey(targetKey interface{}, targetKeyType KeyType, hash hash.Hash) (string, error) {
// Presently this method implements a CKM_RSA_AES_KEY_WRAP-compatible
// wrapping interface and only works on RSA keyEntries as a result.
if ke.RSAPublicKey == nil {
return "", fmt.Errorf("unsupported key type in use; must be a rsa key")
}
var preppedTargetKey []byte
switch targetKeyType {
case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305, KeyType_HMAC:
var ok bool
preppedTargetKey, ok = targetKey.([]byte)
if !ok {
return "", fmt.Errorf("failed to wrap target key for import: symmetric key not provided in byte format (%T)", targetKey)
}
default:
var err error
preppedTargetKey, err = x509.MarshalPKCS8PrivateKey(targetKey)
if err != nil {
return "", fmt.Errorf("failed to wrap target key for import: %w", err)
}
}
result, err := wrapTargetPKCS8ForImport(ke.RSAPublicKey, preppedTargetKey, hash)
if err != nil {
return result, fmt.Errorf("failed to wrap target key for import: %w", err)
}
return result, nil
}
func wrapTargetPKCS8ForImport(wrappingKey *rsa.PublicKey, preppedTargetKey []byte, hash hash.Hash) (string, error) {
// Generate an ephemeral AES-256 key
ephKey, err := uuid.GenerateRandomBytes(32)
if err != nil {
return "", fmt.Errorf("failed to generate an ephemeral AES wrapping key: %w", err)
}
// Wrap ephemeral AES key with public wrapping key
ephKeyWrapped, err := rsa.EncryptOAEP(hash, rand.Reader, wrappingKey, ephKey, []byte{} /* label */)
if err != nil {
return "", fmt.Errorf("failed to encrypt ephemeral wrapping key with public key: %w", err)
}
// Create KWP instance for wrapping target key
kwp, err := subtle.NewKWP(ephKey)
if err != nil {
return "", fmt.Errorf("failed to generate new KWP from AES key: %w", err)
}
// Wrap target key with KWP
targetKeyWrapped, err := kwp.Wrap(preppedTargetKey)
if err != nil {
return "", fmt.Errorf("failed to wrap target key with KWP: %w", err)
}
// Combined wrapped keys into a single blob and base64 encode
wrappedKeys := append(ephKeyWrapped, targetKeyWrapped...)
return base64.StdEncoding.EncodeToString(wrappedKeys), nil
}