Persist automatic entity merges (#29568)

* Persist automatic entity merges

* Local aliases write in test

* Add identity entity merge unit property test

* N entities merge

* Persist alias duplication fix

---------

Co-authored-by: Paul Banks <pbanks@hashicorp.com>
Co-authored-by: Mike Palmiotto <mike.palmiotto@hashicorp.com>
This commit is contained in:
Bianca 2025-03-06 23:06:20 +01:00 committed by GitHub
parent 735016d653
commit 6eeb228889
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 363 additions and 168 deletions

View File

@ -199,7 +199,8 @@ func (s *StoragePacker) DeleteMultipleItems(ctx context.Context, logger hclog.Lo
// Look for a matching storage entries and delete them from the list.
for i := 0; i < len(bucket.Items); i++ {
if _, ok := itemsToRemove[bucket.Items[i].ID]; ok {
bucket.Items[i] = bucket.Items[len(bucket.Items)-1]
copy(bucket.Items[i:], bucket.Items[i+1:])
bucket.Items[len(bucket.Items)-1] = nil // allow GC
bucket.Items = bucket.Items[:len(bucket.Items)-1]
// Since we just moved a value to position i we need to

View File

@ -6,6 +6,7 @@ package storagepacker
import (
"context"
"fmt"
"math/rand"
"testing"
"github.com/golang/protobuf/proto"
@ -68,6 +69,49 @@ func BenchmarkStoragePacker(b *testing.B) {
}
}
func BenchmarkStoragePacker_DeleteMultiple(b *testing.B) {
b.StopTimer()
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
if err != nil {
b.Fatal(err)
}
ctx := context.Background()
// Persist a storage entry
for i := 0; i <= 100000; i++ {
item := &Item{
ID: fmt.Sprintf("item%d", i),
}
err = storagePacker.PutItem(ctx, item)
if err != nil {
b.Fatal(err)
}
// Verify that it can be read
fetchedItem, err := storagePacker.GetItem(item.ID)
if err != nil {
b.Fatal(err)
}
if fetchedItem == nil {
b.Fatalf("failed to read the stored item")
}
if item.ID != fetchedItem.ID {
b.Fatalf("bad: item ID; expected: %q\n actual: %q\n", item.ID, fetchedItem.ID)
}
}
b.StartTimer()
for i := 0; i < b.N; i++ {
err = storagePacker.DeleteItem(ctx, fmt.Sprintf("item%d", rand.Intn(100000)))
if err != nil {
b.Fatal(err)
}
}
}
func TestStoragePacker(t *testing.T) {
storagePacker, err := NewStoragePacker(&logical.InmemStorage{}, log.New(&log.LoggerOptions{Name: "storagepackertest"}), "")
if err != nil {

View File

@ -11,10 +11,12 @@ import (
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/namespace"
ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap"
"github.com/hashicorp/vault/helper/testhelpers/minimal"
"github.com/hashicorp/vault/sdk/helper/ldaputil"
"github.com/hashicorp/vault/vault"
"github.com/stretchr/testify/require"
)
@ -642,3 +644,24 @@ func addRemoveLdapGroupMember(t *testing.T, cfg *ldaputil.ConfigEntry, userCN st
t.Fatal(err)
}
}
func findEntityFromDuplicateSet(t *testing.T, c *vault.TestClusterCore, entityIDs []string) *identity.Entity {
t.Helper()
var entity *identity.Entity
// Try fetch each ID and ensure exactly one is present
found := 0
for _, entityID := range entityIDs {
e, err := c.IdentityStore().MemDBEntityByID(entityID, true)
require.NoError(t, err)
if e != nil {
found++
entity = e
}
}
// More than one means they didn't merge as expected!
require.Equal(t, found, 1,
"node %s does not have exactly one duplicate from the set", c.NodeID)
return entity
}

View File

@ -752,7 +752,7 @@ func (i *IdentityStore) invalidateEntityBucket(ctx context.Context, key string)
}
}
err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false)
_, err = i.upsertEntityInTxn(ctx, txn, bucketEntity, nil, false, false)
if err != nil {
i.logger.Error("failed to update entity in MemDB", "entity_id", bucketEntity.ID, "error", err)
return
@ -1416,7 +1416,7 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
}
// Update MemDB and persist entity object
err = i.upsertEntityInTxn(ctx, txn, entity, nil, true)
_, err = i.upsertEntityInTxn(ctx, txn, entity, nil, true, false)
if err != nil {
return entity, entityCreated, err
}

View File

@ -28,6 +28,7 @@ type ConflictResolver interface {
ResolveEntities(ctx context.Context, existing, duplicate *identity.Entity) (bool, error)
ResolveGroups(ctx context.Context, existing, duplicate *identity.Group) (bool, error)
ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error)
Reload(ctx context.Context)
}
// errorResolver is a ConflictResolver that logs a warning message when a
@ -91,6 +92,10 @@ func (r *errorResolver) ResolveAliases(ctx context.Context, parent *identity.Ent
return false, errDuplicateIdentityName
}
// Reload is a no-op for the errorResolver implementation.
func (r *errorResolver) Reload(ctx context.Context) {
}
// duplicateReportingErrorResolver collects duplicate information and optionally
// logs a report on all the duplicates. We don't embed an errorResolver here
// because we _don't_ want it's side effect of warning on just some duplicates
@ -144,6 +149,10 @@ func (r *duplicateReportingErrorResolver) ResolveAliases(ctx context.Context, pa
return false, errDuplicateIdentityName
}
func (r *duplicateReportingErrorResolver) Reload(ctx context.Context) {
r.seenEntities = make(map[string][]*identity.Entity)
}
type identityDuplicateReportEntry struct {
artifactType string
scope string
@ -429,3 +438,7 @@ func (r *renameResolver) ResolveGroups(ctx context.Context, existing, duplicate
func (r *renameResolver) ResolveAliases(ctx context.Context, parent *identity.Entity, existing, duplicate *identity.Alias) (bool, error) {
return false, nil
}
// Reload is a no-op for the renameResolver implementation.
func (r *renameResolver) Reload(ctx context.Context) {
}

View File

@ -16,11 +16,9 @@ import (
"github.com/hashicorp/vault/helper/identity"
"github.com/hashicorp/vault/helper/identity/mfa"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/storagepacker"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/logical"
"google.golang.org/protobuf/types/known/anypb"
)
func entityPathFields() map[string]*framework.FieldSchema {
@ -881,7 +879,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
}
fromEntity, err := i.MemDBEntityByID(fromEntityID, false)
fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, false)
if err != nil {
return nil, err, nil
}
@ -984,7 +982,6 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
var fromEntityGroups []*identity.Group
toEntityAccessors := make(map[string][]string)
for _, alias := range toEntity.Aliases {
if accessors, ok := toEntityAccessors[alias.MountAccessor]; !ok {
// While it is not supported to have multiple aliases with the same mount accessor in one entity
@ -1002,7 +999,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
return errors.New("to_entity_id should not be present in from_entity_ids"), nil, nil
}
fromEntity, err := i.MemDBEntityByID(fromEntityID, true)
fromEntity, err := i.MemDBEntityByIDInTxn(txn, fromEntityID, true)
if err != nil {
return nil, err, nil
}
@ -1025,13 +1022,20 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
}
for _, fromAlias := range fromEntity.Aliases {
// We're going to modify this alias but it's still a pointer to the one in
// MemDB that could be being read by other goroutines even though we might
// be removing from MemDB really shortly...
fromAlias, err = fromAlias.Clone()
if err != nil {
return nil, err, nil
}
// If true, we need to handle conflicts (conflict = both aliases share the same mount accessor)
if toAliasIds, ok := toEntityAccessors[fromAlias.MountAccessor]; ok {
for _, toAliasId := range toAliasIds {
// When forceMergeAliases is true (as part of the merge-during-upsert case), we make the decision
// for the user, and keep the to_entity alias, merging the from_entity
// for the user, and keep the from_entity alias
// This case's code is the same as when the user selects to keep the from_entity alias
// but is kept separate for clarity
// but is kept separate for clarity.
if forceMergeAliases {
i.logger.Info("Deleting to_entity alias during entity merge", "to_entity", toEntity.ID, "deleted_alias", toAliasId)
err := i.MemDBDeleteAliasByIDInTxn(txn, toAliasId, false)
@ -1046,8 +1050,8 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
if err != nil {
return nil, fmt.Errorf("aborting entity merge - failed to delete orphaned alias %q during merge into entity %q: %w", fromAlias.ID, toEntity.ID, err), nil
}
// Remove the alias from the entity's list in memory too!
toEntity.DeleteAliasByID(toAliasId)
// Don't need to alter toEntity aliases since we it never contained
// the alias we're deleting.
// Continue to next alias, as there's no alias to merge left in the from_entity
continue
@ -1070,13 +1074,12 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
fromAlias.MergedFromCanonicalIDs = append(fromAlias.MergedFromCanonicalIDs, fromEntity.ID)
err = i.MemDBUpsertAliasInTxn(txn, fromAlias, false)
if err != nil {
return nil, fmt.Errorf("failed to update alias during merge: %w", err), nil
}
// We don't insert into MemDB right now because we'll do that for all the
// aliases we want to end up with at the end to ensure they are inserted
// in the same order as when they load from storage next time.
// Add the alias to the desired entity
toEntity.Aliases = append(toEntity.Aliases, fromAlias)
toEntity.UpsertAlias(fromAlias)
}
// If told to, merge policies
@ -1124,6 +1127,30 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
}
}
// Normalize Alias order. We do this because we persist NonLocal and Local
// aliases separately and so after next reload local aliases will all come
// after non-local ones. While it's logically equivalent, it makes reasoning
// about merges and determinism very hard if the order of things in MemDB can
// change from one unseal to the next so we are especially careful to ensure
// it's exactly the same whether we just merged or on a subsequent load.
// persistEntities will already split these up and persist them separately, so
// we're kinda duplicating effort and code here but this should't happen often
// so I think it's fine.
nonLocalAliases, localAliases := splitLocalAliases(toEntity)
toEntity.Aliases = append(nonLocalAliases, localAliases...)
// Don't forget to insert aliases into alias table that were part of
// `toEntity` but were not merged above (because they didn't conflict). This
// might re-insert the same aliases we just inserted above again but that's a
// no-op. TODO: maybe we could remove the memdb updates in the loop above and
// have them all be inserted here.
for _, alias := range toEntity.Aliases {
err = i.MemDBUpsertAliasInTxn(txn, alias, false)
if err != nil {
return nil, err, nil
}
}
// Update MemDB with changes to the entity we are merging to
err = i.MemDBUpsertEntityInTxn(txn, toEntity)
if err != nil {
@ -1140,16 +1167,7 @@ func (i *IdentityStore) mergeEntity(ctx context.Context, txn *memdb.Txn, toEntit
if persist && !isPerfSecondaryOrStandby {
// Persist the entity which we are merging to
toEntityAsAny, err := anypb.New(toEntity)
if err != nil {
return nil, err, nil
}
item := &storagepacker.Item{
ID: toEntity.ID,
Message: toEntityAsAny,
}
err = i.entityPacker.PutItem(ctx, item)
err = i.persistEntity(ctx, toEntity)
if err != nil {
return nil, err, nil
}

View File

@ -359,7 +359,7 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc {
flags.Count = 2
}
ids, _, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags)
ids, bucketIds, err := i.CreateDuplicateEntityAliasesInStorage(ctx, flags)
if err != nil {
i.logger.Error("error creating duplicate entities", "error", err)
return logical.ErrorResponse("error creating duplicate entities"), err
@ -367,7 +367,8 @@ func (i *IdentityStore) createDuplicateEntityAliases() framework.OperationFunc {
return &logical.Response{
Data: map[string]interface{}{
"entity_ids": ids,
"entity_ids": ids,
"bucket_keys": bucketIds,
},
}, nil
}

View File

@ -889,7 +889,7 @@ func TestOIDC_SignIDToken(t *testing.T) {
txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
_, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil {
t.Fatal(err)
}
@ -1020,7 +1020,7 @@ func TestOIDC_SignIDToken_NilSigningKey(t *testing.T) {
txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
_, err := c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil {
t.Fatal(err)
}
@ -1497,7 +1497,7 @@ func TestOIDC_Path_Introspect(t *testing.T) {
txn := c.identityStore.db.Txn(true)
defer txn.Abort()
err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true)
_, err = c.identityStore.upsertEntityInTxn(ctx, txn, testEntity, nil, true, false)
if err != nil {
t.Fatal(err)
}

View File

@ -12,6 +12,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"testing"
"time"
@ -68,7 +69,7 @@ func TestIdentityStore_DeleteEntityAlias(t *testing.T) {
BucketKey: c.identityStore.entityPacker.BucketKey("testEntityID"),
}
err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false)
_, err := c.identityStore.upsertEntityInTxn(context.Background(), txn, entity, nil, false, false)
require.NoError(t, err)
err = c.identityStore.deleteAliasesInEntityInTxn(txn, entity, []*identity.Alias{alias, alias2})
@ -1422,51 +1423,45 @@ func TestIdentityStoreLoadingIsDeterministic(t *testing.T) {
seedval, err = strconv.ParseInt(os.Getenv("VAULT_TEST_IDENTITY_STORE_SEED"), 10, 64)
require.NoError(t, err)
}
seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test
defer t.Logf("Test generated with seed: %d", seedval)
tests := []struct {
name string
flags *determinismTestFlags
}{
{
name: "error-resolver-primary",
flags: &determinismTestFlags{
identityDeduplication: false,
secondary: false,
seed: seed,
},
tests := map[string]*determinismTestFlags{
"error-resolver-primary": {
identityDeduplication: false,
secondary: false,
},
{
name: "identity-cleanup-primary",
flags: &determinismTestFlags{
identityDeduplication: true,
secondary: false,
seed: seed,
},
},
{
name: "error-resolver-secondary",
flags: &determinismTestFlags{
identityDeduplication: false,
secondary: true,
seed: seed,
},
},
{
name: "identity-cleanup-secondary",
flags: &determinismTestFlags{
identityDeduplication: true,
secondary: true,
seed: seed,
},
"identity-cleanup-primary": {
identityDeduplication: true,
secondary: false,
},
}
for _, test := range tests {
t.Run(t.Name()+"-"+test.name, func(t *testing.T) {
identityStoreLoadingIsDeterministic(t, test.flags)
// Hook to add cases that only differ in Enterprise
if entIdentityStoreDeterminismSupportsSecondary() {
tests["error-resolver-secondary"] = &determinismTestFlags{
identityDeduplication: false,
secondary: true,
}
tests["identity-cleanup-secondary"] = &determinismTestFlags{
identityDeduplication: true,
secondary: true,
}
}
repeats := 50
for name, flags := range tests {
t.Run(t.Name()+"-"+name, func(t *testing.T) {
// Create a random source specific to this test case so every test case
// starts out from the identical random state given the same seed. We do
// want each iteration to explore different path though so we do it here
// not inside the test func.
seed := rand.New(rand.NewSource(seedval)) // Seed for deterministic test
flags.seed = seed
for i := 0; i < repeats; i++ {
identityStoreLoadingIsDeterministic(t, flags)
}
})
}
}
@ -1538,7 +1533,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed)
attachAlias(t, e, alias, upme, seed)
attachAlias(t, e, localAlias, localMe, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err)
// Subset of entities get a duplicate alias and/or duplicate local alias.
@ -1551,7 +1547,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
for rnd < pDup && dupeNum < 10 {
e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-dup-%d", i, dupeNum), c.identityStore.entityPacker, seed)
attachAlias(t, e, alias, upme, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err)
// Toss again to see if we continue
rnd = seed.Float64()
@ -1563,8 +1560,9 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
for rnd < pDup && dupeNum < 10 {
e := makeEntityForPacker(t, fmt.Sprintf("entity-%d-localdup-%d", i, dupeNum), c.identityStore.entityPacker, seed)
attachAlias(t, e, localAlias, localMe, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err)
rnd = seed.Float64()
dupeNum++
}
@ -1572,7 +1570,7 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
rnd = seed.Float64()
for rnd < pDup {
e := makeEntityForPacker(t, name, c.identityStore.entityPacker, seed)
err = TestHelperWriteToStoragePacker(ctx, c.identityStore.entityPacker, e.ID, e)
err = c.identityStore.persistEntity(ctx, e)
require.NoError(t, err)
rnd = seed.Float64()
}
@ -1624,20 +1622,20 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
}
// Storage is now primed for the test.
require.NoError(t, c.Seal(rootToken))
require.True(t, c.Sealed())
for _, key := range sealKeys {
unsealed, err := c.Unseal(key)
require.NoError(t, err, "failed unseal on initial assertions")
if unsealed {
break
}
}
require.False(t, c.Sealed())
if identityDeduplication {
// Perform an initial Seal/Unseal with duplicates injected to assert
// initial state.
require.NoError(t, c.Seal(rootToken))
require.True(t, c.Sealed())
for _, key := range sealKeys {
unsealed, err := c.Unseal(key)
require.NoError(t, err, "failed unseal on initial assertions")
if unsealed {
break
}
}
require.False(t, c.Sealed())
// Test out the system backend ActivationFunc wiring
c.FeatureActivationFlags.ActivateInMem(activationflags.IdentityDeduplication, true)
@ -1657,11 +1655,15 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
// build a list of human readable ids that we can compare.
prevLoadedNames := []string{}
var prevErr error
for i := 0; i < 10; i++ {
err := c.identityStore.resetDB()
require.NoError(t, err)
logger.Info(" ==> BEGIN LOAD ARTIFACTS", "i", i)
err = c.identityStore.loadArtifacts(ctx, true)
if i > 0 {
require.Equal(t, prevErr, err)
}
@ -1673,6 +1675,8 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
case <-c.identityStore.activateDeduplicationDone:
default:
}
logger.Info(" ==> END LOAD ARTIFACTS ", "i", i)
// Identity store should be loaded now. Check it's contents.
loadedNames := []string{}
@ -1683,24 +1687,25 @@ func identityStoreLoadingIsDeterministic(t *testing.T, flags *determinismTestFla
require.NoError(t, err)
for item := iter.Next(); item != nil; item = iter.Next() {
e := item.(*identity.Entity)
loadedNames = append(loadedNames, e.Name)
loadedNames = append(loadedNames, e.NamespaceID+"/"+e.Name+"_"+e.ID)
for _, a := range e.Aliases {
loadedNames = append(loadedNames, a.Name)
loadedNames = append(loadedNames, a.MountAccessor+"/"+a.Name+"_"+a.ID)
}
}
// This is a non-triviality check to make sure we actually loaded stuff and
// are not just passing because of a bug in the test.
numLoaded := len(loadedNames)
require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i)
// Standalone alias query
iter, err = tx.LowerBound(entityAliasesTable, "id", "")
require.NoError(t, err)
for item := iter.Next(); item != nil; item = iter.Next() {
a := item.(*identity.Alias)
loadedNames = append(loadedNames, a.Name)
loadedNames = append(loadedNames, "fromAliasTable:"+a.MountAccessor+"/"+a.Name+"_"+a.ID)
}
// This is a non-triviality check to make sure we actually loaded stuff and
// are not just passing because of a bug in the test.
numLoaded := len(loadedNames)
require.Greater(t, numLoaded, 300, "not enough entities and aliases loaded on attempt %d", i)
// Groups
iter, err = tx.LowerBound(groupsTable, "id", "")
require.NoError(t, err)
@ -1774,23 +1779,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
// Storage is now primed for the test.
// Setup a logger we can use to capture unseal logs
var unsealLogs []string
unsealLogger := &logFn{
fn: func(msg string, args []interface{}) {
pairs := make([]string, 0, len(args)/2)
for pair := range slices.Chunk(args, 2) {
// Yes this will panic if we didn't log an even number of args but thats
// OK because that's a bug!
pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1]))
}
unsealLogs = append(unsealLogs, fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " ")))
},
}
logBuf, stopCapture := startLogCapture(t, logger)
logger.RegisterSink(unsealLogger)
err = c.identityStore.loadArtifacts(ctx, true)
stopCapture()
require.NoError(t, err)
logger.DeregisterSink(unsealLogger)
// Identity store should be loaded now. Check it's contents.
@ -1799,23 +1793,12 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
// many of these cases and seems strange to encode in a test that we want
// broken behavior!
numDupes := make(map[string]int)
uniqueIDs := make(map[string]struct{})
duplicateCountRe := regexp.MustCompile(`(\d+) (different-case( local)? entity alias|entity|group) duplicates found`)
// Be sure not to match attributes like alias_id= because there are dupes
// there. The report lines we care about always have a space before the id
// pair.
propsRe := regexp.MustCompile(`\s(id=(\S+))`)
for _, log := range unsealLogs {
for _, log := range logBuf.Lines() {
if matches := duplicateCountRe.FindStringSubmatch(log); len(matches) >= 3 {
num, _ := strconv.Atoi(matches[1])
numDupes[matches[2]] = num
}
if propMatches := propsRe.FindStringSubmatch(log); len(propMatches) >= 3 {
artifactID := propMatches[2]
require.NotContains(t, uniqueIDs, artifactID,
"duplicate ID reported in logs for different artifacts")
uniqueIDs[artifactID] = struct{}{}
}
}
t.Logf("numDupes: %v", numDupes)
wantAliases, wantLocalAliases, wantEntities, wantGroups := identityStoreDuplicateReportTestWantDuplicateCounts()
@ -1825,11 +1808,60 @@ func TestIdentityStoreLoadingDuplicateReporting(t *testing.T) {
require.Equal(t, wantGroups, numDupes["group"])
}
// logFn is a type we used to use here before concurrentLogBuffer was added.
// It's used in other tests in Enterprise so we need to keep it around to avoid
// breaking the build during merge
type logFn struct {
fn func(msg string, args []interface{})
}
// Accept implements hclog.SinkAdapter
func (f *logFn) Accept(name string, level hclog.Level, msg string, args ...interface{}) {
f.fn(msg, args)
}
// concrrentLogBuffer is a simple hclog sink that captures log output in an
// slice of lines in a goroutine-safe way. Use `startLogCapture` to use it.
type concurrentLogBuffer struct {
m sync.Mutex
b []string
}
// Accept implements hclog.SinkAdapter
func (c *concurrentLogBuffer) Accept(name string, level hclog.Level, msg string, args ...interface{}) {
pairs := make([]string, 0, len(args)/2)
for pair := range slices.Chunk(args, 2) {
// Yes this will panic if we didn't log an even number of args but thats
// OK because that's a bug!
pairs = append(pairs, fmt.Sprintf("%s=%s", pair[0], pair[1]))
}
line := fmt.Sprintf("%s: %s", msg, strings.Join(pairs, " "))
c.m.Lock()
defer c.m.Unlock()
c.b = append(c.b, line)
}
// Lines returns all the lines logged since starting the capture.
func (c *concurrentLogBuffer) Lines() []string {
c.m.Lock()
defer c.m.Unlock()
return c.b
}
// startLogCaptue attaches a logFn to the logger and returns a channel that will
// be sent each line of log output in a basec flat text format. It returns a
// cancel/stop func that should be called when capture should stop or at the end
// of processing.
func startLogCapture(t *testing.T, logger *corehelpers.TestLogger) (*concurrentLogBuffer, func()) {
t.Helper()
b := &concurrentLogBuffer{
b: make([]string, 0, 1024),
}
logger.RegisterSink(b)
stop := func() {
logger.DeregisterSink(b)
}
return b, stop
}

View File

@ -13,6 +13,16 @@ import (
//go:generate go run github.com/hashicorp/vault/tools/stubmaker
// entIdentityStoreDeterminismSupportsSecondary is a hack to drop duplicate
// tests in CE where the secondary param will only cause the no-op methods below
// to run which is functionally the same. It would be cleaner to define
// different tests in CE and ENT but it's a table test with customer test-only
// struct types which makes it a massive pain to have ent-ce specific code
// interact with the test arguments.
func entIdentityStoreDeterminismSupportsSecondary() bool {
return false
}
func entIdentityStoreDeterminismSecondaryTestSetup(t *testing.T, ctx context.Context, c *Core, me, localme *MountEntry, seed *rand.Rand) {
// no op
}

View File

@ -10,6 +10,7 @@ import (
"math/rand"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@ -44,14 +45,47 @@ func (i *IdentityStore) loadArtifacts(ctx context.Context, isActive bool) error
return nil
}
loadFuncEntities := func(context.Context) error {
reload, err := i.loadEntities(ctx, isActive, false)
if err != nil {
return fmt.Errorf("failed to load entities: %w", err)
}
if reload {
// The persistMerges flag is used to fix a non-determinism issue in duplicate entities with global alias merges.
// This does not solve the case for local alias merges.
// Previously, alias merges could inconsistently flip between nodes due to bucket invalidation
// and reprocessing on standbys, leading to divergence from the primary.
// This flag ensures deterministic merges across all nodes by preventing unintended reinsertions
// of previously merged entities during bucket diffs.
persistMerges := false
if !i.localNode.ReplicationState().HasState(
consts.ReplicationDRSecondary|
consts.ReplicationPerformanceSecondary,
) && isActive {
persistMerges = true
}
// Since we're reloading entities, we need to inform the ConflictResolver about it so that it can
// clean up data related to the previous load.
i.conflictResolver.Reload(ctx)
_, err := i.loadEntities(ctx, isActive, persistMerges)
if err != nil {
return fmt.Errorf("failed to load entities: %w", err)
}
}
return nil
}
loadFunc := func(context.Context) error {
i.logger.Debug("loading identity store artifacts with",
"case_sensitive", !i.disableLowerCasedNames,
"conflict_resolver", i.conflictResolver)
if err := i.loadEntities(ctx, isActive); err != nil {
if err := loadFuncEntities(ctx); err != nil {
return fmt.Errorf("failed to load entities: %w", err)
}
if err := i.loadGroups(ctx, isActive); err != nil {
return fmt.Errorf("failed to load groups: %w", err)
}
@ -440,15 +474,19 @@ func (i *IdentityStore) loadCachedEntitiesOfLocalAliases(ctx context.Context) er
return nil
}
func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool) error {
func (i *IdentityStore) loadEntities(ctx context.Context, isActive bool, persistMerges bool) (bool, error) {
// Accumulate existing entities
i.logger.Debug("loading entities")
existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix)
if err != nil {
return fmt.Errorf("failed to scan for entities: %w", err)
return false, fmt.Errorf("failed to scan for entities: %w", err)
}
i.logger.Debug("entities collected", "num_existing", len(existing))
// Reponsible for flagging callers if entities should be reloaded in case
// entity merges need to be persisted.
var reload atomic.Bool
duplicatedAccessors := make(map[string]struct{})
// Make the channels used for the worker pool. We send the index into existing
// so that we can keep results in the same order as inputs. Note that this is
@ -585,23 +623,23 @@ LOOP:
if err != nil && !i.disableLowerCasedNames {
return err
}
persist := false
if modified {
// If we modified the group we need to persist the changes to avoid bugs
// If we modified the entity, we need to persist the changes to avoid bugs
// where memDB and storage are out of sync in the future (e.g. after
// invalidations of other items in the same bucket later). We do this
// _even_ if `persist=false` because it is in general during unseal but
// this is exactly when we need to fix these. We must be _really_
// careful to only do this on primary active node though which is the
// only source of truth that should have write access to groups across a
// cluster since they are always non-local. Note that we check !Stadby
// and !secondary because we still need to write back even if this is a
// single cluster with no replication setup and I'm not _sure_ that we
// report such a cluster as a primary.
// cluster since they are always non-local. Note that we check !Standby
// and !Secondary because we still need to write back even if this is a
// single cluster with no replication.
if !i.localNode.ReplicationState().HasState(
consts.ReplicationDRSecondary|
consts.ReplicationPerformanceSecondary|
consts.ReplicationPerformanceStandby,
consts.ReplicationPerformanceSecondary,
) && isActive {
persist = true
}
@ -631,21 +669,28 @@ LOOP:
tx = i.db.Txn(true)
defer tx.Abort()
}
// Only update MemDB and don't hit the storage again
err = i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist)
// Only update MemDB and don't hit the storage again unless we are
// merging and on a primary active node.
shouldReload, err := i.upsertEntityInTxn(nsCtx, tx, entity, nil, persist, persistMerges)
if err != nil {
return fmt.Errorf("failed to update entity in MemDB: %w", err)
}
upsertedItems += toBeUpserted
if shouldReload {
reload.CompareAndSwap(false, true)
}
}
if upsertedItems > 0 {
tx.Commit()
}
return nil
}
err := load(entities)
if err != nil {
return err
return false, err
}
}
@ -654,7 +699,7 @@ LOOP:
// Let all go routines finish
wg.Wait()
if err != nil {
return err
return false, err
}
// Flatten the map into a list of keys, in order to log them
@ -669,7 +714,7 @@ LOOP:
i.logger.Info("entities restored")
}
return nil
return reload.Load(), nil
}
// loadLocalAliasesForEntity upserts local aliases into the entity by retrieving
@ -730,17 +775,18 @@ func getAccessorsOnDuplicateAliases(aliases []*identity.Alias) []string {
// false, then storage will not be updated. When an alias is transferred from
// one entity to another, both the source and destination entities should get
// updated, in which case, callers should send in both entity and
// previousEntity.
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
// previousEntity. persistMerges is ignored if persist = true but if persist =
// false it allows the caller to request that we persist the data only if it is
// changes by a merge caused by a duplicate alias.
func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist, persistMerges bool) (reload bool, err error) {
defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now())
var err error
if txn == nil {
return errors.New("txn is nil")
return false, errors.New("txn is nil")
}
if entity == nil {
return errors.New("entity is nil")
return false, errors.New("entity is nil")
}
if entity.NamespaceID == "" {
@ -748,16 +794,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
}
if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID {
return errors.New("entity and previous entity are not in the same namespace")
return false, errors.New("entity and previous entity are not in the same namespace")
}
aliasFactors := make([]string, len(entity.Aliases))
for index, alias := range entity.Aliases {
// Verify that alias is not associated to a different one already
aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, false, false)
aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, false, false)
if err != nil {
return err
return false, err
}
if alias.NamespaceID == "" {
@ -768,14 +814,14 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
case aliasByFactors == nil:
// Not found, no merging needed, just check namespace
if alias.NamespaceID != entity.NamespaceID {
return errors.New("alias and entity are not in the same namespace")
return false, errors.New("alias and entity are not in the same namespace")
}
case aliasByFactors.CanonicalID == entity.ID:
// Lookup found the same entity, so it's already attached to the
// right place
if aliasByFactors.NamespaceID != entity.NamespaceID {
return errors.New("alias from factors and entity are not in the same namespace")
return false, errors.New("alias from factors and entity are not in the same namespace")
}
case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID:
@ -811,17 +857,18 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
"entity_aliases", entity.Aliases,
"alias_by_factors", aliasByFactors)
respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persist)
persistMerge := persist || persistMerges
respErr, intErr := i.mergeEntityAsPartOfUpsert(ctx, txn, entity, aliasByFactors.CanonicalID, persistMerge)
switch {
case respErr != nil:
return respErr
return false, respErr
case intErr != nil:
return intErr
return false, intErr
}
// The entity and aliases will be loaded into memdb and persisted
// as a result of the merge, so we are done here
return nil
return true, nil
}
// This is subtle. We want to call `ResolveAliases` so that the resolver can
@ -852,13 +899,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
// especially desirable to me, but we'd rather not change behavior for now.
if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) &&
conflictErr != nil && !i.disableLowerCasedNames {
return conflictErr
return false, conflictErr
}
// Insert or update alias in MemDB using the transaction created above
err = i.MemDBUpsertAliasInTxn(txn, alias, false)
if err != nil {
return err
return false, err
}
aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor
@ -868,13 +915,13 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
if previousEntity != nil {
err = i.MemDBUpsertEntityInTxn(txn, previousEntity)
if err != nil {
return err
return false, err
}
if persist {
// Persist the previous entity object
if err := i.persistEntity(ctx, previousEntity); err != nil {
return err
return false, err
}
}
}
@ -882,16 +929,16 @@ func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, e
// Insert or update entity in MemDB using the transaction created above
err = i.MemDBUpsertEntityInTxn(txn, entity)
if err != nil {
return err
return false, err
}
if persist {
if err := i.persistEntity(ctx, entity); err != nil {
return err
return false, err
}
}
return nil
return false, nil
}
func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.Alias, entity *identity.Entity, updateDb bool) (*identity.Alias, error) {
@ -970,7 +1017,7 @@ func (i *IdentityStore) processLocalAlias(ctx context.Context, lAlias *logical.A
if err := i.MemDBUpsertAliasInTxn(txn, alias, false); err != nil {
return nil, err
}
if err := i.upsertEntityInTxn(ctx, txn, entity, nil, false); err != nil {
if _, err := i.upsertEntityInTxn(ctx, txn, entity, nil, false, false); err != nil {
return nil, err
}
txn.Commit()
@ -1004,6 +1051,21 @@ func (i *IdentityStore) cacheTemporaryEntity(ctx context.Context, entity *identi
return nil
}
func splitLocalAliases(entity *identity.Entity) ([]*identity.Alias, []*identity.Alias) {
var localAliases []*identity.Alias
var nonLocalAliases []*identity.Alias
for _, alias := range entity.Aliases {
if alias.Local {
localAliases = append(localAliases, alias)
} else {
nonLocalAliases = append(nonLocalAliases, alias)
}
}
return nonLocalAliases, localAliases
}
func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Entity) error {
// If the entity that is passed into this function is resulting from a memdb
// query without cloning, then modifying it will result in a direct DB edit,
@ -1016,16 +1078,7 @@ func (i *IdentityStore) persistEntity(ctx context.Context, entity *identity.Enti
}
// Separate the local and non-local aliases.
var localAliases []*identity.Alias
var nonLocalAliases []*identity.Alias
for _, alias := range entity.Aliases {
switch alias.Local {
case true:
localAliases = append(localAliases, alias)
default:
nonLocalAliases = append(nonLocalAliases, alias)
}
}
nonLocalAliases, localAliases := splitLocalAliases(entity)
// Store the entity with non-local aliases.
entity.Aliases = nonLocalAliases
@ -1076,7 +1129,7 @@ func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entit
txn := i.db.Txn(true)
defer txn.Abort()
err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist)
_, err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist, false)
if err != nil {
return err
}