diff --git a/command/agentproxyshared/cache/cacheboltdb/bolt.go b/command/agentproxyshared/cache/cacheboltdb/bolt.go index ff7ec5fdf8..6100ef8962 100644 --- a/command/agentproxyshared/cache/cacheboltdb/bolt.go +++ b/command/agentproxyshared/cache/cacheboltdb/bolt.go @@ -165,7 +165,7 @@ func createV1BoltSchema(tx *bolt.Tx) error { func createV2BoltSchema(tx *bolt.Tx) error { // Create the buckets for tokens and leases. - for _, bucket := range []string{TokenType, LeaseType, lookupType} { + for _, bucket := range []string{TokenType, LeaseType, lookupType, StaticSecretType, TokenCapabilitiesType} { if _, err := tx.CreateBucketIfNotExists([]byte(bucket)); err != nil { return fmt.Errorf("failed to create %s bucket: %w", bucket, err) } @@ -267,6 +267,10 @@ func (b *BoltStorage) Set(ctx context.Context, id string, plaintext []byte, inde if err := meta.Put([]byte(AutoAuthToken), protoBlob); err != nil { return fmt.Errorf("failed to set latest auto-auth token: %w", err) } + case StaticSecretType: + key = []byte(id) + case TokenCapabilitiesType: + key = []byte(id) default: return fmt.Errorf("called Set for unsupported type %q", indexType) } @@ -419,7 +423,7 @@ func (b *BoltStorage) Close() error { // the schema/layout func (b *BoltStorage) Clear() error { return b.db.Update(func(tx *bolt.Tx) error { - for _, name := range []string{TokenType, LeaseType, lookupType} { + for _, name := range []string{TokenType, LeaseType, lookupType, StaticSecretType, TokenCapabilitiesType} { b.logger.Trace("deleting bolt bucket", "name", name) if err := tx.DeleteBucket([]byte(name)); err != nil { return err diff --git a/command/agentproxyshared/cache/cacheboltdb/bolt_test.go b/command/agentproxyshared/cache/cacheboltdb/bolt_test.go index c2959fc9f6..dbfafdce7b 100644 --- a/command/agentproxyshared/cache/cacheboltdb/bolt_test.go +++ b/command/agentproxyshared/cache/cacheboltdb/bolt_test.go @@ -6,7 +6,6 @@ package cacheboltdb import ( "context" "fmt" - "io/ioutil" "os" "path" "path/filepath" @@ -34,7 +33,7 @@ func getTestKeyManager(t *testing.T) keymanager.KeyManager { func TestBolt_SetGet(t *testing.T) { ctx := context.Background() - path, err := ioutil.TempDir("", "bolt-test") + path, err := os.MkdirTemp("", "bolt-test") require.NoError(t, err) defer os.RemoveAll(path) @@ -60,7 +59,7 @@ func TestBolt_SetGet(t *testing.T) { func TestBoltDelete(t *testing.T) { ctx := context.Background() - path, err := ioutil.TempDir("", "bolt-test") + path, err := os.MkdirTemp("", "bolt-test") require.NoError(t, err) defer os.RemoveAll(path) @@ -92,7 +91,7 @@ func TestBoltDelete(t *testing.T) { func TestBoltClear(t *testing.T) { ctx := context.Background() - path, err := ioutil.TempDir("", "bolt-test") + path, err := os.MkdirTemp("", "bolt-test") require.NoError(t, err) defer os.RemoveAll(path) @@ -126,6 +125,20 @@ func TestBoltClear(t *testing.T) { require.Len(t, tokens, 1) assert.Equal(t, []byte("hello"), tokens[0]) + err = b.Set(ctx, "static-secret", []byte("hello"), StaticSecretType) + require.NoError(t, err) + staticSecrets, err := b.GetByType(ctx, StaticSecretType) + require.NoError(t, err) + require.Len(t, staticSecrets, 1) + assert.Equal(t, []byte("hello"), staticSecrets[0]) + + err = b.Set(ctx, "capabilities-index", []byte("hello"), TokenCapabilitiesType) + require.NoError(t, err) + capabilities, err := b.GetByType(ctx, TokenCapabilitiesType) + require.NoError(t, err) + require.Len(t, capabilities, 1) + assert.Equal(t, []byte("hello"), capabilities[0]) + // Clear the bolt db, and check that it's indeed clear err = b.Clear() require.NoError(t, err) @@ -135,12 +148,18 @@ func TestBoltClear(t *testing.T) { tokens, err = b.GetByType(ctx, TokenType) require.NoError(t, err) assert.Len(t, tokens, 0) + staticSecrets, err = b.GetByType(ctx, StaticSecretType) + require.NoError(t, err) + require.Len(t, staticSecrets, 0) + capabilities, err = b.GetByType(ctx, TokenCapabilitiesType) + require.NoError(t, err) + require.Len(t, capabilities, 0) } func TestBoltSetAutoAuthToken(t *testing.T) { ctx := context.Background() - path, err := ioutil.TempDir("", "bolt-test") + path, err := os.MkdirTemp("", "bolt-test") require.NoError(t, err) defer os.RemoveAll(path) @@ -210,11 +229,11 @@ func TestDBFileExists(t *testing.T) { var tmpPath string var err error if tc.mkDir { - tmpPath, err = ioutil.TempDir("", "test-db-path") + tmpPath, err = os.MkdirTemp("", "test-db-path") require.NoError(t, err) } if tc.createFile { - err = ioutil.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0o600) + err = os.WriteFile(path.Join(tmpPath, DatabaseFileName), []byte("test-db-path"), 0o600) require.NoError(t, err) } exists, err := DBFileExists(tmpPath) @@ -244,7 +263,7 @@ func Test_SetGetRetrievalToken(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - path, err := ioutil.TempDir("", "bolt-test") + path, err := os.MkdirTemp("", "bolt-test") require.NoError(t, err) defer os.RemoveAll(path) @@ -270,7 +289,7 @@ func Test_SetGetRetrievalToken(t *testing.T) { func TestBolt_MigrateFromV1ToV2Schema(t *testing.T) { ctx := context.Background() - path, err := ioutil.TempDir("", "bolt-test") + path, err := os.MkdirTemp("", "bolt-test") require.NoError(t, err) defer os.RemoveAll(path) @@ -342,7 +361,7 @@ func TestBolt_MigrateFromV1ToV2Schema(t *testing.T) { func TestBolt_MigrateFromInvalidToV2Schema(t *testing.T) { ctx := context.Background() - path, err := ioutil.TempDir("", "bolt-test") + path, err := os.MkdirTemp("", "bolt-test") require.NoError(t, err) defer os.RemoveAll(path) diff --git a/command/agentproxyshared/cache/cachememdb/cache_memdb.go b/command/agentproxyshared/cache/cachememdb/cache_memdb.go index daa9a747df..ed2cd0ac80 100644 --- a/command/agentproxyshared/cache/cachememdb/cache_memdb.go +++ b/command/agentproxyshared/cache/cachememdb/cache_memdb.go @@ -237,6 +237,28 @@ func (c *CacheMemDB) SetCapabilitiesIndex(index *CapabilitiesIndex) error { return nil } +// EvictCapabilitiesIndex removes a capabilities index from the cache based on index name and value. +func (c *CacheMemDB) EvictCapabilitiesIndex(indexName string, indexValues ...interface{}) error { + index, err := c.GetCapabilitiesIndex(indexName, indexValues...) + if err == ErrCacheItemNotFound { + return nil + } + if err != nil { + return fmt.Errorf("unable to fetch index on cache deletion: %v", err) + } + + txn := c.db.Load().(*memdb.MemDB).Txn(true) + defer txn.Abort() + + if err := txn.Delete(tableNameCapabilitiesIndexer, index); err != nil { + return fmt.Errorf("unable to delete index from cache: %v", err) + } + + txn.Commit() + + return nil +} + // GetByPrefix returns all the cached indexes based on the index name and the // value prefix. func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) { diff --git a/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go b/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go index 47fa75ee54..63959141d6 100644 --- a/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go +++ b/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go @@ -7,6 +7,8 @@ import ( "context" "testing" + "github.com/stretchr/testify/require" + "github.com/go-test/deep" ) @@ -393,3 +395,92 @@ func TestCacheMemDB_Flush(t *testing.T) { t.Fatalf("expected cache to be empty, got = %v", out) } } + +// TestCacheMemDB_EvictCapabilitiesIndex tests EvictCapabilitiesIndex works as expected. +func TestCacheMemDB_EvictCapabilitiesIndex(t *testing.T) { + cache, err := New() + require.Nil(t, err) + + // Test on empty cache + err = cache.EvictCapabilitiesIndex(IndexNameID, "foo") + require.Nil(t, err) + + capabilitiesIndex := &CapabilitiesIndex{ + ID: "id", + Token: "token", + } + + err = cache.SetCapabilitiesIndex(capabilitiesIndex) + require.Nil(t, err) + + err = cache.EvictCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID) + require.Nil(t, err) + + // Verify that the cache doesn't contain the entry anymore + index, err := cache.GetCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID) + require.Equal(t, ErrCacheItemNotFound, err) + require.Nil(t, index) +} + +// TestCacheMemDB_GetCapabilitiesIndex tests GetCapabilitiesIndex works as expected. +func TestCacheMemDB_GetCapabilitiesIndex(t *testing.T) { + cache, err := New() + require.Nil(t, err) + + capabilitiesIndex := &CapabilitiesIndex{ + ID: "id", + Token: "token", + } + + err = cache.SetCapabilitiesIndex(capabilitiesIndex) + require.Nil(t, err) + + // Verify that we can retrieve the index + index, err := cache.GetCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID) + require.Nil(t, err) + require.Equal(t, capabilitiesIndex, index) + + // Verify behaviour on a non-existing ID + index, err = cache.GetCapabilitiesIndex(IndexNameID, "not a real id") + require.Equal(t, ErrCacheItemNotFound, err) + require.Nil(t, index) + + // Verify behaviour with a non-existing index name + index, err = cache.GetCapabilitiesIndex("not a real name", capabilitiesIndex.ID) + require.NotNil(t, err) +} + +// TestCacheMemDB_SetCapabilitiesIndex tests SetCapabilitiesIndex works as expected. +func TestCacheMemDB_SetCapabilitiesIndex(t *testing.T) { + cache, err := New() + require.Nil(t, err) + + capabilitiesIndex := &CapabilitiesIndex{ + ID: "id", + Token: "token", + } + + err = cache.SetCapabilitiesIndex(capabilitiesIndex) + require.Nil(t, err) + + // Verify we can retrieve the index + index, err := cache.GetCapabilitiesIndex(IndexNameID, capabilitiesIndex.ID) + require.Nil(t, err) + require.Equal(t, capabilitiesIndex, index) + + // Verify behaviour on a nil index + err = cache.SetCapabilitiesIndex(nil) + require.NotNil(t, err) + + // Verify behaviour on an index without id + err = cache.SetCapabilitiesIndex(&CapabilitiesIndex{ + Token: "token", + }) + require.NotNil(t, err) + + // Verify behaviour on an index with only ID + err = cache.SetCapabilitiesIndex(&CapabilitiesIndex{ + ID: "id", + }) + require.Nil(t, err) +} diff --git a/command/agentproxyshared/cache/cachememdb/index.go b/command/agentproxyshared/cache/cachememdb/index.go index 3a602cab6c..348688d7ad 100644 --- a/command/agentproxyshared/cache/cachememdb/index.go +++ b/command/agentproxyshared/cache/cachememdb/index.go @@ -198,3 +198,22 @@ func Deserialize(indexBytes []byte) (*Index, error) { } return index, nil } + +// SerializeCapabilitiesIndex returns a json marshal'ed CapabilitiesIndex object +func (i CapabilitiesIndex) SerializeCapabilitiesIndex() ([]byte, error) { + indexBytes, err := json.Marshal(i) + if err != nil { + return nil, err + } + + return indexBytes, nil +} + +// DeserializeCapabilitiesIndex converts json bytes to an CapabilitiesIndex object +func DeserializeCapabilitiesIndex(indexBytes []byte) (*CapabilitiesIndex, error) { + index := new(CapabilitiesIndex) + if err := json.Unmarshal(indexBytes, index); err != nil { + return nil, err + } + return index, nil +} diff --git a/command/agentproxyshared/cache/lease_cache.go b/command/agentproxyshared/cache/lease_cache.go index 38ca6b4b0d..0db186d580 100644 --- a/command/agentproxyshared/cache/lease_cache.go +++ b/command/agentproxyshared/cache/lease_cache.go @@ -102,6 +102,10 @@ type LeaseCache struct { // cacheStaticSecrets is used to determine if the cache should also // cache static secrets, as well as dynamic secrets. cacheStaticSecrets bool + + // capabilityManager is used when static secrets are enabled to + // manage the capabilities of cached tokens. + capabilityManager *StaticSecretCapabilityManager } // LeaseCacheConfig is the configuration for initializing a new @@ -168,9 +172,22 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) { }, nil } +// SetCapabilityManager is a setter for CapabilityManager. If set, will manage capabilities +// for capability indexes. +func (c *LeaseCache) SetCapabilityManager(capabilityManager *StaticSecretCapabilityManager) { + c.capabilityManager = capabilityManager +} + // SetShuttingDown is a setter for the shuttingDown field func (c *LeaseCache) SetShuttingDown(in bool) { c.shuttingDown.Store(in) + + // Since we're shutting down, also stop the capability manager's jobs. + // We can do this forcibly since no there's no reason to update + // the cache when we're shutting down. + if c.capabilityManager != nil { + c.capabilityManager.Stop() + } } // SetPersistentStorage is a setter for the persistent storage field in @@ -628,7 +645,7 @@ func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendReques return err } - capabilitiesIndex, err := c.retrieveOrCreateTokenCapabilitiesEntry(req.Token) + capabilitiesIndex, created, err := c.retrieveOrCreateTokenCapabilitiesEntry(req.Token) if err != nil { c.logger.Error("failed to cache the proxied response", "error", err) return err @@ -644,27 +661,35 @@ func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendReques // update the index with the new capability: capabilitiesIndex.ReadablePaths[path] = struct{}{} - err = c.db.SetCapabilitiesIndex(capabilitiesIndex) + err = c.SetCapabilitiesIndex(ctx, capabilitiesIndex) if err != nil { c.logger.Error("failed to cache token capabilities as part of caching the proxied response", "error", err) return err } + // Lastly, ensure that we start renewing this index, if it's new. + // We require the 'created' check so that we don't renew the same + // index multiple times. + if c.capabilityManager != nil && created { + c.capabilityManager.StartRenewingCapabilities(capabilitiesIndex) + } + return nil } // retrieveOrCreateTokenCapabilitiesEntry will either retrieve the token // capabilities entry from the cache, or create a new, empty one. -func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cachememdb.CapabilitiesIndex, error) { +// The bool represents if a new token capability has been created. +func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cachememdb.CapabilitiesIndex, bool, error) { // The index ID is a hash of the token. indexId := hashStaticSecretIndex(token) indexFromCache, err := c.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId) if err != nil && err != cachememdb.ErrCacheItemNotFound { - return nil, err + return nil, false, err } if indexFromCache != nil { - return indexFromCache, nil + return indexFromCache, false, nil } // Build the index to cache based on the response received @@ -674,7 +699,7 @@ func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cach ReadablePaths: make(map[string]struct{}), } - return index, nil + return index, true, nil } func (c *LeaseCache) createCtxInfo(ctx context.Context) *cachememdb.ContextInfo { @@ -1266,6 +1291,28 @@ func (c *LeaseCache) Set(ctx context.Context, index *cachememdb.Index) error { return nil } +// SetCapabilitiesIndex stores the capabilities index in the cachememdb, and also stores it in the persistent +// cache (if enabled) +func (c *LeaseCache) SetCapabilitiesIndex(ctx context.Context, index *cachememdb.CapabilitiesIndex) error { + if err := c.db.SetCapabilitiesIndex(index); err != nil { + return err + } + + if c.ps != nil { + plaintext, err := index.SerializeCapabilitiesIndex() + if err != nil { + return err + } + + if err := c.ps.Set(ctx, index.ID, plaintext, cacheboltdb.TokenCapabilitiesType); err != nil { + return err + } + c.logger.Trace("set entry in persistent storage", "type", cacheboltdb.TokenCapabilitiesType, "id", index.ID) + } + + return nil +} + // Evict removes an Index from the cachememdb, and also removes it from the // persistent cache (if enabled) func (c *LeaseCache) Evict(index *cachememdb.Index) error { @@ -1300,6 +1347,8 @@ func (c *LeaseCache) Flush() error { // Restore loads the cachememdb from the persistent storage passed in. Loads // tokens first, since restoring a lease's renewal context and watcher requires // looking up the token in the cachememdb. +// Restore also restarts any capability management for managed static secret +// tokens. func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStorage) error { var errs *multierror.Error @@ -1348,6 +1397,51 @@ func (c *LeaseCache) Restore(ctx context.Context, storage *cacheboltdb.BoltStora } } + // Then process static secrets and their capabilities + if c.cacheStaticSecrets { + staticSecrets, err := storage.GetByType(ctx, cacheboltdb.StaticSecretType) + if err != nil { + errs = multierror.Append(errs, err) + } else { + for _, staticSecret := range staticSecrets { + newIndex, err := cachememdb.Deserialize(staticSecret) + if err != nil { + errs = multierror.Append(errs, err) + continue + } + + c.logger.Trace("restoring static secret index", "id", newIndex.ID, "path", newIndex.RequestPath) + if err := c.db.Set(newIndex); err != nil { + errs = multierror.Append(errs, err) + continue + } + } + } + + capabilityIndexes, err := storage.GetByType(ctx, cacheboltdb.TokenCapabilitiesType) + if err != nil { + errs = multierror.Append(errs, err) + } else { + for _, capabilityIndex := range capabilityIndexes { + newIndex, err := cachememdb.DeserializeCapabilitiesIndex(capabilityIndex) + if err != nil { + errs = multierror.Append(errs, err) + continue + } + + c.logger.Trace("restoring capability index", "id", newIndex.ID) + if err := c.db.SetCapabilitiesIndex(newIndex); err != nil { + errs = multierror.Append(errs, err) + continue + } + + if c.capabilityManager != nil { + c.capabilityManager.StartRenewingCapabilities(newIndex) + } + } + } + } + return errs.ErrorOrNil() } diff --git a/command/agentproxyshared/cache/static_secret_capability_manager.go b/command/agentproxyshared/cache/static_secret_capability_manager.go new file mode 100644 index 0000000000..46de4740dd --- /dev/null +++ b/command/agentproxyshared/cache/static_secret_capability_manager.go @@ -0,0 +1,261 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + "time" + + "github.com/gammazero/workerpool" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb" + "github.com/mitchellh/mapstructure" + "golang.org/x/exp/maps" +) + +const ( + // DefaultWorkers is the default number of workers for the worker pool. + DefaultWorkers = 5 + + // DefaultStaticSecretTokenCapabilityRefreshInterval is the default time + // between each capability poll. This is configured with the following config value: + // static_secret_token_capability_refresh_interval + DefaultStaticSecretTokenCapabilityRefreshInterval = 5 * time.Minute +) + +// StaticSecretCapabilityManager is a struct that utilizes +// a worker pool to keep capabilities up to date. +type StaticSecretCapabilityManager struct { + client *api.Client + leaseCache *LeaseCache + logger hclog.Logger + workerPool *workerpool.WorkerPool + staticSecretTokenCapabilityRefreshInterval time.Duration +} + +// StaticSecretCapabilityManagerConfig is the configuration for initializing a new +// StaticSecretCapabilityManager. +type StaticSecretCapabilityManagerConfig struct { + LeaseCache *LeaseCache + Logger hclog.Logger + Client *api.Client + StaticSecretTokenCapabilityRefreshInterval time.Duration +} + +// NewStaticSecretCapabilityManager creates a new instance of a StaticSecretCapabilityManager. +func NewStaticSecretCapabilityManager(conf *StaticSecretCapabilityManagerConfig) (*StaticSecretCapabilityManager, error) { + if conf == nil { + return nil, errors.New("nil configuration provided") + } + + if conf.LeaseCache == nil { + return nil, fmt.Errorf("nil Lease Cache (a required parameter): %v", conf) + } + + if conf.Logger == nil { + return nil, fmt.Errorf("nil Logger (a required parameter): %v", conf) + } + + if conf.Client == nil { + return nil, fmt.Errorf("nil Client (a required parameter): %v", conf) + } + + if conf.StaticSecretTokenCapabilityRefreshInterval == 0 { + conf.StaticSecretTokenCapabilityRefreshInterval = DefaultStaticSecretTokenCapabilityRefreshInterval + } + + workerPool := workerpool.New(DefaultWorkers) + + return &StaticSecretCapabilityManager{ + client: conf.Client, + leaseCache: conf.LeaseCache, + logger: conf.Logger, + workerPool: workerPool, + staticSecretTokenCapabilityRefreshInterval: conf.StaticSecretTokenCapabilityRefreshInterval, + }, nil +} + +// submitWorkToPoolAfterInterval submits work to the pool after the defined +// staticSecretTokenCapabilityRefreshInterval +func (sscm *StaticSecretCapabilityManager) submitWorkToPoolAfterInterval(work func()) { + time.AfterFunc(sscm.staticSecretTokenCapabilityRefreshInterval, func() { + if !sscm.workerPool.Stopped() { + sscm.workerPool.Submit(work) + } + }) +} + +// Stop stops all ongoing jobs and ensures future jobs will not +// get added to the worker pool. +func (sscm *StaticSecretCapabilityManager) Stop() { + sscm.workerPool.Stop() +} + +// StartRenewingCapabilities takes a polling job and submits a constant renewal of capabilities to the worker pool. +// indexToRenew is the capabilities index we'll renew the capabilities for. +func (sscm *StaticSecretCapabilityManager) StartRenewingCapabilities(indexToRenew *cachememdb.CapabilitiesIndex) { + var work func() + work = func() { + if sscm.workerPool.Stopped() { + sscm.logger.Trace("worker pool stopped, stopping renewal") + return + } + + capabilitiesIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexToRenew.ID) + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { + // This cache entry no longer exists, so there is no more work to do. + sscm.logger.Trace("cache item not found for capabilities refresh, stopping the process") + return + } + if err != nil { + sscm.logger.Error("error when attempting to get capabilities index to refresh token capabilities", "indexToRenew.ID", indexToRenew.ID, "err", err) + sscm.submitWorkToPoolAfterInterval(work) + return + } + + capabilitiesIndex.IndexLock.RLock() + token := capabilitiesIndex.Token + indexReadablePathsMap := capabilitiesIndex.ReadablePaths + capabilitiesIndex.IndexLock.RUnlock() + indexReadablePaths := maps.Keys(indexReadablePathsMap) + + client, err := sscm.client.Clone() + if err != nil { + sscm.logger.Error("error when attempting clone client to refresh token capabilities", "indexToRenew.ID", indexToRenew.ID, "err", err) + sscm.submitWorkToPoolAfterInterval(work) + return + } + + client.SetToken(token) + + capabilities, err := getCapabilities(indexReadablePaths, client) + if err != nil { + sscm.logger.Error("error when attempting to retrieve updated token capabilities", "indexToRenew.ID", indexToRenew.ID, "err", err) + sscm.submitWorkToPoolAfterInterval(work) + return + } + + newReadablePaths := reconcileCapabilities(indexReadablePaths, capabilities) + if maps.Equal(indexReadablePathsMap, newReadablePaths) { + sscm.logger.Trace("capabilities were the same for index, nothing to do", "indexToRenew.ID", indexToRenew.ID) + // there's nothing to update! + sscm.submitWorkToPoolAfterInterval(work) + return + } + + // before updating or evicting the index, we must update the tokens on + // for each path, update the corresponding index with the diff + for _, path := range indexReadablePaths { + // If the old path isn't contained in the new readable paths, + // we must delete it from the tokens map for its corresponding + // path index. + if _, ok := newReadablePaths[path]; !ok { + indexId := hashStaticSecretIndex(path) + index, err := sscm.leaseCache.db.Get(cachememdb.IndexNameID, indexId) + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { + // Nothing to update! + continue + } + if err != nil { + sscm.logger.Error("error when attempting to update corresponding paths for capabilities index", "indexToRenew.ID", indexToRenew.ID, "err", err) + sscm.submitWorkToPoolAfterInterval(work) + return + } + sscm.logger.Trace("updating tokens for index, as capability has been lost", "index.ID", index.ID, "request_path", index.RequestPath) + index.IndexLock.Lock() + delete(index.Tokens, capabilitiesIndex.Token) + err = sscm.leaseCache.Set(context.Background(), index) + if err != nil { + sscm.logger.Error("error when attempting to update index in cache", "index.ID", index.ID, "err", err) + } + index.IndexLock.Unlock() + } + } + + // Lastly, we should update the capabilities index, either evicting or updating it + capabilitiesIndex.IndexLock.Lock() + defer capabilitiesIndex.IndexLock.Unlock() + if len(newReadablePaths) == 0 { + err := sscm.leaseCache.db.EvictCapabilitiesIndex(cachememdb.IndexNameID, indexToRenew.ID) + if err != nil { + sscm.logger.Error("error when attempting to evict capabilities from cache", "index.ID", indexToRenew.ID, "err", err) + sscm.submitWorkToPoolAfterInterval(work) + return + } + // If we successfully evicted the index, no need to re-submit the work to the pool. + return + } + + // The token still has some capabilities, so, update the capabilities index: + capabilitiesIndex.ReadablePaths = newReadablePaths + err = sscm.leaseCache.SetCapabilitiesIndex(context.Background(), capabilitiesIndex) + if err != nil { + sscm.logger.Error("error when attempting to update capabilities from cache", "index.ID", indexToRenew.ID, "err", err) + } + + // Finally, put ourselves back on the work pool after + sscm.submitWorkToPoolAfterInterval(work) + return + } + + sscm.submitWorkToPoolAfterInterval(work) +} + +// getCapabilities is a wrapper around a /sys/capabilities-self call that returns +// capabilities as a map with paths as keys, and capabilities as values. +func getCapabilities(paths []string, client *api.Client) (map[string][]string, error) { + body := make(map[string]interface{}) + body["paths"] = paths + capabilities := make(map[string][]string) + + secret, err := client.Logical().Write("sys/capabilities-self", body) + if err != nil && strings.Contains(err.Error(), "permission denied") { + // Token has expired. Return an empty set of capabilities: + return capabilities, nil + } + if err != nil { + return nil, err + } + + if secret == nil || secret.Data == nil { + return nil, errors.New("data from server response is empty") + } + + for _, path := range paths { + var res []string + err = mapstructure.Decode(secret.Data[path], &res) + if err != nil { + return nil, err + } + + capabilities[path] = res + } + + return capabilities, nil +} + +// reconcileCapabilities takes a set of known readable paths, and a set of capabilities (a response from the +// sys/capabilities-self endpoint) and returns a subset of the readablePaths after taking into account any updated +// capabilities as a set, represented by a map of strings to structs. +// It will delete any path in readablePaths if it does not have a "root" or "read" capability listed in the +// capabilities map. +func reconcileCapabilities(readablePaths []string, capabilities map[string][]string) map[string]struct{} { + newReadablePaths := make(map[string]struct{}) + for pathName, permissions := range capabilities { + if slices.Contains(permissions, "read") || slices.Contains(permissions, "root") { + // We do this as an additional sanity check. We never want to + // add permissions that weren't there before. + if slices.Contains(readablePaths, pathName) { + newReadablePaths[pathName] = struct{}{} + } + } + } + + return newReadablePaths +} diff --git a/command/agentproxyshared/cache/static_secret_capability_manager_test.go b/command/agentproxyshared/cache/static_secret_capability_manager_test.go new file mode 100644 index 0000000000..0bbb5973a8 --- /dev/null +++ b/command/agentproxyshared/cache/static_secret_capability_manager_test.go @@ -0,0 +1,432 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "testing" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb" + "github.com/hashicorp/vault/helper/testhelpers/minimal" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/stretchr/testify/require" +) + +// testNewStaticSecretCapabilityManager returns a new StaticSecretCapabilityManager +// for use in tests. +func testNewStaticSecretCapabilityManager(t *testing.T, client *api.Client) *StaticSecretCapabilityManager { + t.Helper() + + lc := testNewLeaseCache(t, []*SendResponse{}) + + updater, err := NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{ + LeaseCache: lc, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager"), + Client: client, + StaticSecretTokenCapabilityRefreshInterval: 250 * time.Millisecond, + }) + if err != nil { + t.Fatal(err) + } + + return updater +} + +// TestNewStaticSecretCapabilityManager tests the NewStaticSecretCapabilityManager method, +// to ensure it errors out when appropriate. +func TestNewStaticSecretCapabilityManager(t *testing.T) { + t.Parallel() + + lc := testNewLeaseCache(t, []*SendResponse{}) + logger := logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager") + client, err := api.NewClient(api.DefaultConfig()) + require.Nil(t, err) + + // Expect an error if any of the arguments are nil: + updater, err := NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{ + LeaseCache: nil, + Logger: logger, + Client: client, + }) + require.Error(t, err) + require.Nil(t, updater) + + updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{ + LeaseCache: lc, + Logger: nil, + Client: client, + }) + require.Error(t, err) + require.Nil(t, updater) + + updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{ + LeaseCache: lc, + Logger: logger, + Client: nil, + }) + require.Error(t, err) + require.Nil(t, updater) + + // Don't expect an error if the arguments are as expected + updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{ + LeaseCache: lc, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager"), + Client: client, + }) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, updater) + require.NotNil(t, updater.workerPool) + require.NotNil(t, updater.staticSecretTokenCapabilityRefreshInterval) + require.NotNil(t, updater.client) + require.NotNil(t, updater.leaseCache) + require.NotNil(t, updater.logger) + require.Equal(t, DefaultStaticSecretTokenCapabilityRefreshInterval, updater.staticSecretTokenCapabilityRefreshInterval) + + // Lastly, double check that the refresh interval can be properly set + updater, err = NewStaticSecretCapabilityManager(&StaticSecretCapabilityManagerConfig{ + LeaseCache: lc, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.capabilitiesmanager"), + Client: client, + StaticSecretTokenCapabilityRefreshInterval: time.Hour, + }) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, updater) + require.NotNil(t, updater.workerPool) + require.NotNil(t, updater.staticSecretTokenCapabilityRefreshInterval) + require.NotNil(t, updater.client) + require.NotNil(t, updater.leaseCache) + require.NotNil(t, updater.logger) + require.Equal(t, time.Hour, updater.staticSecretTokenCapabilityRefreshInterval) +} + +// TestGetCapabilitiesRootToken tests the getCapabilities method with the root +// token, expecting to get "root" capabilities on valid paths +func TestGetCapabilitiesRootToken(t *testing.T) { + t.Parallel() + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + capabilitiesToCheck := []string{"auth/token/create", "sys/health"} + capabilities, err := getCapabilities(capabilitiesToCheck, client) + require.NoError(t, err) + + expectedCapabilities := map[string][]string{ + "auth/token/create": {"root"}, + "sys/health": {"root"}, + } + require.Equal(t, expectedCapabilities, capabilities) +} + +// TestGetCapabilitiesLowPrivilegeToken tests the getCapabilities method with +// a low privilege token, expecting to get deny or non-root capabilities +func TestGetCapabilitiesLowPrivilegeToken(t *testing.T) { + t.Parallel() + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + renewable := true + // Set the token's policies to 'default' and nothing else + tokenCreateRequest := &api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + Renewable: &renewable, + } + + secret, err := client.Auth().Token().CreateOrphan(tokenCreateRequest) + require.NoError(t, err) + token := secret.Auth.ClientToken + + client.SetToken(token) + + capabilitiesToCheck := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"} + capabilities, err := getCapabilities(capabilitiesToCheck, client) + require.NoError(t, err) + + expectedCapabilities := map[string][]string{ + "auth/token/create": {"deny"}, + "sys/capabilities-self": {"update"}, + "auth/token/lookup-self": {"read"}, + } + require.Equal(t, expectedCapabilities, capabilities) +} + +// TestGetCapabilitiesBadClientToken tests that getCapabilities +// returns an empty set of capabilities if the token is bad (and it gets a 403) +func TestGetCapabilitiesBadClientToken(t *testing.T) { + t.Parallel() + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + client.SetToken("") + + capabilitiesToCheck := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"} + capabilities, err := getCapabilities(capabilitiesToCheck, client) + require.Nil(t, err) + require.Equal(t, map[string][]string{}, capabilities) +} + +// TestGetCapabilitiesEmptyPaths tests the getCapabilities will error on an empty +// set of paths to check +func TestGetCapabilitiesEmptyPaths(t *testing.T) { + t.Parallel() + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + var capabilitiesToCheck []string + _, err := getCapabilities(capabilitiesToCheck, client) + require.Error(t, err) +} + +// TestReconcileCapabilities tests that reconcileCapabilities will +// correctly previously remove readable paths that we don't have read access to. +func TestReconcileCapabilities(t *testing.T) { + t.Parallel() + paths := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"} + capabilities := map[string][]string{ + "auth/token/create": {"deny"}, + "sys/capabilities-self": {"update"}, + "auth/token/lookup-self": {"read"}, + } + + updatedCapabilities := reconcileCapabilities(paths, capabilities) + expectedUpdatedCapabilities := map[string]struct{}{ + "auth/token/lookup-self": {}, + } + require.Equal(t, expectedUpdatedCapabilities, updatedCapabilities) +} + +// TestReconcileCapabilitiesNoOp tests that reconcileCapabilities will +// correctly not remove capabilities when they all remain readable. +func TestReconcileCapabilitiesNoOp(t *testing.T) { + t.Parallel() + paths := []string{"foo/bar", "bar/baz", "baz/foo"} + capabilities := map[string][]string{ + "foo/bar": {"read"}, + "bar/baz": {"root"}, + "baz/foo": {"read"}, + } + + updatedCapabilities := reconcileCapabilities(paths, capabilities) + expectedUpdatedCapabilities := map[string]struct{}{ + "foo/bar": {}, + "bar/baz": {}, + "baz/foo": {}, + } + require.Equal(t, expectedUpdatedCapabilities, updatedCapabilities) +} + +// TestReconcileCapabilitiesNoAdding tests that reconcileCapabilities will +// not add any capabilities that weren't present in the first argument to the function +func TestReconcileCapabilitiesNoAdding(t *testing.T) { + t.Parallel() + paths := []string{"auth/token/create", "sys/capabilities-self", "auth/token/lookup-self"} + capabilities := map[string][]string{ + "auth/token/create": {"deny"}, + "sys/capabilities-self": {"update"}, + "auth/token/lookup-self": {"read"}, + "some/new/path": {"read"}, + } + + updatedCapabilities := reconcileCapabilities(paths, capabilities) + expectedUpdatedCapabilities := map[string]struct{}{ + "auth/token/lookup-self": {}, + } + require.Equal(t, expectedUpdatedCapabilities, updatedCapabilities) +} + +// TestSubmitWorkNoOp tests that we will gracefully end if the capabilities index +// does not exist in the cache +func TestSubmitWorkNoOp(t *testing.T) { + t.Parallel() + client, err := api.NewClient(api.DefaultConfig()) + require.Nil(t, err) + sscm := testNewStaticSecretCapabilityManager(t, client) + // This index will be a no-op, as this does not exist in the cache + index := &cachememdb.CapabilitiesIndex{ + ID: "test", + } + sscm.StartRenewingCapabilities(index) + + // Wait for the job to complete... + time.Sleep(1 * time.Second) + require.Equal(t, 0, sscm.workerPool.WaitingQueueSize()) +} + +// TestSubmitWorkUpdatesIndex tests that an index will be correctly updated if the capabilities differ. +func TestSubmitWorkUpdatesIndex(t *testing.T) { + t.Parallel() + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + // Create a low permission token + renewable := true + // Set the token's policies to 'default' and nothing else + tokenCreateRequest := &api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + Renewable: &renewable, + } + + secret, err := client.Auth().Token().CreateOrphan(tokenCreateRequest) + require.NoError(t, err) + token := secret.Auth.ClientToken + indexId := hashStaticSecretIndex(token) + + sscm := testNewStaticSecretCapabilityManager(t, client) + index := &cachememdb.CapabilitiesIndex{ + ID: indexId, + Token: token, + // The token will (perhaps obviously) not have + // read access to /foo/bar, but will to /auth/token/lookup-self + ReadablePaths: map[string]struct{}{ + "foo/bar": {}, + "auth/token/lookup-self": {}, + }, + } + err = sscm.leaseCache.db.SetCapabilitiesIndex(index) + require.Nil(t, err) + + sscm.StartRenewingCapabilities(index) + + // Wait for the job to complete at least once... + time.Sleep(3 * time.Second) + + newIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId) + require.Nil(t, err) + require.Equal(t, map[string]struct{}{ + "auth/token/lookup-self": {}, + }, newIndex.ReadablePaths) + + // Forcefully stop any remaining workers + sscm.workerPool.Stop() +} + +// TestSubmitWorkUpdatesIndexWithBadToken tests that an index will be correctly updated if the token +// has expired and we cannot access the sys capabilities endpoint. +func TestSubmitWorkUpdatesIndexWithBadToken(t *testing.T) { + t.Parallel() + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + token := "not real token" + indexId := hashStaticSecretIndex(token) + + sscm := testNewStaticSecretCapabilityManager(t, client) + index := &cachememdb.CapabilitiesIndex{ + ID: indexId, + Token: token, + ReadablePaths: map[string]struct{}{ + "foo/bar": {}, + "auth/token/lookup-self": {}, + }, + } + err := sscm.leaseCache.db.SetCapabilitiesIndex(index) + require.Nil(t, err) + + sscm.StartRenewingCapabilities(index) + + // Wait for the job to complete at least once... + time.Sleep(3 * time.Second) + + // This entry should be evicted. + newIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId) + require.Equal(t, err, cachememdb.ErrCacheItemNotFound) + require.Nil(t, newIndex) + + // Forcefully stop any remaining workers + sscm.workerPool.Stop() +} + +// TestSubmitWorkUpdatesAllIndexes tests that an index will be correctly updated if the capabilities differ, as +// well as the indexes related to the paths that are being checked for. +func TestSubmitWorkUpdatesAllIndexes(t *testing.T) { + t.Parallel() + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + // Create a low permission token + renewable := true + // Set the token's policies to 'default' and nothing else + tokenCreateRequest := &api.TokenCreateRequest{ + Policies: []string{"default"}, + TTL: "30m", + Renewable: &renewable, + } + + secret, err := client.Auth().Token().CreateOrphan(tokenCreateRequest) + require.NoError(t, err) + token := secret.Auth.ClientToken + indexId := hashStaticSecretIndex(token) + + sscm := testNewStaticSecretCapabilityManager(t, client) + index := &cachememdb.CapabilitiesIndex{ + ID: indexId, + Token: token, + // The token will (perhaps obviously) not have + // read access to /foo/bar, but will to /auth/token/lookup-self + ReadablePaths: map[string]struct{}{ + "foo/bar": {}, + "auth/token/lookup-self": {}, + }, + } + err = sscm.leaseCache.db.SetCapabilitiesIndex(index) + require.Nil(t, err) + + pathIndexId1 := hashStaticSecretIndex("foo/bar") + pathIndex1 := &cachememdb.Index{ + ID: pathIndexId1, + Namespace: "root/", + Tokens: map[string]struct{}{ + token: {}, + }, + RequestPath: "foo/bar", + Response: []byte{}, + } + + pathIndexId2 := hashStaticSecretIndex("auth/token/lookup-self") + pathIndex2 := &cachememdb.Index{ + ID: pathIndexId2, + Namespace: "root/", + Tokens: map[string]struct{}{ + token: {}, + }, + RequestPath: "auth/token/lookup-self", + Response: []byte{}, + } + + err = sscm.leaseCache.db.Set(pathIndex1) + require.Nil(t, err) + + err = sscm.leaseCache.db.Set(pathIndex2) + require.Nil(t, err) + + sscm.StartRenewingCapabilities(index) + + // Wait for the job to complete at least once... + time.Sleep(1 * time.Second) + + newIndex, err := sscm.leaseCache.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId) + require.Nil(t, err) + require.Equal(t, map[string]struct{}{ + "auth/token/lookup-self": {}, + }, newIndex.ReadablePaths) + + // For this, we expect the token to have been deleted + newPathIndex1, err := sscm.leaseCache.db.Get(cachememdb.IndexNameID, pathIndexId1) + require.Nil(t, err) + require.Equal(t, map[string]struct{}{}, newPathIndex1.Tokens) + + // For this, we expect no change + newPathIndex2, err := sscm.leaseCache.db.Get(cachememdb.IndexNameID, pathIndexId2) + require.Nil(t, err) + require.Equal(t, newPathIndex2, newPathIndex2) + + // Forcefully stop any remaining workers + sscm.workerPool.Stop() +} diff --git a/command/proxy.go b/command/proxy.go index ec5daab603..2d7c74bf05 100644 --- a/command/proxy.go +++ b/command/proxy.go @@ -491,6 +491,18 @@ func (c *ProxyCommand) Run(args []string) int { c.UI.Error(fmt.Sprintf("Error creating static secret cache updater: %v", err)) return 1 } + + capabilityManager, err := cache.NewStaticSecretCapabilityManager(&cache.StaticSecretCapabilityManagerConfig{ + LeaseCache: leaseCache, + Logger: c.logger.Named("cache.staticsecretcapabilitymanager"), + Client: client, + StaticSecretTokenCapabilityRefreshInterval: config.Cache.StaticSecretTokenCapabilityRefreshInterval, + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating static secret capability manager: %v", err)) + return 1 + } + leaseCache.SetCapabilityManager(capabilityManager) } } diff --git a/command/proxy/config/config.go b/command/proxy/config/config.go index c0afd50d6b..2103fb11ee 100644 --- a/command/proxy/config/config.go +++ b/command/proxy/config/config.go @@ -101,9 +101,11 @@ type APIProxy struct { // Cache contains any configuration needed for Cache mode type Cache struct { - Persist *agentproxyshared.PersistConfig `hcl:"persist"` - InProcDialer transportDialer `hcl:"-"` - CacheStaticSecrets bool `hcl:"cache_static_secrets"` + Persist *agentproxyshared.PersistConfig `hcl:"persist"` + InProcDialer transportDialer `hcl:"-"` + CacheStaticSecrets bool `hcl:"cache_static_secrets"` + StaticSecretTokenCapabilityRefreshIntervalRaw interface{} `hcl:"static_secret_token_capability_refresh_interval"` + StaticSecretTokenCapabilityRefreshInterval time.Duration `hcl:"-"` } // AutoAuth is the configured authentication method and sinks @@ -621,6 +623,14 @@ func parseCache(result *Config, list *ast.ObjectList) error { return fmt.Errorf("error parsing persist: %w", err) } + if result.Cache.StaticSecretTokenCapabilityRefreshIntervalRaw != nil { + var err error + if result.Cache.StaticSecretTokenCapabilityRefreshInterval, err = parseutil.ParseDurationSecond(result.Cache.StaticSecretTokenCapabilityRefreshIntervalRaw); err != nil { + return fmt.Errorf("error parsing static_secret_token_capability_refresh_interval, must be provided as a duration string: %w", err) + } + result.Cache.StaticSecretTokenCapabilityRefreshIntervalRaw = nil + } + return nil } diff --git a/command/proxy/config/config_test.go b/command/proxy/config/config_test.go index c92e1e1579..114436eb6f 100644 --- a/command/proxy/config/config_test.go +++ b/command/proxy/config/config_test.go @@ -5,6 +5,7 @@ package config import ( "testing" + "time" "github.com/go-test/deep" "github.com/hashicorp/vault/command/agentproxyshared" @@ -130,3 +131,62 @@ func TestLoadConfigFile_StaticSecretCachingWithoutAutoAuth(t *testing.T) { t.Fatalf("expected error, as static secret caching requires auto-auth") } } + +// TestLoadConfigFile_ProxyCacheStaticSecrets tests loading a config file containing a cache +// as well as a valid proxy config with static secret caching enabled +func TestLoadConfigFile_ProxyCacheStaticSecrets(t *testing.T) { + config, err := LoadConfigFile("./test-fixtures/config-cache-static-secret-cache.hcl") + if err != nil { + t.Fatal(err) + } + + expected := &Config{ + SharedConfig: &configutil.SharedConfig{ + PidFile: "./pidfile", + Listeners: []*configutil.Listener{ + { + Type: "tcp", + Address: "127.0.0.1:8300", + TLSDisable: true, + }, + }, + }, + AutoAuth: &AutoAuth{ + Method: &Method{ + Type: "aws", + MountPath: "auth/aws", + Config: map[string]interface{}{ + "role": "foobar", + }, + }, + Sinks: []*Sink{ + { + Type: "file", + DHType: "curve25519", + DHPath: "/tmp/file-foo-dhpath", + AAD: "foobar", + Config: map[string]interface{}{ + "path": "/tmp/file-foo", + }, + }, + }, + }, + Cache: &Cache{ + CacheStaticSecrets: true, + StaticSecretTokenCapabilityRefreshInterval: 1 * time.Hour, + }, + Vault: &Vault{ + Address: "http://127.0.0.1:1111", + TLSSkipVerify: true, + TLSSkipVerifyRaw: interface{}("true"), + Retry: &Retry{ + NumRetries: 12, + }, + }, + } + + config.Prune() + if diff := deep.Equal(config, expected); diff != nil { + t.Fatal(diff) + } +} diff --git a/command/proxy/config/test-fixtures/config-cache-static-secret-cache.hcl b/command/proxy/config/test-fixtures/config-cache-static-secret-cache.hcl new file mode 100644 index 0000000000..fa395bd8bd --- /dev/null +++ b/command/proxy/config/test-fixtures/config-cache-static-secret-cache.hcl @@ -0,0 +1,38 @@ +# Copyright (c) HashiCorp, Inc. +# SPDX-License-Identifier: BUSL-1.1 + +pid_file = "./pidfile" + +auto_auth { + method { + type = "aws" + config = { + role = "foobar" + } + } + + sink { + type = "file" + config = { + path = "/tmp/file-foo" + } + aad = "foobar" + dh_type = "curve25519" + dh_path = "/tmp/file-foo-dhpath" + } +} + +cache { + cache_static_secrets = true + static_secret_token_capability_refresh_interval = "1h" +} + +listener "tcp" { + address = "127.0.0.1:8300" + tls_disable = true +} + +vault { + address = "http://127.0.0.1:1111" + tls_skip_verify = "true" +} diff --git a/command/proxy_test.go b/command/proxy_test.go index 0ee60e1fa1..b310a41cbd 100644 --- a/command/proxy_test.go +++ b/command/proxy_test.go @@ -1156,6 +1156,178 @@ log_level = "trace" wg.Wait() } +// TestProxy_Cache_StaticSecretPermissionsLost Tests that the cache successfully caches a static secret +// going through the Proxy for a KVV2 secret, and then the calling client loses permissions to the secret, +// so it can no longer access the cache. +func TestProxy_Cache_StaticSecretPermissionsLost(t *testing.T) { + logger := logging.NewVaultLogger(hclog.Trace) + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": logicalKv.VersionedKVFactory, + }, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + + serverClient := cluster.Cores[0].Client + + // Unset the environment variable so that proxy picks up the right test + // cluster address + defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) + os.Unsetenv(api.EnvVaultAddress) + + tokenFileName := makeTempFile(t, "token-file", serverClient.Token()) + defer os.Remove(tokenFileName) + // We need auto-auth so that the event system can run. + // For ease, we use the token file path with the root token. + autoAuthConfig := fmt.Sprintf(` +auto_auth { + method { + type = "token_file" + config = { + token_file_path = "%s" + } + } +}`, tokenFileName) + + // We make the token capability refresh interval one second, for ease of testing + cacheConfig := ` +cache { + cache_static_secrets = true + static_secret_token_capability_refresh_interval = "1s" +} +` + listenAddr := generateListenerAddress(t) + listenConfig := fmt.Sprintf(` +listener "tcp" { + address = "%s" + tls_disable = true +} +`, listenAddr) + + config := fmt.Sprintf(` +vault { + address = "%s" + tls_skip_verify = true +} +%s +%s +%s +log_level = "trace" +`, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig) + configPath := makeTempFile(t, "config.hcl", config) + defer os.Remove(configPath) + + // Start proxy + ui, cmd := testProxyCommand(t, logger) + cmd.startedCh = make(chan struct{}) + + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + cmd.Run([]string{"-config", configPath}) + wg.Done() + }() + + select { + case <-cmd.startedCh: + case <-time.After(5 * time.Second): + t.Errorf("timeout") + t.Errorf("stdout: %s", ui.OutputWriter.String()) + t.Errorf("stderr: %s", ui.ErrorWriter.String()) + } + + proxyClient, err := api.NewClient(api.DefaultConfig()) + require.Nil(t, err) + proxyClient.SetMaxRetries(0) + err = proxyClient.SetAddress("http://" + listenAddr) + require.Nil(t, err) + + secretData := map[string]interface{}{ + "foo": "bar", + } + + // Mount the KVV2 engine + err = serverClient.Sys().Mount("secret-v2", &api.MountInput{ + Type: "kv-v2", + }) + require.Nil(t, err) + + err = serverClient.Sys().PutPolicy("kv-policy", ` + path "secret-v2/*" { + capabilities = ["update", "read"] + }`) + require.Nil(t, err) + + // Setup a token that we can later revoke: + renewable := true + // Set the token's policies to 'default' and nothing else + tokenCreateRequest := &api.TokenCreateRequest{ + Policies: []string{"default", "kv-policy"}, + TTL: "2s", + Renewable: &renewable, + } + + secret, err := serverClient.Auth().Token().CreateOrphan(tokenCreateRequest) + require.Nil(t, err) + token := secret.Auth.ClientToken + proxyClient.SetToken(token) + + // Create kvv2 secret + _, err = serverClient.KVv2("secret-v2").Put(context.Background(), "my-secret", secretData) + require.Nil(t, err) + + // We use raw requests so we can check the headers for cache hit/miss. + req := proxyClient.NewRequest(http.MethodGet, "/v1/secret-v2/data/my-secret") + resp1, err := proxyClient.RawRequest(req) + require.Nil(t, err) + + cacheValue := resp1.Header.Get("X-Cache") + require.Equal(t, "MISS", cacheValue) + + // We expect this to be a cache hit, with the new value + resp2, err := proxyClient.RawRequest(req) + require.Nil(t, err) + + cacheValue = resp2.Header.Get("X-Cache") + require.Equal(t, "HIT", cacheValue) + + // Lastly, we check to make sure the actual data we received is + // as we expect. We must use ParseSecret due to the raw requests. + secret1, err := api.ParseSecret(resp1.Body) + if err != nil { + t.Fatal(err) + } + data, ok := secret1.Data["data"] + require.True(t, ok) + require.Equal(t, secretData, data) + + secret2, err := api.ParseSecret(resp2.Body) + if err != nil { + t.Fatal(err) + } + data2, ok := secret2.Data["data"] + require.True(t, ok) + // We expect that the cached value got updated by the event system. + require.Equal(t, secretData, data2) + + // Wait for the token to expire, and for the permissions to be revoked + // The TTL on the token was 2s, and the capability refresh is every 1s, + // so this should give us more than enough time! + time.Sleep(5 * time.Second) + kvSecret, err := proxyClient.KVv2("secret-v2").Get(context.Background(), "my-secret") + if err == nil { + t.Fatalf("expected error, but none found, secret:%v, err:%v", kvSecret, err) + } + // Make sure it's a permission denied error + if !strings.Contains(err.Error(), "permission denied") { + t.Fatalf("expected error on GET to secret after token revocation, secret:%v, err:%v", kvSecret, err) + } + + close(cmd.ShutdownCh) + wg.Wait() +} + // TestProxy_ApiProxy_Retry Tests the retry functionalities of Vault Proxy's API Proxy func TestProxy_ApiProxy_Retry(t *testing.T) { //----------------------------------------------------