Add Transit CMAC stubs in CE (#26552)

This commit is contained in:
Steven Clark 2024-04-22 13:19:04 -04:00 committed by GitHub
parent 0a505f9651
commit ff500ca1c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 262 additions and 20 deletions

View File

@ -77,10 +77,12 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
b.pathImportCertChain(),
},
Secrets: []*framework.Secret{},
Invalidate: b.invalidate,
BackendType: logical.TypeLogical,
PeriodicFunc: b.periodicFunc,
Secrets: []*framework.Secret{},
Invalidate: b.invalidate,
BackendType: logical.TypeLogical,
PeriodicFunc: b.periodicFunc,
InitializeFunc: b.initialize,
Clean: b.cleanup,
}
b.backendUUID = conf.BackendUUID
@ -107,11 +109,15 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (*backend, error)
return nil, err
}
b.setupEnt()
return &b, nil
}
type backend struct {
*framework.Backend
entBackend
lm *keysutil.LockManager
// Lock to make changes to any of the backend's cache configuration.
configMutex sync.RWMutex
@ -185,6 +191,8 @@ func (b *backend) invalidate(ctx context.Context, key string) {
defer b.configMutex.Unlock()
b.cacheSizeChanged = true
}
b.invalidateEnt(ctx, key)
}
// periodicFunc is a central collection of functions that run on an interval.
@ -203,7 +211,11 @@ func (b *backend) periodicFunc(ctx context.Context, req *logical.Request) error
b.autoRotateOnce = sync.Once{}
}
return err
if err != nil {
return err
}
return b.periodicFuncEnt(ctx, req)
}
// autoRotateKeys retrieves all transit keys and rotates those which have an
@ -292,3 +304,11 @@ func (b *backend) rotateIfRequired(ctx context.Context, req *logical.Request, ke
}
return nil
}
func (b *backend) initialize(ctx context.Context, request *logical.InitializationRequest) error {
return b.initializeEnt(ctx, request)
}
func (b *backend) cleanup(ctx context.Context) {
b.cleanupEnt(ctx)
}

View File

@ -0,0 +1,26 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package transit
import (
"context"
"github.com/hashicorp/vault/sdk/logical"
)
type entBackend struct{}
func (b *backend) initializeEnt(_ context.Context, _ *logical.InitializationRequest) error {
return nil
}
func (b *backend) invalidateEnt(_ context.Context, _ string) {}
func (b *backend) periodicFuncEnt(_ context.Context, _ *logical.Request) error { return nil }
func (b *backend) cleanupEnt(_ context.Context) {}
func (b *backend) setupEnt() {}

View File

