Persist automatic entity merges (#29568)

* Persist automatic entity merges

* Local aliases write in test

* Add identity entity merge unit property test

* N entities merge

* Persist alias duplication fix

---------

Co-authored-by: Paul Banks <pbanks@hashicorp.com>
Co-authored-by: Mike Palmiotto <mike.palmiotto@hashicorp.com>
This commit is contained in:
Bianca 2025-03-06 23:06:20 +01:00 committed by GitHub
parent 735016d653
commit 6eeb228889
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 363 additions and 168 deletions

View File

@ -199,7 +199,8 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo
// Look for a matching storage entries and delete them from the list. // Look for a matching storage entries and delete them from the list.
for i := 0; i < len(bucket.Items); i++ { for i := 0; i < len(bucket.Items); i++ {
if _, ok := itemsToRemove[bucket.Items[i].ID]; ok { if _, ok := itemsToRemove[bucket.Items[i].ID]; ok {
bucket.Items[i] = bucket.Items[len(bucket.Items)-1] copy(bucket.Items[i:], bucket.Items[i+1:])
bucket.Items[len(bucket.Items)-1] = nil // allow GC
bucket.Items = bucket.Items[:len(bucket.Items)-1] bucket.Items = bucket.Items[:len(bucket.Items)-1]
// Since we just moved a value to position i we need to // Since we just moved a value to position i we need to

View File

@ -6,6 +6,7 @@ package storagepacker
import ( import (
"context" "context"
"fmt" "fmt"
"math/rand"
"testing" "testing"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -68,6 +69,49 @@ func BenchmarkStoragePacker(b *testing.B) {
} }
} }
func BenchmarkStoragePacker_DeleteMultiple(b *testing.B) {
b.StopTimer()
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
if err != nil {
b.Fatal(err)
}
ctx := context.Background()
// Persist a storage entry
for i := 0; i <= 100000; i++ {
item := &Item{
ID: fmt.Sprintf("item%d", i),
}
err = storagePacker.PutItem(ctx, item)
if err != nil {
b.Fatal(err)
}
// Verify that it can be read
fetchedItem, err := storagePacker.GetItem(item.ID)
if err != nil {
b.Fatal(err)
}
if fetchedItem == nil {
b.Fatalf("failed to read the stored item")
}
if item.ID != fetchedItem.ID {
b.Fatalf("bad: item ID; expected: %q\n actual: %q\n", item.ID, fetchedItem.ID)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
err = storagePacker.DeleteItem(ctx, fmt.Sprintf("item%d", rand.Intn(100000)))
if err != nil {
b.Fatal(err)
}
}
}
func TestStoragePacker(t *testing.T) { func TestStoragePacker(t *testing.T) {
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "") storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
if err != nil { if err != nil {

View File

@ -11,10 +11,12 @@ import (
log "github.com/hashicorp/go-hclog" log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap" ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap"
"github.com/hashicorp/vault/helper/testhelpers/minimal" "github.com/hashicorp/vault/helper/testhelpers/minimal"
"github.com/hashicorp/vault/sdk/helper/ldaputil" "github.com/hashicorp/vault/sdk/helper/ldaputil"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -642,3 +644,24 @@ func addRemoveLdapGroupMember(t *testing.T, cfg *ldaputil.ConfigEntry, userCN st
t.Fatal(err) t.Fatal(err)
} }
} }
func findEntityFromDuplicateSet(t *testing.T, c *vault.TestClusterCore, entityIDs []string) *identity.Entity {
t.Helper()
var entity *identity.Entity
// Try fetch each ID and ensure exactly one is present
found := 0
for _, entityID := range entityIDs {
e, err := c.IdentityStore().MemDBEntityByID(entityID, true)
require.NoError(t, err)
if e != nil {
found++
entity = e
}
}
// More than one means they didn't merge as expected!
require.Equal(t, found, 1,
"node %s does not have exactly one duplicate from the set", c.NodeID)
return entity
}

View File

@ -752,7 +752,7 @@ func (i *IdentityStore) invalidateEntityBucket(ctx context.Context, key string)
} }
} }
err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false) _, err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false, false)
if err != nil { if err != nil {
i.logger.Error("failed to update entity in MemDB", "entity_id", bucketEntity.ID, "error", err) i.logger.Error("failed to update entity in MemDB", "entity_id", bucketEntity.ID, "error", err)
return return
@ -1416,7 +1416,7 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
} }
// Update MemDB and persist entity object // Update MemDB and persist entity object
err = i.upsertEntityInTxn(ctx, txn, entity, nil, true) _, err = i.upsertEntityInTxn(ctx, txn, entity, nil, true, false)
if err != nil { if err != nil {
return entity, entityCreated, err return entity, entityCreated, err
} }

View File

