vault/command/command_testonly/server_testonly_test.go
Mike Palmiotto 7ad778541e
Disable Request Limiter by default (#25442)
This PR flips the logic for the Request Limiter, setting it to default
disabled.

We allow users to turn on the global Request Limiter, but leave the
Listener configuration as a "disable per Listener".
2024-02-16 17:50:18 -05:00

214 lines
4.7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build testonly
package command_testonly
import (
"os"
"sync"
"testing"
"time"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command"
"github.com/hashicorp/vault/limits"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/require"
)
func init() {
if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" {
os.Setenv(command.EnvVaultLicense, signed)
}
}
const (
baseHCL = `
backend "inmem" { }
disable_mlock = true
listener "tcp" {
address = "127.0.0.1:8209"
tls_disable = "true"
}
api_addr = "http://127.0.0.1:8209"
`
requestLimiterDisableHCL = `
request_limiter {
disable = true
}
`
requestLimiterEnableHCL = `
request_limiter {
disable = false
}
`
)
// TestServer_ReloadRequestLimiter tests a series of reloads and state
// transitions between RequestLimiter enable and disable.
func TestServer_ReloadRequestLimiter(t *testing.T) {
t.Parallel()
enabledResponse := &vault.RequestLimiterResponse{
GlobalDisabled: false,
ListenerDisabled: false,
Limiters: map[string]*vault.LimiterStatus{
limits.WriteLimiter: {
Enabled: true,
Flags: limits.DefaultLimiterFlags[limits.WriteLimiter],
},
limits.SpecialPathLimiter: {
Enabled: true,
Flags: limits.DefaultLimiterFlags[limits.SpecialPathLimiter],
},
},
}
disabledResponse := &vault.RequestLimiterResponse{
GlobalDisabled: true,
ListenerDisabled: false,
Limiters: map[string]*vault.LimiterStatus{
limits.WriteLimiter: {
Enabled: false,
},
limits.SpecialPathLimiter: {
Enabled: false,
},
},
}
cases := []struct {
name string
configAfter string
expectedResponse *vault.RequestLimiterResponse
}{
{
"disable after default",
baseHCL + requestLimiterDisableHCL,
disabledResponse,
},
{
"disable after disable",
baseHCL + requestLimiterDisableHCL,
disabledResponse,
},
{
"enable after disable",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
{
"default after enable",
baseHCL,
disabledResponse,
},
{
"default after default",
baseHCL,
disabledResponse,
},
{
"enable after default",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
{
"enable after enable",
baseHCL + requestLimiterEnableHCL,
enabledResponse,
},
}
ui, srv := command.TestServerCommand(t)
f, err := os.CreateTemp(t.TempDir(), "")
require.NoErrorf(t, err, "error creating temp dir: %v", err)
_, err = f.WriteString(baseHCL)
require.NoErrorf(t, err, "cannot write temp file contents")
configPath := f.Name()
var output string
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
code := srv.Run([]string{"-config", configPath})
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Equal(t, 0, code, output)
}()
select {
case <-srv.StartedCh():
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
defer func() {
srv.ShutdownCh <- struct{}{}
wg.Wait()
}()
err = f.Close()
require.NoErrorf(t, err, "unable to close temp file")
// create a client and unseal vault
cli, err := srv.Client()
require.NoError(t, err)
require.NoError(t, cli.SetAddress("http://127.0.0.1:8209"))
initResp, err := cli.Sys().Init(&api.InitRequest{SecretShares: 1, SecretThreshold: 1})
require.NoError(t, err)
_, err = cli.Sys().Unseal(initResp.Keys[0])
require.NoError(t, err)
cli.SetToken(initResp.RootToken)
output = ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Contains(t, output, "Request Limiter: disabled")
verifyLimiters := func(t *testing.T, expectedResponse *vault.RequestLimiterResponse) {
t.Helper()
statusResp, err := cli.Logical().Read("/sys/internal/request-limiter/status")
require.NoError(t, err)
require.NotNil(t, statusResp)
limitersResp, ok := statusResp.Data["request_limiter"]
require.True(t, ok)
require.NotNil(t, limitersResp)
var limiters *vault.RequestLimiterResponse
err = mapstructure.Decode(limitersResp, &limiters)
require.NoError(t, err)
require.NotNil(t, limiters)
require.Equal(t, expectedResponse, limiters)
}
// Start off with default disabled
verifyLimiters(t, disabledResponse)
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Write the new contents and reload the server
f, err = os.OpenFile(configPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0o644)
require.NoError(t, err)
defer f.Close()
_, err = f.WriteString(tc.configAfter)
require.NoErrorf(t, err, "cannot write temp file contents")
srv.SighupCh <- struct{}{}
select {
case <-srv.ReloadedCh():
case <-time.After(5 * time.Second):
t.Fatalf("test timed out")
}
verifyLimiters(t, tc.expectedResponse)
})
}
}