mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-12 16:16:47 +02:00
Merge remote-tracking branch 'remotes/from/ce/release/2.x.x' into release/2.x.x
This commit is contained in:
commit
0aab908fb9
3
changelog/_14536.txt
Normal file
3
changelog/_14536.txt
Normal file
@ -0,0 +1,3 @@
|
||||
```release-note:improvement
|
||||
consumption-billing: Added consumption billing metrics for OIDC tokens.
|
||||
```
|
||||
@ -5,10 +5,18 @@ package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/api"
|
||||
"github.com/hashicorp/vault/builtin/credential/userpass"
|
||||
"github.com/hashicorp/vault/helper/namespace"
|
||||
vaulthttp "github.com/hashicorp/vault/http"
|
||||
"github.com/hashicorp/vault/internalshared/configutil"
|
||||
"github.com/hashicorp/vault/sdk/logical"
|
||||
"github.com/hashicorp/vault/vault"
|
||||
"github.com/hashicorp/vault/vault/billing"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -104,3 +112,193 @@ func TestGcpKmsDataProtectionCallCounts(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint64(5), counts)
|
||||
}
|
||||
|
||||
// TestOidcTokenBillingBothMethods tests OIDC token billing for both token creation methods:
|
||||
// 1. Simple role-based tokens via identity/oidc/token/{role}
|
||||
// 2. Provider-based tokens via the full authorization code flow (pathOIDCToken)
|
||||
// This test runs on a single primary cluster and verifies that both methods correctly
|
||||
// track duration-adjusted billing counts.
|
||||
func TestOidcTokenBillingBothMethods(t *testing.T) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
CredentialBackends: map[string]logical.Factory{
|
||||
"userpass": userpass.Factory,
|
||||
},
|
||||
BillingConfig: billing.BillingConfig{
|
||||
MetricsUpdateCadence: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
clusterOpts := &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
DefaultHandlerProperties: vault.HandlerProperties{
|
||||
ListenerConfig: &configutil.Listener{},
|
||||
},
|
||||
NumCores: 1,
|
||||
}
|
||||
cluster := vault.NewTestCluster(t, coreConfig, clusterOpts)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
client := cluster.Cores[0].Client
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a policy that allows reading OIDC tokens
|
||||
oidcPolicy := `path "identity/oidc/token/*" { capabilities = ["read"] }`
|
||||
_, err := client.Logical().Write("sys/policy/oidc-reader", map[string]interface{}{
|
||||
"policy": oidcPolicy,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Enable userpass for entity creation
|
||||
err = client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{
|
||||
Type: "userpass",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a userpass user with the OIDC reader policy
|
||||
_, err = client.Logical().Write("auth/userpass/users/testuser", map[string]interface{}{
|
||||
"password": "testpass",
|
||||
"policies": "oidc-reader",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Login to create entity
|
||||
loginResp, err := client.Logical().Write("auth/userpass/login/testuser", map[string]interface{}{
|
||||
"password": "testpass",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
userToken := loginResp.Auth.ClientToken
|
||||
|
||||
// METHOD 1: Configure simple role-based OIDC tokens (identity/oidc/token/{role})
|
||||
// Create OIDC key
|
||||
_, err = client.Logical().Write("identity/oidc/key/role-key", map[string]interface{}{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create OIDC role with 1-hour TTL
|
||||
_, err = client.Logical().Write("identity/oidc/role/test-role", map[string]interface{}{
|
||||
"key": "role-key",
|
||||
"ttl": "1h",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the auto-generated client_id for the role
|
||||
secret, err := client.Logical().Read("identity/oidc/role/test-role")
|
||||
require.NoError(t, err)
|
||||
roleClientID := secret.Data["client_id"].(string)
|
||||
|
||||
// Configure the key to allow this role's client_id
|
||||
_, err = client.Logical().Write("identity/oidc/key/role-key", map[string]interface{}{
|
||||
"allowed_client_ids": roleClientID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// METHOD 2: Configure OIDC provider for authorization code flow (pathOIDCToken)
|
||||
// Create OIDC client with 2-hour ID token TTL and 1-hour access token TTL
|
||||
_, err = client.Logical().Write("identity/oidc/client/provider-client", map[string]interface{}{
|
||||
"redirect_uris": []string{"https://localhost:8251/callback"},
|
||||
"assignments": []string{"allow_all"},
|
||||
"id_token_ttl": "2h",
|
||||
"access_token_ttl": "1h",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read the client to get client_id and client_secret
|
||||
clientResp, err := client.Logical().Read("identity/oidc/client/provider-client")
|
||||
require.NoError(t, err)
|
||||
providerClientID := clientResp.Data["client_id"].(string)
|
||||
providerClientSecret := clientResp.Data["client_secret"].(string)
|
||||
|
||||
// Create OIDC provider
|
||||
_, err = client.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{
|
||||
"allowed_client_ids": []string{providerClientID},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate tokens using METHOD 1: role-based (identity/oidc/token/{role})
|
||||
// 2 tokens × 1 hour = 2 hours
|
||||
client.SetToken(userToken)
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err := client.Logical().Read("identity/oidc/token/test-role")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Generate tokens using METHOD 2: provider-based (authorization code flow)
|
||||
// 3 tokens × 2 hours (max of 2h ID token and 1h access token) = 6 hours
|
||||
client.SetToken(client.Token()) // Reset to root token
|
||||
for i := 0; i < 3; i++ {
|
||||
code := getAuthorizationCode(t, ctx, client, "test-provider", providerClientID, userToken)
|
||||
exchangeCodeForToken(t, ctx, client, "test-provider", code, providerClientID, providerClientSecret)
|
||||
}
|
||||
|
||||
currentMonth := time.Now().UTC()
|
||||
|
||||
// Total expected: 2 hours (role-based) + 6 hours (provider-based) = 8 hours
|
||||
expectedDurationAdjustedCount := vault.DurationAdjustedTokenCount(8 * time.Hour.Seconds())
|
||||
delta := 0.0001
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
count, err := core.GetStoredOidcDurationAdjustedCount(ctx, currentMonth)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return count >= (expectedDurationAdjustedCount-delta) && count <= (expectedDurationAdjustedCount+delta)
|
||||
}, 10*time.Second, 500*time.Millisecond, "OIDC count not flushed to storage within timeout")
|
||||
|
||||
// Verify exact value
|
||||
count, err := core.GetStoredOidcDurationAdjustedCount(ctx, currentMonth)
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, expectedDurationAdjustedCount, count, delta,
|
||||
"Expected 8 hours total: 2 hours from role-based tokens (2×1h) + 6 hours from provider tokens (3×2h)")
|
||||
}
|
||||
|
||||
// exchangeCodeForToken is a test helper function to exchange authorization code for tokens via the OIDC provider token endpoint
|
||||
func exchangeCodeForToken(t *testing.T, ctx context.Context, client *api.Client, providerName, code, clientID, clientSecret string) {
|
||||
// Prepare the token request with basic auth
|
||||
req := client.NewRequest("POST", "/v1/identity/oidc/provider/"+providerName+"/token")
|
||||
req.Headers.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(clientID+":"+clientSecret)))
|
||||
req.BodyBytes = []byte(fmt.Sprintf(`{"code":"%s","grant_type":"authorization_code","redirect_uri":"https://localhost:8251/callback"}`, code))
|
||||
req.Headers.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.RawRequestWithContext(ctx, req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, 200, resp.StatusCode)
|
||||
}
|
||||
|
||||
// getAuthorizationCode is a test helper function to get authorization code from the OIDC provider authorize endpoint
|
||||
func getAuthorizationCode(t *testing.T, ctx context.Context, client *api.Client, providerName, clientID, userToken string) string {
|
||||
// Save the original token
|
||||
originalToken := client.Token()
|
||||
|
||||
// Use the user token (from userpass login) to authorize
|
||||
client.SetToken(userToken)
|
||||
|
||||
// Use RawRequestWithContext to make the authorize request
|
||||
req := client.NewRequest("POST", "/v1/identity/oidc/provider/"+providerName+"/authorize")
|
||||
req.BodyBytes, _ = json.Marshal(map[string]interface{}{
|
||||
"client_id": clientID,
|
||||
"scope": "openid",
|
||||
"redirect_uri": "https://localhost:8251/callback",
|
||||
"response_type": "code",
|
||||
"state": "test-state",
|
||||
"nonce": "test-nonce",
|
||||
})
|
||||
|
||||
resp, err := client.RawRequestWithContext(ctx, req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Restore the original token
|
||||
client.SetToken(originalToken)
|
||||
|
||||
// Parse the JSON response
|
||||
var authResult struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
err = json.NewDecoder(resp.Body).Decode(&authResult)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, authResult.Code)
|
||||
|
||||
return authResult.Code
|
||||
}
|
||||
|
||||
@ -1081,7 +1081,7 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical.
|
||||
}
|
||||
|
||||
// Track OIDC token generation for billing
|
||||
// Store raw count and duration (seconds), normalize later during storage flush
|
||||
// Store duration (seconds), normalize later during storage flush
|
||||
validity := expiry.Seconds()
|
||||
if i.billingCounter != nil {
|
||||
i.billingCounter.IncrementOidcTokenCount(validity)
|
||||
|
||||
@ -2157,6 +2157,12 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request,
|
||||
return tokenResponse(nil, ErrTokenServerError, err.Error())
|
||||
}
|
||||
|
||||
// Track OIDC token generated for billing
|
||||
// Store duration (seconds), normalize later during storage flush
|
||||
if i.billingCounter != nil {
|
||||
i.billingCounter.IncrementOidcTokenCount(getMaxTokenTTL(client.AccessTokenTTL, client.IDTokenTTL).Seconds())
|
||||
}
|
||||
|
||||
return tokenResponse(map[string]interface{}{
|
||||
"token_type": "Bearer",
|
||||
"access_token": accessToken.ID,
|
||||
@ -2165,6 +2171,14 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request,
|
||||
}, "", "")
|
||||
}
|
||||
|
||||
// getMaxTokenTTL returns the maximum of the given access token and ID token
|
||||
func getMaxTokenTTL(accessTokenTTL, idTokenTTL time.Duration) time.Duration {
|
||||
if accessTokenTTL > idTokenTTL {
|
||||
return accessTokenTTL
|
||||
}
|
||||
return idTokenTTL
|
||||
}
|
||||
|
||||
// tokenResponse returns the OIDC Token Response. An error response is
|
||||
// returned if the given error code is non-empty. For details, see spec at
|
||||
// - https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||
|
||||
@ -3724,3 +3724,70 @@ func TestOIDC_Path_OpenIDProviderConfig_ProviderDoesNotExist(t *testing.T) {
|
||||
t.Fatalf("expected empty response but got success; error:\n%v\nresp: %#v", err, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetMaxTokenTTL tests the getMaxTokenTTL utility function
|
||||
func TestGetMaxTokenTTL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessTokenTTL time.Duration
|
||||
idTokenTTL time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "access token TTL is greater",
|
||||
accessTokenTTL: 2 * time.Hour,
|
||||
idTokenTTL: 1 * time.Hour,
|
||||
expected: 2 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "id token TTL is greater",
|
||||
accessTokenTTL: 1 * time.Hour,
|
||||
idTokenTTL: 2 * time.Hour,
|
||||
expected: 2 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "both TTLs are equal",
|
||||
accessTokenTTL: 1 * time.Hour,
|
||||
idTokenTTL: 1 * time.Hour,
|
||||
expected: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "access token TTL is zero",
|
||||
accessTokenTTL: 0,
|
||||
idTokenTTL: 1 * time.Hour,
|
||||
expected: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "id token TTL is zero",
|
||||
accessTokenTTL: 1 * time.Hour,
|
||||
idTokenTTL: 0,
|
||||
expected: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "both TTLs are zero",
|
||||
accessTokenTTL: 0,
|
||||
idTokenTTL: 0,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "large TTL values",
|
||||
accessTokenTTL: 24 * time.Hour,
|
||||
idTokenTTL: 48 * time.Hour,
|
||||
expected: 48 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "small TTL values",
|
||||
accessTokenTTL: 30 * time.Second,
|
||||
idTokenTTL: 15 * time.Second,
|
||||
expected: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getMaxTokenTTL(tt.accessTokenTTL, tt.idTokenTTL)
|
||||
require.Equal(t, tt.expected, result, "getMaxTokenTTL(%v, %v) = %v, want %v",
|
||||
tt.accessTokenTTL, tt.idTokenTTL, result, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user