diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index bfe348b8cc..9ff64896c7 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -15,15 +15,19 @@ func Backend() *framework.Backend { PathsSpecial: &logical.Paths{ Root: []string{ "keys/*", - "raw/*", }, }, Paths: []*framework.Path{ + // Rotate/Config needs to come before Keys + // as the handler is greedy + pathConfig(), + pathRotate(), + pathRewrap(), pathKeys(), - pathRaw(), pathEncrypt(), pathDecrypt(), + pathDatakey(), }, Secrets: []*framework.Secret{}, diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 88caed93fb..5081644a24 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -3,6 +3,8 @@ package transit import ( "encoding/base64" "fmt" + "strconv" + "strings" "testing" "github.com/hashicorp/vault/logical" @@ -21,12 +23,90 @@ func TestBackend_basic(t *testing.T) { Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", false), testAccStepReadPolicy(t, "test", false, false), - testAccStepReadRaw(t, "test", false, false), testAccStepEncrypt(t, "test", testPlaintext, decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepDeleteNotDisabledPolicy(t, "test"), + testAccStepEnableDeletion(t, "test"), + testAccStepDeletePolicy(t, "test"), + testAccStepWritePolicy(t, "test", false), + testAccStepEnableDeletion(t, "test"), + testAccStepDisableDeletion(t, "test"), + testAccStepDeleteNotDisabledPolicy(t, "test"), + testAccStepEnableDeletion(t, "test"), + testAccStepDeletePolicy(t, "test"), + testAccStepReadPolicy(t, "test", true, false), + }, + }) +} + +func TestBackend_datakey(t *testing.T) { + dataKeyInfo := make(map[string]interface{}) + logicaltest.Test(t, logicaltest.TestCase{ + Backend: Backend(), + Steps: []logicaltest.TestStep{ + testAccStepWritePolicy(t, "test", false), + testAccStepReadPolicy(t, "test", false, false), + testAccStepWriteDatakey(t, "test", false, 256, dataKeyInfo), + testAccStepDecryptDatakey(t, "test", dataKeyInfo), + testAccStepWriteDatakey(t, "test", true, 128, dataKeyInfo), + }, + }) +} + +func TestBackend_rotation(t *testing.T) { + decryptData := make(map[string]interface{}) + encryptHistory := make(map[int]map[string]interface{}) + logicaltest.Test(t, logicaltest.TestCase{ + Backend: Backend(), + Steps: []logicaltest.TestStep{ + testAccStepWritePolicy(t, "test", false), + testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 0, encryptHistory), + testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 1, encryptHistory), + testAccStepRotate(t, "test"), // now v2 + testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 2, encryptHistory), + testAccStepRotate(t, "test"), // now v3 + testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 3, encryptHistory), + testAccStepRotate(t, "test"), // now v4 + testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 4, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepEncryptVX(t, "test", testPlaintext, decryptData, 99, encryptHistory), + testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 2, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 3, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 99, encryptHistory), + testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 4, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepDeleteNotDisabledPolicy(t, "test"), + testAccStepAdjustPolicy(t, "test", 3), + testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory), + testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory), + testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 2, encryptHistory), + testAccStepDecryptExpectFailure(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 3, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 4, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepAdjustPolicy(t, "test", 1), + testAccStepLoadVX(t, "test", decryptData, 0, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 1, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepLoadVX(t, "test", decryptData, 2, encryptHistory), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepRewrap(t, "test", decryptData, 4), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepEnableDeletion(t, "test"), testAccStepDeletePolicy(t, "test"), testAccStepReadPolicy(t, "test", true, false), - testAccStepReadRaw(t, "test", true, false), }, }) } @@ -40,6 +120,7 @@ func TestBackend_upsert(t *testing.T) { testAccStepEncrypt(t, "test", testPlaintext, decryptData), testAccStepReadPolicy(t, "test", false, false), testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepEnableDeletion(t, "test"), testAccStepDeletePolicy(t, "test"), testAccStepReadPolicy(t, "test", true, false), }, @@ -53,12 +134,11 @@ func TestBackend_basic_derived(t *testing.T) { Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test", true), testAccStepReadPolicy(t, "test", false, true), - testAccStepReadRaw(t, "test", false, true), testAccStepEncryptContext(t, "test", testPlaintext, "my-cool-context", decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepEnableDeletion(t, "test"), testAccStepDeletePolicy(t, "test"), testAccStepReadPolicy(t, "test", true, true), - testAccStepReadRaw(t, "test", true, true), }, }) } @@ -73,6 +153,36 @@ func testAccStepWritePolicy(t *testing.T, name string, derived bool) logicaltest } } +func testAccStepAdjustPolicy(t *testing.T, name string, minVer int) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "keys/" + name + "/config", + Data: map[string]interface{}{ + "min_decryption_version": minVer, + }, + } +} + +func testAccStepDisableDeletion(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "keys/" + name + "/config", + Data: map[string]interface{}{ + "deletion_allowed": false, + }, + } +} + +func testAccStepEnableDeletion(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "keys/" + name + "/config", + Data: map[string]interface{}{ + "deletion_allowed": true, + }, + } +} + func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.DeleteOperation, @@ -80,6 +190,23 @@ func testAccStepDeletePolicy(t *testing.T, name string) logicaltest.TestStep { } } +func testAccStepDeleteNotDisabledPolicy(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.DeleteOperation, + Path: "keys/" + name, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if resp == nil { + return fmt.Errorf("Got nil response instead of error") + } + if resp.IsError() { + return nil + } + return fmt.Errorf("expected error but did not get one") + }, + } +} + func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.ReadOperation, @@ -94,11 +221,13 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool) return nil } var d struct { - Name string `mapstructure:"name"` - Key []byte `mapstructure:"key"` - CipherMode string `mapstructure:"cipher_mode"` - Derived bool `mapstructure:"derived"` - KDFMode string `mapstructure:"kdf_mode"` + Name string `mapstructure:"name"` + Key []byte `mapstructure:"key"` + Keys map[string]int64 `mapstructure:"keys"` + CipherMode string `mapstructure:"cipher_mode"` + Derived bool `mapstructure:"derived"` + KDFMode string `mapstructure:"kdf_mode"` + DeletionAllowed bool `mapstructure:"deletion_allowed"` } if err := mapstructure.Decode(resp.Data, &d); err != nil { return err @@ -114,48 +243,10 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone, derived bool) if d.Key != nil { return fmt.Errorf("bad: %#v", d) } - if d.Derived != derived { + if d.Keys == nil { return fmt.Errorf("bad: %#v", d) } - if derived && d.KDFMode != kdfMode { - return fmt.Errorf("bad: %#v", d) - } - return nil - }, - } -} - -func testAccStepReadRaw(t *testing.T, name string, expectNone, derived bool) logicaltest.TestStep { - return logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "raw/" + name, - Check: func(resp *logical.Response) error { - if resp == nil && !expectNone { - return fmt.Errorf("missing response") - } else if expectNone { - if resp != nil { - return fmt.Errorf("response when expecting none") - } - return nil - } - var d struct { - Name string `mapstructure:"name"` - Key []byte `mapstructure:"key"` - CipherMode string `mapstructure:"cipher_mode"` - Derived bool `mapstructure:"derived"` - KDFMode string `mapstructure:"kdf_mode"` - } - if err := mapstructure.Decode(resp.Data, &d); err != nil { - return err - } - - if d.Name != name { - return fmt.Errorf("bad: %#v", d) - } - if d.CipherMode != "aes-gcm" { - return fmt.Errorf("bad: %#v", d) - } - if len(d.Key) != 32 { + if d.DeletionAllowed == true { return fmt.Errorf("bad: %#v", d) } if d.Derived != derived { @@ -240,9 +331,192 @@ func testAccStepDecrypt( } if string(plainRaw) != plaintext { - return fmt.Errorf("plaintext mismatch: %s expect: %s", plainRaw, plaintext) + return fmt.Errorf("plaintext mismatch: %s expect: %s, decryptData was %#v", plainRaw, plaintext, decryptData) } return nil }, } } + +func testAccStepRewrap( + t *testing.T, name string, decryptData map[string]interface{}, expectedVer int) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "rewrap/" + name, + Data: decryptData, + Check: func(resp *logical.Response) error { + var d struct { + Ciphertext string `mapstructure:"ciphertext"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + if d.Ciphertext == "" { + return fmt.Errorf("missing ciphertext") + } + splitStrings := strings.Split(d.Ciphertext, ":") + verString := splitStrings[1][1:] + ver, err := strconv.Atoi(verString) + if err != nil { + return fmt.Errorf("Error pulling out version from verString '%s', ciphertext was %s", verString, d.Ciphertext) + } + if ver != expectedVer { + return fmt.Errorf("Did not get expected version") + } + decryptData["ciphertext"] = d.Ciphertext + return nil + }, + } +} + +func testAccStepEncryptVX( + t *testing.T, name, plaintext string, decryptData map[string]interface{}, + ver int, encryptHistory map[int]map[string]interface{}) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "encrypt/" + name, + Data: map[string]interface{}{ + "plaintext": base64.StdEncoding.EncodeToString([]byte(plaintext)), + }, + Check: func(resp *logical.Response) error { + var d struct { + Ciphertext string `mapstructure:"ciphertext"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + if d.Ciphertext == "" { + return fmt.Errorf("missing ciphertext") + } + splitStrings := strings.Split(d.Ciphertext, ":") + splitStrings[1] = "v" + strconv.Itoa(ver) + ciphertext := strings.Join(splitStrings, ":") + decryptData["ciphertext"] = ciphertext + encryptHistory[ver] = map[string]interface{}{ + "ciphertext": ciphertext, + } + return nil + }, + } +} + +func testAccStepLoadVX( + t *testing.T, name string, decryptData map[string]interface{}, + ver int, encryptHistory map[int]map[string]interface{}) logicaltest.TestStep { + // This is really a no-op to allow us to do data manip in the check function + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "keys/" + name, + Check: func(resp *logical.Response) error { + decryptData["ciphertext"] = encryptHistory[ver]["ciphertext"].(string) + return nil + }, + } +} + +func testAccStepDecryptExpectFailure( + t *testing.T, name, plaintext string, decryptData map[string]interface{}) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "decrypt/" + name, + Data: decryptData, + ErrorOk: true, + Check: func(resp *logical.Response) error { + if !resp.IsError() { + return fmt.Errorf("expected error") + } + return nil + }, + } +} + +func testAccStepRotate(t *testing.T, name string) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "keys/" + name + "/rotate", + } +} + +func testAccStepWriteDatakey(t *testing.T, name string, + noPlaintext bool, bits int, + dataKeyInfo map[string]interface{}) logicaltest.TestStep { + data := map[string]interface{}{} + subPath := "plaintext" + if noPlaintext { + subPath = "wrapped" + } + if bits != 256 { + data["bits"] = bits + } + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "datakey/" + subPath + "/" + name, + Data: data, + Check: func(resp *logical.Response) error { + var d struct { + Plaintext string `mapstructure:"plaintext"` + Ciphertext string `mapstructure:"ciphertext"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + if noPlaintext && len(d.Plaintext) != 0 { + return fmt.Errorf("received plaintxt when we disabled it") + } + if !noPlaintext { + if len(d.Plaintext) == 0 { + return fmt.Errorf("did not get plaintext when we expected it") + } + dataKeyInfo["plaintext"] = d.Plaintext + plainBytes, err := base64.StdEncoding.DecodeString(d.Plaintext) + if err != nil { + return fmt.Errorf("could not base64 decode plaintext string '%s'", d.Plaintext) + } + if len(plainBytes)*8 != bits { + return fmt.Errorf("returned key does not have correct bit length") + } + } + dataKeyInfo["ciphertext"] = d.Ciphertext + return nil + }, + } +} + +func testAccStepDecryptDatakey(t *testing.T, name string, + dataKeyInfo map[string]interface{}) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.WriteOperation, + Path: "decrypt/" + name, + Data: dataKeyInfo, + Check: func(resp *logical.Response) error { + var d struct { + Plaintext string `mapstructure:"plaintext"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + + if d.Plaintext != dataKeyInfo["plaintext"].(string) { + return fmt.Errorf("plaintext mismatch: got '%s', expected '%s', decryptData was %#v", d.Plaintext, dataKeyInfo["plaintext"].(string)) + } + return nil + }, + } +} + +func TestKeyUpgrade(t *testing.T) { + p := &Policy{ + Name: "test", + Key: []byte(testPlaintext), + CipherMode: "aes-gcm", + } + + p.migrateKeyToKeysMap() + + if p.Key != nil || + p.Keys == nil || + len(p.Keys) != 1 || + string(p.Keys[1].Key) != testPlaintext { + t.Errorf("bad key migration, result is %#v", p.Keys) + } +} diff --git a/builtin/logical/transit/path_config.go b/builtin/logical/transit/path_config.go new file mode 100644 index 0000000000..d368c64b9a --- /dev/null +++ b/builtin/logical/transit/path_config.go @@ -0,0 +1,86 @@ +package transit + +import ( + "fmt" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathConfig() *framework.Path { + return &framework.Path{ + Pattern: "keys/" + framework.GenericNameRegex("name") + "/config", + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the key", + }, + + "min_decryption_version": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `If set, the minimum version of the key allowed +to be decrypted.`, + }, + + "deletion_allowed": &framework.FieldSchema{ + Type: framework.TypeBool, + Description: "Whether to allow deletion of the key", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: pathConfigWrite, + }, + + HelpSynopsis: pathConfigHelpSyn, + HelpDescription: pathConfigHelpDesc, + } +} + +func pathConfigWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + // Check if the policy already exists + policy, err := getPolicy(req, name) + if err != nil { + return nil, err + } + if policy == nil { + return logical.ErrorResponse( + fmt.Sprintf("no existing role named %s could be found", name)), + logical.ErrInvalidRequest + } + + persistNeeded := false + + minDecryptionVersion := d.Get("min_decryption_version").(int) + if minDecryptionVersion != 0 && + minDecryptionVersion != policy.MinDecryptionVersion { + policy.MinDecryptionVersion = minDecryptionVersion + persistNeeded = true + } + + allowDeletionInt, ok := d.GetOk("deletion_allowed") + if ok { + allowDeletion := allowDeletionInt.(bool) + if allowDeletion != policy.DeletionAllowed { + policy.DeletionAllowed = allowDeletion + persistNeeded = true + } + } + + if !persistNeeded { + return nil, nil + } + + return nil, policy.Persist(req.Storage, name) +} + +const pathConfigHelpSyn = `Configure a named encryption key` + +const pathConfigHelpDesc = ` +This path is used to configure the named key. Currently, this +supports adjusting the minimum version of the key allowed to +be used for decryption via the min_decryption_version paramter. +` diff --git a/builtin/logical/transit/path_datakey.go b/builtin/logical/transit/path_datakey.go new file mode 100644 index 0000000000..43652733a3 --- /dev/null +++ b/builtin/logical/transit/path_datakey.go @@ -0,0 +1,142 @@ +package transit + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathDatakey() *framework.Path { + return &framework.Path{ + Pattern: "datakey/" + framework.GenericNameRegex("plaintext") + "/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "The backend key used for encrypting the data key", + }, + + "plaintext": &framework.FieldSchema{ + Type: framework.TypeString, + Description: `"plaintext" will return the key in both plaintext and +ciphertext; "wrapped" will return the ciphertext only.`, + }, + + "context": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Context for key derivation. Required for derived keys.", + }, + + "bits": &framework.FieldSchema{ + Type: framework.TypeInt, + Description: `Number of bits for the key; currently 128 and +256 are supported. Defaults to 256.`, + Default: 256, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: pathDatakeyWrite, + }, + + HelpSynopsis: pathDatakeyHelpSyn, + HelpDescription: pathDatakeyHelpDesc, + } +} + +func pathDatakeyWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + plaintext := d.Get("plaintext").(string) + plaintextAllowed := false + switch plaintext { + case "plaintext": + plaintextAllowed = true + case "wrapped": + default: + return logical.ErrorResponse("Invalid path, must be 'plaintext' or 'wrapped'"), logical.ErrInvalidRequest + } + + // Decode the context if any + contextRaw := d.Get("context").(string) + var context []byte + if len(contextRaw) != 0 { + var err error + context, err = base64.StdEncoding.DecodeString(contextRaw) + if err != nil { + return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest + } + } + + // Get the policy + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + + // Error if invalid policy + if p == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + } + + newKey := make([]byte, 32) + bits := d.Get("bits").(int) + switch bits { + case 512: + newKey = make([]byte, 64) + case 256: + case 128: + newKey = make([]byte, 16) + default: + return logical.ErrorResponse("invalid bit length"), logical.ErrInvalidRequest + } + _, err = rand.Read(newKey) + if err != nil { + return nil, err + } + + ciphertext, err := p.Encrypt(context, base64.StdEncoding.EncodeToString(newKey)) + if err != nil { + switch err.(type) { + case certutil.UserError: + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + case certutil.InternalError: + return nil, err + default: + return nil, err + } + } + + if ciphertext == "" { + return nil, fmt.Errorf("empty ciphertext returned") + } + + // Generate the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "ciphertext": ciphertext, + }, + } + + if plaintextAllowed { + resp.Data["plaintext"] = base64.StdEncoding.EncodeToString(newKey) + } + + return resp, nil +} + +const pathDatakeyHelpSyn = `Generate a data key` + +const pathDatakeyHelpDesc = ` +This path can be used to generate a data key: a random +key of a certain length that can be used for encryption +and decryption, protected by the named backend key. 128, 256, +or 512 bits can be specified; if not specified, the default +is 256 bits. Call with the the "wrapped" path to prevent the +(base64-encoded) plaintext key from being returned along with +the encrypted key, the "plaintext" path returns both. +` diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index 091e1da5d3..8ea94da554 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -1,11 +1,10 @@ package transit import ( - "crypto/aes" - "crypto/cipher" "encoding/base64" - "strings" + "fmt" + "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -42,8 +41,8 @@ func pathDecrypt() *framework.Path { func pathDecryptWrite( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - value := d.Get("ciphertext").(string) - if len(value) == 0 { + ciphertext := d.Get("ciphertext").(string) + if len(ciphertext) == 0 { return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest } @@ -69,56 +68,26 @@ func pathDecryptWrite( return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest } - // Derive the key that should be used - key, err := p.DeriveKey(context) + plaintext, err := p.Decrypt(context, ciphertext) if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + switch err.(type) { + case certutil.UserError: + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + case certutil.InternalError: + return nil, err + default: + return nil, err + } } - // Guard against a potentially invalid cipher-mode - switch p.CipherMode { - case "aes-gcm": - default: - return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest - } - - // Verify the prefix - if !strings.HasPrefix(value, "vault:v0:") { - return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest - } - - // Decode the base64 - decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(value, "vault:v0:")) - if err != nil { - return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest - } - - // Setup the cipher - aesCipher, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return nil, err - } - - // Extract the nonce and ciphertext - nonce := decoded[:gcm.NonceSize()] - ciphertext := decoded[gcm.NonceSize():] - - // Verify and Decrypt - plain, err := gcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return logical.ErrorResponse("invalid ciphertext"), logical.ErrInvalidRequest + if plaintext == "" { + return nil, fmt.Errorf("empty plaintext returned") } // Generate the response resp := &logical.Response{ Data: map[string]interface{}{ - "plaintext": base64.StdEncoding.EncodeToString(plain), + "plaintext": plaintext, }, } return resp, nil diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index add5371e96..7a402292dd 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -1,12 +1,10 @@ package transit import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" "encoding/base64" "fmt" + "github.com/hashicorp/vault/helper/certutil" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -48,10 +46,10 @@ func pathEncryptWrite( return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest } - // Decode the plaintext value - plaintext, err := base64.StdEncoding.DecodeString(value) + // Get the policy + p, err := getPolicy(req, name) if err != nil { - return logical.ErrorResponse("failed to decode plaintext as base64"), logical.ErrInvalidRequest + return nil, err } // Decode the context if any @@ -65,12 +63,6 @@ func pathEncryptWrite( } } - // Get the policy - p, err := getPolicy(req, name) - if err != nil { - return nil, err - } - // Error if invalid policy if p == nil { isDerived := len(context) != 0 @@ -80,54 +72,26 @@ func pathEncryptWrite( } } - // Derive the key that should be used - key, err := p.DeriveKey(context) + ciphertext, err := p.Encrypt(context, value) if err != nil { - return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + switch err.(type) { + case certutil.UserError: + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + case certutil.InternalError: + return nil, err + default: + return nil, err + } } - // Guard against a potentially invalid cipher-mode - switch p.CipherMode { - case "aes-gcm": - default: - return logical.ErrorResponse("unsupported cipher mode"), logical.ErrInvalidRequest + if ciphertext == "" { + return nil, fmt.Errorf("empty ciphertext returned") } - // Setup the cipher - aesCipher, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - - // Setup the GCM AEAD - gcm, err := cipher.NewGCM(aesCipher) - if err != nil { - return nil, err - } - - // Compute random nonce - nonce := make([]byte, gcm.NonceSize()) - _, err = rand.Read(nonce) - if err != nil { - return nil, err - } - - // Encrypt and tag with GCM - out := gcm.Seal(nil, nonce, plaintext, nil) - - // Place the encrypted data after the nonce - full := append(nonce, out...) - - // Convert to base64 - encoded := base64.StdEncoding.EncodeToString(full) - - // Prepend some information - encoded = "vault:v0:" + encoded - // Generate the response resp := &logical.Response{ Data: map[string]interface{}{ - "ciphertext": encoded, + "ciphertext": ciphertext, }, } return resp, nil diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 0b341b7f76..7c6d1d1573 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -1,126 +1,13 @@ package transit import ( - "crypto/rand" - "encoding/json" "fmt" + "strconv" - "github.com/hashicorp/vault/helper/kdf" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) -const ( - // kdfMode is the only KDF mode currently supported - kdfMode = "hmac-sha256-counter" -) - -// Policy is the struct used to store metadata -type Policy struct { - Name string `json:"name"` - Key []byte `json:"key"` - CipherMode string `json:"cipher"` - - // Derived keys MUST provide a context and the - // master underlying key is never used. - Derived bool `json:"derived"` - KDFMode string `json:"kdf_mode"` -} - -func (p *Policy) Serialize() ([]byte, error) { - return json.Marshal(p) -} - -// DeriveKey is used to derive the encryption key that should -// be used depending on the policy. If derivation is disabled the -// raw key is used and no context is required, otherwise the KDF -// mode is used with the context to derive the proper key. -func (p *Policy) DeriveKey(context []byte) ([]byte, error) { - // Fast-path non-derived keys - if !p.Derived { - return p.Key, nil - } - - // Ensure a context is provided - if len(context) == 0 { - return nil, fmt.Errorf("missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information.") - } - - switch p.KDFMode { - case kdfMode: - prf := kdf.HMACSHA256PRF - prfLen := kdf.HMACSHA256PRFLen - return kdf.CounterMode(prf, prfLen, p.Key, context, 256) - default: - return nil, fmt.Errorf("unsupported key derivation mode") - } -} - -func DeserializePolicy(buf []byte) (*Policy, error) { - p := new(Policy) - if err := json.Unmarshal(buf, p); err != nil { - return nil, err - } - return p, nil -} - -func getPolicy(req *logical.Request, name string) (*Policy, error) { - // Check if the policy already exists - raw, err := req.Storage.Get("policy/" + name) - if err != nil { - return nil, err - } - if raw == nil { - return nil, nil - } - - // Decode the policy - p, err := DeserializePolicy(raw.Value) - if err != nil { - return nil, err - } - return p, nil -} - -// generatePolicy is used to create a new named policy with -// a randomly generated key -func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) { - // Create the policy object - p := &Policy{ - Name: name, - CipherMode: "aes-gcm", - Derived: derived, - } - if derived { - p.KDFMode = kdfMode - } - - // Generate a 256bit key - p.Key = make([]byte, 32) - _, err := rand.Read(p.Key) - if err != nil { - return nil, err - } - - // Encode the policy - buf, err := p.Serialize() - if err != nil { - return nil, err - } - - // Write the policy into storage - err = storage.Put(&logical.StorageEntry{ - Key: "policy/" + name, - Value: buf, - }) - if err != nil { - return nil, err - } - - // Return the policy - return p, nil -} - func pathKeys() *framework.Path { return &framework.Path{ Pattern: "keys/" + framework.GenericNameRegex("name"), @@ -169,6 +56,7 @@ func pathPolicyWrite( func pathPolicyRead( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) + p, err := getPolicy(req, name) if err != nil { return nil, err @@ -180,14 +68,22 @@ func pathPolicyRead( // Return the response resp := &logical.Response{ Data: map[string]interface{}{ - "name": p.Name, - "cipher_mode": p.CipherMode, - "derived": p.Derived, + "name": p.Name, + "cipher_mode": p.CipherMode, + "derived": p.Derived, + "deletion_allowed": p.DeletionAllowed, }, } if p.Derived { resp.Data["kdf_mode"] = p.KDFMode } + + retKeys := map[string]int64{} + for k, v := range p.Keys { + retKeys[strconv.Itoa(k)] = v.CreationTime + } + resp.Data["keys"] = retKeys + return resp, nil } @@ -195,14 +91,26 @@ func pathPolicyDelete( req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) - err := req.Storage.Delete("policy/" + name) + p, err := getPolicy(req, name) if err != nil { - return nil, err + return logical.ErrorResponse(fmt.Sprintf("error looking up policy %s, error is %s", name, err)), err + } + if p == nil { + return logical.ErrorResponse(fmt.Sprintf("no such key %s", name)), logical.ErrInvalidRequest + } + + if !p.DeletionAllowed { + return logical.ErrorResponse(fmt.Sprintf("'allow_deletion' config value is not set")), logical.ErrInvalidRequest + } + + err = req.Storage.Delete("policy/" + name) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("error deleting policy %s: %s", name, err)), err } return nil, nil } -const pathPolicyHelpSyn = `Managed named encrption keys` +const pathPolicyHelpSyn = `Managed named encryption keys` const pathPolicyHelpDesc = ` This path is used to manage the named keys that are available. diff --git a/builtin/logical/transit/path_raw.go b/builtin/logical/transit/path_raw.go deleted file mode 100644 index 6c349a20dd..0000000000 --- a/builtin/logical/transit/path_raw.go +++ /dev/null @@ -1,58 +0,0 @@ -package transit - -import ( - "github.com/hashicorp/vault/logical" - "github.com/hashicorp/vault/logical/framework" -) - -func pathRaw() *framework.Path { - return &framework.Path{ - Pattern: "raw/" + framework.GenericNameRegex("name"), - Fields: map[string]*framework.FieldSchema{ - "name": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "Name of the key", - }, - }, - - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: pathRawRead, - }, - - HelpSynopsis: pathPolicyHelpSyn, - HelpDescription: pathPolicyHelpDesc, - } -} - -func pathRawRead( - req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - name := d.Get("name").(string) - p, err := getPolicy(req, name) - if err != nil { - return nil, err - } - if p == nil { - return nil, nil - } - - // Return the response - resp := &logical.Response{ - Data: map[string]interface{}{ - "name": p.Name, - "key": p.Key, - "cipher_mode": p.CipherMode, - "derived": p.Derived, - }, - } - if p.Derived { - resp.Data["kdf_mode"] = p.KDFMode - } - return resp, nil -} - -const pathRawHelpSyn = `Fetch raw keys for named encrption keys` - -const pathRawHelpDesc = ` -This path is used to get the underlying encryption keys used for the -named keys that are available. -` diff --git a/builtin/logical/transit/path_rewrap.go b/builtin/logical/transit/path_rewrap.go new file mode 100644 index 0000000000..d6b48d0356 --- /dev/null +++ b/builtin/logical/transit/path_rewrap.go @@ -0,0 +1,120 @@ +package transit + +import ( + "encoding/base64" + "fmt" + + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathRewrap() *framework.Path { + return &framework.Path{ + Pattern: "rewrap/" + framework.GenericNameRegex("name"), + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the key", + }, + + "ciphertext": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Ciphertext value to rewrap", + }, + + "context": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Context for key derivation. Required for derived keys.", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: pathRewrapWrite, + }, + + HelpSynopsis: pathRewrapHelpSyn, + HelpDescription: pathRewrapHelpDesc, + } +} + +func pathRewrapWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + value := d.Get("ciphertext").(string) + if len(value) == 0 { + return logical.ErrorResponse("missing ciphertext to decrypt"), logical.ErrInvalidRequest + } + + // Decode the context if any + contextRaw := d.Get("context").(string) + var context []byte + if len(contextRaw) != 0 { + var err error + context, err = base64.StdEncoding.DecodeString(contextRaw) + if err != nil { + return logical.ErrorResponse("failed to decode context as base64"), logical.ErrInvalidRequest + } + } + + // Get the policy + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + + // Error if invalid policy + if p == nil { + return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + } + + plaintext, err := p.Decrypt(context, value) + if err != nil { + switch err.(type) { + case certutil.UserError: + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + case certutil.InternalError: + return nil, err + default: + return nil, err + } + } + + if plaintext == "" { + return nil, fmt.Errorf("empty plaintext returned during rewrap") + } + + ciphertext, err := p.Encrypt(context, plaintext) + if err != nil { + switch err.(type) { + case certutil.UserError: + return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest + case certutil.InternalError: + return nil, err + default: + return nil, err + } + } + + if ciphertext == "" { + return nil, fmt.Errorf("empty ciphertext returned") + } + + // Generate the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "ciphertext": ciphertext, + }, + } + return resp, nil +} + +const pathRewrapHelpSyn = `Rewrap ciphertext` + +const pathRewrapHelpDesc = ` +After key rotation, this function can be used to rewrap the +given ciphertext with the latest version of the named key. +If the given ciphertext is already using the latest version +of the key, this function is a no-op. +` diff --git a/builtin/logical/transit/path_rotate.go b/builtin/logical/transit/path_rotate.go new file mode 100644 index 0000000000..c557dc9278 --- /dev/null +++ b/builtin/logical/transit/path_rotate.go @@ -0,0 +1,56 @@ +package transit + +import ( + "fmt" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathRotate() *framework.Path { + return &framework.Path{ + Pattern: "keys/" + framework.GenericNameRegex("name") + "/rotate", + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the key", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.WriteOperation: pathRotateWrite, + }, + + HelpSynopsis: pathRotateHelpSyn, + HelpDescription: pathRotateHelpDesc, + } +} + +func pathRotateWrite( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + + // Check if the policy already exists + policy, err := getPolicy(req, name) + if err != nil { + return nil, err + } + if policy == nil { + return logical.ErrorResponse( + fmt.Sprintf("no existing role named %s could be found", name)), + logical.ErrInvalidRequest + } + + // Generate the policy + err = policy.rotate(req.Storage) + + return nil, err +} + +const pathRotateHelpSyn = `Rotate named encryption key` + +const pathRotateHelpDesc = ` +This path is used to rotate the named key. After rotation, +new encryption requests using this name will use the new key, +but decryption will still be supported for older versions. +` diff --git a/builtin/logical/transit/policy.go b/builtin/logical/transit/policy.go new file mode 100644 index 0000000000..fbda93111e --- /dev/null +++ b/builtin/logical/transit/policy.go @@ -0,0 +1,360 @@ +package transit + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "strconv" + "strings" + "time" + + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/kdf" + "github.com/hashicorp/vault/logical" +) + +const ( + // kdfMode is the only KDF mode currently supported + kdfMode = "hmac-sha256-counter" +) + +// KeyEntry stores the key and metadata +type KeyEntry struct { + Key []byte `json:"key"` + CreationTime int64 `json:"creation_time"` +} + +// KeyEntryMap is used to allow JSON marshal/unmarshal +type KeyEntryMap map[int]KeyEntry + +// MarshalJSON implements JSON marshaling +func (kem KeyEntryMap) MarshalJSON() ([]byte, error) { + intermediate := map[string]KeyEntry{} + for k, v := range kem { + intermediate[strconv.Itoa(k)] = v + } + return json.Marshal(&intermediate) +} + +// MarshalJSON implements JSON unmarshaling +func (kem KeyEntryMap) UnmarshalJSON(data []byte) error { + intermediate := map[string]KeyEntry{} + err := json.Unmarshal(data, &intermediate) + if err != nil { + return err + } + for k, v := range intermediate { + keyval, err := strconv.Atoi(k) + if err != nil { + return err + } + kem[keyval] = v + } + + return nil +} + +// Policy is the struct used to store metadata +type Policy struct { + Name string `json:"name"` + Key []byte `json:"key,omitempty"` //DEPRECATED + Keys KeyEntryMap `json:"keys"` + CipherMode string `json:"cipher"` + + // Derived keys MUST provide a context and the + // master underlying key is never used. + Derived bool `json:"derived"` + KDFMode string `json:"kdf_mode"` + + // The minimum version of the key allowed to be used + // for decryption + MinDecryptionVersion int `json:"min_decryption_version"` + + // Whether the key is allowed to be deleted + DeletionAllowed bool `json:"deletion_allowed"` +} + +func (p *Policy) Persist(storage logical.Storage, name string) error { + // Encode the policy + buf, err := p.Serialize() + if err != nil { + return err + } + + // Write the policy into storage + err = storage.Put(&logical.StorageEntry{ + Key: "policy/" + name, + Value: buf, + }) + if err != nil { + return err + } + + return nil +} + +func (p *Policy) Serialize() ([]byte, error) { + return json.Marshal(p) +} + +// DeriveKey is used to derive the encryption key that should +// be used depending on the policy. If derivation is disabled the +// raw key is used and no context is required, otherwise the KDF +// mode is used with the context to derive the proper key. +func (p *Policy) DeriveKey(context []byte, ver int) ([]byte, error) { + if p.Keys == nil || len(p.Keys) == 0 { + if p.Key == nil || len(p.Key) == 0 { + return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"} + } + p.migrateKeyToKeysMap() + } + + if len(p.Keys) == 0 { + return nil, certutil.InternalError{Err: "unable to access the key; no key versions found"} + } + + if ver <= 0 || ver > len(p.Keys) { + return nil, certutil.UserError{Err: "invalid key version"} + } + + // Fast-path non-derived keys + if !p.Derived { + return p.Keys[ver].Key, nil + } + + // Ensure a context is provided + if len(context) == 0 { + return nil, certutil.UserError{Err: "missing 'context' for key deriviation. The key was created using a derived key, which means additional, per-request information must be included in order to encrypt or decrypt information"} + } + + switch p.KDFMode { + case kdfMode: + prf := kdf.HMACSHA256PRF + prfLen := kdf.HMACSHA256PRFLen + return kdf.CounterMode(prf, prfLen, p.Keys[ver].Key, context, 256) + default: + return nil, certutil.InternalError{Err: "unsupported key derivation mode"} + } +} + +func (p *Policy) Encrypt(context []byte, value string) (string, error) { + // Decode the plaintext value + plaintext, err := base64.StdEncoding.DecodeString(value) + if err != nil { + return "", certutil.UserError{Err: "failed to decode plaintext as base64"} + } + + // Derive the key that should be used + key, err := p.DeriveKey(context, len(p.Keys)) + if err != nil { + return "", certutil.InternalError{Err: err.Error()} + } + + // Guard against a potentially invalid cipher-mode + switch p.CipherMode { + case "aes-gcm": + default: + return "", certutil.InternalError{Err: "unsupported cipher mode"} + } + + // Setup the cipher + aesCipher, err := aes.NewCipher(key) + if err != nil { + return "", certutil.InternalError{Err: err.Error()} + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return "", certutil.InternalError{Err: err.Error()} + } + + // Compute random nonce + nonce := make([]byte, gcm.NonceSize()) + _, err = rand.Read(nonce) + if err != nil { + return "", certutil.InternalError{Err: err.Error()} + } + + // Encrypt and tag with GCM + out := gcm.Seal(nil, nonce, plaintext, nil) + + // Place the encrypted data after the nonce + full := append(nonce, out...) + + // Convert to base64 + encoded := base64.StdEncoding.EncodeToString(full) + + // Prepend some information + encoded = "vault:v" + strconv.Itoa(len(p.Keys)) + ":" + encoded + + return encoded, nil +} + +func (p *Policy) Decrypt(context []byte, value string) (string, error) { + // Verify the prefix + if !strings.HasPrefix(value, "vault:v") { + return "", certutil.UserError{Err: "invalid ciphertext"} + } + + splitVerCiphertext := strings.SplitN(strings.TrimPrefix(value, "vault:v"), ":", 2) + if len(splitVerCiphertext) != 2 { + return "", certutil.UserError{Err: "invalid ciphertext"} + } + + ver, err := strconv.Atoi(splitVerCiphertext[0]) + if err != nil { + return "", certutil.UserError{Err: "invalid ciphertext"} + } + + if ver == 0 { + // Compatibility mode with initial implementation, where keys start at zero + ver = 1 + } + + if p.MinDecryptionVersion > 0 && ver < p.MinDecryptionVersion { + return "", certutil.UserError{Err: "ciphertext version is disallowed by policy (too old)"} + } + + // Derive the key that should be used + key, err := p.DeriveKey(context, ver) + if err != nil { + return "", err + } + + // Guard against a potentially invalid cipher-mode + switch p.CipherMode { + case "aes-gcm": + default: + return "", certutil.InternalError{Err: "unsupported cipher mode"} + } + + // Decode the base64 + decoded, err := base64.StdEncoding.DecodeString(splitVerCiphertext[1]) + if err != nil { + return "", certutil.UserError{Err: "invalid ciphertext"} + } + + // Setup the cipher + aesCipher, err := aes.NewCipher(key) + if err != nil { + return "", certutil.InternalError{Err: err.Error()} + } + + // Setup the GCM AEAD + gcm, err := cipher.NewGCM(aesCipher) + if err != nil { + return "", certutil.InternalError{Err: err.Error()} + } + + // Extract the nonce and ciphertext + nonce := decoded[:gcm.NonceSize()] + ciphertext := decoded[gcm.NonceSize():] + + // Verify and Decrypt + plain, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", certutil.UserError{Err: "invalid ciphertext"} + } + + return base64.StdEncoding.EncodeToString(plain), nil +} + +func (p *Policy) rotate(storage logical.Storage) error { + if p.Keys == nil { + p.migrateKeyToKeysMap() + } + + // Generate a 256bit key + newKey := make([]byte, 32) + _, err := rand.Read(newKey) + if err != nil { + return err + } + p.Keys[len(p.Keys)+1] = KeyEntry{ + Key: newKey, + CreationTime: time.Now().Unix(), + } + + return p.Persist(storage, p.Name) +} + +func (p *Policy) migrateKeyToKeysMap() { + if p.Key == nil || len(p.Key) == 0 { + p.Key = nil + p.Keys = KeyEntryMap{} + return + } + + p.Keys = KeyEntryMap{ + 1: KeyEntry{ + Key: p.Key, + CreationTime: time.Now().Unix(), + }, + } + p.Key = nil +} + +func deserializePolicy(buf []byte) (*Policy, error) { + p := &Policy{ + Keys: KeyEntryMap{}, + } + if err := json.Unmarshal(buf, p); err != nil { + return nil, err + } + + return p, nil +} + +func getPolicy(req *logical.Request, name string) (*Policy, error) { + // Check if the policy already exists + raw, err := req.Storage.Get("policy/" + name) + if err != nil { + return nil, err + } + if raw == nil { + return nil, nil + } + + // Decode the policy + p, err := deserializePolicy(raw.Value) + if err != nil { + return nil, err + } + + // Ensure we've moved from Key -> Keys + if p.Key != nil && len(p.Key) > 0 { + p.migrateKeyToKeysMap() + + err = p.Persist(req.Storage, name) + if err != nil { + return nil, err + } + } + + return p, nil +} + +// generatePolicy is used to create a new named policy with +// a randomly generated key +func generatePolicy(storage logical.Storage, name string, derived bool) (*Policy, error) { + // Create the policy object + p := &Policy{ + Name: name, + CipherMode: "aes-gcm", + Derived: derived, + } + if derived { + p.KDFMode = kdfMode + } + + err := p.rotate(storage) + if err != nil { + return nil, err + } + + // Return the policy + return p, nil +} diff --git a/logical/testing/testing.go b/logical/testing/testing.go index e829f003a3..117ee17a47 100644 --- a/logical/testing/testing.go +++ b/logical/testing/testing.go @@ -209,7 +209,18 @@ func Test(t TestT, c TestCase) { Path: "sys/revoke/" + resp.Secret.LeaseID, }) } - if err == nil && resp.IsError() && !s.ErrorOk { + // If it's an error, but an error is expected, and one is also + // returned as a logical.ErrorResponse, let it go to the check + if err != nil { + if !resp.IsError() || (resp.IsError() && !s.ErrorOk) { + t.Error(fmt.Sprintf("Failed step %d: %s", i+1, err)) + break + } + // Set it to nil here as we're catching on the + // logical.ErrorResponse instead + err = nil + } + if resp.IsError() && !s.ErrorOk { err = fmt.Errorf("Erroneous response:\n\n%#v", resp) } if err == nil && s.Check != nil { diff --git a/website/source/docs/secrets/transit/index.html.md b/website/source/docs/secrets/transit/index.html.md index f177ff1b96..b002a7667a 100644 --- a/website/source/docs/secrets/transit/index.html.md +++ b/website/source/docs/secrets/transit/index.html.md @@ -62,21 +62,6 @@ cipher_mode aes-gcm derived false ```` -We can read from the `raw/` endpoint to see the encryption key itself: - -``` -$ vault read transit/raw/foo -Key Value -name foo -cipher_mode aes-gcm -key PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8= -derived false -```` - -Here we can see that the randomly generated encryption key being used, as -well as the AES-GCM cipher mode. We don't need to know any of this to use -the key however. - Now, if we wanted to encrypt a piece of plain text, we use the encrypt endpoint using our named key: @@ -299,44 +284,3 @@ only encrypt or decrypt using the named keys they need access to. - -### /transit/raw/ -#### GET - -
-
Description
-
- Returns raw information about a named encryption key, - Including the underlying encryption key. This is a root protected endpoint. -
- -
Method
-
GET
- -
URL
-
`/transit/raw/`
- -
Parameters
-
- None -
- -
Returns
-
- - ```javascript - { - "data": { - "name": "foo", - "cipher_mode": "aes-gcm", - "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" - "derived": "true", - "kdf_mode": "hmac-sha256-counter", - } - } - ``` - -
-
- -