From e84e7b2d76e288864e8952abf638371eafbb4483 Mon Sep 17 00:00:00 2001 From: Jason O'Donnell <2160810+jasonodonnell@users.noreply.github.com> Date: Tue, 23 Feb 2021 19:48:39 -0500 Subject: [PATCH] agent: add caching encryption package (#10986) * agent: add caching encryption package * Fix documentation * Add GetKey, GetPersistentKey * Remove chan from interface * Add error to interface * Fix tests --- command/agent/cache/crypto/crypto.go | 18 +++ command/agent/cache/crypto/k8s.go | 97 ++++++++++++++++ command/agent/cache/crypto/k8s_test.go | 155 +++++++++++++++++++++++++ 3 files changed, 270 insertions(+) create mode 100644 command/agent/cache/crypto/crypto.go create mode 100644 command/agent/cache/crypto/k8s.go create mode 100644 command/agent/cache/crypto/k8s_test.go diff --git a/command/agent/cache/crypto/crypto.go b/command/agent/cache/crypto/crypto.go new file mode 100644 index 0000000000..85e20c0373 --- /dev/null +++ b/command/agent/cache/crypto/crypto.go @@ -0,0 +1,18 @@ +package crypto + +import ( + "context" +) + +const ( + KeyID = "root" +) + +type KeyManager interface { + GetKey() []byte + GetPersistentKey() ([]byte, error) + Renewable() bool + Renewer(context.Context) error + Encrypt(context.Context, []byte, []byte) ([]byte, error) + Decrypt(context.Context, []byte, []byte) ([]byte, error) +} diff --git a/command/agent/cache/crypto/k8s.go b/command/agent/cache/crypto/k8s.go new file mode 100644 index 0000000000..36a6d86b8c --- /dev/null +++ b/command/agent/cache/crypto/k8s.go @@ -0,0 +1,97 @@ +package crypto + +import ( + "context" + "crypto/rand" + "fmt" + + wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/go-kms-wrapping/wrappers/aead" +) + +var _ KeyManager = (*KubeEncryptionKey)(nil) + +type KubeEncryptionKey struct { + renewable bool + wrapper *aead.Wrapper +} + +// NewK8s returns a new instance of the Kube encryption key. Kubernetes +// encryption keys aren't renewable. +func NewK8s(existingKey []byte) (*KubeEncryptionKey, error) { + k := &KubeEncryptionKey{ + renewable: false, + wrapper: aead.NewWrapper(nil), + } + + k.wrapper.SetConfig(map[string]string{"key_id": KeyID}) + + var rootKey []byte = nil + if len(existingKey) != 0 { + if len(existingKey) != 32 { + return k, fmt.Errorf("invalid key size, should be 32, got %d", len(existingKey)) + } + rootKey = existingKey + } + + if rootKey == nil { + newKey := make([]byte, 32) + _, err := rand.Read(newKey) + if err != nil { + return k, err + } + rootKey = newKey + } + + if err := k.wrapper.SetAESGCMKeyBytes(rootKey); err != nil { + return k, err + } + + return k, nil +} + +// GetKey returns the encryption key in a format optimized for storage. +// In k8s we store the key as is, so just return the key stored. +func (k *KubeEncryptionKey) GetKey() []byte { + return k.wrapper.GetKeyBytes() +} + +// GetPersistentKey returns the key which should be stored in the persisent +// cache. In k8s we store the key as is, so just return the key stored. +func (k *KubeEncryptionKey) GetPersistentKey() ([]byte, error) { + return k.wrapper.GetKeyBytes(), nil +} + +// Renewable lets the caller know if this encryption key type is +// renewable. In Kubernetes the key isn't renewable. +func (k *KubeEncryptionKey) Renewable() bool { + return k.renewable +} + +// Renewer is used when the encryption key type is renewable. Since Kubernetes +// keys aren't renewable, returning nothing. +func (k *KubeEncryptionKey) Renewer(ctx context.Context) error { + return nil +} + +// Encrypt takes plaintext values and encrypts them using the store key and additional +// data. For Kubernetes the AAD should be the service account JWT. +func (k *KubeEncryptionKey) Encrypt(ctx context.Context, plaintext, aad []byte) ([]byte, error) { + blob, err := k.wrapper.Encrypt(ctx, plaintext, aad) + if err != nil { + return nil, err + } + return blob.Ciphertext, nil +} + +// Decrypt takes ciphertext and AAD values and returns the decrypted value. For Kubernetes the AAD +// should be the service account JWT. +func (k *KubeEncryptionKey) Decrypt(ctx context.Context, ciphertext, aad []byte) ([]byte, error) { + blob := &wrapping.EncryptedBlobInfo{ + Ciphertext: ciphertext, + KeyInfo: &wrapping.KeyInfo{ + KeyID: KeyID, + }, + } + return k.wrapper.Decrypt(ctx, blob, aad) +} diff --git a/command/agent/cache/crypto/k8s_test.go b/command/agent/cache/crypto/k8s_test.go new file mode 100644 index 0000000000..01b2f883f8 --- /dev/null +++ b/command/agent/cache/crypto/k8s_test.go @@ -0,0 +1,155 @@ +package crypto + +import ( + "fmt" + "math/rand" + "testing" +) + +func TestCrypto_KubernetesNewKey(t *testing.T) { + k8sKey, err := NewK8s([]byte{}) + if err != nil { + t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) + } + + key := k8sKey.GetKey() + if key == nil { + t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key)) + } + + persistentKey, _ := k8sKey.GetPersistentKey() + if persistentKey == nil { + t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", persistentKey)) + } + + if string(key) != string(persistentKey) { + t.Fatalf("keys don't match, they should: key: %s, persistentKey: %s", key, persistentKey) + } + + plaintextInput := []byte("test") + aad := []byte("kubernetes") + + ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad) + if err != nil { + t.Fatalf(err.Error()) + } + + if ciphertext == nil { + t.Fatalf("ciphertext nil, it shouldn't be") + } + + plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad) + if err != nil { + t.Fatalf(err.Error()) + } + + if string(plaintext) != string(plaintextInput) { + t.Fatalf("expected %s, got %s", plaintextInput, plaintext) + } +} + +func TestCrypto_KubernetesExistingKey(t *testing.T) { + rootKey := make([]byte, 32) + n, err := rand.Read(rootKey) + if err != nil { + t.Fatal(err) + } + if n != 32 { + t.Fatal(n) + } + + k8sKey, err := NewK8s(rootKey) + if err != nil { + t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) + } + + key := k8sKey.GetKey() + if key == nil { + t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", key)) + } + + if string(key) != string(rootKey) { + t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, key)) + } + + persistentKey, _ := k8sKey.GetPersistentKey() + if persistentKey == nil { + t.Fatalf("key is nil, it shouldn't be") + } + + if string(persistentKey) != string(rootKey) { + t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: expected: %s, got: %s", rootKey, persistentKey)) + } + + if string(key) != string(persistentKey) { + t.Fatalf(fmt.Sprintf("expected keys to be the same, they weren't: %s %s", rootKey, persistentKey)) + } + + plaintextInput := []byte("test") + aad := []byte("kubernetes") + + ciphertext, err := k8sKey.Encrypt(nil, plaintextInput, aad) + if err != nil { + t.Fatalf(err.Error()) + } + + if ciphertext == nil { + t.Fatalf("ciphertext nil, it shouldn't be") + } + + plaintext, err := k8sKey.Decrypt(nil, ciphertext, aad) + if err != nil { + t.Fatalf(err.Error()) + } + + if string(plaintext) != string(plaintextInput) { + t.Fatalf("expected %s, got %s", plaintextInput, plaintext) + } +} + +func TestCrypto_KubernetesPassGeneratedKey(t *testing.T) { + k8sFirstKey, err := NewK8s([]byte{}) + if err != nil { + t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) + } + + firstPersistentKey := k8sFirstKey.GetKey() + if firstPersistentKey == nil { + t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", firstPersistentKey)) + } + + plaintextInput := []byte("test") + aad := []byte("kubernetes") + + ciphertext, err := k8sFirstKey.Encrypt(nil, plaintextInput, aad) + if err != nil { + t.Fatalf(err.Error()) + } + + if ciphertext == nil { + t.Fatalf("ciphertext nil, it shouldn't be") + } + + k8sLoadedKey, err := NewK8s(firstPersistentKey) + if err != nil { + t.Fatalf(fmt.Sprintf("unexpected error: %s", err)) + } + + loadedKey, _ := k8sLoadedKey.GetPersistentKey() + if loadedKey == nil { + t.Fatalf(fmt.Sprintf("key is nil, it shouldn't be: %s", loadedKey)) + } + + if string(loadedKey) != string(firstPersistentKey) { + t.Fatalf(fmt.Sprintf("keys do not match")) + } + + plaintext, err := k8sLoadedKey.Decrypt(nil, ciphertext, aad) + if err != nil { + t.Fatalf(err.Error()) + } + + if string(plaintext) != string(plaintextInput) { + t.Fatalf("expected %s, got %s", plaintextInput, plaintext) + } +}