mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-05 04:16:31 +02:00
Merge remote-tracking branch 'remotes/from/ce/main'
This commit is contained in:
commit
63f4ea79a0
3
changelog/_14050.txt
Normal file
3
changelog/_14050.txt
Normal file
@ -0,0 +1,3 @@
|
||||
```release-note:bug
|
||||
identity: fixed a rare but possible data race issue with identities.
|
||||
```
|
||||
@ -4,7 +4,10 @@
|
||||
package identity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"maps"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
@ -16,10 +19,35 @@ import (
|
||||
ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/minimal"
|
||||
"github.com/hashicorp/vault/sdk/helper/ldaputil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testIdentityStoreWithGithubUserpassAuth(ctx context.Context, t *testing.T) (*vault.IdentityStore, string, string, *vault.TestClusterCore) {
|
||||
t.Helper()
|
||||
|
||||
cluster := minimal.NewTestSoloCluster(t, nil)
|
||||
client := cluster.Cores[0].Client
|
||||
|
||||
err := client.Sys().EnableAuthWithOptions("github", &api.EnableAuthOptions{Type: "github"})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{Type: "userpass"})
|
||||
require.NoError(t, err)
|
||||
|
||||
auth, err := client.Sys().ListAuth()
|
||||
require.NoError(t, err)
|
||||
|
||||
githubAccessor := auth["github/"].Accessor
|
||||
require.NotEmpty(t, githubAccessor)
|
||||
|
||||
userpassAccessor := auth["userpass/"].Accessor
|
||||
require.NotEmpty(t, userpassAccessor)
|
||||
|
||||
return cluster.Cores[0].IdentityStore(), githubAccessor, userpassAccessor, cluster.Cores[0]
|
||||
}
|
||||
|
||||
func TestIdentityStore_ExternalGroupMemberships_DifferentMounts(t *testing.T) {
|
||||
t.Parallel()
|
||||
cluster := minimal.NewTestSoloCluster(t, nil)
|
||||
@ -665,3 +693,195 @@ func findEntityFromDuplicateSet(t *testing.T, c *vault.TestClusterCore, entityID
|
||||
"node %s does not have exactly one duplicate from the set", c.NodeID)
|
||||
return entity
|
||||
}
|
||||
|
||||
// TestIdentityStore_CreateOrFetchEntity_ConcurrentFastPath verifies concurrent
|
||||
// callers on the fast path all read the existing entity without creating new
|
||||
// entities or mutating alias metadata.
|
||||
func TestIdentityStore_CreateOrFetchEntity_ConcurrentFastPath(t *testing.T) {
|
||||
ctx := namespace.RootContext(nil)
|
||||
is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t)
|
||||
|
||||
alias := &logical.Alias{
|
||||
MountType: "github",
|
||||
MountAccessor: ghAccessor,
|
||||
Name: "githubuser",
|
||||
Metadata: map[string]string{
|
||||
"foo": "a",
|
||||
},
|
||||
}
|
||||
|
||||
entity, _, err := is.CreateOrFetchEntity(ctx, alias)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, entity)
|
||||
|
||||
const workers = 16
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, workers)
|
||||
|
||||
for range workers {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
got, created, err := is.CreateOrFetchEntity(ctx, alias)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if created {
|
||||
errCh <- fmt.Errorf("unexpected entity creation on fast path")
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
errCh <- fmt.Errorf("expected entity on fast path")
|
||||
return
|
||||
}
|
||||
if len(got.Aliases) != 1 {
|
||||
errCh <- fmt.Errorf("expected 1 alias, got %d", len(got.Aliases))
|
||||
return
|
||||
}
|
||||
if !maps.Equal(got.Aliases[0].Metadata, alias.Metadata) {
|
||||
errCh <- fmt.Errorf("unexpected alias metadata: %#v", got.Aliases[0].Metadata)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIdentityStore_CreateOrFetchEntity_ConcurrentMetadataUpdates verifies
|
||||
// concurrent readers and metadata updates for the same alias remain stable and
|
||||
// return a single-entity view.
|
||||
func TestIdentityStore_CreateOrFetchEntity_ConcurrentMetadataUpdates(t *testing.T) {
|
||||
ctx := namespace.RootContext(nil)
|
||||
is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t)
|
||||
|
||||
newAlias := func(metadataValue string) *logical.Alias {
|
||||
return &logical.Alias{
|
||||
MountType: "github",
|
||||
MountAccessor: ghAccessor,
|
||||
Name: "githubuser",
|
||||
Metadata: map[string]string{
|
||||
"foo": metadataValue,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
entity, _, err := is.CreateOrFetchEntity(ctx, newAlias("a"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, entity)
|
||||
|
||||
const workers = 8
|
||||
const iterations = 100
|
||||
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, workers*2)
|
||||
|
||||
for range workers {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
|
||||
for range iterations {
|
||||
got, _, err := is.CreateOrFetchEntity(ctx, newAlias("a"))
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
errCh <- fmt.Errorf("expected entity from read path")
|
||||
return
|
||||
}
|
||||
if len(got.Aliases) != 1 {
|
||||
errCh <- fmt.Errorf("expected 1 alias, got %d", len(got.Aliases))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
|
||||
for i := range iterations {
|
||||
metadataValue := "b"
|
||||
if i%2 == 0 {
|
||||
metadataValue = "a"
|
||||
}
|
||||
|
||||
got, _, err := is.CreateOrFetchEntity(ctx, newAlias(metadataValue))
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
errCh <- fmt.Errorf("expected entity from update path")
|
||||
return
|
||||
}
|
||||
if len(got.Aliases) != 1 {
|
||||
errCh <- fmt.Errorf("expected 1 alias, got %d", len(got.Aliases))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIdentityStore_EntityByAliasFactors_BlackBoxCloneBehavior verifies
|
||||
// alias-based lookup and clone behavior using only exported identity store APIs.
|
||||
func TestIdentityStore_EntityByAliasFactors_BlackBoxCloneBehavior(t *testing.T) {
|
||||
ctx := namespace.RootContext(nil)
|
||||
is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t)
|
||||
|
||||
loginAlias := &logical.Alias{
|
||||
MountType: "github",
|
||||
MountAccessor: ghAccessor,
|
||||
Name: "githubuser",
|
||||
Metadata: map[string]string{
|
||||
"foo": "a",
|
||||
},
|
||||
}
|
||||
|
||||
createdEntity, _, err := is.CreateOrFetchEntity(ctx, loginAlias)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdEntity)
|
||||
|
||||
matchedAlias, err := is.MemDBAliasByFactors(ghAccessor, loginAlias.Name, true, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, matchedAlias)
|
||||
require.True(t, maps.Equal(matchedAlias.Metadata, loginAlias.Metadata))
|
||||
|
||||
matchedEntity, err := is.MemDBEntityByID(matchedAlias.CanonicalID, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, matchedEntity)
|
||||
require.Equal(t, createdEntity.ID, matchedEntity.ID)
|
||||
|
||||
// Clone reads should not mutate MemDB state.
|
||||
matchedEntity.Name = "mutated-name"
|
||||
freshAlias, err := is.MemDBAliasByFactors(ghAccessor, loginAlias.Name, true, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, freshAlias)
|
||||
|
||||
freshEntity, err := is.MemDBEntityByID(freshAlias.CanonicalID, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, freshEntity)
|
||||
require.NotEqual(t, "mutated-name", freshEntity.Name)
|
||||
|
||||
require.False(t, maps.Equal(freshAlias.Metadata, map[string]string{"foo": "does-not-match"}))
|
||||
}
|
||||
|
||||
@ -1301,6 +1301,51 @@ func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor,
|
||||
return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone)
|
||||
}
|
||||
|
||||
// entityByAliasFactorsIf fetches and clones an entity by alias factors only when
|
||||
// shouldReturn evaluates to true for the matching alias.
|
||||
func (i *IdentityStore) entityByAliasFactorsIf(mountAccessor, aliasName string, shouldReturn func(*identity.Alias) bool) (*identity.Entity, error) {
|
||||
if mountAccessor == "" {
|
||||
return nil, fmt.Errorf("missing mount accessor")
|
||||
}
|
||||
|
||||
if aliasName == "" {
|
||||
return nil, fmt.Errorf("missing alias name")
|
||||
}
|
||||
|
||||
txn := i.db.Txn(false)
|
||||
|
||||
return i.entityByAliasFactorsInTxnIf(txn, mountAccessor, aliasName, shouldReturn)
|
||||
}
|
||||
|
||||
// entityByAliasFactorsInTxnIf fetches and clones an entity by alias factors only
|
||||
// when shouldReturn evaluates to true for the matching alias.
|
||||
func (i *IdentityStore) entityByAliasFactorsInTxnIf(txn *memdb.Txn, mountAccessor, aliasName string, shouldReturn func(*identity.Alias) bool) (*identity.Entity, error) {
|
||||
var entity *identity.Entity
|
||||
|
||||
if txn == nil {
|
||||
return nil, fmt.Errorf("nil txn")
|
||||
}
|
||||
|
||||
if mountAccessor == "" {
|
||||
return nil, fmt.Errorf("missing mount accessor")
|
||||
}
|
||||
|
||||
if aliasName == "" {
|
||||
return nil, fmt.Errorf("missing alias name")
|
||||
}
|
||||
|
||||
alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if alias == nil {
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
return i.MemDBEntityByAliasIDInTxnClonePredicate(txn, alias.ID, shouldReturn)
|
||||
}
|
||||
|
||||
// CreateEntity creates a new entity.
|
||||
func (i *IdentityStore) CreateEntity(ctx context.Context) (*identity.Entity, error) {
|
||||
defer metrics.MeasureSince([]string{"identity", "create_entity"}, time.Now())
|
||||
@ -1359,21 +1404,15 @@ func (i *IdentityStore) CreateOrFetchEntity(ctx context.Context, alias *logical.
|
||||
return nil, false, fmt.Errorf("mount accessor %q is not a mount of type %q", alias.MountAccessor, alias.MountType)
|
||||
}
|
||||
|
||||
// Check if an entity already exists for the given alias.
|
||||
// We don't clone here to avoid unnecessary allocations - if we need to
|
||||
// return early, we'll clone at that point.
|
||||
entity, err = i.entityByAliasFactors(alias.MountAccessor, alias.Name, false)
|
||||
// Fast path: only clone and return the entity when alias metadata is unchanged.
|
||||
entity, err = i.entityByAliasFactorsIf(alias.MountAccessor, alias.Name, func(existingAlias *identity.Alias) bool {
|
||||
return strutil.EqualStringMaps(existingAlias.Metadata, alias.Metadata)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if entity != nil && changedAliasIndex(entity, alias) == -1 {
|
||||
// Entity exists and no metadata changes - clone before returning
|
||||
// to avoid exposing internal MemDB state to callers.
|
||||
clonedEntity, err := entity.Clone()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return clonedEntity, false, nil
|
||||
if entity != nil {
|
||||
return entity, false, nil
|
||||
}
|
||||
|
||||
i.lock.Lock()
|
||||
|
||||
@ -437,6 +437,49 @@ func TestIdentityStore_EntityByAliasFactors(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIdentityStore_EntityByAliasFactorsInTxnIf verifies the predicate-gated
|
||||
// lookup returns cloned entities on match and nil when the predicate rejects.
|
||||
func TestIdentityStore_EntityByAliasFactorsInTxnIf(t *testing.T) {
|
||||
ctx := namespace.RootContext(nil)
|
||||
is, ghAccessor, _, _ := testIdentityStoreWithGithubUserpassAuth(ctx, t)
|
||||
|
||||
loginAlias := &logical.Alias{
|
||||
MountType: "github",
|
||||
MountAccessor: ghAccessor,
|
||||
Name: "githubuser",
|
||||
Metadata: map[string]string{
|
||||
"foo": "a",
|
||||
},
|
||||
}
|
||||
|
||||
createdEntity, _, err := is.CreateOrFetchEntity(ctx, loginAlias)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdEntity)
|
||||
|
||||
txn := is.db.Txn(false)
|
||||
defer txn.Abort()
|
||||
|
||||
matchedEntity, err := is.entityByAliasFactorsInTxnIf(txn, ghAccessor, loginAlias.Name, func(existingAlias *identity.Alias) bool {
|
||||
return maps.Equal(existingAlias.Metadata, loginAlias.Metadata)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, matchedEntity)
|
||||
require.Equal(t, createdEntity.ID, matchedEntity.ID)
|
||||
|
||||
// The conditional helper returns clones; mutating the result must not mutate MemDB state.
|
||||
matchedEntity.Name = "mutated-name"
|
||||
freshEntity, err := is.entityByAliasFactors(ghAccessor, loginAlias.Name, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, freshEntity)
|
||||
require.NotEqual(t, "mutated-name", freshEntity.Name)
|
||||
|
||||
notMatchedEntity, err := is.entityByAliasFactorsInTxnIf(txn, ghAccessor, loginAlias.Name, func(*identity.Alias) bool {
|
||||
return false
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, notMatchedEntity)
|
||||
}
|
||||
|
||||
func TestIdentityStore_WrapInfoInheritance(t *testing.T) {
|
||||
var err error
|
||||
var resp *logical.Response
|
||||
|
||||
@ -1901,6 +1901,35 @@ func (i *IdentityStore) MemDBEntityByAliasIDInTxn(txn *memdb.Txn, aliasID string
|
||||
return i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, clone)
|
||||
}
|
||||
|
||||
// MemDBEntityByAliasIDInTxnClonePredicate fetches and clones an entity by alias
|
||||
// ID only when shouldClone evaluates to true for the alias.
|
||||
func (i *IdentityStore) MemDBEntityByAliasIDInTxnClonePredicate(txn *memdb.Txn, aliasID string, shouldClone func(*identity.Alias) bool) (*identity.Entity, error) {
|
||||
var entity *identity.Entity
|
||||
|
||||
if aliasID == "" {
|
||||
return nil, fmt.Errorf("missing alias ID")
|
||||
}
|
||||
|
||||
if txn == nil {
|
||||
return nil, fmt.Errorf("txn is nil")
|
||||
}
|
||||
|
||||
alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if alias == nil {
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
if shouldClone != nil && !shouldClone(alias) {
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
return i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, true)
|
||||
}
|
||||
|
||||
func (i *IdentityStore) MemDBEntityByAliasID(aliasID string, clone bool) (*identity.Entity, error) {
|
||||
if aliasID == "" {
|
||||
return nil, fmt.Errorf("missing alias ID")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user