From 44bace874fbe78098b16cc205ff27737ddec278e Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Wed, 6 May 2026 14:38:10 -0600 Subject: [PATCH] Add oidc tokens to consumption billing for tokens created using pathOIDCToken (#14536) (#14555) (#14566) * add oidc tokens to consumption billing for tokens created using the oidc provider * moving oidc test to ce * add changelog * modify tag Co-authored-by: akshya96 <87045294+akshya96@users.noreply.github.com> --- changelog/_14536.txt | 3 + vault/external_tests/billing/billing_test.go | 198 +++++++++++++++++++ vault/identity_store_oidc.go | 2 +- vault/identity_store_oidc_provider.go | 14 ++ vault/identity_store_oidc_provider_test.go | 67 +++++++ 5 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 changelog/_14536.txt diff --git a/changelog/_14536.txt b/changelog/_14536.txt new file mode 100644 index 0000000000..76729af580 --- /dev/null +++ b/changelog/_14536.txt @@ -0,0 +1,3 @@ +```release-note:improvement +consumption-billing: Added consumption billing metrics for OIDC tokens. +``` \ No newline at end of file diff --git a/vault/external_tests/billing/billing_test.go b/vault/external_tests/billing/billing_test.go index 1f4a888e85..2937d7658a 100644 --- a/vault/external_tests/billing/billing_test.go +++ b/vault/external_tests/billing/billing_test.go @@ -5,10 +5,18 @@ package billing import ( "context" + "encoding/base64" + "encoding/json" + "fmt" "testing" "time" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/helper/namespace" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/internalshared/configutil" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault/billing" "github.com/stretchr/testify/require" @@ -104,3 +112,193 @@ func TestGcpKmsDataProtectionCallCounts(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(5), counts) } + +// TestOidcTokenBillingBothMethods tests OIDC token billing for both token creation methods: +// 1. Simple role-based tokens via identity/oidc/token/{role} +// 2. Provider-based tokens via the full authorization code flow (pathOIDCToken) +// This test runs on a single primary cluster and verifies that both methods correctly +// track duration-adjusted billing counts. +func TestOidcTokenBillingBothMethods(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + BillingConfig: billing.BillingConfig{ + MetricsUpdateCadence: 5 * time.Second, + }, + } + clusterOpts := &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + DefaultHandlerProperties: vault.HandlerProperties{ + ListenerConfig: &configutil.Listener{}, + }, + NumCores: 1, + } + cluster := vault.NewTestCluster(t, coreConfig, clusterOpts) + defer cluster.Cleanup() + + core := cluster.Cores[0].Core + vault.TestWaitActive(t, core) + client := cluster.Cores[0].Client + ctx := context.Background() + + // Create a policy that allows reading OIDC tokens + oidcPolicy := `path "identity/oidc/token/*" { capabilities = ["read"] }` + _, err := client.Logical().Write("sys/policy/oidc-reader", map[string]interface{}{ + "policy": oidcPolicy, + }) + require.NoError(t, err) + + // Enable userpass for entity creation + err = client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + require.NoError(t, err) + + // Create a userpass user with the OIDC reader policy + _, err = client.Logical().Write("auth/userpass/users/testuser", map[string]interface{}{ + "password": "testpass", + "policies": "oidc-reader", + }) + require.NoError(t, err) + + // Login to create entity + loginResp, err := client.Logical().Write("auth/userpass/login/testuser", map[string]interface{}{ + "password": "testpass", + }) + require.NoError(t, err) + userToken := loginResp.Auth.ClientToken + + // METHOD 1: Configure simple role-based OIDC tokens (identity/oidc/token/{role}) + // Create OIDC key + _, err = client.Logical().Write("identity/oidc/key/role-key", map[string]interface{}{}) + require.NoError(t, err) + + // Create OIDC role with 1-hour TTL + _, err = client.Logical().Write("identity/oidc/role/test-role", map[string]interface{}{ + "key": "role-key", + "ttl": "1h", + }) + require.NoError(t, err) + + // Get the auto-generated client_id for the role + secret, err := client.Logical().Read("identity/oidc/role/test-role") + require.NoError(t, err) + roleClientID := secret.Data["client_id"].(string) + + // Configure the key to allow this role's client_id + _, err = client.Logical().Write("identity/oidc/key/role-key", map[string]interface{}{ + "allowed_client_ids": roleClientID, + }) + require.NoError(t, err) + + // METHOD 2: Configure OIDC provider for authorization code flow (pathOIDCToken) + // Create OIDC client with 2-hour ID token TTL and 1-hour access token TTL + _, err = client.Logical().Write("identity/oidc/client/provider-client", map[string]interface{}{ + "redirect_uris": []string{"https://localhost:8251/callback"}, + "assignments": []string{"allow_all"}, + "id_token_ttl": "2h", + "access_token_ttl": "1h", + }) + require.NoError(t, err) + + // Read the client to get client_id and client_secret + clientResp, err := client.Logical().Read("identity/oidc/client/provider-client") + require.NoError(t, err) + providerClientID := clientResp.Data["client_id"].(string) + providerClientSecret := clientResp.Data["client_secret"].(string) + + // Create OIDC provider + _, err = client.Logical().Write("identity/oidc/provider/test-provider", map[string]interface{}{ + "allowed_client_ids": []string{providerClientID}, + }) + require.NoError(t, err) + + // Generate tokens using METHOD 1: role-based (identity/oidc/token/{role}) + // 2 tokens × 1 hour = 2 hours + client.SetToken(userToken) + for i := 0; i < 2; i++ { + _, err := client.Logical().Read("identity/oidc/token/test-role") + require.NoError(t, err) + } + + // Generate tokens using METHOD 2: provider-based (authorization code flow) + // 3 tokens × 2 hours (max of 2h ID token and 1h access token) = 6 hours + client.SetToken(client.Token()) // Reset to root token + for i := 0; i < 3; i++ { + code := getAuthorizationCode(t, ctx, client, "test-provider", providerClientID, userToken) + exchangeCodeForToken(t, ctx, client, "test-provider", code, providerClientID, providerClientSecret) + } + + currentMonth := time.Now().UTC() + + // Total expected: 2 hours (role-based) + 6 hours (provider-based) = 8 hours + expectedDurationAdjustedCount := vault.DurationAdjustedTokenCount(8 * time.Hour.Seconds()) + delta := 0.0001 + + require.Eventually(t, func() bool { + count, err := core.GetStoredOidcDurationAdjustedCount(ctx, currentMonth) + if err != nil { + return false + } + return count >= (expectedDurationAdjustedCount-delta) && count <= (expectedDurationAdjustedCount+delta) + }, 10*time.Second, 500*time.Millisecond, "OIDC count not flushed to storage within timeout") + + // Verify exact value + count, err := core.GetStoredOidcDurationAdjustedCount(ctx, currentMonth) + require.NoError(t, err) + require.InDelta(t, expectedDurationAdjustedCount, count, delta, + "Expected 8 hours total: 2 hours from role-based tokens (2×1h) + 6 hours from provider tokens (3×2h)") +} + +// exchangeCodeForToken is a test helper function to exchange authorization code for tokens via the OIDC provider token endpoint +func exchangeCodeForToken(t *testing.T, ctx context.Context, client *api.Client, providerName, code, clientID, clientSecret string) { + // Prepare the token request with basic auth + req := client.NewRequest("POST", "/v1/identity/oidc/provider/"+providerName+"/token") + req.Headers.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(clientID+":"+clientSecret))) + req.BodyBytes = []byte(fmt.Sprintf(`{"code":"%s","grant_type":"authorization_code","redirect_uri":"https://localhost:8251/callback"}`, code)) + req.Headers.Set("Content-Type", "application/json") + + resp, err := client.RawRequestWithContext(ctx, req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, 200, resp.StatusCode) +} + +// getAuthorizationCode is a test helper function to get authorization code from the OIDC provider authorize endpoint +func getAuthorizationCode(t *testing.T, ctx context.Context, client *api.Client, providerName, clientID, userToken string) string { + // Save the original token + originalToken := client.Token() + + // Use the user token (from userpass login) to authorize + client.SetToken(userToken) + + // Use RawRequestWithContext to make the authorize request + req := client.NewRequest("POST", "/v1/identity/oidc/provider/"+providerName+"/authorize") + req.BodyBytes, _ = json.Marshal(map[string]interface{}{ + "client_id": clientID, + "scope": "openid", + "redirect_uri": "https://localhost:8251/callback", + "response_type": "code", + "state": "test-state", + "nonce": "test-nonce", + }) + + resp, err := client.RawRequestWithContext(ctx, req) + require.NoError(t, err) + defer resp.Body.Close() + + // Restore the original token + client.SetToken(originalToken) + + // Parse the JSON response + var authResult struct { + Code string `json:"code"` + State string `json:"state"` + } + err = json.NewDecoder(resp.Body).Decode(&authResult) + require.NoError(t, err) + require.NotEmpty(t, authResult.Code) + + return authResult.Code +} diff --git a/vault/identity_store_oidc.go b/vault/identity_store_oidc.go index d7b80abd95..7ff7a8a5b7 100644 --- a/vault/identity_store_oidc.go +++ b/vault/identity_store_oidc.go @@ -1081,7 +1081,7 @@ func (i *IdentityStore) pathOIDCGenerateToken(ctx context.Context, req *logical. } // Track OIDC token generation for billing - // Store raw count and duration (seconds), normalize later during storage flush + // Store duration (seconds), normalize later during storage flush validity := expiry.Seconds() if i.billingCounter != nil { i.billingCounter.IncrementOidcTokenCount(validity) diff --git a/vault/identity_store_oidc_provider.go b/vault/identity_store_oidc_provider.go index 3d967443fe..5716a0331c 100644 --- a/vault/identity_store_oidc_provider.go +++ b/vault/identity_store_oidc_provider.go @@ -2157,6 +2157,12 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request, return tokenResponse(nil, ErrTokenServerError, err.Error()) } + // Track OIDC token generated for billing + // Store duration (seconds), normalize later during storage flush + if i.billingCounter != nil { + i.billingCounter.IncrementOidcTokenCount(getMaxTokenTTL(client.AccessTokenTTL, client.IDTokenTTL).Seconds()) + } + return tokenResponse(map[string]interface{}{ "token_type": "Bearer", "access_token": accessToken.ID, @@ -2165,6 +2171,14 @@ func (i *IdentityStore) pathOIDCToken(ctx context.Context, req *logical.Request, }, "", "") } +// getMaxTokenTTL returns the maximum of the given access token and ID token +func getMaxTokenTTL(accessTokenTTL, idTokenTTL time.Duration) time.Duration { + if accessTokenTTL > idTokenTTL { + return accessTokenTTL + } + return idTokenTTL +} + // tokenResponse returns the OIDC Token Response. An error response is // returned if the given error code is non-empty. For details, see spec at // - https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse diff --git a/vault/identity_store_oidc_provider_test.go b/vault/identity_store_oidc_provider_test.go index 7aa73c386d..0fda840cc3 100644 --- a/vault/identity_store_oidc_provider_test.go +++ b/vault/identity_store_oidc_provider_test.go @@ -3724,3 +3724,70 @@ func TestOIDC_Path_OpenIDProviderConfig_ProviderDoesNotExist(t *testing.T) { t.Fatalf("expected empty response but got success; error:\n%v\nresp: %#v", err, resp) } } + +// TestGetMaxTokenTTL tests the getMaxTokenTTL utility function +func TestGetMaxTokenTTL(t *testing.T) { + tests := []struct { + name string + accessTokenTTL time.Duration + idTokenTTL time.Duration + expected time.Duration + }{ + { + name: "access token TTL is greater", + accessTokenTTL: 2 * time.Hour, + idTokenTTL: 1 * time.Hour, + expected: 2 * time.Hour, + }, + { + name: "id token TTL is greater", + accessTokenTTL: 1 * time.Hour, + idTokenTTL: 2 * time.Hour, + expected: 2 * time.Hour, + }, + { + name: "both TTLs are equal", + accessTokenTTL: 1 * time.Hour, + idTokenTTL: 1 * time.Hour, + expected: 1 * time.Hour, + }, + { + name: "access token TTL is zero", + accessTokenTTL: 0, + idTokenTTL: 1 * time.Hour, + expected: 1 * time.Hour, + }, + { + name: "id token TTL is zero", + accessTokenTTL: 1 * time.Hour, + idTokenTTL: 0, + expected: 1 * time.Hour, + }, + { + name: "both TTLs are zero", + accessTokenTTL: 0, + idTokenTTL: 0, + expected: 0, + }, + { + name: "large TTL values", + accessTokenTTL: 24 * time.Hour, + idTokenTTL: 48 * time.Hour, + expected: 48 * time.Hour, + }, + { + name: "small TTL values", + accessTokenTTL: 30 * time.Second, + idTokenTTL: 15 * time.Second, + expected: 30 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getMaxTokenTTL(tt.accessTokenTTL, tt.idTokenTTL) + require.Equal(t, tt.expected, result, "getMaxTokenTTL(%v, %v) = %v, want %v", + tt.accessTokenTTL, tt.idTokenTTL, result, tt.expected) + }) + } +}