diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index c48df88ac6..9d732e1cb7 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -67,7 +67,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d var batchInputItems []BatchRequestItem var err error if batchInputRaw != nil { - err = decodeBatchRequestItems(batchInputRaw, &batchInputItems) + err = decodeDecryptBatchRequestItems(batchInputRaw, &batchInputItems) if err != nil { return nil, fmt.Errorf("failed to parse batch input: %w", err) } diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index 0eefcbbb7b..5e27b549d0 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -127,10 +127,19 @@ to the min_encryption_version configured on the key.`, } } +func decodeEncryptBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error { + return decodeBatchRequestItems(src, true, false, dst) +} + +func decodeDecryptBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error { + return decodeBatchRequestItems(src, false, true, dst) +} + // decodeBatchRequestItems is a fast path alternative to mapstructure.Decode to decode []BatchRequestItem. // It aims to behave as closely possible to the original mapstructure.Decode and will return the same errors. +// Note, however, that an error will also be returned if one of the required fields is missing. // https://github.com/hashicorp/vault/pull/8775/files#r437709722 -func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error { +func decodeBatchRequestItems(src interface{}, requirePlaintext bool, requireCiphertext bool, dst *[]BatchRequestItem) error { if src == nil || dst == nil { return nil } @@ -173,15 +182,18 @@ func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error { } else { errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].ciphertext' expected type 'string', got unconvertible type '%T'", i, item["ciphertext"])) } + } else if requireCiphertext { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].ciphertext' missing ciphertext to decrypt", i)) } - // don't allow "null" to be passed in for the plaintext value if v, has := item["plaintext"]; has { if casted, ok := v.(string); ok { (*dst)[i].Plaintext = casted } else { errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].plaintext' expected type 'string', got unconvertible type '%T'", i, item["plaintext"])) } + } else if requirePlaintext { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].plaintext' missing plaintext to encrypt", i)) } if v, has := item["nonce"]; has { @@ -240,7 +252,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d batchInputRaw := d.Raw["batch_input"] var batchInputItems []BatchRequestItem if batchInputRaw != nil { - err = decodeBatchRequestItems(batchInputRaw, &batchInputItems) + err = decodeEncryptBatchRequestItems(batchInputRaw, &batchInputItems) if err != nil { return nil, fmt.Errorf("failed to parse batch input: %w", err) } @@ -249,14 +261,18 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d return logical.ErrorResponse("missing batch input to process"), logical.ErrInvalidRequest } } else { - valueRaw, ok := d.GetOk("plaintext") + valueRaw, ok := d.Raw["plaintext"] if !ok { return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest } + plaintext, ok := valueRaw.(string) + if !ok { + return logical.ErrorResponse("expected plaintext of type 'string', got unconvertible type '%T'", valueRaw), logical.ErrInvalidRequest + } batchInputItems = make([]BatchRequestItem, 1) batchInputItems[0] = BatchRequestItem{ - Plaintext: valueRaw.(string), + Plaintext: plaintext, Context: d.Get("context").(string), Nonce: d.Get("nonce").(string), KeyVersion: d.Get("key_version").(int), diff --git a/builtin/logical/transit/path_encrypt_test.go b/builtin/logical/transit/path_encrypt_test.go index d0a4ed38dd..d9a7081ae4 100644 --- a/builtin/logical/transit/path_encrypt_test.go +++ b/builtin/logical/transit/path_encrypt_test.go @@ -13,6 +13,75 @@ import ( "github.com/mitchellh/mapstructure" ) +func TestTransit_MissingPlaintext(t *testing.T) { + var resp *logical.Response + var err error + + b, s := createBackendWithStorage(t) + + // Create the policy + policyReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "keys/existing_key", + Storage: s, + } + resp, err = b.HandleRequest(context.Background(), policyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + encData := map[string]interface{}{ + "plaintext": nil, + } + + encReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "encrypt/existing_key", + Storage: s, + Data: encData, + } + resp, err = b.HandleRequest(context.Background(), encReq) + if resp == nil || !resp.IsError() { + t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) + } +} + +func TestTransit_MissingPlaintextInBatchInput(t *testing.T) { + var resp *logical.Response + var err error + + b, s := createBackendWithStorage(t) + + // Create the policy + policyReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "keys/existing_key", + Storage: s, + } + resp, err = b.HandleRequest(context.Background(), policyReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + batchInput := []interface{}{ + map[string]interface{}{}, // Note that there is no map entry for plaintext + } + + batchData := map[string]interface{}{ + "batch_input": batchInput, + } + batchReq := &logical.Request{ + Operation: logical.CreateOperation, + Path: "encrypt/upserted_key", + Storage: s, + Data: batchData, + } + resp, err = b.HandleRequest(context.Background(), batchReq) + if err == nil { + t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp) + } +} + // Case1: Ensure that batch encryption did not affect the normal flow of // encrypting the plaintext with a pre-existing key. func TestTransit_BatchEncryptionCase1(t *testing.T) { @@ -607,10 +676,12 @@ func TestTransit_BatchEncryptionCase13(t *testing.T) { // Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem. func TestTransit_decodeBatchRequestItems(t *testing.T) { tests := []struct { - name string - src interface{} - dest []BatchRequestItem - wantErrContains string + name string + src interface{} + requirePlaintext bool + requireCiphertext bool + dest []BatchRequestItem + wantErrContains string }{ // basic edge cases of nil values {name: "nil-nil", src: nil, dest: nil}, @@ -729,16 +800,51 @@ func TestTransit_decodeBatchRequestItems(t *testing.T) { src: []interface{}{map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "nonce": "null"}}, dest: []BatchRequestItem{}, }, + // required fields + { + name: "required_plaintext_present", + src: []interface{}{map[string]interface{}{"plaintext": ""}}, + requirePlaintext: true, + dest: []BatchRequestItem{}, + }, + { + name: "required_plaintext_missing", + src: []interface{}{map[string]interface{}{}}, + requirePlaintext: true, + dest: []BatchRequestItem{}, + wantErrContains: "missing plaintext", + }, + { + name: "required_ciphertext_present", + src: []interface{}{map[string]interface{}{"ciphertext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}}, + requireCiphertext: true, + dest: []BatchRequestItem{}, + }, + { + name: "required_ciphertext_missing", + src: []interface{}{map[string]interface{}{}}, + requireCiphertext: true, + dest: []BatchRequestItem{}, + wantErrContains: "missing ciphertext", + }, + { + name: "required_plaintext_and_ciphertext_missing", + src: []interface{}{map[string]interface{}{}}, + requirePlaintext: true, + requireCiphertext: true, + dest: []BatchRequestItem{}, + wantErrContains: "missing ciphertext", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { expectedDest := append(tt.dest[:0:0], tt.dest...) // copy of the dest state - expectedErr := mapstructure.Decode(tt.src, &expectedDest) + expectedErr := mapstructure.Decode(tt.src, &expectedDest) != nil || tt.wantErrContains != "" - gotErr := decodeBatchRequestItems(tt.src, &tt.dest) + gotErr := decodeBatchRequestItems(tt.src, tt.requirePlaintext, tt.requireCiphertext, &tt.dest) gotDest := tt.dest - if expectedErr != nil { + if expectedErr { if gotErr == nil { t.Fatal("decodeBatchRequestItems unexpected error value; expected error but got none") } diff --git a/builtin/logical/transit/path_hash.go b/builtin/logical/transit/path_hash.go index 5e4c750530..2b894f00ef 100644 --- a/builtin/logical/transit/path_hash.go +++ b/builtin/logical/transit/path_hash.go @@ -63,7 +63,16 @@ Defaults to "sha2-256".`, } func (b *backend) pathHashWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - inputB64 := d.Get("input").(string) + rawInput, ok := d.Raw["input"] + if !ok { + return logical.ErrorResponse("input missing"), logical.ErrInvalidRequest + } + + inputB64, ok := rawInput.(string) + if !ok { + return logical.ErrorResponse("expected input of type 'string', got unconvertible type '%T'", rawInput), logical.ErrInvalidRequest + } + format := d.Get("format").(string) algorithm := d.Get("urlalgorithm").(string) if algorithm == "" { diff --git a/builtin/logical/transit/path_hash_test.go b/builtin/logical/transit/path_hash_test.go index 98ce87889a..0492cf03a4 100644 --- a/builtin/logical/transit/path_hash_test.go +++ b/builtin/logical/transit/path_hash_test.go @@ -29,7 +29,7 @@ func TestTransit_Hash(t *testing.T) { } if errExpected { if !resp.IsError() { - t.Fatalf("bad: got error response: %#v", *resp) + t.Fatalf("bad: did not get error response: %#v", *resp) } return } @@ -86,6 +86,10 @@ func TestTransit_Hash(t *testing.T) { doRequest(req, false, "98rFrYMEIqVAizamCmBiBoe+GAdlo+KJW8O9vYV8nggkbIMGTU42EvDLkn8+rSCEE6uYYkv3sGF68PA/YggJdg==") // Test bad input/format/algorithm + req.Data["input"] = nil + doRequest(req, true, "") + + req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA==" req.Data["format"] = "base92" doRequest(req, true, "") diff --git a/builtin/logical/transit/path_trim.go b/builtin/logical/transit/path_trim.go index d8587f1c18..fd01b60a2c 100644 --- a/builtin/logical/transit/path_trim.go +++ b/builtin/logical/transit/path_trim.go @@ -55,11 +55,14 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc { } defer p.Unlock() - minAvailableVersionRaw, ok := d.GetOk("min_available_version") + minAvailableVersionRaw, ok := d.Raw["min_available_version"] if !ok { return logical.ErrorResponse("missing min_available_version"), nil } - minAvailableVersion := minAvailableVersionRaw.(int) + minAvailableVersion, ok := minAvailableVersionRaw.(int) + if !ok { + return logical.ErrorResponse("expected min_available_version of type 'int', got unconvertible type '%T'", minAvailableVersionRaw), logical.ErrInvalidRequest + } originalMinAvailableVersion := p.MinAvailableVersion diff --git a/builtin/logical/transit/path_trim_test.go b/builtin/logical/transit/path_trim_test.go index be989b1642..db38aad938 100644 --- a/builtin/logical/transit/path_trim_test.go +++ b/builtin/logical/transit/path_trim_test.go @@ -79,6 +79,20 @@ func TestTransit_Trim(t *testing.T) { } doErrReq(t, req) + // Set min_encryption_version to 0 + req.Path = "keys/aes/config" + req.Data = map[string]interface{}{ + "min_encryption_version": 0, + } + doReq(t, req) + + // Min available version should not be converted to 0 for nil values + req.Path = "keys/aes/trim" + req.Data = map[string]interface{}{ + "min_available_version": nil, + } + doErrReq(t, req) + // Set min_encryption_version to 4 req.Path = "keys/aes/config" req.Data = map[string]interface{}{ diff --git a/changelog/14074.txt b/changelog/14074.txt new file mode 100644 index 0000000000..9d12642482 --- /dev/null +++ b/changelog/14074.txt @@ -0,0 +1,3 @@ +```release-note:bug +secrets/transit: Return an error if any required parameter is missing or nil. Do not encrypt nil plaintext as if it was an empty string. +```