@ -28,6 +28,7 @@ type ConflictResolver interface {
ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error) ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error)
ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error) ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error)
ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error)
Reload(ctx context.Context)
} }
// errorResolver is a ConflictResolver that logs a warning message when a // errorResolver is a ConflictResolver that logs a warning message when a
@ -91,6 +92,10 @@ func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Ent
return false, errDuplicateIdentityName return false, errDuplicateIdentityName
} }
// Reload is a no-op for the errorResolver implementation.
func (r *errorResolver) Reload(ctx context.Context) {
}
// duplicateReportingErrorResolver collects duplicate information and optionally // duplicateReportingErrorResolver collects duplicate information and optionally
// logs a report on all the duplicates. We don't embed an errorResolver here // logs a report on all the duplicates. We don't embed an errorResolver here
// because we _don't_ want it's side effect of warning on just some duplicates // because we _don't_ want it's side effect of warning on just some duplicates
@ -144,6 +149,10 @@ func (r *duplicateReportingErrorResolver) ResolveAliases(ctx context.Context, pa
return false, errDuplicateIdentityName return false, errDuplicateIdentityName
} }
func (r *duplicateReportingErrorResolver) Reload(ctx context.Context) {
r.seenEntities = make(map[string][]*identity.Entity)
}
type identityDuplicateReportEntry struct { type identityDuplicateReportEntry struct {
artifactType string artifactType string
scope string scope string
@ -429,3 +438,7 @@ func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate
func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) { func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) {
return false, nil return false, nil
} }
// Reload is a no-op for the renameResolver implementation.
func (r *renameResolver) Reload(ctx context.Context) {
}

View File

