vault/limits/registry.go
Mike Palmiotto e4a11ae7cd
Request Limiter Reload tests (#25126)
This PR introduces a new testonly endpoint for introspecting the
RequestLimiter state. It makes use of the endpoint to verify that changes to
the request_limiter config are honored across reload.

In the future, we may choose to make the sys/internal/request-limiter/status
endpoint available in normal binaries, but this is an expedient way to expose
the status for testing without having to rush the design.

In order to re-use as much of the existing command package utility funcionality
as possible without introducing sprawling code changes, I introduced a new
server_util.go and exported some fields via accessors.

The tests shook out a couple of bugs (including a deadlock and lack of
locking around the core limiterRegistry state).
2024-02-01 09:11:08 -05:00

223 lines
6.1 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package limits
import (
"os"
"strconv"
"sync"
"github.com/hashicorp/go-hclog"
)
const (
WriteLimiter = "write"
SpecialPathLimiter = "special-path"
LimitsBadEnvVariable = "failed to process limiter environment variable, using default"
)
// NOTE: Great care should be taken when setting any of these variables to avoid
// adverse affects in optimal request servicing. It is strongly advised that
// these variables not be used unless there is a very good reason. These are
// intentionally undocumented environment variables that may be removed in
// future versions of Vault.
const (
// EnvVaultDisableWriteLimiter is used to turn off the
// RequestLimiter for write-based HTTP methods.
EnvVaultDisableWriteLimiter = "VAULT_DISABLE_WRITE_LIMITER"
// EnvVaultWriteLimiterMin is used to modify the minimum
// concurrency limit for write-based HTTP methods.
EnvVaultWriteLimiterMin = "VAULT_WRITE_LIMITER_MIN"
// EnvVaultWriteLimiterMax is used to modify the maximum
// concurrency limit for write-based HTTP methods.
EnvVaultWriteLimiterMax = "VAULT_WRITE_LIMITER_MAX"
// EnvVaultDisablePathBasedRequestLimiting is used to turn off the
// RequestLimiter for special-cased paths, specified in
// Backend.PathsSpecial.
EnvVaultDisableSpecialPathLimiter = "VAULT_DISABLE_SPECIAL_PATH_LIMITER"
// EnvVaultSpecialPathLimiterMin is used to modify the minimum
// concurrency limit for write-based HTTP methods.
EnvVaultSpecialPathLimiterMin = "VAULT_SPECIAL_PATH_LIMITER_MIN"
// EnvVaultSpecialPathLimiterMax is used to modify the maximum
// concurrency limit for write-based HTTP methods.
EnvVaultSpecialPathLimiterMax = "VAULT_SPECIAL_PATH_LIMITER_MAX"
)
// LimiterRegistry holds the map of RequestLimiters mapped to keys.
type LimiterRegistry struct {
Limiters map[string]*RequestLimiter
Logger hclog.Logger
Enabled bool
sync.RWMutex
}
// NewLimiterRegistry is a basic LimiterRegistry constructor.
func NewLimiterRegistry(logger hclog.Logger) *LimiterRegistry {
return &LimiterRegistry{
Limiters: make(map[string]*RequestLimiter),
Logger: logger,
}
}
// processEnvVars consults Limiter-specific environment variables and tells the
// caller if the Limiter should be disabled. If not, it adjusts the passed-in
// limiterFlags as appropriate.
func (r *LimiterRegistry) processEnvVars(name string, flags *LimiterFlags, envDisabled, envMin, envMax string) bool {
envFlagsLogger := r.Logger.With("name", name)
if disabledRaw := os.Getenv(envDisabled); disabledRaw != "" {
disabled, err := strconv.ParseBool(disabledRaw)
if err != nil {
envFlagsLogger.Warn(LimitsBadEnvVariable,
"env", envDisabled,
"val", disabledRaw,
"default", false,
"error", err,
)
}
if disabled {
envFlagsLogger.Warn("limiter disabled by environment variable", "env", envDisabled, "val", disabledRaw)
return true
}
}
envFlags := &LimiterFlags{}
if minRaw := os.Getenv(envMin); minRaw != "" {
min, err := strconv.Atoi(minRaw)
if err != nil {
envFlagsLogger.Warn(LimitsBadEnvVariable,
"env", envMin,
"val", minRaw,
"default", flags.MinLimit,
"error", err,
)
} else {
envFlags.MinLimit = min
}
}
if maxRaw := os.Getenv(envMax); maxRaw != "" {
max, err := strconv.Atoi(maxRaw)
if err != nil {
envFlagsLogger.Warn(LimitsBadEnvVariable,
"env", envMax,
"val", maxRaw,
"default", flags.MaxLimit,
"error", err,
)
} else {
envFlags.MaxLimit = max
}
}
switch {
case envFlags.MinLimit == 0:
// Assume no environment variable was provided.
case envFlags.MinLimit > 0:
flags.MinLimit = envFlags.MinLimit
default:
r.Logger.Warn("min limit must be greater than zero, falling back to defaults", "minLimit", flags.MinLimit)
}
switch {
case envFlags.MaxLimit == 0:
// Assume no environment variable was provided.
case envFlags.MaxLimit > flags.MinLimit:
flags.MaxLimit = envFlags.MaxLimit
default:
r.Logger.Warn("max limit must be greater than min, falling back to defaults", "maxLimit", flags.MaxLimit)
}
return false
}
// Enable sets up a new LimiterRegistry and marks it Enabled.
func (r *LimiterRegistry) Enable() {
r.Lock()
defer r.Unlock()
if r.Enabled {
return
}
r.Logger.Info("enabling request limiters")
r.Limiters = map[string]*RequestLimiter{}
for name, flags := range DefaultLimiterFlags {
r.Register(name, flags)
}
r.Enabled = true
}
// Register creates a new request limiter and assigns it a slot in the
// LimiterRegistry. Locking should be done in the caller.
func (r *LimiterRegistry) Register(name string, flags LimiterFlags) {
var disabled bool
switch name {
case WriteLimiter:
disabled = r.processEnvVars(name, &flags,
EnvVaultDisableWriteLimiter,
EnvVaultWriteLimiterMin,
EnvVaultWriteLimiterMax,
)
if disabled {
return
}
case SpecialPathLimiter:
disabled = r.processEnvVars(name, &flags,
EnvVaultDisableSpecialPathLimiter,
EnvVaultSpecialPathLimiterMin,
EnvVaultSpecialPathLimiterMax,
)
if disabled {
return
}
default:
r.Logger.Warn("skipping invalid limiter type", "key", name)
return
}
// Always set the initial limit to min so the system can find its own
// equilibrium, since max might be too high.
flags.InitialLimit = flags.MinLimit
limiter, err := NewRequestLimiter(r.Logger.Named(name), name, flags)
if err != nil {
r.Logger.Error("failed to register limiter", "name", name, "error", err)
return
}
r.Limiters[name] = limiter
}
// Disable drops its references to underlying limiters.
func (r *LimiterRegistry) Disable() {
r.Lock()
defer r.Unlock()
if !r.Enabled {
return
}
r.Logger.Info("disabling request limiters")
// Any outstanding tokens will be flushed when their request completes, as
// they've already acquired a listener. Just drop the limiter references
// here and the garbage-collector should take care of the rest.
r.Limiters = map[string]*RequestLimiter{}
r.Enabled = false
}
// GetLimiter looks up a RequestLimiter by key in the LimiterRegistry.
func (r *LimiterRegistry) GetLimiter(key string) *RequestLimiter {
r.RLock()
defer r.RUnlock()
return r.Limiters[key]
}