diff --git a/audit/entry_formatter.go b/audit/entry_formatter.go index d89743c4e1..56611544ac 100644 --- a/audit/entry_formatter.go +++ b/audit/entry_formatter.go @@ -263,6 +263,7 @@ func mergeEnterpriseTokenMetadata(a *auth, req *logical.Request) error { if req.EnterpriseTokenMetadata == "" && req.EnterpriseTokenIssuer == "" && + req.EnterpriseTokenTransaction == "" && len(req.EnterpriseTokenAudience) == 0 && len(req.EnterpriseTokenAuthorizationDetails) == 0 { return nil @@ -277,6 +278,9 @@ func mergeEnterpriseTokenMetadata(a *auth, req *logical.Request) error { if req.EnterpriseTokenIssuer != "" { a.Metadata["enterprise_token_issuer"] = req.EnterpriseTokenIssuer } + if req.EnterpriseTokenTransaction != "" { + a.Metadata["enterprise_token_transaction"] = req.EnterpriseTokenTransaction + } if len(req.EnterpriseTokenAudience) > 0 { audJSON, err := json.Marshal(req.EnterpriseTokenAudience) if err != nil { diff --git a/audit/entry_formatter_test.go b/audit/entry_formatter_test.go index ea68d10ff9..f1832daff8 100644 --- a/audit/entry_formatter_test.go +++ b/audit/entry_formatter_test.go @@ -593,9 +593,10 @@ func TestMergeEnterpriseTokenMetadata(t *testing.T) { t.Parallel() requestTests := map[string]struct { - Input *logical.Request - ExpectedMetadata string - ExpectedIssuer string + Input *logical.Request + ExpectedMetadata string + ExpectedIssuer string + ExpectedTransaction string }{ "metadata-present": { Input: &logical.Request{ID: "req-1", EnterpriseTokenMetadata: "token-abc"}, @@ -614,6 +615,15 @@ func TestMergeEnterpriseTokenMetadata(t *testing.T) { ExpectedMetadata: "token-xyz", ExpectedIssuer: "https://issuer.example.com", }, + "transaction-present": { + Input: &logical.Request{ + ID: "req-4", + EnterpriseTokenMetadata: "token-txn", + EnterpriseTokenTransaction: "txn-123", + }, + ExpectedMetadata: "token-txn", + ExpectedTransaction: "txn-123", + }, } for name, tc := range requestTests { @@ -623,7 +633,7 @@ func TestMergeEnterpriseTokenMetadata(t *testing.T) { a := &auth{} err := mergeEnterpriseTokenMetadata(a, tc.Input) require.NoError(t, err) - if tc.ExpectedMetadata == "" && tc.ExpectedIssuer == "" { + if tc.ExpectedMetadata == "" && tc.ExpectedIssuer == "" && tc.ExpectedTransaction == "" { require.Nil(t, a.Metadata) } @@ -640,12 +650,14 @@ func TestMergeEnterpriseTokenMetadata(t *testing.T) { assertMetadataField("enterprise_token_metadata", tc.ExpectedMetadata) assertMetadataField("enterprise_token_issuer", tc.ExpectedIssuer) + assertMetadataField("enterprise_token_transaction", tc.ExpectedTransaction) }) } } // TestEntryFormatter_Process_JSON_EnterpriseToken verifies that enterprise token fields // (actor_entity_id, actor_entity_name, enterprise_token_metadata, enterprise_token_issuer, +// enterprise_token_transaction, // enterprise_token_audience, enterprise_token_authorization_details) are correctly // serialized into auth.metadata in the JSON audit output, and absent when not set. func TestEntryFormatter_Process_JSON_EnterpriseToken(t *testing.T) { @@ -664,6 +676,7 @@ func TestEntryFormatter_Process_JSON_EnterpriseToken(t *testing.T) { WantActorEntityName string WantMetadata string WantIssuer string + WantTransaction string WantAudience string WantAuthorizationDetails string }{ @@ -683,6 +696,7 @@ func TestEntryFormatter_Process_JSON_EnterpriseToken(t *testing.T) { Path: "/cubbyhole/test", EnterpriseTokenMetadata: "test-token-abc", EnterpriseTokenIssuer: "https://issuer.example.com", + EnterpriseTokenTransaction: "txn-actor-1", EnterpriseTokenAudience: []string{"vault"}, EnterpriseTokenAuthorizationDetails: authzDetails, Connection: &logical.Connection{ @@ -693,6 +707,7 @@ func TestEntryFormatter_Process_JSON_EnterpriseToken(t *testing.T) { WantActorEntityName: "actor-service", WantMetadata: "test-token-abc", WantIssuer: "https://issuer.example.com", + WantTransaction: "txn-actor-1", WantAudience: `["vault"]`, WantAuthorizationDetails: `[{"currency":"USD","type":"payment_initiation"}]`, }, @@ -706,18 +721,20 @@ func TestEntryFormatter_Process_JSON_EnterpriseToken(t *testing.T) { TokenType: logical.TokenTypeDefault, }, Req: &logical.Request{ - Operation: logical.ReadOperation, - Path: "/cubbyhole/test", - EnterpriseTokenMetadata: "test-token-xyz", - EnterpriseTokenIssuer: "https://issuer.example.com", - EnterpriseTokenAudience: []string{"vault"}, + Operation: logical.ReadOperation, + Path: "/cubbyhole/test", + EnterpriseTokenMetadata: "test-token-xyz", + EnterpriseTokenIssuer: "https://issuer.example.com", + EnterpriseTokenTransaction: "txn-base-1", + EnterpriseTokenAudience: []string{"vault"}, Connection: &logical.Connection{ RemoteAddr: "127.0.0.1", }, }, - WantMetadata: "test-token-xyz", - WantIssuer: "https://issuer.example.com", - WantAudience: `["vault"]`, + WantMetadata: "test-token-xyz", + WantIssuer: "https://issuer.example.com", + WantTransaction: "txn-base-1", + WantAudience: `["vault"]`, }, } @@ -766,6 +783,7 @@ func TestEntryFormatter_Process_JSON_EnterpriseToken(t *testing.T) { require.NotNil(t, result.Request) require.Equal(t, tc.WantMetadata, result.Auth.Metadata["enterprise_token_metadata"]) require.Equal(t, tc.WantIssuer, result.Auth.Metadata["enterprise_token_issuer"]) + require.Equal(t, tc.WantTransaction, result.Auth.Metadata["enterprise_token_transaction"]) require.Equal(t, tc.WantAudience, result.Auth.Metadata["enterprise_token_audience"]) require.Equal(t, tc.WantAuthorizationDetails, result.Auth.Metadata["enterprise_token_authorization_details"]) }) @@ -798,11 +816,12 @@ func TestEntryFormatter_Process_Response_EnterpriseToken(t *testing.T) { TokenType: logical.TokenTypeDefault, }, Request: &logical.Request{ - Operation: logical.ReadOperation, - Path: "/secret/data/test", - EnterpriseTokenMetadata: "resp-token-abc", - EnterpriseTokenIssuer: "https://issuer.example.com", - EnterpriseTokenAudience: []string{"vault", "api"}, + Operation: logical.ReadOperation, + Path: "/secret/data/test", + EnterpriseTokenMetadata: "resp-token-abc", + EnterpriseTokenIssuer: "https://issuer.example.com", + EnterpriseTokenTransaction: "txn-response-1", + EnterpriseTokenAudience: []string{"vault", "api"}, Connection: &logical.Connection{ RemoteAddr: "127.0.0.1", }, @@ -847,6 +866,7 @@ func TestEntryFormatter_Process_Response_EnterpriseToken(t *testing.T) { require.Equal(t, "actor-service", result.Auth.Metadata["actor_entity_name"]) require.Equal(t, "resp-token-abc", result.Auth.Metadata["enterprise_token_metadata"]) require.Equal(t, "https://issuer.example.com", result.Auth.Metadata["enterprise_token_issuer"]) + require.Equal(t, "txn-response-1", result.Auth.Metadata["enterprise_token_transaction"]) require.Equal(t, `["vault","api"]`, result.Auth.Metadata["enterprise_token_audience"]) // Response auth must also have enterprise token fields in metadata @@ -854,6 +874,7 @@ func TestEntryFormatter_Process_Response_EnterpriseToken(t *testing.T) { require.NotNil(t, result.Response.Auth) require.Equal(t, "resp-token-abc", result.Response.Auth.Metadata["enterprise_token_metadata"]) require.Equal(t, "https://issuer.example.com", result.Response.Auth.Metadata["enterprise_token_issuer"]) + require.Equal(t, "txn-response-1", result.Response.Auth.Metadata["enterprise_token_transaction"]) require.Equal(t, `["vault","api"]`, result.Response.Auth.Metadata["enterprise_token_audience"]) } @@ -888,6 +909,7 @@ func TestEntryFormatter_EnterpriseTokenFieldsNotOnRequestOrAuthTopLevel(t *testi Path: "/secret/data/test", EnterpriseTokenMetadata: "test-token-123", EnterpriseTokenIssuer: "https://issuer.example.com", + EnterpriseTokenTransaction: "txn-top-level-1", EnterpriseTokenAudience: []string{"vault"}, EnterpriseTokenAuthorizationDetails: []logical.AuthorizationDetail{{"type": "access"}}, Connection: &logical.Connection{ @@ -951,6 +973,10 @@ func TestEntryFormatter_EnterpriseTokenFieldsNotOnRequestOrAuthTopLevel(t *testi require.True(t, ok) require.Equal(t, "https://issuer.example.com", tokenIssuer) + tokenTransaction, ok := metadataMap["enterprise_token_transaction"] + require.True(t, ok) + require.Equal(t, "txn-top-level-1", tokenTransaction) + tokenAudience, ok := metadataMap["enterprise_token_audience"] require.True(t, ok) require.Equal(t, `["vault"]`, tokenAudience) diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index 092e453ee1..b8396a3e6f 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -442,9 +442,10 @@ func TestCopy_request_EnterpriseTokenFields(t *testing.T) { Data: map[string]interface{}{ "foo": "bar", }, - EnterpriseTokenMetadata: "test-token-abc", - EnterpriseTokenIssuer: "https://issuer.example.com", - EnterpriseTokenAudience: []string{"vault", "api"}, + EnterpriseTokenMetadata: "test-token-abc", + EnterpriseTokenIssuer: "https://issuer.example.com", + EnterpriseTokenTransaction: "txn-copy-1", + EnterpriseTokenAudience: []string{"vault", "api"}, EnterpriseTokenAuthorizationDetails: []logical.AuthorizationDetail{ { "type": "vault:path_access", @@ -476,10 +477,11 @@ func TestHashRequest_EnterpriseTokenFieldsInMetadata(t *testing.T) { auditAuth := &auth{ ClientToken: "secret-token", Metadata: map[string]string{ - "enterprise_token_metadata": "test-token-xyz", - "enterprise_token_issuer": "https://issuer.example.com", - "actor_entity_id": "actor-123", - "actor_entity_name": "actor-service", + "enterprise_token_metadata": "test-token-xyz", + "enterprise_token_issuer": "https://issuer.example.com", + "enterprise_token_transaction": "txn-hash-1", + "actor_entity_id": "actor-123", + "actor_entity_name": "actor-service", }, } @@ -494,6 +496,7 @@ func TestHashRequest_EnterpriseTokenFieldsInMetadata(t *testing.T) { // Metadata values must pass through unchanged — they are not secrets. require.Equal(t, "test-token-xyz", auditAuth.Metadata["enterprise_token_metadata"]) require.Equal(t, "https://issuer.example.com", auditAuth.Metadata["enterprise_token_issuer"]) + require.Equal(t, "txn-hash-1", auditAuth.Metadata["enterprise_token_transaction"]) require.Equal(t, "actor-123", auditAuth.Metadata["actor_entity_id"]) require.Equal(t, "actor-service", auditAuth.Metadata["actor_entity_name"]) } diff --git a/command/base_predict_test.go b/command/base_predict_test.go index 18c55eb6d9..27795ddfb5 100644 --- a/command/base_predict_test.go +++ b/command/base_predict_test.go @@ -455,7 +455,7 @@ func TestPredict_Policies(t *testing.T) { { "good_path", client, - []string{"default", "root"}, + []string{"default", "default-ceiling", "root"}, }, } diff --git a/command/policy_delete.go b/command/policy_delete.go index d3e241adc1..b7cb70c3ec 100644 --- a/command/policy_delete.go +++ b/command/policy_delete.go @@ -35,8 +35,8 @@ Usage: vault policy delete [options] NAME $ vault policy delete my-policy - Note that it is not possible to delete the "default" or "root" policies. - These are built-in policies. + Note that it is not possible to delete the "default", "default-ceiling", + or "root" policies. These are built-in policies. ` + c.Flags().Help() diff --git a/command/policy_delete_test.go b/command/policy_delete_test.go index 06068110b2..076e22a3ae 100644 --- a/command/policy_delete_test.go +++ b/command/policy_delete_test.go @@ -105,7 +105,7 @@ func TestPolicyDeleteCommand_Run(t *testing.T) { t.Fatal(err) } - list := []string{"default", "root"} + list := []string{"default", "default-ceiling", "root"} if !reflect.DeepEqual(policies, list) { t.Errorf("expected %q to be %q", policies, list) } diff --git a/command/policy_list_test.go b/command/policy_list_test.go index aec04c1b10..8a39cd7de6 100644 --- a/command/policy_list_test.go +++ b/command/policy_list_test.go @@ -80,7 +80,7 @@ func TestPolicyListCommand_Run(t *testing.T) { t.Errorf("expected %d to be %d", code, exp) } - expected := "default\nroot" + expected := "default\ndefault-ceiling\nroot" combined := ui.OutputWriter.String() + ui.ErrorWriter.String() if !strings.Contains(combined, expected) { t.Errorf("expected %q to contain %q", combined, expected) diff --git a/command/policy_write_test.go b/command/policy_write_test.go index 4a67f229a3..76b8997517 100644 --- a/command/policy_write_test.go +++ b/command/policy_write_test.go @@ -131,7 +131,7 @@ func TestPolicyWriteCommand_Run(t *testing.T) { t.Fatal(err) } - list := []string{"default", "my-policy", "root"} + list := []string{"default", "default-ceiling", "my-policy", "root"} if !reflect.DeepEqual(policies, list) { t.Errorf("expected %q to be %q", policies, list) } @@ -172,7 +172,7 @@ func TestPolicyWriteCommand_Run(t *testing.T) { t.Fatal(err) } - list := []string{"default", "my-policy", "root"} + list := []string{"default", "default-ceiling", "my-policy", "root"} if !reflect.DeepEqual(policies, list) { t.Errorf("expected %q to be %q", policies, list) } diff --git a/http/sys_policy_test.go b/http/sys_policy_test.go index c6924bc1ff..91880bb6a8 100644 --- a/http/sys_policy_test.go +++ b/http/sys_policy_test.go @@ -29,11 +29,11 @@ func TestSysPolicies(t *testing.T) { "auth": nil, "mount_type": "system", "data": map[string]interface{}{ - "policies": []interface{}{"default", "root"}, - "keys": []interface{}{"default", "root"}, + "policies": []interface{}{"default", "default-ceiling", "root"}, + "keys": []interface{}{"default", "default-ceiling", "root"}, }, - "policies": []interface{}{"default", "root"}, - "keys": []interface{}{"default", "root"}, + "policies": []interface{}{"default", "default-ceiling", "root"}, + "keys": []interface{}{"default", "default-ceiling", "root"}, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -98,11 +98,11 @@ func TestSysWritePolicy(t *testing.T) { "auth": nil, "mount_type": "system", "data": map[string]interface{}{ - "policies": []interface{}{"default", "foo", "root"}, - "keys": []interface{}{"default", "foo", "root"}, + "policies": []interface{}{"default", "default-ceiling", "foo", "root"}, + "keys": []interface{}{"default", "default-ceiling", "foo", "root"}, }, - "policies": []interface{}{"default", "foo", "root"}, - "keys": []interface{}{"default", "foo", "root"}, + "policies": []interface{}{"default", "default-ceiling", "foo", "root"}, + "keys": []interface{}{"default", "default-ceiling", "foo", "root"}, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) @@ -134,6 +134,7 @@ func TestSysDeletePolicy(t *testing.T) { // Also attempt to delete these since they should not be allowed (ignore // responses, if they exist later that's sufficient) resp = testHttpDelete(t, token, addr+"/v1/sys/policy/default") + resp = testHttpDelete(t, token, addr+"/v1/sys/policy/default-ceiling") resp = testHttpDelete(t, token, addr+"/v1/sys/policy/response-wrapping") resp = testHttpGet(t, token, addr+"/v1/sys/policy") @@ -148,11 +149,11 @@ func TestSysDeletePolicy(t *testing.T) { "auth": nil, "mount_type": "system", "data": map[string]interface{}{ - "policies": []interface{}{"default", "root"}, - "keys": []interface{}{"default", "root"}, + "policies": []interface{}{"default", "default-ceiling", "root"}, + "keys": []interface{}{"default", "default-ceiling", "root"}, }, - "policies": []interface{}{"default", "root"}, - "keys": []interface{}{"default", "root"}, + "policies": []interface{}{"default", "default-ceiling", "root"}, + "keys": []interface{}{"default", "default-ceiling", "root"}, } testResponseStatus(t, resp, 200) testResponseBody(t, resp, &actual) diff --git a/sdk/helper/locksutil/locks.go b/sdk/helper/locksutil/locks.go index 378cd8f5f0..1897daf96c 100644 --- a/sdk/helper/locksutil/locks.go +++ b/sdk/helper/locksutil/locks.go @@ -54,10 +54,19 @@ func LockIndexForKey(key string) uint8 { return uint8(cryptoutil.Blake2b256Hash(key)[0]) } +// LockForKey returns the striped lock entry for a key. +// Different logical keys can hash to the same underlying lock, so callers must +// not assume two keys imply two distinct RWMutexes. If a code path needs to +// lock more than one key, prefer LocksForKeys to deduplicate aliased stripes and +// avoid self-deadlocking by re-entering the same lock. func LockForKey(locks []*LockEntry, key string) *LockEntry { return locks[LockIndexForKey(key)] } +// LocksForKeys returns the unique striped lock entries for a set of keys in a +// stable slice order. Use this when a code path needs more than one keyed lock: +// it deduplicates keys that alias to the same stripe and supports consistent +// acquisition ordering across callers. func LocksForKeys(locks []*LockEntry, keys []string) []*LockEntry { lockIndexes := make(map[uint8]struct{}, len(keys)) for _, k := range keys { diff --git a/sdk/logical/request.go b/sdk/logical/request.go index 8184764f0f..c6155a3a9f 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -147,12 +147,20 @@ type Request struct { // EnterpriseTokenIssuer stores the enterprise token issuer. EnterpriseTokenIssuer string `json:"enterprise_token_issuer,omitempty" structs:"enterprise_token_issuer" mapstructure:"enterprise_token_issuer"` + // EnterpriseTokenTransaction stores the enterprise token transaction claim. + EnterpriseTokenTransaction string `json:"enterprise_token_transaction,omitempty" structs:"enterprise_token_transaction" mapstructure:"enterprise_token_transaction"` + // EnterpriseTokenAudience stores enterprise token audience values. EnterpriseTokenAudience []string `json:"enterprise_token_audience,omitempty" structs:"enterprise_token_audience" mapstructure:"enterprise_token_audience"` // EnterpriseTokenAuthorizationDetails stores enterprise token authorization details. EnterpriseTokenAuthorizationDetails []AuthorizationDetail `json:"enterprise_token_authorization_details,omitempty" structs:"enterprise_token_authorization_details" mapstructure:"enterprise_token_authorization_details"` + // EnterpriseTokenAuthorizationDetailsPresent indicates whether the inbound + // enterprise token included an authorization_details claim at all. This lets + // callers distinguish "claim missing" from "claim present but empty". + EnterpriseTokenAuthorizationDetailsPresent bool `json:"enterprise_token_authorization_details_present,omitempty" structs:"enterprise_token_authorization_details_present" mapstructure:"enterprise_token_authorization_details_present"` + // ClientTokenAccessor is provided to the core so that the it can get // logged as part of request audit logging. ClientTokenAccessor string `json:"client_token_accessor" structs:"client_token_accessor" mapstructure:"client_token_accessor" sentinel:""` @@ -279,8 +287,11 @@ type Request struct { // client token. ClientID string `json:"client_id" structs:"client_id" mapstructure:"client_id" sentinel:""` - // InboundSSCToken is the token that arrives on an inbound request, supplied - // by the vault user. + // InboundSSCToken stores the original token value as supplied by the caller + // on the inbound request (header/body), before token decoding or + // normalization (for example SSCT decoding or enterprise token normalization + // to internal IDs). This allows response/forwarding paths to preserve the + // caller-visible token representation when needed. InboundSSCToken string // When a request has been forwarded, contains information of the host the request was forwarded 'from' diff --git a/vault/core_metrics_test.go b/vault/core_metrics_test.go index fc9d31502d..9e9a1d9d32 100644 --- a/vault/core_metrics_test.go +++ b/vault/core_metrics_test.go @@ -413,8 +413,8 @@ func TestCoreMetrics_AvailablePolicies(t *testing.T) { }, }, ExpectedValues: map[string]float32{ - // The "default" policy will always be included - "acl": 2, + // The built-in ACL policies are always included. + "acl": 3, "egp": 0, "rgp": 0, }, @@ -429,8 +429,8 @@ func TestCoreMetrics_AvailablePolicies(t *testing.T) { }, }, ExpectedValues: map[string]float32{ - // The "default" policy will always be included - "acl": 3, + // The built-in ACL policies are always included. + "acl": 4, "egp": 0, "rgp": 0, }, diff --git a/vault/enterprise_token_lookup_ce.go b/vault/enterprise_token_lookup_ce.go new file mode 100644 index 0000000000..0bca3dd7f4 --- /dev/null +++ b/vault/enterprise_token_lookup_ce.go @@ -0,0 +1,12 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package vault + +import "errors" + +func resolveEnterpriseTokenIDForLookup(_ string) (string, error) { + return "", errors.New("enterprise build required") +} diff --git a/vault/expiration.go b/vault/expiration.go index 92efd1009f..b6e83c5ee0 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -1094,7 +1094,7 @@ func (m *ExpirationManager) revokeCommon(ctx context.Context, leaseID string, fo // Delete the secondary index, but only if it's a leased secret (not auth) if le.Secret != nil { var indexToken string - // Maintain secondary index by token, except for orphan batch tokens and ent tokens + // Maintain secondary index by token, except for orphan batch tokens and enterprise tokens switch le.ClientTokenType { case logical.TokenTypeBatch: te, err := m.tokenStore.lookupBatchTokenInternal(ctx, le.ClientToken) @@ -1558,6 +1558,9 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request // Register is used to take a request and response with an associated // lease. The secret gets assigned a LeaseID and the management of // the lease is assumed by the expiration manager. +// +// For enterprise tokens, Register uses the token entry ID for indexing and +// caps secret leases at token expiration, marking them non-renewable. func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, resp *logical.Response, loginRole string) (id string, retErr error) { defer metrics.MeasureSince([]string{"expire", "register"}, time.Now()) @@ -1606,6 +1609,9 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, namespace: ns, Version: 1, } + if te.Type == logical.TokenTypeEnt { + le.ClientToken = te.ID + } var indexToken string // Maintain secondary index by token, except for orphan batch tokens @@ -1658,14 +1664,19 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, le.ExpireTime = tokenLeaseTimes.ExpireTime } } - - // If the token is an ent token, derive TTL from the ent token if te.Type == logical.TokenTypeEnt { - entTokenExpireTime := deriveExpireTimeFromEntToken(te) - if !entTokenExpireTime.IsZero() && le.ExpireTime.After(entTokenExpireTime) { - // Use the ent token's expiration time for the lease - le.ExpireTime = entTokenExpireTime + tokenExpireTime := deriveExpireTimeFromEntToken(te) + if tokenExpireTime.IsZero() && te.TTL > 0 { + tokenExpireTime = time.Unix(te.CreationTime, 0).Add(te.TTL) } + if !tokenExpireTime.IsZero() && (le.ExpireTime.IsZero() || le.ExpireTime.After(tokenExpireTime)) { + le.ExpireTime = tokenExpireTime + le.Secret.TTL = le.ExpireTime.Sub(le.IssueTime) + if le.Secret.TTL < 0 { + le.Secret.TTL = 0 + } + } + le.Secret.Renewable = false } // Acquire the lock here so persistEntry and updatePending are atomic, @@ -2339,20 +2350,18 @@ func (m *ExpirationManager) deleteEntry(ctx context.Context, le *leaseEntry) err // createIndexByToken creates a secondary index from the token to a lease entry func (m *ExpirationManager) createIndexByToken(ctx context.Context, le *leaseEntry, token string) error { - var tokenNS *namespace.Namespace - var saltCtx context.Context - var err error - - if IsEnterpriseToken(token) { - // fetch the namespace from the lease rather than the req context to allow for cross namespace access - tokenNS, err = m.getNamespaceFromLeaseID(ctx, le.LeaseID) + tokenNS := namespace.RootNamespace + saltCtx := namespace.ContextWithNamespace(ctx, namespace.RootNamespace) + // For enterprise token IDs, derive namespace context from the lease rather than + // parsing token segments. + if IsEnterpriseTokenId(token) { + ns, err := m.getNamespaceFromLeaseID(ctx, le.LeaseID) if err != nil { return err } + tokenNS = ns saltCtx = namespace.ContextWithNamespace(ctx, tokenNS) } else { - tokenNS = namespace.RootNamespace - saltCtx = namespace.ContextWithNamespace(ctx, namespace.RootNamespace) _, nsID := namespace.SplitIDFromString(token) if nsID != "" { var err error @@ -2391,15 +2400,24 @@ func (m *ExpirationManager) createIndexByToken(ctx context.Context, le *leaseEnt func (m *ExpirationManager) indexByToken(ctx context.Context, le *leaseEntry) (*logical.StorageEntry, error) { tokenNS := namespace.RootNamespace saltCtx := namespace.ContextWithNamespace(ctx, tokenNS) - _, nsID := namespace.SplitIDFromString(le.ClientToken) - if nsID != "" { - var err error - tokenNS, err = NamespaceByID(ctx, nsID, m.core) + if IsEnterpriseTokenId(le.ClientToken) { + ns, err := m.getNamespaceFromLeaseID(ctx, le.LeaseID) if err != nil { return nil, err } - if tokenNS != nil { - saltCtx = namespace.ContextWithNamespace(ctx, tokenNS) + tokenNS = ns + saltCtx = namespace.ContextWithNamespace(ctx, tokenNS) + } else { + _, nsID := namespace.SplitIDFromString(le.ClientToken) + if nsID != "" { + var err error + tokenNS, err = NamespaceByID(ctx, nsID, m.core) + if err != nil { + return nil, err + } + if tokenNS != nil { + saltCtx = namespace.ContextWithNamespace(ctx, tokenNS) + } } } @@ -2424,20 +2442,16 @@ func (m *ExpirationManager) indexByToken(ctx context.Context, le *leaseEntry) (* // removeIndexByToken removes the secondary index from the token to a lease entry func (m *ExpirationManager) removeIndexByToken(ctx context.Context, le *leaseEntry, token string) error { - var tokenNS *namespace.Namespace - var saltCtx context.Context - var err error - - if IsEnterpriseToken(token) { - tokenNS, err = namespace.FromContext(ctx) + tokenNS := namespace.RootNamespace + saltCtx := namespace.ContextWithNamespace(ctx, namespace.RootNamespace) + if IsEnterpriseTokenId(token) { + ns, err := m.getNamespaceFromLeaseID(ctx, le.LeaseID) if err != nil { return err } - + tokenNS = ns saltCtx = namespace.ContextWithNamespace(ctx, tokenNS) } else { - tokenNS = namespace.RootNamespace - saltCtx = namespace.ContextWithNamespace(ctx, namespace.RootNamespace) _, nsID := namespace.SplitIDFromString(token) if nsID != "" { var err error @@ -2832,6 +2846,11 @@ func (m *ExpirationManager) markLeaseIrrevocable(ctx context.Context, le *leaseE m.nonexpiring.Delete(le.LeaseID) } +// getNamespaceFromLeaseID resolves the namespace encoded in a lease ID suffix. +// Lease IDs are generated by the expiration manager and include "." for +// non-root namespaces; root-namespace lease IDs have no namespace suffix. +// For persisted leaseEntry records, LeaseID is always set at creation, so +// namespace derivation from LeaseID is expected to be stable. func (m *ExpirationManager) getNamespaceFromLeaseID(ctx context.Context, leaseID string) (*namespace.Namespace, error) { _, nsID := namespace.SplitIDFromString(leaseID) diff --git a/vault/identity_store.go b/vault/identity_store.go index cd364b2a05..8377f502e5 100644 --- a/vault/identity_store.go +++ b/vault/identity_store.go @@ -67,25 +67,25 @@ func (i *IdentityStore) resetDB() error { func NewIdentityStore(ctx context.Context, core *Core, config *logical.BackendConfig, logger log.Logger) (*IdentityStore, error) { iStore := &IdentityStore{ - view: config.StorageView, - logger: logger, - router: core.router, - redirectAddr: core.redirectAddr, - localNode: core, - namespacer: core, - metrics: core.MetricSink(), - totpPersister: core, - groupUpdater: core, - tokenStorer: core, - entityCreator: core, - mountLister: core, - billingCounter: core, - mfaBackend: core.loginMFABackend, - aliasLocks: locksutil.CreateLocks(), - activationManager: core.FeatureActivationFlags, - activationErrorHandler: core, + view: config.StorageView, + logger: logger, + router: core.router, + redirectAddr: core.redirectAddr, + localNode: core, + namespacer: core, + metrics: core.MetricSink(), + totpPersister: core, + groupUpdater: core, + tokenStorer: core, + entityCreator: core, + mountLister: core, + billingCounter: core, + syntheticAliasAccessorValidator: core, + mfaBackend: core.loginMFABackend, + aliasLocks: locksutil.CreateLocks(), + activationManager: core.FeatureActivationFlags, + activationErrorHandler: core, } - // Create a memdb instance, which by default, operates on lower cased // identity names err := iStore.resetDB() diff --git a/vault/identity_store_aliases.go b/vault/identity_store_aliases.go index e2b7db272a..f47c54151c 100644 --- a/vault/identity_store_aliases.go +++ b/vault/identity_store_aliases.go @@ -136,6 +136,42 @@ This field is deprecated, use canonical_id.`, } } +// validateAliasMountAccessor validates mount_accessor values for entity aliases. +// +// It accepts either a real mounted backend accessor or a supported synthetic +// accessor validated by the synthetic alias accessor validator extension point. +// +// For mounted backend accessors, this returns the matched mount entry. For +// synthetic accessors, this returns a minimal entry carrying namespace/local +// semantics used by alias create/update checks. +func (i *IdentityStore) validateAliasMountAccessor(ctx context.Context, mountAccessor string) (*MountEntry, error) { + if mountAccessor == "" { + return nil, fmt.Errorf("invalid mount accessor %q", mountAccessor) + } + if mountEntry := i.router.MatchingMountByAccessor(mountAccessor); mountEntry != nil { + return mountEntry, nil + } + + if i.syntheticAliasAccessorValidator == nil { + i.logger.Error("synthetic alias accessor validator is not configured", "mount_accessor", mountAccessor) + return nil, fmt.Errorf("failed to validate mount accessor %q due to internal configuration error", mountAccessor) + } + + valid, err := i.syntheticAliasAccessorValidator.validateSyntheticAliasAccessor(ctx, mountAccessor) + if err != nil { + return nil, err + } + if !valid { + return nil, fmt.Errorf("invalid mount accessor %q", mountAccessor) + } + + ns, err := namespace.FromContext(ctx) + if err != nil { + return nil, err + } + return &MountEntry{NamespaceID: ns.ID}, nil +} + func aliasFieldSchema() map[string]*framework.FieldSchema { return map[string]*framework.FieldSchema{ "id": { @@ -279,20 +315,31 @@ func (i *IdentityStore) handleAliasCreateUpdate() framework.OperationFunc { } } + // If they didn't provide an ID or Mount Accessor, but provided an issuer, validate that the issuer has been + // registered. Return error if issuer has not been registered. + if mountAccessor == "" && issuer != "" { + // Generate synthetic Mount Accessor + syntheticAccessor, err := i.syntheticAliasAccessorValidator.generateSyntheticAliasAccessor(ctx, issuer) + if err != nil { + return logical.ErrorResponse(err.Error()), nil + } + mountAccessor = syntheticAccessor + } + // If they didn't provide an ID, we must have both accessor and name provided if mountAccessor == "" || name == "" { return logical.ErrorResponse("'id' or 'mount_accessor' and 'name' must be provided"), nil } - mountEntry := i.router.MatchingMountByAccessor(mountAccessor) - if mountEntry == nil { - return logical.ErrorResponse(fmt.Sprintf("invalid mount accessor %q", mountAccessor)), nil + mountEntry, err := i.validateAliasMountAccessor(ctx, mountAccessor) + if err != nil { + return logical.ErrorResponse(err.Error()), nil } - if mountEntry.NamespaceID != ns.ID { + if mountEntry != nil && mountEntry.NamespaceID != ns.ID { return logical.ErrorResponse("matching mount is in a different namespace than request"), logical.ErrPermissionDenied } - localMount := mountEntry.Local + localMount := mountEntry != nil && mountEntry.Local // Look up the alias by factors; if it's found it's an update return i.handleAliasCreateUpdateCommon(ctx, ns, mountAccessor, name, canonicalID, externalID, issuer, customMetadata, localMount, "") @@ -497,11 +544,11 @@ func (i *IdentityStore) handleAliasUpdate(ctx context.Context, canonicalID, name !strutil.EqualStringMaps(customMetadata, alias.CustomMetadata) || issuer != alias.Issuer || externalID != alias.ExternalID { // Check here to see if such an alias already exists, if so bail - mountEntry := i.router.MatchingMountByAccessor(mountAccessor) - if mountEntry == nil { - return logical.ErrorResponse(fmt.Sprintf("invalid mount accessor %q", mountAccessor)), nil + mountEntry, err := i.validateAliasMountAccessor(ctx, mountAccessor) + if err != nil { + return logical.ErrorResponse(err.Error()), nil } - if mountEntry.NamespaceID != alias.NamespaceID { + if mountEntry != nil && mountEntry.NamespaceID != alias.NamespaceID { return logical.ErrorResponse("given mount accessor is not in the same namespace as the existing alias"), logical.ErrPermissionDenied } @@ -536,15 +583,16 @@ func (i *IdentityStore) handleAliasUpdate(ctx context.Context, canonicalID, name alias.CustomMetadata = customMetadata } - mountValidationResp := i.router.ValidateMountByAccessor(alias.MountAccessor) - if mountValidationResp == nil { - return nil, fmt.Errorf("invalid mount accessor %q", alias.MountAccessor) + mountEntry, err := i.validateAliasMountAccessor(ctx, alias.MountAccessor) + if err != nil { + return nil, err } + mountIsLocal := mountEntry != nil && mountEntry.Local newEntity := currentEntity if canonicalID != "" && canonicalID != alias.CanonicalID { // Don't allow moving local aliases between entities. - if mountValidationResp.MountLocal { + if mountIsLocal { return logical.ErrorResponse("local aliases can't be moved between entities"), nil } @@ -590,11 +638,11 @@ func (i *IdentityStore) handleAliasUpdate(ctx context.Context, canonicalID, name currentEntity = nil } - if mountValidationResp.MountLocal { + if mountIsLocal { alias, err = i.processLocalAlias(ctx, &logical.Alias{ MountAccessor: mountAccessor, Name: name, - Local: mountValidationResp.MountLocal, + Local: mountIsLocal, CustomMetadata: customMetadata, Issuer: issuer, ExternalID: externalID, diff --git a/vault/identity_store_aliases_stubs_oss.go b/vault/identity_store_aliases_stubs_oss.go new file mode 100644 index 0000000000..122bef6485 --- /dev/null +++ b/vault/identity_store_aliases_stubs_oss.go @@ -0,0 +1,16 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package vault + +import "context" + +func (c *Core) validateSyntheticAliasAccessor(context.Context, string) (bool, error) { + return false, nil +} + +func (c *Core) generateSyntheticAliasAccessor(context.Context, string) (string, error) { + return "", nil +} diff --git a/vault/identity_store_structs.go b/vault/identity_store_structs.go index 3edd659248..49f4359670 100644 --- a/vault/identity_store_structs.go +++ b/vault/identity_store_structs.go @@ -104,18 +104,19 @@ type IdentityStore struct { // operated case insensitively disableLowerCasedNames bool - router *Router - redirectAddr string - localNode LocalNode - namespacer Namespacer - metrics metricsutil.Metrics - totpPersister TOTPPersister - groupUpdater GroupUpdater - tokenStorer TokenStorer - entityCreator EntityCreator - mountLister MountLister - mfaBackend *LoginMFABackend - billingCounter BillingCounter + router *Router + redirectAddr string + localNode LocalNode + namespacer Namespacer + metrics metricsutil.Metrics + totpPersister TOTPPersister + groupUpdater GroupUpdater + tokenStorer TokenStorer + entityCreator EntityCreator + mountLister MountLister + syntheticAliasAccessorValidator SyntheticAliasAccessorValidator + mfaBackend *LoginMFABackend + billingCounter BillingCounter // aliasLocks is used to protect modifications to alias entries based on the uniqueness factor // which is name + accessor @@ -203,6 +204,13 @@ type MountLister interface { var _ MountLister = &Core{} +type SyntheticAliasAccessorValidator interface { + validateSyntheticAliasAccessor(context.Context, string) (bool, error) + generateSyntheticAliasAccessor(context.Context, string) (string, error) +} + +var _ SyntheticAliasAccessorValidator = &Core{} + type Sealer interface { Shutdown() error } diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 60b3a0bd32..27f4bfeece 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -2722,8 +2722,8 @@ func TestSystemBackend_policyList(t *testing.T) { ) exp := map[string]interface{}{ - "keys": []string{"default", "root"}, - "policies": []string{"default", "root"}, + "keys": []string{"default", "default-ceiling", "root"}, + "policies": []string{"default", "default-ceiling", "root"}, } if !reflect.DeepEqual(resp.Data, exp) { t.Fatalf("got: %#v expect: %#v", resp.Data, exp) @@ -2801,8 +2801,8 @@ func TestSystemBackend_policyCRUD(t *testing.T) { } exp = map[string]interface{}{ - "keys": []string{"default", "foo", "root"}, - "policies": []string{"default", "foo", "root"}, + "keys": []string{"default", "default-ceiling", "foo", "root"}, + "policies": []string{"default", "default-ceiling", "foo", "root"}, } if !reflect.DeepEqual(resp.Data, exp) { t.Fatalf("got: %#v expect: %#v", resp.Data, exp) @@ -2844,8 +2844,8 @@ func TestSystemBackend_policyCRUD(t *testing.T) { } exp = map[string]interface{}{ - "keys": []string{"default", "root"}, - "policies": []string{"default", "root"}, + "keys": []string{"default", "default-ceiling", "root"}, + "policies": []string{"default", "default-ceiling", "root"}, } if !reflect.DeepEqual(resp.Data, exp) { t.Fatalf("got: %#v expect: %#v", resp.Data, exp) diff --git a/vault/policy_store.go b/vault/policy_store.go index 06e4dd5197..253ebe81a2 100644 --- a/vault/policy_store.go +++ b/vault/policy_store.go @@ -36,6 +36,9 @@ const ( // defaultPolicyName is the name of the default policy defaultPolicyName = "default" + // defaultCeilingPolicyName is the name of the default ceiling policy. + defaultCeilingPolicyName = "default-ceiling" + // responseWrappingPolicyName is the name of the fixed policy responseWrappingPolicyName = "response-wrapping" @@ -159,6 +162,22 @@ path "sys/control-group/request" { path "identity/oidc/provider/+/authorize" { capabilities = ["read", "update"] } +` + + // defaultCeilingPolicy is the default ceiling policy. + defaultCeilingPolicy = ` +# Allow an entity to inspect its own registration information +path "agent-registry/registration/entity_id/{{identity.entity.id}}" { + capabilities = ["read"] +} + +# Allow an entity to read the default policies +path "policy/default" { + capabilities = ["read"] +} +path "policy/default-ceiling" { + capabilities = ["read"] +} ` ) @@ -280,6 +299,10 @@ func (c *Core) setupPolicyStore(ctx context.Context) error { if err := c.policyStore.loadACLPolicy(ctx, defaultPolicyName, defaultPolicy); err != nil { return err } + // Ensure that the default ceiling policy exists, and if not, create it + if err := c.policyStore.loadACLPolicy(ctx, defaultCeilingPolicyName, defaultCeilingPolicy); err != nil { + return err + } // Ensure that the response wrapping policy exists if err := c.policyStore.loadACLPolicy(ctx, responseWrappingPolicyName, responseWrappingPolicy); err != nil { return err @@ -835,8 +858,8 @@ func (ps *PolicyStore) switchedDeletePolicy(ctx context.Context, name string, po if strutil.StrListContains(immutablePolicies, name) { return fmt.Errorf("cannot delete %q policy", name) } - if name == "default" { - return fmt.Errorf("cannot delete default policy") + if name == defaultPolicyName || name == defaultCeilingPolicyName { + return fmt.Errorf("cannot delete %s policy", name) } } diff --git a/vault/policy_store_test.go b/vault/policy_store_test.go index 1f7dfa88cb..1542d43eb0 100644 --- a/vault/policy_store_test.go +++ b/vault/policy_store_test.go @@ -117,7 +117,8 @@ func testPolicyStoreCRUD(t *testing.T, ps *PolicyStore, ns *namespace.Namespace) if err != nil { t.Fatalf("err: %v", err) } - if len(out) != 1 { + expected := []string{defaultPolicyName, defaultCeilingPolicyName} + if !reflect.DeepEqual(expected, out) { t.Fatalf("bad: %v", out) } @@ -139,17 +140,17 @@ func testPolicyStoreCRUD(t *testing.T, ps *PolicyStore, ns *namespace.Namespace) t.Fatalf("bad: %v", p) } - // List should contain two elements + // List should contain the two built-in assignable policies plus the new policy. ctx = namespace.ContextWithNamespace(context.Background(), ns) out, err = ps.ListPolicies(ctx, PolicyTypeACL) if err != nil { t.Fatalf("err: %v", err) } - if len(out) != 2 { + if len(out) != 3 { t.Fatalf("bad: %v", out) } - expected := []string{"default", "dev"} + expected = []string{defaultPolicyName, defaultCeilingPolicyName, "dev"} if !reflect.DeepEqual(expected, out) { t.Fatalf("expected: %v\ngot: %v", expected, out) } @@ -167,7 +168,8 @@ func testPolicyStoreCRUD(t *testing.T, ps *PolicyStore, ns *namespace.Namespace) if err != nil { t.Fatalf("err: %v", err) } - if len(out) != 1 || out[0] != "default" { + expected = []string{defaultPolicyName, defaultCeilingPolicyName} + if !reflect.DeepEqual(expected, out) { t.Fatalf("bad: %v", out) } @@ -191,17 +193,57 @@ func TestPolicyStore_Predefined(t *testing.T) { // Test predefined policy handling func testPolicyStorePredefined(t *testing.T, ps *PolicyStore, ns *namespace.Namespace) { - // List should be two elements + // List should contain the built-in assignable ACL policies. ctx := namespace.ContextWithNamespace(context.Background(), ns) out, err := ps.ListPolicies(ctx, PolicyTypeACL) if err != nil { t.Fatalf("err: %v", err) } - // This shouldn't contain response-wrapping since it's non-assignable - if len(out) != 1 || out[0] != "default" { + // This shouldn't contain response-wrapping since it's non-assignable. + expected := []string{defaultPolicyName, defaultCeilingPolicyName} + if !reflect.DeepEqual(expected, out) { t.Fatalf("bad: %v", out) } + ctx = namespace.ContextWithNamespace(context.Background(), ns) + pDefaultCeiling, err := ps.GetPolicy(ctx, defaultCeilingPolicyName, PolicyTypeToken) + if err != nil { + t.Fatalf("err: %v", err) + } + if pDefaultCeiling == nil { + t.Fatal("nil default ceiling policy") + } + if pDefaultCeiling.Raw != defaultCeilingPolicy { + t.Fatalf("bad: expected\n%s\ngot\n%s\n", defaultCeilingPolicy, pDefaultCeiling.Raw) + } + ctx = namespace.ContextWithNamespace(context.Background(), ns) + err = ps.DeletePolicy(ctx, pDefaultCeiling.Name, PolicyTypeACL) + if err == nil { + t.Fatalf("expected err deleting %s", pDefaultCeiling.Name) + } + + ctx = namespace.ContextWithNamespace(context.Background(), ns) + updatedDefaultCeiling, err := ParseACLPolicy(ns, aclPolicy) + if err != nil { + t.Fatalf("err: %v", err) + } + updatedDefaultCeiling.Name = defaultCeilingPolicyName + err = ps.SetPolicy(ctx, updatedDefaultCeiling) + if err != nil { + t.Fatalf("expected err to be nil updating %s: %v", updatedDefaultCeiling.Name, err) + } + ctx = namespace.ContextWithNamespace(context.Background(), ns) + pDefaultCeiling, err = ps.GetPolicy(ctx, defaultCeilingPolicyName, PolicyTypeToken) + if err != nil { + t.Fatalf("err: %v", err) + } + if pDefaultCeiling == nil { + t.Fatal("nil updated default ceiling policy") + } + if pDefaultCeiling.Raw != updatedDefaultCeiling.Raw { + t.Fatalf("bad: expected\n%s\ngot\n%s\n", updatedDefaultCeiling.Raw, pDefaultCeiling.Raw) + } + // Response-wrapping policy checks ctx = namespace.ContextWithNamespace(context.Background(), ns) pCubby, err := ps.GetPolicy(ctx, "response-wrapping", PolicyTypeToken) @@ -353,7 +395,7 @@ func TestPolicyStore_PoliciesByNamespaces(t *testing.T) { t.Fatalf("err: %v", err) } - expectedResult := []string{"default", "dev"} + expectedResult := []string{defaultPolicyName, defaultCeilingPolicyName, "dev"} if !reflect.DeepEqual(expectedResult, out) { t.Fatalf("expected: %v\ngot: %v", expectedResult, out) } diff --git a/vault/request_handling.go b/vault/request_handling.go index 8754eea141..a96d5d3d68 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -242,7 +242,7 @@ func (c *Core) fetchACLTokenEntryAndEntity(ctx context.Context, req *logical.Req var secondEntity *identity.Entity if IsEnterpriseToken(req.ClientToken) { - isValidEnterpriseToken, tokenMetadataContainer, entity, actorEntity, err := c.validateEnterpriseTokenAndFetchEntity(ctx, req.ClientToken) + isValidEnterpriseToken, tokenMetadataContainer, entity, actorEntity, chosenProfile, err := c.validateEnterpriseTokenAndFetchEntity(ctx, req.ClientToken) if err != nil { c.logger.Error("failed to validate enterprise token", "error", err) } @@ -251,10 +251,12 @@ func (c *Core) fetchACLTokenEntryAndEntity(ctx context.Context, req *logical.Req } req.EnterpriseTokenMetadata = getEnterpriseTokenMetadata(tokenMetadataContainer) req.EnterpriseTokenIssuer = getEnterpriseTokenIssuer(tokenMetadataContainer) + req.EnterpriseTokenTransaction = getEnterpriseTokenTransaction(tokenMetadataContainer) req.EnterpriseTokenAudience = getEnterpriseTokenAudience(tokenMetadataContainer) + _, req.EnterpriseTokenAuthorizationDetailsPresent = tokenMetadataContainer["authorization_details"] req.EnterpriseTokenAuthorizationDetails = getEnterpriseTokenAuthorizationDetails(tokenMetadataContainer) secondEntity = actorEntity - err = c.createAndStoreEnterpriseTokenEntry(ctx, req, tokenMetadataContainer, entity, actorEntity) + err = c.createAndStoreEnterpriseTokenEntry(ctx, req, tokenMetadataContainer, entity, actorEntity, chosenProfile) if err != nil { if c.perfStandby && errors.Is(err, logical.ErrReadOnly) { return nil, nil, nil, nil, logical.ErrPerfStandbyPleaseForward @@ -419,9 +421,19 @@ func (c *Core) fetchACLTokenEntryAndEntity(ctx context.Context, req *logical.Req } // restoreForwardingTokenHeaders restores client token headers so forwarded -// requests preserve the original auth source on the active node. +// requests preserve the caller's original token representation on the active +// node. It prefers Request.InboundSSCToken (captured before any token +// normalization) and falls back to Request.ClientToken when no inbound value is +// available. func restoreForwardingTokenHeaders(req *logical.Request) { - if req == nil || req.ClientToken == "" { + if req == nil { + return + } + tokenToForward := req.InboundSSCToken + if tokenToForward == "" { + tokenToForward = req.ClientToken + } + if tokenToForward == "" { return } if req.Headers == nil { @@ -429,9 +441,9 @@ func restoreForwardingTokenHeaders(req *logical.Request) { } switch req.ClientTokenSource { case logical.ClientTokenFromVaultHeader: - req.Headers[consts.AuthHeaderName] = []string{req.ClientToken} + req.Headers[consts.AuthHeaderName] = []string{tokenToForward} case logical.ClientTokenFromAuthzHeader: - req.Headers["Authorization"] = append(req.Headers["Authorization"], fmt.Sprintf("Bearer %s", req.ClientToken)) + req.Headers["Authorization"] = append(req.Headers["Authorization"], fmt.Sprintf("Bearer %s", tokenToForward)) } } @@ -675,12 +687,7 @@ func (c *Core) CheckToken(ctx context.Context, req *logical.Request, unauth bool // forward this request properly to the active node. if retErr.ErrorOrNil() != nil && checkErrControlGroupTokenNeedsCreated(retErr) && c.perfStandby && len(req.ClientToken) != 0 { - switch req.ClientTokenSource { - case logical.ClientTokenFromVaultHeader: - req.Headers[consts.AuthHeaderName] = []string{req.ClientToken} - case logical.ClientTokenFromAuthzHeader: - req.Headers["Authorization"] = append(req.Headers["Authorization"], fmt.Sprintf("Bearer %s", req.ClientToken)) - } + restoreForwardingTokenHeaders(req) // We also return the appropriate error so that the caller can forward the // request to the active node return auth, te, logical.ErrPerfStandbyPleaseForward diff --git a/vault/request_handling_ce.go b/vault/request_handling_ce.go index 80c57adda8..487740fa98 100644 --- a/vault/request_handling_ce.go +++ b/vault/request_handling_ce.go @@ -13,11 +13,13 @@ import ( "github.com/hashicorp/vault/sdk/logical" ) -func (c *Core) validateEnterpriseTokenAndFetchEntity(ctx context.Context, tokenString string) (bool, map[string]interface{}, *identity.Entity, *identity.Entity, error) { - return false, nil, nil, nil, errors.New("not implemented") +type OAuthResourceServerConfigProfile struct{} + +func (c *Core) validateEnterpriseTokenAndFetchEntity(ctx context.Context, tokenString string) (bool, map[string]interface{}, *identity.Entity, *identity.Entity, *OAuthResourceServerConfigProfile, error) { + return false, nil, nil, nil, nil, errors.New("not implemented") } -func (c *Core) createAndStoreEnterpriseTokenEntry(ctx context.Context, req *logical.Request, allClaims map[string]interface{}, entity *identity.Entity, actorEntity *identity.Entity) error { +func (c *Core) createAndStoreEnterpriseTokenEntry(ctx context.Context, req *logical.Request, allClaims map[string]interface{}, entity *identity.Entity, actorEntity *identity.Entity, chosenProfile *OAuthResourceServerConfigProfile) error { return nil } @@ -33,6 +35,10 @@ func getEnterpriseTokenIssuer(_ map[string]interface{}) string { return "" } +func getEnterpriseTokenTransaction(_ map[string]interface{}) string { + return "" +} + func getEnterpriseTokenAudience(_ map[string]interface{}) []string { return nil } diff --git a/vault/request_handling_test.go b/vault/request_handling_test.go index 3972378640..55362fd69d 100644 --- a/vault/request_handling_test.go +++ b/vault/request_handling_test.go @@ -54,6 +54,56 @@ func TestRequiresMaterializedTokenState(t *testing.T) { } } +// TestRestoreForwardingTokenHeaders_UsesInboundToken verifies Authorization +// forwarding prefers the original inbound token when present. +func TestRestoreForwardingTokenHeaders_UsesInboundToken(t *testing.T) { + t.Parallel() + + req := &logical.Request{ + ClientToken: "jwt.internal-id", + InboundSSCToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload.sig", + ClientTokenSource: logical.ClientTokenFromAuthzHeader, + Headers: map[string][]string{ + "Authorization": {"Basic abc123"}, + }, + } + + restoreForwardingTokenHeaders(req) + + require.Equal(t, []string{"Basic abc123", "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload.sig"}, req.Headers["Authorization"]) +} + +// TestRestoreForwardingTokenHeaders_FallsBackToClientToken verifies fallback to +// req.ClientToken when no inbound token is present. +func TestRestoreForwardingTokenHeaders_FallsBackToClientToken(t *testing.T) { + t.Parallel() + + req := &logical.Request{ + ClientToken: "jwt.jti-value", + ClientTokenSource: logical.ClientTokenFromVaultHeader, + } + + restoreForwardingTokenHeaders(req) + + require.Equal(t, []string{"jwt.jti-value"}, req.Headers["X-Vault-Token"]) +} + +// TestRestoreForwardingTokenHeaders_UsesInboundTokenForVaultHeader verifies +// X-Vault-Token forwarding prefers the original inbound token. +func TestRestoreForwardingTokenHeaders_UsesInboundTokenForVaultHeader(t *testing.T) { + t.Parallel() + + req := &logical.Request{ + ClientToken: "jwt.jti-value", + InboundSSCToken: "jwt.raw.value", + ClientTokenSource: logical.ClientTokenFromVaultHeader, + } + + restoreForwardingTokenHeaders(req) + + require.Equal(t, []string{"jwt.raw.value"}, req.Headers["X-Vault-Token"]) +} + func TestRequestHandling_Wrapping(t *testing.T) { core, _, root := TestCoreUnsealed(t) diff --git a/vault/token_store.go b/vault/token_store.go index fbb1106f97..93ffd8496a 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -1518,17 +1518,16 @@ func (ts *TokenStore) Lookup(ctx context.Context, id string) (*logical.TokenEntr if id == "" { return nil, fmt.Errorf("cannot lookup blank token") } + normalizedID := normalizeEnterpriseTokenToID(id) // If it starts with "b." it's a batch token - if IsBatchToken(id) { - return ts.lookupBatchToken(ctx, id) + if IsBatchToken(normalizedID) { + return ts.lookupBatchToken(ctx, normalizedID) } - - lock := locksutil.LockForKey(ts.tokenLocks, id) + lock := locksutil.LockForKey(ts.tokenLocks, normalizedID) lock.RLock() defer lock.RUnlock() - - return ts.lookupInternal(ctx, id, false, false) + return ts.lookupInternal(ctx, normalizedID, false, false) } func (ts *TokenStore) stripBatchPrefix(id string) string { @@ -2684,7 +2683,8 @@ func (ts *TokenStore) handleCreate(ctx context.Context, req *logical.Request, d // handleCreateCommon handles the auth/token/create path for creation of new tokens func (ts *TokenStore) handleCreateCommon(ctx context.Context, req *logical.Request, d *framework.FieldData, orphan bool, role *tsRoleEntry) (*logical.Response, error) { - if !orphan && IsEnterpriseToken(req.ClientToken) { + normalizedClientToken := normalizeEnterpriseTokenToID(req.ClientToken) + if !orphan && IsEnterpriseTokenId(normalizedClientToken) { return logical.ErrorResponse("enterprise tokens cannot create child tokens"), logical.ErrInvalidRequest } @@ -3355,7 +3355,8 @@ func (ts *TokenStore) handleRevokeTree(ctx context.Context, req *logical.Request } func (ts *TokenStore) revokeCommon(ctx context.Context, req *logical.Request, data *framework.FieldData, id string) (*logical.Response, error) { - if IsEnterpriseToken(id) { + normalizedID := normalizeEnterpriseTokenToID(id) + if IsEnterpriseTokenId(normalizedID) { return logical.ErrorResponse("cannot revoke ent token"), nil } te, err := ts.Lookup(ctx, id) @@ -3402,7 +3403,8 @@ func (ts *TokenStore) handleRevokeOrphan(ctx context.Context, req *logical.Reque return logical.ErrorResponse("missing token ID"), logical.ErrInvalidRequest } - if IsEnterpriseToken(id) { + normalizedID := normalizeEnterpriseTokenToID(id) + if IsEnterpriseTokenId(normalizedID) { return logical.ErrorResponse("enterprise token cannot be revoked"), nil } @@ -3444,7 +3446,19 @@ func (ts *TokenStore) handleLookup(ctx context.Context, req *logical.Request, da return logical.ErrorResponse("missing token ID"), logical.ErrInvalidRequest } if IsEnterpriseToken(id) { - id = getEnterpriseTokenId(req.EnterpriseTokenMetadata) + // If the token specified in the request body is different from the caller's + // token, resolve the token ID based on the body token's claims (JTI) instead + // of req.EnterpriseTokenMetadata, otherwise we may silently return the caller's + // own token entry or fail for non-Enterprise token callers. + if id == req.ClientToken { + id = getEnterpriseTokenId(req.EnterpriseTokenMetadata) + } else { + resolvedID, err := resolveEnterpriseTokenIDForLookup(id) + if err != nil { + return logical.ErrorResponse("invalid token"), logical.ErrInvalidRequest + } + id = resolvedID + } } lock := locksutil.LockForKey(ts.tokenLocks, id) lock.RLock() @@ -3557,7 +3571,8 @@ func (ts *TokenStore) handleRenew(ctx context.Context, req *logical.Request, dat if id == "" { return logical.ErrorResponse("missing token ID"), logical.ErrInvalidRequest } - if IsEnterpriseToken(id) { + normalizedID := normalizeEnterpriseTokenToID(id) + if IsEnterpriseTokenId(normalizedID) { return logical.ErrorResponse("enterprise tokens cannot be renewed"), nil } incrementRaw := data.Get("increment").(int) diff --git a/vault/token_store_ce.go b/vault/token_store_ce.go index 8a60254e70..8685c7dbda 100644 --- a/vault/token_store_ce.go +++ b/vault/token_store_ce.go @@ -16,6 +16,10 @@ func getEnterpriseTokenId(_ string) string { return "" } +func normalizeEnterpriseTokenToID(token string) string { + return token +} + func (ts *TokenStore) handleTidyEnterpriseTokens(_ context.Context, _ *namespace.Namespace, _ *multierror.Error) error { return nil } diff --git a/vault/version_store_ce.go b/vault/version_store_ce.go index 04baeb05d8..2e0d8e3cd6 100644 --- a/vault/version_store_ce.go +++ b/vault/version_store_ce.go @@ -8,3 +8,7 @@ package vault func IsEnterpriseToken(token string) bool { return false } + +func IsEnterpriseTokenId(tokenID string) bool { + return false +}