Move Request Limiter to enterprise (#25615)

This commit is contained in:
Mike Palmiotto 2024-02-27 16:24:06 -05:00 committed by GitHub
parent df57ff46ff
commit b54ac98a0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 116 additions and 997 deletions

View File

@ -1,5 +1,5 @@
```release-note:feature
**Request Limiter**: Add adaptive concurrency limits to write-based HTTP
methods and special-case `pki/issue` requests to prevent overloading the Vault
server.
**Request Limiter (enterprise)**: Add adaptive concurrency limits to
write-based HTTP methods and special-case `pki/issue` requests to prevent
overloading the Vault server.
```

View File

@ -31,3 +31,7 @@ func entCheckStorageType(coreConfig *vault.CoreConfig) bool {
func entGetFIPSInfoKey() string {
return ""
}
func entGetRequestLimiterStatus(coreConfig vault.CoreConfig) string {
return ""
}

View File

@ -1,213 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build testonly
package command_testonly
import (
"os"
"sync"
"testing"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
func init() {
if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" {
os.Setenv(command.EnvVaultLicense, signed)
}
}
const (
baseHCL = `
backend "inmem" { }
disable_mlock = true
listener "tcp" {
address = "127.0.0.1:8209"
tls_disable = "true"
}
api_addr = "http://127.0.0.1:8209"
`
requestLimiterDisableHCL = `
request_limiter {
disable = true
}
`
requestLimiterEnableHCL = `
request_limiter {
disable = false
}
`
)
// TestServer_ReloadRequestLimiter tests a series of reloads and state
// transitions between RequestLimiter enable and disable.
func TestServer_ReloadRequestLimiter(t *testing.T) {
t.Parallel()
enabledResponse := &vault.RequestLimiterResponse{
GlobalDisabled: false,
ListenerDisabled: false,
Limiters: map[string]*vault.LimiterStatus{
limits.WriteLimiter: {
Enabled: true,
Flags: limits.DefaultLimiterFlags[limits.WriteLimiter],
},
limits.SpecialPathLimiter: {
Enabled: true,
Flags: limits.DefaultLimiterFlags[limits.SpecialPathLimiter],
},
},
}
disabledResponse := &vault.RequestLimiterResponse{
GlobalDisabled: true,
ListenerDisabled: false,
Limiters: map[string]*vault.LimiterStatus{
limits.WriteLimiter: {
Enabled: false,
},
limits.SpecialPathLimiter: {
Enabled: false,
},
},
}
cases := []struct {
name string
configAfter string
expectedResponse *vault.RequestLimiterResponse
}{
{
"disable after default",
baseHCL + requestLimiterDisableHCL,
disabledResponse,
},
{
"disable after disable",
baseHCL + requestLimiterDisableHCL,
disabledResponse,
},
{
"enable after disable",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
{
"default after enable",
baseHCL,
disabledResponse,
},
{
"default after default",
baseHCL,
disabledResponse,
},
{
"enable after default",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
{
"enable after enable",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
}
ui, srv := command.TestServerCommand(t)
f, err := os.CreateTemp(t.TempDir(), "")
require.NoErrorf(t, err, "error creating temp dir: %v", err)
_, err = f.WriteString(baseHCL)
require.NoErrorf(t, err, "cannot write temp file contents")
configPath := f.Name()
var output string
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
code := srv.Run([]string{"-config", configPath})
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Equal(t, 0, code, output)
}()
select {
case <-srv.StartedCh():
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
defer func() {
srv.ShutdownCh <- struct{}{}
wg.Wait()
}()
err = f.Close()
require.NoErrorf(t, err, "unable to close temp file")
// create a client and unseal vault
cli, err := srv.Client()
require.NoError(t, err)
require.NoError(t, cli.SetAddress("http://127.0.0.1:8209"))
initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1})
require.NoError(t, err)
_, err = cli.Sys().Unseal(initResp.Keys[0])
require.NoError(t, err)
cli.SetToken(initResp.RootToken)
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Contains(t, output, "Request Limiter: disabled")
verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) {
t.Helper()
statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status")
require.NoError(t, err)
require.NotNil(t, statusResp)
limitersResp, ok := statusResp.Data["request_limiter"]
require.True(t, ok)
require.NotNil(t, limitersResp)
var limiters *vault.RequestLimiterResponse
err = mapstructure.Decode(limitersResp, &limiters)
require.NoError(t, err)
require.NotNil(t, limiters)
require.Equal(t, expectedResponse, limiters)
}
// Start off with default disabled
verifyLimiters(t, disabledResponse)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Write the new contents and reload the server
f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
require.NoError(t, err)
defer f.Close()
_, err = f.WriteString(tc.configAfter)
require.NoErrorf(t, err, "cannot write temp file contents")
srv.SighupCh <- struct{}{}
select {
case <-srv.ReloadedCh():
case <-time.After(5 * time.Second):
t.Fatalf("test timed out")
}
verifyLimiters(t, tc.expectedResponse)
})
}
}

View File

