mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-10 00:27:02 +02:00
We have many hand-written String() methods (and similar) for enums. These require more maintenance and are more error-prone than using automatically generated methods. In addition, the auto-generated versions can be more efficient. Here, we switch to using https://github.com/loggerhead/enumer, itself a fork of https://github.com/diegostamigni/enumer, no longer maintained, and a fork of the mostly standard tool https://pkg.go.dev/golang.org/x/tools/cmd/stringer. We use this fork of enumer for Go 1.20+ compatibility and because we require the `-transform` flag to be able to generate constants that match our current code base. Some enums were not targeted for this change:
464 lines
12 KiB
Go
464 lines
12 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package testcluster
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/api"
|
|
"github.com/hashicorp/vault/sdk/helper/xor"
|
|
)
|
|
|
|
// Note that OSS standbys will not accept seal requests. And ent perf standbys
|
|
// may fail it as well if they haven't yet been able to get "elected" as perf standbys.
|
|
func SealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
err := client.Sys().SealWithContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return NodeSealed(ctx, cluster, nodeIdx)
|
|
}
|
|
|
|
func SealAllNodes(ctx context.Context, cluster VaultCluster) error {
|
|
for i := range cluster.Nodes() {
|
|
if err := SealNode(ctx, cluster, i); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func UnsealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
for _, key := range cluster.GetBarrierOrRecoveryKeys() {
|
|
_, err := client.Sys().UnsealWithContext(ctx, hex.EncodeToString(key))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return NodeHealthy(ctx, cluster, nodeIdx)
|
|
}
|
|
|
|
func UnsealAllNodes(ctx context.Context, cluster VaultCluster) error {
|
|
for i := range cluster.Nodes() {
|
|
if err := UnsealNode(ctx, cluster, i); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func NodeSealed(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
var health *api.HealthResponse
|
|
var err error
|
|
for ctx.Err() == nil {
|
|
health, err = client.Sys().HealthWithContext(ctx)
|
|
switch {
|
|
case err != nil:
|
|
case !health.Sealed:
|
|
err = fmt.Errorf("unsealed: %#v", health)
|
|
default:
|
|
return nil
|
|
}
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
return fmt.Errorf("node %d is not sealed: %v", nodeIdx, err)
|
|
}
|
|
|
|
func WaitForNCoresSealed(ctx context.Context, cluster VaultCluster, n int) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
errs := make(chan error)
|
|
for i := range cluster.Nodes() {
|
|
go func(i int) {
|
|
var err error
|
|
for ctx.Err() == nil {
|
|
err = NodeSealed(ctx, cluster, i)
|
|
if err == nil {
|
|
errs <- nil
|
|
return
|
|
}
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
if err == nil {
|
|
err = ctx.Err()
|
|
}
|
|
errs <- err
|
|
}(i)
|
|
}
|
|
|
|
var merr *multierror.Error
|
|
var sealed int
|
|
for range cluster.Nodes() {
|
|
err := <-errs
|
|
if err != nil {
|
|
merr = multierror.Append(merr, err)
|
|
} else {
|
|
sealed++
|
|
if sealed == n {
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("%d cores were not sealed, errs: %v", n, merr.ErrorOrNil())
|
|
}
|
|
|
|
func NodeHealthy(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
var health *api.HealthResponse
|
|
var err error
|
|
for ctx.Err() == nil {
|
|
health, err = client.Sys().HealthWithContext(ctx)
|
|
switch {
|
|
case err != nil:
|
|
case health == nil:
|
|
err = fmt.Errorf("nil response to health check")
|
|
case health.Sealed:
|
|
err = fmt.Errorf("sealed: %#v", health)
|
|
default:
|
|
return nil
|
|
}
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
return fmt.Errorf("node %d is unhealthy: %v", nodeIdx, err)
|
|
}
|
|
|
|
func LeaderNode(ctx context.Context, cluster VaultCluster) (int, error) {
|
|
// Be robust to multiple nodes thinking they are active. This is possible in
|
|
// certain network partition situations where the old leader has not
|
|
// discovered it's lost leadership yet. In tests this is only likely to come
|
|
// up when we are specifically provoking it, but it's possible it could happen
|
|
// at any point if leadership flaps of connectivity suffers transient errors
|
|
// etc. so be robust against it. The best solution would be to have some sort
|
|
// of epoch like the raft term that is guaranteed to be monotonically
|
|
// increasing through elections, however we don't have that abstraction for
|
|
// all HABackends in general. The best we have is the ActiveTime. In a
|
|
// distributed systems text book this would be bad to rely on due to clock
|
|
// sync issues etc. but for our tests it's likely fine because even if we are
|
|
// running separate Vault containers, they are all using the same hardware
|
|
// clock in the system.
|
|
leaderActiveTimes := make(map[int]time.Time)
|
|
for i, node := range cluster.Nodes() {
|
|
client := node.APIClient()
|
|
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
|
resp, err := client.Sys().LeaderWithContext(ctx)
|
|
cancel()
|
|
if err != nil || resp == nil || !resp.IsSelf {
|
|
continue
|
|
}
|
|
leaderActiveTimes[i] = resp.ActiveTime
|
|
}
|
|
if len(leaderActiveTimes) == 0 {
|
|
return -1, fmt.Errorf("no leader found")
|
|
}
|
|
// At least one node thinks it is active. If multiple, pick the one with the
|
|
// most recent ActiveTime. Note if there is only one then this just returns
|
|
// it.
|
|
var newestLeaderIdx int
|
|
var newestActiveTime time.Time
|
|
for i, at := range leaderActiveTimes {
|
|
if at.After(newestActiveTime) {
|
|
newestActiveTime = at
|
|
newestLeaderIdx = i
|
|
}
|
|
}
|
|
return newestLeaderIdx, nil
|
|
}
|
|
|
|
func WaitForActiveNode(ctx context.Context, cluster VaultCluster) (int, error) {
|
|
for ctx.Err() == nil {
|
|
if idx, _ := LeaderNode(ctx, cluster); idx != -1 {
|
|
return idx, nil
|
|
}
|
|
time.Sleep(500 * time.Millisecond)
|
|
}
|
|
return -1, ctx.Err()
|
|
}
|
|
|
|
func WaitForStandbyNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
|
|
if nodeIdx >= len(cluster.Nodes()) {
|
|
return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
|
|
}
|
|
node := cluster.Nodes()[nodeIdx]
|
|
client := node.APIClient()
|
|
|
|
var err error
|
|
for ctx.Err() == nil {
|
|
var resp *api.LeaderResponse
|
|
|
|
resp, err = client.Sys().LeaderWithContext(ctx)
|
|
switch {
|
|
case err != nil:
|
|
case resp.IsSelf:
|
|
return fmt.Errorf("waiting for standby but node is leader")
|
|
case resp.LeaderAddress == "":
|
|
err = fmt.Errorf("node doesn't know leader address")
|
|
default:
|
|
return nil
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
}
|
|
if err == nil {
|
|
err = ctx.Err()
|
|
}
|
|
return err
|
|
}
|
|
|
|
func WaitForActiveNodeAndStandbys(ctx context.Context, cluster VaultCluster) (int, error) {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
leaderIdx, err := WaitForActiveNode(ctx, cluster)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if len(cluster.Nodes()) == 1 {
|
|
return 0, nil
|
|
}
|
|
|
|
errs := make(chan error)
|
|
for i := range cluster.Nodes() {
|
|
if i == leaderIdx {
|
|
continue
|
|
}
|
|
go func(i int) {
|
|
errs <- WaitForStandbyNode(ctx, cluster, i)
|
|
}(i)
|
|
}
|
|
|
|
var merr *multierror.Error
|
|
expectedStandbys := len(cluster.Nodes()) - 1
|
|
for i := 0; i < expectedStandbys; i++ {
|
|
merr = multierror.Append(merr, <-errs)
|
|
}
|
|
|
|
return leaderIdx, merr.ErrorOrNil()
|
|
}
|
|
|
|
func WaitForActiveNodeAndPerfStandbys(ctx context.Context, cluster VaultCluster) error {
|
|
logger := cluster.NamedLogger("WaitForActiveNodeAndPerfStandbys")
|
|
// This WaitForActiveNode was added because after a Raft cluster is sealed
|
|
// and then unsealed, when it comes up it may have a different leader than
|
|
// Core0, making this helper fail.
|
|
// A sleep before calling WaitForActiveNodeAndPerfStandbys seems to sort
|
|
// things out, but so apparently does this. We should be able to eliminate
|
|
// this call to WaitForActiveNode by reworking the logic in this method.
|
|
leaderIdx, err := WaitForActiveNode(ctx, cluster)
|
|
if err != nil {
|
|
return fmt.Errorf("did not find leader: %w", err)
|
|
}
|
|
|
|
if len(cluster.Nodes()) == 1 {
|
|
return nil
|
|
}
|
|
|
|
expectedStandbys := len(cluster.Nodes()) - 1
|
|
|
|
mountPoint, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
leaderClient := cluster.Nodes()[leaderIdx].APIClient()
|
|
|
|
for ctx.Err() == nil {
|
|
err = leaderClient.Sys().MountWithContext(ctx, mountPoint, &api.MountInput{
|
|
Type: "kv",
|
|
Local: true,
|
|
})
|
|
if err == nil {
|
|
break
|
|
}
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("unable to mount KV engine: %w", err)
|
|
}
|
|
path := mountPoint + "/waitforactivenodeandperfstandbys"
|
|
var standbys, actives int64
|
|
errchan := make(chan error, len(cluster.Nodes()))
|
|
for i := range cluster.Nodes() {
|
|
go func(coreNo int) {
|
|
node := cluster.Nodes()[coreNo]
|
|
client := node.APIClient()
|
|
val := 1
|
|
var err error
|
|
defer func() {
|
|
errchan <- err
|
|
}()
|
|
|
|
var lastWAL uint64
|
|
for ctx.Err() == nil {
|
|
_, err = leaderClient.Logical().WriteWithContext(ctx, path, map[string]interface{}{
|
|
"bar": val,
|
|
})
|
|
val++
|
|
time.Sleep(250 * time.Millisecond)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
var leader *api.LeaderResponse
|
|
leader, err = client.Sys().LeaderWithContext(ctx)
|
|
if err != nil {
|
|
logger.Trace("waiting for core", "core", coreNo, "err", err)
|
|
continue
|
|
}
|
|
switch {
|
|
case leader.IsSelf:
|
|
logger.Trace("waiting for core", "core", coreNo, "isLeader", true)
|
|
atomic.AddInt64(&actives, 1)
|
|
return
|
|
case leader.PerfStandby && leader.PerfStandbyLastRemoteWAL > 0:
|
|
switch {
|
|
case lastWAL == 0:
|
|
lastWAL = leader.PerfStandbyLastRemoteWAL
|
|
logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
|
|
case lastWAL < leader.PerfStandbyLastRemoteWAL:
|
|
logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
|
|
atomic.AddInt64(&standbys, 1)
|
|
return
|
|
}
|
|
default:
|
|
logger.Trace("waiting for core", "core", coreNo,
|
|
"ha_enabled", leader.HAEnabled,
|
|
"is_self", leader.IsSelf,
|
|
"perf_standby", leader.PerfStandby,
|
|
"perf_standby_remote_wal", leader.PerfStandbyLastRemoteWAL)
|
|
}
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
errs := make([]error, 0, len(cluster.Nodes()))
|
|
for range cluster.Nodes() {
|
|
errs = append(errs, <-errchan)
|
|
}
|
|
if actives != 1 || int(standbys) != expectedStandbys {
|
|
return fmt.Errorf("expected 1 active core and %d standbys, got %d active and %d standbys, errs: %v",
|
|
expectedStandbys, actives, standbys, errs)
|
|
}
|
|
|
|
for ctx.Err() == nil {
|
|
err = leaderClient.Sys().UnmountWithContext(ctx, mountPoint)
|
|
if err == nil {
|
|
break
|
|
}
|
|
time.Sleep(time.Second)
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("unable to unmount KV engine: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func Clients(vc VaultCluster) []*api.Client {
|
|
var ret []*api.Client
|
|
for _, n := range vc.Nodes() {
|
|
ret = append(ret, n.APIClient())
|
|
}
|
|
return ret
|
|
}
|
|
|
|
//go:generate enumer -type=GenerateRootKind -trimprefix=GenerateRoot
|
|
type GenerateRootKind int
|
|
|
|
const (
|
|
GenerateRootRegular GenerateRootKind = iota
|
|
GenerateRootDR
|
|
GenerateRecovery
|
|
)
|
|
|
|
func GenerateRoot(cluster VaultCluster, kind GenerateRootKind) (string, error) {
|
|
// If recovery keys supported, use those to perform root token generation instead
|
|
keys := cluster.GetBarrierOrRecoveryKeys()
|
|
|
|
client := cluster.Nodes()[0].APIClient()
|
|
|
|
var err error
|
|
var status *api.GenerateRootStatusResponse
|
|
switch kind {
|
|
case GenerateRootRegular:
|
|
status, err = client.Sys().GenerateRootInit("", "")
|
|
case GenerateRootDR:
|
|
status, err = client.Sys().GenerateDROperationTokenInit("", "")
|
|
case GenerateRecovery:
|
|
status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if status.Required > len(keys) {
|
|
return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys))
|
|
}
|
|
|
|
otp := status.OTP
|
|
|
|
for i, key := range keys {
|
|
if i >= status.Required {
|
|
break
|
|
}
|
|
|
|
strKey := base64.StdEncoding.EncodeToString(key)
|
|
switch kind {
|
|
case GenerateRootRegular:
|
|
status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
|
|
case GenerateRootDR:
|
|
status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
|
|
case GenerateRecovery:
|
|
status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if !status.Complete {
|
|
return "", fmt.Errorf("generate root operation did not end successfully")
|
|
}
|
|
|
|
tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(tokenBytes), nil
|
|
}
|