diff --git a/vault/external_tests/raft/raft_test.go b/vault/external_tests/raft/raft_test.go index efe209c44e..73eab754bf 100644 --- a/vault/external_tests/raft/raft_test.go +++ b/vault/external_tests/raft/raft_test.go @@ -514,7 +514,7 @@ func TestRaft_SnapshotAPI_MidstreamFailure(t *testing.T) { // defer goleak.VerifyNone(t) t.Parallel() - seal, setErr := vaultseal.NewToggleableTestSeal(nil) + seal, wrappers := vaultseal.NewTestSeal(nil) autoSeal := vault.NewAutoSeal(seal) cluster, _ := raftCluster(t, &RaftClusterOpts{ NumCores: 1, @@ -547,7 +547,7 @@ func TestRaft_SnapshotAPI_MidstreamFailure(t *testing.T) { wg.Done() }() - setErr[0](errors.New("seal failure")) + wrappers[0].SetError(errors.New("seal failure")) // Take a snapshot err := leaderClient.Sys().RaftSnapshot(w) w.Close() diff --git a/vault/seal/seal.go b/vault/seal/seal.go index 3786f37c86..cd49f052ed 100644 --- a/vault/seal/seal.go +++ b/vault/seal/seal.go @@ -203,13 +203,13 @@ func haveCommonSeal(existingSealKmsConfigs, newSealKmsConfigs []*configutil.KMS) } func findRenamedDisabledSeals(configs []*configutil.KMS) []*configutil.KMS { - diabledSeals := []*configutil.KMS{} + disabledSeals := []*configutil.KMS{} for _, seal := range configs { if seal.Disabled && strings.HasSuffix(seal.Name, configutil.KmsRenameDisabledSuffix) { - diabledSeals = append(diabledSeals, seal) + disabledSeals = append(disabledSeals, seal) } } - return diabledSeals + return disabledSeals } func compareKMSConfigByNameAndType() cmp.Option { @@ -468,7 +468,10 @@ func (a *access) Init(ctx context.Context, options ...wrapping.Option) error { a.logger.Warn("cannot determine key ID for seal", "seal", sealWrapper.Name, "err", err) return fmt.Errorf("cannod determine key ID for seal %s: %w", sealWrapper.Name, err) } - keyIds = append(keyIds, keyId) + if keyId != "" { + // Some wrappers may not yet know their key id. For emample, see gcpkms.Wrapper. + keyIds = append(keyIds, keyId) + } } } a.keyIdSet.setIds(keyIds) @@ -477,7 +480,7 @@ func (a *access) Init(ctx context.Context, options ...wrapping.Option) error { func (a *access) IsUpToDate(ctx context.Context, value *MultiWrapValue, forceKeyIdRefresh bool) (bool, error) { // Note that we don't compare generations when the value is transitory, since all single-blobInfo - // values are unmarshalled as transitory values. + // values (i.e. not yet upgraded to MultiWrapValues) are unmarshalled as transitory values. if value.Generation != 0 && value.Generation != a.Generation() { return false, nil } @@ -558,6 +561,34 @@ GATHER_RESULTS: } } + { + // Check for duplicate Key IDs. + // If any wrappers produce duplicated IDs, their BlobInfo will be replaced by an error. + + keyIdToSealWrapperNameMap := make(map[string]string) + for _, sealWrapper := range enabledWrappersByPriority { + wrapperName := sealWrapper.Name + if result, ok := results[wrapperName]; ok { + if result.err != nil { + continue + } + if result.ciphertext.KeyInfo == nil { + // Can this really happen? Probably not? + continue + } + keyId := result.ciphertext.KeyInfo.KeyId + duplicateWrapperName, isDuplicate := keyIdToSealWrapperNameMap[keyId] + if isDuplicate { + for _, name := range []string{wrapperName, duplicateWrapperName} { + results[name].err = fmt.Errorf("seal %s has returned duplicate key ID %s, key IDs must be unique", name, keyId) + results[name].ciphertext = nil + } + } + keyIdToSealWrapperNameMap[keyId] = wrapperName + } + } + } + // Sort out the successful results from the errors var slots []*wrapping.BlobInfo errs := make(map[string]error) @@ -587,6 +618,7 @@ GATHER_RESULTS: a.logger.Trace("successfully encrypted value", "encryption seal wrappers", len(slots), "total enabled seal wrappers", len(a.GetEnabledSealWrappersByPriority())) + ret := &MultiWrapValue{ Generation: a.Generation(), Slots: slots, @@ -748,7 +780,7 @@ GATHER_RESULTS: return nil, false, errors.New("context timeout exceeded") } -// tryDecrypt returns the plaintext and a flad indicating whether the decryption was done by the "unwrapSeal" (see +// tryDecrypt returns the plaintext and a flag indicating whether the decryption was done by the "unwrapSeal" (see // sealWrapMigration.Decrypt). func (a *access) tryDecrypt(ctx context.Context, sealWrapper *SealWrapper, ciphertextByKeyId map[string]*wrapping.BlobInfo, options []wrapping.Option) ([]byte, bool, error) { now := time.Now() diff --git a/vault/seal/seal_test.go b/vault/seal/seal_test.go index 07abe465db..c5bc522413 100644 --- a/vault/seal/seal_test.go +++ b/vault/seal/seal_test.go @@ -4,6 +4,9 @@ package seal import ( + "context" + "fmt" + "github.com/stretchr/testify/require" "testing" wrapping "github.com/hashicorp/go-kms-wrapping/v2" @@ -95,3 +98,47 @@ func Test_keyIdSet(t *testing.T) { runTest(tt.name+".setIDs", useSetIds) } } + +// Test_Encrypt_duplicate_keyIds verifies that if two seal wrappers produce the same Key ID, an error +// will be returned for both. +func Test_Encrypt_duplicate_keyIds(t *testing.T) { + ctx := context.Background() + + setId := func(w *SealWrapper, keyId string) { + testWrapper := w.Wrapper.(*ToggleableWrapper).Wrapper.(*wrapping.TestWrapper) + testWrapper.SetKeyId(keyId) + } + + getId := func(w *SealWrapper) string { + id, err := w.Wrapper.KeyId(ctx) + if err != nil { + t.Fatal(err) + } + return id + } + + access, _ := NewTestSeal(&TestSealOpts{WrapperCount: 3}) + + // Set up - make the key IDs the same for the last two wrappers + wrappers := access.GetAllSealWrappersByPriority() + setId(wrappers[1], "this-key-is-duplicated") + setId(wrappers[2], "this-key-is-duplicated") + + // Some sanity checks + require.NotEqual(t, wrappers[0].Name, wrappers[1].Name) + require.NotEqual(t, wrappers[1].Name, wrappers[2].Name) + require.NotEqual(t, getId(wrappers[0]), getId(wrappers[1])) + require.Equal(t, getId(wrappers[1]), getId(wrappers[2])) + + // Encrypt a value + mwv, errorMap := access.Encrypt(ctx, []byte("Rinconete y Cortadillo")) + + // Assertions + require.NotNilf(t, mwv, "seal 0 should have succeeded") + + requireDuplicateErr := func(w *SealWrapper) { + require.ErrorContains(t, errorMap[w.Name], fmt.Sprintf("seal %v has returned duplicate key ID", w.Name)) + } + requireDuplicateErr(wrappers[1]) + requireDuplicateErr(wrappers[2]) +} diff --git a/vault/seal/seal_testing.go b/vault/seal/seal_testing.go index 56a7bcb46c..f376564c20 100644 --- a/vault/seal/seal_testing.go +++ b/vault/seal/seal_testing.go @@ -6,6 +6,7 @@ package seal import ( "context" "fmt" + UUID "github.com/hashicorp/go-uuid" "sync" "github.com/hashicorp/vault/sdk/helper/logging" @@ -17,7 +18,7 @@ import ( type TestSealOpts struct { Logger hclog.Logger StoredKeys StoredKeysSupport - Secret []byte + Secrets [][]byte Name wrapping.WrapperType WrapperCount int Generation uint64 @@ -37,6 +38,29 @@ func NewTestSealOpts(opts *TestSealOpts) *TestSealOpts { // we might at some point need to allow Generation == 0 opts.Generation = 1 } + switch len(opts.Secrets) { + case opts.WrapperCount: + // all good, each wrapper has its own secret + + case 0: + if opts.WrapperCount == 1 { + // If there is only one wrapper, the default TestWrapper behaviour of reversing + // the bytes slice is fine. + opts.Secrets = [][]byte{nil} + } else { + // If there is more than one wrapper, each one needs a different secret + for i := 0; i < opts.WrapperCount; i++ { + uuid, err := UUID.GenerateUUID() + if err != nil { + panic(fmt.Sprintf("error generating secret: %v", err)) + } + opts.Secrets = append(opts.Secrets, []byte(uuid)) + } + } + + default: + panic(fmt.Sprintf("wrong number of secrets %d vs %d wrappers", len(opts.Secrets), opts.WrapperCount)) + } return opts } @@ -46,7 +70,12 @@ func NewTestSeal(opts *TestSealOpts) (Access, []*ToggleableWrapper) { sealWrappers := make([]*SealWrapper, opts.WrapperCount) ctx := context.Background() for i := 0; i < opts.WrapperCount; i++ { - wrappers[i] = &ToggleableWrapper{Wrapper: wrapping.NewTestWrapper(opts.Secret)} + wrapperName := fmt.Sprintf("%s-%d", opts.Name, i+1) + wrappers[i] = &ToggleableWrapper{Wrapper: wrapping.NewTestWrapper(opts.Secrets[i])} + _, err := wrappers[i].Wrapper.SetConfig(context.Background(), wrapping.WithKeyId(wrapperName)) + if err != nil { + panic(err) + } wrapperType, err := wrappers[i].Type(ctx) if err != nil { panic(err) @@ -54,7 +83,7 @@ func NewTestSeal(opts *TestSealOpts) (Access, []*ToggleableWrapper) { sealWrappers[i] = NewSealWrapper( wrappers[i], i+1, - fmt.Sprintf("%s-%d", opts.Name, i+1), + wrapperName, wrapperType.String(), false, true, @@ -75,77 +104,6 @@ type TestSealWrapperOpts struct { WrapperCount int } -func CreateTestSealWrapperOpts(opts *TestSealWrapperOpts) *TestSealWrapperOpts { - if opts == nil { - opts = new(TestSealWrapperOpts) - } - if opts.WrapperCount == 0 { - opts.WrapperCount = 1 - } - if opts.Logger == nil { - opts.Logger = logging.NewVaultLogger(hclog.Debug) - } - return opts -} - -func CreateTestSealWrappers(opts *TestSealWrapperOpts) []*SealWrapper { - opts = CreateTestSealWrapperOpts(opts) - wrappers := make([]*ToggleableWrapper, opts.WrapperCount) - sealWrappers := make([]*SealWrapper, opts.WrapperCount) - ctx := context.Background() - for i := 0; i < opts.WrapperCount; i++ { - wrappers[i] = &ToggleableWrapper{Wrapper: wrapping.NewTestWrapper(opts.Secret)} - wrapperType, err := wrappers[i].Type(ctx) - if err != nil { - panic(err) - } - sealWrappers[i] = NewSealWrapper( - wrappers[i], - i+1, - fmt.Sprintf("%s-%d", opts.Name, i+1), - wrapperType.String(), - false, - true, - ) - } - - return sealWrappers -} - -func NewToggleableTestSeal(opts *TestSealOpts) (Access, []func(error)) { - opts = NewTestSealOpts(opts) - - wrappers := make([]*ToggleableWrapper, opts.WrapperCount) - sealWrappers := make([]*SealWrapper, opts.WrapperCount) - funcs := make([]func(error), opts.WrapperCount) - ctx := context.Background() - for i := 0; i < opts.WrapperCount; i++ { - w := &ToggleableWrapper{Wrapper: wrapping.NewTestWrapper(opts.Secret)} - wrapperType, err := w.Type(ctx) - if err != nil { - panic(err) - } - - wrappers[i] = w - sealWrappers[i] = NewSealWrapper( - wrappers[i], - i+1, - fmt.Sprintf("%s-%d", opts.Name, i+1), - wrapperType.String(), - false, - true, - ) - funcs[i] = w.SetError - } - - sealAccess, err := NewAccessFromSealWrappers(nil, opts.Generation, true, sealWrappers) - if err != nil { - panic(err) - } - - return sealAccess, funcs -} - type ToggleableWrapper struct { wrapping.Wrapper wrapperType *wrapping.WrapperType diff --git a/vault/seal_autoseal_test.go b/vault/seal_autoseal_test.go index d4f52ecc18..e070471a35 100644 --- a/vault/seal_autoseal_test.go +++ b/vault/seal_autoseal_test.go @@ -183,7 +183,7 @@ func TestAutoSeal_HealthCheck(t *testing.T) { metrics.NewGlobal(metricsConf, inmemSink) pBackend := newTestBackend(t) - testSealAccess, setErrs := seal.NewToggleableTestSeal(&seal.TestSealOpts{Name: "health-test"}) + testSealAccess, wrappers := seal.NewTestSeal(&seal.TestSealOpts{Name: "health-test"}) core, _, _ := TestCoreUnsealedWithConfig(t, &CoreConfig{ MetricSink: metricsutil.NewClusterMetricSink("", inmemSink), Physical: pBackend, @@ -195,7 +195,7 @@ func TestAutoSeal_HealthCheck(t *testing.T) { core.seal = autoSeal autoSeal.StartHealthCheck() defer autoSeal.StopHealthCheck() - setErrs[0](errors.New("disconnected")) + wrappers[0].SetError(errors.New("disconnected")) tries := 10 for tries = 10; tries > 0; tries-- { @@ -208,7 +208,7 @@ func TestAutoSeal_HealthCheck(t *testing.T) { t.Fatalf("Expected to detect unhealthy seals") } - setErrs[0](nil) + wrappers[0].SetError(nil) time.Sleep(50 * time.Millisecond) if !autoSeal.Healthy() { t.Fatal("Expected seals to be healthy") @@ -216,8 +216,8 @@ func TestAutoSeal_HealthCheck(t *testing.T) { } func TestAutoSeal_BarrierSealConfigType(t *testing.T) { - singleWrapperAccess, _ := seal.NewToggleableTestSeal(&seal.TestSealOpts{WrapperCount: 1}) - multipleWrapperAccess, _ := seal.NewToggleableTestSeal(&seal.TestSealOpts{WrapperCount: 2}) + singleWrapperAccess, _ := seal.NewTestSeal(&seal.TestSealOpts{WrapperCount: 1}) + multipleWrapperAccess, _ := seal.NewTestSeal(&seal.TestSealOpts{WrapperCount: 2}) require.Equalf(t, singleWrapperAccess.GetAllSealWrappersByPriority()[0].SealConfigType, NewAutoSeal(singleWrapperAccess).BarrierSealConfigType().String(), "autoseals that have a single seal wrapper report that wrapper's as the barrier seal type") diff --git a/vault/seal_testing_util.go b/vault/seal_testing_util.go index 6139ed25eb..ea735b0875 100644 --- a/vault/seal_testing_util.go +++ b/vault/seal_testing_util.go @@ -10,6 +10,8 @@ import ( testing "github.com/mitchellh/go-testing-interface" ) +// NewTestSeal creates a new seal for testing. If you want to use the same seal multiple times, such as for +// a cluster, use NewTestSealFunc instead. func NewTestSeal(t testing.T, opts *seal.TestSealOpts) Seal { t.Helper() opts = seal.NewTestSealOpts(opts) @@ -50,3 +52,27 @@ func NewTestSeal(t testing.T, opts *seal.TestSealOpts) Seal { return NewAutoSeal(access) } } + +// NewTestSealFunc returns a function that creates seals. All such seals will have TestWrappers that +// share the same secret, thus making them equivalent. +func NewTestSealFunc(t testing.T, opts *seal.TestSealOpts) func() Seal { + testSeal := NewTestSeal(t, opts) + + return func() Seal { + return cloneTestSeal(t, testSeal) + } +} + +// CloneTestSeal creates a new test seal that shares the same seal wrappers as `testSeal`. +func cloneTestSeal(t testing.T, testSeal Seal) Seal { + logger := corehelpers.NewTestLogger(t).Named("sealAccess") + + access, err := seal.NewAccessFromSealWrappers(logger, testSeal.GetAccess().Generation(), testSeal.GetAccess().GetSealGenerationInfo().IsRewrapped(), testSeal.GetAccess().GetAllSealWrappersByPriority()) + if err != nil { + t.Fatal("error cloning seal %v", err) + } + if testSeal.StoredKeysSupported() == seal.StoredKeysNotSupported { + return NewDefaultSeal(access) + } + return NewAutoSeal(access) +}