@ -16,11 +16,9 @@ import (
"github.com/hashicorp/vault/helper/identity" "github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/identity/mfa" "github.com/hashicorp/vault/helper/identity/mfa"
"github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/sdk/logical"
"google.golang.org/protobuf/types/known/anypb"
) )
func entityPathFields() map[string]*framework.FieldSchema { func entityPathFields() map[string]*framework.FieldSchema {
@ -881,7 +879,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
} }
fromEntity, err := i.MemDBEntityByID(fromEntityID, false) fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, false)
if err != nil { if err != nil {
return nil, err, nil return nil, err, nil
} }
@ -984,7 +982,6 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
var fromEntityGroups []*identity.Group var fromEntityGroups []*identity.Group
toEntityAccessors := make(map[string][]string) toEntityAccessors := make(map[string][]string)
for _, alias := range toEntity.Aliases { for _, alias := range toEntity.Aliases {
if accessors, ok := toEntityAccessors[alias.MountAccessor]; !ok { if accessors, ok := toEntityAccessors[alias.MountAccessor]; !ok {
// While it is not supported to have multiple aliases with the same mount accessor in one entity // While it is not supported to have multiple aliases with the same mount accessor in one entity
@ -1002,7 +999,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
} }
fromEntity, err := i.MemDBEntityByID(fromEntityID, true) fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, true)
if err != nil { if err != nil {
return nil, err, nil return nil, err, nil
} }
@ -1025,13 +1022,20 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
} }
for _, fromAlias := range fromEntity.Aliases { for _, fromAlias := range fromEntity.Aliases {
// We're going to modify this alias but it's still a pointer to the one in
// MemDB that could be being read by other goroutines even though we might
// be removing from MemDB really shortly...
fromAlias, err = fromAlias.Clone()
if err != nil {
return nil, err, nil
}
// If true, we need to handle conflicts (conflict = both aliases share the same mount accessor) // If true, we need to handle conflicts (conflict = both aliases share the same mount accessor)
if toAliasIds, ok := toEntityAccessors[fromAlias.MountAccessor]; ok { if toAliasIds, ok := toEntityAccessors[fromAlias.MountAccessor]; ok {
for _, toAliasId := range toAliasIds { for _, toAliasId := range toAliasIds {
// When forceMergeAliases is true (as part of the merge-during-upsert case), we make the decision // When forceMergeAliases is true (as part of the merge-during-upsert case), we make the decision
// for the user, and keep the to_entity alias, merging the from_entity // for the user, and keep the from_entity alias
// This case's code is the same as when the user selects to keep the from_entity alias // This case's code is the same as when the user selects to keep the from_entity alias
// but is kept separate for clarity // but is kept separate for clarity.
if forceMergeAliases { if forceMergeAliases {
i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId) i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId)
err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false) err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false)
@ -1046,8 +1050,8 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
if err != nil { if err != nil {
return nil, fmt.Errorf("aborting entity merge - failed to delete orphaned alias %q during merge into entity %q: %w", fromAlias.ID, toEntity.ID, err), nil return nil, fmt.Errorf("aborting entity merge - failed to delete orphaned alias %q during merge into entity %q: %w", fromAlias.ID, toEntity.ID, err), nil
} }
// Remove the alias from the entity's list in memory too! // Don't need to alter toEntity aliases since we it never contained
toEntity.DeleteAliasByID(toAliasId) // the alias we're deleting.
// Continue to next alias, as there's no alias to merge left in the from_entity // Continue to next alias, as there's no alias to merge left in the from_entity
continue continue
@ -1070,13 +1074,12 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
fromAlias.MergedFromCanonicalIDs = append(fromAlias.MergedFromCanonicalIDs, fromEntity.ID) fromAlias.MergedFromCanonicalIDs = append(fromAlias.MergedFromCanonicalIDs, fromEntity.ID)
err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false) // We don't insert into MemDB right now because we'll do that for all the
if err != nil { // aliases we want to end up with at the end to ensure they are inserted
return nil, fmt.Errorf("failed to update alias during merge: %w", err), nil // in the same order as when they load from storage next time.
}
// Add the alias to the desired entity // Add the alias to the desired entity
toEntity.Aliases = append(toEntity.Aliases, fromAlias) toEntity.UpsertAlias(fromAlias)
} }
// If told to, merge policies // If told to, merge policies
@ -1124,6 +1127,30 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
} }
} }
// Normalize Alias order. We do this because we persist NonLocal and Local
// aliases separately and so after next reload local aliases will all come
// after non-local ones. While it's logically equivalent, it makes reasoning
// about merges and determinism very hard if the order of things in MemDB can
// change from one unseal to the next so we are especially careful to ensure
// it's exactly the same whether we just merged or on a subsequent load.
// persistEntities will already split these up and persist them separately, so
// we're kinda duplicating effort and code here but this should't happen often
// so I think it's fine.
nonLocalAliases, localAliases := splitLocalAliases(toEntity)
toEntity.Aliases = append(nonLocalAliases, localAliases...)
// Don't forget to insert aliases into alias table that were part of
// `toEntity` but were not merged above (because they didn't conflict). This
// might re-insert the same aliases we just inserted above again but that's a
// no-op. TODO: maybe we could remove the memdb updates in the loop above and
// have them all be inserted here.
for _, alias := range toEntity.Aliases {
err = i.MemDBUpsertAliasInTxn(txn, alias, false)
if err != nil {
return nil, err, nil
}
}
// Update MemDB with changes to the entity we are merging to // Update MemDB with changes to the entity we are merging to
err = i.MemDBUpsertEntityInTxn(txn, toEntity) err = i.MemDBUpsertEntityInTxn(txn, toEntity)
if err != nil { if err != nil {
@ -1140,16 +1167,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
if persist && !isPerfSecondaryOrStandby { if persist && !isPerfSecondaryOrStandby {
// Persist the entity which we are merging to // Persist the entity which we are merging to
toEntityAsAny, err := anypb.New(toEntity) err = i.persistEntity(ctx, toEntity)
if err != nil {
return nil, err, nil
}
item := &storagepacker.Item{
ID: toEntity.ID,
Message: toEntityAsAny,
}
err = i.entityPacker.PutItem(ctx, item)
if err != nil { if err != nil {
return nil, err, nil return nil, err, nil
} }

View File

@ -359,7 +359,7 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc {
flags.Count = 2 flags.Count = 2
} }
ids, _, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags) ids, bucketIds, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags)
if err != nil { if err != nil {
i.logger.Error("error creating duplicate entities", "error", err) i.logger.Error("error creating duplicate entities", "error", err)
return logical.ErrorResponse("error creating duplicate entities"), err return logical.ErrorResponse("error creating duplicate entities"), err
@ -367,7 +367,8 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc {
return &logical.Response{ return &logical.Response{
Data: map[string]interface{}{ Data: map[string]interface{}{
"entity_ids": ids, "entity_ids": ids,
"bucket_keys": bucketIds,
}, },
}, nil }, nil
} }

View File

@ -889,7 +889,7 @@ func TestOIDC_SignIDToken(t *testing.T) {
txn := c.identityStore.db.Txn(true) txn := c.identityStore.db.Txn(true)
defer txn.Abort() defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) _, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1020,7 +1020,7 @@ func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) {
txn := c.identityStore.db.Txn(true) txn := c.identityStore.db.Txn(true)
defer txn.Abort() defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) _, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1497,7 +1497,7 @@ func TestOIDC_Path_Introspect(t *testing.T) {
txn := c.identityStore.db.Txn(true) txn := c.identityStore.db.Txn(true)
defer txn.Abort() defer txn.Abort()
err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true) _, err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -12,6 +12,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -68,7 +69,7 @@ func TestIdentityStore_DeleteEntityAlias(t *testing.T) {
BucketKey: c.identityStore.entityPacker.BucketKey("testEntityID"), BucketKey: c.identityStore.entityPacker.BucketKey("testEntityID"),
} }
err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false) _, err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false, false)
require.NoError(t, err) require.NoError(t, err)
err = c.identityStore.deleteAliasesInEntityInTxn(txn, entity, []*identity.Alias{alias, alias2}) err = c.identityStore.deleteAliasesInEntityInTxn(txn, entity, []*identity.Alias{alias, alias2})
@ -1422,51 +1423,45 @@ func TestIdentityStoreLoadingIsDeterministic(t *testing.T) {
seedval, err = strconv.ParseInt(os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED"), 10, 64) seedval, err = strconv.ParseInt(os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED"), 10, 64)
require.NoError(t, err) require.NoError(t, err)
} }
seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test
defer t.Logf("Test generated with seed: %d", seedval) defer t.Logf("Test generated with seed: %d", seedval)
tests := []struct { tests := map[string]*determinismTestFlags{
name string "error-resolver-primary": {
flags *determinismTestFlags identityDeduplication: false,
}{ secondary: false,
{
name: "error-resolver-primary",
flags: &determinismTestFlags{
identityDeduplication: false,
secondary: false,
seed: seed,
},
}, },
{ "identity-cleanup-primary": {
name: "identity-cleanup-primary", identityDeduplication: true,
flags: &determinismTestFlags{ secondary: false,
identityDeduplication: true,
secondary: false,
seed: seed,
},
},
{
name: "error-resolver-secondary",
flags: &determinismTestFlags{
identityDeduplication: false,
secondary: true,
seed: seed,
},
},
{
name: "identity-cleanup-secondary",
flags: &determinismTestFlags{
identityDeduplication: true,
secondary: true,
seed: seed,
},
}, },
} }
for _, test := range tests { // Hook to add cases that only differ in Enterprise
t.Run(t.Name()+"-"+test.name, func(t *testing.T) { if entIdentityStoreDeterminismSupportsSecondary() {
identityStoreLoadingIsDeterministic(t, test.flags) tests["error-resolver-secondary"] = &determinismTestFlags{
identityDeduplication: false,
secondary: true,
}
tests["identity-cleanup-secondary"] = &determinismTestFlags{
identityDeduplication: true,
secondary: true,
}
}
repeats := 50
for name, flags := range tests {
t.Run(t.Name()+"-"+name, func(t *testing.T) {
// Create a random source specific to this test case so every test case
// starts out from the identical random state given the same seed. We do
// want each iteration to explore different path though so we do it here
// not inside the test func.
seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test
flags.seed = seed
for i := 0; i < repeats; i++ {
identityStoreLoadingIsDeterministic(t, flags)
}
}) })
} }
} }
@ -1538,7 +1533,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed) e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed)
attachAlias(t, e, alias, upme, seed) attachAlias(t, e, alias, upme, seed)
attachAlias(t, e, localAlias, localMe, seed) attachAlias(t, e, localAlias, localMe, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err) require.NoError(t, err)
// Subset of entities get a duplicate alias and/or duplicate local alias. // Subset of entities get a duplicate alias and/or duplicate local alias.
@ -1551,7 +1547,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
for rnd < pDup && dupeNum < 10 { for rnd < pDup && dupeNum < 10 {
e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-dup-%d", i, dupeNum), c.identityStore.entityPacker, seed) e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-dup-%d", i, dupeNum), c.identityStore.entityPacker, seed)
attachAlias(t, e, alias, upme, seed) attachAlias(t, e, alias, upme, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err) require.NoError(t, err)
// Toss again to see if we continue // Toss again to see if we continue
rnd = seed.Float64() rnd = seed.Float64()
@ -1563,8 +1560,9 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
for rnd < pDup && dupeNum < 10 { for rnd < pDup && dupeNum < 10 {
e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-localdup-%d", i, dupeNum), c.identityStore.entityPacker, seed) e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-localdup-%d", i, dupeNum), c.identityStore.entityPacker, seed)
attachAlias(t, e, localAlias, localMe, seed) attachAlias(t, e, localAlias, localMe, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err) require.NoError(t, err)
rnd = seed.Float64() rnd = seed.Float64()
dupeNum++ dupeNum++
} }
@ -1572,7 +1570,7 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
rnd = seed.Float64() rnd = seed.Float64()
for rnd < pDup { for rnd < pDup {
e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed) e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e) err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err) require.NoError(t, err)
rnd = seed.Float64() rnd = seed.Float64()
} }
@ -1624,20 +1622,20 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
} }
// Storage is now primed for the test. // Storage is now primed for the test.
require.NoError(t, c.Seal(rootToken))
require.True(t, c.Sealed())
for _, key := range sealKeys {
unsealed, err := c.Unseal(key)
require.NoError(t, err, "failed unseal on initial assertions")
if unsealed {
break
}
}
require.False(t, c.Sealed())
if identityDeduplication { if identityDeduplication {
// Perform an initial Seal/Unseal with duplicates injected to assert // Perform an initial Seal/Unseal with duplicates injected to assert
// initial state. // initial state.
require.NoError(t, c.Seal(rootToken))
require.True(t, c.Sealed())
for _, key := range sealKeys {
unsealed, err := c.Unseal(key)
require.NoError(t, err, "failed unseal on initial assertions")
if unsealed {
break
}
}
require.False(t, c.Sealed())
// Test out the system backend ActivationFunc wiring // Test out the system backend ActivationFunc wiring
c.FeatureActivationFlags.ActivateInMem(activationflags.IdentityDeduplication, true) c.FeatureActivationFlags.ActivateInMem(activationflags.IdentityDeduplication, true)
@ -1657,11 +1655,15 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
// build a list of human readable ids that we can compare. // build a list of human readable ids that we can compare.
prevLoadedNames := []string{} prevLoadedNames := []string{}
var prevErr error var prevErr error
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
err := c.identityStore.resetDB() err := c.identityStore.resetDB()
require.NoError(t, err) require.NoError(t, err)
logger.Info(" ==> BEGIN LOAD ARTIFACTS", "i", i)
err = c.identityStore.loadArtifacts(ctx, true) err = c.identityStore.loadArtifacts(ctx, true)
if i > 0 { if i > 0 {
require.Equal(t, prevErr, err) require.Equal(t, prevErr, err)
} }
@ -1673,6 +1675,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
case <-c.identityStore.activateDeduplicationDone: case <-c.identityStore.activateDeduplicationDone:
default: default:
} }
logger.Info(" ==> END LOAD ARTIFACTS ", "i", i)
// Identity store should be loaded now. Check it's contents. // Identity store should be loaded now. Check it's contents.
loadedNames := []string{} loadedNames := []string{}
@ -1683,24 +1687,25 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
require.NoError(t, err) require.NoError(t, err)
for item := iter.Next(); item != nil; item = iter.Next() { for item := iter.Next(); item != nil; item = iter.Next() {
e := item.(*identity.Entity) e := item.(*identity.Entity)
loadedNames = append(loadedNames, e.Name) loadedNames = append(loadedNames, e.NamespaceID+"/"+e.Name+"_"+e.ID)
for _, a := range e.Aliases { for _, a := range e.Aliases {
loadedNames = append(loadedNames, a.Name) loadedNames = append(loadedNames, a.MountAccessor+"/"+a.Name+"_"+a.ID)
} }
} }
// This is a non-triviality check to make sure we actually loaded stuff and
// are not just passing because of a bug in the test.
numLoaded := len(loadedNames)
require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i)
// Standalone alias query // Standalone alias query
iter, err = tx.LowerBound(entityAliasesTable, "id", "") iter, err = tx.LowerBound(entityAliasesTable, "id", "")
require.NoError(t, err) require.NoError(t, err)
for item := iter.Next(); item != nil; item = iter.Next() { for item := iter.Next(); item != nil; item = iter.Next() {
a := item.(*identity.Alias) a := item.(*identity.Alias)
loadedNames = append(loadedNames, a.Name) loadedNames = append(loadedNames, "fromAliasTable:"+a.MountAccessor+"/"+a.Name+"_"+a.ID)
} }
// This is a non-triviality check to make sure we actually loaded stuff and
// are not just passing because of a bug in the test.
numLoaded := len(loadedNames)
require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i)
// Groups // Groups
iter, err = tx.LowerBound(groupsTable, "id", "") iter, err = tx.LowerBound(groupsTable, "id", "")
require.NoError(t, err) require.NoError(t, err)
@ -1774,23 +1779,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
// Storage is now primed for the test. // Storage is now primed for the test.
// Setup a logger we can use to capture unseal logs // Setup a logger we can use to capture unseal logs
var unsealLogs []string logBuf, stopCapture := startLogCapture(t, logger)
unsealLogger := &logFn{
fn: func(msg string, args []interface{}) {
pairs := make([]string, 0, len(args)/2)
for pair := range slices.Chunk(args, 2) {
// Yes this will panic if we didn't log an even number of args but thats
// OK because that's a bug!
pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1]))
}
unsealLogs = append(unsealLogs, fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " ")))
},
}
logger.RegisterSink(unsealLogger)
err = c.identityStore.loadArtifacts(ctx, true) err = c.identityStore.loadArtifacts(ctx, true)
stopCapture()
require.NoError(t, err) require.NoError(t, err)
logger.DeregisterSink(unsealLogger)
// Identity store should be loaded now. Check it's contents. // Identity store should be loaded now. Check it's contents.
@ -1799,23 +1793,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
// many of these cases and seems strange to encode in a test that we want // many of these cases and seems strange to encode in a test that we want
// broken behavior! // broken behavior!
numDupes := make(map[string]int) numDupes := make(map[string]int)
uniqueIDs := make(map[string]struct{})
duplicateCountRe := regexp.MustCompile(`(\d+) (different-case( local)? entity alias|entity|group) duplicates found`) duplicateCountRe := regexp.MustCompile(`(\d+) (different-case( local)? entity alias|entity|group) duplicates found`)
// Be sure not to match attributes like alias_id= because there are dupes for _, log := range logBuf.Lines() {
// there. The report lines we care about always have a space before the id
// pair.
propsRe := regexp.MustCompile(`\s(id=(\S+))`)
for _, log := range unsealLogs {
if matches := duplicateCountRe.FindStringSubmatch(log); len(matches) >= 3 { if matches := duplicateCountRe.FindStringSubmatch(log); len(matches) >= 3 {
num, _ := strconv.Atoi(matches[1]) num, _ := strconv.Atoi(matches[1])
numDupes[matches[2]] = num numDupes[matches[2]] = num
} }
if propMatches := propsRe.FindStringSubmatch(log); len(propMatches) >= 3 {
artifactID := propMatches[2]
require.NotContains(t, uniqueIDs, artifactID,
"duplicate ID reported in logs for different artifacts")
uniqueIDs[artifactID] = struct{}{}
}
} }
t.Logf("numDupes: %v", numDupes) t.Logf("numDupes: %v", numDupes)
wantAliases, wantLocalAliases, wantEntities, wantGroups := identityStoreDuplicateReportTestWantDuplicateCounts() wantAliases, wantLocalAliases, wantEntities, wantGroups := identityStoreDuplicateReportTestWantDuplicateCounts()
@ -1825,11 +1808,60 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
require.Equal(t, wantGroups, numDupes["group"]) require.Equal(t, wantGroups, numDupes["group"])
} }
// logFn is a type we used to use here before concurrentLogBuffer was added.
// It's used in other tests in Enterprise so we need to keep it around to avoid
// breaking the build during merge
type logFn struct { type logFn struct {
fn func(msg string, args []interface{}) fn func(msg string, args []interface{})
} }
// Accept implements hclog.SinkAdapter
func (f *logFn) Accept(name string, level hclog.Level, msg string, args ...interface{}) { func (f *logFn) Accept(name string, level hclog.Level, msg string, args ...interface{}) {
f.fn(msg, args) f.fn(msg, args)
} }
// concrrentLogBuffer is a simple hclog sink that captures log output in an
// slice of lines in a goroutine-safe way. Use `startLogCapture` to use it.
type concurrentLogBuffer struct {
m sync.Mutex
b []string
}
// Accept implements hclog.SinkAdapter
func (c *concurrentLogBuffer) Accept(name string, level hclog.Level, msg string, args ...interface{}) {
pairs := make([]string, 0, len(args)/2)
for pair := range slices.Chunk(args, 2) {
// Yes this will panic if we didn't log an even number of args but thats
// OK because that's a bug!
pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1]))
}
line := fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " "))
c.m.Lock()
defer c.m.Unlock()
c.b = append(c.b, line)
}
// Lines returns all the lines logged since starting the capture.
func (c *concurrentLogBuffer) Lines() []string {
c.m.Lock()
defer c.m.Unlock()
return c.b
}
// startLogCaptue attaches a logFn to the logger and returns a channel that will
// be sent each line of log output in a basec flat text format. It returns a
// cancel/stop func that should be called when capture should stop or at the end
// of processing.
func startLogCapture(t *testing.T, logger *corehelpers.TestLogger) (*concurrentLogBuffer, func()) {
t.Helper()
b := &concurrentLogBuffer{
b: make([]string, 0, 1024),
}
logger.RegisterSink(b)
stop := func() {
logger.DeregisterSink(b)
}
return b, stop
}

