From ff500ca1c3df00c4e64fb29c676b3f136e0bfdbb Mon Sep 17 00:00:00 2001 From: Steven Clark Date: Mon, 22 Apr 2024 13:19:04 -0400 Subject: [PATCH] Add Transit CMAC stubs in CE (#26552) --- builtin/logical/transit/backend.go | 30 ++++- builtin/logical/transit/backend_ce.go | 26 +++++ builtin/logical/transit/key_utils.go | 62 ++++++++++ builtin/logical/transit/path_cmac_ce.go | 17 +++ builtin/logical/transit/path_sign_verify.go | 119 +++++++++++++++++--- sdk/helper/keysutil/policy.go | 28 +++++ 6 files changed, 262 insertions(+), 20 deletions(-) create mode 100644 builtin/logical/transit/backend_ce.go create mode 100644 builtin/logical/transit/key_utils.go create mode 100644 builtin/logical/transit/path_cmac_ce.go diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index e30d766049..694d0c735b 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -77,10 +77,12 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) b.pathImportCertChain(), }, - Secrets: []*framework.Secret{}, - Invalidate: b.invalidate, - BackendType: logical.TypeLogical, - PeriodicFunc: b.periodicFunc, + Secrets: []*framework.Secret{}, + Invalidate: b.invalidate, + BackendType: logical.TypeLogical, + PeriodicFunc: b.periodicFunc, + InitializeFunc: b.initialize, + Clean: b.cleanup, } b.backendUUID = conf.BackendUUID @@ -107,11 +109,15 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error) return nil, err } + b.setupEnt() + return &b, nil } type backend struct { *framework.Backend + entBackend + lm *keysutil.LockManager // Lock to make changes to any of the backend's cache configuration. configMutex sync.RWMutex @@ -185,6 +191,8 @@ func (b *backend) invalidate(ctx context.Context, key string) { defer b.configMutex.Unlock() b.cacheSizeChanged = true } + + b.invalidateEnt(ctx, key) } // periodicFunc is a central collection of functions that run on an interval. @@ -203,7 +211,11 @@ func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error b.autoRotateOnce = sync.Once{} } - return err + if err != nil { + return err + } + + return b.periodicFuncEnt(ctx, req) } // autoRotateKeys retrieves all transit keys and rotates those which have an @@ -292,3 +304,11 @@ func (b *backend) rotateIfRequired(ctx context.Context, req *logical.Request, ke } return nil } + +func (b *backend) initialize(ctx context.Context, request *logical.InitializationRequest) error { + return b.initializeEnt(ctx, request) +} + +func (b *backend) cleanup(ctx context.Context) { + b.cleanupEnt(ctx) +} diff --git a/builtin/logical/transit/backend_ce.go b/builtin/logical/transit/backend_ce.go new file mode 100644 index 0000000000..4c88fc30a3 --- /dev/null +++ b/builtin/logical/transit/backend_ce.go @@ -0,0 +1,26 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package transit + +import ( + "context" + + "github.com/hashicorp/vault/sdk/logical" +) + +type entBackend struct{} + +func (b *backend) initializeEnt(_ context.Context, _ *logical.InitializationRequest) error { + return nil +} + +func (b *backend) invalidateEnt(_ context.Context, _ string) {} + +func (b *backend) periodicFuncEnt(_ context.Context, _ *logical.Request) error { return nil } + +func (b *backend) cleanupEnt(_ context.Context) {} + +func (b *backend) setupEnt() {} diff --git a/builtin/logical/transit/key_utils.go b/builtin/logical/transit/key_utils.go new file mode 100644 index 0000000000..05dfe8b0b8 --- /dev/null +++ b/builtin/logical/transit/key_utils.go @@ -0,0 +1,62 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package transit + +import ( + "context" + "errors" + "fmt" + + "github.com/hashicorp/vault/sdk/helper/keysutil" + "github.com/hashicorp/vault/sdk/logical" +) + +func (b *backend) getReadLockedPolicy(ctx context.Context, s logical.Storage, name string) (*keysutil.Policy, error) { + p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{ + Storage: s, + Name: name, + }, b.GetRandomReader()) + if err != nil { + return nil, err + } + if p == nil { + return nil, fmt.Errorf("%w: key %s not found", logical.ErrInvalidRequest, name) + } + if !b.System().CachingDisabled() { + p.Lock(false) + } + return p, nil +} + +// runWithReadLockedPolicy runs a function passing in the policy specified by keyName that has been +// locked in a read only fashion without the ability to upsert the policy +func (b *backend) runWithReadLockedPolicy(ctx context.Context, s logical.Storage, keyName string, f func(p *keysutil.Policy) (*logical.Response, error)) (*logical.Response, error) { + p, err := b.getReadLockedPolicy(ctx, s, keyName) + if err != nil { + if errors.Is(err, logical.ErrInvalidRequest) { + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + } + return nil, err + } + defer p.Unlock() + return f(p) +} + +// validateKeyVersion verifies that the passed in key version is valid for our +// current key policy, returning correct version to use within the policy. +func validateKeyVersion(p *keysutil.Policy, ver int) (int, error) { + switch { + case ver < 0: + return 0, fmt.Errorf("cannot use negative key version %d", ver) + case ver == 0: + // Allowed, will use latest; set explicitly here to ensure the string + // is generated properly + ver = p.LatestVersion + case ver == p.LatestVersion: + // Allowed + case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion: + return 0, fmt.Errorf("cannot use key version %d: version is too old (disallowed by policy) for key %s", ver, p.Name) + } + return ver, nil +} diff --git a/builtin/logical/transit/path_cmac_ce.go b/builtin/logical/transit/path_cmac_ce.go new file mode 100644 index 0000000000..c9919aaa9a --- /dev/null +++ b/builtin/logical/transit/path_cmac_ce.go @@ -0,0 +1,17 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package transit + +import ( + "context" + + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" +) + +func (b *backend) pathCMACVerify(_ context.Context, _ *logical.Request, _ *framework.FieldData) (*logical.Response, error) { + return logical.ErrorResponse("CMAC verification is only available in enterprise versions of Vault"), nil +} diff --git a/builtin/logical/transit/path_sign_verify.go b/builtin/logical/transit/path_sign_verify.go index 3307c5ca99..2043c8724e 100644 --- a/builtin/logical/transit/path_sign_verify.go +++ b/builtin/logical/transit/path_sign_verify.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" + "github.com/hashicorp/go-secure-stdlib/parseutil" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/keysutil" @@ -51,7 +52,7 @@ type batchResponseSignItem struct { // BatchRequestVerifyItem represents a request item for batch processing. // A map type allows us to distinguish between empty and missing values. -type batchRequestVerifyItem map[string]string +type batchRequestVerifyItem map[string]interface{} // BatchResponseVerifyItem represents a response item for batch processing type batchResponseVerifyItem struct { @@ -216,6 +217,11 @@ derivation is enabled; currently only available with ed25519 keys.`, Description: "The HMAC, including vault header/key version", }, + "cmac": { + Type: framework.TypeString, + Description: "The CMAC, including vault header/key version", + }, + "input": { Type: framework.TypeString, Description: "The base64-encoded input data to verify", @@ -226,6 +232,11 @@ derivation is enabled; currently only available with ed25519 keys.`, Description: `Hash algorithm to use (POST URL parameter)`, }, + "mac_length": { + Type: framework.TypeInt, + Description: `MAC length to use (POST body parameter). Valid values are:`, + }, + "hash_algorithm": { Type: framework.TypeString, Default: defaultHashAlgorithm, @@ -279,7 +290,7 @@ Options are 'auto' (the default used by Golang, causing the salt to be as large "batch_input": { Type: framework.TypeSlice, Description: `Specifies a list of items for processing. When this parameter is set, -any supplied 'input', 'hmac' or 'signature' parameters will be ignored. Responses are returned in the +any supplied 'input', 'hmac', 'cmac' or 'signature' parameters will be ignored. Responses are returned in the 'batch_results' array component of the 'data' element of the response. Any batch output will preserve the order of the batch input`, }, @@ -534,6 +545,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * if hmac, ok := d.GetOk("hmac"); ok { batchInputItems[0]["hmac"] = hmac.(string) } + if cmac, ok := d.GetOk("cmac"); ok { + batchInputItems[0]["cmac"] = cmac.(string) + } batchInputItems[0]["context"] = d.Get("context").(string) } @@ -542,26 +556,30 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * // If one batch_input item is 'hmac', they all must be 'hmac'. sigFound := false hmacFound := false + cmacFound := false missing := false for _, v := range batchInputItems { if _, ok := v["signature"]; ok { sigFound = true } else if _, ok := v["hmac"]; ok { hmacFound = true + } else if _, ok := v["cmac"]; ok { + cmacFound = true } else { missing = true } } + optionsSet := numBooleansTrue(sigFound, hmacFound, cmacFound) switch { - case batchInputRaw == nil && sigFound && hmacFound: - return logical.ErrorResponse("provide one of 'signature' or 'hmac'"), logical.ErrInvalidRequest + case batchInputRaw == nil && optionsSet > 1: + return logical.ErrorResponse("provide one of 'signature', 'hmac' or 'cmac'"), logical.ErrInvalidRequest - case batchInputRaw == nil && !sigFound && !hmacFound: - return logical.ErrorResponse("neither a 'signature' nor an 'hmac' were given to verify"), logical.ErrInvalidRequest + case batchInputRaw == nil && optionsSet == 0: + return logical.ErrorResponse("missing 'signature', 'hmac' or 'cmac' were given to verify"), logical.ErrInvalidRequest - case sigFound && hmacFound: - return logical.ErrorResponse("elements of batch_input must all provide 'signature' or all provide 'hmac'"), logical.ErrInvalidRequest + case optionsSet > 1: + return logical.ErrorResponse("elements of batch_input must all provide either 'signature', 'hmac' or 'cmac'"), logical.ErrInvalidRequest case missing && sigFound: return logical.ErrorResponse("some elements of batch_input are missing 'signature'"), logical.ErrInvalidRequest @@ -569,11 +587,17 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * case missing && hmacFound: return logical.ErrorResponse("some elements of batch_input are missing 'hmac'"), logical.ErrInvalidRequest - case missing: - return logical.ErrorResponse("no batch_input elements have 'signature' or 'hmac'"), logical.ErrInvalidRequest + case missing && cmacFound: + return logical.ErrorResponse("some elements of batch_input are missing 'cmac'"), logical.ErrInvalidRequest + + case optionsSet == 0: + return logical.ErrorResponse("no batch_input elements have 'signature', 'hmac' or 'cmac'"), logical.ErrInvalidRequest case hmacFound: return b.pathHMACVerify(ctx, req, d) + + case cmacFound: + return b.pathCMACVerify(ctx, req, d) } name := d.Get("name").(string) @@ -636,27 +660,38 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * response := make([]batchResponseVerifyItem, len(batchInputItems)) for i, item := range batchInputItems { - rawInput, ok := item["input"] if !ok { response[i].Error = "missing input" response[i].err = logical.ErrInvalidRequest continue } + strInput, err := parseutil.ParseString(rawInput) + if err != nil { + response[i].Error = fmt.Sprintf("unable to parse input as string: %s", err) + response[i].err = logical.ErrInvalidRequest + continue + } - input, err := base64.StdEncoding.DecodeString(rawInput) + input, err := base64.StdEncoding.DecodeString(strInput) if err != nil { response[i].Error = fmt.Sprintf("unable to decode input as base64: %s", err) response[i].err = logical.ErrInvalidRequest continue } - sig, ok := item["signature"] + sigRaw, ok := item["signature"].(string) if !ok { response[i].Error = "missing signature" response[i].err = logical.ErrInvalidRequest continue } + sig, err := parseutil.ParseString(sigRaw) + if err != nil { + response[i].Error = fmt.Sprintf("failed to parse signature as a string: %s", err) + response[i].err = logical.ErrInvalidRequest + continue + } if p.Type.HashSignatureInput() && !prehashed { hf := keysutil.HashFuncMap[hashAlgorithm]() @@ -666,7 +701,12 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * } } - contextRaw := item["context"] + contextRaw, err := parseutil.ParseString(item["context"]) + if err != nil { + response[i].Error = fmt.Sprintf("failed to parse context as a string: %s", err) + response[i].err = logical.ErrInvalidRequest + continue + } var context []byte if len(contextRaw) != 0 { context, err = base64.StdEncoding.DecodeString(contextRaw) @@ -720,7 +760,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * if batchInputRaw != nil { // Copy the references for i := range batchInputItems { - response[i].Reference = batchInputItems[i]["reference"] + if ref, err := parseutil.ParseString(batchInputItems[i]["reference"]); err == nil { + response[i].Reference = ref + } } resp.Data = map[string]interface{}{ "batch_results": response, @@ -740,6 +782,53 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d * return resp, nil } +func numBooleansTrue(bools ...bool) int { + numSet := 0 + for _, value := range bools { + if value { + numSet++ + } + } + return numSet +} + +func decodeTransitSignature(sig string) ([]byte, int, error) { + if !strings.HasPrefix(sig, "vault:v") { + return nil, 0, fmt.Errorf("prefix is not vault:v") + } + + splitVerification := strings.SplitN(strings.TrimPrefix(sig, "vault:v"), ":", 2) + if len(splitVerification) != 2 { + return nil, 0, fmt.Errorf("wrong number of fields delimited by ':', got %d expected 2", len(splitVerification)) + } + + ver, err := strconv.Atoi(splitVerification[0]) + if err != nil { + return nil, 0, fmt.Errorf("key version number %s count not be decoded", splitVerification[0]) + } + + if ver < 1 { + return nil, 0, fmt.Errorf("key version less than 1 are invalid got: %d", ver) + } + + if len(strings.TrimSpace(splitVerification[1])) == 0 { + return nil, 0, fmt.Errorf("missing base64 verification string from vault signature") + } + + verBytes, err := base64.StdEncoding.DecodeString(splitVerification[1]) + if err != nil { + return nil, 0, fmt.Errorf("unable to decode verification string as base64: %s", err) + } + + return verBytes, ver, nil +} + +func encodeTransitSignature(value []byte, keyVersion int) string { + retStr := base64.StdEncoding.EncodeToString(value) + retStr = fmt.Sprintf("vault:v%d:%s", keyVersion, retStr) + return retStr +} + const pathSignHelpSyn = `Generate a signature for input data using the named key` const pathSignHelpDesc = ` diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index 2a4cadd2f0..467937b46c 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -170,6 +170,15 @@ func (kt KeyType) AssociatedDataSupported() bool { return false } +func (kt KeyType) CMACSupported() bool { + switch kt { + case KeyType_AES128_CMAC, KeyType_AES256_CMAC: + return true + default: + return false + } +} + func (kt KeyType) ImportPublicKeySupported() bool { switch kt { case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519: @@ -1077,6 +1086,25 @@ func (p *Policy) HMACKey(version int) ([]byte, error) { return keyEntry.HMACKey, nil } +func (p *Policy) CMACKey(version int) ([]byte, error) { + switch { + case version < 0: + return nil, fmt.Errorf("key version does not exist (cannot be negative)") + case version > p.LatestVersion: + return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion) + } + keyEntry, err := p.safeGetKeyEntry(version) + if err != nil { + return nil, err + } + + if p.Type.CMACSupported() { + return keyEntry.Key, nil + } + + return nil, fmt.Errorf("key type %s does not support CMAC operations", p.Type) +} + func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) { return p.SignWithOptions(ver, context, input, &SigningOptions{ HashAlgorithm: hashAlgorithm,