vault/command/server_test.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

396 lines
10 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
//go:build !race && !hsm && !fips_140_3
// NOTE: we can't use this with HSM. We can't set testing mode on and it's not
// safe to use env vars since that provides an attack vector in the real world.
//
// The server tests have a go-metrics/exp manager race condition :(.
package command
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/vault/sdk/physical"
physInmem "github.com/hashicorp/vault/sdk/physical/inmem"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)
func init() {
if signed := os.Getenv("VAULT_LICENSE_CI"); signed != "" {
os.Setenv(EnvVaultLicense, signed)
}
}
func testBaseHCL(tb testing.TB, listenerExtras string) string {
tb.Helper()
return strings.TrimSpace(fmt.Sprintf(`
disable_mlock = true
listener "tcp" {
address = "127.0.0.1:%d"
tls_disable = "true"
%s
}
`, 0, listenerExtras))
}
const (
goodListenerTimeouts = `http_read_header_timeout = 12
http_read_timeout = "34s"
http_write_timeout = "56m"
http_idle_timeout = "78h"`
badListenerReadHeaderTimeout = `http_read_header_timeout = "12km"`
badListenerReadTimeout = `http_read_timeout = "34日"`
badListenerWriteTimeout = `http_write_timeout = "56lbs"`
badListenerIdleTimeout = `http_idle_timeout = "78gophers"`
inmemHCL = `
backend "inmem_ha" {
advertise_addr = "http://127.0.0.1:8200"
}
`
haInmemHCL = `
ha_backend "inmem_ha" {
redirect_addr = "http://127.0.0.1:8200"
}
`
badHAInmemHCL = `
ha_backend "inmem" {}
`
reloadHCL = `
backend "inmem" {}
disable_mlock = true
listener "tcp" {
address = "127.0.0.1:8203"
tls_cert_file = "TMPDIR/reload_cert.pem"
tls_key_file = "TMPDIR/reload_key.pem"
}
`
cloudHCL = `
cloud {
resource_id = "organization/bc58b3d0-2eab-4ab8-abf4-f61d3c9975ff/project/1c78e888-2142-4000-8918-f933bbbc7690/hashicorp.example.resource/example"
client_id = "J2TtcSYOyPUkPV2z0mSyDtvitxLVjJmu"
client_secret = "N9JtHZyOnHrIvJZs82pqa54vd4jnkyU3xCcqhFXuQKJZZuxqxxbP1xCfBZVB82vY"
}
`
)
func testServerCommand(tb testing.TB) (*cli.MockUi, *ServerCommand) {
tb.Helper()
ui := cli.NewMockUi()
return ui, &ServerCommand{
BaseCommand: &BaseCommand{
UI: ui,
},
ShutdownCh: MakeShutdownCh(),
SighupCh: MakeSighupCh(),
SigUSR2Ch: MakeSigUSR2Ch(),
PhysicalBackends: map[string]physical.Factory{
"inmem": physInmem.NewInmem,
"inmem_ha": physInmem.NewInmemHA,
},
// These prevent us from random sleep guessing...
startedCh: make(chan struct{}, 5),
reloadedCh: make(chan struct{}, 5),
licenseReloadedCh: make(chan error),
}
}
func TestServer_ReloadListener(t *testing.T) {
t.Parallel()
wd, _ := os.Getwd()
wd += "/server/test-fixtures/reload/"
td, err := ioutil.TempDir("", "vault-test-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(td)
wg := &sync.WaitGroup{}
// Setup initial certs
inBytes, _ := ioutil.ReadFile(wd + "reload_foo.pem")
ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0o777)
inBytes, _ = ioutil.ReadFile(wd + "reload_foo.key")
ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0o777)
relhcl := strings.ReplaceAll(reloadHCL, "TMPDIR", td)
ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0o777)
inBytes, _ = ioutil.ReadFile(wd + "reload_ca.pem")
certPool := x509.NewCertPool()
ok := certPool.AppendCertsFromPEM(inBytes)
if !ok {
t.Fatal("not ok when appending CA cert")
}
ui, cmd := testServerCommand(t)
_ = ui
wg.Add(1)
args := []string{"-config", td + "/reload.hcl"}
go func() {
if code := cmd.Run(args); code != 0 {
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
t.Errorf("got a non-zero exit status: %s", output)
}
wg.Done()
}()
testCertificateName := func(cn string) error {
conn, err := tls.Dial("tcp", "127.0.0.1:8203", &tls.Config{
RootCAs: certPool,
})
if err != nil {
return err
}
defer conn.Close()
if err = conn.Handshake(); err != nil {
return err
}
servName := conn.ConnectionState().PeerCertificates[0].Subject.CommonName
if servName != cn {
return fmt.Errorf("expected %s, got %s", cn, servName)
}
return nil
}
select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
if err := testCertificateName("foo.example.com"); err != nil {
t.Fatalf("certificate name didn't check out: %s", err)
}
relhcl = strings.ReplaceAll(reloadHCL, "TMPDIR", td)
inBytes, _ = ioutil.ReadFile(wd + "reload_bar.pem")
ioutil.WriteFile(td+"/reload_cert.pem", inBytes, 0o777)
inBytes, _ = ioutil.ReadFile(wd + "reload_bar.key")
ioutil.WriteFile(td+"/reload_key.pem", inBytes, 0o777)
ioutil.WriteFile(td+"/reload.hcl", []byte(relhcl), 0o777)
cmd.SighupCh <- struct{}{}
select {
case <-cmd.reloadedCh:
case <-time.After(5 * time.Second):
t.Fatalf("timeout")
}
if err := testCertificateName("bar.example.com"); err != nil {
t.Fatalf("certificate name didn't check out: %s", err)
}
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}
func TestServer(t *testing.T) {
t.Parallel()
cases := []struct {
name string
contents string
exp string
code int
args []string
}{
{
"common_ha",
testBaseHCL(t, "") + inmemHCL,
"(HA available)",
0,
[]string{"-test-verify-only"},
},
{
"separate_ha",
testBaseHCL(t, "") + inmemHCL + haInmemHCL,
"HA Storage:",
0,
[]string{"-test-verify-only"},
},
{
"bad_separate_ha",
testBaseHCL(t, "") + inmemHCL + badHAInmemHCL,
"Specified HA storage does not support HA",
1,
[]string{"-test-verify-only"},
},
{
"good_listener_timeout_config",
testBaseHCL(t, goodListenerTimeouts) + inmemHCL,
"",
0,
[]string{"-test-server-config"},
},
{
"bad_listener_read_header_timeout_config",
testBaseHCL(t, badListenerReadHeaderTimeout) + inmemHCL,
"unknown unit \"km\" in duration \"12km\"",
1,
[]string{"-test-server-config"},
},
{
"bad_listener_read_timeout_config",
testBaseHCL(t, badListenerReadTimeout) + inmemHCL,
"unknown unit \"\\xe6\\x97\\xa5\" in duration",
1,
[]string{"-test-server-config"},
},
{
"bad_listener_write_timeout_config",
testBaseHCL(t, badListenerWriteTimeout) + inmemHCL,
"unknown unit \"lbs\" in duration \"56lbs\"",
1,
[]string{"-test-server-config"},
},
{
"bad_listener_idle_timeout_config",
testBaseHCL(t, badListenerIdleTimeout) + inmemHCL,
"unknown unit \"gophers\" in duration \"78gophers\"",
1,
[]string{"-test-server-config"},
},
{
"environment_variables_logged",
testBaseHCL(t, "") + inmemHCL,
"Environment Variables",
0,
[]string{"-test-verify-only"},
},
{
"cloud_config",
testBaseHCL(t, "") + inmemHCL + cloudHCL,
"HCP Organization: bc58b3d0-2eab-4ab8-abf4-f61d3c9975ff",
0,
[]string{"-test-verify-only"},
},
{
"recovery_mode",
testBaseHCL(t, "") + inmemHCL,
"",
0,
[]string{"-test-verify-only", "-recovery"},
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ui, cmd := testServerCommand(t)
f, err := os.CreateTemp(t.TempDir(), "")
require.NoErrorf(t, err, "error creating temp dir: %v", err)
_, err = f.WriteString(tc.contents)
require.NoErrorf(t, err, "cannot write temp file contents")
err = f.Close()
require.NoErrorf(t, err, "unable to close temp file")
args := append(tc.args, "-config", f.Name())
code := cmd.Run(args)
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Equal(t, tc.code, code, "expected %d to be %d: %s", code, tc.code, output)
require.Contains(t, output, tc.exp, "expected %q to contain %q", output, tc.exp)
})
}
}
// TestServer_DevTLS verifies that a vault server starts up correctly with the -dev-tls flag
func TestServer_DevTLS(t *testing.T) {
ui, cmd := testServerCommand(t)
args := []string{"-dev-tls", "-dev-listen-address=127.0.0.1:0", "-test-server-config"}
retCode := cmd.Run(args)
output := ui.ErrorWriter.String() + ui.OutputWriter.String()
require.Equal(t, 0, retCode, output)
require.Contains(t, output, `tls: "enabled"`)
}
// TestConfigureDevTLS verifies the various logic paths that flow through the
// configureDevTLS function.
func TestConfigureDevTLS(t *testing.T) {
testcases := []struct {
ServerCommand *ServerCommand
DeferFuncNotNil bool
ConfigNotNil bool
TLSDisable bool
CertPathEmpty bool
ErrNotNil bool
TestDescription string
}{
{
ServerCommand: &ServerCommand{
flagDevTLS: false,
},
ConfigNotNil: true,
TLSDisable: true,
CertPathEmpty: true,
ErrNotNil: false,
TestDescription: "flagDev is false, nothing will be configured",
},
{
ServerCommand: &ServerCommand{
flagDevTLS: true,
flagDevTLSCertDir: "",
},
DeferFuncNotNil: true,
ConfigNotNil: true,
ErrNotNil: false,
TestDescription: "flagDevTLSCertDir is empty",
},
{
ServerCommand: &ServerCommand{
flagDevTLS: true,
flagDevTLSCertDir: "@/#",
},
CertPathEmpty: true,
ErrNotNil: true,
TestDescription: "flagDevTLSCertDir is set to something invalid",
},
}
for _, testcase := range testcases {
fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand)
if fun != nil {
// If a function is returned, call it right away to clean up
// files created in the temporary directory before anything else has
// a chance to fail this test.
fun()
}
require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription)
require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription)
if testcase.ConfigNotNil {
require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription)
require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription)
}
require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription)
require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription)
}
}