diff --git a/command/agentproxyshared/cache/cache_test.go b/command/agentproxyshared/cache/cache_test.go index f1648efa37..12e1e18e3a 100644 --- a/command/agentproxyshared/cache/cache_test.go +++ b/command/agentproxyshared/cache/cache_test.go @@ -33,7 +33,7 @@ func tokenRevocationValidation(t *testing.T, sampleSpace map[string]string, expe t.Helper() for val, valType := range sampleSpace { index, err := leaseCache.db.Get(valType, val) - if err != nil { + if err != nil && err != cachememdb.ErrCacheItemNotFound { t.Fatal(err) } if expected[val] == "" && index != nil { @@ -1098,12 +1098,8 @@ func testCachingCacheClearCommon(t *testing.T, clearType string) { // Verify the entry is cleared idx, err = leaseCache.db.Get(cachememdb.IndexNameLease, gotLeaseID) - if err != nil { - t.Fatal(err) - } - - if idx != nil { - t.Fatalf("expected entry to be nil, got: %v", idx) + if err != cachememdb.ErrCacheItemNotFound { + t.Fatal("expected entry to be nil, got", err) } } diff --git a/command/agentproxyshared/cache/cacheboltdb/bolt.go b/command/agentproxyshared/cache/cacheboltdb/bolt.go index 42c9fff485..ff7ec5fdf8 100644 --- a/command/agentproxyshared/cache/cacheboltdb/bolt.go +++ b/command/agentproxyshared/cache/cacheboltdb/bolt.go @@ -42,6 +42,11 @@ const ( // StaticSecretType - Bucket/type for static secrets StaticSecretType = "static-secret" + // TokenCapabilitiesType - Bucket/type for the token capabilities that + // are used to govern access to static secrets. These will be updated + // periodically to ensure that access to the cached secret remains. + TokenCapabilitiesType = "token-capabilities" + // LeaseType - v2 Bucket/type for auth AND secret leases. // // This bucket stores keys in the same order they were created using diff --git a/command/agentproxyshared/cache/cachememdb/cache_memdb.go b/command/agentproxyshared/cache/cachememdb/cache_memdb.go index ffc7b858ef..d95555a765 100644 --- a/command/agentproxyshared/cache/cachememdb/cache_memdb.go +++ b/command/agentproxyshared/cache/cachememdb/cache_memdb.go @@ -12,9 +12,14 @@ import ( ) const ( - tableNameIndexer = "indexer" + tableNameIndexer = "indexer" + tableNameCapabilitiesIndexer = "capabilities-indexer" ) +// ErrCacheItemNotFound is returned on Get and GetCapabilitiesIndex calls +// when the entry is not found in the cache. +var ErrCacheItemNotFound = errors.New("cache item not found") + // CacheMemDB is the underlying cache database for storing indexes. type CacheMemDB struct { db *atomic.Value @@ -120,6 +125,20 @@ func newDB() (*memdb.MemDB, error) { }, }, }, + tableNameCapabilitiesIndexer: { + Name: tableNameCapabilitiesIndexer, + Indexes: map[string]*memdb.IndexSchema{ + // This index enables fetching the cached item based on the + // identifier of the index. + CapabilitiesIndexNameID: { + Name: CapabilitiesIndexNameID, + Unique: true, + Indexer: &memdb.StringFieldIndex{ + Field: "ID", + }, + }, + }, + }, }, } @@ -131,6 +150,7 @@ func newDB() (*memdb.MemDB, error) { } // Get returns the index based on the indexer and the index values provided. +// If the capabilities index isn't present, it will return nil, ErrCacheItemNotFound func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, error) { if !validIndexName(indexName) { return nil, fmt.Errorf("invalid index name %q", indexName) @@ -144,7 +164,7 @@ func (c *CacheMemDB) Get(indexName string, indexValues ...interface{}) (*Index, } if raw == nil { - return nil, nil + return nil, ErrCacheItemNotFound } index, ok := raw.(*Index) @@ -173,6 +193,50 @@ func (c *CacheMemDB) Set(index *Index) error { return nil } +// GetCapabilitiesIndex returns the CapabilitiesIndex from the cache. +// If the capabilities index isn't present, it will return nil, ErrCacheItemNotFound +func (c *CacheMemDB) GetCapabilitiesIndex(indexName string, indexValues ...interface{}) (*CapabilitiesIndex, error) { + if !validCapabilitiesIndexName(indexName) { + return nil, fmt.Errorf("invalid index name %q", indexName) + } + + txn := c.db.Load().(*memdb.MemDB).Txn(false) + + raw, err := txn.First(tableNameCapabilitiesIndexer, indexName, indexValues...) + if err != nil { + return nil, err + } + + if raw == nil { + return nil, ErrCacheItemNotFound + } + + index, ok := raw.(*CapabilitiesIndex) + if !ok { + return nil, errors.New("unable to parse capabilities index value from the cache") + } + + return index, nil +} + +// SetCapabilitiesIndex stores the CapabilitiesIndex index into the cache. +func (c *CacheMemDB) SetCapabilitiesIndex(index *CapabilitiesIndex) error { + if index == nil { + return errors.New("nil capabilities index provided") + } + + txn := c.db.Load().(*memdb.MemDB).Txn(true) + defer txn.Abort() + + if err := txn.Insert(tableNameCapabilitiesIndexer, index); err != nil { + return fmt.Errorf("unable to insert index into cache: %v", err) + } + + txn.Commit() + + return nil +} + // GetByPrefix returns all the cached indexes based on the index name and the // value prefix. func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ([]*Index, error) { @@ -210,14 +274,13 @@ func (c *CacheMemDB) GetByPrefix(indexName string, indexValues ...interface{}) ( // Evict removes an index from the cache based on index name and value. func (c *CacheMemDB) Evict(indexName string, indexValues ...interface{}) error { index, err := c.Get(indexName, indexValues...) + if err == ErrCacheItemNotFound { + return nil + } if err != nil { return fmt.Errorf("unable to fetch index on cache deletion: %v", err) } - if index == nil { - return nil - } - txn := c.db.Load().(*memdb.MemDB).Txn(true) defer txn.Abort() diff --git a/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go b/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go index 0a08e9a701..47fa75ee54 100644 --- a/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go +++ b/command/agentproxyshared/cache/cachememdb/cache_memdb_test.go @@ -40,8 +40,8 @@ func TestCacheMemDB_Get(t *testing.T) { // Test on empty cache index, err := cache.Get(IndexNameID, "foo") - if err != nil { - t.Fatal(err) + if err != ErrCacheItemNotFound { + t.Fatal("expected cache item to be not found", err) } if index != nil { t.Fatalf("expected nil index, got: %v", index) @@ -56,6 +56,7 @@ func TestCacheMemDB_Get(t *testing.T) { TokenAccessor: "test_accessor", Lease: "test_lease", Response: []byte("hello world"), + Tokens: map[string]struct{}{}, } if err := cache.Set(in); err != nil { @@ -97,7 +98,7 @@ func TestCacheMemDB_Get(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { out, err := cache.Get(tc.indexName, tc.indexValues...) - if err != nil { + if err != nil && err != ErrCacheItemNotFound { t.Fatal(err) } if diff := deep.Equal(in, out); diff != nil { @@ -169,22 +170,22 @@ func TestCacheMemDB_GetByPrefix(t *testing.T) { }{ { "by_request_path", - "request_path", + IndexNameRequestPath, []interface{}{"test_ns/", "/v1/request/path"}, }, { "by_lease", - "lease", + IndexNameLease, []interface{}{"path/to/test_lease"}, }, { "by_token_parent", - "token_parent", + IndexNameTokenParent, []interface{}{"test_token_parent"}, }, { "by_lease_token", - "lease_token", + IndexNameLeaseToken, []interface{}{"test_lease_token"}, }, } @@ -348,10 +349,9 @@ func TestCacheMemDB_Evict(t *testing.T) { // Verify that the cache doesn't contain the entry any more index, err := cache.Get(tc.indexName, tc.indexValues...) - if (err != nil) != tc.wantErr { - t.Fatal(err) + if err != ErrCacheItemNotFound && !tc.wantErr { + t.Fatal("expected cache item to be not found", err) } - if index != nil { t.Fatalf("expected nil entry, got = %#v", index) } @@ -386,8 +386,8 @@ func TestCacheMemDB_Flush(t *testing.T) { // Check the cache doesn't contain inserted index out, err := cache.Get(IndexNameID, "test_id") - if err != nil { - t.Fatal(err) + if err != ErrCacheItemNotFound { + t.Fatal("expected cache item to be not found", err) } if out != nil { t.Fatalf("expected cache to be empty, got = %v", out) diff --git a/command/agentproxyshared/cache/cachememdb/index.go b/command/agentproxyshared/cache/cachememdb/index.go index af80c0907e..3a602cab6c 100644 --- a/command/agentproxyshared/cache/cachememdb/index.go +++ b/command/agentproxyshared/cache/cachememdb/index.go @@ -23,11 +23,12 @@ type Index struct { // Required: true, Unique: true Token string - // Tokens is a list of tokens that can access this cached response, + // Tokens is a set of tokens that can access this cached response, // which is used for static secret caching, and enabling multiple // tokens to be able to access the same cache entry for static secrets. + // Implemented as a map so that all values are unique. // Required: false, Unique: false - Tokens []string + Tokens map[string]struct{} // TokenParent is the parent token of the token held by this index // Required: false, Unique: false @@ -76,12 +77,35 @@ type Index struct { // LastRenewed is the timestamp of last renewal LastRenewed time.Time - // Type is the index type (token, auth-lease, secret-lease) + // Type is the index type (token, auth-lease, secret-lease, static-secret) Type string // IndexLock is a lock held for some indexes to prevent data // races upon update. - IndexLock sync.Mutex + IndexLock sync.RWMutex +} + +// CapabilitiesIndex holds the capabilities for cached static secrets. +// This type of index does not represent a response. +type CapabilitiesIndex struct { + // ID is a value that uniquely represents the request held by this + // index. This is computed by hashing the token that this capabilities + // index represents the capabilities of. + // Required: true, Unique: true + ID string + + // Token is the token that fetched the response held by this index + // Required: true, Unique: true + Token string + + // ReadablePaths is a set of paths with read capabilities for the given token. + // Implemented as a map for uniqueness. The key to the map is a path (such as + // `foo/bar` that we've demonstrated we can read. + ReadablePaths map[string]struct{} + + // IndexLock is a lock held for some indexes to prevent data + // races upon update. + IndexLock sync.RWMutex } type IndexName uint32 @@ -107,17 +131,29 @@ const ( // IndexNameLeaseToken is the token that created the lease. IndexNameLeaseToken = "lease_token" + + // CapabilitiesIndexNameID is the ID of the capabilities index. + CapabilitiesIndexNameID = "id" ) func validIndexName(indexName string) bool { switch indexName { - case "id": - case "lease": - case "request_path": - case "token": - case "token_accessor": - case "token_parent": - case "lease_token": + case IndexNameID: + case IndexNameLease: + case IndexNameRequestPath: + case IndexNameToken: + case IndexNameTokenAccessor: + case IndexNameTokenParent: + case IndexNameLeaseToken: + default: + return false + } + return true +} + +func validCapabilitiesIndexName(indexName string) bool { + switch indexName { + case CapabilitiesIndexNameID: default: return false } diff --git a/command/agentproxyshared/cache/cachememdb/index_test.go b/command/agentproxyshared/cache/cachememdb/index_test.go index a218b4433a..7b348e3402 100644 --- a/command/agentproxyshared/cache/cachememdb/index_test.go +++ b/command/agentproxyshared/cache/cachememdb/index_test.go @@ -17,7 +17,7 @@ func TestSerializeDeserialize(t *testing.T) { testIndex := &Index{ ID: "testid", Token: "testtoken", - Tokens: []string{"token1", "token2"}, + Tokens: map[string]struct{}{"token1": {}, "token2": {}}, TokenParent: "parent token", TokenAccessor: "test accessor", Namespace: "test namespace", diff --git a/command/agentproxyshared/cache/lease_cache.go b/command/agentproxyshared/cache/lease_cache.go index be5ee90bff..95a5d5ee78 100644 --- a/command/agentproxyshared/cache/lease_cache.go +++ b/command/agentproxyshared/cache/lease_cache.go @@ -14,7 +14,6 @@ import ( "io" "net/http" "net/url" - "slices" "strings" "sync" "time" @@ -224,18 +223,22 @@ func (c *LeaseCache) checkCacheForRequest(id string, req *SendRequest) (*SendRes } index, err := c.db.Get(cachememdb.IndexNameID, id) + if err == cachememdb.ErrCacheItemNotFound { + return nil, nil + } if err != nil { return nil, err } - if index == nil { - return nil, nil - } + index.IndexLock.RLock() + defer index.IndexLock.RUnlock() if token != "" { // This is a static secret check. We need to ensure that this token // has previously demonstrated access to this static secret. - if !slices.Contains(index.Tokens, token) { + // We could check the capabilities cache here, but since these + // indexes should be in sync, this saves us an extra cache get. + if _, ok := index.Tokens[token]; !ok { // We don't have access to this static secret, so // we do not return the cached response. return nil, nil @@ -293,7 +296,9 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, // This is the last step, so we defer the call first if inflight != nil && inflight.remaining.Load() == 0 { c.inflightCache.Delete(dynamicSecretCacheId) - c.inflightCache.Delete(staticSecretCacheId) + if staticSecretCacheId != "" { + c.inflightCache.Delete(staticSecretCacheId) + } } }() @@ -330,37 +335,39 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, idLockDynamicSecret.Unlock() } - idLockStaticSecret := locksutil.LockForKey(c.idLocks, staticSecretCacheId) + if staticSecretCacheId != "" { + idLockStaticSecret := locksutil.LockForKey(c.idLocks, staticSecretCacheId) - // Briefly grab an ID-based lock in here to emulate a load-or-store behavior - // and prevent concurrent cacheable requests from being proxied twice if - // they both miss the cache due to it being clean when peeking the cache - // entry. - idLockStaticSecret.Lock() - inflightRaw, found = c.inflightCache.Get(staticSecretCacheId) - if found { - idLockStaticSecret.Unlock() - inflight = inflightRaw.(*inflightRequest) - inflight.remaining.Inc() - defer inflight.remaining.Dec() - - // If found it means that there's an inflight request being processed. - // We wait until that's finished before proceeding further. - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-inflight.ch: - } - } else { - if inflight == nil { - inflight = newInflightRequest() + // Briefly grab an ID-based lock in here to emulate a load-or-store behavior + // and prevent concurrent cacheable requests from being proxied twice if + // they both miss the cache due to it being clean when peeking the cache + // entry. + idLockStaticSecret.Lock() + inflightRaw, found = c.inflightCache.Get(staticSecretCacheId) + if found { + idLockStaticSecret.Unlock() + inflight = inflightRaw.(*inflightRequest) inflight.remaining.Inc() defer inflight.remaining.Dec() - defer close(inflight.ch) - } - c.inflightCache.Set(staticSecretCacheId, inflight, gocache.NoExpiration) - idLockStaticSecret.Unlock() + // If found it means that there's an inflight request being processed. + // We wait until that's finished before proceeding further. + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-inflight.ch: + } + } else { + if inflight == nil { + inflight = newInflightRequest() + inflight.remaining.Inc() + defer inflight.remaining.Dec() + defer close(inflight.ch) + } + + c.inflightCache.Set(staticSecretCacheId, inflight, gocache.NoExpiration) + idLockStaticSecret.Unlock() + } } // Check if the response for this request is already in the dynamic secret cache @@ -374,13 +381,15 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, } // Check if the response for this request is already in the static secret cache - cachedResp, err = c.checkCacheForStaticSecretRequest(staticSecretCacheId, req) - if err != nil { - return nil, err - } - if cachedResp != nil { - c.logger.Debug("returning cached response", "id", staticSecretCacheId, "path", req.Request.URL.Path) - return cachedResp, nil + if staticSecretCacheId != "" { + cachedResp, err = c.checkCacheForStaticSecretRequest(staticSecretCacheId, req) + if err != nil { + return nil, err + } + if cachedResp != nil { + c.logger.Debug("returning cached response", "id", staticSecretCacheId, "path", req.Request.URL.Path) + return cachedResp, nil + } } c.logger.Debug("forwarding request from cache", "method", req.Request.Method, "path", req.Request.URL.Path) @@ -435,8 +444,9 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, return resp, nil } - // TODO: if secret.MountType == "kvv1" || secret.MountType == "kvv2" - if c.cacheStaticSecrets && secret != nil { + // There shouldn't be a situation where secret.MountType == "kv" and + // staticSecretCacheId == "", but just in case. + if c.cacheStaticSecrets && secret.MountType == "kv" && staticSecretCacheId != "" { index.Type = cacheboltdb.StaticSecretType index.ID = staticSecretCacheId err := c.cacheStaticSecret(ctx, req, resp, index) @@ -465,15 +475,15 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, case secret.LeaseID != "": c.logger.Debug("processing lease response", "method", req.Request.Method, "path", req.Request.URL.Path) entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) + if err == cachememdb.ErrCacheItemNotFound { + // If the lease belongs to a token that is not managed by the lease cache, + // return the response without caching it. + c.logger.Debug("pass-through lease response; token not managed by lease cache", "method", req.Request.Method, "path", req.Request.URL.Path) + return resp, nil + } if err != nil { return nil, err } - // If the lease belongs to a token that is not managed by the agent, - // return the response without caching it. - if entry == nil { - c.logger.Debug("pass-through lease response; token not managed by agent", "method", req.Request.Method, "path", req.Request.URL.Path) - return resp, nil - } // Derive a context for renewal using the token's context renewCtxInfo = cachememdb.NewContextInfo(entry.RenewCtxInfo.Ctx) @@ -491,15 +501,15 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, var parentCtx context.Context if !secret.Auth.Orphan { entry, err := c.db.Get(cachememdb.IndexNameToken, req.Token) + if err == cachememdb.ErrCacheItemNotFound { + // If the lease belongs to a token that is not managed by the lease cache, + // return the response without caching it. + c.logger.Debug("pass-through lease response; parent token not managed by lease cache", "method", req.Request.Method, "path", req.Request.URL.Path) + return resp, nil + } if err != nil { return nil, err } - // If parent token is not managed by the agent, child shouldn't be - // either. - if entry == nil { - c.logger.Debug("pass-through auth response; parent token not managed by agent", "method", req.Request.Method, "path", req.Request.URL.Path) - return resp, nil - } c.logger.Debug("setting parent context", "method", req.Request.Method, "path", req.Request.URL.Path) parentCtx = entry.RenewCtxInfo.Ctx @@ -572,20 +582,19 @@ func (c *LeaseCache) cacheStaticSecret(ctx context.Context, req *SendRequest, re // If a cached version of this secret exists, we now have access, so // we don't need to re-cache, just update index.Tokens indexFromCache, err := c.db.Get(cachememdb.IndexNameID, index.ID) - if err != nil { + if err != nil && err != cachememdb.ErrCacheItemNotFound { return err } - // We must hold a lock for the index while it's being updated. - // We keep the two locking mechanisms distinct, so that it's only writes - // that have to be serial. - index.IndexLock.Lock() - defer index.IndexLock.Unlock() - // The index already exists, so all we need to do is add our token // to the index's allowed token list, then re-store it if indexFromCache != nil { - indexFromCache.Tokens = append(indexFromCache.Tokens, req.Token) + // We must hold a lock for the index while it's being updated. + // We keep the two locking mechanisms distinct, so that it's only writes + // that have to be serial. + indexFromCache.IndexLock.Lock() + defer indexFromCache.IndexLock.Unlock() + indexFromCache.Tokens[req.Token] = struct{}{} return c.storeStaticSecretIndex(ctx, req, indexFromCache) } @@ -607,8 +616,8 @@ func (c *LeaseCache) cacheStaticSecret(ctx context.Context, req *SendRequest, re // Set the index's Response index.Response = respBytes.Bytes() - // Set the index's tokens - index.Tokens = []string{req.Token} + // Initialize the token map and add this token to it. + index.Tokens = map[string]struct{}{req.Token: {}} // Set the index type index.Type = cacheboltdb.StaticSecretType @@ -625,13 +634,55 @@ func (c *LeaseCache) storeStaticSecretIndex(ctx context.Context, req *SendReques return err } - // TODO: We need to also update the cache for the token's permission capabilities. - // TODO: for this we'll need: req.Token, req.URL.Path - // TODO: we need to build a NEW index, with a hash of the token as the ID + capabilitiesIndex, err := c.retrieveOrCreateTokenCapabilitiesEntry(req.Token) + if err != nil { + c.logger.Error("failed to cache the proxied response", "error", err) + return err + } + + path := getStaticSecretPathFromRequest(req) + + // Extra caution -- avoid potential nil + if capabilitiesIndex.ReadablePaths == nil { + capabilitiesIndex.ReadablePaths = make(map[string]struct{}) + } + + // update the index with the new capability: + capabilitiesIndex.ReadablePaths[path] = struct{}{} + + err = c.db.SetCapabilitiesIndex(capabilitiesIndex) + if err != nil { + c.logger.Error("failed to cache token capabilities as part of caching the proxied response", "error", err) + return err + } return nil } +// retrieveOrCreateTokenCapabilitiesEntry will either retrieve the token +// capabilities entry from the cache, or create a new, empty one. +func (c *LeaseCache) retrieveOrCreateTokenCapabilitiesEntry(token string) (*cachememdb.CapabilitiesIndex, error) { + // The index ID is a hash of the token. + indexId := hex.EncodeToString(cryptoutil.Blake2b256Hash(token)) + indexFromCache, err := c.db.GetCapabilitiesIndex(cachememdb.IndexNameID, indexId) + if err != nil && err != cachememdb.ErrCacheItemNotFound { + return nil, err + } + + if indexFromCache != nil { + return indexFromCache, nil + } + + // Build the index to cache based on the response received + index := &cachememdb.CapabilitiesIndex{ + ID: indexId, + Token: token, + ReadablePaths: make(map[string]struct{}), + } + + return index, nil +} + func (c *LeaseCache) createCtxInfo(ctx context.Context) *cachememdb.ContextInfo { if ctx == nil { c.l.RLock() @@ -727,7 +778,7 @@ func (c *LeaseCache) updateLastRenewed(ctx context.Context, index *cachememdb.In defer idLock.Unlock() getIndex, err := c.db.Get(cachememdb.IndexNameID, index.ID) - if err != nil { + if err != nil && err != cachememdb.ErrCacheItemNotFound { return err } index.LastRenewed = t @@ -764,12 +815,63 @@ func computeIndexID(req *SendRequest) (string, error) { return hex.EncodeToString(cryptoutil.Blake2b256Hash(string(b.Bytes()))), nil } +// canonicalizeStaticSecretPath takes an API request path such as +// /v1/foo/bar and a namespace, and turns it into a canonical representation +// of the secret's path in Vault. +// We opt for this form as namespace.Canonicalize returns a namespace in the +// form of "ns1/", so we keep consistent with path canonicalization. +func canonicalizeStaticSecretPath(requestPath string, ns string) string { + // /sys/capabilities accepts both requests that look like foo/bar + // and /foo/bar but not /v1/foo/bar. + // We trim the /v1/ from the start of the URL to get the foo/bar form. + // This means that we can use the paths we retrieve from the + // /sys/capabilities endpoint to access this index + // without having to re-add the /v1/ + path := strings.TrimPrefix(requestPath, "/v1/") + // Trim any leading slashes, as we never want those. + // This ensures /foo/bar gets turned to foo/bar + path = strings.TrimPrefix(path, "/") + + // If a namespace was provided in a way that wasn't directly in the path, + // it must be added to the path. + path = namespace.Canonicalize(ns) + path + + return path +} + +// getStaticSecretPathFromRequest gets the canonical path for a +// request, taking into account intricacies relating to /v1/ and namespaces +// in the header. +// Returns a path like foo/bar or ns1/foo/bar. +// We opt for this form as namespace.Canonicalize returns a namespace in the +// form of "ns1/", so we keep consistent with path canonicalization. +func getStaticSecretPathFromRequest(req *SendRequest) string { + path := req.Request.URL.Path + // Static secrets always have /v1 as a prefix. This enables us to + // enable a pass-through and never attempt to cache or view-from-cache + // any request without the /v1 prefix. + if !strings.HasPrefix(path, "/v1") { + return "" + } + var namespace string + if header := req.Request.Header; header != nil { + namespace = header.Get(api.NamespaceHeaderName) + } + return canonicalizeStaticSecretPath(path, namespace) +} + // computeStaticSecretCacheIndex results in a value that uniquely identifies a static // secret's cached ID. Notably, we intentionally ignore headers (for example, // the X-Vault-Token header) to remain agnostic to which token is being // used in the request. We care only about the path. +// This will return "" if the index does not have a /v1 prefix, and therefore +// cannot be a static secret. func computeStaticSecretCacheIndex(req *SendRequest) string { - return hex.EncodeToString(cryptoutil.Blake2b256Hash(req.Request.URL.Path)) + path := getStaticSecretPathFromRequest(req) + if path == "" { + return path + } + return hex.EncodeToString(cryptoutil.Blake2b256Hash(path)) } // HandleCacheClear returns a handlerFunc that can perform cache clearing operations. @@ -848,9 +950,18 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) return err } for _, index := range indexes { - if index.RenewCtxInfo != nil { - if index.RenewCtxInfo.CancelFunc != nil { - index.RenewCtxInfo.CancelFunc() + // If it's a static secret, we must remove directly, as there + // is no renew func to cancel. + if index.Type == cacheboltdb.StaticSecretType { + err = c.db.Evict(cachememdb.IndexNameID, index.ID) + if err != nil { + return err + } + } else { + if index.RenewCtxInfo != nil { + if index.RenewCtxInfo.CancelFunc != nil { + index.RenewCtxInfo.CancelFunc() + } } } } @@ -862,12 +973,12 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) // Get the context for the given token and cancel its context index, err := c.db.Get(cachememdb.IndexNameToken, in.Token) + if err == cachememdb.ErrCacheItemNotFound { + return nil + } if err != nil { return err } - if index == nil { - return nil - } c.logger.Debug("canceling context of index attached to token") @@ -881,12 +992,12 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) // Get the cached index and cancel the corresponding lifetime watcher // context index, err := c.db.Get(cachememdb.IndexNameTokenAccessor, in.TokenAccessor) + if err == cachememdb.ErrCacheItemNotFound { + return nil + } if err != nil { return err } - if index == nil { - return nil - } c.logger.Debug("canceling context of index attached to accessor") @@ -900,12 +1011,12 @@ func (c *LeaseCache) handleCacheClear(ctx context.Context, in *cacheClearInput) // Get the cached index and cancel the corresponding lifetime watcher // context index, err := c.db.Get(cachememdb.IndexNameLease, in.Lease) + if err == cachememdb.ErrCacheItemNotFound { + return nil + } if err != nil { return err } - if index == nil { - return nil - } c.logger.Debug("canceling context of index attached to accessor") @@ -1036,12 +1147,12 @@ func (c *LeaseCache) handleRevocationRequest(ctx context.Context, req *SendReque // Kill the lifetime watchers of the revoked token index, err := c.db.Get(cachememdb.IndexNameToken, token) + if err == cachememdb.ErrCacheItemNotFound { + return true, nil + } if err != nil { return false, err } - if index == nil { - return true, nil - } // Indicate the lifetime watcher goroutine for this index to return. // This will not affect the child tokens because the context is not @@ -1284,14 +1395,13 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error { switch { case secret.LeaseID != "": entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) + if err == cachememdb.ErrCacheItemNotFound { + return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) + } if err != nil { return err } - if entry == nil { - return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) - } - // Derive a context for renewal using the token's context renewCtxInfo = cachememdb.NewContextInfo(entry.RenewCtxInfo.Ctx) @@ -1299,14 +1409,16 @@ func (c *LeaseCache) restoreLeaseRenewCtx(index *cachememdb.Index) error { var parentCtx context.Context if !secret.Auth.Orphan { entry, err := c.db.Get(cachememdb.IndexNameToken, index.RequestToken) + if err == cachememdb.ErrCacheItemNotFound { + // If parent token is not managed by the cache, child shouldn't be + // either. + if entry == nil { + return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) + } + } if err != nil { return err } - // If parent token is not managed by the agent, child shouldn't be - // either. - if entry == nil { - return fmt.Errorf("could not find parent Token %s for req path %s", index.RequestToken, index.RequestPath) - } c.logger.Debug("setting parent context", "method", index.RequestMethod, "path", index.RequestPath) parentCtx = entry.RenewCtxInfo.Ctx @@ -1398,7 +1510,7 @@ func deriveNamespaceAndRevocationPath(req *SendRequest) (string, string) { func (c *LeaseCache) RegisterAutoAuthToken(token string) error { // Get the token from the cache oldIndex, err := c.db.Get(cachememdb.IndexNameToken, token) - if err != nil { + if err != nil && err != cachememdb.ErrCacheItemNotFound { return err } diff --git a/command/agentproxyshared/cache/lease_cache_test.go b/command/agentproxyshared/cache/lease_cache_test.go index cbd46bf92a..6ede59e01e 100644 --- a/command/agentproxyshared/cache/lease_cache_test.go +++ b/command/agentproxyshared/cache/lease_cache_test.go @@ -5,6 +5,7 @@ package cache import ( "context" + "encoding/hex" "fmt" "io/ioutil" "net/http" @@ -27,6 +28,7 @@ import ( "github.com/hashicorp/vault/helper/useragent" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/cryptoutil" "github.com/hashicorp/vault/sdk/helper/logging" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -41,10 +43,11 @@ func testNewLeaseCache(t *testing.T, responses []*SendResponse) *LeaseCache { t.Fatal(err) } lc, err := NewLeaseCache(&LeaseCacheConfig{ - Client: client, - BaseContext: context.Background(), - Proxier: NewMockProxier(responses), - Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + Client: client, + BaseContext: context.Background(), + Proxier: NewMockProxier(responses), + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + CacheStaticSecrets: true, }) if err != nil { t.Fatal(err) @@ -61,10 +64,11 @@ func testNewLeaseCacheWithDelay(t *testing.T, cacheable bool, delay int) *LeaseC } lc, err := NewLeaseCache(&LeaseCacheConfig{ - Client: client, - BaseContext: context.Background(), - Proxier: &mockDelayProxier{cacheable, delay}, - Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + Client: client, + BaseContext: context.Background(), + Proxier: &mockDelayProxier{cacheable, delay}, + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + CacheStaticSecrets: true, }) if err != nil { t.Fatal(err) @@ -80,11 +84,12 @@ func testNewLeaseCacheWithPersistence(t *testing.T, responses []*SendResponse, s require.NoError(t, err) lc, err := NewLeaseCache(&LeaseCacheConfig{ - Client: client, - BaseContext: context.Background(), - Proxier: NewMockProxier(responses), - Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), - Storage: storage, + Client: client, + BaseContext: context.Background(), + Proxier: NewMockProxier(responses), + Logger: logging.NewVaultLogger(hclog.Trace).Named("cache.leasecache"), + Storage: storage, + CacheStaticSecrets: true, }) require.NoError(t, err) @@ -92,9 +97,6 @@ func testNewLeaseCacheWithPersistence(t *testing.T, responses []*SendResponse, s } func TestCache_ComputeIndexID(t *testing.T) { - type args struct { - req *http.Request - } tests := []struct { name string req *SendRequest @@ -145,6 +147,232 @@ func TestCache_ComputeIndexID(t *testing.T) { } } +// TestCache_ComputeStaticSecretIndexID ensures that +// computeStaticSecretCacheIndex works correctly. If this test breaks, then our +// hashing algorithm has changed, and we risk breaking backwards compatibility. +func TestCache_ComputeStaticSecretIndexID(t *testing.T) { + req := &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/foo/bar", + }, + }, + } + + index := computeStaticSecretCacheIndex(req) + // We expect this to be "", as it doesn't start with /v1 + expectedIndex := "" + require.Equal(t, expectedIndex, index) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/foo/bar", + }, + }, + } + + expectedIndex = "b117a962f19f17fa372c8681cadcd6fd370d28ee6e0a7012196b780bef601b53" + index2 := computeStaticSecretCacheIndex(req) + require.Equal(t, expectedIndex, index2) +} + +// Test_GetStaticSecretPathFromRequestNoNamespaces tests that getStaticSecretPathFromRequest +// behaves as expected when no namespaces are involved. +func Test_GetStaticSecretPathFromRequestNoNamespaces(t *testing.T) { + req := &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/foo/bar", + }, + }, + } + + path := getStaticSecretPathFromRequest(req) + require.Equal(t, "foo/bar", path) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + // Paths like this are not static secrets, so we should return "" + Path: "foo/bar", + }, + }, + } + + path = getStaticSecretPathFromRequest(req) + require.Equal(t, "", path) +} + +// Test_GetStaticSecretPathFromRequestNamespaces tests that getStaticSecretPathFromRequest +// behaves as expected when namespaces are involved. +func Test_GetStaticSecretPathFromRequestNamespaces(t *testing.T) { + req := &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/foo/bar", + }, + Header: map[string][]string{api.NamespaceHeaderName: {"ns1"}}, + }, + } + + path := getStaticSecretPathFromRequest(req) + require.Equal(t, "ns1/foo/bar", path) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/foo/bar", + }, + }, + } + + path = getStaticSecretPathFromRequest(req) + require.Equal(t, "ns1/foo/bar", path) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + // Paths like this are not static secrets, so we should return "" + Path: "ns1/foo/bar", + }, + }, + } + + path = getStaticSecretPathFromRequest(req) + require.Equal(t, "", path) +} + +// TestCache_CanonicalizeStaticSecretPath ensures that +// canonicalizeStaticSecretPath works as expected with all kinds of inputs. +func TestCache_CanonicalizeStaticSecretPath(t *testing.T) { + expected := "foo/bar" + actual := canonicalizeStaticSecretPath("/v1/foo/bar", "") + require.Equal(t, expected, actual) + + actual = canonicalizeStaticSecretPath("foo/bar", "") + require.Equal(t, expected, actual) + actual = canonicalizeStaticSecretPath("/foo/bar", "") + require.Equal(t, expected, actual) + + expected = "ns1/foo/bar" + actual = canonicalizeStaticSecretPath("/v1/ns1/foo/bar", "") + require.Equal(t, expected, actual) + + actual = canonicalizeStaticSecretPath("ns1/foo/bar", "") + require.Equal(t, expected, actual) + actual = canonicalizeStaticSecretPath("/ns1/foo/bar", "") + require.Equal(t, expected, actual) + + expected = "ns1/foo/bar" + actual = canonicalizeStaticSecretPath("/v1/foo/bar", "ns1") + require.Equal(t, expected, actual) + + actual = canonicalizeStaticSecretPath("/foo/bar", "ns1") + require.Equal(t, expected, actual) + actual = canonicalizeStaticSecretPath("foo/bar", "ns1") + require.Equal(t, expected, actual) + + expected = "ns1/foo/bar" + actual = canonicalizeStaticSecretPath("/v1/foo/bar", "ns1/") + require.Equal(t, expected, actual) + + actual = canonicalizeStaticSecretPath("/foo/bar", "ns1/") + require.Equal(t, expected, actual) + actual = canonicalizeStaticSecretPath("foo/bar", "ns1/") + require.Equal(t, expected, actual) + + expected = "ns1/foo/bar" + actual = canonicalizeStaticSecretPath("/v1/foo/bar", "/ns1/") + require.Equal(t, expected, actual) + + actual = canonicalizeStaticSecretPath("/foo/bar", "/ns1/") + require.Equal(t, expected, actual) + actual = canonicalizeStaticSecretPath("foo/bar", "/ns1/") + require.Equal(t, expected, actual) +} + +// TestCache_ComputeStaticSecretIndexIDNamespaces ensures that +// computeStaticSecretCacheIndex correctly identifies that a request +// with a namespace header and a request specifying the namespace in the path +// are equivalent. +func TestCache_ComputeStaticSecretIndexIDNamespaces(t *testing.T) { + req := &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "foo/bar", + }, + Header: map[string][]string{api.NamespaceHeaderName: {"ns1"}}, + }, + } + + index := computeStaticSecretCacheIndex(req) + // Paths like this are not static secrets, so we should expect "" + require.Equal(t, "", index) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "ns1/foo/bar", + }, + }, + } + + // Paths like this are not static secrets, so we should expect "" + index2 := computeStaticSecretCacheIndex(req) + require.Equal(t, "", index2) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/ns1/foo/bar", + }, + }, + } + + expectedIndex := "a4605679d269aa1bebac7079a471a33403413f388f63bf0da3c771b225857932" + // We expect that computeStaticSecretCacheIndex will compute the same index + index3 := computeStaticSecretCacheIndex(req) + require.Equal(t, expectedIndex, index3) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/foo/bar", + }, + Header: map[string][]string{api.NamespaceHeaderName: {"ns1"}}, + }, + } + + index4 := computeStaticSecretCacheIndex(req) + require.Equal(t, expectedIndex, index4) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/foo/bar", + }, + Header: map[string][]string{api.NamespaceHeaderName: {"ns1/"}}, + }, + } + + // Paths like this are not static secrets, so we should expect "" + index5 := computeStaticSecretCacheIndex(req) + require.Equal(t, "", index5) + + req = &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/foo/bar", + }, + Header: map[string][]string{api.NamespaceHeaderName: {"ns1/"}}, + }, + } + + index6 := computeStaticSecretCacheIndex(req) + require.Equal(t, expectedIndex, index6) +} + func TestLeaseCache_EmptyToken(t *testing.T) { responses := []*SendResponse{ newTestSendResponse(http.StatusCreated, `{"value": "invalid", "auth": {"client_token": "testtoken"}}`), @@ -176,7 +404,7 @@ func TestLeaseCache_SendCacheable(t *testing.T) { } lc := testNewLeaseCache(t, responses) - // Register an token so that the token and lease requests are cached + // Register a token so that the token and lease requests are cached require.NoError(t, lc.RegisterAutoAuthToken("autoauthtoken")) // Make a request. A response with a new token is returned to the lease @@ -248,6 +476,216 @@ func TestLeaseCache_SendCacheable(t *testing.T) { } } +// TestLeaseCache_StoreCacheableStaticSecret tests that cacheStaticSecret works +// as expected, creating the two expected cache entries, and also ensures +// that we can evict the cache entry with the cache clear API afterwards. +func TestLeaseCache_StoreCacheableStaticSecret(t *testing.T) { + request := &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/secrets/foo/bar", + }, + }, + Token: "token", + } + response := newTestSendResponse(http.StatusCreated, `{"data": {"foo": "bar"}, "mount_type": "kvv2"}`) + responses := []*SendResponse{ + response, + } + index := &cachememdb.Index{ + Type: cacheboltdb.StaticSecretType, + RequestPath: request.Request.URL.Path, + Namespace: "root/", + Token: "token", + ID: computeStaticSecretCacheIndex(request), + } + + lc := testNewLeaseCache(t, responses) + + // We expect two entries to be stored by this: + // 1. The actual static secret + // 2. The capabilities index + err := lc.cacheStaticSecret(context.Background(), request, response, index) + if err != nil { + return + } + + indexFromDB, err := lc.db.Get(cachememdb.IndexNameID, index.ID) + if err != nil { + return + } + + require.NotNil(t, indexFromDB) + require.Equal(t, "token", indexFromDB.Token) + require.Equal(t, map[string]struct{}{"token": {}}, indexFromDB.Tokens) + require.Equal(t, cacheboltdb.StaticSecretType, indexFromDB.Type) + require.Equal(t, request.Request.URL.Path, indexFromDB.RequestPath) + require.Equal(t, "root/", indexFromDB.Namespace) + + capabilitiesIndexFromDB, err := lc.db.GetCapabilitiesIndex(cachememdb.IndexNameID, hex.EncodeToString(cryptoutil.Blake2b256Hash(index.Token))) + if err != nil { + return + } + + require.NotNil(t, capabilitiesIndexFromDB) + require.Equal(t, "token", capabilitiesIndexFromDB.Token) + require.Equal(t, map[string]struct{}{"secrets/foo/bar": {}}, capabilitiesIndexFromDB.ReadablePaths) + + err = lc.handleCacheClear(context.Background(), &cacheClearInput{ + Type: "request_path", + RequestPath: request.Request.URL.Path, + }) + require.NoError(t, err) + + expectedClearedIndex, err := lc.db.Get(cachememdb.IndexNameID, index.ID) + require.Equal(t, cachememdb.ErrCacheItemNotFound, err) + require.Nil(t, expectedClearedIndex) +} + +// TestLeaseCache_StaticSecret_CacheClear_All tests that static secrets are +// stored correctly, as well as removed from the cache by a cache clear with +// "all" specified as the type. +func TestLeaseCache_StaticSecret_CacheClear_All(t *testing.T) { + request := &SendRequest{ + Request: &http.Request{ + URL: &url.URL{ + Path: "/v1/secrets/foo/bar", + }, + }, + Token: "token", + } + response := newTestSendResponse(http.StatusCreated, `{"data": {"foo": "bar"}, "mount_type": "kvv2"}`) + responses := []*SendResponse{ + response, + } + index := &cachememdb.Index{ + Type: cacheboltdb.StaticSecretType, + RequestPath: request.Request.URL.Path, + Namespace: "root/", + Token: "token", + ID: computeStaticSecretCacheIndex(request), + } + + lc := testNewLeaseCache(t, responses) + + // We expect two entries to be stored by this: + // 1. The actual static secret + // 2. The capabilities index + err := lc.cacheStaticSecret(context.Background(), request, response, index) + if err != nil { + return + } + + indexFromDB, err := lc.db.Get(cachememdb.IndexNameID, index.ID) + if err != nil { + return + } + + require.NotNil(t, indexFromDB) + require.Equal(t, "token", indexFromDB.Token) + require.Equal(t, map[string]struct{}{"token": {}}, indexFromDB.Tokens) + require.Equal(t, cacheboltdb.StaticSecretType, indexFromDB.Type) + require.Equal(t, request.Request.URL.Path, indexFromDB.RequestPath) + require.Equal(t, "root/", indexFromDB.Namespace) + + capabilitiesIndexFromDB, err := lc.db.GetCapabilitiesIndex(cachememdb.IndexNameID, hex.EncodeToString(cryptoutil.Blake2b256Hash(index.Token))) + if err != nil { + t.Fatal(err) + } + + require.NotNil(t, capabilitiesIndexFromDB) + require.Equal(t, "token", capabilitiesIndexFromDB.Token) + require.Equal(t, map[string]struct{}{"secrets/foo/bar": {}}, capabilitiesIndexFromDB.ReadablePaths) + + err = lc.handleCacheClear(context.Background(), &cacheClearInput{ + Type: "all", + }) + require.NoError(t, err) + + expectedClearedIndex, err := lc.db.Get(cachememdb.IndexNameID, index.ID) + require.Equal(t, cachememdb.ErrCacheItemNotFound, err) + require.Nil(t, expectedClearedIndex) + + expectedClearedCapabilitiesIndex, err := lc.db.GetCapabilitiesIndex(cachememdb.IndexNameID, capabilitiesIndexFromDB.ID) + require.Equal(t, cachememdb.ErrCacheItemNotFound, err) + require.Nil(t, expectedClearedCapabilitiesIndex) +} + +// TestLeaseCache_SendCacheableStaticSecret tests that the cache has no issue returning +// static secret style responses. It's similar to TestLeaseCache_SendCacheable in that it +// only tests the surface level of the functionality, but there are other tests that +// test the rest. +func TestLeaseCache_SendCacheableStaticSecret(t *testing.T) { + response := newTestSendResponse(http.StatusCreated, `{"data": {"foo": "bar"}, "mount_type": "kvv2"}`) + responses := []*SendResponse{ + response, + response, + response, + response, + } + + lc := testNewLeaseCache(t, responses) + + // Register a token + require.NoError(t, lc.RegisterAutoAuthToken("autoauthtoken")) + + // Make a request. A response with a new token is returned to the lease + // cache and that will be cached. + urlPath := "http://example.com/v1/sample/api" + sendReq := &SendRequest{ + Token: "autoauthtoken", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err := lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, response.Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Send the same request again to get the cached response + sendReq = &SendRequest{ + Token: "autoauthtoken", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, responses[0].Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Modify the request a little to ensure the second response is + // returned to the lease cache. + sendReq = &SendRequest{ + Token: "autoauthtoken", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, response.Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } + + // Make the same request again and ensure that the same response is returned + // again. + sendReq = &SendRequest{ + Token: "autoauthtoken", + Request: httptest.NewRequest("GET", urlPath, strings.NewReader(`{"value": "input_changed"}`)), + } + resp, err = lc.Send(context.Background(), sendReq) + if err != nil { + t.Fatal(err) + } + if diff := deep.Equal(resp.Response.StatusCode, response.Response.StatusCode); diff != nil { + t.Fatalf("expected getting proxied response: got %v", diff) + } +} + func TestLeaseCache_SendNonCacheable(t *testing.T) { responses := []*SendResponse{ newTestSendResponse(http.StatusOK, `{"value": "output"}`), @@ -338,12 +776,9 @@ func TestLeaseCache_SendNonCacheableNonTokenLease(t *testing.T) { t.Fatalf("expected getting proxied response: got %v", diff) } - idx, err := lc.db.Get(cachememdb.IndexNameRequestPath, "root/", urlPath) - if err != nil { - t.Fatal(err) - } - if idx != nil { - t.Fatalf("expected nil entry, got: %#v", idx) + _, err = lc.db.Get(cachememdb.IndexNameRequestPath, "root/", urlPath) + if err != cachememdb.ErrCacheItemNotFound { + t.Fatal("expected entry to be nil, got", err) } // Verify that the response is not cached by sending the same request and @@ -360,12 +795,9 @@ func TestLeaseCache_SendNonCacheableNonTokenLease(t *testing.T) { t.Fatalf("expected getting proxied response: got %v", diff) } - idx, err = lc.db.Get(cachememdb.IndexNameRequestPath, "root/", urlPath) - if err != nil { - t.Fatal(err) - } - if idx != nil { - t.Fatalf("expected nil entry, got: %#v", idx) + _, err = lc.db.Get(cachememdb.IndexNameRequestPath, "root/", urlPath) + if err != cachememdb.ErrCacheItemNotFound { + t.Fatal("expected entry to be nil, got", err) } } @@ -773,6 +1205,7 @@ func TestLeaseCache_PersistAndRestore(t *testing.T) { // 204 No content gets special handling - avoid. newTestSendResponse(250, `{"auth": {"client_token": "testtoken3", "renewable": true, "orphan": true, "lease_duration": 600}}`), newTestSendResponse(251, `{"lease_id": "secret3-lease", "renewable": true, "data": {"number": "three"}, "lease_duration": 600}`), + newTestSendResponse(http.StatusCreated, `{"data": {"foo": "bar"}, "mount_type": "kvv2"}`), } tempDir, boltStorage := setupBoltStorage(t) diff --git a/command/agentproxyshared/cache/testing.go b/command/agentproxyshared/cache/testing.go index 8bc2239cdf..4bc2e1d025 100644 --- a/command/agentproxyshared/cache/testing.go +++ b/command/agentproxyshared/cache/testing.go @@ -7,7 +7,7 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "io" "math/rand" "net/http" "strings" @@ -62,7 +62,7 @@ func newTestSendResponse(status int, body string) *SendResponse { resp.Response.Header.Set("Date", time.Now().Format(http.TimeFormat)) if body != "" { - resp.Response.Body = ioutil.NopCloser(strings.NewReader(body)) + resp.Response.Body = io.NopCloser(strings.NewReader(body)) resp.ResponseBody = []byte(body) }