View File

@ -13,6 +13,16 @@ import (
//go:generate go run github.com/hashicorp/vault/tools/stubmaker //go:generate go run github.com/hashicorp/vault/tools/stubmaker
// entIdentityStoreDeterminismSupportsSecondary is a hack to drop duplicate
// tests in CE where the secondary param will only cause the no-op methods below
// to run which is functionally the same. It would be cleaner to define
// different tests in CE and ENT but it's a table test with customer test-only
// struct types which makes it a massive pain to have ent-ce specific code
// interact with the test arguments.
func entIdentityStoreDeterminismSupportsSecondary() bool {
return false
}
func entIdentityStoreDeterminismSecondaryTestSetup(t *testing.T, ctx context.Context, c *Core, me, localme *MountEntry, seed *rand.Rand) { func entIdentityStoreDeterminismSecondaryTestSetup(t *testing.T, ctx context.Context, c *Core, me, localme *MountEntry, seed *rand.Rand) {
// no op // no op
} }

View File

@ -10,6 +10,7 @@ import (
"math/rand" "math/rand"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -44,14 +45,47 @@ func (i *IdentityStore) loadArtifacts(ctx context.Context, isActive bool) error
return nil return nil
} }
loadFuncEntities := func(context.Context) error {
reload, err := i.loadEntities(ctx, isActive, false)
if err != nil {
return fmt.Errorf("failed to load entities: %w", err)
}
if reload {
// The persistMerges flag is used to fix a non-determinism issue in duplicate entities with global alias merges.
// This does not solve the case for local alias merges.
// Previously, alias merges could inconsistently flip between nodes due to bucket invalidation
// and reprocessing on standbys, leading to divergence from the primary.
// This flag ensures deterministic merges across all nodes by preventing unintended reinsertions
// of previously merged entities during bucket diffs.
persistMerges := false
if !i.localNode.ReplicationState().HasState(
consts.ReplicationDRSecondary|
consts.ReplicationPerformanceSecondary,
) && isActive {
persistMerges = true
}
// Since we're reloading entities, we need to inform the ConflictResolver about it so that it can
// clean up data related to the previous load.
i.conflictResolver.Reload(ctx)
_, err := i.loadEntities(ctx, isActive, persistMerges)
if err != nil {
return fmt.Errorf("failed to load entities: %w", err)
}
}
return nil
}
loadFunc := func(context.Context) error { loadFunc := func(context.Context) error {
i.logger.Debug("loading identity store artifacts with", i.logger.Debug("loading identity store artifacts with",
"case_sensitive", !i.disableLowerCasedNames, "case_sensitive", !i.disableLowerCasedNames,
"conflict_resolver", i.conflictResolver) "conflict_resolver", i.conflictResolver)
if err := i.loadEntities(ctx, isActive); err != nil { if err := loadFuncEntities(ctx); err != nil {
return fmt.Errorf("failed to load entities: %w", err) return fmt.Errorf("failed to load entities: %w", err)
} }
if err := i.loadGroups(ctx, isActive); err != nil { if err := i.loadGroups(ctx, isActive); err != nil {
return fmt.Errorf("failed to load groups: %w", err) return fmt.Errorf("failed to load groups: %w", err)
} }
@ -440,15 +474,19 @@ func (i *IdentityStore) loadCachedEntitiesOfLocalAliases(ctx context.Context) er
return nil return nil
} }
func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool) error { func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool, persistMerges bool) (bool, error) {
// Accumulate existing entities // Accumulate existing entities
i.logger.Debug("loading entities") i.logger.Debug("loading entities")
existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix) existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan for entities: %w", err) return false, fmt.Errorf("failed to scan for entities: %w", err)
} }
i.logger.Debug("entities collected", "num_existing", len(existing)) i.logger.Debug("entities collected", "num_existing", len(existing))
// Reponsible for flagging callers if entities should be reloaded in case
// entity merges need to be persisted.
var reload atomic.Bool
duplicatedAccessors := make(map[string]struct{}) duplicatedAccessors := make(map[string]struct{})
// Make the channels used for the worker pool. We send the index into existing // Make the channels used for the worker pool. We send the index into existing
// so that we can keep results in the same order as inputs. Note that this is // so that we can keep results in the same order as inputs. Note that this is
@ -585,23 +623,23 @@ LOOP:
if err != nil && !i.disableLowerCasedNames { if err != nil && !i.disableLowerCasedNames {
return err return err
} }
persist := false persist := false
if modified { if modified {
// If we modified the group we need to persist the changes to avoid bugs // If we modified the entity, we need to persist the changes to avoid bugs
// where memDB and storage are out of sync in the future (e.g. after // where memDB and storage are out of sync in the future (e.g. after
// invalidations of other items in the same bucket later). We do this // invalidations of other items in the same bucket later). We do this
// _even_ if `persist=false` because it is in general during unseal but // _even_ if `persist=false` because it is in general during unseal but
// this is exactly when we need to fix these. We must be _really_ // this is exactly when we need to fix these. We must be _really_
// careful to only do this on primary active node though which is the // careful to only do this on primary active node though which is the
// only source of truth that should have write access to groups across a // only source of truth that should have write access to groups across a
// cluster since they are always non-local. Note that we check !Stadby // cluster since they are always non-local. Note that we check !Standby
// and !secondary because we still need to write back even if this is a // and !Secondary because we still need to write back even if this is a
// single cluster with no replication setup and I'm not _sure_ that we // single cluster with no replication.
// report such a cluster as a primary.
if !i.localNode.ReplicationState().HasState( if !i.localNode.ReplicationState().HasState(
consts.ReplicationDRSecondary| consts.ReplicationDRSecondary|
consts.ReplicationPerformanceSecondary| consts.ReplicationPerformanceSecondary,
consts.ReplicationPerformanceStandby,
) && isActive { ) && isActive {
persist = true persist = true
} }
@ -631,21 +669,28 @@ LOOP:
tx = i.db.Txn(true) tx = i.db.Txn(true)
defer tx.Abort() defer tx.Abort()
} }
// Only update MemDB and don't hit the storage again
err = i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist) // Only update MemDB and don't hit the storage again unless we are
// merging and on a primary active node.
shouldReload, err := i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist, persistMerges)
if err != nil { if err != nil {
return fmt.Errorf("failed to update entity in MemDB: %w", err) return fmt.Errorf("failed to update entity in MemDB: %w", err)
} }
upsertedItems += toBeUpserted upsertedItems += toBeUpserted
if shouldReload {
reload.CompareAndSwap(false, true)
}
} }
if upsertedItems > 0 { if upsertedItems > 0 {
tx.Commit() tx.Commit()
} }
return nil return nil
} }
err := load(entities) err := load(entities)
if err != nil { if err != nil {
return err return false, err
} }
} }
@ -654,7 +699,7 @@ LOOP:
// Let all go routines finish // Let all go routines finish
wg.Wait() wg.Wait()
if err != nil { if err != nil {
return err return false, err
} }
// Flatten the map into a list of keys, in order to log them // Flatten the map into a list of keys, in order to log them
@ -669,7 +714,7 @@ LOOP:
i.logger.Info("entities restored") i.logger.Info("entities restored")
} }
return nil return reload.Load(), nil
} }
// loadLocalAliasesForEntity upserts local aliases into the entity by retrieving // loadLocalAliasesForEntity upserts local aliases into the entity by retrieving
@ -730,17 +775,18 @@ func getAccessorsOnDuplicateAliases(aliases []*identity.Alias) []string {
// false, then storage will not be updated. When an alias is transferred from // false, then storage will not be updated. When an alias is transferred from
// one entity to another, both the source and destination entities should get // one entity to another, both the source and destination entities should get
// updated, in which case, callers should send in both entity and // updated, in which case, callers should send in both entity and
// previousEntity. // previousEntity. persistMerges is ignored if persist = true but if persist =
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error { // false it allows the caller to request that we persist the data only if it is
// changes by a merge caused by a duplicate alias.
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist, persistMerges bool) (reload bool, err error) {
defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now()) defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now())
var err error
if txn == nil { if txn == nil {
return errors.New("txn is nil") return false, errors.New("txn is nil")
} }
if entity == nil { if entity == nil {
return errors.New("entity is nil") return false, errors.New("entity is nil")
} }
if entity.NamespaceID == "" { if entity.NamespaceID == "" {
@ -748,16 +794,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
} }
if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID { if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID {
return errors.New("entity and previous entity are not in the same namespace") return false, errors.New("entity and previous entity are not in the same namespace")
} }
aliasFactors := make([]string, len(entity.Aliases)) aliasFactors := make([]string, len(entity.Aliases))
for index, alias := range entity.Aliases { for index, alias := range entity.Aliases {
// Verify that alias is not associated to a different one already // Verify that alias is not associated to a different one already
aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, false, false) aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, false, false)
if err != nil { if err != nil {
return err return false, err
} }
if alias.NamespaceID == "" { if alias.NamespaceID == "" {
@ -768,14 +814,14 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
case aliasByFactors == nil: case aliasByFactors == nil:
// Not found, no merging needed, just check namespace // Not found, no merging needed, just check namespace
if alias.NamespaceID != entity.NamespaceID { if alias.NamespaceID != entity.NamespaceID {
return errors.New("alias and entity are not in the same namespace") return false, errors.New("alias and entity are not in the same namespace")
} }
case aliasByFactors.CanonicalID == entity.ID: case aliasByFactors.CanonicalID == entity.ID:
// Lookup found the same entity, so it's already attached to the // Lookup found the same entity, so it's already attached to the
// right place // right place
if aliasByFactors.NamespaceID != entity.NamespaceID { if aliasByFactors.NamespaceID != entity.NamespaceID {
return errors.New("alias from factors and entity are not in the same namespace") return false, errors.New("alias from factors and entity are not in the same namespace")
} }
case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID: case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID:
@ -811,17 +857,18 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
"entity_aliases", entity.Aliases, "entity_aliases", entity.Aliases,
"alias_by_factors", aliasByFactors) "alias_by_factors", aliasByFactors)
respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persist) persistMerge := persist || persistMerges
respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persistMerge)
switch { switch {
case respErr != nil: case respErr != nil:
return respErr return false, respErr
case intErr != nil: case intErr != nil:
return intErr return false, intErr
} }
// The entity and aliases will be loaded into memdb and persisted // The entity and aliases will be loaded into memdb and persisted
// as a result of the merge, so we are done here // as a result of the merge, so we are done here
return nil return true, nil
} }
// This is subtle. We want to call `ResolveAliases` so that the resolver can // This is subtle. We want to call `ResolveAliases` so that the resolver can
@ -852,13 +899,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
// especially desirable to me, but we'd rather not change behavior for now. // especially desirable to me, but we'd rather not change behavior for now.
if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) && if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) &&
conflictErr != nil && !i.disableLowerCasedNames { conflictErr != nil && !i.disableLowerCasedNames {
return conflictErr return false, conflictErr
} }
// Insert or update alias in MemDB using the transaction created above // Insert or update alias in MemDB using the transaction created above
err = i.MemDBUpsertAliasInTxn(txn, alias, false) err = i.MemDBUpsertAliasInTxn(txn, alias, false)
if err != nil { if err != nil {
return err return false, err
} }
aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor
@ -868,13 +915,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
if previousEntity != nil { if previousEntity != nil {
err = i.MemDBUpsertEntityInTxn(txn, previousEntity) err = i.MemDBUpsertEntityInTxn(txn, previousEntity)
if err != nil { if err != nil {
return err return false, err
} }
if persist { if persist {
// Persist the previous entity object // Persist the previous entity object
if err := i.persistEntity(ctx, previousEntity); err != nil { if err := i.persistEntity(ctx, previousEntity); err != nil {
return err return false, err
} }
} }
} }
@ -882,16 +929,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
// Insert or update entity in MemDB using the transaction created above // Insert or update entity in MemDB using the transaction created above
err = i.MemDBUpsertEntityInTxn(txn, entity) err = i.MemDBUpsertEntityInTxn(txn, entity)
if err != nil { if err != nil {
return err return false, err
} }
if persist { if persist {
if err := i.persistEntity(ctx, entity); err != nil { if err := i.persistEntity(ctx, entity); err != nil {
return err return false, err
} }
} }
return nil return false, nil
} }
func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.Alias, entity *identity.Entity, updateDb bool) (*identity.Alias, error) { func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.Alias, entity *identity.Entity, updateDb bool) (*identity.Alias, error) {
@ -970,7 +1017,7 @@ func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.A
if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil { if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil {
return nil, err return nil, err
} }
if err := i.upsertEntityInTxn(ctx, txn, entity, nil, false); err != nil { if _, err := i.upsertEntityInTxn(ctx, txn, entity, nil, false, false); err != nil {
return nil, err return nil, err
} }
txn.Commit() txn.Commit()
@ -1004,6 +1051,21 @@ func (i *IdentityStore) cacheTemporaryEntity(ctx context.Context, entity *identi
return nil return nil
} }
func splitLocalAliases(entity *identity.Entity) ([]*identity.Alias, []*identity.Alias) {
var localAliases []*identity.Alias
var nonLocalAliases []*identity.Alias
for _, alias := range entity.Aliases {
if alias.Local {
localAliases = append(localAliases, alias)
} else {
nonLocalAliases = append(nonLocalAliases, alias)
}
}
return nonLocalAliases, localAliases
}
func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Entity) error { func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Entity) error {
// If the entity that is passed into this function is resulting from a memdb // If the entity that is passed into this function is resulting from a memdb
// query without cloning, then modifying it will result in a direct DB edit, // query without cloning, then modifying it will result in a direct DB edit,
@ -1016,16 +1078,7 @@ func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Enti
} }
// Separate the local and non-local aliases. // Separate the local and non-local aliases.
var localAliases []*identity.Alias nonLocalAliases, localAliases := splitLocalAliases(entity)
var nonLocalAliases []*identity.Alias
for _, alias := range entity.Aliases {
switch alias.Local {
case true:
localAliases = append(localAliases, alias)
default:
nonLocalAliases = append(nonLocalAliases, alias)
}
}
// Store the entity with non-local aliases. // Store the entity with non-local aliases.
entity.Aliases = nonLocalAliases entity.Aliases = nonLocalAliases
@ -1076,7 +1129,7 @@ func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entit
txn := i.db.Txn(true) txn := i.db.Txn(true)
defer txn.Abort() defer txn.Abort()
err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist) _, err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist, false)
if err != nil { if err != nil {
return err return err
} }