VAULT-19232 Add static secret capability manager to Vault Proxy (#23677)

* VAULT-19232 static secret capability manager so far

* VAULT-19232 mostly finish renewal job logic

* VAULT-19232 some clean up, tests, etc

* VAULT-19232 integrate capability manager with proxy, add E2E test

* VAULT-19232 boltdb stuff

* VAULT-19232 finishing touches

* VAULT-19232 typo

* VAULT-19232 add capabilities index cachememdb tests

* Remove erroneous "the"

Co-authored-by: Kuba Wieczorek <kuba.wieczorek@hashicorp.com>

---------

Co-authored-by: Kuba Wieczorek <kuba.wieczorek@hashicorp.com>
This commit is contained in:
Violet Hynes 2023-10-25 16:43:24 -04:00 committed by GitHub
parent c0ad3f6ce2
commit 363557d045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1255 additions and 21 deletions

View File

@ -165,7 +165,7 @@ func createV1BoltSchema(tx *bolt.Tx) error {
func createV2BoltSchema(tx *bolt.Tx) error {
// Create the buckets for tokens and leases.
for _, bucket := range []string{TokenType, LeaseType, lookupType} {
for _, bucket := range []string{TokenType, LeaseType, lookupType, StaticSecretType, TokenCapabilitiesType} {
if _, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil {
return fmt.Errorf("failed to create %s bucket: %w", bucket, err)
}
@ -267,6 +267,10 @@ func (b *BoltStorage) Set(ctx context.Context, id string, plaintext []byte, inde
if err := meta.Put([]byte(AutoAuthToken), protoBlob); err != nil {
return fmt.Errorf("failed to set latest auto-auth token: %w", err)
}
case StaticSecretType:
key = []byte(id)
case TokenCapabilitiesType:
key = []byte(id)
default:
return fmt.Errorf("called Set for unsupported type %q", indexType)
}
@ -419,7 +423,7 @@ func (b *BoltStorage) Close() error {
// the schema/layout
func (b *BoltStorage) Clear() error {
return b.db.Update(func(tx *bolt.Tx) error {
for _, name := range []string{TokenType, LeaseType, lookupType} {
for _, name := range []string{TokenType, LeaseType, lookupType, StaticSecretType, TokenCapabilitiesType} {
b.logger.Trace("deleting bolt bucket", "name", name)
if err := tx.DeleteBucket([]byte(name)); err != nil {
return err

View File

@ -6,7 +6,6 @@ package cacheboltdb
import (
"context"
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
@ -34,7 +33,7 @@ func getTestKeyManager(t *testing.T) keymanager.KeyManager {
func TestBolt_SetGet(t *testing.T) {
ctx := context.Background()
path, err := ioutil.TempDir("", "bolt-test")
path, err := os.MkdirTemp("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
@ -60,7 +59,7 @@ func TestBolt_SetGet(t *testing.T) {
func TestBoltDelete(t *testing.T) {
ctx := context.Background()
path, err := ioutil.TempDir("", "bolt-test")
path, err := os.MkdirTemp("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
@ -92,7 +91,7 @@ func TestBoltDelete(t *testing.T) {
func TestBoltClear(t *testing.T) {
ctx := context.Background()
path, err := ioutil.TempDir("", "bolt-test")
path, err := os.MkdirTemp("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
@ -126,6 +125,20 @@ func TestBoltClear(t *testing.T) {
require.Len(t, tokens, 1)
assert.Equal(t, []byte("hello"), tokens[0])
err = b.Set(ctx, "static-secret", []byte("hello"), StaticSecretType)
require.NoError(t, err)
staticSecrets, err := b.GetByType(ctx, StaticSecretType)
require.NoError(t, err)
require.Len(t, staticSecrets, 1)
assert.Equal(t, []byte("hello"), staticSecrets[0])
err = b.Set(ctx, "capabilities-index", []byte("hello"), TokenCapabilitiesType)
require.NoError(t, err)
capabilities, err := b.GetByType(ctx, TokenCapabilitiesType)
require.NoError(t, err)
require.Len(t, capabilities, 1)
assert.Equal(t, []byte("hello"), capabilities[0])
// Clear the bolt db, and check that it's indeed clear
err = b.Clear()
require.NoError(t, err)
@ -135,12 +148,18 @@ func TestBoltClear(t *testing.T) {
tokens, err = b.GetByType(ctx, TokenType)
require.NoError(t, err)
assert.Len(t, tokens, 0)
staticSecrets, err = b.GetByType(ctx, StaticSecretType)
require.NoError(t, err)
require.Len(t, staticSecrets, 0)
capabilities, err = b.GetByType(ctx, TokenCapabilitiesType)
require.NoError(t, err)
require.Len(t, capabilities, 0)
}
func TestBoltSetAutoAuthToken(t *testing.T) {
ctx := context.Background()
path, err := ioutil.TempDir("", "bolt-test")
path, err := os.MkdirTemp("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
@ -210,11 +229,11 @@ func TestDBFileExists(t *testing.T) {
var tmpPath string
var err error
if tc.mkDir {
tmpPath, err = ioutil.TempDir("", "test-db-path")
tmpPath, err = os.MkdirTemp("", "test-db-path")
require.NoError(t, err)
}
if tc.createFile {
err = ioutil.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0o600)
err = os.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0o600)
require.NoError(t, err)
}
exists, err := DBFileExists(tmpPath)
@ -244,7 +263,7 @@ func Test_SetGetRetrievalToken(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
path, err := ioutil.TempDir("", "bolt-test")
path, err := os.MkdirTemp("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
@ -270,7 +289,7 @@ func Test_SetGetRetrievalToken(t *testing.T) {
func TestBolt_MigrateFromV1ToV2Schema(t *testing.T) {
ctx := context.Background()
path, err := ioutil.TempDir("", "bolt-test")
path, err := os.MkdirTemp("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)
@ -342,7 +361,7 @@ func TestBolt_MigrateFromV1ToV2Schema(t *testing.T) {
func TestBolt_MigrateFromInvalidToV2Schema(t *testing.T) {
ctx := context.Background()
path, err := ioutil.TempDir("", "bolt-test")
path, err := os.MkdirTemp("", "bolt-test")
require.NoError(t, err)
defer os.RemoveAll(path)

View File

@ -237,6 +237,28 @@ func (c *CacheMemDB) SetCapabilitiesIndex(index *CapabilitiesIndex) error {
return nil
}
// EvictCapabilitiesIndex removes a capabilities index from the cache based on index name and value.
func (c *CacheMemDB) EvictCapabilitiesIndex(indexName string, indexValues ...interface{}) error {
index, err := c.GetCapabilitiesIndex(indexName, indexValues...)
if err == ErrCacheItemNotFound {
return nil
}
if err != nil {
return fmt.Errorf("unable to fetch index on cache deletion: %v", err)
}
txn := c.db.Load().(*memdb.MemDB).Txn(true)
defer txn.Abort()
if err := txn.Delete(tableNameCapabilitiesIndexer, index); err != nil {
return fmt.Errorf("unable to delete index from cache: %v", err)
}
txn.Commit()
return nil
}
// GetByPrefix returns all the cached indexes based on the index name and the
// value prefix.
func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) {

View File

@ -7,6 +7,8 @@ import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/go-test/deep"
)
@ -393,3 +395,92 @@ func TestCacheMemDB_Flush(t *testing.T) {
t.Fatalf("expected cache to be empty, got = %v", out)
}
}
// TestCacheMemDB_EvictCapabilitiesIndex tests EvictCapabilitiesIndex works as expected.
func TestCacheMemDB_EvictCapabilitiesIndex(t *testing.T) {
cache, err := New()
require.Nil(t, err)
// Test on empty cache
err = cache.EvictCapabilitiesIndex(IndexNameID, "foo")
require.Nil(t, err)
capabilitiesIndex := &CapabilitiesIndex{
ID: "id",
Token: "token",
}
err = cache.SetCapabilitiesIndex(capabilitiesIndex)
require.Nil(t, err)
err = cache.EvictCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID)
require.Nil(t, err)
// Verify that the cache doesn't contain the entry anymore
index, err := cache.GetCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID)
require.Equal(t, ErrCacheItemNotFound, err)
require.Nil(t, index)
}
// TestCacheMemDB_GetCapabilitiesIndex tests GetCapabilitiesIndex works as expected.
func TestCacheMemDB_GetCapabilitiesIndex(t *testing.T) {
cache, err := New()
require.Nil(t, err)
capabilitiesIndex := &CapabilitiesIndex{
ID: "id",
Token: "token",
}
err = cache.SetCapabilitiesIndex(capabilitiesIndex)
require.Nil(t, err)
// Verify that we can retrieve the index
index, err := cache.GetCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID)
require.Nil(t, err)
require.Equal(t, capabilitiesIndex, index)
// Verify behaviour on a non-existing ID
index, err = cache.GetCapabilitiesIndex(IndexNameID, "not a real id")
require.Equal(t, ErrCacheItemNotFound, err)
require.Nil(t, index)
// Verify behaviour with a non-existing index name
index, err = cache.GetCapabilitiesIndex("not a real name", capabilitiesIndex.ID)
require.NotNil(t, err)
}
// TestCacheMemDB_SetCapabilitiesIndex tests SetCapabilitiesIndex works as expected.
func TestCacheMemDB_SetCapabilitiesIndex(t *testing.T) {
cache, err := New()
require.Nil(t, err)
capabilitiesIndex := &CapabilitiesIndex{
ID: "id",
Token: "token",
}
err = cache.SetCapabilitiesIndex(capabilitiesIndex)
require.Nil(t, err)
// Verify we can retrieve the index
index, err := cache.GetCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID)
require.Nil(t, err)
require.Equal(t, capabilitiesIndex, index)
// Verify behaviour on a nil index
err = cache.SetCapabilitiesIndex(nil)
require.NotNil(t, err)
// Verify behaviour on an index without id
err = cache.SetCapabilitiesIndex(&CapabilitiesIndex{
Token: "token",
})
require.NotNil(t, err)
// Verify behaviour on an index with only ID
err = cache.SetCapabilitiesIndex(&CapabilitiesIndex{
ID: "id",
})
require.Nil(t, err)
}

View File

@ -198,3 +198,22 @@ func Deserialize(indexBytes []byte) (*Index, error) {
}
return index, nil
}
// SerializeCapabilitiesIndex returns a json marshal'ed CapabilitiesIndex object
func (i CapabilitiesIndex) SerializeCapabilitiesIndex() ([]byte, error) {
indexBytes, err := json.Marshal(i)
if err != nil {
return nil, err
}
return indexBytes, nil
}
// DeserializeCapabilitiesIndex converts json bytes to an CapabilitiesIndex object
func DeserializeCapabilitiesIndex(indexBytes []byte) (*CapabilitiesIndex, error) {
index := new(CapabilitiesIndex)
if err := json.Unmarshal(indexBytes, index); err != nil {
return nil, err
}
return index, nil
}

View File

@ -102,6 +102,10 @@ type LeaseCache struct {
// cacheStaticSecrets is used to determine if the cache should also
// cache static secrets, as well as dynamic secrets.
cacheStaticSecrets bool
// capabilityManager is used when static secrets are enabled to
// manage the capabilities of cached tokens.
capabilityManager *StaticSecretCapabilityManager
}
// LeaseCacheConfig is the configuration for initializing a new
@ -168,9 +172,22 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) {
}, nil
}
// SetCapabilityManager is a setter for CapabilityManager. If set, will manage capabilities
// for capability indexes.
func (c *LeaseCache) SetCapabilityManager(capabilityManager *StaticSecretCapabilityManager) {
c.capabilityManager = capabilityManager
}
// SetShuttingDown is a setter for the shuttingDown field
func (c *LeaseCache) SetShuttingDown(in bool) {
c.shuttingDown.Store(in)
// Since we're shutting down, also stop the capability manager's jobs.
// We can do this forcibly since no there's no reason to update
// the cache when we're shutting down.
if c.capabilityManager != nil {
c.capabilityManager.Stop()
}
}
// SetPersistentStorage is a setter for the persistent storage field in
@ -628,7 +645,7 @@ func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendReques
return err
}
capabilitiesIndex, err := c.retrieveOrCreateTokenCapabilitiesEntry(req.Token)
capabilitiesIndex, created, err := c.retrieveOrCreateTokenCapabilitiesEntry(req.Token)
if err != nil {
c.logger.Error("failed to cache the proxied response", "error", err)
return err
@ -644,27 +661,35 @@ func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendReques
// update the index with the new capability:
capabilitiesIndex.ReadablePaths[path] = struct{}{}
err = c.db.SetCapabilitiesIndex(capabilitiesIndex)
err = c.SetCapabilitiesIndex(ctx, capabilitiesIndex)
if err != nil {
c.logger.Error("failed to cache token capabilities as part of caching the proxied response", "error", err)
return err
}
// Lastly, ensure that we start renewing this index, if it's new.
// We require the 'created' check so that we don't renew the same
// index multiple times.
if c.capabilityManager != nil && created {
c.capabilityManager.StartRenewingCapabilities(capabilitiesIndex)
}
return nil
}
// retrieveOrCreateTokenCapabilitiesEntry will either retrieve the token
// capabilities entry from the cache, or create a new, empty one.
func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cachememdb.CapabilitiesIndex, error) {
// The bool represents if a new token capability has been created.
func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cachememdb.CapabilitiesIndex, bool, error) {
// The index ID is a hash of the token.
indexId := hashStaticSecretIndex(token)
indexFromCache, err := c.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId)
if err != nil && err != cachememdb.ErrCacheItemNotFound {
return nil, err
return nil, false, err
}
if indexFromCache != nil {
return indexFromCache, nil
return indexFromCache, false, nil
}
// Build the index to cache based on the response received
@ -674,7 +699,7 @@ func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cach
ReadablePaths: make(map[string]struct{}),
}
return index, nil
return index, true, nil
}
func (c *LeaseCache) createCtxInfo(ctx context.Context) *cachememdb.ContextInfo {
@ -1266,6 +1291,28 @@ func (c *LeaseCache) Set(ctx context.Context, index *cachememdb.Index) error {
return nil
}
// SetCapabilitiesIndex stores the capabilities index in the cachememdb, and also stores it in the persistent
// cache (if enabled)
func (c *LeaseCache) SetCapabilitiesIndex(ctx context.Context, index *cachememdb.CapabilitiesIndex) error {
if err := c.db.SetCapabilitiesIndex(index); err != nil {
return err
}
if c.ps != nil {
plaintext, err := index.SerializeCapabilitiesIndex()
if err != nil {
return err
}
if err := c.ps.Set(ctx, index.ID, plaintext, cacheboltdb.TokenCapabilitiesType); err != nil {
return err
}
c.logger.Trace("set entry in persistent storage", "type", cacheboltdb.TokenCapabilitiesType, "id", index.ID)
}
return nil
}
// Evict removes an Index from the cachememdb, and also removes it from the
// persistent cache (if enabled)
func (c *LeaseCache) Evict(index *cachememdb.Index) error {
@ -1300,6 +1347,8 @@ func (c *LeaseCache) Flush() error {
// Restore loads the cachememdb from the persistent storage passed in. Loads
// tokens first, since restoring a lease's renewal context and watcher requires
// looking up the token in the cachememdb.
// Restore also restarts any capability management for managed static secret
// tokens.
func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStorage) error {
var errs *multierror.Error
@ -1348,6 +1397,51 @@ func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStora
}
}
// Then process static secrets and their capabilities
if c.cacheStaticSecrets {
staticSecrets, err := storage.GetByType(ctx, cacheboltdb.StaticSecretType)
if err != nil {
errs = multierror.Append(errs, err)
} else {
for _, staticSecret := range staticSecrets {
newIndex, err := cachememdb.Deserialize(staticSecret)
if err != nil {
errs = multierror.Append(errs, err)
continue
}
c.logger.Trace("restoring static secret index", "id", newIndex.ID, "path", newIndex.RequestPath)
if err := c.db.Set(newIndex); err != nil {
errs = multierror.Append(errs, err)
continue
}
}
}
capabilityIndexes, err := storage.GetByType(ctx, cacheboltdb.TokenCapabilitiesType)
if err != nil {
errs = multierror.Append(errs, err)
} else {
for _, capabilityIndex := range capabilityIndexes {
newIndex, err := cachememdb.DeserializeCapabilitiesIndex(capabilityIndex)
if err != nil {
errs = multierror.Append(errs, err)
continue
}
c.logger.Trace("restoring capability index", "id", newIndex.ID)
if err := c.db.SetCapabilitiesIndex(newIndex); err != nil {
errs = multierror.Append(errs, err)
continue
}
if c.capabilityManager != nil {
c.capabilityManager.StartRenewingCapabilities(newIndex)
}
}
}
}
return errs.ErrorOrNil()
}

View File

@ -0,0 +1,261 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cache
import (
"context"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/gammazero/workerpool"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb"
"github.com/mitchellh/mapstructure"
"golang.org/x/exp/maps"
)
const (
// DefaultWorkers is the default number of workers for the worker pool.
DefaultWorkers = 5
// DefaultStaticSecretTokenCapabilityRefreshInterval is the default time
// between each capability poll. This is configured with the following config value:
// static_secret_token_capability_refresh_interval
DefaultStaticSecretTokenCapabilityRefreshInterval = 5 * time.Minute
)
// StaticSecretCapabilityManager is a struct that utilizes
// a worker pool to keep capabilities up to date.
type StaticSecretCapabilityManager struct {
client *api.Client
leaseCache *LeaseCache
logger hclog.Logger
workerPool *workerpool.WorkerPool
staticSecretTokenCapabilityRefreshInterval time.Duration
}
// StaticSecretCapabilityManagerConfig is the configuration for initializing a new
// StaticSecretCapabilityManager.
type StaticSecretCapabilityManagerConfig struct {
LeaseCache *LeaseCache
Logger hclog.Logger
Client *api.Client
StaticSecretTokenCapabilityRefreshInterval time.Duration
}
// NewStaticSecretCapabilityManager creates a new instance of a StaticSecretCapabilityManager.
func NewStaticSecretCapabilityManager(conf *StaticSecretCapabilityManagerConfig) (*StaticSecretCapabilityManager, error) {
if conf == nil {
return nil, errors.New("nil configuration provided")
}
if conf.LeaseCache == nil {
return nil, fmt.Errorf("nil Lease Cache (a required parameter): %v", conf)
}
if conf.Logger == nil {
return nil, fmt.Errorf("nil Logger (a required parameter): %v", conf)
}
if conf.Client == nil {
return nil, fmt.Errorf("nil Client (a required parameter): %v", conf)
}
if conf.StaticSecretTokenCapabilityRefreshInterval == 0 {
conf.StaticSecretTokenCapabilityRefreshInterval = DefaultStaticSecretTokenCapabilityRefreshInterval
}
workerPool := workerpool.New(DefaultWorkers)
return &StaticSecretCapabilityManager{
client: conf.Client,
leaseCache: conf.LeaseCache,
logger: conf.Logger,
workerPool: workerPool,
staticSecretTokenCapabilityRefreshInterval: conf.StaticSecretTokenCapabilityRefreshInterval,
}, nil
}
// submitWorkToPoolAfterInterval submits work to the pool after the defined
// staticSecretTokenCapabilityRefreshInterval
func (sscm *StaticSecretCapabilityManager) submitWorkToPoolAfterInterval(work func()) {
time.AfterFunc(sscm.staticSecretTokenCapabilityRefreshInterval, func() {
if !sscm.workerPool.Stopped() {
sscm.workerPool.Submit(work)
}
})
}
// Stop stops all ongoing jobs and ensures future jobs will not
// get added to the worker pool.
func (sscm *StaticSecretCapabilityManager) Stop() {
sscm.workerPool.Stop()
}
// StartRenewingCapabilities takes a polling job and submits a constant renewal of capabilities to the worker pool.
// indexToRenew is the capabilities index we'll renew the capabilities for.
func (sscm *StaticSecretCapabilityManager) StartRenewingCapabilities(indexToRenew *cachememdb.CapabilitiesIndex) {
var work func()
work = func() {
if sscm.workerPool.Stopped() {
sscm.logger.Trace("worker pool stopped, stopping renewal")
return
}
capabilitiesIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexToRenew.ID)
if errors.Is(err, cachememdb.ErrCacheItemNotFound) {
// This cache entry no longer exists, so there is no more work to do.
sscm.logger.Trace("cache item not found for capabilities refresh, stopping the process")
return
}
if err != nil {
sscm.logger.Error("error when attempting to get capabilities index to refresh token capabilities", "indexToRenew.ID", indexToRenew.ID, "err", err)
sscm.submitWorkToPoolAfterInterval(work)
return
}
capabilitiesIndex.IndexLock.RLock()
token := capabilitiesIndex.Token
indexReadablePathsMap := capabilitiesIndex.ReadablePaths
capabilitiesIndex.IndexLock.RUnlock()
indexReadablePaths := maps.Keys(indexReadablePathsMap)
client, err := sscm.client.Clone()
if err != nil {
sscm.logger.Error("error when attempting clone client to refresh token capabilities", "indexToRenew.ID", indexToRenew.ID, "err", err)
sscm.submitWorkToPoolAfterInterval(work)
return
}
client.SetToken(token)
capabilities, err := getCapabilities(indexReadablePaths, client)
if err != nil {
sscm.logger.Error("error when attempting to retrieve updated token capabilities", "indexToRenew.ID", indexToRenew.ID, "err", err)
sscm.submitWorkToPoolAfterInterval(work)
return
}
newReadablePaths := reconcileCapabilities(indexReadablePaths, capabilities)
if maps.Equal(indexReadablePathsMap, newReadablePaths) {
sscm.logger.Trace("capabilities were the same for index, nothing to do", "indexToRenew.ID", indexToRenew.ID)
// there's nothing to update!
sscm.submitWorkToPoolAfterInterval(work)
return
}
// before updating or evicting the index, we must update the tokens on
// for each path, update the corresponding index with the diff
for _, path := range indexReadablePaths {
// If the old path isn't contained in the new readable paths,
// we must delete it from the tokens map for its corresponding
// path index.
if _, ok := newReadablePaths[path]; !ok {
indexId := hashStaticSecretIndex(path)
index, err := sscm.leaseCache.db.Get(cachememdb.IndexNameID, indexId)
if errors.Is(err, cachememdb.ErrCacheItemNotFound) {
// Nothing to update!
continue
}
if err != nil {
sscm.logger.Error("error when attempting to update corresponding paths for capabilities index", "indexToRenew.ID", indexToRenew.ID, "err", err)
sscm.submitWorkToPoolAfterInterval(work)
return
}
sscm.logger.Trace("updating tokens for index, as capability has been lost", "index.ID", index.ID, "request_path", index.RequestPath)
index.IndexLock.Lock()
delete(index.Tokens, capabilitiesIndex.Token)
err = sscm.leaseCache.Set(context.Background(), index)
if err != nil {
sscm.logger.Error("error when attempting to update index in cache", "index.ID", index.ID, "err", err)
}
index.IndexLock.Unlock()
}
}
// Lastly, we should update the capabilities index, either evicting or updating it
capabilitiesIndex.IndexLock.Lock()
defer capabilitiesIndex.IndexLock.Unlock()
if len(newReadablePaths) == 0 {
err := sscm.leaseCache.db.EvictCapabilitiesIndex(cachememdb.IndexNameID, indexToRenew.ID)
if err != nil {
sscm.logger.Error("error when attempting to evict capabilities from cache", "index.ID", indexToRenew.ID, "err", err)
sscm.submitWorkToPoolAfterInterval(work)
return
}
// If we successfully evicted the index, no need to re-submit the work to the pool.
return
}
// The token still has some capabilities, so, update the capabilities index:
capabilitiesIndex.ReadablePaths = newReadablePaths
err = sscm.leaseCache.SetCapabilitiesIndex(context.Background(), capabilitiesIndex)
if err != nil {
sscm.logger.Error("error when attempting to update capabilities from cache", "index.ID", indexToRenew.ID, "err", err)
}
// Finally, put ourselves back on the work pool after
sscm.submitWorkToPoolAfterInterval(work)
return
}
sscm.submitWorkToPoolAfterInterval(work)
}
// getCapabilities is a wrapper around a /sys/capabilities-self call that returns
// capabilities as a map with paths as keys, and capabilities as values.
func getCapabilities(paths []string, client *api.Client) (map[string][]string, error) {
body := make(map[string]interface{})
body["paths"] = paths
capabilities := make(map[string][]string)
secret, err := client.Logical().Write("sys/capabilities-self", body)
if err != nil && strings.Contains(err.Error(), "permission denied") {
// Token has expired. Return an empty set of capabilities:
return capabilities, nil
}
if err != nil {
return nil, err
}
if secret == nil || secret.Data == nil {
return nil, errors.New("data from server response is empty")
}
for _, path := range paths {
var res []string
err = mapstructure.Decode(secret.Data[path], &res)
if err != nil {
return nil, err
}
capabilities[path] = res
}
return capabilities, nil
}
// reconcileCapabilities takes a set of known readable paths, and a set of capabilities (a response from the
// sys/capabilities-self endpoint) and returns a subset of the readablePaths after taking into account any updated
// capabilities as a set, represented by a map of strings to structs.
// It will delete any path in readablePaths if it does not have a "root" or "read" capability listed in the
// capabilities map.
func reconcileCapabilities(readablePaths []string, capabilities map[string][]string) map[string]struct{} {
newReadablePaths := make(map[string]struct{})
for pathName, permissions := range capabilities {
if slices.Contains(permissions, "read") || slices.Contains(permissions, "root") {
// We do this as an additional sanity check. We never want to
// add permissions that weren't there before.
if slices.Contains(readablePaths, pathName) {
newReadablePaths[pathName] = struct{}{}
}
}
}
return newReadablePaths
}

View File

@ -0,0 +1,432 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cache
import (
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb"
"github.com/hashicorp/vault/helper/testhelpers/minimal"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/stretchr/testify/require"
)
// testNewStaticSecretCapabilityManager returns a new StaticSecretCapabilityManager
// for use in tests.
func testNewStaticSecretCapabilityManager(t *testing.T, client *api.Client) *StaticSecretCapabilityManager {
t.Helper()
lc := testNewLeaseCache(t, []*SendResponse{})
updater, err := NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{
LeaseCache: lc,
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager"),
Client: client,
StaticSecretTokenCapabilityRefreshInterval: 250 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
return updater
}
// TestNewStaticSecretCapabilityManager tests the NewStaticSecretCapabilityManager method,
// to ensure it errors out when appropriate.
func TestNewStaticSecretCapabilityManager(t *testing.T) {
t.Parallel()
lc := testNewLeaseCache(t, []*SendResponse{})
logger := logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager")
client, err := api.NewClient(api.DefaultConfig())
require.Nil(t, err)
// Expect an error if any of the arguments are nil:
updater, err := NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{
LeaseCache: nil,
Logger: logger,
Client: client,
})
require.Error(t, err)
require.Nil(t, updater)
updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{
LeaseCache: lc,
Logger: nil,
Client: client,
})
require.Error(t, err)
require.Nil(t, updater)
updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{
LeaseCache: lc,
Logger: logger,
Client: nil,
})
require.Error(t, err)
require.Nil(t, updater)
// Don't expect an error if the arguments are as expected
updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{
LeaseCache: lc,
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager"),
Client: client,
})
if err != nil {
t.Fatal(err)
}
require.NotNil(t, updater)
require.NotNil(t, updater.workerPool)
require.NotNil(t, updater.staticSecretTokenCapabilityRefreshInterval)
require.NotNil(t, updater.client)
require.NotNil(t, updater.leaseCache)
require.NotNil(t, updater.logger)
require.Equal(t, DefaultStaticSecretTokenCapabilityRefreshInterval, updater.staticSecretTokenCapabilityRefreshInterval)
// Lastly, double check that the refresh interval can be properly set
updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{
LeaseCache: lc,
Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager"),
Client: client,
StaticSecretTokenCapabilityRefreshInterval: time.Hour,
})
if err != nil {
t.Fatal(err)
}
require.NotNil(t, updater)
require.NotNil(t, updater.workerPool)
require.NotNil(t, updater.staticSecretTokenCapabilityRefreshInterval)
require.NotNil(t, updater.client)
require.NotNil(t, updater.leaseCache)
require.NotNil(t, updater.logger)
require.Equal(t, time.Hour, updater.staticSecretTokenCapabilityRefreshInterval)
}
// TestGetCapabilitiesRootToken tests the getCapabilities method with the root
// token, expecting to get "root" capabilities on valid paths
func TestGetCapabilitiesRootToken(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client
capabilitiesToCheck := []string{"auth/token/create", "sys/health"}
capabilities, err := getCapabilities(capabilitiesToCheck, client)
require.NoError(t, err)
expectedCapabilities := map[string][]string{
"auth/token/create": {"root"},
"sys/health": {"root"},
}
require.Equal(t, expectedCapabilities, capabilities)
}
// TestGetCapabilitiesLowPrivilegeToken tests the getCapabilities method with
// a low privilege token, expecting to get deny or non-root capabilities
func TestGetCapabilitiesLowPrivilegeToken(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client
renewable := true
// Set the token's policies to 'default' and nothing else
tokenCreateRequest := &api.TokenCreateRequest{
Policies: []string{"default"},
TTL: "30m",
Renewable: &renewable,
}
secret, err := client.Auth().Token().CreateOrphan(tokenCreateRequest)
require.NoError(t, err)
token := secret.Auth.ClientToken
client.SetToken(token)
capabilitiesToCheck := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"}
capabilities, err := getCapabilities(capabilitiesToCheck, client)
require.NoError(t, err)
expectedCapabilities := map[string][]string{
"auth/token/create": {"deny"},
"sys/capabilities-self": {"update"},
"auth/token/lookup-self": {"read"},
}
require.Equal(t, expectedCapabilities, capabilities)
}
// TestGetCapabilitiesBadClientToken tests that getCapabilities
// returns an empty set of capabilities if the token is bad (and it gets a 403)
func TestGetCapabilitiesBadClientToken(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client
client.SetToken("")
capabilitiesToCheck := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"}
capabilities, err := getCapabilities(capabilitiesToCheck, client)
require.Nil(t, err)
require.Equal(t, map[string][]string{}, capabilities)
}
// TestGetCapabilitiesEmptyPaths tests the getCapabilities will error on an empty
// set of paths to check
func TestGetCapabilitiesEmptyPaths(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client
var capabilitiesToCheck []string
_, err := getCapabilities(capabilitiesToCheck, client)
require.Error(t, err)
}
// TestReconcileCapabilities tests that reconcileCapabilities will
// correctly previously remove readable paths that we don't have read access to.
func TestReconcileCapabilities(t *testing.T) {
t.Parallel()
paths := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"}
capabilities := map[string][]string{
"auth/token/create": {"deny"},
"sys/capabilities-self": {"update"},
"auth/token/lookup-self": {"read"},
}
updatedCapabilities := reconcileCapabilities(paths, capabilities)
expectedUpdatedCapabilities := map[string]struct{}{
"auth/token/lookup-self": {},
}
require.Equal(t, expectedUpdatedCapabilities, updatedCapabilities)
}
// TestReconcileCapabilitiesNoOp tests that reconcileCapabilities will
// correctly not remove capabilities when they all remain readable.
func TestReconcileCapabilitiesNoOp(t *testing.T) {
t.Parallel()
paths := []string{"foo/bar", "bar/baz", "baz/foo"}
capabilities := map[string][]string{
"foo/bar": {"read"},
"bar/baz": {"root"},
"baz/foo": {"read"},
}
updatedCapabilities := reconcileCapabilities(paths, capabilities)
expectedUpdatedCapabilities := map[string]struct{}{
"foo/bar": {},
"bar/baz": {},
"baz/foo": {},
}
require.Equal(t, expectedUpdatedCapabilities, updatedCapabilities)
}
// TestReconcileCapabilitiesNoAdding tests that reconcileCapabilities will
// not add any capabilities that weren't present in the first argument to the function
func TestReconcileCapabilitiesNoAdding(t *testing.T) {
t.Parallel()
paths := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"}
capabilities := map[string][]string{
"auth/token/create": {"deny"},
"sys/capabilities-self": {"update"},
"auth/token/lookup-self": {"read"},
"some/new/path": {"read"},
}
updatedCapabilities := reconcileCapabilities(paths, capabilities)
expectedUpdatedCapabilities := map[string]struct{}{
"auth/token/lookup-self": {},
}
require.Equal(t, expectedUpdatedCapabilities, updatedCapabilities)
}
// TestSubmitWorkNoOp tests that we will gracefully end if the capabilities index
// does not exist in the cache
func TestSubmitWorkNoOp(t *testing.T) {
t.Parallel()
client, err := api.NewClient(api.DefaultConfig())
require.Nil(t, err)
sscm := testNewStaticSecretCapabilityManager(t, client)
// This index will be a no-op, as this does not exist in the cache
index := &cachememdb.CapabilitiesIndex{
ID: "test",
}
sscm.StartRenewingCapabilities(index)
// Wait for the job to complete...
time.Sleep(1 * time.Second)
require.Equal(t, 0, sscm.workerPool.WaitingQueueSize())
}
// TestSubmitWorkUpdatesIndex tests that an index will be correctly updated if the capabilities differ.
func TestSubmitWorkUpdatesIndex(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client
// Create a low permission token
renewable := true
// Set the token's policies to 'default' and nothing else
tokenCreateRequest := &api.TokenCreateRequest{
Policies: []string{"default"},
TTL: "30m",
Renewable: &renewable,
}
secret, err := client.Auth().Token().CreateOrphan(tokenCreateRequest)
require.NoError(t, err)
token := secret.Auth.ClientToken
indexId := hashStaticSecretIndex(token)
sscm := testNewStaticSecretCapabilityManager(t, client)
index := &cachememdb.CapabilitiesIndex{
ID: indexId,
Token: token,
// The token will (perhaps obviously) not have
// read access to /foo/bar, but will to /auth/token/lookup-self
ReadablePaths: map[string]struct{}{
"foo/bar": {},
"auth/token/lookup-self": {},
},
}
err = sscm.leaseCache.db.SetCapabilitiesIndex(index)
require.Nil(t, err)
sscm.StartRenewingCapabilities(index)
// Wait for the job to complete at least once...
time.Sleep(3 * time.Second)
newIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId)
require.Nil(t, err)
require.Equal(t, map[string]struct{}{
"auth/token/lookup-self": {},
}, newIndex.ReadablePaths)
// Forcefully stop any remaining workers
sscm.workerPool.Stop()
}
// TestSubmitWorkUpdatesIndexWithBadToken tests that an index will be correctly updated if the token
// has expired and we cannot access the sys capabilities endpoint.
func TestSubmitWorkUpdatesIndexWithBadToken(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client
token := "not real token"
indexId := hashStaticSecretIndex(token)
sscm := testNewStaticSecretCapabilityManager(t, client)
index := &cachememdb.CapabilitiesIndex{
ID: indexId,
Token: token,
ReadablePaths: map[string]struct{}{
"foo/bar": {},
"auth/token/lookup-self": {},
},
}
err := sscm.leaseCache.db.SetCapabilitiesIndex(index)
require.Nil(t, err)
sscm.StartRenewingCapabilities(index)
// Wait for the job to complete at least once...
time.Sleep(3 * time.Second)
// This entry should be evicted.
newIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId)
require.Equal(t, err, cachememdb.ErrCacheItemNotFound)
require.Nil(t, newIndex)
// Forcefully stop any remaining workers
sscm.workerPool.Stop()
}
// TestSubmitWorkUpdatesAllIndexes tests that an index will be correctly updated if the capabilities differ, as
// well as the indexes related to the paths that are being checked for.
func TestSubmitWorkUpdatesAllIndexes(t *testing.T) {
t.Parallel()
cluster := minimal.NewTestSoloCluster(t, nil)
client := cluster.Cores[0].Client
// Create a low permission token
renewable := true
// Set the token's policies to 'default' and nothing else
tokenCreateRequest := &api.TokenCreateRequest{
Policies: []string{"default"},
TTL: "30m",
Renewable: &renewable,
}
secret, err := client.Auth().Token().CreateOrphan(tokenCreateRequest)
require.NoError(t, err)
token := secret.Auth.ClientToken
indexId := hashStaticSecretIndex(token)
sscm := testNewStaticSecretCapabilityManager(t, client)
index := &cachememdb.CapabilitiesIndex{
ID: indexId,
Token: token,
// The token will (perhaps obviously) not have
// read access to /foo/bar, but will to /auth/token/lookup-self
ReadablePaths: map[string]struct{}{
"foo/bar": {},
"auth/token/lookup-self": {},
},
}
err = sscm.leaseCache.db.SetCapabilitiesIndex(index)
require.Nil(t, err)
pathIndexId1 := hashStaticSecretIndex("foo/bar")
pathIndex1 := &cachememdb.Index{
ID: pathIndexId1,
Namespace: "root/",
Tokens: map[string]struct{}{
token: {},
},
RequestPath: "foo/bar",
Response: []byte{},
}
pathIndexId2 := hashStaticSecretIndex("auth/token/lookup-self")
pathIndex2 := &cachememdb.Index{
ID: pathIndexId2,
Namespace: "root/",
Tokens: map[string]struct{}{
token: {},
},
RequestPath: "auth/token/lookup-self",
Response: []byte{},
}
err = sscm.leaseCache.db.Set(pathIndex1)
require.Nil(t, err)
err = sscm.leaseCache.db.Set(pathIndex2)
require.Nil(t, err)
sscm.StartRenewingCapabilities(index)
// Wait for the job to complete at least once...
time.Sleep(1 * time.Second)
newIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId)
require.Nil(t, err)
require.Equal(t, map[string]struct{}{
"auth/token/lookup-self": {},
}, newIndex.ReadablePaths)
// For this, we expect the token to have been deleted
newPathIndex1, err := sscm.leaseCache.db.Get(cachememdb.IndexNameID, pathIndexId1)
require.Nil(t, err)
require.Equal(t, map[string]struct{}{}, newPathIndex1.Tokens)
// For this, we expect no change
newPathIndex2, err := sscm.leaseCache.db.Get(cachememdb.IndexNameID, pathIndexId2)
require.Nil(t, err)
require.Equal(t, newPathIndex2, newPathIndex2)
// Forcefully stop any remaining workers
sscm.workerPool.Stop()
}

View File

@ -491,6 +491,18 @@ func (c *ProxyCommand) Run(args []string) int {
c.UI.Error(fmt.Sprintf("Error creating static secret cache updater: %v", err))
return 1
}
capabilityManager, err := cache.NewStaticSecretCapabilityManager(&cache.StaticSecretCapabilityManagerConfig{
LeaseCache: leaseCache,
Logger: c.logger.Named("cache.staticsecretcapabilitymanager"),
Client: client,
StaticSecretTokenCapabilityRefreshInterval: config.Cache.StaticSecretTokenCapabilityRefreshInterval,
})
if err != nil {
c.UI.Error(fmt.Sprintf("Error creating static secret capability manager: %v", err))
return 1
}
leaseCache.SetCapabilityManager(capabilityManager)
}
}

View File

@ -101,9 +101,11 @@ type APIProxy struct {
// Cache contains any configuration needed for Cache mode
type Cache struct {
Persist *agentproxyshared.PersistConfig `hcl:"persist"`
InProcDialer transportDialer `hcl:"-"`
CacheStaticSecrets bool `hcl:"cache_static_secrets"`
Persist *agentproxyshared.PersistConfig `hcl:"persist"`
InProcDialer transportDialer `hcl:"-"`
CacheStaticSecrets bool `hcl:"cache_static_secrets"`
StaticSecretTokenCapabilityRefreshIntervalRaw interface{} `hcl:"static_secret_token_capability_refresh_interval"`
StaticSecretTokenCapabilityRefreshInterval time.Duration `hcl:"-"`
}
// AutoAuth is the configured authentication method and sinks
@ -621,6 +623,14 @@ func parseCache(result *Config, list *ast.ObjectList) error {
return fmt.Errorf("error parsing persist: %w", err)
}
if result.Cache.StaticSecretTokenCapabilityRefreshIntervalRaw != nil {
var err error
if result.Cache.StaticSecretTokenCapabilityRefreshInterval, err = parseutil.ParseDurationSecond(result.Cache.StaticSecretTokenCapabilityRefreshIntervalRaw); err != nil {
return fmt.Errorf("error parsing static_secret_token_capability_refresh_interval, must be provided as a duration string: %w", err)
}
result.Cache.StaticSecretTokenCapabilityRefreshIntervalRaw = nil
}
return nil
}

View File

@ -5,6 +5,7 @@ package config
import (
"testing"
"time"
"github.com/go-test/deep"
"github.com/hashicorp/vault/command/agentproxyshared"
@ -130,3 +131,62 @@ func TestLoadConfigFile_StaticSecretCachingWithoutAutoAuth(t *testing.T) {
t.Fatalf("expected error, as static secret caching requires auto-auth")
}
}
// TestLoadConfigFile_ProxyCacheStaticSecrets tests loading a config file containing a cache
// as well as a valid proxy config with static secret caching enabled
func TestLoadConfigFile_ProxyCacheStaticSecrets(t *testing.T) {
config, err := LoadConfigFile("./test-fixtures/config-cache-static-secret-cache.hcl")
if err != nil {
t.Fatal(err)
}
expected := &Config{
SharedConfig: &configutil.SharedConfig{
PidFile: "./pidfile",
Listeners: []*configutil.Listener{
{
Type: "tcp",
Address: "127.0.0.1:8300",
TLSDisable: true,
},
},
},
AutoAuth: &AutoAuth{
Method: &Method{
Type: "aws",
MountPath: "auth/aws",
Config: map[string]interface{}{
"role": "foobar",
},
},
Sinks: []*Sink{
{
Type: "file",
DHType: "curve25519",
DHPath: "/tmp/file-foo-dhpath",
AAD: "foobar",
Config: map[string]interface{}{
"path": "/tmp/file-foo",
},
},
},
},
Cache: &Cache{
CacheStaticSecrets: true,
StaticSecretTokenCapabilityRefreshInterval: 1 * time.Hour,
},
Vault: &Vault{
Address: "http://127.0.0.1:1111",
TLSSkipVerify: true,
TLSSkipVerifyRaw: interface{}("true"),
Retry: &Retry{
NumRetries: 12,
},
},
}
config.Prune()
if diff := deep.Equal(config, expected); diff != nil {
t.Fatal(diff)
}
}

View File

@ -0,0 +1,38 @@
# Copyright (c) HashiCorp, Inc.
# SPDX-License-Identifier: BUSL-1.1
pid_file = "./pidfile"
auto_auth {
method {
type = "aws"
config = {
role = "foobar"
}
}
sink {
type = "file"
config = {
path = "/tmp/file-foo"
}
aad = "foobar"
dh_type = "curve25519"
dh_path = "/tmp/file-foo-dhpath"
}
}
cache {
cache_static_secrets = true
static_secret_token_capability_refresh_interval = "1h"
}
listener "tcp" {
address = "127.0.0.1:8300"
tls_disable = true
}
vault {
address = "http://127.0.0.1:1111"
tls_skip_verify = "true"
}

View File

@ -1156,6 +1156,178 @@ log_level = "trace"
wg.Wait()
}
// TestProxy_Cache_StaticSecretPermissionsLost Tests that the cache successfully caches a static secret
// going through the Proxy for a KVV2 secret, and then the calling client loses permissions to the secret,
// so it can no longer access the cache.
func TestProxy_Cache_StaticSecretPermissionsLost(t *testing.T) {
logger := logging.NewVaultLogger(hclog.Trace)
cluster := vault.NewTestCluster(t, &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"kv": logicalKv.VersionedKVFactory,
},
}, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
serverClient := cluster.Cores[0].Client
// Unset the environment variable so that proxy picks up the right test
// cluster address
defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress))
os.Unsetenv(api.EnvVaultAddress)
tokenFileName := makeTempFile(t, "token-file", serverClient.Token())
defer os.Remove(tokenFileName)
// We need auto-auth so that the event system can run.
// For ease, we use the token file path with the root token.
autoAuthConfig := fmt.Sprintf(`
auto_auth {
method {
type = "token_file"
config = {
token_file_path = "%s"
}
}
}`, tokenFileName)
// We make the token capability refresh interval one second, for ease of testing
cacheConfig := `
cache {
cache_static_secrets = true
static_secret_token_capability_refresh_interval = "1s"
}
`
listenAddr := generateListenerAddress(t)
listenConfig := fmt.Sprintf(`
listener "tcp" {
address = "%s"
tls_disable = true
}
`, listenAddr)
config := fmt.Sprintf(`
vault {
address = "%s"
tls_skip_verify = true
}
%s
%s
%s
log_level = "trace"
`, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig)
configPath := makeTempFile(t, "config.hcl", config)
defer os.Remove(configPath)
// Start proxy
ui, cmd := testProxyCommand(t, logger)
cmd.startedCh = make(chan struct{})
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
cmd.Run([]string{"-config", configPath})
wg.Done()
}()
select {
case <-cmd.startedCh:
case <-time.After(5 * time.Second):
t.Errorf("timeout")
t.Errorf("stdout: %s", ui.OutputWriter.String())
t.Errorf("stderr: %s", ui.ErrorWriter.String())
}
proxyClient, err := api.NewClient(api.DefaultConfig())
require.Nil(t, err)
proxyClient.SetMaxRetries(0)
err = proxyClient.SetAddress("http://" + listenAddr)
require.Nil(t, err)
secretData := map[string]interface{}{
"foo": "bar",
}
// Mount the KVV2 engine
err = serverClient.Sys().Mount("secret-v2", &api.MountInput{
Type: "kv-v2",
})
require.Nil(t, err)
err = serverClient.Sys().PutPolicy("kv-policy", `
path "secret-v2/*" {
capabilities = ["update", "read"]
}`)
require.Nil(t, err)
// Setup a token that we can later revoke:
renewable := true
// Set the token's policies to 'default' and nothing else
tokenCreateRequest := &api.TokenCreateRequest{
Policies: []string{"default", "kv-policy"},
TTL: "2s",
Renewable: &renewable,
}
secret, err := serverClient.Auth().Token().CreateOrphan(tokenCreateRequest)
require.Nil(t, err)
token := secret.Auth.ClientToken
proxyClient.SetToken(token)
// Create kvv2 secret
_, err = serverClient.KVv2("secret-v2").Put(context.Background(), "my-secret", secretData)
require.Nil(t, err)
// We use raw requests so we can check the headers for cache hit/miss.
req := proxyClient.NewRequest(http.MethodGet, "/v1/secret-v2/data/my-secret")
resp1, err := proxyClient.RawRequest(req)
require.Nil(t, err)
cacheValue := resp1.Header.Get("X-Cache")
require.Equal(t, "MISS", cacheValue)
// We expect this to be a cache hit, with the new value
resp2, err := proxyClient.RawRequest(req)
require.Nil(t, err)
cacheValue = resp2.Header.Get("X-Cache")
require.Equal(t, "HIT", cacheValue)
// Lastly, we check to make sure the actual data we received is
// as we expect. We must use ParseSecret due to the raw requests.
secret1, err := api.ParseSecret(resp1.Body)
if err != nil {
t.Fatal(err)
}
data, ok := secret1.Data["data"]
require.True(t, ok)
require.Equal(t, secretData, data)
secret2, err := api.ParseSecret(resp2.Body)
if err != nil {
t.Fatal(err)
}
data2, ok := secret2.Data["data"]
require.True(t, ok)
// We expect that the cached value got updated by the event system.
require.Equal(t, secretData, data2)
// Wait for the token to expire, and for the permissions to be revoked
// The TTL on the token was 2s, and the capability refresh is every 1s,
// so this should give us more than enough time!
time.Sleep(5 * time.Second)
kvSecret, err := proxyClient.KVv2("secret-v2").Get(context.Background(), "my-secret")
if err == nil {
t.Fatalf("expected error, but none found, secret:%v, err:%v", kvSecret, err)
}
// Make sure it's a permission denied error
if !strings.Contains(err.Error(), "permission denied") {
t.Fatalf("expected error on GET to secret after token revocation, secret:%v, err:%v", kvSecret, err)
}
close(cmd.ShutdownCh)
wg.Wait()
}
// TestProxy_ApiProxy_Retry Tests the retry functionalities of Vault Proxy's API Proxy
func TestProxy_ApiProxy_Retry(t *testing.T) {
//----------------------------------------------------