@ -1437,15 +1437,15 @@ func (c *ServerCommand) Run(args []string) int {
info["HCP resource ID"] = config.HCPLinkConf.Resource.ID
}
requestLimiterStatus := entGetRequestLimiterStatus(coreConfig)
if requestLimiterStatus != "" {
infoKeys = append(infoKeys, "request_limiter")
info["request_limiter"] = requestLimiterStatus
}
infoKeys = append(infoKeys, "administrative namespace")
info["administrative namespace"] = config.AdministrativeNamespacePath
infoKeys = append(infoKeys, "request limiter")
info["request limiter"] = "disabled"
if config.RequestLimiter != nil && !config.RequestLimiter.Disable {
info["request limiter"] = "enabled"
}
sort.Strings(infoKeys)
c.UI.Output("==> Vault server configuration:\n")
@ -3118,12 +3118,6 @@ func createCoreConfig(c *ServerCommand, config *server.Config, backend physical.
AdministrativeNamespacePath: config.AdministrativeNamespacePath,
}
if config.RequestLimiter != nil {
coreConfig.DisableRequestLimiter = config.RequestLimiter.Disable
} else {
coreConfig.DisableRequestLimiter = true
}
if c.flagDev {
coreConfig.EnableRaw = true
coreConfig.EnableIntrospection = true

View File

@ -613,7 +613,6 @@ func testLoadConfigFile_json(t *testing.T) {
Type: "tcp",
Address: "127.0.0.1:443",
CustomResponseHeaders: DefaultCustomHeaders,
DisableRequestLimiter: false,
},
},
@ -904,6 +903,7 @@ listener "unix" {
redact_addresses = true
redact_cluster_name = true
redact_version = true
disable_request_limiter = true
}`))
config := Config{
@ -968,6 +968,7 @@ listener "unix" {
RedactAddresses: false,
RedactClusterName: false,
RedactVersion: false,
DisableRequestLimiter: true,
},
},
},

View File

@ -6,7 +6,6 @@
package server
import (
"fmt"
"testing"
"github.com/hashicorp/vault/internalshared/configutil"
@ -87,55 +86,3 @@ func TestCheckSealConfig(t *testing.T) {
})
}
}
// TestRequestLimiterConfig verifies that the census config is correctly instantiated from HCL
func TestRequestLimiterConfig(t *testing.T) {
testCases := []struct {
name string
inConfig string
outErr bool
outRequestLimiter *configutil.RequestLimiter
}{
{
name: "empty",
outRequestLimiter: nil,
},
{
name: "disabled",
inConfig: `
request_limiter {
disable = true
}`,
outRequestLimiter: &configutil.RequestLimiter{Disable: true},
},
{
name: "invalid disable",
inConfig: `
request_limiter {
disable = "people make mistakes"
}`,
outErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
config := fmt.Sprintf(`
ui = false
storage "file" {
path = "/tmp/test"
}
listener "tcp" {
address = "0.0.0.0:8200"
}
%s`, tc.inConfig)
gotConfig, err := ParseConfig(config, "")
if tc.outErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.outRequestLimiter, gotConfig.RequestLimiter)
}
})
}
}

1
go.mod
View File

@ -198,7 +198,6 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pires/go-proxyproto v0.6.1
github.com/pkg/errors v0.9.1
github.com/platinummonkey/go-concurrency-limits v0.7.0
github.com/posener/complete v1.2.3
github.com/pquerna/otp v1.2.1-0.20191009055518-468c2dd2b58d
github.com/prometheus/client_golang v1.14.0

5
go.sum
View File

@ -1274,7 +1274,6 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym
github.com/DataDog/datadog-go v2.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/DataDog/datadog-go v3.2.0+incompatible h1:qSG2N4FghB1He/r2mFrWKCaL7dXCilEuNEeAn20fdD4=
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/DataDog/datadog-go/v5 v5.0.2/go.mod h1:ZI9JFB4ewXbw1sBnF4sxsR2k1H3xjV+PUAOUsHvKpcU=
github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo=
github.com/Jeffail/gabs v1.1.1 h1:V0uzR08Hj22EX8+8QMhyI9sX2hwRu+/RJhJUmnwda/E=
github.com/Jeffail/gabs v1.1.1/go.mod h1:6xMvQMK4k33lb7GUUpaAPh6nKMmemQeg5d4gn7/bOXc=
@ -1299,7 +1298,6 @@ github.com/Microsoft/go-winio v0.4.16/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugX
github.com/Microsoft/go-winio v0.4.17-0.20210211115548-6eac466e5fa3/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
github.com/Microsoft/go-winio v0.4.17-0.20210324224401-5516f17a5958/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
github.com/Microsoft/go-winio v0.4.17/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
github.com/Microsoft/go-winio v0.5.1/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84=
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE=
@ -3140,8 +3138,6 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI=
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/platinummonkey/go-concurrency-limits v0.7.0 h1:Bl9E74+67BrlRLBeryHOaFy0e1L3zD9g436/3vo6akQ=
github.com/platinummonkey/go-concurrency-limits v0.7.0/go.mod h1:Xxr6BywMVH3QyLyd0PanLnkkkmByTTPET3azMpdfmng=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -3210,7 +3206,6 @@ github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0ua
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
github.com/rboyer/safeio v0.2.1 h1:05xhhdRNAdS3apYm7JRjOqngf4xruaW959jmRxGDuSU=
github.com/rboyer/safeio v0.2.1/go.mod h1:Cq/cEPK+YXFn622lsQ0K4KsPZSPtaptHHEldsy7Fmig=
github.com/rcrowley/go-metrics v0.0.0-20180503174638-e2704e165165/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=

View File

@ -918,35 +918,15 @@ func forwardRequest(core *vault.Core, w http.ResponseWriter, r *http.Request) {
w.Write(retBytes)
}
func acquireLimiterListener(core *vault.Core, rawReq *http.Request, r *logical.Request) (*limits.RequestListener, bool) {
var disable bool
disableRequestLimiter := rawReq.Context().Value(logical.CtxKeyDisableRequestLimiter{})
if disableRequestLimiter != nil {
disable = disableRequestLimiter.(bool)
}
r.RequestLimiterDisabled = disable
if disable {
return &limits.RequestListener{}, true
}
lim := &limits.RequestLimiter{}
if r.PathLimited {
lim = core.GetRequestLimiter(limits.SpecialPathLimiter)
} else {
switch rawReq.Method {
case http.MethodGet, http.MethodHead, http.MethodTrace, http.MethodOptions:
// We're only interested in the inverse, so do nothing here.
default:
lim = core.GetRequestLimiter(limits.WriteLimiter)
}
}
return lim.Acquire(rawReq.Context())
}
// request is a helper to perform a request and properly exit in the
// case of an error.
func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *logical.Request) (*logical.Response, bool, bool) {
lsnr, ok := acquireLimiterListener(core, rawReq, r)
lim := &limits.HTTPLimiter{
Method: rawReq.Method,
PathLimited: r.PathLimited,
LookupFunc: core.GetRequestLimiter,
}
lsnr, ok := lim.Acquire(rawReq.Context())
if !ok {
resp := &logical.Response{}
logical.RespondWithStatusCode(resp, r, http.StatusServiceUnavailable)

View File

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
"github.com/hashicorp/vault/vault/quotas"
@ -47,7 +48,7 @@ func wrapRequestLimiterHandler(handler http.Handler, props *vault.HandlerPropert
request := r.WithContext(
context.WithValue(
r.Context(),
logical.CtxKeyDisableRequestLimiter{},
limits.CtxKeyDisableRequestLimiter{},
props.ListenerConfig.DisableRequestLimiter,
),
)

View File

@ -55,8 +55,6 @@ type SharedConfig struct {
ClusterName string `hcl:"cluster_name"`
AdministrativeNamespacePath string `hcl:"administrative_namespace_path"`
RequestLimiter *RequestLimiter `hcl:"request_limiter"`
}
func ParseConfig(d string) (*SharedConfig, error) {
@ -158,13 +156,6 @@ func ParseConfig(d string) (*SharedConfig, error) {
}
}
if o := list.Filter("request_limiter"); len(o.Items) > 0 {
result.found("request_limiter", "RequestLimiter")
if err := parseRequestLimiter(&result, o); err != nil {
return nil, fmt.Errorf("error parsing 'request_limiter': %w", err)
}
}
entConfig := &(result.EntSharedConfig)
if err := entConfig.ParseConfig(list); err != nil {
return nil, fmt.Errorf("error parsing enterprise config: %w", err)
@ -293,13 +284,6 @@ func (c *SharedConfig) Sanitized() map[string]interface{} {
result["telemetry"] = sanitizedTelemetry
}
if c.RequestLimiter != nil {
sanitizedRequestLimiter := map[string]interface{}{
"disable": c.RequestLimiter.Disable,
}
result["request_limiter"] = sanitizedRequestLimiter
}
return result
}

View File

@ -98,10 +98,5 @@ func (c *SharedConfig) Merge(c2 *SharedConfig) *SharedConfig {
result.ClusterName = c2.ClusterName
}
result.RequestLimiter = c.RequestLimiter
if c2.RequestLimiter != nil {
result.RequestLimiter = c2.RequestLimiter
}
return result
}

View File

@ -1,58 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package configutil
import (
"fmt"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/go-secure-stdlib/parseutil"
"github.com/hashicorp/hcl"
"github.com/hashicorp/hcl/hcl/ast"
)
type RequestLimiter struct {
UnusedKeys UnusedKeyMap `hcl:",unusedKeyPositions"`
Disable bool `hcl:"-"`
DisableRaw interface{} `hcl:"disable"`
}
func (r *RequestLimiter) Validate(source string) []ConfigError {
return ValidateUnusedFields(r.UnusedKeys, source)
}
func (r *RequestLimiter) GoString() string {
return fmt.Sprintf("*%#v", *r)
}
var DefaultRequestLimiter = &RequestLimiter{
Disable: true,
}
func parseRequestLimiter(result *SharedConfig, list *ast.ObjectList) error {
if len(list.Items) > 1 {
return fmt.Errorf("only one 'request_limiter' block is permitted")
}
result.RequestLimiter = DefaultRequestLimiter
// Get our one item
item := list.Items[0]
if err := hcl.DecodeObject(&result.RequestLimiter, item.Val); err != nil {
return multierror.Prefix(err, "request_limiter:")
}
result.RequestLimiter.Disable = true
if result.RequestLimiter.DisableRaw != nil {
var err error
if result.RequestLimiter.Disable, err = parseutil.ParseBool(result.RequestLimiter.DisableRaw); err != nil {
return err
}
result.RequestLimiter.DisableRaw = nil
}
return nil
}

56
limits/http_limiter.go Normal file
View File

@ -0,0 +1,56 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package limits
import (
"context"
"errors"
"net/http"
)
//lint:ignore ST1005 Vault is the product name
var ErrCapacity = errors.New("Vault server temporarily overloaded")
const (
WriteLimiter = "write"
SpecialPathLimiter = "special-path"
)
// HTTPLimiter is a convenience struct that we use to wrap some logical request
// context and prevent dependence on Core.
type HTTPLimiter struct {
Method string
PathLimited bool
LookupFunc func(key string) *RequestLimiter
}
// CtxKeyDisableRequestLimiter holds the HTTP Listener's disable config if set.
type CtxKeyDisableRequestLimiter struct{}
func (c CtxKeyDisableRequestLimiter) String() string {
return "disable_request_limiter"
}
// Acquire checks the HTTPLimiter metadata to determine if an HTTP request
// should be limited, or simply passed through as a no-op.
func (h *HTTPLimiter) Acquire(ctx context.Context) (*RequestListener, bool) {
// If the limiter is disabled, return an empty wrapper so the limiter is a
// no-op and indicate that the request can proceed.
if disable := ctx.Value(CtxKeyDisableRequestLimiter{}); disable != nil && disable.(bool) {
return &RequestListener{}, true
}
lim := &RequestLimiter{}
if h.PathLimited {
lim = h.LookupFunc(SpecialPathLimiter)
} else {
switch h.Method {
case http.MethodGet, http.MethodHead, http.MethodTrace, http.MethodOptions:
// We're only interested in the inverse, so do nothing here.
default:
lim = h.LookupFunc(WriteLimiter)
}
}
return lim.Acquire(ctx)
}

View File

@ -1,189 +1,20 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package limits
import (
"context"
"errors"
"fmt"
"math"
"sync/atomic"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-hclog"
"github.com/platinummonkey/go-concurrency-limits/core"
"github.com/platinummonkey/go-concurrency-limits/limit"
"github.com/platinummonkey/go-concurrency-limits/limiter"
"github.com/platinummonkey/go-concurrency-limits/strategy"
)
var (
// ErrCapacity is a new error type to indicate that Vault is not accepting new
// requests. This should be handled by callers in request paths to return
// http.StatusServiceUnavailable to the client.
ErrCapacity = errors.New("Vault server temporarily overloaded")
type RequestLimiter struct{}
// DefaultDebugLogger opts out of the go-concurrency-limits internal Debug
// logger, since it's rather noisy. We're generating logs of interest in
// Vault.
DefaultDebugLogger limit.Logger = nil
// DefaultMetricsRegistry opts out of the go-concurrency-limits internal
// metrics because we're tracking what we care about in Vault.
DefaultMetricsRegistry core.MetricRegistry = core.EmptyMetricRegistryInstance
)
const (
// Smoothing adjusts how heavily we weight newer high-latency detection.
// Higher values (>1) place more emphasis on recent measurements. We set
// this below 1 to better tolerate short-lived spikes in request rate.
DefaultSmoothing = .1
// DefaultLongWindow is chosen as a minimum of 1000 samples. longWindow
// defines sliding window size used for the Exponential Moving Average.
DefaultLongWindow = 1000
)
// RequestLimiter is a thin wrapper for limiter.DefaultLimiter.
type RequestLimiter struct {
*limiter.DefaultLimiter
Flags LimiterFlags
}
// Acquire consults the underlying RequestLimiter to see if a new
// RequestListener can be acquired.
//
// The return values are a *RequestListener, which the caller can use to perform
// latency measurements, and a bool to indicate whether or not a RequestListener
// was acquired.
//
// The returned RequestListener is short-lived and eventually garbage-collected;
// however, the RequestLimiter keeps track of in-flight concurrency using a
// token bucket implementation. The caller must release the resulting Limiter
// token by conducting a measurement.
//
// There are three return cases:
//
// 1) If Request Limiting is disabled, we return an empty RequestListener so all
// measurements are no-ops.
//
// 2) If the request limit has been exceeded, we will not acquire a
// RequestListener and instead return nil, false. No measurement is required,
// since we immediately return from callers with ErrCapacity.
//
// 3) If we have not exceeded the request limit, the caller must call one of
// OnSuccess(), OnDropped(), or OnIgnore() to return a measurement and release
// the underlying Limiter token.
func (l *RequestLimiter) Acquire(ctx context.Context) (*RequestListener, bool) {
// Transparently handle the case where the limiter is disabled.
if l == nil || l.DefaultLimiter == nil {
// Acquire is a no-op on CE
func (l *RequestLimiter) Acquire(_ctx context.Context) (*RequestListener, bool) {
return &RequestListener{}, true
}
lsnr, ok := l.DefaultLimiter.Acquire(ctx)
if !ok {
metrics.IncrCounter(([]string{"limits", "concurrency", "service_unavailable"}), 1)
// If the token acquisition fails, we've reached capacity and we won't
// get a listener, so just return nil.
return nil, false
}
return &RequestListener{
DefaultListener: lsnr.(*limiter.DefaultListener),
released: new(atomic.Bool),
}, true
}
// concurrencyChanger adjusts the current allowed concurrency with an
// exponential backoff as we approach the max limit.
func concurrencyChanger(limit int) int {
change := math.Sqrt(float64(limit))
if change < 1.0 {
change = 1.0
}
return int(change)
}
var DefaultLimiterFlags = map[string]LimiterFlags{
// WriteLimiter default flags have a less conservative MinLimit to prevent
// over-optimizing the request latency, which would result in
// under-utilization and client starvation.
WriteLimiter: {
MinLimit: 100,
MaxLimit: 5000,
InitialLimit: 100,
},
// SpecialPathLimiter default flags have a conservative MinLimit to allow
// more aggressive concurrency throttling for CPU-bound workloads such as
// `pki/issue`.
SpecialPathLimiter: {
MinLimit: 5,
MaxLimit: 5000,
InitialLimit: 5,
},
}
// LimiterFlags establish some initial configuration for a new request limiter.
type LimiterFlags struct {
// MinLimit defines the minimum concurrency floor to prevent over-throttling
// requests during periods of high traffic.
MinLimit int `json:"min_limit,omitempty" mapstructure:"min_limit,omitempty"`
// MaxLimit defines the maximum concurrency ceiling to prevent skewing to a
// point of no return.
//
// We set this to a high value (5000) with the expectation that systems with
// high-performing specs will tolerate higher limits, while the algorithm
// will find its own steady-state concurrency well below this threshold in
// most cases.
MaxLimit int `json:"max_limit,omitempty" mapstructure:"max_limit,omitempty"`
// InitialLimit defines the starting concurrency limit prior to any
// measurements.
//
// If we start this value off too high, Vault could become
// overloaded before the algorithm has a chance to adapt. Setting the value
// to the minimum is a safety measure which could result in early request
// rejection; however, the adaptive nature of the algorithm will prevent
// this from being a prolonged state as the allowed concurrency will
// increase during normal operation.
InitialLimit int `json:"initial_limit,omitempty" mapstructure:"initial_limit,omitempty"`
}
// NewRequestLimiter is a basic constructor for the RequestLimiter wrapper. It
// is responsible for setting up the Gradient2 Limit and instantiating a new
// wrapped DefaultLimiter.
func NewRequestLimiter(logger hclog.Logger, name string, flags LimiterFlags) (*RequestLimiter, error) {
logger.Info("setting up new request limiter",
"initialLimit", flags.InitialLimit,
"maxLimit", flags.MaxLimit,
"minLimit", flags.MinLimit,
)
// NewGradient2Limit is the algorithm which drives request limiting
// decisions. It gathers latency measurements and calculates an Exponential
// Moving Average to determine whether latency deviation warrants a change
// in the current concurrency limit.
lim, err := limit.NewGradient2Limit(name,
flags.InitialLimit,
flags.MaxLimit,
flags.MinLimit,
concurrencyChanger,
DefaultSmoothing,
DefaultLongWindow,
DefaultDebugLogger,
DefaultMetricsRegistry,
)
if err != nil {
return &RequestLimiter{}, fmt.Errorf("failed to create gradient2 limit: %w", err)
}
strategy := strategy.NewSimpleStrategy(flags.InitialLimit)
defLimiter, err := limiter.NewDefaultLimiter(lim, 1e9, 1e9, 10, 100, strategy, nil, DefaultMetricsRegistry)
if err != nil {
return &RequestLimiter{}, err
}
return &RequestLimiter{Flags: flags, DefaultLimiter: defLimiter}, nil
}
// EstimatedLimit is effectively 0, since we're not limiting requests on CE.
func (l *RequestLimiter) EstimatedLimit() int { return 0 }

View File

@ -1,51 +1,14 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package limits
import (
"sync/atomic"
type RequestListener struct{}
"github.com/armon/go-metrics"
"github.com/platinummonkey/go-concurrency-limits/limiter"
)
func (l *RequestListener) OnSuccess() {}
// RequestListener is a thin wrapper for limiter.DefaultLimiter to handle the
// case where request limiting is turned off.
type RequestListener struct {
*limiter.DefaultListener
released *atomic.Bool
}
func (l *RequestListener) OnDropped() {}
// OnSuccess is called as a notification that the operation succeeded and
// internally measured latency should be used as an RTT sample.
func (l *RequestListener) OnSuccess() {
if l.DefaultListener != nil {
metrics.IncrCounter(([]string{"limits", "concurrency", "success"}), 1)
l.DefaultListener.OnSuccess()
l.released.Store(true)
}
}
// OnDropped is called to indicate the request failed and was dropped due to an
// internal server error. Note that this does not include ErrCapacity.
func (l *RequestListener) OnDropped() {
if l.DefaultListener != nil {
metrics.IncrCounter(([]string{"limits", "concurrency", "dropped"}), 1)
l.DefaultListener.OnDropped()
l.released.Store(true)
}
}
// OnIgnore is called to indicate the operation failed before any meaningful RTT
// measurement could be made and should be ignored to not introduce an
// artificially low RTT. It also provides an extra layer of protection against
// leaks of the underlying StrategyToken during recoverable panics in the
// request handler. We treat these as Ignored, discard the measurement, and mark
// the listener as released.
func (l *RequestListener) OnIgnore() {
if l.DefaultListener != nil && l.released.Load() != true {
metrics.IncrCounter(([]string{"limits", "concurrency", "ignored"}), 1)
l.DefaultListener.OnIgnore()
l.released.Store(true)
}
}
func (l *RequestListener) OnIgnore() {}

View File

@ -1,222 +1,9 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package limits
import (
"os"
"strconv"
"sync"
"github.com/hashicorp/go-hclog"
)
const (
WriteLimiter = "write"
SpecialPathLimiter = "special-path"
LimitsBadEnvVariable = "failed to process limiter environment variable, using default"
)
// NOTE: Great care should be taken when setting any of these variables to avoid
// adverse affects in optimal request servicing. It is strongly advised that
// these variables not be used unless there is a very good reason. These are
// intentionally undocumented environment variables that may be removed in
// future versions of Vault.
const (
// EnvVaultDisableWriteLimiter is used to turn off the
// RequestLimiter for write-based HTTP methods.
EnvVaultDisableWriteLimiter = "VAULT_DISABLE_WRITE_LIMITER"
// EnvVaultWriteLimiterMin is used to modify the minimum
// concurrency limit for write-based HTTP methods.
EnvVaultWriteLimiterMin = "VAULT_WRITE_LIMITER_MIN"
// EnvVaultWriteLimiterMax is used to modify the maximum
// concurrency limit for write-based HTTP methods.
EnvVaultWriteLimiterMax = "VAULT_WRITE_LIMITER_MAX"
// EnvVaultDisablePathBasedRequestLimiting is used to turn off the
// RequestLimiter for special-cased paths, specified in
// Backend.PathsSpecial.
EnvVaultDisableSpecialPathLimiter = "VAULT_DISABLE_SPECIAL_PATH_LIMITER"
// EnvVaultSpecialPathLimiterMin is used to modify the minimum
// concurrency limit for write-based HTTP methods.
EnvVaultSpecialPathLimiterMin = "VAULT_SPECIAL_PATH_LIMITER_MIN"
// EnvVaultSpecialPathLimiterMax is used to modify the maximum
// concurrency limit for write-based HTTP methods.
EnvVaultSpecialPathLimiterMax = "VAULT_SPECIAL_PATH_LIMITER_MAX"
)
// LimiterRegistry holds the map of RequestLimiters mapped to keys.
type LimiterRegistry struct {
Limiters map[string]*RequestLimiter
Logger hclog.Logger
Enabled bool
sync.RWMutex
}
// NewLimiterRegistry is a basic LimiterRegistry constructor.
func NewLimiterRegistry(logger hclog.Logger) *LimiterRegistry {
return &LimiterRegistry{
Limiters: make(map[string]*RequestLimiter),
Logger: logger,
}
}
// processEnvVars consults Limiter-specific environment variables and tells the
// caller if the Limiter should be disabled. If not, it adjusts the passed-in
// limiterFlags as appropriate.
func (r *LimiterRegistry) processEnvVars(name string, flags *LimiterFlags, envDisabled, envMin, envMax string) bool {
envFlagsLogger := r.Logger.With("name", name)
if disabledRaw := os.Getenv(envDisabled); disabledRaw != "" {
disabled, err := strconv.ParseBool(disabledRaw)
if err != nil {
envFlagsLogger.Warn(LimitsBadEnvVariable,
"env", envDisabled,
"val", disabledRaw,
"default", false,
"error", err,
)
}
if disabled {
envFlagsLogger.Warn("limiter disabled by environment variable", "env", envDisabled, "val", disabledRaw)
return true
}
}
envFlags := &LimiterFlags{}
if minRaw := os.Getenv(envMin); minRaw != "" {
min, err := strconv.Atoi(minRaw)
if err != nil {
envFlagsLogger.Warn(LimitsBadEnvVariable,
"env", envMin,
"val", minRaw,
"default", flags.MinLimit,
"error", err,
)
} else {
envFlags.MinLimit = min
}
}
if maxRaw := os.Getenv(envMax); maxRaw != "" {
max, err := strconv.Atoi(maxRaw)
if err != nil {
envFlagsLogger.Warn(LimitsBadEnvVariable,
"env", envMax,
"val", maxRaw,
"default", flags.MaxLimit,
"error", err,
)
} else {
envFlags.MaxLimit = max
}
}
switch {
case envFlags.MinLimit == 0:
// Assume no environment variable was provided.
case envFlags.MinLimit > 0:
flags.MinLimit = envFlags.MinLimit
default:
r.Logger.Warn("min limit must be greater than zero, falling back to defaults", "minLimit", flags.MinLimit)
}
switch {
case envFlags.MaxLimit == 0:
// Assume no environment variable was provided.
case envFlags.MaxLimit > flags.MinLimit:
flags.MaxLimit = envFlags.MaxLimit
default:
r.Logger.Warn("max limit must be greater than min, falling back to defaults", "maxLimit", flags.MaxLimit)
}
return false
}
// Enable sets up a new LimiterRegistry and marks it Enabled.
func (r *LimiterRegistry) Enable() {
r.Lock()
defer r.Unlock()
if r.Enabled {
return
}
r.Logger.Info("enabling request limiters")
r.Limiters = map[string]*RequestLimiter{}
for name, flags := range DefaultLimiterFlags {
r.Register(name, flags)
}
r.Enabled = true
}
// Register creates a new request limiter and assigns it a slot in the
// LimiterRegistry. Locking should be done in the caller.
func (r *LimiterRegistry) Register(name string, flags LimiterFlags) {
var disabled bool
switch name {
case WriteLimiter:
disabled = r.processEnvVars(name, &flags,
EnvVaultDisableWriteLimiter,
EnvVaultWriteLimiterMin,
EnvVaultWriteLimiterMax,
)
if disabled {
return
}
case SpecialPathLimiter:
disabled = r.processEnvVars(name, &flags,
EnvVaultDisableSpecialPathLimiter,
EnvVaultSpecialPathLimiterMin,
EnvVaultSpecialPathLimiterMax,
)
if disabled {
return
}
default:
r.Logger.Warn("skipping invalid limiter type", "key", name)
return
}
// Always set the initial limit to min so the system can find its own
// equilibrium, since max might be too high.
flags.InitialLimit = flags.MinLimit
limiter, err := NewRequestLimiter(r.Logger.Named(name), name, flags)
if err != nil {
r.Logger.Error("failed to register limiter", "name", name, "error", err)
return
}
r.Limiters[name] = limiter
}
// Disable drops its references to underlying limiters.
func (r *LimiterRegistry) Disable() {
r.Lock()
defer r.Unlock()
if !r.Enabled {
return
}
r.Logger.Info("disabling request limiters")
// Any outstanding tokens will be flushed when their request completes, as
// they've already acquired a listener. Just drop the limiter references
// here and the garbage-collector should take care of the rest.
r.Limiters = map[string]*RequestLimiter{}
r.Enabled = false
}
// GetLimiter looks up a RequestLimiter by key in the LimiterRegistry.
func (r *LimiterRegistry) GetLimiter(key string) *RequestLimiter {
r.RLock()
defer r.RUnlock()
return r.Limiters[key]
}
type LimiterRegistry struct{}

View File

@ -546,9 +546,3 @@ func ContextOriginalBodyValue(ctx context.Context) (io.ReadCloser, bool) {
func CreateContextOriginalBody(parent context.Context, body io.ReadCloser) context.Context {
return context.WithValue(parent, ctxKeyOriginalBody{}, body)
}
type CtxKeyDisableRequestLimiter struct{}
func (c CtxKeyDisableRequestLimiter) String() string {
return "disable_request_limiter"
}

View File

@ -49,7 +49,6 @@ import (
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/helper/osutil"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/physical/raft"
"github.com/hashicorp/vault/plugins/event"
"github.com/hashicorp/vault/sdk/helper/certutil"
@ -715,9 +714,6 @@ type Core struct {
periodicLeaderRefreshInterval time.Duration
clusterAddrBridge *raft.ClusterAddrBridge
limiterRegistry *limits.LimiterRegistry
limiterRegistryLock sync.Mutex
}
func (c *Core) ActiveNodeClockSkewMillis() int64 {
@ -728,12 +724,6 @@ func (c *Core) EchoDuration() time.Duration {
return c.echoDuration.Load()
}
func (c *Core) GetRequestLimiter(key string) *limits.RequestLimiter {
c.limiterRegistryLock.Lock()
defer c.limiterRegistryLock.Unlock()
return c.limiterRegistry.GetLimiter(key)
}
// c.stateLock needs to be held in read mode before calling this function.
func (c *Core) HAState() consts.HAState {
switch {
@ -902,9 +892,6 @@ type CoreConfig struct {
PeriodicLeaderRefreshInterval time.Duration
ClusterAddrBridge *raft.ClusterAddrBridge
DisableRequestLimiter bool
LimiterRegistry *limits.LimiterRegistry
}
// GetServiceRegistration returns the config's ServiceRegistration, or nil if it does
@ -1007,10 +994,6 @@ func CreateCore(conf *CoreConfig) (*Core, error) {
}
}
if conf.LimiterRegistry == nil {
conf.LimiterRegistry = limits.NewLimiterRegistry(conf.Logger.Named("limits"))
}
// Use imported logging deadlock if requested
var stateLock locking.RWMutex
stateLock = &locking.SyncRWMutex{}
@ -1315,14 +1298,6 @@ func NewCore(conf *CoreConfig) (*Core, error) {
return nil, err
}
c.limiterRegistry = conf.LimiterRegistry
c.limiterRegistryLock.Lock()
c.limiterRegistry.Disable()
if !conf.DisableRequestLimiter {
c.limiterRegistry.Enable()
}
c.limiterRegistryLock.Unlock()
err = c.adjustForSealMigration(conf.UnwrapSeal)
if err != nil {
return nil, err
@ -4109,27 +4084,6 @@ func (c *Core) ReloadLogRequestsLevel() {
}
}
func (c *Core) ReloadRequestLimiter() {
c.limiterRegistry.Logger.Info("reloading request limiter config")
conf := c.rawConfig.Load()
if conf == nil {
return
}
disable := true
requestLimiterConfig := conf.(*server.Config).RequestLimiter
if requestLimiterConfig != nil {
disable = requestLimiterConfig.Disable
}
switch disable {
case true:
c.limiterRegistry.Disable()
default:
c.limiterRegistry.Enable()
}
}
func (c *Core) ReloadIntrospectionEndpointEnabled() {
conf := c.rawConfig.Load()
if conf == nil {

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/namespace"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/sdk/helper/license"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
@ -213,3 +214,11 @@ func DiagnoseCheckLicense(ctx context.Context, vaultCore *Core, coreConfig CoreC
func createCustomMessageManager(storage logical.Storage, _ *Core) CustomMessagesManager {
return uicustommessages.NewManager(storage)
}
// GetRequestLimiter is a stub for CE. The caller will handle the nil case as a no-op.
func (c *Core) GetRequestLimiter(key string) *limits.RequestLimiter {
return nil
}
// ReloadRequestLimiter is a no-op on CE.
func (c *Core) ReloadRequestLimiter() {}

View File

@ -226,10 +226,6 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf
b.Backend.Paths = append(b.Backend.Paths, b.experimentPaths()...)
b.Backend.Paths = append(b.Backend.Paths, b.introspectionPaths()...)
if requestLimiterRead := b.requestLimiterReadPath(); requestLimiterRead != nil {
b.Backend.Paths = append(b.Backend.Paths, b.requestLimiterReadPath())
}
if core.rawEnabled {
b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...)
}

View File

@ -1,88 +0,0 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build testonly
package vault
import (
"context"
"net/http"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)
// RequestLimiterResponse is a struct for marshalling Request Limiter status responses.
type RequestLimiterResponse struct {
GlobalDisabled bool `json:"global_disabled" mapstructure:"global_disabled"`
ListenerDisabled bool `json:"listener_disabled" mapstructure:"listener_disabled"`
Limiters map[string]*LimiterStatus `json:"types" mapstructure:"types"`
}
// LimiterStatus holds the per-limiter status and flags for testing.
type LimiterStatus struct {
Enabled bool `json:"enabled" mapstructure:"enabled"`
Flags limits.LimiterFlags `json:"flags,omitempty" mapstructure:"flags,omitempty"`
}
const readRequestLimiterHelpText = `
Read the current status of the request limiter.
`
func (b *SystemBackend) requestLimiterReadPath() *framework.Path {
return &framework.Path{
Pattern: "internal/request-limiter/status$",
HelpDescription: readRequestLimiterHelpText,
HelpSynopsis: readRequestLimiterHelpText,
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: b.handleReadRequestLimiter,
DisplayAttrs: &framework.DisplayAttributes{
OperationVerb: "read",
OperationSuffix: "verbosity-level-for",
},
Responses: map[int][]framework.Response{
http.StatusOK: {{
Description: "OK",
}},
},
Summary: "Read the current status of the request limiter.",
},
},
}
}
// handleReadRequestLimiter returns the enabled Request Limiter status for this node.
func (b *SystemBackend) handleReadRequestLimiter(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
resp := &RequestLimiterResponse{
Limiters: make(map[string]*LimiterStatus),
}
b.Core.limiterRegistryLock.Lock()
registry := b.Core.limiterRegistry
b.Core.limiterRegistryLock.Unlock()
resp.GlobalDisabled = !registry.Enabled
resp.ListenerDisabled = req.RequestLimiterDisabled
enabled := !(resp.GlobalDisabled || resp.ListenerDisabled)
for name := range limits.DefaultLimiterFlags {
var flags limits.LimiterFlags
if requestLimiter := b.Core.GetRequestLimiter(name); requestLimiter != nil && enabled {
flags = requestLimiter.Flags
}
resp.Limiters[name] = &LimiterStatus{
Enabled: enabled,
Flags: flags,
}
}
return &logical.Response{
Data: map[string]interface{}{
"request_limiter": resp,
},
}, nil
}

View File

@ -44,7 +44,6 @@ import (
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/helper/testhelpers/pluginhelpers"
"github.com/hashicorp/vault/internalshared/configutil"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/logging"
@ -1133,8 +1132,6 @@ type TestClusterOptions struct {
// ABCDLoggerNames names the loggers according to our ABCD convention when generating 4 clusters
ABCDLoggerNames bool
LimiterRegistry *limits.LimiterRegistry
}
type TestPluginConfig struct {
@ -1425,7 +1422,6 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
EnableUI: true,
EnableRaw: true,
BuiltinRegistry: corehelpers.NewMockBuiltinRegistry(),
LimiterRegistry: limits.NewLimiterRegistry(testCluster.Logger),
}
if base != nil {
@ -1515,10 +1511,6 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te
coreConfig.PeriodicLeaderRefreshInterval = base.PeriodicLeaderRefreshInterval
coreConfig.ClusterAddrBridge = base.ClusterAddrBridge
if base.LimiterRegistry != nil {
coreConfig.LimiterRegistry = base.LimiterRegistry
}
testApplyEntBaseConfig(coreConfig, base)
}
if coreConfig.ClusterName == "" {
@ -1912,10 +1904,6 @@ func (testCluster *TestCluster) newCore(t testing.T, idx int, coreConfig *CoreCo
localConfig.NumExpirationWorkers = numExpirationWorkersTest
if opts != nil && opts.LimiterRegistry != nil {
localConfig.LimiterRegistry = opts.LimiterRegistry
}
c, err := NewCore(&localConfig)
if err != nil {
t.Fatalf("err: %v", err)