diff --git a/command/agentproxyshared/cache/cachememdb/cache_memdb.go b/command/agentproxyshared/cache/cachememdb/cache_memdb.go index d95555a765..daa9a747df 100644 --- a/command/agentproxyshared/cache/cachememdb/cache_memdb.go +++ b/command/agentproxyshared/cache/cachememdb/cache_memdb.go @@ -274,7 +274,7 @@ func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ( // Evict removes an index from the cache based on index name and value. func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error { index, err := c.Get(indexName, indexValues...) - if err == ErrCacheItemNotFound { + if errors.Is(err, ErrCacheItemNotFound) { return nil } if err != nil { diff --git a/command/agentproxyshared/cache/lease_cache.go b/command/agentproxyshared/cache/lease_cache.go index 95a5d5ee78..38ca6b4b0d 100644 --- a/command/agentproxyshared/cache/lease_cache.go +++ b/command/agentproxyshared/cache/lease_cache.go @@ -205,25 +205,8 @@ func (c *LeaseCache) checkCacheForStaticSecretRequest(id string, req *SendReques // If a token is provided, it will validate that the token is allowed to retrieve this // cache entry, and return nil if it isn't. func (c *LeaseCache) checkCacheForRequest(id string, req *SendRequest) (*SendResponse, error) { - var token string - if req != nil { - token = req.Token - // HEAD and OPTIONS are included as future-proofing, since neither of those modify the resource either. - if req.Request.Method != http.MethodGet && req.Request.Method != http.MethodHead && req.Request.Method != http.MethodOptions { - // This must be an update to the resource, so we should short-circuit and invalidate the cache - // as we know the cache is now stale. - c.logger.Debug("evicting index from cache, as non-GET received", "id", id, "method", req.Request.Method, "path", req.Request.URL.Path) - err := c.db.Evict(cachememdb.IndexNameID, id) - if err != nil { - return nil, err - } - - return nil, nil - } - } - index, err := c.db.Get(cachememdb.IndexNameID, id) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { return nil, nil } if err != nil { @@ -233,8 +216,17 @@ func (c *LeaseCache) checkCacheForRequest(id string, req *SendRequest) (*SendRes index.IndexLock.RLock() defer index.IndexLock.RUnlock() + var token string + if req != nil { + // Req will be non-nil if we're checking for a static secret. + // Token might still be "" if it's going to an unauthenticated + // endpoint, or similar. For static secrets, we only care about + // requests with tokens attached, as KV is authenticated. + token = req.Token + } + if token != "" { - // This is a static secret check. We need to ensure that this token + // We are checking for a static secret. We need to ensure that this token // has previously demonstrated access to this static secret. // We could check the capabilities cache here, but since these // indexes should be in sync, this saves us an extra cache get. @@ -381,7 +373,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, } // Check if the response for this request is already in the static secret cache - if staticSecretCacheId != "" { + if staticSecretCacheId != "" && req.Request.Method == http.MethodGet { cachedResp, err = c.checkCacheForStaticSecretRequest(staticSecretCacheId, req) if err != nil { return nil, err @@ -446,7 +438,9 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, // There shouldn't be a situation where secret.MountType == "kv" and // staticSecretCacheId == "", but just in case. - if c.cacheStaticSecrets && secret.MountType == "kv" && staticSecretCacheId != "" { + // We restrict this to GETs as those are all we want to cache. + if c.cacheStaticSecrets && secret.MountType == "kv" && + staticSecretCacheId != "" && req.Request.Method == http.MethodGet { index.Type = cacheboltdb.StaticSecretType index.ID = staticSecretCacheId err := c.cacheStaticSecret(ctx, req, resp, index) @@ -475,7 +469,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, case secret.LeaseID != "": c.logger.Debug("processing lease response", "method", req.Request.Method, "path", req.Request.URL.Path) entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { // If the lease belongs to a token that is not managed by the lease cache, // return the response without caching it. c.logger.Debug("pass-through lease response; token not managed by lease cache", "method", req.Request.Method, "path", req.Request.URL.Path) @@ -501,7 +495,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, var parentCtx context.Context if !secret.Auth.Orphan { entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { // If the lease belongs to a token that is not managed by the lease cache, // return the response without caching it. c.logger.Debug("pass-through lease response; parent token not managed by lease cache", "method", req.Request.Method, "path", req.Request.URL.Path) @@ -564,7 +558,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, if index.Type != cacheboltdb.StaticSecretType { // Store the index in the cache - c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path) + c.logger.Debug("storing dynamic secret response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path, "id", index.ID) err = c.Set(ctx, index) if err != nil { c.logger.Error("failed to cache the proxied response", "error", err) @@ -587,7 +581,7 @@ func (c *LeaseCache) cacheStaticSecret(ctx context.Context, req *SendRequest, re } // The index already exists, so all we need to do is add our token - // to the index's allowed token list, then re-store it + // to the index's allowed token list, then re-store it. if indexFromCache != nil { // We must hold a lock for the index while it's being updated. // We keep the two locking mechanisms distinct, so that it's only writes @@ -627,7 +621,7 @@ func (c *LeaseCache) cacheStaticSecret(ctx context.Context, req *SendRequest, re func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendRequest, index *cachememdb.Index) error { // Store the index in the cache - c.logger.Debug("storing response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path) + c.logger.Debug("storing static secret response into the cache", "method", req.Request.Method, "path", req.Request.URL.Path, "id", index.ID) err := c.Set(ctx, index) if err != nil { c.logger.Error("failed to cache the proxied response", "error", err) @@ -663,7 +657,7 @@ func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendReques // capabilities entry from the cache, or create a new, empty one. func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cachememdb.CapabilitiesIndex, error) { // The index ID is a hash of the token. - indexId := hex.EncodeToString(cryptoutil.Blake2b256Hash(token)) + indexId := hashStaticSecretIndex(token) indexFromCache, err := c.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId) if err != nil && err != cachememdb.ErrCacheItemNotFound { return nil, err @@ -860,6 +854,12 @@ func getStaticSecretPathFromRequest(req *SendRequest) string { return canonicalizeStaticSecretPath(path, namespace) } +// hashStaticSecretIndex is a simple function that hashes the path into +// a function. This is kept as a helper function for ease of use by downstream functions. +func hashStaticSecretIndex(unhashedIndex string) string { + return hex.EncodeToString(cryptoutil.Blake2b256Hash(unhashedIndex)) +} + // computeStaticSecretCacheIndex results in a value that uniquely identifies a static // secret's cached ID. Notably, we intentionally ignore headers (for example, // the X-Vault-Token header) to remain agnostic to which token is being @@ -871,7 +871,7 @@ func computeStaticSecretCacheIndex(req *SendRequest) string { if path == "" { return path } - return hex.EncodeToString(cryptoutil.Blake2b256Hash(path)) + return hashStaticSecretIndex(path) } // HandleCacheClear returns a handlerFunc that can perform cache clearing operations. @@ -973,7 +973,7 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) // Get the context for the given token and cancel its context index, err := c.db.Get(cachememdb.IndexNameToken, in.Token) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { return nil } if err != nil { @@ -992,7 +992,7 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) // Get the cached index and cancel the corresponding lifetime watcher // context index, err := c.db.Get(cachememdb.IndexNameTokenAccessor, in.TokenAccessor) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { return nil } if err != nil { @@ -1011,7 +1011,7 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) // Get the cached index and cancel the corresponding lifetime watcher // context index, err := c.db.Get(cachememdb.IndexNameLease, in.Lease) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { return nil } if err != nil { @@ -1147,7 +1147,7 @@ func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendReque // Kill the lifetime watchers of the revoked token index, err := c.db.Get(cachememdb.IndexNameToken, token) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { return true, nil } if err != nil { @@ -1395,7 +1395,7 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error { switch { case secret.LeaseID != "": entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) } if err != nil { @@ -1409,7 +1409,7 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error { var parentCtx context.Context if !secret.Auth.Orphan { entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) - if err == cachememdb.ErrCacheItemNotFound { + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { // If parent token is not managed by the cache, child shouldn't be // either. if entry == nil { diff --git a/command/agentproxyshared/cache/static_secret_cache_updater.go b/command/agentproxyshared/cache/static_secret_cache_updater.go new file mode 100644 index 0000000000..486d6f5291 --- /dev/null +++ b/command/agentproxyshared/cache/static_secret_cache_updater.go @@ -0,0 +1,385 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb" + "github.com/hashicorp/vault/command/agentproxyshared/sink" + "github.com/hashicorp/vault/helper/useragent" + "golang.org/x/exp/maps" + "nhooyr.io/websocket" +) + +// Example Event: +//{ +// "id": "a3be9fb1-b514-519f-5b25-b6f144a8c1ce", +// "source": "https://vaultproject.io/", +// "specversion": "1.0", +// "type": "*", +// "data": { +// "event": { +// "id": "a3be9fb1-b514-519f-5b25-b6f144a8c1ce", +// "metadata": { +// "current_version": "1", +// "data_path": "secret/data/foo", +// "modified": "true", +// "oldest_version": "0", +// "operation": "data-write", +// "path": "secret/data/foo" +// } +// }, +// "event_type": "kv-v2/data-write", +// "plugin_info": { +// "mount_class": "secret", +// "mount_accessor": "kv_5dc4d18e", +// "mount_path": "secret/", +// "plugin": "kv" +// } +// }, +// "datacontentype": "application/cloudevents", +// "time": "2023-09-12T15:19:49.394915-07:00" +//} + +// StaticSecretCacheUpdater is a struct that utilizes +// the event system to keep the static secret cache up to date. +type StaticSecretCacheUpdater struct { + client *api.Client + leaseCache *LeaseCache + logger hclog.Logger + tokenSink sink.Sink +} + +// StaticSecretCacheUpdaterConfig is the configuration for initializing a new +// StaticSecretCacheUpdater. +type StaticSecretCacheUpdaterConfig struct { + Client *api.Client + LeaseCache *LeaseCache + Logger hclog.Logger + // TokenSink is a token sync that will have the latest + // token from auto-auth in it, to be used in event system + // connections. + TokenSink sink.Sink +} + +// NewStaticSecretCacheUpdater creates a new instance of a StaticSecretCacheUpdater. +func NewStaticSecretCacheUpdater(conf *StaticSecretCacheUpdaterConfig) (*StaticSecretCacheUpdater, error) { + if conf == nil { + return nil, errors.New("nil configuration provided") + } + + if conf.LeaseCache == nil { + return nil, fmt.Errorf("nil Lease Cache (a required parameter): %v", conf) + } + + if conf.Logger == nil { + return nil, fmt.Errorf("nil Logger (a required parameter): %v", conf) + } + + if conf.Client == nil { + return nil, fmt.Errorf("nil API client (a required parameter): %v", conf) + } + + if conf.TokenSink == nil { + return nil, fmt.Errorf("nil token sink (a required parameter): %v", conf) + } + + return &StaticSecretCacheUpdater{ + client: conf.Client, + leaseCache: conf.LeaseCache, + logger: conf.Logger, + tokenSink: conf.TokenSink, + }, nil +} + +// streamStaticSecretEvents streams static secret events and updates +// the cache when updates are notified. This method will return errors in cases +// of failed updates, malformed events, and other. +// For best results, the caller of this function should retry on error with backoff, +// if it is desired for the cache to always remain up to date. +func (updater *StaticSecretCacheUpdater) streamStaticSecretEvents(ctx context.Context) error { + // First, ensure our token is up-to-date: + updater.client.SetToken(updater.tokenSink.(sink.SinkReader).Token()) + conn, err := updater.openWebSocketConnection(ctx) + if err != nil { + return fmt.Errorf("error when opening event stream: %w", err) + } + defer conn.Close(websocket.StatusNormalClosure, "") + + // before we check for events, update all of our cached + // kv secrets, in case we missed any events + // TODO: to be implemented in a future PR + + for { + select { + case <-ctx.Done(): + return nil + default: + _, message, err := conn.Read(ctx) + if err != nil { + // The caller of this function should make the decision on if to retry. If it does, then + // the websocket connection will be retried, and we will check for missed events. + return fmt.Errorf("error when attempting to read from event stream, reopening websocket: %w", err) + } + updater.logger.Trace("received event", "message", string(message)) + messageMap := make(map[string]interface{}) + err = json.Unmarshal(message, &messageMap) + if err != nil { + return fmt.Errorf("error when unmarshaling event, message: %s\nerror: %w", string(message), err) + } + data, ok := messageMap["data"].(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected event format when decoding 'data' element, message: %s\nerror: %w", string(message), err) + } + event, ok := data["event"].(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected event format when decoding 'event' element, message: %s\nerror: %w", string(message), err) + } + metadata, ok := event["metadata"].(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected event format when decoding 'metadata' element, message: %s\nerror: %w", string(message), err) + } + modified, ok := metadata["modified"].(string) + if ok && modified == "true" { + path, ok := metadata["path"].(string) + if !ok { + return fmt.Errorf("unexpected event format when decoding 'path' element, message: %s\nerror: %w", string(message), err) + } + err := updater.updateStaticSecret(ctx, path) + if err != nil { + // While we are kind of 'missing' an event this way, re-calling this function will + // result in the secret remaining up to date. + return fmt.Errorf("error updating static secret: path: %q, message: %s error: %w", path, message, err) + } + } else { + // This is an event we're not interested in, ignore it and + // carry on. + continue + } + } + } + + return nil +} + +// updateStaticSecret checks for updates for a static secret on the path given, +// and updates the cache if appropriate +func (updater *StaticSecretCacheUpdater) updateStaticSecret(ctx context.Context, path string) error { + // We clone the client, as we won't be using the same token. + client, err := updater.client.Clone() + if err != nil { + return err + } + + indexId := hashStaticSecretIndex(path) + + updater.logger.Debug("received update static secret request", "path", path, "indexId", indexId) + + index, err := updater.leaseCache.db.Get(cachememdb.IndexNameID, indexId) + if errors.Is(err, cachememdb.ErrCacheItemNotFound) { + // This event doesn't correspond to a secret in our cache + // so this is a no-op. + return nil + } + if err != nil { + return err + } + + // We use a raw request so that we can store all the + // request information, just like we do in the Proxier Send methods. + request := client.NewRequest(http.MethodGet, "/v1/"+path) + if request.Headers == nil { + request.Headers = make(http.Header) + } + request.Headers.Set("User-Agent", useragent.ProxyString()) + + var resp *api.Response + var tokensToRemove []string + var successfulAttempt bool + for _, token := range maps.Keys(index.Tokens) { + client.SetToken(token) + request.Headers.Set(api.AuthHeaderName, token) + resp, err = client.RawRequestWithContext(ctx, request) + if err != nil { + updater.logger.Trace("received error when trying to update cache", "path", path, "err", err, "token", token) + // We cannot access this secret with this token for whatever reason, + // so token for removal. + tokensToRemove = append(tokensToRemove, token) + continue + } else { + // We got our updated secret! + successfulAttempt = true + break + } + } + + if successfulAttempt { + // We need to update the index, so first, hold the lock. + index.IndexLock.Lock() + defer index.IndexLock.Unlock() + + // First, remove the tokens we noted couldn't access the secret from the token index + for _, token := range tokensToRemove { + delete(index.Tokens, token) + } + + sendResponse, err := NewSendResponse(resp, nil) + if err != nil { + return err + } + + // Serialize the response to store it in the cached index + var respBytes bytes.Buffer + err = sendResponse.Response.Write(&respBytes) + if err != nil { + updater.logger.Error("failed to serialize response", "error", err) + return err + } + + // Set the index's Response + index.Response = respBytes.Bytes() + index.LastRenewed = time.Now().UTC() + + // Lastly, store the secret + updater.logger.Debug("storing response into the cache due to event update", "path", path) + err = updater.leaseCache.db.Set(index) + if err != nil { + return err + } + } else { + // No token could successfully update the secret, or secret was deleted. + // We should evict the cache instead of re-storing the secret. + updater.logger.Debug("evicting response from cache", "path", path) + err = updater.leaseCache.db.Evict(cachememdb.IndexNameID, indexId) + if err != nil { + return err + } + } + + return nil +} + +// openWebSocketConnection opens a websocket connection to the event system for +// the events that the static secret cache updater is interested in. +func (updater *StaticSecretCacheUpdater) openWebSocketConnection(ctx context.Context) (*websocket.Conn, error) { + // We parse this into a URL object to get the specific host and scheme + // information without nasty string parsing. + vaultURL, err := url.Parse(updater.client.Address()) + if err != nil { + return nil, err + } + vaultHost := vaultURL.Host + // If we're using https, use wss, otherwise ws + scheme := "wss" + if vaultURL.Scheme == "http" { + scheme = "ws" + } + + webSocketURL := url.URL{ + Path: "/v1/sys/events/subscribe/kv*", + Host: vaultHost, + Scheme: scheme, + } + query := webSocketURL.Query() + query.Set("json", "true") + webSocketURL.RawQuery = query.Encode() + + updater.client.AddHeader(api.AuthHeaderName, updater.client.Token()) + updater.client.AddHeader(api.NamespaceHeaderName, updater.client.Namespace()) + + // Populate these now to avoid recreating them in the upcoming for loop. + headers := updater.client.Headers() + wsURL := webSocketURL.String() + httpClient := updater.client.CloneConfig().HttpClient + + // We do ten attempts, to ensure we follow forwarding to the leader. + var conn *websocket.Conn + for attempt := 0; attempt < 10; attempt++ { + var resp *http.Response + conn, resp, err = websocket.Dial(ctx, wsURL, &websocket.DialOptions{ + HTTPClient: httpClient, + HTTPHeader: headers, + }) + if err == nil { + break + } + + switch { + case resp == nil: + break + case resp.StatusCode == http.StatusTemporaryRedirect: + wsURL = resp.Header.Get("Location") + continue + default: + break + } + } + + if err != nil { + return nil, fmt.Errorf("error returned when opening event stream web socket to %s, ensure auto-auth token"+ + " has correct permissions and Vault is version 1.16 or above: %w", wsURL, err) + } + + if conn == nil { + return nil, errors.New(fmt.Sprintf("too many redirects as part of establishing web socket connection to %s", wsURL)) + } + + return conn, nil +} + +// Run is intended to be the method called by Vault Proxy, that runs the subsystem. +// Once a token is provided to the sink, we will start the websocket and start consuming +// events and updating secrets. +// Run will shut down gracefully when the context is cancelled. +func (updater *StaticSecretCacheUpdater) Run(ctx context.Context) error { + updater.logger.Info("starting static secret cache updater subsystem") + defer func() { + updater.logger.Info("static secret cache updater subsystem stopped") + }() + +tokenLoop: + for { + select { + case <-ctx.Done(): + return nil + default: + // Wait for the auto-auth token to be populated... + if updater.tokenSink.(sink.SinkReader).Token() != "" { + break tokenLoop + } + time.Sleep(100 * time.Millisecond) + } + } + + shouldBackoff := false + for { + select { + case <-ctx.Done(): + return nil + default: + // If we're erroring and the context isn't done, we should add + // a little backoff to make sure we don't accidentally overload + // Vault or similar. + if shouldBackoff { + time.Sleep(10 * time.Second) + } + err := updater.streamStaticSecretEvents(ctx) + if err != nil { + updater.logger.Warn("error occurred during streaming static secret cache update events:", err) + shouldBackoff = true + continue + } + } + } +} diff --git a/command/agentproxyshared/cache/static_secret_cache_updater_test.go b/command/agentproxyshared/cache/static_secret_cache_updater_test.go new file mode 100644 index 0000000000..aabd644178 --- /dev/null +++ b/command/agentproxyshared/cache/static_secret_cache_updater_test.go @@ -0,0 +1,581 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/helper/testhelpers/minimal" + + "github.com/hashicorp/go-hclog" + kv "github.com/hashicorp/vault-plugin-secrets-kv" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/command/agentproxyshared/cache/cachememdb" + "github.com/hashicorp/vault/command/agentproxyshared/sink" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/logging" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/vault" + "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "nhooyr.io/websocket" +) + +// Avoiding a circular dependency in the test. +type mockSink struct { + token *atomic.String +} + +func (m *mockSink) Token() string { + return m.token.Load() +} + +func (m *mockSink) WriteToken(token string) error { + m.token.Store(token) + return nil +} + +func newMockSink(t *testing.T) sink.Sink { + t.Helper() + + return &mockSink{ + token: atomic.NewString(""), + } +} + +// testNewStaticSecretCacheUpdater returns a new StaticSecretCacheUpdater +// for use in tests. +func testNewStaticSecretCacheUpdater(t *testing.T, client *api.Client) *StaticSecretCacheUpdater { + t.Helper() + + lc := testNewLeaseCache(t, []*SendResponse{}) + tokenSink := newMockSink(t) + tokenSink.WriteToken(client.Token()) + + updater, err := NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ + Client: client, + LeaseCache: lc, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), + TokenSink: tokenSink, + }) + if err != nil { + t.Fatal(err) + } + return updater +} + +// TestNewStaticSecretCacheUpdater tests the NewStaticSecretCacheUpdater method, +// to ensure it errors out when appropriate. +func TestNewStaticSecretCacheUpdater(t *testing.T) { + t.Parallel() + + lc := testNewLeaseCache(t, []*SendResponse{}) + config := api.DefaultConfig() + logger := logging.NewVaultLogger(hclog.Trace).Named("cache.updater") + client, err := api.NewClient(config) + if err != nil { + t.Fatal(err) + } + tokenSink := newMockSink(t) + + // Expect an error if any of the arguments are nil: + updater, err := NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ + Client: nil, + LeaseCache: lc, + Logger: logger, + TokenSink: tokenSink, + }) + require.Error(t, err) + require.Nil(t, updater) + + updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ + Client: client, + LeaseCache: nil, + Logger: logger, + TokenSink: tokenSink, + }) + require.Error(t, err) + require.Nil(t, updater) + + updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ + Client: client, + LeaseCache: lc, + Logger: nil, + TokenSink: tokenSink, + }) + require.Error(t, err) + require.Nil(t, updater) + + updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ + Client: client, + LeaseCache: lc, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), + TokenSink: nil, + }) + require.Error(t, err) + require.Nil(t, updater) + + // Don't expect an error if the arguments are as expected + updater, err = NewStaticSecretCacheUpdater(&StaticSecretCacheUpdaterConfig{ + Client: client, + LeaseCache: lc, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.updater"), + TokenSink: tokenSink, + }) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, updater) +} + +// TestOpenWebSocketConnection tests that the openWebSocketConnection function +// works as expected. This uses a TLS enabled (wss) WebSocket connection. +func TestOpenWebSocketConnection(t *testing.T) { + t.Parallel() + // We need a valid cluster for the connection to succeed. + cluster := minimal.NewTestSoloCluster(t, nil) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + updater.tokenSink.WriteToken(client.Token()) + + conn, err := updater.openWebSocketConnection(context.Background()) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, conn) +} + +// TestOpenWebSocketConnectionReceivesEventsDefaultMount tests that the openWebSocketConnection function +// works as expected with the default KVV1 mount, and then the connection can be used to receive an event. +// This acts as more of an event system sanity check than a test of the updater +// logic. It's still important coverage, though. +// As of right now, it does not pass since the default kv mount is LeasedPassthroughBackend. +// If that is changed, this test will be unskipped. +func TestOpenWebSocketConnectionReceivesEventsDefaultMount(t *testing.T) { + t.Parallel() + t.Skip("This test won't finish, as the default KV mount is LeasedPassthroughBackend in tests, and therefore does not send events") + // We need a valid cluster for the connection to succeed. + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + + conn, err := updater.openWebSocketConnection(context.Background()) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, conn) + + t.Cleanup(func() { + conn.Close(websocket.StatusNormalClosure, "") + }) + + makeData := func(i int) map[string]interface{} { + return map[string]interface{}{ + "foo": fmt.Sprintf("bar%d", i), + } + } + // Put a secret, which should trigger an event + err = client.KVv1("secret").Put(context.Background(), "foo", makeData(100)) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 5; i++ { + // Do a fresh PUT just to refresh the secret and send a new message + err = client.KVv1("secret").Put(context.Background(), "foo", makeData(i)) + if err != nil { + t.Fatal(err) + } + + // This method blocks until it gets a secret, so this test + // will only pass if we're receiving events correctly. + _, message, err := conn.Read(context.Background()) + if err != nil { + t.Fatal(err) + } + t.Log(string(message)) + } +} + +// TestOpenWebSocketConnectionReceivesEventsKVV1 tests that the openWebSocketConnection function +// works as expected with KVV1, and then the connection can be used to receive an event. +// This acts as more of an event system sanity check than a test of the updater +// logic. It's still important coverage, though. +func TestOpenWebSocketConnectionReceivesEventsKVV1(t *testing.T) { + t.Parallel() + // We need a valid cluster for the connection to succeed. + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.Factory, + }, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + + conn, err := updater.openWebSocketConnection(context.Background()) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, conn) + + t.Cleanup(func() { + conn.Close(websocket.StatusNormalClosure, "") + }) + + err = client.Sys().Mount("secret-v1", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + + makeData := func(i int) map[string]interface{} { + return map[string]interface{}{ + "foo": fmt.Sprintf("bar%d", i), + } + } + // Put a secret, which should trigger an event + err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(100)) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 5; i++ { + // Do a fresh PUT just to refresh the secret and send a new message + err = client.KVv1("secret-v1").Put(context.Background(), "foo", makeData(i)) + if err != nil { + t.Fatal(err) + } + + // This method blocks until it gets a secret, so this test + // will only pass if we're receiving events correctly. + _, _, err := conn.Read(context.Background()) + if err != nil { + t.Fatal(err) + } + } +} + +// TestOpenWebSocketConnectionReceivesEvents tests that the openWebSocketConnection function +// works as expected with KVV2, and then the connection can be used to receive an event. +// This acts as more of an event system sanity check than a test of the updater +// logic. It's still important coverage, though. +func TestOpenWebSocketConnectionReceivesEventsKVV2(t *testing.T) { + t.Parallel() + // We need a valid cluster for the connection to succeed. + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + + conn, err := updater.openWebSocketConnection(context.Background()) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, conn) + + t.Cleanup(func() { + conn.Close(websocket.StatusNormalClosure, "") + }) + + makeData := func(i int) map[string]interface{} { + return map[string]interface{}{ + "foo": fmt.Sprintf("bar%d", i), + } + } + + err = client.Sys().Mount("secret-v2", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + // Put a secret, which should trigger an event + _, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(100)) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 5; i++ { + // Do a fresh PUT just to refresh the secret and send a new message + _, err = client.KVv2("secret-v2").Put(context.Background(), "foo", makeData(i)) + if err != nil { + t.Fatal(err) + } + + // This method blocks until it gets a secret, so this test + // will only pass if we're receiving events correctly. + _, _, err := conn.Read(context.Background()) + if err != nil { + t.Fatal(err) + } + } +} + +// TestOpenWebSocketConnectionTestServer tests that the openWebSocketConnection function +// works as expected using vaulthttp.TestServer. This server isn't TLS enabled, so tests +// the ws path (as opposed to the wss) path. +func TestOpenWebSocketConnectionTestServer(t *testing.T) { + t.Parallel() + // We need a valid cluster for the connection to succeed. + core := vault.TestCoreWithConfig(t, &vault.CoreConfig{}) + ln, addr := vaulthttp.TestServer(t, core) + defer ln.Close() + + keys, rootToken := vault.TestCoreInit(t, core) + for _, key := range keys { + _, err := core.Unseal(key) + if err != nil { + t.Fatal(err) + } + } + + config := api.DefaultConfig() + config.Address = addr + client, err := api.NewClient(config) + if err != nil { + t.Fatal(err) + } + client.SetToken(rootToken) + updater := testNewStaticSecretCacheUpdater(t, client) + + conn, err := updater.openWebSocketConnection(context.Background()) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, conn) +} + +// Test_StreamStaticSecretEvents_UpdatesCacheWithNewSecrets tests that an event will +// properly update the corresponding secret in Proxy's cache. This is a little more end-to-end-y +// than TestUpdateStaticSecret, and essentially is testing a similar thing, though is +// ensuring that updateStaticSecret gets called by the event arriving +// (as part of streamStaticSecretEvents) instead of testing calling it explicitly. +func Test_StreamStaticSecretEvents_UpdatesCacheWithNewSecrets(t *testing.T) { + t.Parallel() + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": kv.VersionedKVFactory, + }, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + leaseCache := updater.leaseCache + + wg := &sync.WaitGroup{} + runStreamStaticSecretEvents := func() { + wg.Add(1) + err := updater.streamStaticSecretEvents(context.Background()) + if err != nil { + t.Fatal(err) + } + } + go runStreamStaticSecretEvents() + + // First, create the secret in the cache that we expect to be updated: + path := "secret-v2/data/foo" + indexId := hashStaticSecretIndex(path) + initialTime := time.Now().UTC() + // pre-populate the leaseCache with a secret to update + index := &cachememdb.Index{ + Namespace: "root/", + RequestPath: path, + LastRenewed: initialTime, + ID: indexId, + // Valid token provided, so update should work. + Tokens: map[string]struct{}{client.Token(): {}}, + Response: []byte{}, + } + err := leaseCache.db.Set(index) + if err != nil { + t.Fatal(err) + } + + secretData := map[string]interface{}{ + "foo": "bar", + } + + err = client.Sys().Mount("secret-v2", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + // Put a secret, which should trigger an event + _, err = client.KVv2("secret-v2").Put(context.Background(), "foo", secretData) + if err != nil { + t.Fatal(err) + } + + // Wait for the event to arrive. Events are usually much, much faster + // than this, but we make it five seconds to protect against CI flakiness. + time.Sleep(5 * time.Second) + + // Then, do a GET to see if the event got updated + newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, newIndex) + require.NotEqual(t, []byte{}, newIndex.Response) + require.Truef(t, initialTime.Before(newIndex.LastRenewed), "last updated time not updated on index") + require.Equal(t, index.RequestPath, newIndex.RequestPath) + require.Equal(t, index.Tokens, newIndex.Tokens) + + wg.Done() +} + +// TestUpdateStaticSecret tests that updateStaticSecret works as expected, reaching out +// to Vault to get an updated secret when called. +func TestUpdateStaticSecret(t *testing.T) { + t.Parallel() + // We need a valid cluster for the connection to succeed. + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + leaseCache := updater.leaseCache + + path := "secret/foo" + indexId := hashStaticSecretIndex(path) + initialTime := time.Now().UTC() + // pre-populate the leaseCache with a secret to update + index := &cachememdb.Index{ + Namespace: "root/", + RequestPath: "secret/foo", + LastRenewed: initialTime, + ID: indexId, + // Valid token provided, so update should work. + Tokens: map[string]struct{}{client.Token(): {}}, + Response: []byte{}, + } + err := leaseCache.db.Set(index) + if err != nil { + t.Fatal(err) + } + + secretData := map[string]interface{}{ + "foo": "bar", + } + + // create the secret in Vault. n.b. the test cluster has already mounted the KVv1 backend at "secret" + err = client.KVv1("secret").Put(context.Background(), "foo", secretData) + if err != nil { + t.Fatal(err) + } + + // attempt the update + err = updater.updateStaticSecret(context.Background(), path) + if err != nil { + t.Fatal(err) + } + + newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) + if err != nil { + t.Fatal(err) + } + require.NotNil(t, newIndex) + require.Truef(t, initialTime.Before(newIndex.LastRenewed), "last updated time not updated on index") + require.NotEqual(t, []byte{}, newIndex.Response) + require.Equal(t, index.RequestPath, newIndex.RequestPath) + require.Equal(t, index.Tokens, newIndex.Tokens) +} + +// TestUpdateStaticSecret_EvictsIfInvalidTokens tests that updateStaticSecret will +// evict secrets from the cache if no valid tokens are left. +func TestUpdateStaticSecret_EvictsIfInvalidTokens(t *testing.T) { + t.Parallel() + // We need a valid cluster for the connection to succeed. + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + leaseCache := updater.leaseCache + + path := "secret/foo" + indexId := hashStaticSecretIndex(path) + renewTime := time.Now().UTC() + + // pre-populate the leaseCache with a secret to update + index := &cachememdb.Index{ + Namespace: "root/", + RequestPath: "secret/foo", + LastRenewed: renewTime, + ID: indexId, + // Note: invalid Tokens value provided, so this secret cannot be updated, and must be evicted + Tokens: map[string]struct{}{"invalid token": {}}, + } + err := leaseCache.db.Set(index) + if err != nil { + t.Fatal(err) + } + + secretData := map[string]interface{}{ + "foo": "bar", + } + + // create the secret in Vault. n.b. the test cluster has already mounted the KVv1 backend at "secret" + err = client.KVv1("secret").Put(context.Background(), "foo", secretData) + if err != nil { + t.Fatal(err) + } + + // attempt the update + err = updater.updateStaticSecret(context.Background(), path) + if err != nil { + t.Fatal(err) + } + + newIndex, err := leaseCache.db.Get(cachememdb.IndexNameID, indexId) + require.Equal(t, cachememdb.ErrCacheItemNotFound, err) + require.Nil(t, newIndex) +} + +// TestUpdateStaticSecret_HandlesNonCachedPaths tests that updateStaticSecret +// doesn't fail or error if we try and give it an update to a path that isn't cached. +func TestUpdateStaticSecret_HandlesNonCachedPaths(t *testing.T) { + t.Parallel() + // We need a valid cluster for the connection to succeed. + cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + client := cluster.Cores[0].Client + + updater := testNewStaticSecretCacheUpdater(t, client) + + path := "secret/foo" + + // attempt the update + err := updater.updateStaticSecret(context.Background(), path) + if err != nil { + t.Fatal(err) + } + require.Nil(t, err) +} diff --git a/command/proxy.go b/command/proxy.go index fdcee532dc..ec5daab603 100644 --- a/command/proxy.go +++ b/command/proxy.go @@ -433,6 +433,8 @@ func (c *ProxyCommand) Run(args []string) int { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + var updater *cache.StaticSecretCacheUpdater + // Parse proxy cache configurations if config.Cache != nil { cacheLogger := c.logger.Named("cache") @@ -463,6 +465,33 @@ func (c *ProxyCommand) Run(args []string) int { defer deferFunc() } } + + // If we're caching static secrets, we need to start the updater, too + if config.Cache.CacheStaticSecrets { + staticSecretCacheUpdaterLogger := c.logger.Named("cache.staticsecretcacheupdater") + inmemSink, err := inmem.New(&sink.SinkConfig{ + Logger: staticSecretCacheUpdaterLogger, + }, leaseCache) + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating inmem sink for static secret updater susbsystem: %v", err)) + return 1 + } + sinks = append(sinks, &sink.SinkConfig{ + Logger: staticSecretCacheUpdaterLogger, + Sink: inmemSink, + }) + + updater, err = cache.NewStaticSecretCacheUpdater(&cache.StaticSecretCacheUpdaterConfig{ + Client: client, + LeaseCache: leaseCache, + Logger: staticSecretCacheUpdaterLogger, + TokenSink: inmemSink, + }) + if err != nil { + c.UI.Error(fmt.Sprintf("Error creating static secret cache updater: %v", err)) + return 1 + } + } } var listeners []net.Listener @@ -500,7 +529,7 @@ func (c *ProxyCommand) Run(args []string) int { var inmemSink sink.Sink if config.APIProxy != nil { if config.APIProxy.UseAutoAuthToken { - apiProxyLogger.Debug("auto-auth token is allowed to be used; configuring inmem sink") + apiProxyLogger.Debug("configuring inmem auto-auth sink") inmemSink, err = inmem.New(&sink.SinkConfig{ Logger: apiProxyLogger, }, leaseCache) @@ -699,6 +728,16 @@ func (c *ProxyCommand) Run(args []string) int { }) } + // Add the static secret cache updater, if appropriate + if updater != nil { + g.Add(func() error { + err := updater.Run(ctx) + return err + }, func(error) { + cancelFunc() + }) + } + // Server configuration output padding := 24 sort.Strings(infoKeys) diff --git a/command/proxy/config/config.go b/command/proxy/config/config.go index 1881f08633..c0afd50d6b 100644 --- a/command/proxy/config/config.go +++ b/command/proxy/config/config.go @@ -247,12 +247,17 @@ func (c *Config) ValidateConfig() error { } if c.AutoAuth != nil { + cacheStaticSecrets := c.Cache != nil && c.Cache.CacheStaticSecrets if len(c.AutoAuth.Sinks) == 0 && - (c.APIProxy == nil || !c.APIProxy.UseAutoAuthToken) { - return fmt.Errorf("auto_auth requires at least one sink or api_proxy.use_auto_auth_token=true") + (c.APIProxy == nil || !c.APIProxy.UseAutoAuthToken) && !cacheStaticSecrets { + return fmt.Errorf("auto_auth requires at least one sink, api_proxy.use_auto_auth_token=true, or cache.cache_static_secrets=true") } } + if c.Cache != nil && c.Cache.CacheStaticSecrets && c.AutoAuth == nil { + return fmt.Errorf("cache.cache_static_secrets=true requires an auto-auth block configured, to use the token to connect with Vault's event system") + } + if c.AutoAuth == nil && c.Cache == nil && len(c.Listeners) == 0 { return fmt.Errorf("no auto_auth, cache, or listener block found in config") } diff --git a/command/proxy/config/config_test.go b/command/proxy/config/config_test.go index c6b631df2f..c92e1e1579 100644 --- a/command/proxy/config/config_test.go +++ b/command/proxy/config/config_test.go @@ -117,3 +117,16 @@ func TestLoadConfigFile_ProxyCache(t *testing.T) { t.Fatal(diff) } } + +// TestLoadConfigFile_StaticSecretCachingWithoutAutoAuth tests that loading +// a config file with static secret caching enabled but no auto auth will fail. +func TestLoadConfigFile_StaticSecretCachingWithoutAutoAuth(t *testing.T) { + cfg, err := LoadConfigFile("./test-fixtures/config-cache-static-no-auto-auth.hcl") + if err != nil { + t.Fatal(err) + } + + if err := cfg.ValidateConfig(); err == nil { + t.Fatalf("expected error, as static secret caching requires auto-auth") + } +} diff --git a/command/proxy/config/test-fixtures/config-cache-static-no-auto-auth.hcl b/command/proxy/config/test-fixtures/config-cache-static-no-auto-auth.hcl new file mode 100644 index 0000000000..815d7fd8e6 --- /dev/null +++ b/command/proxy/config/test-fixtures/config-cache-static-no-auto-auth.hcl @@ -0,0 +1,18 @@ +# Copyright (c) HashiCorp, Inc. +# SPDX-License-Identifier: BUSL-1.1 + +pid_file = "./pidfile" + +cache { + cache_static_secrets = true +} + +listener "tcp" { + address = "127.0.0.1:8300" + tls_disable = true +} + +vault { + address = "http://127.0.0.1:1111" + tls_skip_verify = "true" +} diff --git a/command/proxy_test.go b/command/proxy_test.go index ecfe910803..0ee60e1fa1 100644 --- a/command/proxy_test.go +++ b/command/proxy_test.go @@ -703,6 +703,20 @@ func TestProxy_Cache_StaticSecret(t *testing.T) { defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) os.Unsetenv(api.EnvVaultAddress) + tokenFileName := makeTempFile(t, "token-file", serverClient.Token()) + defer os.Remove(tokenFileName) + // We need auto-auth so that the event system can run. + // For ease, we use the token file path with the root token. + autoAuthConfig := fmt.Sprintf(` +auto_auth { + method { + type = "token_file" + config = { + token_file_path = "%s" + } + } +}`, tokenFileName) + cacheConfig := ` cache { cache_static_secrets = true @@ -723,13 +737,14 @@ vault { } %s %s +%s log_level = "trace" -`, serverClient.Address(), cacheConfig, listenConfig) +`, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) // Start proxy - _, cmd := testProxyCommand(t, logger) + ui, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} @@ -743,6 +758,8 @@ log_level = "trace" case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") + t.Errorf("stdout: %s", ui.OutputWriter.String()) + t.Errorf("stderr: %s", ui.ErrorWriter.String()) } proxyClient, err := api.NewClient(api.DefaultConfig()) @@ -804,15 +821,18 @@ log_level = "trace" wg.Wait() } -// TestProxy_Cache_StaticSecretInvalidation Tests that the cache successfully caches a static secret -// going through the Proxy, and that it gets invalidated by a POST. -func TestProxy_Cache_StaticSecretInvalidation(t *testing.T) { +// TestProxy_Cache_EventSystemUpdatesCacheKVV1 Tests that the cache successfully caches a static secret +// going through the Proxy, and then the cache gets updated on a POST to the KVV1 secret due to an +// event. +func TestProxy_Cache_EventSystemUpdatesCacheKVV1(t *testing.T) { logger := logging.NewVaultLogger(hclog.Trace) - cluster := vault.NewTestCluster(t, nil, &vault.TestClusterOptions{ + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": logicalKv.Factory, + }, + }, &vault.TestClusterOptions{ HandlerFunc: vaulthttp.Handler, }) - cluster.Start() - defer cluster.Cleanup() serverClient := cluster.Cores[0].Client @@ -821,6 +841,20 @@ func TestProxy_Cache_StaticSecretInvalidation(t *testing.T) { defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) os.Unsetenv(api.EnvVaultAddress) + tokenFileName := makeTempFile(t, "token-file", serverClient.Token()) + defer os.Remove(tokenFileName) + // We need auto-auth so that the event system can run. + // For ease, we use the token file path with the root token. + autoAuthConfig := fmt.Sprintf(` +auto_auth { + method { + type = "token_file" + config = { + token_file_path = "%s" + } + } +}`, tokenFileName) + cacheConfig := ` cache { cache_static_secrets = true @@ -841,13 +875,14 @@ vault { } %s %s +%s log_level = "trace" -`, serverClient.Address(), cacheConfig, listenConfig) +`, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig) configPath := makeTempFile(t, "config.hcl", config) defer os.Remove(configPath) // Start proxy - _, cmd := testProxyCommand(t, logger) + ui, cmd := testProxyCommand(t, logger) cmd.startedCh = make(chan struct{}) wg := &sync.WaitGroup{} @@ -861,6 +896,8 @@ log_level = "trace" case <-cmd.startedCh: case <-time.After(5 * time.Second): t.Errorf("timeout") + t.Errorf("stdout: %s", ui.OutputWriter.String()) + t.Errorf("stderr: %s", ui.ErrorWriter.String()) } proxyClient, err := api.NewClient(api.DefaultConfig()) @@ -882,14 +919,27 @@ log_level = "trace" "bar": "baz", } + // Wait for the event system to successfully connect. + // This is longer than it needs to be to account for unnatural slowness/avoiding + // flakiness. + time.Sleep(5 * time.Second) + + // Mount the KVV2 engine + err = serverClient.Sys().Mount("secret-v1", &api.MountInput{ + Type: "kv", + }) + if err != nil { + t.Fatal(err) + } + // Create kvv1 secret - err = serverClient.KVv1("secret").Put(context.Background(), "my-secret", secretData) + err = serverClient.KVv1("secret-v1").Put(context.Background(), "my-secret", secretData) if err != nil { t.Fatal(err) } // We use raw requests so we can check the headers for cache hit/miss. - req := proxyClient.NewRequest(http.MethodGet, "/v1/secret/my-secret") + req := proxyClient.NewRequest(http.MethodGet, "/v1/secret-v1/my-secret") resp1, err := proxyClient.RawRequest(req) if err != nil { t.Fatal(err) @@ -899,27 +949,23 @@ log_level = "trace" require.Equal(t, "MISS", cacheValue) // Update the secret using the proxy client - err = proxyClient.KVv1("secret").Put(context.Background(), "my-secret", secretData2) + err = proxyClient.KVv1("secret-v1").Put(context.Background(), "my-secret", secretData2) if err != nil { t.Fatal(err) } + // Give some time for the event to actually get sent and the cache to be updated. + // This is longer than it needs to be to account for unnatural slowness/avoiding + // flakiness. + time.Sleep(5 * time.Second) + + // We expect this to be a cache hit, with the new value resp2, err := proxyClient.RawRequest(req) if err != nil { t.Fatal(err) } cacheValue = resp2.Header.Get("X-Cache") - // This should miss too, as we just updated it - require.Equal(t, "MISS", cacheValue) - - resp3, err := proxyClient.RawRequest(req) - if err != nil { - t.Fatal(err) - } - - cacheValue = resp3.Header.Get("X-Cache") - // This should hit, as the third request should get the cached value require.Equal(t, "HIT", cacheValue) // Lastly, we check to make sure the actual data we received is @@ -936,11 +982,175 @@ log_level = "trace" } require.Equal(t, secretData2, secret2.Data) - secret3, err := api.ParseSecret(resp3.Body) + close(cmd.ShutdownCh) + wg.Wait() +} + +// TestProxy_Cache_EventSystemUpdatesCacheKVV2 Tests that the cache successfully caches a static secret +// going through the Proxy for a KVV2 secret, and then the cache gets updated on a POST to the secret due to an +// event. +func TestProxy_Cache_EventSystemUpdatesCacheKVV2(t *testing.T) { + logger := logging.NewVaultLogger(hclog.Trace) + cluster := vault.NewTestCluster(t, &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "kv": logicalKv.VersionedKVFactory, + }, + }, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + + serverClient := cluster.Cores[0].Client + + // Unset the environment variable so that proxy picks up the right test + // cluster address + defer os.Setenv(api.EnvVaultAddress, os.Getenv(api.EnvVaultAddress)) + os.Unsetenv(api.EnvVaultAddress) + + tokenFileName := makeTempFile(t, "token-file", serverClient.Token()) + defer os.Remove(tokenFileName) + // We need auto-auth so that the event system can run. + // For ease, we use the token file path with the root token. + autoAuthConfig := fmt.Sprintf(` +auto_auth { + method { + type = "token_file" + config = { + token_file_path = "%s" + } + } +}`, tokenFileName) + + cacheConfig := ` +cache { + cache_static_secrets = true +} +` + listenAddr := generateListenerAddress(t) + listenConfig := fmt.Sprintf(` +listener "tcp" { + address = "%s" + tls_disable = true +} +`, listenAddr) + + config := fmt.Sprintf(` +vault { + address = "%s" + tls_skip_verify = true +} +%s +%s +%s +log_level = "trace" +`, serverClient.Address(), cacheConfig, listenConfig, autoAuthConfig) + configPath := makeTempFile(t, "config.hcl", config) + defer os.Remove(configPath) + + // Start proxy + ui, cmd := testProxyCommand(t, logger) + cmd.startedCh = make(chan struct{}) + + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + cmd.Run([]string{"-config", configPath}) + wg.Done() + }() + + select { + case <-cmd.startedCh: + case <-time.After(5 * time.Second): + t.Errorf("timeout") + t.Errorf("stdout: %s", ui.OutputWriter.String()) + t.Errorf("stderr: %s", ui.ErrorWriter.String()) + } + + proxyClient, err := api.NewClient(api.DefaultConfig()) if err != nil { t.Fatal(err) } - require.Equal(t, secret2.Data, secret3.Data) + proxyClient.SetToken(serverClient.Token()) + proxyClient.SetMaxRetries(0) + err = proxyClient.SetAddress("http://" + listenAddr) + if err != nil { + t.Fatal(err) + } + + secretData := map[string]interface{}{ + "foo": "bar", + } + + secretData2 := map[string]interface{}{ + "bar": "baz", + } + + // Wait for the event system to successfully connect. + // This is longer than it needs to be to account for unnatural slowness/avoiding + // flakiness. + time.Sleep(5 * time.Second) + + // Mount the KVV2 engine + err = serverClient.Sys().Mount("secret-v2", &api.MountInput{ + Type: "kv-v2", + }) + if err != nil { + t.Fatal(err) + } + + // Create kvv2 secret + _, err = serverClient.KVv2("secret-v2").Put(context.Background(), "my-secret", secretData) + if err != nil { + t.Fatal(err) + } + + // We use raw requests so we can check the headers for cache hit/miss. + req := proxyClient.NewRequest(http.MethodGet, "/v1/secret-v2/data/my-secret") + resp1, err := proxyClient.RawRequest(req) + if err != nil { + t.Fatal(err) + } + + cacheValue := resp1.Header.Get("X-Cache") + require.Equal(t, "MISS", cacheValue) + + // Update the secret using the proxy client + _, err = proxyClient.KVv2("secret-v2").Put(context.Background(), "my-secret", secretData2) + if err != nil { + t.Fatal(err) + } + + // Give some time for the event to actually get sent and the cache to be updated. + // This is longer than it needs to be to account for unnatural slowness/avoiding + // flakiness. + time.Sleep(5 * time.Second) + + // We expect this to be a cache hit, with the new value + resp2, err := proxyClient.RawRequest(req) + if err != nil { + t.Fatal(err) + } + + cacheValue = resp2.Header.Get("X-Cache") + require.Equal(t, "HIT", cacheValue) + + // Lastly, we check to make sure the actual data we received is + // as we expect. We must use ParseSecret due to the raw requests. + secret1, err := api.ParseSecret(resp1.Body) + if err != nil { + t.Fatal(err) + } + data, ok := secret1.Data["data"] + require.True(t, ok) + require.Equal(t, secretData, data) + + secret2, err := api.ParseSecret(resp2.Body) + if err != nil { + t.Fatal(err) + } + data2, ok := secret2.Data["data"] + require.True(t, ok) + // We expect that the cached value got updated by the event system. + require.Equal(t, secretData2, data2) close(cmd.ShutdownCh) wg.Wait()