vault/http/forwarding_test.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

609 lines
16 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package http
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/net/http2"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault/api"
credCert "github.com/hashicorp/vault/builtin/credential/cert"
"github.com/hashicorp/vault/builtin/logical/transit"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/keysutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault"
)
func TestHTTP_Fallback_Bad_Address(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": transit.Factory,
},
ClusterAddr: "https://127.3.4.1:8382",
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
// make it easy to get access to the active
core := cores[0].Core
vault.TestWaitActive(t, core)
addrs := []string{
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
}
for _, addr := range addrs {
config := api.DefaultConfig()
config.Address = addr
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig()
client, err := api.NewClient(config)
if err != nil {
t.Fatal(err)
}
client.SetToken(cluster.RootToken)
secret, err := client.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
if secret == nil {
t.Fatal("secret is nil")
}
if secret.Data["id"].(string) != cluster.RootToken {
t.Fatal("token mismatch")
}
}
}
func TestHTTP_Fallback_Disabled(t *testing.T) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": transit.Factory,
},
ClusterAddr: "empty",
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
// make it easy to get access to the active
core := cores[0].Core
vault.TestWaitActive(t, core)
addrs := []string{
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
}
for _, addr := range addrs {
config := api.DefaultConfig()
config.Address = addr
config.HttpClient.Transport.(*http.Transport).TLSClientConfig = cores[0].TLSConfig()
client, err := api.NewClient(config)
if err != nil {
t.Fatal(err)
}
client.SetToken(cluster.RootToken)
secret, err := client.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
if secret == nil {
t.Fatal("secret is nil")
}
if secret.Data["id"].(string) != cluster.RootToken {
t.Fatal("token mismatch")
}
}
}
// This function recreates the fuzzy testing from transit to pipe a large
// number of requests from the standbys to the active node.
func TestHTTP_Forwarding_Stress(t *testing.T) {
testHTTP_Forwarding_Stress_Common(t, false, 50)
testHTTP_Forwarding_Stress_Common(t, true, 50)
}
func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint32) {
testPlaintext := "the quick brown fox"
testPlaintextB64 := "dGhlIHF1aWNrIGJyb3duIGZveA=="
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"transit": transit.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
// make it easy to get access to the active
core := cores[0].Core
vault.TestWaitActive(t, core)
wg := sync.WaitGroup{}
funcs := []string{"encrypt", "decrypt", "rotate", "change_min_version"}
keys := []string{"test1", "test2", "test3"}
hosts := []string{
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[1].Listeners[0].Address.Port),
fmt.Sprintf("https://127.0.0.1:%d/v1/transit/", cores[2].Listeners[0].Address.Port),
}
transport := &http.Transport{
TLSClientConfig: cores[0].TLSConfig(),
}
if err := http2.ConfigureTransport(transport); err != nil {
t.Fatal(err)
}
client := &http.Client{
Transport: transport,
CheckRedirect: func(*http.Request, []*http.Request) error {
return fmt.Errorf("redirects not allowed in this test")
},
}
// core.Logger().Printf("[TRACE] mounting transit")
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/mounts/transit", cores[0].Listeners[0].Address.Port),
bytes.NewBufferString("{\"type\": \"transit\"}"))
if err != nil {
t.Fatal(err)
}
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
// core.Logger().Printf("[TRACE] done mounting transit")
var totalOps *uint32 = new(uint32)
var successfulOps *uint32 = new(uint32)
var key1ver *int32 = new(int32)
*key1ver = 1
var key2ver *int32 = new(int32)
*key2ver = 1
var key3ver *int32 = new(int32)
*key3ver = 1
var numWorkers *uint32 = new(uint32)
*numWorkers = 50
var numWorkersStarted *uint32 = new(uint32)
var waitLock sync.Mutex
waitCond := sync.NewCond(&waitLock)
// This is the goroutine loop
doFuzzy := func(id int, parallel bool) {
var myTotalOps uint32
var mySuccessfulOps uint32
var keyVer int32 = 1
// Check for panics, otherwise notify we're done
defer func() {
if err := recover(); err != nil {
core.Logger().Error("got a panic", "error", err)
t.Fail()
}
atomic.AddUint32(totalOps, myTotalOps)
atomic.AddUint32(successfulOps, mySuccessfulOps)
wg.Done()
}()
// Holds the latest encrypted value for each key
latestEncryptedText := map[string]string{}
client := &http.Client{
Transport: transport,
}
var chosenFunc, chosenKey, chosenHost string
myRand := rand.New(rand.NewSource(int64(id) * 400))
doReq := func(method, url string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, err
}
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
doResp := func(resp *http.Response) (*api.Secret, error) {
if resp == nil {
return nil, fmt.Errorf("nil response")
}
defer resp.Body.Close()
// Make sure we weren't redirected
if resp.StatusCode > 300 && resp.StatusCode < 400 {
return nil, fmt.Errorf("got status code %d, resp was %#v", resp.StatusCode, *resp)
}
result := &api.Response{Response: resp}
err := result.Error()
if err != nil {
return nil, err
}
secret, err := api.ParseSecret(result.Body)
if err != nil {
return nil, err
}
return secret, nil
}
for _, chosenHost := range hosts {
for _, chosenKey := range keys {
// Try to write the key to make sure it exists
_, err := doReq("POST", chosenHost+"keys/"+fmt.Sprintf("%s-%t", chosenKey, parallel), bytes.NewBufferString("{}"))
if err != nil {
panic(err)
}
}
}
if !parallel {
chosenHost = hosts[id%len(hosts)]
chosenKey = fmt.Sprintf("key-%t-%d", parallel, id)
_, err := doReq("POST", chosenHost+"keys/"+chosenKey, bytes.NewBufferString("{}"))
if err != nil {
panic(err)
}
}
atomic.AddUint32(numWorkersStarted, 1)
waitCond.L.Lock()
for atomic.LoadUint32(numWorkersStarted) != atomic.LoadUint32(numWorkers) {
waitCond.Wait()
}
waitCond.L.Unlock()
waitCond.Broadcast()
core.Logger().Debug("Starting goroutine", "id", id)
startTime := time.Now()
for {
// Stop after 10 seconds
if time.Now().Sub(startTime) > 10*time.Second {
return
}
myTotalOps++
// Pick a function and a key
chosenFunc = funcs[myRand.Int()%len(funcs)]
if parallel {
chosenKey = fmt.Sprintf("%s-%t", keys[myRand.Int()%len(keys)], parallel)
chosenHost = hosts[myRand.Int()%len(hosts)]
}
switch chosenFunc {
// Encrypt our plaintext and store the result
case "encrypt":
// core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
resp, err := doReq("POST", chosenHost+"encrypt/"+chosenKey, bytes.NewBufferString(fmt.Sprintf("{\"plaintext\": \"%s\"}", testPlaintextB64)))
if err != nil {
panic(err)
}
secret, err := doResp(resp)
if err != nil {
panic(err)
}
latest := secret.Data["ciphertext"].(string)
if latest == "" {
panic(fmt.Errorf("bad ciphertext"))
}
latestEncryptedText[chosenKey] = secret.Data["ciphertext"].(string)
mySuccessfulOps++
// Decrypt the ciphertext and compare the result
case "decrypt":
ct := latestEncryptedText[chosenKey]
if ct == "" {
mySuccessfulOps++
continue
}
// core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
resp, err := doReq("POST", chosenHost+"decrypt/"+chosenKey, bytes.NewBufferString(fmt.Sprintf("{\"ciphertext\": \"%s\"}", ct)))
if err != nil {
panic(err)
}
secret, err := doResp(resp)
if err != nil {
// This could well happen since the min version is jumping around
if strings.Contains(err.Error(), keysutil.ErrTooOld) {
mySuccessfulOps++
continue
}
panic(err)
}
ptb64 := secret.Data["plaintext"].(string)
pt, err := base64.StdEncoding.DecodeString(ptb64)
if err != nil {
panic(fmt.Errorf("got an error decoding base64 plaintext: %v", err))
}
if string(pt) != testPlaintext {
panic(fmt.Errorf("got bad plaintext back: %s", pt))
}
mySuccessfulOps++
// Rotate to a new key version
case "rotate":
// core.Logger().Printf("[TRACE] %s, %s, %d", chosenFunc, chosenKey, id)
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/rotate", bytes.NewBufferString("{}"))
if err != nil {
panic(err)
}
if parallel {
switch chosenKey {
case "test1":
atomic.AddInt32(key1ver, 1)
case "test2":
atomic.AddInt32(key2ver, 1)
case "test3":
atomic.AddInt32(key3ver, 1)
}
} else {
keyVer++
}
mySuccessfulOps++
// Change the min version, which also tests the archive functionality
case "change_min_version":
var latestVersion int32 = keyVer
if parallel {
switch chosenKey {
case "test1":
latestVersion = atomic.LoadInt32(key1ver)
case "test2":
latestVersion = atomic.LoadInt32(key2ver)
case "test3":
latestVersion = atomic.LoadInt32(key3ver)
}
}
setVersion := (myRand.Int31() % latestVersion) + 1
// core.Logger().Printf("[TRACE] %s, %s, %d, new min version %d", chosenFunc, chosenKey, id, setVersion)
_, err := doReq("POST", chosenHost+"keys/"+chosenKey+"/config", bytes.NewBufferString(fmt.Sprintf("{\"min_decryption_version\": %d}", setVersion)))
if err != nil {
panic(err)
}
mySuccessfulOps++
}
}
}
atomic.StoreUint32(numWorkers, num)
// Spawn some of these workers for 10 seconds
for i := 0; i < int(atomic.LoadUint32(numWorkers)); i++ {
wg.Add(1)
// core.Logger().Printf("[TRACE] spawning %d", i)
go doFuzzy(i+1, parallel)
}
// Wait for them all to finish
wg.Wait()
if *totalOps == 0 || *totalOps != *successfulOps {
t.Fatalf("total/successful ops zero or mismatch: %d/%d; parallel: %t, num %d", *totalOps, *successfulOps, parallel, num)
}
t.Logf("total operations tried: %d, total successful: %d; parallel: %t, num %d", *totalOps, *successfulOps, parallel, num)
}
// This tests TLS connection state forwarding by ensuring that we can use a
// client TLS to authenticate against the cert backend
func TestHTTP_Forwarding_ClientTLS(t *testing.T) {
coreConfig := &vault.CoreConfig{
CredentialBackends: map[string]logical.Factory{
"cert": credCert.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
// make it easy to get access to the active
core := cores[0].Core
vault.TestWaitActive(t, core)
transport := cleanhttp.DefaultTransport()
transport.TLSClientConfig = cores[0].TLSConfig()
if err := http2.ConfigureTransport(transport); err != nil {
t.Fatal(err)
}
client := &http.Client{
Transport: transport,
}
req, err := http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/sys/auth/cert", cores[0].Listeners[0].Address.Port),
bytes.NewBufferString("{\"type\": \"cert\"}"))
if err != nil {
t.Fatal(err)
}
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
type certConfig struct {
Certificate string `json:"certificate"`
Policies string `json:"policies"`
}
encodedCertConfig, err := json.Marshal(&certConfig{
Certificate: string(cluster.CACertPEM),
Policies: "default",
})
if err != nil {
t.Fatal(err)
}
req, err = http.NewRequest("POST", fmt.Sprintf("https://127.0.0.1:%d/v1/auth/cert/certs/test", cores[0].Listeners[0].Address.Port),
bytes.NewBuffer(encodedCertConfig))
if err != nil {
t.Fatal(err)
}
req.Header.Set(consts.AuthHeaderName, cluster.RootToken)
_, err = client.Do(req)
if err != nil {
t.Fatal(err)
}
addrs := []string{
fmt.Sprintf("https://127.0.0.1:%d", cores[1].Listeners[0].Address.Port),
fmt.Sprintf("https://127.0.0.1:%d", cores[2].Listeners[0].Address.Port),
}
for i, addr := range addrs {
// Ensure we can't possibly use lingering connections even though it should
// be to a different address
transport = cleanhttp.DefaultTransport()
// i starts at zero but cores in addrs start at 1
transport.TLSClientConfig = cores[i+1].TLSConfig()
if err := http2.ConfigureTransport(transport); err != nil {
t.Fatal(err)
}
httpClient := &http.Client{
Transport: transport,
CheckRedirect: func(*http.Request, []*http.Request) error {
return fmt.Errorf("redirects not allowed in this test")
},
}
client, err := api.NewClient(&api.Config{
Address: addr,
HttpClient: httpClient,
})
if err != nil {
t.Fatal(err)
}
secret, err := client.Logical().Write("auth/cert/login", nil)
if err != nil {
t.Fatal(err)
}
if secret == nil {
t.Fatal("secret is nil")
}
if secret.Auth == nil {
t.Fatal("auth is nil")
}
if secret.Auth.Policies == nil || len(secret.Auth.Policies) == 0 || secret.Auth.Policies[0] != "default" {
t.Fatalf("bad policies: %#v", secret.Auth.Policies)
}
if secret.Auth.ClientToken == "" {
t.Fatalf("bad client token: %#v", *secret.Auth)
}
client.SetToken(secret.Auth.ClientToken)
secret, err = client.Auth().Token().LookupSelf()
if err != nil {
t.Fatal(err)
}
if secret == nil {
t.Fatal("secret is nil")
}
if secret.Data == nil || len(secret.Data) == 0 {
t.Fatal("secret data was empty")
}
}
}
func TestHTTP_Forwarding_HelpOperation(t *testing.T) {
cluster := vault.NewTestCluster(t, &vault.CoreConfig{}, &vault.TestClusterOptions{
HandlerFunc: Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
testHelp := func(client *api.Client) {
help, err := client.Help("auth/token")
if err != nil {
t.Fatal(err)
}
if help == nil {
t.Fatal("help was nil")
}
}
testHelp(cores[0].Client)
testHelp(cores[1].Client)
}
func TestHTTP_Forwarding_LocalOnly(t *testing.T) {
cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{
HandlerFunc: Handler,
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
vault.TestWaitActive(t, cores[0].Core)
testLocalOnly := func(client *api.Client) {
_, err := client.Logical().Read("sys/config/state/sanitized")
if err == nil {
t.Fatal("expected error")
}
}
testLocalOnly(cores[1].Client)
testLocalOnly(cores[2].Client)
}