mirror of
https://github.com/hashicorp/vault.git
synced 2025-09-21 05:41:08 +02:00
Persist automatic entity merges (#29568)
* Persist automatic entity merges * Local aliases write in test * Add identity entity merge unit property test * N entities merge * Persist alias duplication fix --------- Co-authored-by: Paul Banks <pbanks@hashicorp.com> Co-authored-by: Mike Palmiotto <mike.palmiotto@hashicorp.com>
This commit is contained in:
parent
735016d653
commit
6eeb228889
@ -199,7 +199,8 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo
|
||||
// Look for a matching storage entries and delete them from the list.
|
||||
for i := 0; i < len(bucket.Items); i++ {
|
||||
if _, ok := itemsToRemove[bucket.Items[i].ID]; ok {
|
||||
bucket.Items[i] = bucket.Items[len(bucket.Items)-1]
|
||||
copy(bucket.Items[i:], bucket.Items[i+1:])
|
||||
bucket.Items[len(bucket.Items)-1] = nil // allow GC
|
||||
bucket.Items = bucket.Items[:len(bucket.Items)-1]
|
||||
|
||||
// Since we just moved a value to position i we need to
|
||||
|
@ -6,6 +6,7 @@ package storagepacker
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
@ -68,6 +69,49 @@ func BenchmarkStoragePacker(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStoragePacker_DeleteMultiple(b *testing.B) {
|
||||
b.StopTimer()
|
||||
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Persist a storage entry
|
||||
for i := 0; i <= 100000; i++ {
|
||||
item := &Item{
|
||||
ID: fmt.Sprintf("item%d", i),
|
||||
}
|
||||
|
||||
err = storagePacker.PutItem(ctx, item)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify that it can be read
|
||||
fetchedItem, err := storagePacker.GetItem(item.ID)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if fetchedItem == nil {
|
||||
b.Fatalf("failed to read the stored item")
|
||||
}
|
||||
|
||||
if item.ID != fetchedItem.ID {
|
||||
b.Fatalf("bad: item ID; expected: %q\n actual: %q\n", item.ID, fetchedItem.ID)
|
||||
}
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
err = storagePacker.DeleteItem(ctx, fmt.Sprintf("item%d", rand.Intn(100000)))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoragePacker(t *testing.T) {
|
||||
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
|
||||
if err != nil {
|
||||
|
@ -11,10 +11,12 @@ import (
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/go-secure-stdlib/strutil"
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/helper/identity"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/minimal"
|
||||
"github.com/hashicorp/vault/sdk/helper/ldaputil"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -642,3 +644,24 @@ func addRemoveLdapGroupMember(t *testing.T, cfg *ldaputil.ConfigEntry, userCN st
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func findEntityFromDuplicateSet(t *testing.T, c *vault.TestClusterCore, entityIDs []string) *identity.Entity {
|
||||
t.Helper()
|
||||
|
||||
var entity *identity.Entity
|
||||
|
||||
// Try fetch each ID and ensure exactly one is present
|
||||
found := 0
|
||||
for _, entityID := range entityIDs {
|
||||
e, err := c.IdentityStore().MemDBEntityByID(entityID, true)
|
||||
require.NoError(t, err)
|
||||
if e != nil {
|
||||
found++
|
||||
entity = e
|
||||
}
|
||||
}
|
||||
// More than one means they didn't merge as expected!
|
||||
require.Equal(t, found, 1,
|
||||
"node %s does not have exactly one duplicate from the set", c.NodeID)
|
||||
return entity
|
||||
}
|
||||
|
@ -752,7 +752,7 @@ func (i *IdentityStore) invalidateEntityBucket(ctx context.Context, key string)
|
||||
}
|
||||
}
|
||||
|
||||
err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false)
|
||||
_, err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false, false)
|
||||
if err != nil {
|
||||
i.logger.Error("failed to update entity in MemDB", "entity_id", bucketEntity.ID, "error", err)
|
||||
return
|
||||
@ -1416,7 +1416,7 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
|
||||
}
|
||||
|
||||
// Update MemDB and persist entity object
|
||||
err = i.upsertEntityInTxn(ctx, txn, entity, nil, true)
|
||||
_, err = i.upsertEntityInTxn(ctx, txn, entity, nil, true, false)
|
||||
if err != nil {
|
||||
return entity, entityCreated, err
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ type ConflictResolver interface {
|
||||
ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error)
|
||||
ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error)
|
||||
ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error)
|
||||
Reload(ctx context.Context)
|
||||
}
|
||||
|
||||
// errorResolver is a ConflictResolver that logs a warning message when a
|
||||
@ -91,6 +92,10 @@ func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Ent
|
||||
return false, errDuplicateIdentityName
|
||||
}
|
||||
|
||||
// Reload is a no-op for the errorResolver implementation.
|
||||
func (r *errorResolver) Reload(ctx context.Context) {
|
||||
}
|
||||
|
||||
// duplicateReportingErrorResolver collects duplicate information and optionally
|
||||
// logs a report on all the duplicates. We don't embed an errorResolver here
|
||||
// because we _don't_ want it's side effect of warning on just some duplicates
|
||||
@ -144,6 +149,10 @@ func (r *duplicateReportingErrorResolver) ResolveAliases(ctx context.Context, pa
|
||||
return false, errDuplicateIdentityName
|
||||
}
|
||||
|
||||
func (r *duplicateReportingErrorResolver) Reload(ctx context.Context) {
|
||||
r.seenEntities = make(map[string][]*identity.Entity)
|
||||
}
|
||||
|
||||
type identityDuplicateReportEntry struct {
|
||||
artifactType string
|
||||
scope string
|
||||
@ -429,3 +438,7 @@ func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate
|
||||
func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Reload is a no-op for the renameResolver implementation.
|
||||
func (r *renameResolver) Reload(ctx context.Context) {
|
||||
}
|
||||
|
@ -16,11 +16,9 @@ import (
|
||||
"github.com/hashicorp/vault/helper/identity"
|
||||
"github.com/hashicorp/vault/helper/identity/mfa"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
"github.com/hashicorp/vault/helper/storagepacker"
|
||||
"github.com/hashicorp/vault/sdk/framework"
|
||||
"github.com/hashicorp/vault/sdk/helper/consts"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
func entityPathFields() map[string]*framework.FieldSchema {
|
||||
@ -881,7 +879,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
|
||||
}
|
||||
|
||||
fromEntity, err := i.MemDBEntityByID(fromEntityID, false)
|
||||
fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, false)
|
||||
if err != nil {
|
||||
return nil, err, nil
|
||||
}
|
||||
@ -984,7 +982,6 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
var fromEntityGroups []*identity.Group
|
||||
|
||||
toEntityAccessors := make(map[string][]string)
|
||||
|
||||
for _, alias := range toEntity.Aliases {
|
||||
if accessors, ok := toEntityAccessors[alias.MountAccessor]; !ok {
|
||||
// While it is not supported to have multiple aliases with the same mount accessor in one entity
|
||||
@ -1002,7 +999,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
|
||||
}
|
||||
|
||||
fromEntity, err := i.MemDBEntityByID(fromEntityID, true)
|
||||
fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, true)
|
||||
if err != nil {
|
||||
return nil, err, nil
|
||||
}
|
||||
@ -1025,13 +1022,20 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
}
|
||||
|
||||
for _, fromAlias := range fromEntity.Aliases {
|
||||
// We're going to modify this alias but it's still a pointer to the one in
|
||||
// MemDB that could be being read by other goroutines even though we might
|
||||
// be removing from MemDB really shortly...
|
||||
fromAlias, err = fromAlias.Clone()
|
||||
if err != nil {
|
||||
return nil, err, nil
|
||||
}
|
||||
// If true, we need to handle conflicts (conflict = both aliases share the same mount accessor)
|
||||
if toAliasIds, ok := toEntityAccessors[fromAlias.MountAccessor]; ok {
|
||||
for _, toAliasId := range toAliasIds {
|
||||
// When forceMergeAliases is true (as part of the merge-during-upsert case), we make the decision
|
||||
// for the user, and keep the to_entity alias, merging the from_entity
|
||||
// for the user, and keep the from_entity alias
|
||||
// This case's code is the same as when the user selects to keep the from_entity alias
|
||||
// but is kept separate for clarity
|
||||
// but is kept separate for clarity.
|
||||
if forceMergeAliases {
|
||||
i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId)
|
||||
err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false)
|
||||
@ -1046,8 +1050,8 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aborting entity merge - failed to delete orphaned alias %q during merge into entity %q: %w", fromAlias.ID, toEntity.ID, err), nil
|
||||
}
|
||||
// Remove the alias from the entity's list in memory too!
|
||||
toEntity.DeleteAliasByID(toAliasId)
|
||||
// Don't need to alter toEntity aliases since we it never contained
|
||||
// the alias we're deleting.
|
||||
|
||||
// Continue to next alias, as there's no alias to merge left in the from_entity
|
||||
continue
|
||||
@ -1070,13 +1074,12 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
|
||||
fromAlias.MergedFromCanonicalIDs = append(fromAlias.MergedFromCanonicalIDs, fromEntity.ID)
|
||||
|
||||
err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update alias during merge: %w", err), nil
|
||||
}
|
||||
// We don't insert into MemDB right now because we'll do that for all the
|
||||
// aliases we want to end up with at the end to ensure they are inserted
|
||||
// in the same order as when they load from storage next time.
|
||||
|
||||
// Add the alias to the desired entity
|
||||
toEntity.Aliases = append(toEntity.Aliases, fromAlias)
|
||||
toEntity.UpsertAlias(fromAlias)
|
||||
}
|
||||
|
||||
// If told to, merge policies
|
||||
@ -1124,6 +1127,30 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize Alias order. We do this because we persist NonLocal and Local
|
||||
// aliases separately and so after next reload local aliases will all come
|
||||
// after non-local ones. While it's logically equivalent, it makes reasoning
|
||||
// about merges and determinism very hard if the order of things in MemDB can
|
||||
// change from one unseal to the next so we are especially careful to ensure
|
||||
// it's exactly the same whether we just merged or on a subsequent load.
|
||||
// persistEntities will already split these up and persist them separately, so
|
||||
// we're kinda duplicating effort and code here but this should't happen often
|
||||
// so I think it's fine.
|
||||
nonLocalAliases, localAliases := splitLocalAliases(toEntity)
|
||||
toEntity.Aliases = append(nonLocalAliases, localAliases...)
|
||||
|
||||
// Don't forget to insert aliases into alias table that were part of
|
||||
// `toEntity` but were not merged above (because they didn't conflict). This
|
||||
// might re-insert the same aliases we just inserted above again but that's a
|
||||
// no-op. TODO: maybe we could remove the memdb updates in the loop above and
|
||||
// have them all be inserted here.
|
||||
for _, alias := range toEntity.Aliases {
|
||||
err = i.MemDBUpsertAliasInTxn(txn, alias, false)
|
||||
if err != nil {
|
||||
return nil, err, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Update MemDB with changes to the entity we are merging to
|
||||
err = i.MemDBUpsertEntityInTxn(txn, toEntity)
|
||||
if err != nil {
|
||||
@ -1140,16 +1167,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
|
||||
|
||||
if persist && !isPerfSecondaryOrStandby {
|
||||
// Persist the entity which we are merging to
|
||||
toEntityAsAny, err := anypb.New(toEntity)
|
||||
if err != nil {
|
||||
return nil, err, nil
|
||||
}
|
||||
item := &storagepacker.Item{
|
||||
ID: toEntity.ID,
|
||||
Message: toEntityAsAny,
|
||||
}
|
||||
|
||||
err = i.entityPacker.PutItem(ctx, item)
|
||||
err = i.persistEntity(ctx, toEntity)
|
||||
if err != nil {
|
||||
return nil, err, nil
|
||||
}
|
||||
|
@ -359,7 +359,7 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc {
|
||||
flags.Count = 2
|
||||
}
|
||||
|
||||
ids, _, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags)
|
||||
ids, bucketIds, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags)
|
||||
if err != nil {
|
||||
i.logger.Error("error creating duplicate entities", "error", err)
|
||||
return logical.ErrorResponse("error creating duplicate entities"), err
|
||||
@ -367,7 +367,8 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc {
|
||||
|
||||
return &logical.Response{
|
||||
Data: map[string]interface{}{
|
||||
"entity_ids": ids,
|
||||
"entity_ids": ids,
|
||||
"bucket_keys": bucketIds,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -889,7 +889,7 @@ func TestOIDC_SignIDToken(t *testing.T) {
|
||||
|
||||
txn := c.identityStore.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
|
||||
_, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1020,7 +1020,7 @@ func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) {
|
||||
|
||||
txn := c.identityStore.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
|
||||
_, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1497,7 +1497,7 @@ func TestOIDC_Path_Introspect(t *testing.T) {
|
||||
|
||||
txn := c.identityStore.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
|
||||
_, err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -68,7 +69,7 @@ func TestIdentityStore_DeleteEntityAlias(t *testing.T) {
|
||||
BucketKey: c.identityStore.entityPacker.BucketKey("testEntityID"),
|
||||
}
|
||||
|
||||
err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false)
|
||||
_, err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.identityStore.deleteAliasesInEntityInTxn(txn, entity, []*identity.Alias{alias, alias2})
|
||||
@ -1422,51 +1423,45 @@ func TestIdentityStoreLoadingIsDeterministic(t *testing.T) {
|
||||
seedval, err = strconv.ParseInt(os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED"), 10, 64)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test
|
||||
defer t.Logf("Test generated with seed: %d", seedval)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
flags *determinismTestFlags
|
||||
}{
|
||||
{
|
||||
name: "error-resolver-primary",
|
||||
flags: &determinismTestFlags{
|
||||
identityDeduplication: false,
|
||||
secondary: false,
|
||||
seed: seed,
|
||||
},
|
||||
tests := map[string]*determinismTestFlags{
|
||||
"error-resolver-primary": {
|
||||
identityDeduplication: false,
|
||||
secondary: false,
|
||||
},
|
||||
{
|
||||
name: "identity-cleanup-primary",
|
||||
flags: &determinismTestFlags{
|
||||
identityDeduplication: true,
|
||||
secondary: false,
|
||||
seed: seed,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "error-resolver-secondary",
|
||||
flags: &determinismTestFlags{
|
||||
identityDeduplication: false,
|
||||
secondary: true,
|
||||
seed: seed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "identity-cleanup-secondary",
|
||||
flags: &determinismTestFlags{
|
||||
identityDeduplication: true,
|
||||
secondary: true,
|
||||
seed: seed,
|
||||
},
|
||||
"identity-cleanup-primary": {
|
||||
identityDeduplication: true,
|
||||
secondary: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(t.Name()+"-"+test.name, func(t *testing.T) {
|
||||
identityStoreLoadingIsDeterministic(t, test.flags)
|
||||
// Hook to add cases that only differ in Enterprise
|
||||
if entIdentityStoreDeterminismSupportsSecondary() {
|
||||
tests["error-resolver-secondary"] = &determinismTestFlags{
|
||||
identityDeduplication: false,
|
||||
secondary: true,
|
||||
}
|
||||
tests["identity-cleanup-secondary"] = &determinismTestFlags{
|
||||
identityDeduplication: true,
|
||||
secondary: true,
|
||||
}
|
||||
}
|
||||
|
||||
repeats := 50
|
||||
|
||||
for name, flags := range tests {
|
||||
t.Run(t.Name()+"-"+name, func(t *testing.T) {
|
||||
// Create a random source specific to this test case so every test case
|
||||
// starts out from the identical random state given the same seed. We do
|
||||
// want each iteration to explore different path though so we do it here
|
||||
// not inside the test func.
|
||||
seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test
|
||||
flags.seed = seed
|
||||
|
||||
for i := 0; i < repeats; i++ {
|
||||
identityStoreLoadingIsDeterministic(t, flags)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1538,7 +1533,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed)
|
||||
attachAlias(t, e, alias, upme, seed)
|
||||
attachAlias(t, e, localAlias, localMe, seed)
|
||||
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
|
||||
|
||||
err = c.identityStore.persistEntity(ctx, e)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Subset of entities get a duplicate alias and/or duplicate local alias.
|
||||
@ -1551,7 +1547,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
for rnd < pDup && dupeNum < 10 {
|
||||
e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-dup-%d", i, dupeNum), c.identityStore.entityPacker, seed)
|
||||
attachAlias(t, e, alias, upme, seed)
|
||||
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
|
||||
|
||||
err = c.identityStore.persistEntity(ctx, e)
|
||||
require.NoError(t, err)
|
||||
// Toss again to see if we continue
|
||||
rnd = seed.Float64()
|
||||
@ -1563,8 +1560,9 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
for rnd < pDup && dupeNum < 10 {
|
||||
e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-localdup-%d", i, dupeNum), c.identityStore.entityPacker, seed)
|
||||
attachAlias(t, e, localAlias, localMe, seed)
|
||||
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
|
||||
err = c.identityStore.persistEntity(ctx, e)
|
||||
require.NoError(t, err)
|
||||
|
||||
rnd = seed.Float64()
|
||||
dupeNum++
|
||||
}
|
||||
@ -1572,7 +1570,7 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
rnd = seed.Float64()
|
||||
for rnd < pDup {
|
||||
e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed)
|
||||
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
|
||||
err = c.identityStore.persistEntity(ctx, e)
|
||||
require.NoError(t, err)
|
||||
rnd = seed.Float64()
|
||||
}
|
||||
@ -1624,20 +1622,20 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
}
|
||||
|
||||
// Storage is now primed for the test.
|
||||
require.NoError(t, c.Seal(rootToken))
|
||||
require.True(t, c.Sealed())
|
||||
for _, key := range sealKeys {
|
||||
unsealed, err := c.Unseal(key)
|
||||
require.NoError(t, err, "failed unseal on initial assertions")
|
||||
if unsealed {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.False(t, c.Sealed())
|
||||
|
||||
if identityDeduplication {
|
||||
// Perform an initial Seal/Unseal with duplicates injected to assert
|
||||
// initial state.
|
||||
require.NoError(t, c.Seal(rootToken))
|
||||
require.True(t, c.Sealed())
|
||||
for _, key := range sealKeys {
|
||||
unsealed, err := c.Unseal(key)
|
||||
require.NoError(t, err, "failed unseal on initial assertions")
|
||||
if unsealed {
|
||||
break
|
||||
}
|
||||
}
|
||||
require.False(t, c.Sealed())
|
||||
|
||||
// Test out the system backend ActivationFunc wiring
|
||||
c.FeatureActivationFlags.ActivateInMem(activationflags.IdentityDeduplication, true)
|
||||
@ -1657,11 +1655,15 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
// build a list of human readable ids that we can compare.
|
||||
prevLoadedNames := []string{}
|
||||
var prevErr error
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
err := c.identityStore.resetDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
logger.Info(" ==> BEGIN LOAD ARTIFACTS", "i", i)
|
||||
|
||||
err = c.identityStore.loadArtifacts(ctx, true)
|
||||
|
||||
if i > 0 {
|
||||
require.Equal(t, prevErr, err)
|
||||
}
|
||||
@ -1673,6 +1675,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
case <-c.identityStore.activateDeduplicationDone:
|
||||
default:
|
||||
}
|
||||
logger.Info(" ==> END LOAD ARTIFACTS ", "i", i)
|
||||
|
||||
// Identity store should be loaded now. Check it's contents.
|
||||
loadedNames := []string{}
|
||||
|
||||
@ -1683,24 +1687,25 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
|
||||
require.NoError(t, err)
|
||||
for item := iter.Next(); item != nil; item = iter.Next() {
|
||||
e := item.(*identity.Entity)
|
||||
loadedNames = append(loadedNames, e.Name)
|
||||
loadedNames = append(loadedNames, e.NamespaceID+"/"+e.Name+"_"+e.ID)
|
||||
for _, a := range e.Aliases {
|
||||
loadedNames = append(loadedNames, a.Name)
|
||||
loadedNames = append(loadedNames, a.MountAccessor+"/"+a.Name+"_"+a.ID)
|
||||
}
|
||||
}
|
||||
// This is a non-triviality check to make sure we actually loaded stuff and
|
||||
// are not just passing because of a bug in the test.
|
||||
numLoaded := len(loadedNames)
|
||||
require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i)
|
||||
|
||||
// Standalone alias query
|
||||
iter, err = tx.LowerBound(entityAliasesTable, "id", "")
|
||||
require.NoError(t, err)
|
||||
for item := iter.Next(); item != nil; item = iter.Next() {
|
||||
a := item.(*identity.Alias)
|
||||
loadedNames = append(loadedNames, a.Name)
|
||||
loadedNames = append(loadedNames, "fromAliasTable:"+a.MountAccessor+"/"+a.Name+"_"+a.ID)
|
||||
}
|
||||
|
||||
// This is a non-triviality check to make sure we actually loaded stuff and
|
||||
// are not just passing because of a bug in the test.
|
||||
numLoaded := len(loadedNames)
|
||||
require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i)
|
||||
|
||||
// Groups
|
||||
iter, err = tx.LowerBound(groupsTable, "id", "")
|
||||
require.NoError(t, err)
|
||||
@ -1774,23 +1779,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
|
||||
// Storage is now primed for the test.
|
||||
|
||||
// Setup a logger we can use to capture unseal logs
|
||||
var unsealLogs []string
|
||||
unsealLogger := &logFn{
|
||||
fn: func(msg string, args []interface{}) {
|
||||
pairs := make([]string, 0, len(args)/2)
|
||||
for pair := range slices.Chunk(args, 2) {
|
||||
// Yes this will panic if we didn't log an even number of args but thats
|
||||
// OK because that's a bug!
|
||||
pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1]))
|
||||
}
|
||||
unsealLogs = append(unsealLogs, fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " ")))
|
||||
},
|
||||
}
|
||||
logBuf, stopCapture := startLogCapture(t, logger)
|
||||
|
||||
logger.RegisterSink(unsealLogger)
|
||||
err = c.identityStore.loadArtifacts(ctx, true)
|
||||
stopCapture()
|
||||
|
||||
require.NoError(t, err)
|
||||
logger.DeregisterSink(unsealLogger)
|
||||
|
||||
// Identity store should be loaded now. Check it's contents.
|
||||
|
||||
@ -1799,23 +1793,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
|
||||
// many of these cases and seems strange to encode in a test that we want
|
||||
// broken behavior!
|
||||
numDupes := make(map[string]int)
|
||||
uniqueIDs := make(map[string]struct{})
|
||||
duplicateCountRe := regexp.MustCompile(`(\d+) (different-case( local)? entity alias|entity|group) duplicates found`)
|
||||
// Be sure not to match attributes like alias_id= because there are dupes
|
||||
// there. The report lines we care about always have a space before the id
|
||||
// pair.
|
||||
propsRe := regexp.MustCompile(`\s(id=(\S+))`)
|
||||
for _, log := range unsealLogs {
|
||||
for _, log := range logBuf.Lines() {
|
||||
if matches := duplicateCountRe.FindStringSubmatch(log); len(matches) >= 3 {
|
||||
num, _ := strconv.Atoi(matches[1])
|
||||
numDupes[matches[2]] = num
|
||||
}
|
||||
if propMatches := propsRe.FindStringSubmatch(log); len(propMatches) >= 3 {
|
||||
artifactID := propMatches[2]
|
||||
require.NotContains(t, uniqueIDs, artifactID,
|
||||
"duplicate ID reported in logs for different artifacts")
|
||||
uniqueIDs[artifactID] = struct{}{}
|
||||
}
|
||||
}
|
||||
t.Logf("numDupes: %v", numDupes)
|
||||
wantAliases, wantLocalAliases, wantEntities, wantGroups := identityStoreDuplicateReportTestWantDuplicateCounts()
|
||||
@ -1825,11 +1808,60 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
|
||||
require.Equal(t, wantGroups, numDupes["group"])
|
||||
}
|
||||
|
||||
// logFn is a type we used to use here before concurrentLogBuffer was added.
|
||||
// It's used in other tests in Enterprise so we need to keep it around to avoid
|
||||
// breaking the build during merge
|
||||
type logFn struct {
|
||||
fn func(msg string, args []interface{})
|
||||
}
|
||||
|
||||
// Accept implements hclog.SinkAdapter
|
||||
func (f *logFn) Accept(name string, level hclog.Level, msg string, args ...interface{}) {
|
||||
f.fn(msg, args)
|
||||
}
|
||||
|
||||
// concrrentLogBuffer is a simple hclog sink that captures log output in an
|
||||
// slice of lines in a goroutine-safe way. Use `startLogCapture` to use it.
|
||||
type concurrentLogBuffer struct {
|
||||
m sync.Mutex
|
||||
b []string
|
||||
}
|
||||
|
||||
// Accept implements hclog.SinkAdapter
|
||||
func (c *concurrentLogBuffer) Accept(name string, level hclog.Level, msg string, args ...interface{}) {
|
||||
pairs := make([]string, 0, len(args)/2)
|
||||
for pair := range slices.Chunk(args, 2) {
|
||||
// Yes this will panic if we didn't log an even number of args but thats
|
||||
// OK because that's a bug!
|
||||
pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1]))
|
||||
}
|
||||
line := fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " "))
|
||||
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
c.b = append(c.b, line)
|
||||
}
|
||||
|
||||
// Lines returns all the lines logged since starting the capture.
|
||||
func (c *concurrentLogBuffer) Lines() []string {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
return c.b
|
||||
}
|
||||
|
||||
// startLogCaptue attaches a logFn to the logger and returns a channel that will
|
||||
// be sent each line of log output in a basec flat text format. It returns a
|
||||
// cancel/stop func that should be called when capture should stop or at the end
|
||||
// of processing.
|
||||
func startLogCapture(t *testing.T, logger *corehelpers.TestLogger) (*concurrentLogBuffer, func()) {
|
||||
t.Helper()
|
||||
b := &concurrentLogBuffer{
|
||||
b: make([]string, 0, 1024),
|
||||
}
|
||||
logger.RegisterSink(b)
|
||||
stop := func() {
|
||||
logger.DeregisterSink(b)
|
||||
}
|
||||
return b, stop
|
||||
}
|
||||
|
@ -13,6 +13,16 @@ import (
|
||||
|
||||
//go:generate go run github.com/hashicorp/vault/tools/stubmaker
|
||||
|
||||
// entIdentityStoreDeterminismSupportsSecondary is a hack to drop duplicate
|
||||
// tests in CE where the secondary param will only cause the no-op methods below
|
||||
// to run which is functionally the same. It would be cleaner to define
|
||||
// different tests in CE and ENT but it's a table test with customer test-only
|
||||
// struct types which makes it a massive pain to have ent-ce specific code
|
||||
// interact with the test arguments.
|
||||
func entIdentityStoreDeterminismSupportsSecondary() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func entIdentityStoreDeterminismSecondaryTestSetup(t *testing.T, ctx context.Context, c *Core, me, localme *MountEntry, seed *rand.Rand) {
|
||||
// no op
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -44,14 +45,47 @@ func (i *IdentityStore) loadArtifacts(ctx context.Context, isActive bool) error
|
||||
return nil
|
||||
}
|
||||
|
||||
loadFuncEntities := func(context.Context) error {
|
||||
reload, err := i.loadEntities(ctx, isActive, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load entities: %w", err)
|
||||
}
|
||||
if reload {
|
||||
// The persistMerges flag is used to fix a non-determinism issue in duplicate entities with global alias merges.
|
||||
// This does not solve the case for local alias merges.
|
||||
// Previously, alias merges could inconsistently flip between nodes due to bucket invalidation
|
||||
// and reprocessing on standbys, leading to divergence from the primary.
|
||||
// This flag ensures deterministic merges across all nodes by preventing unintended reinsertions
|
||||
// of previously merged entities during bucket diffs.
|
||||
persistMerges := false
|
||||
if !i.localNode.ReplicationState().HasState(
|
||||
consts.ReplicationDRSecondary|
|
||||
consts.ReplicationPerformanceSecondary,
|
||||
) && isActive {
|
||||
persistMerges = true
|
||||
}
|
||||
|
||||
// Since we're reloading entities, we need to inform the ConflictResolver about it so that it can
|
||||
// clean up data related to the previous load.
|
||||
i.conflictResolver.Reload(ctx)
|
||||
|
||||
_, err := i.loadEntities(ctx, isActive, persistMerges)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load entities: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
loadFunc := func(context.Context) error {
|
||||
i.logger.Debug("loading identity store artifacts with",
|
||||
"case_sensitive", !i.disableLowerCasedNames,
|
||||
"conflict_resolver", i.conflictResolver)
|
||||
|
||||
if err := i.loadEntities(ctx, isActive); err != nil {
|
||||
if err := loadFuncEntities(ctx); err != nil {
|
||||
return fmt.Errorf("failed to load entities: %w", err)
|
||||
}
|
||||
|
||||
if err := i.loadGroups(ctx, isActive); err != nil {
|
||||
return fmt.Errorf("failed to load groups: %w", err)
|
||||
}
|
||||
@ -440,15 +474,19 @@ func (i *IdentityStore) loadCachedEntitiesOfLocalAliases(ctx context.Context) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool) error {
|
||||
func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool, persistMerges bool) (bool, error) {
|
||||
// Accumulate existing entities
|
||||
i.logger.Debug("loading entities")
|
||||
existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to scan for entities: %w", err)
|
||||
return false, fmt.Errorf("failed to scan for entities: %w", err)
|
||||
}
|
||||
i.logger.Debug("entities collected", "num_existing", len(existing))
|
||||
|
||||
// Reponsible for flagging callers if entities should be reloaded in case
|
||||
// entity merges need to be persisted.
|
||||
var reload atomic.Bool
|
||||
|
||||
duplicatedAccessors := make(map[string]struct{})
|
||||
// Make the channels used for the worker pool. We send the index into existing
|
||||
// so that we can keep results in the same order as inputs. Note that this is
|
||||
@ -585,23 +623,23 @@ LOOP:
|
||||
if err != nil && !i.disableLowerCasedNames {
|
||||
return err
|
||||
}
|
||||
|
||||
persist := false
|
||||
if modified {
|
||||
// If we modified the group we need to persist the changes to avoid bugs
|
||||
// If we modified the entity, we need to persist the changes to avoid bugs
|
||||
// where memDB and storage are out of sync in the future (e.g. after
|
||||
// invalidations of other items in the same bucket later). We do this
|
||||
// _even_ if `persist=false` because it is in general during unseal but
|
||||
// this is exactly when we need to fix these. We must be _really_
|
||||
// careful to only do this on primary active node though which is the
|
||||
// only source of truth that should have write access to groups across a
|
||||
// cluster since they are always non-local. Note that we check !Stadby
|
||||
// and !secondary because we still need to write back even if this is a
|
||||
// single cluster with no replication setup and I'm not _sure_ that we
|
||||
// report such a cluster as a primary.
|
||||
// cluster since they are always non-local. Note that we check !Standby
|
||||
// and !Secondary because we still need to write back even if this is a
|
||||
// single cluster with no replication.
|
||||
|
||||
if !i.localNode.ReplicationState().HasState(
|
||||
consts.ReplicationDRSecondary|
|
||||
consts.ReplicationPerformanceSecondary|
|
||||
consts.ReplicationPerformanceStandby,
|
||||
consts.ReplicationPerformanceSecondary,
|
||||
) && isActive {
|
||||
persist = true
|
||||
}
|
||||
@ -631,21 +669,28 @@ LOOP:
|
||||
tx = i.db.Txn(true)
|
||||
defer tx.Abort()
|
||||
}
|
||||
// Only update MemDB and don't hit the storage again
|
||||
err = i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist)
|
||||
|
||||
// Only update MemDB and don't hit the storage again unless we are
|
||||
// merging and on a primary active node.
|
||||
shouldReload, err := i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist, persistMerges)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update entity in MemDB: %w", err)
|
||||
}
|
||||
upsertedItems += toBeUpserted
|
||||
|
||||
if shouldReload {
|
||||
reload.CompareAndSwap(false, true)
|
||||
}
|
||||
}
|
||||
if upsertedItems > 0 {
|
||||
tx.Commit()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := load(entities)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
}
|
||||
@ -654,7 +699,7 @@ LOOP:
|
||||
// Let all go routines finish
|
||||
wg.Wait()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Flatten the map into a list of keys, in order to log them
|
||||
@ -669,7 +714,7 @@ LOOP:
|
||||
i.logger.Info("entities restored")
|
||||
}
|
||||
|
||||
return nil
|
||||
return reload.Load(), nil
|
||||
}
|
||||
|
||||
// loadLocalAliasesForEntity upserts local aliases into the entity by retrieving
|
||||
@ -730,17 +775,18 @@ func getAccessorsOnDuplicateAliases(aliases []*identity.Alias) []string {
|
||||
// false, then storage will not be updated. When an alias is transferred from
|
||||
// one entity to another, both the source and destination entities should get
|
||||
// updated, in which case, callers should send in both entity and
|
||||
// previousEntity.
|
||||
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
|
||||
// previousEntity. persistMerges is ignored if persist = true but if persist =
|
||||
// false it allows the caller to request that we persist the data only if it is
|
||||
// changes by a merge caused by a duplicate alias.
|
||||
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist, persistMerges bool) (reload bool, err error) {
|
||||
defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now())
|
||||
var err error
|
||||
|
||||
if txn == nil {
|
||||
return errors.New("txn is nil")
|
||||
return false, errors.New("txn is nil")
|
||||
}
|
||||
|
||||
if entity == nil {
|
||||
return errors.New("entity is nil")
|
||||
return false, errors.New("entity is nil")
|
||||
}
|
||||
|
||||
if entity.NamespaceID == "" {
|
||||
@ -748,16 +794,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
|
||||
}
|
||||
|
||||
if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID {
|
||||
return errors.New("entity and previous entity are not in the same namespace")
|
||||
return false, errors.New("entity and previous entity are not in the same namespace")
|
||||
}
|
||||
|
||||
aliasFactors := make([]string, len(entity.Aliases))
|
||||
|
||||
for index, alias := range entity.Aliases {
|
||||
// Verify that alias is not associated to a different one already
|
||||
aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, false, false)
|
||||
aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if alias.NamespaceID == "" {
|
||||
@ -768,14 +814,14 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
|
||||
case aliasByFactors == nil:
|
||||
// Not found, no merging needed, just check namespace
|
||||
if alias.NamespaceID != entity.NamespaceID {
|
||||
return errors.New("alias and entity are not in the same namespace")
|
||||
return false, errors.New("alias and entity are not in the same namespace")
|
||||
}
|
||||
|
||||
case aliasByFactors.CanonicalID == entity.ID:
|
||||
// Lookup found the same entity, so it's already attached to the
|
||||
// right place
|
||||
if aliasByFactors.NamespaceID != entity.NamespaceID {
|
||||
return errors.New("alias from factors and entity are not in the same namespace")
|
||||
return false, errors.New("alias from factors and entity are not in the same namespace")
|
||||
}
|
||||
|
||||
case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID:
|
||||
@ -811,17 +857,18 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
|
||||
"entity_aliases", entity.Aliases,
|
||||
"alias_by_factors", aliasByFactors)
|
||||
|
||||
respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persist)
|
||||
persistMerge := persist || persistMerges
|
||||
respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persistMerge)
|
||||
switch {
|
||||
case respErr != nil:
|
||||
return respErr
|
||||
return false, respErr
|
||||
case intErr != nil:
|
||||
return intErr
|
||||
return false, intErr
|
||||
}
|
||||
|
||||
// The entity and aliases will be loaded into memdb and persisted
|
||||
// as a result of the merge, so we are done here
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// This is subtle. We want to call `ResolveAliases` so that the resolver can
|
||||
@ -852,13 +899,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
|
||||
// especially desirable to me, but we'd rather not change behavior for now.
|
||||
if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) &&
|
||||
conflictErr != nil && !i.disableLowerCasedNames {
|
||||
return conflictErr
|
||||
return false, conflictErr
|
||||
}
|
||||
|
||||
// Insert or update alias in MemDB using the transaction created above
|
||||
err = i.MemDBUpsertAliasInTxn(txn, alias, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor
|
||||
@ -868,13 +915,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
|
||||
if previousEntity != nil {
|
||||
err = i.MemDBUpsertEntityInTxn(txn, previousEntity)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if persist {
|
||||
// Persist the previous entity object
|
||||
if err := i.persistEntity(ctx, previousEntity); err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -882,16 +929,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
|
||||
// Insert or update entity in MemDB using the transaction created above
|
||||
err = i.MemDBUpsertEntityInTxn(txn, entity)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if persist {
|
||||
if err := i.persistEntity(ctx, entity); err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.Alias, entity *identity.Entity, updateDb bool) (*identity.Alias, error) {
|
||||
@ -970,7 +1017,7 @@ func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.A
|
||||
if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := i.upsertEntityInTxn(ctx, txn, entity, nil, false); err != nil {
|
||||
if _, err := i.upsertEntityInTxn(ctx, txn, entity, nil, false, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
txn.Commit()
|
||||
@ -1004,6 +1051,21 @@ func (i *IdentityStore) cacheTemporaryEntity(ctx context.Context, entity *identi
|
||||
return nil
|
||||
}
|
||||
|
||||
func splitLocalAliases(entity *identity.Entity) ([]*identity.Alias, []*identity.Alias) {
|
||||
var localAliases []*identity.Alias
|
||||
var nonLocalAliases []*identity.Alias
|
||||
|
||||
for _, alias := range entity.Aliases {
|
||||
if alias.Local {
|
||||
localAliases = append(localAliases, alias)
|
||||
} else {
|
||||
nonLocalAliases = append(nonLocalAliases, alias)
|
||||
}
|
||||
}
|
||||
|
||||
return nonLocalAliases, localAliases
|
||||
}
|
||||
|
||||
func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Entity) error {
|
||||
// If the entity that is passed into this function is resulting from a memdb
|
||||
// query without cloning, then modifying it will result in a direct DB edit,
|
||||
@ -1016,16 +1078,7 @@ func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Enti
|
||||
}
|
||||
|
||||
// Separate the local and non-local aliases.
|
||||
var localAliases []*identity.Alias
|
||||
var nonLocalAliases []*identity.Alias
|
||||
for _, alias := range entity.Aliases {
|
||||
switch alias.Local {
|
||||
case true:
|
||||
localAliases = append(localAliases, alias)
|
||||
default:
|
||||
nonLocalAliases = append(nonLocalAliases, alias)
|
||||
}
|
||||
}
|
||||
nonLocalAliases, localAliases := splitLocalAliases(entity)
|
||||
|
||||
// Store the entity with non-local aliases.
|
||||
entity.Aliases = nonLocalAliases
|
||||
@ -1076,7 +1129,7 @@ func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entit
|
||||
txn := i.db.Txn(true)
|
||||
defer txn.Abort()
|
||||
|
||||
err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist)
|
||||
_, err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user