From ba53e126a23915fed966728e0027abd91a43ed01 Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Wed, 29 Apr 2026 03:09:16 -0600 Subject: [PATCH] changing cloning to a bool (#14050) (#14363) * changing cloning to a bool * fixing linting and bad error return * adding changelog * moving tests in to external tests, adding helper Co-authored-by: JMGoldsmith --- changelog/_14050.txt | 3 + .../external_tests/identity/identity_test.go | 220 ++++++++++++++++++ vault/identity_store.go | 63 ++++- vault/identity_store_test.go | 43 ++++ vault/identity_store_util.go | 29 +++ 5 files changed, 346 insertions(+), 12 deletions(-) create mode 100644 changelog/_14050.txt diff --git a/changelog/_14050.txt b/changelog/_14050.txt new file mode 100644 index 0000000000..3ac702bdf2 --- /dev/null +++ b/changelog/_14050.txt @@ -0,0 +1,3 @@ +```release-note:bug +identity: fixed a rare but possible data race issue with identities. +``` \ No newline at end of file diff --git a/vault/external_tests/identity/identity_test.go b/vault/external_tests/identity/identity_test.go index 7f45a8325a..0500612753 100644 --- a/vault/external_tests/identity/identity_test.go +++ b/vault/external_tests/identity/identity_test.go @@ -4,7 +4,10 @@ package identity import ( + "context" "fmt" + "maps" + "sync" "testing" "github.com/go-ldap/ldap/v3" @@ -16,10 +19,35 @@ import ( ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap" "github.com/hashicorp/vault/helper/testhelpers/minimal" "github.com/hashicorp/vault/sdk/helper/ldaputil" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" "github.com/stretchr/testify/require" ) +func testIdentityStoreWithGithubUserpassAuth(ctx context.Context, t *testing.T) (*vault.IdentityStore, string, string, *vault.TestClusterCore) { + t.Helper() + + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + err := client.Sys().EnableAuthWithOptions("github", &api.EnableAuthOptions{Type: "github"}) + require.NoError(t, err) + + err = client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{Type: "userpass"}) + require.NoError(t, err) + + auth, err := client.Sys().ListAuth() + require.NoError(t, err) + + githubAccessor := auth["github/"].Accessor + require.NotEmpty(t, githubAccessor) + + userpassAccessor := auth["userpass/"].Accessor + require.NotEmpty(t, userpassAccessor) + + return cluster.Cores[0].IdentityStore(), githubAccessor, userpassAccessor, cluster.Cores[0] +} + func TestIdentityStore_ExternalGroupMemberships_DifferentMounts(t *testing.T) { t.Parallel() cluster := minimal.NewTestSoloCluster(t, nil) @@ -665,3 +693,195 @@ func findEntityFromDuplicateSet(t *testing.T, c *vault.TestClusterCore, entityID "node %s does not have exactly one duplicate from the set", c.NodeID) return entity } + +// TestIdentityStore_CreateOrFetchEntity_ConcurrentFastPath verifies concurrent +// callers on the fast path all read the existing entity without creating new +// entities or mutating alias metadata. +func TestIdentityStore_CreateOrFetchEntity_ConcurrentFastPath(t *testing.T) { + ctx := namespace.RootContext(nil) + is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t) + + alias := &logical.Alias{ + MountType: "github", + MountAccessor: ghAccessor, + Name: "githubuser", + Metadata: map[string]string{ + "foo": "a", + }, + } + + entity, _, err := is.CreateOrFetchEntity(ctx, alias) + require.NoError(t, err) + require.NotNil(t, entity) + + const workers = 16 + var wg sync.WaitGroup + errCh := make(chan error, workers) + + for range workers { + wg.Add(1) + go func() { + defer wg.Done() + + got, created, err := is.CreateOrFetchEntity(ctx, alias) + if err != nil { + errCh <- err + return + } + if created { + errCh <- fmt.Errorf("unexpected entity creation on fast path") + return + } + if got == nil { + errCh <- fmt.Errorf("expected entity on fast path") + return + } + if len(got.Aliases) != 1 { + errCh <- fmt.Errorf("expected 1 alias, got %d", len(got.Aliases)) + return + } + if !maps.Equal(got.Aliases[0].Metadata, alias.Metadata) { + errCh <- fmt.Errorf("unexpected alias metadata: %#v", got.Aliases[0].Metadata) + } + }() + } + + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } +} + +// TestIdentityStore_CreateOrFetchEntity_ConcurrentMetadataUpdates verifies +// concurrent readers and metadata updates for the same alias remain stable and +// return a single-entity view. +func TestIdentityStore_CreateOrFetchEntity_ConcurrentMetadataUpdates(t *testing.T) { + ctx := namespace.RootContext(nil) + is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t) + + newAlias := func(metadataValue string) *logical.Alias { + return &logical.Alias{ + MountType: "github", + MountAccessor: ghAccessor, + Name: "githubuser", + Metadata: map[string]string{ + "foo": metadataValue, + }, + } + } + + entity, _, err := is.CreateOrFetchEntity(ctx, newAlias("a")) + require.NoError(t, err) + require.NotNil(t, entity) + + const workers = 8 + const iterations = 100 + + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, workers*2) + + for range workers { + wg.Add(1) + go func() { + defer wg.Done() + <-start + + for range iterations { + got, _, err := is.CreateOrFetchEntity(ctx, newAlias("a")) + if err != nil { + errCh <- err + return + } + if got == nil { + errCh <- fmt.Errorf("expected entity from read path") + return + } + if len(got.Aliases) != 1 { + errCh <- fmt.Errorf("expected 1 alias, got %d", len(got.Aliases)) + return + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-start + + for i := range iterations { + metadataValue := "b" + if i%2 == 0 { + metadataValue = "a" + } + + got, _, err := is.CreateOrFetchEntity(ctx, newAlias(metadataValue)) + if err != nil { + errCh <- err + return + } + if got == nil { + errCh <- fmt.Errorf("expected entity from update path") + return + } + if len(got.Aliases) != 1 { + errCh <- fmt.Errorf("expected 1 alias, got %d", len(got.Aliases)) + return + } + } + }() + } + + close(start) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } +} + +// TestIdentityStore_EntityByAliasFactors_BlackBoxCloneBehavior verifies +// alias-based lookup and clone behavior using only exported identity store APIs. +func TestIdentityStore_EntityByAliasFactors_BlackBoxCloneBehavior(t *testing.T) { + ctx := namespace.RootContext(nil) + is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t) + + loginAlias := &logical.Alias{ + MountType: "github", + MountAccessor: ghAccessor, + Name: "githubuser", + Metadata: map[string]string{ + "foo": "a", + }, + } + + createdEntity, _, err := is.CreateOrFetchEntity(ctx, loginAlias) + require.NoError(t, err) + require.NotNil(t, createdEntity) + + matchedAlias, err := is.MemDBAliasByFactors(ghAccessor, loginAlias.Name, true, false) + require.NoError(t, err) + require.NotNil(t, matchedAlias) + require.True(t, maps.Equal(matchedAlias.Metadata, loginAlias.Metadata)) + + matchedEntity, err := is.MemDBEntityByID(matchedAlias.CanonicalID, true) + require.NoError(t, err) + require.NotNil(t, matchedEntity) + require.Equal(t, createdEntity.ID, matchedEntity.ID) + + // Clone reads should not mutate MemDB state. + matchedEntity.Name = "mutated-name" + freshAlias, err := is.MemDBAliasByFactors(ghAccessor, loginAlias.Name, true, false) + require.NoError(t, err) + require.NotNil(t, freshAlias) + + freshEntity, err := is.MemDBEntityByID(freshAlias.CanonicalID, true) + require.NoError(t, err) + require.NotNil(t, freshEntity) + require.NotEqual(t, "mutated-name", freshEntity.Name) + + require.False(t, maps.Equal(freshAlias.Metadata, map[string]string{"foo": "does-not-match"})) +} diff --git a/vault/identity_store.go b/vault/identity_store.go index 2c16ea585f..7dfdf91c5b 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -1301,6 +1301,51 @@ func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor, return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone) } +// entityByAliasFactorsIf fetches and clones an entity by alias factors only when +// shouldReturn evaluates to true for the matching alias. +func (i *IdentityStore) entityByAliasFactorsIf(mountAccessor, aliasName string, shouldReturn func(*identity.Alias) bool) (*identity.Entity, error) { + if mountAccessor == "" { + return nil, fmt.Errorf("missing mount accessor") + } + + if aliasName == "" { + return nil, fmt.Errorf("missing alias name") + } + + txn := i.db.Txn(false) + + return i.entityByAliasFactorsInTxnIf(txn, mountAccessor, aliasName, shouldReturn) +} + +// entityByAliasFactorsInTxnIf fetches and clones an entity by alias factors only +// when shouldReturn evaluates to true for the matching alias. +func (i *IdentityStore) entityByAliasFactorsInTxnIf(txn *memdb.Txn, mountAccessor, aliasName string, shouldReturn func(*identity.Alias) bool) (*identity.Entity, error) { + var entity *identity.Entity + + if txn == nil { + return nil, fmt.Errorf("nil txn") + } + + if mountAccessor == "" { + return nil, fmt.Errorf("missing mount accessor") + } + + if aliasName == "" { + return nil, fmt.Errorf("missing alias name") + } + + alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false) + if err != nil { + return nil, err + } + + if alias == nil { + return entity, nil + } + + return i.MemDBEntityByAliasIDInTxnClonePredicate(txn, alias.ID, shouldReturn) +} + // CreateEntity creates a new entity. func (i *IdentityStore) CreateEntity(ctx context.Context) (*identity.Entity, error) { defer metrics.MeasureSince([]string{"identity", "create_entity"}, time.Now()) @@ -1359,21 +1404,15 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical. return nil, false, fmt.Errorf("mount accessor %q is not a mount of type %q", alias.MountAccessor, alias.MountType) } - // Check if an entity already exists for the given alias. - // We don't clone here to avoid unnecessary allocations - if we need to - // return early, we'll clone at that point. - entity, err = i.entityByAliasFactors(alias.MountAccessor, alias.Name, false) + // Fast path: only clone and return the entity when alias metadata is unchanged. + entity, err = i.entityByAliasFactorsIf(alias.MountAccessor, alias.Name, func(existingAlias *identity.Alias) bool { + return strutil.EqualStringMaps(existingAlias.Metadata, alias.Metadata) + }) if err != nil { return nil, false, err } - if entity != nil && changedAliasIndex(entity, alias) == -1 { - // Entity exists and no metadata changes - clone before returning - // to avoid exposing internal MemDB state to callers. - clonedEntity, err := entity.Clone() - if err != nil { - return nil, false, err - } - return clonedEntity, false, nil + if entity != nil { + return entity, false, nil } i.lock.Lock() diff --git a/vault/identity_store_test.go b/vault/identity_store_test.go index 5d5acd90ac..34d8d8af63 100644 --- a/vault/identity_store_test.go +++ b/vault/identity_store_test.go @@ -437,6 +437,49 @@ func TestIdentityStore_EntityByAliasFactors(t *testing.T) { } } +// TestIdentityStore_EntityByAliasFactorsInTxnIf verifies the predicate-gated +// lookup returns cloned entities on match and nil when the predicate rejects. +func TestIdentityStore_EntityByAliasFactorsInTxnIf(t *testing.T) { + ctx := namespace.RootContext(nil) + is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t) + + loginAlias := &logical.Alias{ + MountType: "github", + MountAccessor: ghAccessor, + Name: "githubuser", + Metadata: map[string]string{ + "foo": "a", + }, + } + + createdEntity, _, err := is.CreateOrFetchEntity(ctx, loginAlias) + require.NoError(t, err) + require.NotNil(t, createdEntity) + + txn := is.db.Txn(false) + defer txn.Abort() + + matchedEntity, err := is.entityByAliasFactorsInTxnIf(txn, ghAccessor, loginAlias.Name, func(existingAlias *identity.Alias) bool { + return maps.Equal(existingAlias.Metadata, loginAlias.Metadata) + }) + require.NoError(t, err) + require.NotNil(t, matchedEntity) + require.Equal(t, createdEntity.ID, matchedEntity.ID) + + // The conditional helper returns clones; mutating the result must not mutate MemDB state. + matchedEntity.Name = "mutated-name" + freshEntity, err := is.entityByAliasFactors(ghAccessor, loginAlias.Name, true) + require.NoError(t, err) + require.NotNil(t, freshEntity) + require.NotEqual(t, "mutated-name", freshEntity.Name) + + notMatchedEntity, err := is.entityByAliasFactorsInTxnIf(txn, ghAccessor, loginAlias.Name, func(*identity.Alias) bool { + return false + }) + require.NoError(t, err) + require.Nil(t, notMatchedEntity) +} + func TestIdentityStore_WrapInfoInheritance(t *testing.T) { var err error var resp *logical.Response diff --git a/vault/identity_store_util.go b/vault/identity_store_util.go index 04674a6559..4c7ef2f6d5 100644 --- a/vault/identity_store_util.go +++ b/vault/identity_store_util.go @@ -1901,6 +1901,35 @@ func (i *IdentityStore) MemDBEntityByAliasIDInTxn(txn *memdb.Txn, aliasID string return i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, clone) } +// MemDBEntityByAliasIDInTxnClonePredicate fetches and clones an entity by alias +// ID only when shouldClone evaluates to true for the alias. +func (i *IdentityStore) MemDBEntityByAliasIDInTxnClonePredicate(txn *memdb.Txn, aliasID string, shouldClone func(*identity.Alias) bool) (*identity.Entity, error) { + var entity *identity.Entity + + if aliasID == "" { + return nil, fmt.Errorf("missing alias ID") + } + + if txn == nil { + return nil, fmt.Errorf("txn is nil") + } + + alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, false) + if err != nil { + return nil, err + } + + if alias == nil { + return entity, nil + } + + if shouldClone != nil && !shouldClone(alias) { + return entity, nil + } + + return i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, true) +} + func (i *IdentityStore) MemDBEntityByAliasID(aliasID string, clone bool) (*identity.Entity, error) { if aliasID == "" { return nil, fmt.Errorf("missing alias ID")