From 6eeb2288892378a026421c8c2c32288aa33a3e4a Mon Sep 17 00:00:00 2001 From: Bianca <48203644+biazmoreira@users.noreply.github.com> Date: Thu, 6 Mar 2025 23:06:20 +0100 Subject: [PATCH] Persist automatic entity merges (#29568) * Persist automatic entity merges * Local aliases write in test * Add identity entity merge unit property test * N entities merge * Persist alias duplication fix --------- Co-authored-by: Paul Banks Co-authored-by: Mike Palmiotto --- helper/storagepacker/storagepacker.go | 3 +- helper/storagepacker/storagepacker_test.go | 44 ++++ .../external_tests/identity/identity_test.go | 23 ++ vault/identity_store.go | 4 +- vault/identity_store_conflicts.go | 13 ++ vault/identity_store_entities.go | 66 ++++-- vault/identity_store_injector_testonly.go | 5 +- vault/identity_store_oidc_test.go | 6 +- vault/identity_store_test.go | 208 ++++++++++-------- vault/identity_store_test_stubs_oss.go | 10 + vault/identity_store_util.go | 149 +++++++++---- 11 files changed, 363 insertions(+), 168 deletions(-) diff --git a/helper/storagepacker/storagepacker.go b/helper/storagepacker/storagepacker.go index 219049b1bb..b89fbdca7c 100644 --- a/helper/storagepacker/storagepacker.go +++ b/helper/storagepacker/storagepacker.go @@ -199,7 +199,8 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo // Look for a matching storage entries and delete them from the list. for i := 0; i < len(bucket.Items); i++ { if _, ok := itemsToRemove[bucket.Items[i].ID]; ok { - bucket.Items[i] = bucket.Items[len(bucket.Items)-1] + copy(bucket.Items[i:], bucket.Items[i+1:]) + bucket.Items[len(bucket.Items)-1] = nil // allow GC bucket.Items = bucket.Items[:len(bucket.Items)-1] // Since we just moved a value to position i we need to diff --git a/helper/storagepacker/storagepacker_test.go b/helper/storagepacker/storagepacker_test.go index d1f4f66e74..90aca0457a 100644 --- a/helper/storagepacker/storagepacker_test.go +++ b/helper/storagepacker/storagepacker_test.go @@ -6,6 +6,7 @@ package storagepacker import ( "context" "fmt" + "math/rand" "testing" "github.com/golang/protobuf/proto" @@ -68,6 +69,49 @@ func BenchmarkStoragePacker(b *testing.B) { } } +func BenchmarkStoragePacker_DeleteMultiple(b *testing.B) { + b.StopTimer() + storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "") + if err != nil { + b.Fatal(err) + } + + ctx := context.Background() + + // Persist a storage entry + for i := 0; i <= 100000; i++ { + item := &Item{ + ID: fmt.Sprintf("item%d", i), + } + + err = storagePacker.PutItem(ctx, item) + if err != nil { + b.Fatal(err) + } + + // Verify that it can be read + fetchedItem, err := storagePacker.GetItem(item.ID) + if err != nil { + b.Fatal(err) + } + if fetchedItem == nil { + b.Fatalf("failed to read the stored item") + } + + if item.ID != fetchedItem.ID { + b.Fatalf("bad: item ID; expected: %q\n actual: %q\n", item.ID, fetchedItem.ID) + } + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + err = storagePacker.DeleteItem(ctx, fmt.Sprintf("item%d", rand.Intn(100000))) + if err != nil { + b.Fatal(err) + } + } +} + func TestStoragePacker(t *testing.T) { storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "") if err != nil { diff --git a/vault/external_tests/identity/identity_test.go b/vault/external_tests/identity/identity_test.go index a816d97b6c..ed8117d2db 100644 --- a/vault/external_tests/identity/identity_test.go +++ b/vault/external_tests/identity/identity_test.go @@ -11,10 +11,12 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/namespace" ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap" "github.com/hashicorp/vault/helper/testhelpers/minimal" "github.com/hashicorp/vault/sdk/helper/ldaputil" + "github.com/hashicorp/vault/vault" "github.com/stretchr/testify/require" ) @@ -642,3 +644,24 @@ func addRemoveLdapGroupMember(t *testing.T, cfg *ldaputil.ConfigEntry, userCN st t.Fatal(err) } } + +func findEntityFromDuplicateSet(t *testing.T, c *vault.TestClusterCore, entityIDs []string) *identity.Entity { + t.Helper() + + var entity *identity.Entity + + // Try fetch each ID and ensure exactly one is present + found := 0 + for _, entityID := range entityIDs { + e, err := c.IdentityStore().MemDBEntityByID(entityID, true) + require.NoError(t, err) + if e != nil { + found++ + entity = e + } + } + // More than one means they didn't merge as expected! + require.Equal(t, found, 1, + "node %s does not have exactly one duplicate from the set", c.NodeID) + return entity +} diff --git a/vault/identity_store.go b/vault/identity_store.go index 9763116d7c..d2769a04d6 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -752,7 +752,7 @@ func (i *IdentityStore) invalidateEntityBucket(ctx context.Context, key string) } } - err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false) + _, err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false, false) if err != nil { i.logger.Error("failed to update entity in MemDB", "entity_id", bucketEntity.ID, "error", err) return @@ -1416,7 +1416,7 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical. } // Update MemDB and persist entity object - err = i.upsertEntityInTxn(ctx, txn, entity, nil, true) + _, err = i.upsertEntityInTxn(ctx, txn, entity, nil, true, false) if err != nil { return entity, entityCreated, err } diff --git a/vault/identity_store_conflicts.go b/vault/identity_store_conflicts.go index ec3729bbe0..3be5fee7ef 100644 --- a/vault/identity_store_conflicts.go +++ b/vault/identity_store_conflicts.go @@ -28,6 +28,7 @@ type ConflictResolver interface { ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) + Reload(ctx context.Context) } // errorResolver is a ConflictResolver that logs a warning message when a @@ -91,6 +92,10 @@ func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Ent return false, errDuplicateIdentityName } +// Reload is a no-op for the errorResolver implementation. +func (r *errorResolver) Reload(ctx context.Context) { +} + // duplicateReportingErrorResolver collects duplicate information and optionally // logs a report on all the duplicates. We don't embed an errorResolver here // because we _don't_ want it's side effect of warning on just some duplicates @@ -144,6 +149,10 @@ func (r *duplicateReportingErrorResolver) ResolveAliases(ctx context.Context, pa return false, errDuplicateIdentityName } +func (r *duplicateReportingErrorResolver) Reload(ctx context.Context) { + r.seenEntities = make(map[string][]*identity.Entity) +} + type identityDuplicateReportEntry struct { artifactType string scope string @@ -429,3 +438,7 @@ func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) { return false, nil } + +// Reload is a no-op for the renameResolver implementation. +func (r *renameResolver) Reload(ctx context.Context) { +} diff --git a/vault/identity_store_entities.go b/vault/identity_store_entities.go index 8f668023f5..001888d55b 100644 --- a/vault/identity_store_entities.go +++ b/vault/identity_store_entities.go @@ -16,11 +16,9 @@ import ( "github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/identity/mfa" "github.com/hashicorp/vault/helper/namespace" - "github.com/hashicorp/vault/helper/storagepacker" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/logical" - "google.golang.org/protobuf/types/known/anypb" ) func entityPathFields() map[string]*framework.FieldSchema { @@ -881,7 +879,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil } - fromEntity, err := i.MemDBEntityByID(fromEntityID, false) + fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, false) if err != nil { return nil, err, nil } @@ -984,7 +982,6 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit var fromEntityGroups []*identity.Group toEntityAccessors := make(map[string][]string) - for _, alias := range toEntity.Aliases { if accessors, ok := toEntityAccessors[alias.MountAccessor]; !ok { // While it is not supported to have multiple aliases with the same mount accessor in one entity @@ -1002,7 +999,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil } - fromEntity, err := i.MemDBEntityByID(fromEntityID, true) + fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, true) if err != nil { return nil, err, nil } @@ -1025,13 +1022,20 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit } for _, fromAlias := range fromEntity.Aliases { + // We're going to modify this alias but it's still a pointer to the one in + // MemDB that could be being read by other goroutines even though we might + // be removing from MemDB really shortly... + fromAlias, err = fromAlias.Clone() + if err != nil { + return nil, err, nil + } // If true, we need to handle conflicts (conflict = both aliases share the same mount accessor) if toAliasIds, ok := toEntityAccessors[fromAlias.MountAccessor]; ok { for _, toAliasId := range toAliasIds { // When forceMergeAliases is true (as part of the merge-during-upsert case), we make the decision - // for the user, and keep the to_entity alias, merging the from_entity + // for the user, and keep the from_entity alias // This case's code is the same as when the user selects to keep the from_entity alias - // but is kept separate for clarity + // but is kept separate for clarity. if forceMergeAliases { i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) @@ -1046,8 +1050,8 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit if err != nil { return nil, fmt.Errorf("aborting entity merge - failed to delete orphaned alias %q during merge into entity %q: %w", fromAlias.ID, toEntity.ID, err), nil } - // Remove the alias from the entity's list in memory too! - toEntity.DeleteAliasByID(toAliasId) + // Don't need to alter toEntity aliases since we it never contained + // the alias we're deleting. // Continue to next alias, as there's no alias to merge left in the from_entity continue @@ -1070,13 +1074,12 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit fromAlias.MergedFromCanonicalIDs = append(fromAlias.MergedFromCanonicalIDs, fromEntity.ID) - err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false) - if err != nil { - return nil, fmt.Errorf("failed to update alias during merge: %w", err), nil - } + // We don't insert into MemDB right now because we'll do that for all the + // aliases we want to end up with at the end to ensure they are inserted + // in the same order as when they load from storage next time. // Add the alias to the desired entity - toEntity.Aliases = append(toEntity.Aliases, fromAlias) + toEntity.UpsertAlias(fromAlias) } // If told to, merge policies @@ -1124,6 +1127,30 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit } } + // Normalize Alias order. We do this because we persist NonLocal and Local + // aliases separately and so after next reload local aliases will all come + // after non-local ones. While it's logically equivalent, it makes reasoning + // about merges and determinism very hard if the order of things in MemDB can + // change from one unseal to the next so we are especially careful to ensure + // it's exactly the same whether we just merged or on a subsequent load. + // persistEntities will already split these up and persist them separately, so + // we're kinda duplicating effort and code here but this should't happen often + // so I think it's fine. + nonLocalAliases, localAliases := splitLocalAliases(toEntity) + toEntity.Aliases = append(nonLocalAliases, localAliases...) + + // Don't forget to insert aliases into alias table that were part of + // `toEntity` but were not merged above (because they didn't conflict). This + // might re-insert the same aliases we just inserted above again but that's a + // no-op. TODO: maybe we could remove the memdb updates in the loop above and + // have them all be inserted here. + for _, alias := range toEntity.Aliases { + err = i.MemDBUpsertAliasInTxn(txn, alias, false) + if err != nil { + return nil, err, nil + } + } + // Update MemDB with changes to the entity we are merging to err = i.MemDBUpsertEntityInTxn(txn, toEntity) if err != nil { @@ -1140,16 +1167,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit if persist && !isPerfSecondaryOrStandby { // Persist the entity which we are merging to - toEntityAsAny, err := anypb.New(toEntity) - if err != nil { - return nil, err, nil - } - item := &storagepacker.Item{ - ID: toEntity.ID, - Message: toEntityAsAny, - } - - err = i.entityPacker.PutItem(ctx, item) + err = i.persistEntity(ctx, toEntity) if err != nil { return nil, err, nil } diff --git a/vault/identity_store_injector_testonly.go b/vault/identity_store_injector_testonly.go index dcd7a53c19..5cb4c0780f 100644 --- a/vault/identity_store_injector_testonly.go +++ b/vault/identity_store_injector_testonly.go @@ -359,7 +359,7 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc { flags.Count = 2 } - ids, _, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags) + ids, bucketIds, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags) if err != nil { i.logger.Error("error creating duplicate entities", "error", err) return logical.ErrorResponse("error creating duplicate entities"), err @@ -367,7 +367,8 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc { return &logical.Response{ Data: map[string]interface{}{ - "entity_ids": ids, + "entity_ids": ids, + "bucket_keys": bucketIds, }, }, nil } diff --git a/vault/identity_store_oidc_test.go b/vault/identity_store_oidc_test.go index 652b1f1492..c56bd439da 100644 --- a/vault/identity_store_oidc_test.go +++ b/vault/identity_store_oidc_test.go @@ -889,7 +889,7 @@ func TestOIDC_SignIDToken(t *testing.T) { txn := c.identityStore.db.Txn(true) defer txn.Abort() - err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) + _, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false) if err != nil { t.Fatal(err) } @@ -1020,7 +1020,7 @@ func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) { txn := c.identityStore.db.Txn(true) defer txn.Abort() - err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) + _, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false) if err != nil { t.Fatal(err) } @@ -1497,7 +1497,7 @@ func TestOIDC_Path_Introspect(t *testing.T) { txn := c.identityStore.db.Txn(true) defer txn.Abort() - err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) + _, err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false) if err != nil { t.Fatal(err) } diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go index 1dcde690f8..8156ced7f5 100644 --- a/vault/identity_store_test.go +++ b/vault/identity_store_test.go @@ -12,6 +12,7 @@ import ( "slices" "strconv" "strings" + "sync" "testing" "time" @@ -68,7 +69,7 @@ func TestIdentityStore_DeleteEntityAlias(t *testing.T) { BucketKey: c.identityStore.entityPacker.BucketKey("testEntityID"), } - err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false) + _, err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false, false) require.NoError(t, err) err = c.identityStore.deleteAliasesInEntityInTxn(txn, entity, []*identity.Alias{alias, alias2}) @@ -1422,51 +1423,45 @@ func TestIdentityStoreLoadingIsDeterministic(t *testing.T) { seedval, err = strconv.ParseInt(os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED"), 10, 64) require.NoError(t, err) } - seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test defer t.Logf("Test generated with seed: %d", seedval) - tests := []struct { - name string - flags *determinismTestFlags - }{ - { - name: "error-resolver-primary", - flags: &determinismTestFlags{ - identityDeduplication: false, - secondary: false, - seed: seed, - }, + tests := map[string]*determinismTestFlags{ + "error-resolver-primary": { + identityDeduplication: false, + secondary: false, }, - { - name: "identity-cleanup-primary", - flags: &determinismTestFlags{ - identityDeduplication: true, - secondary: false, - seed: seed, - }, - }, - - { - name: "error-resolver-secondary", - flags: &determinismTestFlags{ - identityDeduplication: false, - secondary: true, - seed: seed, - }, - }, - { - name: "identity-cleanup-secondary", - flags: &determinismTestFlags{ - identityDeduplication: true, - secondary: true, - seed: seed, - }, + "identity-cleanup-primary": { + identityDeduplication: true, + secondary: false, }, } - for _, test := range tests { - t.Run(t.Name()+"-"+test.name, func(t *testing.T) { - identityStoreLoadingIsDeterministic(t, test.flags) + // Hook to add cases that only differ in Enterprise + if entIdentityStoreDeterminismSupportsSecondary() { + tests["error-resolver-secondary"] = &determinismTestFlags{ + identityDeduplication: false, + secondary: true, + } + tests["identity-cleanup-secondary"] = &determinismTestFlags{ + identityDeduplication: true, + secondary: true, + } + } + + repeats := 50 + + for name, flags := range tests { + t.Run(t.Name()+"-"+name, func(t *testing.T) { + // Create a random source specific to this test case so every test case + // starts out from the identical random state given the same seed. We do + // want each iteration to explore different path though so we do it here + // not inside the test func. + seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test + flags.seed = seed + + for i := 0; i < repeats; i++ { + identityStoreLoadingIsDeterministic(t, flags) + } }) } } @@ -1538,7 +1533,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed) attachAlias(t, e, alias, upme, seed) attachAlias(t, e, localAlias, localMe, seed) - err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) + + err = c.identityStore.persistEntity(ctx, e) require.NoError(t, err) // Subset of entities get a duplicate alias and/or duplicate local alias. @@ -1551,7 +1547,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla for rnd < pDup && dupeNum < 10 { e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-dup-%d", i, dupeNum), c.identityStore.entityPacker, seed) attachAlias(t, e, alias, upme, seed) - err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) + + err = c.identityStore.persistEntity(ctx, e) require.NoError(t, err) // Toss again to see if we continue rnd = seed.Float64() @@ -1563,8 +1560,9 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla for rnd < pDup && dupeNum < 10 { e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-localdup-%d", i, dupeNum), c.identityStore.entityPacker, seed) attachAlias(t, e, localAlias, localMe, seed) - err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) + err = c.identityStore.persistEntity(ctx, e) require.NoError(t, err) + rnd = seed.Float64() dupeNum++ } @@ -1572,7 +1570,7 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla rnd = seed.Float64() for rnd < pDup { e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed) - err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) + err = c.identityStore.persistEntity(ctx, e) require.NoError(t, err) rnd = seed.Float64() } @@ -1624,20 +1622,20 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla } // Storage is now primed for the test. + require.NoError(t, c.Seal(rootToken)) + require.True(t, c.Sealed()) + for _, key := range sealKeys { + unsealed, err := c.Unseal(key) + require.NoError(t, err, "failed unseal on initial assertions") + if unsealed { + break + } + } + require.False(t, c.Sealed()) if identityDeduplication { // Perform an initial Seal/Unseal with duplicates injected to assert // initial state. - require.NoError(t, c.Seal(rootToken)) - require.True(t, c.Sealed()) - for _, key := range sealKeys { - unsealed, err := c.Unseal(key) - require.NoError(t, err, "failed unseal on initial assertions") - if unsealed { - break - } - } - require.False(t, c.Sealed()) // Test out the system backend ActivationFunc wiring c.FeatureActivationFlags.ActivateInMem(activationflags.IdentityDeduplication, true) @@ -1657,11 +1655,15 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla // build a list of human readable ids that we can compare. prevLoadedNames := []string{} var prevErr error + for i := 0; i < 10; i++ { err := c.identityStore.resetDB() require.NoError(t, err) + logger.Info(" ==> BEGIN LOAD ARTIFACTS", "i", i) + err = c.identityStore.loadArtifacts(ctx, true) + if i > 0 { require.Equal(t, prevErr, err) } @@ -1673,6 +1675,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla case <-c.identityStore.activateDeduplicationDone: default: } + logger.Info(" ==> END LOAD ARTIFACTS ", "i", i) + // Identity store should be loaded now. Check it's contents. loadedNames := []string{} @@ -1683,24 +1687,25 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla require.NoError(t, err) for item := iter.Next(); item != nil; item = iter.Next() { e := item.(*identity.Entity) - loadedNames = append(loadedNames, e.Name) + loadedNames = append(loadedNames, e.NamespaceID+"/"+e.Name+"_"+e.ID) for _, a := range e.Aliases { - loadedNames = append(loadedNames, a.Name) + loadedNames = append(loadedNames, a.MountAccessor+"/"+a.Name+"_"+a.ID) } } - // This is a non-triviality check to make sure we actually loaded stuff and - // are not just passing because of a bug in the test. - numLoaded := len(loadedNames) - require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i) // Standalone alias query iter, err = tx.LowerBound(entityAliasesTable, "id", "") require.NoError(t, err) for item := iter.Next(); item != nil; item = iter.Next() { a := item.(*identity.Alias) - loadedNames = append(loadedNames, a.Name) + loadedNames = append(loadedNames, "fromAliasTable:"+a.MountAccessor+"/"+a.Name+"_"+a.ID) } + // This is a non-triviality check to make sure we actually loaded stuff and + // are not just passing because of a bug in the test. + numLoaded := len(loadedNames) + require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i) + // Groups iter, err = tx.LowerBound(groupsTable, "id", "") require.NoError(t, err) @@ -1774,23 +1779,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) { // Storage is now primed for the test. // Setup a logger we can use to capture unseal logs - var unsealLogs []string - unsealLogger := &logFn{ - fn: func(msg string, args []interface{}) { - pairs := make([]string, 0, len(args)/2) - for pair := range slices.Chunk(args, 2) { - // Yes this will panic if we didn't log an even number of args but thats - // OK because that's a bug! - pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1])) - } - unsealLogs = append(unsealLogs, fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " "))) - }, - } + logBuf, stopCapture := startLogCapture(t, logger) - logger.RegisterSink(unsealLogger) err = c.identityStore.loadArtifacts(ctx, true) + stopCapture() + require.NoError(t, err) - logger.DeregisterSink(unsealLogger) // Identity store should be loaded now. Check it's contents. @@ -1799,23 +1793,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) { // many of these cases and seems strange to encode in a test that we want // broken behavior! numDupes := make(map[string]int) - uniqueIDs := make(map[string]struct{}) duplicateCountRe := regexp.MustCompile(`(\d+) (different-case( local)? entity alias|entity|group) duplicates found`) - // Be sure not to match attributes like alias_id= because there are dupes - // there. The report lines we care about always have a space before the id - // pair. - propsRe := regexp.MustCompile(`\s(id=(\S+))`) - for _, log := range unsealLogs { + for _, log := range logBuf.Lines() { if matches := duplicateCountRe.FindStringSubmatch(log); len(matches) >= 3 { num, _ := strconv.Atoi(matches[1]) numDupes[matches[2]] = num } - if propMatches := propsRe.FindStringSubmatch(log); len(propMatches) >= 3 { - artifactID := propMatches[2] - require.NotContains(t, uniqueIDs, artifactID, - "duplicate ID reported in logs for different artifacts") - uniqueIDs[artifactID] = struct{}{} - } } t.Logf("numDupes: %v", numDupes) wantAliases, wantLocalAliases, wantEntities, wantGroups := identityStoreDuplicateReportTestWantDuplicateCounts() @@ -1825,11 +1808,60 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) { require.Equal(t, wantGroups, numDupes["group"]) } +// logFn is a type we used to use here before concurrentLogBuffer was added. +// It's used in other tests in Enterprise so we need to keep it around to avoid +// breaking the build during merge type logFn struct { fn func(msg string, args []interface{}) } -// Accept implements hclog.SinkAdapter func (f *logFn) Accept(name string, level hclog.Level, msg string, args ...interface{}) { f.fn(msg, args) } + +// concrrentLogBuffer is a simple hclog sink that captures log output in an +// slice of lines in a goroutine-safe way. Use `startLogCapture` to use it. +type concurrentLogBuffer struct { + m sync.Mutex + b []string +} + +// Accept implements hclog.SinkAdapter +func (c *concurrentLogBuffer) Accept(name string, level hclog.Level, msg string, args ...interface{}) { + pairs := make([]string, 0, len(args)/2) + for pair := range slices.Chunk(args, 2) { + // Yes this will panic if we didn't log an even number of args but thats + // OK because that's a bug! + pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1])) + } + line := fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " ")) + + c.m.Lock() + defer c.m.Unlock() + + c.b = append(c.b, line) +} + +// Lines returns all the lines logged since starting the capture. +func (c *concurrentLogBuffer) Lines() []string { + c.m.Lock() + defer c.m.Unlock() + + return c.b +} + +// startLogCaptue attaches a logFn to the logger and returns a channel that will +// be sent each line of log output in a basec flat text format. It returns a +// cancel/stop func that should be called when capture should stop or at the end +// of processing. +func startLogCapture(t *testing.T, logger *corehelpers.TestLogger) (*concurrentLogBuffer, func()) { + t.Helper() + b := &concurrentLogBuffer{ + b: make([]string, 0, 1024), + } + logger.RegisterSink(b) + stop := func() { + logger.DeregisterSink(b) + } + return b, stop +} diff --git a/vault/identity_store_test_stubs_oss.go b/vault/identity_store_test_stubs_oss.go index 02a78851a6..3e3a843adf 100644 --- a/vault/identity_store_test_stubs_oss.go +++ b/vault/identity_store_test_stubs_oss.go @@ -13,6 +13,16 @@ import ( //go:generate go run github.com/hashicorp/vault/tools/stubmaker +// entIdentityStoreDeterminismSupportsSecondary is a hack to drop duplicate +// tests in CE where the secondary param will only cause the no-op methods below +// to run which is functionally the same. It would be cleaner to define +// different tests in CE and ENT but it's a table test with customer test-only +// struct types which makes it a massive pain to have ent-ce specific code +// interact with the test arguments. +func entIdentityStoreDeterminismSupportsSecondary() bool { + return false +} + func entIdentityStoreDeterminismSecondaryTestSetup(t *testing.T, ctx context.Context, c *Core, me, localme *MountEntry, seed *rand.Rand) { // no op } diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 193ed9db2b..f520b4785e 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -10,6 +10,7 @@ import ( "math/rand" "strings" "sync" + "sync/atomic" "testing" "time" @@ -44,14 +45,47 @@ func (i *IdentityStore) loadArtifacts(ctx context.Context, isActive bool) error return nil } + loadFuncEntities := func(context.Context) error { + reload, err := i.loadEntities(ctx, isActive, false) + if err != nil { + return fmt.Errorf("failed to load entities: %w", err) + } + if reload { + // The persistMerges flag is used to fix a non-determinism issue in duplicate entities with global alias merges. + // This does not solve the case for local alias merges. + // Previously, alias merges could inconsistently flip between nodes due to bucket invalidation + // and reprocessing on standbys, leading to divergence from the primary. + // This flag ensures deterministic merges across all nodes by preventing unintended reinsertions + // of previously merged entities during bucket diffs. + persistMerges := false + if !i.localNode.ReplicationState().HasState( + consts.ReplicationDRSecondary| + consts.ReplicationPerformanceSecondary, + ) && isActive { + persistMerges = true + } + + // Since we're reloading entities, we need to inform the ConflictResolver about it so that it can + // clean up data related to the previous load. + i.conflictResolver.Reload(ctx) + + _, err := i.loadEntities(ctx, isActive, persistMerges) + if err != nil { + return fmt.Errorf("failed to load entities: %w", err) + } + } + return nil + } + loadFunc := func(context.Context) error { i.logger.Debug("loading identity store artifacts with", "case_sensitive", !i.disableLowerCasedNames, "conflict_resolver", i.conflictResolver) - if err := i.loadEntities(ctx, isActive); err != nil { + if err := loadFuncEntities(ctx); err != nil { return fmt.Errorf("failed to load entities: %w", err) } + if err := i.loadGroups(ctx, isActive); err != nil { return fmt.Errorf("failed to load groups: %w", err) } @@ -440,15 +474,19 @@ func (i *IdentityStore) loadCachedEntitiesOfLocalAliases(ctx context.Context) er return nil } -func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool) error { +func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool, persistMerges bool) (bool, error) { // Accumulate existing entities i.logger.Debug("loading entities") existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix) if err != nil { - return fmt.Errorf("failed to scan for entities: %w", err) + return false, fmt.Errorf("failed to scan for entities: %w", err) } i.logger.Debug("entities collected", "num_existing", len(existing)) + // Reponsible for flagging callers if entities should be reloaded in case + // entity merges need to be persisted. + var reload atomic.Bool + duplicatedAccessors := make(map[string]struct{}) // Make the channels used for the worker pool. We send the index into existing // so that we can keep results in the same order as inputs. Note that this is @@ -585,23 +623,23 @@ LOOP: if err != nil && !i.disableLowerCasedNames { return err } + persist := false if modified { - // If we modified the group we need to persist the changes to avoid bugs + // If we modified the entity, we need to persist the changes to avoid bugs // where memDB and storage are out of sync in the future (e.g. after // invalidations of other items in the same bucket later). We do this // _even_ if `persist=false` because it is in general during unseal but // this is exactly when we need to fix these. We must be _really_ // careful to only do this on primary active node though which is the // only source of truth that should have write access to groups across a - // cluster since they are always non-local. Note that we check !Stadby - // and !secondary because we still need to write back even if this is a - // single cluster with no replication setup and I'm not _sure_ that we - // report such a cluster as a primary. + // cluster since they are always non-local. Note that we check !Standby + // and !Secondary because we still need to write back even if this is a + // single cluster with no replication. + if !i.localNode.ReplicationState().HasState( consts.ReplicationDRSecondary| - consts.ReplicationPerformanceSecondary| - consts.ReplicationPerformanceStandby, + consts.ReplicationPerformanceSecondary, ) && isActive { persist = true } @@ -631,21 +669,28 @@ LOOP: tx = i.db.Txn(true) defer tx.Abort() } - // Only update MemDB and don't hit the storage again - err = i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist) + + // Only update MemDB and don't hit the storage again unless we are + // merging and on a primary active node. + shouldReload, err := i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist, persistMerges) if err != nil { return fmt.Errorf("failed to update entity in MemDB: %w", err) } upsertedItems += toBeUpserted + + if shouldReload { + reload.CompareAndSwap(false, true) + } } if upsertedItems > 0 { tx.Commit() } return nil } + err := load(entities) if err != nil { - return err + return false, err } } @@ -654,7 +699,7 @@ LOOP: // Let all go routines finish wg.Wait() if err != nil { - return err + return false, err } // Flatten the map into a list of keys, in order to log them @@ -669,7 +714,7 @@ LOOP: i.logger.Info("entities restored") } - return nil + return reload.Load(), nil } // loadLocalAliasesForEntity upserts local aliases into the entity by retrieving @@ -730,17 +775,18 @@ func getAccessorsOnDuplicateAliases(aliases []*identity.Alias) []string { // false, then storage will not be updated. When an alias is transferred from // one entity to another, both the source and destination entities should get // updated, in which case, callers should send in both entity and -// previousEntity. -func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error { +// previousEntity. persistMerges is ignored if persist = true but if persist = +// false it allows the caller to request that we persist the data only if it is +// changes by a merge caused by a duplicate alias. +func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist, persistMerges bool) (reload bool, err error) { defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now()) - var err error if txn == nil { - return errors.New("txn is nil") + return false, errors.New("txn is nil") } if entity == nil { - return errors.New("entity is nil") + return false, errors.New("entity is nil") } if entity.NamespaceID == "" { @@ -748,16 +794,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e } if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID { - return errors.New("entity and previous entity are not in the same namespace") + return false, errors.New("entity and previous entity are not in the same namespace") } aliasFactors := make([]string, len(entity.Aliases)) for index, alias := range entity.Aliases { // Verify that alias is not associated to a different one already - aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, false, false) + aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, false, false) if err != nil { - return err + return false, err } if alias.NamespaceID == "" { @@ -768,14 +814,14 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e case aliasByFactors == nil: // Not found, no merging needed, just check namespace if alias.NamespaceID != entity.NamespaceID { - return errors.New("alias and entity are not in the same namespace") + return false, errors.New("alias and entity are not in the same namespace") } case aliasByFactors.CanonicalID == entity.ID: // Lookup found the same entity, so it's already attached to the // right place if aliasByFactors.NamespaceID != entity.NamespaceID { - return errors.New("alias from factors and entity are not in the same namespace") + return false, errors.New("alias from factors and entity are not in the same namespace") } case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID: @@ -811,17 +857,18 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e "entity_aliases", entity.Aliases, "alias_by_factors", aliasByFactors) - respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persist) + persistMerge := persist || persistMerges + respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persistMerge) switch { case respErr != nil: - return respErr + return false, respErr case intErr != nil: - return intErr + return false, intErr } // The entity and aliases will be loaded into memdb and persisted // as a result of the merge, so we are done here - return nil + return true, nil } // This is subtle. We want to call `ResolveAliases` so that the resolver can @@ -852,13 +899,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e // especially desirable to me, but we'd rather not change behavior for now. if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) && conflictErr != nil && !i.disableLowerCasedNames { - return conflictErr + return false, conflictErr } // Insert or update alias in MemDB using the transaction created above err = i.MemDBUpsertAliasInTxn(txn, alias, false) if err != nil { - return err + return false, err } aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor @@ -868,13 +915,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e if previousEntity != nil { err = i.MemDBUpsertEntityInTxn(txn, previousEntity) if err != nil { - return err + return false, err } if persist { // Persist the previous entity object if err := i.persistEntity(ctx, previousEntity); err != nil { - return err + return false, err } } } @@ -882,16 +929,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e // Insert or update entity in MemDB using the transaction created above err = i.MemDBUpsertEntityInTxn(txn, entity) if err != nil { - return err + return false, err } if persist { if err := i.persistEntity(ctx, entity); err != nil { - return err + return false, err } } - return nil + return false, nil } func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.Alias, entity *identity.Entity, updateDb bool) (*identity.Alias, error) { @@ -970,7 +1017,7 @@ func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.A if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil { return nil, err } - if err := i.upsertEntityInTxn(ctx, txn, entity, nil, false); err != nil { + if _, err := i.upsertEntityInTxn(ctx, txn, entity, nil, false, false); err != nil { return nil, err } txn.Commit() @@ -1004,6 +1051,21 @@ func (i *IdentityStore) cacheTemporaryEntity(ctx context.Context, entity *identi return nil } +func splitLocalAliases(entity *identity.Entity) ([]*identity.Alias, []*identity.Alias) { + var localAliases []*identity.Alias + var nonLocalAliases []*identity.Alias + + for _, alias := range entity.Aliases { + if alias.Local { + localAliases = append(localAliases, alias) + } else { + nonLocalAliases = append(nonLocalAliases, alias) + } + } + + return nonLocalAliases, localAliases +} + func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Entity) error { // If the entity that is passed into this function is resulting from a memdb // query without cloning, then modifying it will result in a direct DB edit, @@ -1016,16 +1078,7 @@ func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Enti } // Separate the local and non-local aliases. - var localAliases []*identity.Alias - var nonLocalAliases []*identity.Alias - for _, alias := range entity.Aliases { - switch alias.Local { - case true: - localAliases = append(localAliases, alias) - default: - nonLocalAliases = append(nonLocalAliases, alias) - } - } + nonLocalAliases, localAliases := splitLocalAliases(entity) // Store the entity with non-local aliases. entity.Aliases = nonLocalAliases @@ -1076,7 +1129,7 @@ func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entit txn := i.db.Txn(true) defer txn.Abort() - err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist) + _, err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist, false) if err != nil { return err }