diff --git a/vault/logical_system_use_case_billing.go b/vault/logical_system_use_case_billing.go index 428a3ce250..484547dd3c 100644 --- a/vault/logical_system_use_case_billing.go +++ b/vault/logical_system_use_case_billing.go @@ -19,6 +19,9 @@ const ( WarningRefreshIgnoredOnStandby = "refresh_data parameter is supported only on the active node. " + "Since this parameter was set on a performance standby, the billing data was not refreshed " + "and retrieved from storage without update." + + WarningStartEndMonthOutOfRetentionRange = "the specified start_month and/or end_month fall outside the range of the current billing data retention period." + + "Months that are not covered in the retention period will show a zero updated_at timestamp and no metrics." ) func (b *SystemBackend) useCaseConsumptionBillingPaths() []*framework.Path { @@ -31,18 +34,28 @@ func (b *SystemBackend) useCaseConsumptionBillingPaths() []*framework.Path { Description: "If set, updates the billing counts for the current month before returning. This is an expensive operation with potential performance impact and should be used sparingly.", Query: true, }, + "start_month": { + Type: framework.TypeString, + Description: "Start month in YYYY-MM format (inclusive). If not specified, defaults to the oldest available month within BillingRetentionMonths.", + Query: true, + }, + "end_month": { + Type: framework.TypeString, + Description: "End month in YYYY-MM format (inclusive). If not specified, defaults to the current month.", + Query: true, + }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.ReadOperation: &framework.PathOperation{ Callback: b.handleUseCaseConsumption, - Summary: fmt.Sprintf("Reports consumption billing metrics for %d months (current month + previous %d months).", billing.BillingRetentionMonths, billing.BillingRetentionMonths-1), + Summary: "Reports consumption billing metrics on a monthly granularity.", Responses: map[int][]framework.Response{ http.StatusOK: {{ Description: http.StatusText(http.StatusOK), Fields: map[string]*framework.FieldSchema{ "months": { Type: framework.TypeSlice, - Description: fmt.Sprintf("List of monthly billing data for %d months (current month + previous %d months).", billing.BillingRetentionMonths, billing.BillingRetentionMonths-1), + Description: "List of monthly billing data.", }, }, }}, @@ -77,24 +90,31 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic refreshData = false } - // Build billing data for BillingRetentionMonths (current month + previous months) - months := make([]interface{}, 0, billing.BillingRetentionMonths) - - // Handle current month first (with optional refresh) - currentMonthTime := timeutil.StartOfMonth(currentMonth) - currentMonthData, err := b.buildMonthBillingData(ctx, currentMonthTime, refreshData) + startMonth, endMonth, isOutOfRetention, err := parseStartEndMonths(data, currentMonth) if err != nil { - return nil, fmt.Errorf("error building billing data for month %s: %w", currentMonthTime.Format("2006-01"), err) + return nil, err } - months = append(months, currentMonthData) - // Handle previous months (no refresh needed) - for i := 1; i < billing.BillingRetentionMonths; i++ { - monthTime := timeutil.StartOfMonth(currentMonth).AddDate(0, -i, 0) + if isOutOfRetention { + warnings = append(warnings, WarningStartEndMonthOutOfRetentionRange) + } - monthData, err := b.buildMonthBillingData(ctx, monthTime, false) + // Build list of months to retrieve (from end to start, newest first) + monthsToRetrieve := []time.Time{} + for month := endMonth; !month.Before(startMonth); month = month.AddDate(0, -1, 0) { + monthsToRetrieve = append(monthsToRetrieve, month) + } + + // Build billing data for requested months + months := make([]interface{}, 0, len(monthsToRetrieve)) + + for _, month := range monthsToRetrieve { + // Only refresh current month if refresh_data is true + shouldRefresh := refreshData && month.Equal(timeutil.StartOfMonth(currentMonth)) + + monthData, err := b.buildMonthBillingData(ctx, month, shouldRefresh) if err != nil { - return nil, fmt.Errorf("error building billing data for month %s: %w", monthTime.Format("2006-01"), err) + return nil, fmt.Errorf("error building billing data for month %s: %w", month.Format("2006-01"), err) } months = append(months, monthData) @@ -110,6 +130,48 @@ func (b *SystemBackend) handleUseCaseConsumption(ctx context.Context, req *logic }, nil } +// parseStartEndMonths parses the start and end month parameters from the request and validates if they are valid. +// If they are outside of the BillingRetentionMonths range, it returns a warning. If no parameter is specified, +// the start and end defaults to the start of the BillingRetentionMonths range and the current month, respectively. +func parseStartEndMonths(data *framework.FieldData, currentMonth time.Time) (time.Time, time.Time, bool, error) { + defaultStartMonth := timeutil.StartOfMonth(currentMonth).AddDate(0, -billing.BillingRetentionMonths+1, 0) + defaultEndMonth := timeutil.StartOfMonth(currentMonth) + + parseMonth := func(key string, defaultMonth time.Time) (time.Time, error) { + if monthStr := data.Get(key).(string); monthStr != "" { + return time.Parse("2006-01", monthStr) + } + return defaultMonth, nil + } + + var startMonth, endMonth time.Time + var isOutOfRetention bool + var err error + + startMonth, err = parseMonth("start_month", defaultStartMonth) + if err != nil { + return time.Time{}, time.Time{}, false, fmt.Errorf("invalid start_month format: %w", err) + } + + endMonth, err = parseMonth("end_month", defaultEndMonth) + if err != nil { + return time.Time{}, time.Time{}, false, fmt.Errorf("invalid end_month format: %w", err) + } + + if startMonth.After(endMonth) { + return time.Time{}, time.Time{}, false, fmt.Errorf("start_month is later than end_month") + } + + // We don't need to check for startMonth after the current month because either an even later endMonth is + // specified which would be caught by the second condition, or no end was set and it defaulted to the current month, + // which would have been caught in the check above. Vice versa for endMonth before the default start month. + if startMonth.Before(defaultStartMonth) || endMonth.After(defaultEndMonth) { + isOutOfRetention = true + } + + return startMonth, endMonth, isOutOfRetention, nil +} + // buildMonthBillingData constructs billing data for a specific month func (b *SystemBackend) buildMonthBillingData(ctx context.Context, month time.Time, refreshData bool) (map[string]interface{}, error) { currentMonth := timeutil.StartOfMonth(time.Now().UTC()) diff --git a/vault/logical_system_use_case_billing_test.go b/vault/logical_system_use_case_billing_test.go index b3229cd166..949567dd67 100644 --- a/vault/logical_system_use_case_billing_test.go +++ b/vault/logical_system_use_case_billing_test.go @@ -76,6 +76,170 @@ func TestSystemBackend_BillingOverviewMonthFormat(t *testing.T) { } } +// TestSystemBackend_BillingOverview_StartEndMonthParams tests the billing overview +// endpoint with different combinations of start_month and end_month parameters. It +// verifies that the correct range of months is returned along with any expected warnings +// or errors. +func TestSystemBackend_BillingOverview_StartEndMonthParams(t *testing.T) { + now := time.Now().UTC() + currentMonth := now.Format("2006-01") + previousMonth := timeutil.StartOfPreviousMonth(now).Format("2006-01") + nextMonth := timeutil.StartOfNextMonth(now).Format("2006-01") + twoMonthsAfterCurrent := timeutil.StartOfMonth(now).AddDate(0, 2, 0).Format("2006-01") + retentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.BillingRetentionMonths+1, 0).Format("2006-01") + beforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.BillingRetentionMonths, 0).Format("2006-01") + twoMonthsBeforeRetentionStart := timeutil.StartOfMonth(now).AddDate(0, -billing.BillingRetentionMonths-1, 0).Format("2006-01") + + testCases := []struct { + name string + startMonth interface{} + endMonth interface{} + expectedMonths int + expectedWarning string + expectedError string + }{ + { + name: "start and end in retention period", + startMonth: previousMonth, + endMonth: currentMonth, + expectedMonths: 2, + }, + { + name: "start before retention period, default end", + startMonth: beforeRetentionStart, + expectedMonths: billing.BillingRetentionMonths + 1, + expectedWarning: WarningStartEndMonthOutOfRetentionRange, + }, + { + name: "end after retention period, default start", + endMonth: nextMonth, + expectedMonths: billing.BillingRetentionMonths + 1, + expectedWarning: WarningStartEndMonthOutOfRetentionRange, + }, + { + name: "start is exactly start of retention period", + startMonth: retentionStart, + endMonth: previousMonth, + expectedMonths: billing.BillingRetentionMonths - 1, + }, + { + name: "start and end after retention period", + startMonth: nextMonth, + endMonth: twoMonthsAfterCurrent, + expectedMonths: 2, + expectedWarning: WarningStartEndMonthOutOfRetentionRange, + }, + { + name: "start and end before retention period", + startMonth: twoMonthsBeforeRetentionStart, + endMonth: beforeRetentionStart, + expectedMonths: 2, + expectedWarning: WarningStartEndMonthOutOfRetentionRange, + }, + { + name: "start after retention period, default end", + startMonth: nextMonth, + expectedError: "start_month is later than end_month", + }, + { + name: "no parameters, default start and end", + expectedMonths: billing.BillingRetentionMonths, + }, + { + name: "start after end", + startMonth: previousMonth, + endMonth: retentionStart, + expectedError: "start_month is later than end_month", + }, + { + name: "same month", + startMonth: currentMonth, + endMonth: currentMonth, + expectedMonths: 1, + }, + { + name: "invalid date format", + startMonth: "2023/01", + endMonth: previousMonth, + expectedError: "invalid start_month format", + }, + { + name: "invalid month", + startMonth: "2023-13", + endMonth: previousMonth, + expectedError: "invalid start_month format", + }, + { + name: "invalid data type", + startMonth: previousMonth, + endMonth: 45, + expectedError: "invalid end_month format", + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + _, b, _ := testCoreSystemBackend(t) + ctx := namespace.RootContext(nil) + + req := logical.TestRequest(t, logical.ReadOperation, "billing/overview") + req.Data["start_month"] = test.startMonth + req.Data["end_month"] = test.endMonth + resp, err := b.HandleRequest(ctx, req) + + if test.expectedError != "" { + require.Nil(t, resp) + require.Error(t, err) + require.Contains(t, err.Error(), test.expectedError) + return + } + + require.NoError(t, err) + require.NotNil(t, resp) + + if test.expectedWarning != "" { + require.NotEmpty(t, resp.Warnings) + require.Contains(t, resp.Warnings, test.expectedWarning) + } else { + require.Empty(t, resp.Warnings) + } + + // Verify the correct number of months are returned + months := resp.Data["months"].([]interface{}) + require.Len(t, months, test.expectedMonths) + + // expected start and end months are the test parameters if specified, + // or default to the retention start and current month + var expectedStartMonth, expectedEndMonth string + if test.startMonth != nil { + expectedStartMonth = test.startMonth.(string) + } else { + expectedStartMonth = retentionStart + } + if test.endMonth != nil { + expectedEndMonth = test.endMonth.(string) + } else { + expectedEndMonth = currentMonth + } + + // Months are ordered from most recent to oldest, so the first month returned + // should be the expected endMonth and the last month the expected startMonth + firstMonth, ok := months[0].(map[string]interface{}) + require.True(t, ok) + firstMonthStr, ok := firstMonth["month"].(string) + require.True(t, ok) + require.Equal(t, expectedEndMonth, firstMonthStr) + + lastMonth, ok := months[len(months)-1].(map[string]interface{}) + require.True(t, ok) + lastMonthStr, ok := lastMonth["month"].(string) + require.True(t, ok) + require.Equal(t, expectedStartMonth, lastMonthStr) + }) + } +} + // TestSystemBackend_BillingOverview_WithMetrics tests the billing overview endpoint // with actual KV secrets created to generate billing metrics. It verifies that KV v2 // secrets are properly counted in billing, the static_secrets metric appears in the