diff --git a/.github/workflows/test-go.yml b/.github/workflows/test-go.yml index e0002d731a..962682d634 100644 --- a/.github/workflows/test-go.yml +++ b/.github/workflows/test-go.yml @@ -156,7 +156,7 @@ jobs: # testonly tagged tests need an additional tag to be included # also running some extra tests for sanity checking with the testonly build tag ( - go list -tags=testonly ./vault/external_tests/{kv,token,*replication-perf*,*testonly*} ./vault/ | gotestsum tool ci-matrix --debug \ + go list -tags=testonly ./vault/external_tests/{kv,token,*replication-perf*,*testonly*} ./command/*testonly* ./vault/ | gotestsum tool ci-matrix --debug \ --partitions "${{ inputs.total-runners }}" \ --timing-files 'test-results/go-test/*.json' > matrix.json ) diff --git a/command/command_testonly/server_testonly_test.go b/command/command_testonly/server_testonly_test.go new file mode 100644 index 0000000000..e146d71c03 --- /dev/null +++ b/command/command_testonly/server_testonly_test.go @@ -0,0 +1,207 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build testonly + +package command_testonly + +import ( + "os" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command" + "github.com/hashicorp/vault/limits" + "github.com/hashicorp/vault/vault" + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/require" +) + +const ( + baseHCL = ` + backend "inmem" { } + disable_mlock = true + listener "tcp" { + address = "127.0.0.1:8209" + tls_disable = "true" + } + api_addr = "http://127.0.0.1:8209" + ` + requestLimiterDisableHCL = ` + request_limiter { + disable = true + } +` + requestLimiterEnableHCL = ` + request_limiter { + disable = false + } +` +) + +// TestServer_ReloadRequestLimiter tests a series of reloads and state +// transitions between RequestLimiter enable and disable. +func TestServer_ReloadRequestLimiter(t *testing.T) { + t.Parallel() + + enabledResponse := &vault.RequestLimiterResponse{ + GlobalDisabled: false, + ListenerDisabled: false, + Limiters: map[string]*vault.LimiterStatus{ + limits.WriteLimiter: { + Enabled: true, + Flags: limits.DefaultLimiterFlags[limits.WriteLimiter], + }, + limits.SpecialPathLimiter: { + Enabled: true, + Flags: limits.DefaultLimiterFlags[limits.SpecialPathLimiter], + }, + }, + } + + disabledResponse := &vault.RequestLimiterResponse{ + GlobalDisabled: true, + ListenerDisabled: false, + Limiters: map[string]*vault.LimiterStatus{ + limits.WriteLimiter: { + Enabled: false, + }, + limits.SpecialPathLimiter: { + Enabled: false, + }, + }, + } + + cases := []struct { + name string + configAfter string + expectedResponse *vault.RequestLimiterResponse + }{ + { + "enable after default", + baseHCL + requestLimiterEnableHCL, + enabledResponse, + }, + { + "enable after enable", + baseHCL + requestLimiterEnableHCL, + enabledResponse, + }, + { + "disable after enable", + baseHCL + requestLimiterDisableHCL, + disabledResponse, + }, + { + "default after disable", + baseHCL, + enabledResponse, + }, + { + "default after default", + baseHCL, + enabledResponse, + }, + { + "disable after default", + baseHCL + requestLimiterDisableHCL, + disabledResponse, + }, + { + "disable after disable", + baseHCL + requestLimiterDisableHCL, + disabledResponse, + }, + } + + ui, srv := command.TestServerCommand(t) + + f, err := os.CreateTemp(t.TempDir(), "") + require.NoErrorf(t, err, "error creating temp dir: %v", err) + + _, err = f.WriteString(baseHCL) + require.NoErrorf(t, err, "cannot write temp file contents") + + configPath := f.Name() + + var output string + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + code := srv.Run([]string{"-config", configPath}) + output = ui.ErrorWriter.String() + ui.OutputWriter.String() + require.Equal(t, 0, code, output) + }() + + select { + case <-srv.StartedCh(): + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + } + defer func() { + srv.ShutdownCh <- struct{}{} + wg.Wait() + }() + + err = f.Close() + require.NoErrorf(t, err, "unable to close temp file") + + // create a client and unseal vault + cli, err := srv.Client() + require.NoError(t, err) + require.NoError(t, cli.SetAddress("http://127.0.0.1:8209")) + initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1}) + require.NoError(t, err) + _, err = cli.Sys().Unseal(initResp.Keys[0]) + require.NoError(t, err) + cli.SetToken(initResp.RootToken) + + output = ui.ErrorWriter.String() + ui.OutputWriter.String() + require.Contains(t, output, "Request Limiter: enabled") + + verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) { + t.Helper() + + statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status") + require.NoError(t, err) + require.NotNil(t, statusResp) + + limitersResp, ok := statusResp.Data["request_limiter"] + require.True(t, ok) + require.NotNil(t, limitersResp) + + var limiters *vault.RequestLimiterResponse + err = mapstructure.Decode(limitersResp, &limiters) + require.NoError(t, err) + require.NotNil(t, limiters) + + require.Equal(t, expectedResponse, limiters) + } + + // Start off with default enabled + verifyLimiters(t, enabledResponse) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Write the new contents and reload the server + f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644) + require.NoError(t, err) + defer f.Close() + + _, err = f.WriteString(tc.configAfter) + require.NoErrorf(t, err, "cannot write temp file contents") + + srv.SighupCh <- struct{}{} + select { + case <-srv.ReloadedCh(): + case <-time.After(5 * time.Second): + t.Fatalf("test timed out") + } + + verifyLimiters(t, tc.expectedResponse) + }) + } +} diff --git a/command/server_test.go b/command/server_test.go index e46d5bedba..160fdc8a34 100644 --- a/command/server_test.go +++ b/command/server_test.go @@ -22,11 +22,9 @@ import ( "testing" "time" - "github.com/hashicorp/cli" "github.com/hashicorp/vault/command/server" "github.com/hashicorp/vault/helper/testhelpers/corehelpers" "github.com/hashicorp/vault/internalshared/configutil" - "github.com/hashicorp/vault/sdk/physical" physInmem "github.com/hashicorp/vault/sdk/physical/inmem" "github.com/hashicorp/vault/vault" "github.com/hashicorp/vault/vault/seal" @@ -97,29 +95,6 @@ cloud { ` ) -func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { - tb.Helper() - - ui := cli.NewMockUi() - return ui, &ServerCommand{ - BaseCommand: &BaseCommand{ - UI: ui, - }, - ShutdownCh: MakeShutdownCh(), - SighupCh: MakeSighupCh(), - SigUSR2Ch: MakeSigUSR2Ch(), - PhysicalBackends: map[string]physical.Factory{ - "inmem": physInmem.NewInmem, - "inmem_ha": physInmem.NewInmemHA, - }, - - // These prevent us from random sleep guessing... - startedCh: make(chan struct{}, 5), - reloadedCh: make(chan struct{}, 5), - licenseReloadedCh: make(chan error), - } -} - func TestServer_ReloadListener(t *testing.T) { t.Parallel() diff --git a/command/server_util.go b/command/server_util.go new file mode 100644 index 0000000000..667b958595 --- /dev/null +++ b/command/server_util.go @@ -0,0 +1,48 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package command + +import ( + "testing" + + "github.com/hashicorp/cli" + "github.com/hashicorp/vault/sdk/physical" + physInmem "github.com/hashicorp/vault/sdk/physical/inmem" +) + +func TestServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { + tb.Helper() + return testServerCommand(tb) +} + +func (c *ServerCommand) StartedCh() chan struct{} { + return c.startedCh +} + +func (c *ServerCommand) ReloadedCh() chan struct{} { + return c.reloadedCh +} + +func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) { + tb.Helper() + + ui := cli.NewMockUi() + return ui, &ServerCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + ShutdownCh: MakeShutdownCh(), + SighupCh: MakeSighupCh(), + SigUSR2Ch: MakeSigUSR2Ch(), + PhysicalBackends: map[string]physical.Factory{ + "inmem": physInmem.NewInmem, + "inmem_ha": physInmem.NewInmemHA, + }, + + // These prevent us from random sleep guessing... + startedCh: make(chan struct{}, 5), + reloadedCh: make(chan struct{}, 5), + licenseReloadedCh: make(chan error), + } +} diff --git a/http/handler.go b/http/handler.go index fd920394fa..847f5eb133 100644 --- a/http/handler.go +++ b/http/handler.go @@ -919,6 +919,7 @@ func acquireLimiterListener(core *vault.Core, rawReq *http.Request, r *logical.R if disableRequestLimiter != nil { disable = disableRequestLimiter.(bool) } + r.RequestLimiterDisabled = disable if disable { return &limits.RequestListener{}, true } diff --git a/limits/limiter.go b/limits/limiter.go index d9758eaf71..15a811ddd9 100644 --- a/limits/limiter.go +++ b/limits/limiter.go @@ -47,6 +47,7 @@ const ( // RequestLimiter is a thin wrapper for limiter.DefaultLimiter. type RequestLimiter struct { *limiter.DefaultLimiter + Flags LimiterFlags } // Acquire consults the underlying RequestLimiter to see if a new @@ -103,34 +104,31 @@ func concurrencyChanger(limit int) int { return int(change) } -var ( - // DefaultWriteLimiterFlags have a less conservative MinLimit to prevent +var DefaultLimiterFlags = map[string]LimiterFlags{ + // WriteLimiter default flags have a less conservative MinLimit to prevent // over-optimizing the request latency, which would result in // under-utilization and client starvation. - DefaultWriteLimiterFlags = LimiterFlags{ - Name: WriteLimiter, - MinLimit: 100, - MaxLimit: 5000, - } + WriteLimiter: { + MinLimit: 100, + MaxLimit: 5000, + InitialLimit: 100, + }, - // DefaultSpecialPathLimiterFlags have a conservative MinLimit to allow more - // aggressive concurrency throttling for CPU-bound workloads such as + // SpecialPathLimiter default flags have a conservative MinLimit to allow + // more aggressive concurrency throttling for CPU-bound workloads such as // `pki/issue`. - DefaultSpecialPathLimiterFlags = LimiterFlags{ - Name: SpecialPathLimiter, - MinLimit: 5, - MaxLimit: 5000, - } -) + SpecialPathLimiter: { + MinLimit: 5, + MaxLimit: 5000, + InitialLimit: 5, + }, +} // LimiterFlags establish some initial configuration for a new request limiter. type LimiterFlags struct { - // Name specifies the limiter Name for registry lookup and logging. - Name string - // MinLimit defines the minimum concurrency floor to prevent over-throttling // requests during periods of high traffic. - MinLimit int + MinLimit int `json:"min_limit,omitempty" mapstructure:"min_limit,omitempty"` // MaxLimit defines the maximum concurrency ceiling to prevent skewing to a // point of no return. @@ -139,7 +137,7 @@ type LimiterFlags struct { // high-performing specs will tolerate higher limits, while the algorithm // will find its own steady-state concurrency well below this threshold in // most cases. - MaxLimit int + MaxLimit int `json:"max_limit,omitempty" mapstructure:"max_limit,omitempty"` // InitialLimit defines the starting concurrency limit prior to any // measurements. @@ -150,13 +148,13 @@ type LimiterFlags struct { // rejection; however, the adaptive nature of the algorithm will prevent // this from being a prolonged state as the allowed concurrency will // increase during normal operation. - InitialLimit int + InitialLimit int `json:"initial_limit,omitempty" mapstructure:"initial_limit,omitempty"` } // NewRequestLimiter is a basic constructor for the RequestLimiter wrapper. It // is responsible for setting up the Gradient2 Limit and instantiating a new // wrapped DefaultLimiter. -func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter, error) { +func NewRequestLimiter(logger hclog.Logger, name string, flags LimiterFlags) (*RequestLimiter, error) { logger.Info("setting up new request limiter", "initialLimit", flags.InitialLimit, "maxLimit", flags.MaxLimit, @@ -167,7 +165,7 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter // decisions. It gathers latency measurements and calculates an Exponential // Moving Average to determine whether latency deviation warrants a change // in the current concurrency limit. - lim, err := limit.NewGradient2Limit(flags.Name, + lim, err := limit.NewGradient2Limit(name, flags.InitialLimit, flags.MaxLimit, flags.MinLimit, @@ -178,7 +176,7 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter DefaultMetricsRegistry, ) if err != nil { - return nil, fmt.Errorf("failed to create gradient2 limit: %w", err) + return &RequestLimiter{}, fmt.Errorf("failed to create gradient2 limit: %w", err) } strategy := strategy.NewSimpleStrategy(flags.InitialLimit) @@ -187,5 +185,5 @@ func NewRequestLimiter(logger hclog.Logger, flags LimiterFlags) (*RequestLimiter return &RequestLimiter{}, err } - return &RequestLimiter{defLimiter}, nil + return &RequestLimiter{Flags: flags, DefaultLimiter: defLimiter}, nil } diff --git a/limits/registry.go b/limits/registry.go index 1868bf6a44..8d1cb8bca7 100644 --- a/limits/registry.go +++ b/limits/registry.go @@ -67,8 +67,8 @@ func NewLimiterRegistry(logger hclog.Logger) *LimiterRegistry { // processEnvVars consults Limiter-specific environment variables and tells the // caller if the Limiter should be disabled. If not, it adjusts the passed-in // limiterFlags as appropriate. -func (r *LimiterRegistry) processEnvVars(flags *LimiterFlags, envDisabled, envMin, envMax string) bool { - envFlagsLogger := r.Logger.With("name", flags.Name) +func (r *LimiterRegistry) processEnvVars(name string, flags *LimiterFlags, envDisabled, envMin, envMax string) bool { + envFlagsLogger := r.Logger.With("name", name) if disabledRaw := os.Getenv(envDisabled); disabledRaw != "" { disabled, err := strconv.ParseBool(disabledRaw) if err != nil { @@ -147,20 +147,22 @@ func (r *LimiterRegistry) Enable() { r.Logger.Info("enabling request limiters") r.Limiters = map[string]*RequestLimiter{} - r.Register(DefaultWriteLimiterFlags) - r.Register(DefaultSpecialPathLimiterFlags) + + for name, flags := range DefaultLimiterFlags { + r.Register(name, flags) + } r.Enabled = true } // Register creates a new request limiter and assigns it a slot in the // LimiterRegistry. Locking should be done in the caller. -func (r *LimiterRegistry) Register(flags LimiterFlags) { +func (r *LimiterRegistry) Register(name string, flags LimiterFlags) { var disabled bool - switch flags.Name { + switch name { case WriteLimiter: - disabled = r.processEnvVars(&flags, + disabled = r.processEnvVars(name, &flags, EnvVaultDisableWriteLimiter, EnvVaultWriteLimiterMin, EnvVaultWriteLimiterMax, @@ -169,7 +171,7 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) { return } case SpecialPathLimiter: - disabled = r.processEnvVars(&flags, + disabled = r.processEnvVars(name, &flags, EnvVaultDisableSpecialPathLimiter, EnvVaultSpecialPathLimiterMin, EnvVaultSpecialPathLimiterMax, @@ -178,7 +180,7 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) { return } default: - r.Logger.Warn("skipping invalid limiter type", "key", flags.Name) + r.Logger.Warn("skipping invalid limiter type", "key", name) return } @@ -186,18 +188,19 @@ func (r *LimiterRegistry) Register(flags LimiterFlags) { // equilibrium, since max might be too high. flags.InitialLimit = flags.MinLimit - limiter, err := NewRequestLimiter(r.Logger.Named(flags.Name), flags) + limiter, err := NewRequestLimiter(r.Logger.Named(name), name, flags) if err != nil { - r.Logger.Error("failed to register limiter", "name", flags.Name, "error", err) + r.Logger.Error("failed to register limiter", "name", name, "error", err) return } - r.Limiters[flags.Name] = limiter + r.Limiters[name] = limiter } // Disable drops its references to underlying limiters. func (r *LimiterRegistry) Disable() { r.Lock() + defer r.Unlock() if !r.Enabled { return @@ -209,7 +212,6 @@ func (r *LimiterRegistry) Disable() { // here and the garbage-collector should take care of the rest. r.Limiters = map[string]*RequestLimiter{} r.Enabled = false - r.Unlock() } // GetLimiter looks up a RequestLimiter by key in the LimiterRegistry. diff --git a/sdk/logical/request.go b/sdk/logical/request.go index a291795648..24fddeff08 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -255,6 +255,9 @@ type Request struct { // Name of the chroot namespace for the listener that the request was made against ChrootNamespace string `json:"chroot_namespace,omitempty"` + + // RequestLimiterDisabled tells whether the request context has Request Limiter applied. + RequestLimiterDisabled bool `json:"request_limiter_disabled,omitempty"` } // Clone returns a deep copy (almost) of the request. diff --git a/vault/core.go b/vault/core.go index 44ac048988..7d8480f86d 100644 --- a/vault/core.go +++ b/vault/core.go @@ -725,6 +725,8 @@ func (c *Core) EchoDuration() time.Duration { } func (c *Core) GetRequestLimiter(key string) *limits.RequestLimiter { + c.limiterRegistryLock.Lock() + defer c.limiterRegistryLock.Unlock() return c.limiterRegistry.GetLimiter(key) } diff --git a/vault/logical_system.go b/vault/logical_system.go index 90ef973a4c..be59afa27a 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -226,6 +226,10 @@ func NewSystemBackend(core *Core, logger log.Logger, config *logical.BackendConf b.Backend.Paths = append(b.Backend.Paths, b.experimentPaths()...) b.Backend.Paths = append(b.Backend.Paths, b.introspectionPaths()...) + if requestLimiterRead := b.requestLimiterReadPath(); requestLimiterRead != nil { + b.Backend.Paths = append(b.Backend.Paths, b.requestLimiterReadPath()) + } + if core.rawEnabled { b.Backend.Paths = append(b.Backend.Paths, b.rawPaths()...) } diff --git a/vault/logical_system_limits.go b/vault/logical_system_limits.go new file mode 100644 index 0000000000..76443a23ba --- /dev/null +++ b/vault/logical_system_limits.go @@ -0,0 +1,12 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !testonly + +package vault + +import ( + "github.com/hashicorp/vault/sdk/framework" +) + +func (b *SystemBackend) requestLimiterReadPath() *framework.Path { return nil } diff --git a/vault/logical_system_limits_testonly.go b/vault/logical_system_limits_testonly.go new file mode 100644 index 0000000000..a4e0755bfb --- /dev/null +++ b/vault/logical_system_limits_testonly.go @@ -0,0 +1,88 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +//go:build testonly + +package vault + +import ( + "context" + "net/http" + + "github.com/hashicorp/vault/limits" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" +) + +// RequestLimiterResponse is a struct for marshalling Request Limiter status responses. +type RequestLimiterResponse struct { + GlobalDisabled bool `json:"global_disabled" mapstructure:"global_disabled"` + ListenerDisabled bool `json:"listener_disabled" mapstructure:"listener_disabled"` + Limiters map[string]*LimiterStatus `json:"types" mapstructure:"types"` +} + +// LimiterStatus holds the per-limiter status and flags for testing. +type LimiterStatus struct { + Enabled bool `json:"enabled" mapstructure:"enabled"` + Flags limits.LimiterFlags `json:"flags,omitempty" mapstructure:"flags,omitempty"` +} + +const readRequestLimiterHelpText = ` +Read the current status of the request limiter. +` + +func (b *SystemBackend) requestLimiterReadPath() *framework.Path { + return &framework.Path{ + Pattern: "internal/request-limiter/status$", + HelpDescription: readRequestLimiterHelpText, + HelpSynopsis: readRequestLimiterHelpText, + Operations: map[logical.Operation]framework.OperationHandler{ + logical.ReadOperation: &framework.PathOperation{ + Callback: b.handleReadRequestLimiter, + DisplayAttrs: &framework.DisplayAttributes{ + OperationVerb: "read", + OperationSuffix: "verbosity-level-for", + }, + Responses: map[int][]framework.Response{ + http.StatusOK: {{ + Description: "OK", + }}, + }, + Summary: "Read the current status of the request limiter.", + }, + }, + } +} + +// handleReadRequestLimiter returns the enabled Request Limiter status for this node. +func (b *SystemBackend) handleReadRequestLimiter(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + resp := &RequestLimiterResponse{ + Limiters: make(map[string]*LimiterStatus), + } + + b.Core.limiterRegistryLock.Lock() + registry := b.Core.limiterRegistry + b.Core.limiterRegistryLock.Unlock() + + resp.GlobalDisabled = !registry.Enabled + resp.ListenerDisabled = req.RequestLimiterDisabled + enabled := !(resp.GlobalDisabled || resp.ListenerDisabled) + + for name := range limits.DefaultLimiterFlags { + var flags limits.LimiterFlags + if requestLimiter := b.Core.GetRequestLimiter(name); requestLimiter != nil && enabled { + flags = requestLimiter.Flags + } + + resp.Limiters[name] = &LimiterStatus{ + Enabled: enabled, + Flags: flags, + } + } + + return &logical.Response{ + Data: map[string]interface{}{ + "request_limiter": resp, + }, + }, nil +}