From bf8d7efb072018f36bf3072f0c1ab0f5b21277e5 Mon Sep 17 00:00:00 2001 From: Scott Miller Date: Thu, 1 Oct 2020 21:04:36 -0500 Subject: [PATCH] Expose generic versions of KDF and symmetric crypto (#10076) * Support salt in DeriveKey * Revert "Support salt in DeriveKey" This reverts commit b295ae42673308a2d66d66b53527c6f9aba92ac9. * Refactor out key derivation, symmetric encryption, and symmetric decryption into generic functions * comments * comments * go mod vendor * bump both go.mods * This one too * bump * bump * bump * Make the lesser used params of symmetric ops a struct * go fmt * Call GetKey instead of DeriveKey * Address feedback * Wrong rv * Rename calls * Assign the nonce field * trivial change * Check nonce len instead * go mod vendor --- api/go.mod | 2 +- api/go.sum | 1 + builtin/logical/transit/backend_test.go | 10 +- builtin/logical/transit/path_keys.go | 2 +- go.mod | 4 +- sdk/go.mod | 3 +- sdk/go.sum | 8 + sdk/helper/keysutil/policy.go | 272 +++++++++++------- sdk/helper/keysutil/policy_test.go | 28 ++ .../vault/sdk/helper/keysutil/policy.go | 272 +++++++++++------- vendor/modules.txt | 4 +- 11 files changed, 374 insertions(+), 232 deletions(-) diff --git a/api/go.mod b/api/go.mod index c6cee73f22..abc93cd9eb 100644 --- a/api/go.mod +++ b/api/go.mod @@ -12,7 +12,7 @@ require ( github.com/hashicorp/go-retryablehttp v0.6.6 github.com/hashicorp/go-rootcerts v1.0.2 github.com/hashicorp/hcl v1.0.0 - github.com/hashicorp/vault/sdk v0.1.14-0.20200519221838-e0cfd64bc267 + github.com/hashicorp/vault/sdk v0.0.0-20201001212527-2e121bafe1e4 github.com/mitchellh/mapstructure v1.3.2 golang.org/x/net v0.0.0-20200602114024-627f9648deb9 golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 diff --git a/api/go.sum b/api/go.sum index ae27f2521e..a1ff1c60b0 100644 --- a/api/go.sum +++ b/api/go.sum @@ -131,6 +131,7 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/vault/api v0.0.0-20201001211907-38d91b749c77/go.mod h1:R3Umvhlxi2TN7Ex2hzOowyeNb+SfbVWI973N+ctaFMk= github.com/hashicorp/vault/api v1.0.5-0.20200519221902-385fac77e20f/go.mod h1:euTFbi2YJgwcju3imEt919lhJKF68nN1cQPq3aA+kBE= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 562b6d3b90..7b258a1627 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -24,7 +24,7 @@ import ( ) const ( - testPlaintext = "the quick brown fox" + testPlaintext = "The quick brown fox" ) func createBackendWithStorage(t testing.TB) (*backend, logical.Storage) { @@ -930,12 +930,12 @@ func testDerivedKeyUpgrade(t *testing.T, keyType keysutil.KeyType) { t.Fatalf("bad KDF value by default; counter val is %d, KDF val is %d, policy is %#v", keysutil.Kdf_hmac_sha256_counter, p.KDF, *p) } - derBytesOld, err := p.DeriveKey(keyContext, 1, 0) + derBytesOld, err := p.GetKey(keyContext, 1, 0) if err != nil { t.Fatal(err) } - derBytesOld2, err := p.DeriveKey(keyContext, 1, 0) + derBytesOld2, err := p.GetKey(keyContext, 1, 0) if err != nil { t.Fatal(err) } @@ -949,12 +949,12 @@ func testDerivedKeyUpgrade(t *testing.T, keyType keysutil.KeyType) { t.Fatal("expected no upgrade needed") } - derBytesNew, err := p.DeriveKey(keyContext, 1, 64) + derBytesNew, err := p.GetKey(keyContext, 1, 64) if err != nil { t.Fatal(err) } - derBytesNew2, err := p.DeriveKey(keyContext, 1, 64) + derBytesNew2, err := p.GetKey(keyContext, 1, 64) if err != nil { t.Fatal(err) } diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 2725020354..18300c7dc2 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -298,7 +298,7 @@ func (b *backend) pathPolicyRead(ctx context.Context, req *logical.Request, d *f if err != nil { return nil, errwrap.Wrapf(fmt.Sprintf("invalid version %q: {{err}}", k), err) } - derived, err := p.DeriveKey(context, ver, 32) + derived, err := p.GetKey(context, ver, 32) if err != nil { return nil, fmt.Errorf("failed to derive key to return public component") } diff --git a/go.mod b/go.mod index ce46fa4d5c..6f52bbf4a6 100644 --- a/go.mod +++ b/go.mod @@ -94,8 +94,8 @@ require ( github.com/hashicorp/vault-plugin-secrets-kv v0.5.6 github.com/hashicorp/vault-plugin-secrets-mongodbatlas v0.1.2 github.com/hashicorp/vault-plugin-secrets-openldap v0.1.5 - github.com/hashicorp/vault/api v1.0.5-0.20200805123347-1ef507638af6 - github.com/hashicorp/vault/sdk v0.1.14-0.20200916184745-5576096032f8 + github.com/hashicorp/vault/api v1.0.5-0.20201001211907-38d91b749c77 + github.com/hashicorp/vault/sdk v0.1.14-0.20201001211907-38d91b749c77 github.com/influxdata/influxdb v0.0.0-20190411212539-d24b7ba8c4c4 github.com/jcmturner/gokrb5/v8 v8.0.0 github.com/jefferai/isbadcipher v0.0.0-20190226160619-51d2077c035f diff --git a/sdk/go.mod b/sdk/go.mod index 54b79cfb7e..a636913229 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -33,7 +33,7 @@ require ( github.com/hashicorp/go-version v1.2.0 github.com/hashicorp/golang-lru v0.5.3 github.com/hashicorp/hcl v1.0.0 - github.com/hashicorp/vault/api v1.0.5-0.20200519221902-385fac77e20f + github.com/hashicorp/vault/api v0.0.0-20201001212527-2e121bafe1e4 github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.6 // indirect github.com/mitchellh/copystructure v1.0.0 @@ -46,6 +46,7 @@ require ( github.com/pierrec/lz4 v2.5.2+incompatible github.com/pkg/errors v0.9.1 github.com/ryanuber/go-glob v1.0.0 + github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 golang.org/x/net v0.0.0-20200602114024-627f9648deb9 golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 diff --git a/sdk/go.sum b/sdk/go.sum index 6acd468237..556433fbea 100644 --- a/sdk/go.sum +++ b/sdk/go.sum @@ -157,6 +157,8 @@ github.com/hashicorp/go-retryablehttp v0.6.6 h1:HJunrbHTDDbBb/ay4kxa1n+dLmttUlnP github.com/hashicorp/go-retryablehttp v0.6.6/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= github.com/hashicorp/go-rootcerts v1.0.1 h1:DMo4fmknnz0E0evoNYnV48RjWndOsmd6OW+09R3cEP8= github.com/hashicorp/go-rootcerts v1.0.1/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= github.com/hashicorp/go-sockaddr v1.0.2 h1:ztczhD1jLxIRjVejw8gFomI1BQZOe2WoVOu0SyteCQc= github.com/hashicorp/go-sockaddr v1.0.2/go.mod h1:rB4wwRAUzs07qva3c5SdrY/NEtAUjGlgmH/UkBUC97A= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -171,9 +173,14 @@ github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8 github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/vault/api v0.0.0-20201001211907-38d91b749c77 h1:ksVfuH3dKQPzqBH+LDVYMXeYC8rzuz526qBIQrIps7U= +github.com/hashicorp/vault/api v0.0.0-20201001211907-38d91b749c77/go.mod h1:R3Umvhlxi2TN7Ex2hzOowyeNb+SfbVWI973N+ctaFMk= +github.com/hashicorp/vault/api v0.0.0-20201001212527-2e121bafe1e4 h1:jNcfITMv2iB4wm+VGeB3uWGf8Gz3YXFHQJOAQZPjpVU= +github.com/hashicorp/vault/api v0.0.0-20201001212527-2e121bafe1e4/go.mod h1:R3Umvhlxi2TN7Ex2hzOowyeNb+SfbVWI973N+ctaFMk= github.com/hashicorp/vault/api v1.0.5-0.20200519221902-385fac77e20f h1:PYtnlUZzFSZxPcq7mYp5oC9N+BcJ8IKYf6/EG0GHM2Y= github.com/hashicorp/vault/api v1.0.5-0.20200519221902-385fac77e20f/go.mod h1:euTFbi2YJgwcju3imEt919lhJKF68nN1cQPq3aA+kBE= github.com/hashicorp/vault/sdk v0.1.14-0.20200519221530-14615acda45f/go.mod h1:WX57W2PwkrOPQ6rVQk+dy5/htHIaB4aBM70EwKThu10= +github.com/hashicorp/vault/sdk v0.1.14-0.20200519221838-e0cfd64bc267/go.mod h1:WX57W2PwkrOPQ6rVQk+dy5/htHIaB4aBM70EwKThu10= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= @@ -297,6 +304,7 @@ github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKv github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/sdk/helper/keysutil/policy.go b/sdk/helper/keysutil/policy.go index 949c626505..5e3e8d7c0d 100644 --- a/sdk/helper/keysutil/policy.go +++ b/sdk/helper/keysutil/policy.go @@ -692,16 +692,23 @@ func (p *Policy) Upgrade(ctx context.Context, storage logical.Storage, randReade return nil } -// DeriveKey is used to derive the encryption key that should be used depending +// GetKey is used to derive the encryption key that should be used depending // on the policy. If derivation is disabled the raw key is used and no context // is required, otherwise the KDF mode is used with the context to derive the // proper key. -func (p *Policy) DeriveKey(context []byte, ver, numBytes int) ([]byte, error) { +func (p *Policy) GetKey(context []byte, ver, numBytes int) ([]byte, error) { // Fast-path non-derived keys if !p.Derived { return p.Keys[strconv.Itoa(ver)].Key, nil } + return p.DeriveKey(context, nil, ver, numBytes) +} + +// DeriveKey is used to derive a symmetric key given a context and salt. This does not +// check the policies Derived flag, but just implements the derivation logic. GetKey +// is responsible for switching on the policy config. +func (p *Policy) DeriveKey(context, salt []byte, ver int, numBytes int) ([]byte, error) { if !p.Type.DerivationSupported() { return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)} } @@ -723,10 +730,10 @@ func (p *Policy) DeriveKey(context []byte, ver, numBytes int) ([]byte, error) { case Kdf_hmac_sha256_counter: prf := kdf.HMACSHA256PRF prfLen := kdf.HMACSHA256PRFLen - return kdf.CounterMode(prf, prfLen, p.Keys[strconv.Itoa(ver)].Key, context, 256) + return kdf.CounterMode(prf, prfLen, p.Keys[strconv.Itoa(ver)].Key, append(context, salt...), 256) case Kdf_hkdf_sha256: - reader := hkdf.New(sha256.New, p.Keys[strconv.Itoa(ver)].Key, nil, context) + reader := hkdf.New(sha256.New, p.Keys[strconv.Itoa(ver)].Key, salt, context) derBytes := bytes.NewBuffer(nil) derBytes.Grow(numBytes) limReader := &io.LimitedReader{ @@ -809,7 +816,6 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305: hmacKey := context - var aead cipher.AEAD var encKey []byte var deriveHMAC bool @@ -823,7 +829,7 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, encBytes = 16 } - key, err := p.DeriveKey(context, ver, encBytes+hmacBytes) + key, err := p.GetKey(context, ver, encBytes+hmacBytes) if err != nil { return "", err } @@ -843,65 +849,16 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, } } - switch p.Type { - case KeyType_AES128_GCM96, KeyType_AES256_GCM96: - // Setup the cipher - aesCipher, err := aes.NewCipher(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, + SymmetricOpts{ + Convergent: p.ConvergentEncryption, + HMACKey: hmacKey, + Nonce: nonce, + }) - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = gcm - - case KeyType_ChaCha20_Poly1305: - cha, err := chacha20poly1305.New(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = cha + if err != nil { + return "", err } - - if p.ConvergentEncryption { - convergentVersion := p.convergentVersion(ver) - switch convergentVersion { - case 1: - if len(nonce) != aead.NonceSize() { - return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())} - } - case 2, 3: - if len(hmacKey) == 0 { - return "", errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")} - } - nonceHmac := hmac.New(sha256.New, hmacKey) - nonceHmac.Write(plaintext) - nonceSum := nonceHmac.Sum(nil) - nonce = nonceSum[:aead.NonceSize()] - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)} - } - } else { - // Compute random nonce - nonce, err = uuid.GenerateRandomBytes(aead.NonceSize()) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - } - - // Encrypt and tag with AEAD - ciphertext = aead.Seal(nil, nonce, plaintext, nil) - - // Place the encrypted data after the nonce - if !p.ConvergentEncryption || p.convergentVersion(ver) > 1 { - ciphertext = append(nonce, ciphertext...) - } - case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096: key := p.Keys[strconv.Itoa(ver)].RSAKey ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &key.PublicKey, plaintext, nil) @@ -976,14 +933,12 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { switch p.Type { case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305: - var aead cipher.AEAD - numBytes := 32 if p.Type == KeyType_AES128_GCM96 { numBytes = 16 } - encKey, err := p.DeriveKey(context, ver, numBytes) + encKey, err := p.GetKey(context, ver, numBytes) if err != nil { return "", err } @@ -992,50 +947,14 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { return "", errutil.InternalError{Err: "could not derive enc key, length not correct"} } - switch p.Type { - case KeyType_AES128_GCM96, KeyType_AES256_GCM96: - // Setup the cipher - aesCipher, err := aes.NewCipher(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = gcm - - case KeyType_ChaCha20_Poly1305: - cha, err := chacha20poly1305.New(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = cha - } - - if len(decoded) < aead.NonceSize() { - return "", errutil.UserError{Err: "invalid ciphertext length"} - } - - // Extract the nonce and ciphertext - var ciphertext []byte - if p.ConvergentEncryption && convergentVersion == 1 { - ciphertext = decoded - } else { - nonce = decoded[:aead.NonceSize()] - ciphertext = decoded[aead.NonceSize():] - } - - // Verify and Decrypt - plain, err = aead.Open(nil, nonce, ciphertext, nil) + plain, err = p.SymmetricDecryptRaw(encKey, decoded, + SymmetricOpts{ + Convergent: p.ConvergentEncryption, + ConvergentVersion: p.ConvergentVersion, + }) if err != nil { - return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"} + return "", err } - case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096: key := p.Keys[strconv.Itoa(ver)].RSAKey plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil) @@ -1156,7 +1075,7 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si if p.Derived { // Derive the key that should be used var err error - key, err = p.DeriveKey(context, ver, 32) + key, err = p.GetKey(context, ver, 32) if err != nil { return nil, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)} } @@ -1325,7 +1244,7 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, if p.Derived { // Derive the key that should be used var err error - key, err = p.DeriveKey(context, ver, 32) + key, err = p.GetKey(context, ver, 32) if err != nil { return false, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)} } @@ -1596,3 +1515,136 @@ func (p *Policy) getVersionPrefix(ver int) string { return prefix } + +// SymmetricOpts are the arguments to symmetric operations that are "optional", e.g. +// not always used. This improves the aesthetics of calls to those functions. +type SymmetricOpts struct { + // Whether to use convergent encryption + Convergent bool + // The version of the convergent encryption scheme + ConvergentVersion int + // The nonce, if not randomly generated + Nonce []byte + // Additional data to include in AEAD authentication + AdditionalData []byte + // The HMAC key, for generating IVs in convergent encryption + HMACKey []byte +} + +// Symmetrically encrypt a plaintext given the convergence configuration and appropriate keys +func (p *Policy) SymmetricEncryptRaw(ver int, encKey, plaintext []byte, opts SymmetricOpts) ([]byte, error) { + var aead cipher.AEAD + var err error + nonce := opts.Nonce + + switch p.Type { + case KeyType_AES128_GCM96, KeyType_AES256_GCM96: + // Setup the cipher + aesCipher, err := aes.NewCipher(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = gcm + + case KeyType_ChaCha20_Poly1305: + cha, err := chacha20poly1305.New(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = cha + } + + if opts.Convergent { + convergentVersion := p.convergentVersion(ver) + switch convergentVersion { + case 1: + if len(opts.Nonce) != aead.NonceSize() { + return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())} + } + case 2, 3: + if len(opts.HMACKey) == 0 { + return nil, errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")} + } + nonceHmac := hmac.New(sha256.New, opts.HMACKey) + nonceHmac.Write(plaintext) + nonceSum := nonceHmac.Sum(nil) + nonce = nonceSum[:aead.NonceSize()] + default: + return nil, errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)} + } + } else if len(nonce) == 0 { + // Compute random nonce + nonce, err = uuid.GenerateRandomBytes(aead.NonceSize()) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + } + + // Encrypt and tag with AEAD + ciphertext := aead.Seal(nil, nonce, plaintext, opts.AdditionalData) + + // Place the encrypted data after the nonce + if !opts.Convergent || p.convergentVersion(ver) > 1 { + ciphertext = append(nonce, ciphertext...) + } + return ciphertext, nil +} + +// Symmetrically decrypt a ciphertext given the convergence configuration and appropriate keys +func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOpts) ([]byte, error) { + var aead cipher.AEAD + var nonce []byte + + switch p.Type { + case KeyType_AES128_GCM96, KeyType_AES256_GCM96: + // Setup the cipher + aesCipher, err := aes.NewCipher(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = gcm + + case KeyType_ChaCha20_Poly1305: + cha, err := chacha20poly1305.New(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = cha + } + + if len(ciphertext) < aead.NonceSize() { + return nil, errutil.UserError{Err: "invalid ciphertext length"} + } + + // Extract the nonce and ciphertext + var trueCT []byte + if opts.Convergent && opts.ConvergentVersion == 1 { + trueCT = ciphertext + } else { + nonce = ciphertext[:aead.NonceSize()] + trueCT = ciphertext[aead.NonceSize():] + } + + // Verify and Decrypt + plain, err := aead.Open(nil, nonce, trueCT, opts.AdditionalData) + if err != nil { + return nil, err + } + return plain, nil +} diff --git a/sdk/helper/keysutil/policy_test.go b/sdk/helper/keysutil/policy_test.go index 341ecda701..40af77d8e6 100644 --- a/sdk/helper/keysutil/policy_test.go +++ b/sdk/helper/keysutil/policy_test.go @@ -1,6 +1,7 @@ package keysutil import ( + "bytes" "context" "crypto/rand" "reflect" @@ -613,3 +614,30 @@ func Test_BadArchive(t *testing.T) { t.Fatalf("unexpected key length %d", len(p.Keys)) } } + +func BenchmarkSymmetric(b *testing.B) { + ctx := context.Background() + lm, _ := NewLockManager(true, 0) + storage := &logical.InmemStorage{} + p, _, _ := lm.GetPolicy(ctx, PolicyRequest{ + Upsert: true, + Storage: storage, + KeyType: KeyType_AES256_GCM96, + Name: "test", + }, rand.Reader) + key, _ := p.GetKey(nil, 1, 32) + pt := make([]byte, 10) + ad := make([]byte, 10) + for i := 0; i < b.N; i++ { + ct, _ := p.SymmetricEncryptRaw(1, key, pt, + SymmetricOpts{ + AdditionalData: ad, + }) + pt2, _ := p.SymmetricDecryptRaw(key, ct, SymmetricOpts{ + AdditionalData: ad, + }) + if !bytes.Equal(pt, pt2) { + b.Fail() + } + } +} diff --git a/vendor/github.com/hashicorp/vault/sdk/helper/keysutil/policy.go b/vendor/github.com/hashicorp/vault/sdk/helper/keysutil/policy.go index 949c626505..5e3e8d7c0d 100644 --- a/vendor/github.com/hashicorp/vault/sdk/helper/keysutil/policy.go +++ b/vendor/github.com/hashicorp/vault/sdk/helper/keysutil/policy.go @@ -692,16 +692,23 @@ func (p *Policy) Upgrade(ctx context.Context, storage logical.Storage, randReade return nil } -// DeriveKey is used to derive the encryption key that should be used depending +// GetKey is used to derive the encryption key that should be used depending // on the policy. If derivation is disabled the raw key is used and no context // is required, otherwise the KDF mode is used with the context to derive the // proper key. -func (p *Policy) DeriveKey(context []byte, ver, numBytes int) ([]byte, error) { +func (p *Policy) GetKey(context []byte, ver, numBytes int) ([]byte, error) { // Fast-path non-derived keys if !p.Derived { return p.Keys[strconv.Itoa(ver)].Key, nil } + return p.DeriveKey(context, nil, ver, numBytes) +} + +// DeriveKey is used to derive a symmetric key given a context and salt. This does not +// check the policies Derived flag, but just implements the derivation logic. GetKey +// is responsible for switching on the policy config. +func (p *Policy) DeriveKey(context, salt []byte, ver int, numBytes int) ([]byte, error) { if !p.Type.DerivationSupported() { return nil, errutil.UserError{Err: fmt.Sprintf("derivation not supported for key type %v", p.Type)} } @@ -723,10 +730,10 @@ func (p *Policy) DeriveKey(context []byte, ver, numBytes int) ([]byte, error) { case Kdf_hmac_sha256_counter: prf := kdf.HMACSHA256PRF prfLen := kdf.HMACSHA256PRFLen - return kdf.CounterMode(prf, prfLen, p.Keys[strconv.Itoa(ver)].Key, context, 256) + return kdf.CounterMode(prf, prfLen, p.Keys[strconv.Itoa(ver)].Key, append(context, salt...), 256) case Kdf_hkdf_sha256: - reader := hkdf.New(sha256.New, p.Keys[strconv.Itoa(ver)].Key, nil, context) + reader := hkdf.New(sha256.New, p.Keys[strconv.Itoa(ver)].Key, salt, context) derBytes := bytes.NewBuffer(nil) derBytes.Grow(numBytes) limReader := &io.LimitedReader{ @@ -809,7 +816,6 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305: hmacKey := context - var aead cipher.AEAD var encKey []byte var deriveHMAC bool @@ -823,7 +829,7 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, encBytes = 16 } - key, err := p.DeriveKey(context, ver, encBytes+hmacBytes) + key, err := p.GetKey(context, ver, encBytes+hmacBytes) if err != nil { return "", err } @@ -843,65 +849,16 @@ func (p *Policy) Encrypt(ver int, context, nonce []byte, value string) (string, } } - switch p.Type { - case KeyType_AES128_GCM96, KeyType_AES256_GCM96: - // Setup the cipher - aesCipher, err := aes.NewCipher(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } + ciphertext, err = p.SymmetricEncryptRaw(ver, encKey, plaintext, + SymmetricOpts{ + Convergent: p.ConvergentEncryption, + HMACKey: hmacKey, + Nonce: nonce, + }) - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = gcm - - case KeyType_ChaCha20_Poly1305: - cha, err := chacha20poly1305.New(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = cha + if err != nil { + return "", err } - - if p.ConvergentEncryption { - convergentVersion := p.convergentVersion(ver) - switch convergentVersion { - case 1: - if len(nonce) != aead.NonceSize() { - return "", errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())} - } - case 2, 3: - if len(hmacKey) == 0 { - return "", errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")} - } - nonceHmac := hmac.New(sha256.New, hmacKey) - nonceHmac.Write(plaintext) - nonceSum := nonceHmac.Sum(nil) - nonce = nonceSum[:aead.NonceSize()] - default: - return "", errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)} - } - } else { - // Compute random nonce - nonce, err = uuid.GenerateRandomBytes(aead.NonceSize()) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - } - - // Encrypt and tag with AEAD - ciphertext = aead.Seal(nil, nonce, plaintext, nil) - - // Place the encrypted data after the nonce - if !p.ConvergentEncryption || p.convergentVersion(ver) > 1 { - ciphertext = append(nonce, ciphertext...) - } - case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096: key := p.Keys[strconv.Itoa(ver)].RSAKey ciphertext, err = rsa.EncryptOAEP(sha256.New(), rand.Reader, &key.PublicKey, plaintext, nil) @@ -976,14 +933,12 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { switch p.Type { case KeyType_AES128_GCM96, KeyType_AES256_GCM96, KeyType_ChaCha20_Poly1305: - var aead cipher.AEAD - numBytes := 32 if p.Type == KeyType_AES128_GCM96 { numBytes = 16 } - encKey, err := p.DeriveKey(context, ver, numBytes) + encKey, err := p.GetKey(context, ver, numBytes) if err != nil { return "", err } @@ -992,50 +947,14 @@ func (p *Policy) Decrypt(context, nonce []byte, value string) (string, error) { return "", errutil.InternalError{Err: "could not derive enc key, length not correct"} } - switch p.Type { - case KeyType_AES128_GCM96, KeyType_AES256_GCM96: - // Setup the cipher - aesCipher, err := aes.NewCipher(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = gcm - - case KeyType_ChaCha20_Poly1305: - cha, err := chacha20poly1305.New(encKey) - if err != nil { - return "", errutil.InternalError{Err: err.Error()} - } - - aead = cha - } - - if len(decoded) < aead.NonceSize() { - return "", errutil.UserError{Err: "invalid ciphertext length"} - } - - // Extract the nonce and ciphertext - var ciphertext []byte - if p.ConvergentEncryption && convergentVersion == 1 { - ciphertext = decoded - } else { - nonce = decoded[:aead.NonceSize()] - ciphertext = decoded[aead.NonceSize():] - } - - // Verify and Decrypt - plain, err = aead.Open(nil, nonce, ciphertext, nil) + plain, err = p.SymmetricDecryptRaw(encKey, decoded, + SymmetricOpts{ + Convergent: p.ConvergentEncryption, + ConvergentVersion: p.ConvergentVersion, + }) if err != nil { - return "", errutil.UserError{Err: "invalid ciphertext: unable to decrypt"} + return "", err } - case KeyType_RSA2048, KeyType_RSA3072, KeyType_RSA4096: key := p.Keys[strconv.Itoa(ver)].RSAKey plain, err = rsa.DecryptOAEP(sha256.New(), rand.Reader, key, decoded, nil) @@ -1156,7 +1075,7 @@ func (p *Policy) Sign(ver int, context, input []byte, hashAlgorithm HashType, si if p.Derived { // Derive the key that should be used var err error - key, err = p.DeriveKey(context, ver, 32) + key, err = p.GetKey(context, ver, 32) if err != nil { return nil, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)} } @@ -1325,7 +1244,7 @@ func (p *Policy) VerifySignature(context, input []byte, hashAlgorithm HashType, if p.Derived { // Derive the key that should be used var err error - key, err = p.DeriveKey(context, ver, 32) + key, err = p.GetKey(context, ver, 32) if err != nil { return false, errutil.InternalError{Err: fmt.Sprintf("error deriving key: %v", err)} } @@ -1596,3 +1515,136 @@ func (p *Policy) getVersionPrefix(ver int) string { return prefix } + +// SymmetricOpts are the arguments to symmetric operations that are "optional", e.g. +// not always used. This improves the aesthetics of calls to those functions. +type SymmetricOpts struct { + // Whether to use convergent encryption + Convergent bool + // The version of the convergent encryption scheme + ConvergentVersion int + // The nonce, if not randomly generated + Nonce []byte + // Additional data to include in AEAD authentication + AdditionalData []byte + // The HMAC key, for generating IVs in convergent encryption + HMACKey []byte +} + +// Symmetrically encrypt a plaintext given the convergence configuration and appropriate keys +func (p *Policy) SymmetricEncryptRaw(ver int, encKey, plaintext []byte, opts SymmetricOpts) ([]byte, error) { + var aead cipher.AEAD + var err error + nonce := opts.Nonce + + switch p.Type { + case KeyType_AES128_GCM96, KeyType_AES256_GCM96: + // Setup the cipher + aesCipher, err := aes.NewCipher(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = gcm + + case KeyType_ChaCha20_Poly1305: + cha, err := chacha20poly1305.New(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = cha + } + + if opts.Convergent { + convergentVersion := p.convergentVersion(ver) + switch convergentVersion { + case 1: + if len(opts.Nonce) != aead.NonceSize() { + return nil, errutil.UserError{Err: fmt.Sprintf("base64-decoded nonce must be %d bytes long when using convergent encryption with this key", aead.NonceSize())} + } + case 2, 3: + if len(opts.HMACKey) == 0 { + return nil, errutil.InternalError{Err: fmt.Sprintf("invalid hmac key length of zero")} + } + nonceHmac := hmac.New(sha256.New, opts.HMACKey) + nonceHmac.Write(plaintext) + nonceSum := nonceHmac.Sum(nil) + nonce = nonceSum[:aead.NonceSize()] + default: + return nil, errutil.InternalError{Err: fmt.Sprintf("unhandled convergent version %d", convergentVersion)} + } + } else if len(nonce) == 0 { + // Compute random nonce + nonce, err = uuid.GenerateRandomBytes(aead.NonceSize()) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + } + + // Encrypt and tag with AEAD + ciphertext := aead.Seal(nil, nonce, plaintext, opts.AdditionalData) + + // Place the encrypted data after the nonce + if !opts.Convergent || p.convergentVersion(ver) > 1 { + ciphertext = append(nonce, ciphertext...) + } + return ciphertext, nil +} + +// Symmetrically decrypt a ciphertext given the convergence configuration and appropriate keys +func (p *Policy) SymmetricDecryptRaw(encKey, ciphertext []byte, opts SymmetricOpts) ([]byte, error) { + var aead cipher.AEAD + var nonce []byte + + switch p.Type { + case KeyType_AES128_GCM96, KeyType_AES256_GCM96: + // Setup the cipher + aesCipher, err := aes.NewCipher(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = gcm + + case KeyType_ChaCha20_Poly1305: + cha, err := chacha20poly1305.New(encKey) + if err != nil { + return nil, errutil.InternalError{Err: err.Error()} + } + + aead = cha + } + + if len(ciphertext) < aead.NonceSize() { + return nil, errutil.UserError{Err: "invalid ciphertext length"} + } + + // Extract the nonce and ciphertext + var trueCT []byte + if opts.Convergent && opts.ConvergentVersion == 1 { + trueCT = ciphertext + } else { + nonce = ciphertext[:aead.NonceSize()] + trueCT = ciphertext[aead.NonceSize():] + } + + // Verify and Decrypt + plain, err := aead.Open(nil, nonce, trueCT, opts.AdditionalData) + if err != nil { + return nil, err + } + return plain, nil +} diff --git a/vendor/modules.txt b/vendor/modules.txt index c4354c4eae..5b6f4fc931 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -508,9 +508,9 @@ github.com/hashicorp/vault-plugin-secrets-mongodbatlas # github.com/hashicorp/vault-plugin-secrets-openldap v0.1.5 github.com/hashicorp/vault-plugin-secrets-openldap github.com/hashicorp/vault-plugin-secrets-openldap/client -# github.com/hashicorp/vault/api v1.0.5-0.20200805123347-1ef507638af6 => ./api +# github.com/hashicorp/vault/api v1.0.5-0.20201001211907-38d91b749c77 => ./api github.com/hashicorp/vault/api -# github.com/hashicorp/vault/sdk v0.1.14-0.20200916184745-5576096032f8 => ./sdk +# github.com/hashicorp/vault/sdk v0.1.14-0.20201001211907-38d91b749c77 => ./sdk github.com/hashicorp/vault/sdk/database/dbplugin github.com/hashicorp/vault/sdk/database/helper/connutil github.com/hashicorp/vault/sdk/database/helper/credsutil