@ -0,0 +1,62 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package transit
import (
"context"
"errors"
"fmt"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
)
func (b *backend) getReadLockedPolicy(ctx context.Context, s logical.Storage, name string) (*keysutil.Policy, error) {
p, _, err := b.GetPolicy(ctx, keysutil.PolicyRequest{
Storage: s,
Name: name,
}, b.GetRandomReader())
if err != nil {
return nil, err
}
if p == nil {
return nil, fmt.Errorf("%w: key %s not found", logical.ErrInvalidRequest, name)
}
if !b.System().CachingDisabled() {
p.Lock(false)
}
return p, nil
}
// runWithReadLockedPolicy runs a function passing in the policy specified by keyName that has been
// locked in a read only fashion without the ability to upsert the policy
func (b *backend) runWithReadLockedPolicy(ctx context.Context, s logical.Storage, keyName string, f func(p *keysutil.Policy) (*logical.Response, error)) (*logical.Response, error) {
p, err := b.getReadLockedPolicy(ctx, s, keyName)
if err != nil {
if errors.Is(err, logical.ErrInvalidRequest) {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
return nil, err
}
defer p.Unlock()
return f(p)
}
// validateKeyVersion verifies that the passed in key version is valid for our
// current key policy, returning correct version to use within the policy.
func validateKeyVersion(p *keysutil.Policy, ver int) (int, error) {
switch {
case ver < 0:
return 0, fmt.Errorf("cannot use negative key version %d", ver)
case ver == 0:
// Allowed, will use latest; set explicitly here to ensure the string
// is generated properly
ver = p.LatestVersion
case ver == p.LatestVersion:
// Allowed
case p.MinEncryptionVersion > 0 && ver < p.MinEncryptionVersion:
return 0, fmt.Errorf("cannot use key version %d: version is too old (disallowed by policy) for key %s", ver, p.Name)
}
return ver, nil
}

View File

@ -0,0 +1,17 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package transit
import (
"context"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
func (b *backend) pathCMACVerify(_ context.Context, _ *logical.Request, _ *framework.FieldData) (*logical.Response, error) {
return logical.ErrorResponse("CMAC verification is only available in enterprise versions of Vault"), nil
}

View File

@ -12,6 +12,7 @@ import (
"strconv"
"strings"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/errutil"
"github.com/hashicorp/vault/sdk/helper/keysutil"
@ -51,7 +52,7 @@ type batchResponseSignItem struct {
// BatchRequestVerifyItem represents a request item for batch processing.
// A map type allows us to distinguish between empty and missing values.
type batchRequestVerifyItem map[string]string
type batchRequestVerifyItem map[string]interface{}
// BatchResponseVerifyItem represents a response item for batch processing
type batchResponseVerifyItem struct {
@ -216,6 +217,11 @@ derivation is enabled; currently only available with ed25519 keys.`,
Description: "The HMAC, including vault header/key version",
},
"cmac": {
Type: framework.TypeString,
Description: "The CMAC, including vault header/key version",
},
"input": {
Type: framework.TypeString,
Description: "The base64-encoded input data to verify",
@ -226,6 +232,11 @@ derivation is enabled; currently only available with ed25519 keys.`,
Description: `Hash algorithm to use (POST URL parameter)`,
},
"mac_length": {
Type: framework.TypeInt,
Description: `MAC length to use (POST body parameter). Valid values are:`,
},
"hash_algorithm": {
Type: framework.TypeString,
Default: defaultHashAlgorithm,
@ -279,7 +290,7 @@ Options are 'auto' (the default used by Golang, causing the salt to be as large
"batch_input": {
Type: framework.TypeSlice,
Description: `Specifies a list of items for processing. When this parameter is set,
any supplied 'input', 'hmac' or 'signature' parameters will be ignored. Responses are returned in the
any supplied 'input', 'hmac', 'cmac' or 'signature' parameters will be ignored. Responses are returned in the
'batch_results' array component of the 'data' element of the response. Any batch output will
preserve the order of the batch input`,
},
@ -534,6 +545,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
if hmac, ok := d.GetOk("hmac"); ok {
batchInputItems[0]["hmac"] = hmac.(string)
}
if cmac, ok := d.GetOk("cmac"); ok {
batchInputItems[0]["cmac"] = cmac.(string)
}
batchInputItems[0]["context"] = d.Get("context").(string)
}
@ -542,26 +556,30 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
// If one batch_input item is 'hmac', they all must be 'hmac'.
sigFound := false
hmacFound := false
cmacFound := false
missing := false
for _, v := range batchInputItems {
if _, ok := v["signature"]; ok {
sigFound = true
} else if _, ok := v["hmac"]; ok {
hmacFound = true
} else if _, ok := v["cmac"]; ok {
cmacFound = true
} else {
missing = true
}
}
optionsSet := numBooleansTrue(sigFound, hmacFound, cmacFound)
switch {
case batchInputRaw == nil && sigFound && hmacFound:
return logical.ErrorResponse("provide one of 'signature' or 'hmac'"), logical.ErrInvalidRequest
case batchInputRaw == nil && optionsSet > 1:
return logical.ErrorResponse("provide one of 'signature', 'hmac' or 'cmac'"), logical.ErrInvalidRequest
case batchInputRaw == nil && !sigFound && !hmacFound:
return logical.ErrorResponse("neither a 'signature' nor an 'hmac' were given to verify"), logical.ErrInvalidRequest
case batchInputRaw == nil && optionsSet == 0:
return logical.ErrorResponse("missing 'signature', 'hmac' or 'cmac' were given to verify"), logical.ErrInvalidRequest
case sigFound && hmacFound:
return logical.ErrorResponse("elements of batch_input must all provide 'signature' or all provide 'hmac'"), logical.ErrInvalidRequest
case optionsSet > 1:
return logical.ErrorResponse("elements of batch_input must all provide either 'signature', 'hmac' or 'cmac'"), logical.ErrInvalidRequest
case missing && sigFound:
return logical.ErrorResponse("some elements of batch_input are missing 'signature'"), logical.ErrInvalidRequest
@ -569,11 +587,17 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
case missing && hmacFound:
return logical.ErrorResponse("some elements of batch_input are missing 'hmac'"), logical.ErrInvalidRequest
case missing:
return logical.ErrorResponse("no batch_input elements have 'signature' or 'hmac'"), logical.ErrInvalidRequest
case missing && cmacFound:
return logical.ErrorResponse("some elements of batch_input are missing 'cmac'"), logical.ErrInvalidRequest
case optionsSet == 0:
return logical.ErrorResponse("no batch_input elements have 'signature', 'hmac' or 'cmac'"), logical.ErrInvalidRequest
case hmacFound:
return b.pathHMACVerify(ctx, req, d)
case cmacFound:
return b.pathCMACVerify(ctx, req, d)
}
name := d.Get("name").(string)
@ -636,27 +660,38 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
response := make([]batchResponseVerifyItem, len(batchInputItems))
for i, item := range batchInputItems {
rawInput, ok := item["input"]
if !ok {
response[i].Error = "missing input"
response[i].err = logical.ErrInvalidRequest
continue
}
strInput, err := parseutil.ParseString(rawInput)
if err != nil {
response[i].Error = fmt.Sprintf("unable to parse input as string: %s", err)
response[i].err = logical.ErrInvalidRequest
continue
}
input, err := base64.StdEncoding.DecodeString(rawInput)
input, err := base64.StdEncoding.DecodeString(strInput)
if err != nil {
response[i].Error = fmt.Sprintf("unable to decode input as base64: %s", err)
response[i].err = logical.ErrInvalidRequest
continue
}
sig, ok := item["signature"]
sigRaw, ok := item["signature"].(string)
if !ok {
response[i].Error = "missing signature"
response[i].err = logical.ErrInvalidRequest
continue
}
sig, err := parseutil.ParseString(sigRaw)
if err != nil {
response[i].Error = fmt.Sprintf("failed to parse signature as a string: %s", err)
response[i].err = logical.ErrInvalidRequest
continue
}
if p.Type.HashSignatureInput() && !prehashed {
hf := keysutil.HashFuncMap[hashAlgorithm]()
@ -666,7 +701,12 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
}
}
contextRaw := item["context"]
contextRaw, err := parseutil.ParseString(item["context"])
if err != nil {
response[i].Error = fmt.Sprintf("failed to parse context as a string: %s", err)
response[i].err = logical.ErrInvalidRequest
continue
}
var context []byte
if len(contextRaw) != 0 {
context, err = base64.StdEncoding.DecodeString(contextRaw)
@ -720,7 +760,9 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
if batchInputRaw != nil {
// Copy the references
for i := range batchInputItems {
response[i].Reference = batchInputItems[i]["reference"]
if ref, err := parseutil.ParseString(batchInputItems[i]["reference"]); err == nil {
response[i].Reference = ref
}
}
resp.Data = map[string]interface{}{
"batch_results": response,
@ -740,6 +782,53 @@ func (b *backend) pathVerifyWrite(ctx context.Context, req *logical.Request, d *
return resp, nil
}
func numBooleansTrue(bools ...bool) int {
numSet := 0
for _, value := range bools {
if value {
numSet++
}
}
return numSet
}
func decodeTransitSignature(sig string) ([]byte, int, error) {
if !strings.HasPrefix(sig, "vault:v") {
return nil, 0, fmt.Errorf("prefix is not vault:v")
}
splitVerification := strings.SplitN(strings.TrimPrefix(sig, "vault:v"), ":", 2)
if len(splitVerification) != 2 {
return nil, 0, fmt.Errorf("wrong number of fields delimited by ':', got %d expected 2", len(splitVerification))
}
ver, err := strconv.Atoi(splitVerification[0])
if err != nil {
return nil, 0, fmt.Errorf("key version number %s count not be decoded", splitVerification[0])
}
if ver < 1 {
return nil, 0, fmt.Errorf("key version less than 1 are invalid got: %d", ver)
}
if len(strings.TrimSpace(splitVerification[1])) == 0 {
return nil, 0, fmt.Errorf("missing base64 verification string from vault signature")
}
verBytes, err := base64.StdEncoding.DecodeString(splitVerification[1])
if err != nil {
return nil, 0, fmt.Errorf("unable to decode verification string as base64: %s", err)
}
return verBytes, ver, nil
}
func encodeTransitSignature(value []byte, keyVersion int) string {
retStr := base64.StdEncoding.EncodeToString(value)
retStr = fmt.Sprintf("vault:v%d:%s", keyVersion, retStr)
return retStr
}
const pathSignHelpSyn = `Generate a signature for input data using the named key`
const pathSignHelpDesc = `

View File

@ -170,6 +170,15 @@ func (kt KeyType) AssociatedDataSupported() bool {
return false
}
func (kt KeyType) CMACSupported() bool {
switch kt {
case KeyType_AES128_CMAC, KeyType_AES256_CMAC:
return true
default:
return false
}
}
func (kt KeyType) ImportPublicKeySupported() bool {
switch kt {
case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096, KeyType_ECDSA_P256, KeyType_ECDSA_P384, KeyType_ECDSA_P521, KeyType_ED25519:
@ -1077,6 +1086,25 @@ func (p *Policy) HMACKey(version int) ([]byte, error) {
return keyEntry.HMACKey, nil
}
func (p *Policy) CMACKey(version int) ([]byte, error) {
switch {
case version < 0:
return nil, fmt.Errorf("key version does not exist (cannot be negative)")
case version > p.LatestVersion:
return nil, fmt.Errorf("key version does not exist; latest key version is %d", p.LatestVersion)
}
keyEntry, err := p.safeGetKeyEntry(version)
if err != nil {
return nil, err
}
if p.Type.CMACSupported() {
return keyEntry.Key, nil
}
return nil, fmt.Errorf("key type %s does not support CMAC operations", p.Type)
}
func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, sigAlgorithm string, marshaling MarshalingType) (*SigningResult, error) {
return p.SignWithOptions(ver, context, input, &SigningOptions{
HashAlgorithm: hashAlgorithm,