diff --git a/command/agent/cache/lease_cache.go b/command/agent/cache/lease_cache.go index 156cba0a5a..c88c105ff1 100644 --- a/command/agent/cache/lease_cache.go +++ b/command/agent/cache/lease_cache.go @@ -22,6 +22,7 @@ import ( "github.com/hashicorp/vault/helper/consts" "github.com/hashicorp/vault/helper/cryptoutil" "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/locksutil" "github.com/hashicorp/vault/helper/namespace" nshelper "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/logical" @@ -72,6 +73,10 @@ type LeaseCache struct { db *cachememdb.CacheMemDB baseCtxInfo *cachememdb.ContextInfo l *sync.RWMutex + + // idLocks is used during cache lookup to ensure that identical requests made + // in parallel won't trigger multiple renewal goroutines. + idLocks []*locksutil.LockEntry } // LeaseCacheConfig is the configuration for initializing a new @@ -112,6 +117,35 @@ func NewLeaseCache(conf *LeaseCacheConfig) (*LeaseCache, error) { db: db, baseCtxInfo: baseCtxInfo, l: &sync.RWMutex{}, + idLocks: locksutil.CreateLocks(), + }, nil +} + +// checkCacheForRequest checks the cache for a particular request based on its +// computed ID. It returns a non-nil *SendResponse if an entry is found. +func (c *LeaseCache) checkCacheForRequest(id string) (*SendResponse, error) { + index, err := c.db.Get(cachememdb.IndexNameID, id) + if err != nil { + return nil, err + } + + if index == nil { + return nil, nil + } + + // Cached request is found, deserialize the response + reader := bufio.NewReader(bytes.NewReader(index.Response)) + resp, err := http.ReadResponse(reader, nil) + if err != nil { + c.logger.Error("failed to deserialize response", "error", err) + return nil, err + } + + return &SendResponse{ + Response: &api.Response{ + Response: resp, + }, + ResponseBody: index.Response, }, nil } @@ -126,29 +160,40 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, return nil, err } + // Grab a read lock for this particular request + idLock := locksutil.LockForKey(c.idLocks, id) + + idLock.RLock() + unlockFunc := idLock.RUnlock + defer func() { unlockFunc() }() + // Check if the response for this request is already in the cache - index, err := c.db.Get(cachememdb.IndexNameID, id) + sendResp, err := c.checkCacheForRequest(id) + if err != nil { + return nil, err + } + if sendResp != nil { + c.logger.Debug("returning cached response", "path", req.Request.URL.Path) + return sendResp, nil + } + + // Perform a lock upgrade + idLock.RUnlock() + idLock.Lock() + unlockFunc = idLock.Unlock + + // Check cache once more after upgrade + sendResp, err = c.checkCacheForRequest(id) if err != nil { return nil, err } - // Cached request is found, deserialize the response and return early - if index != nil { + // If found, it means that some other parallel request already cached this response + // in between this upgrade so we can simply return that. Otherwise, this request + // will be the one performing the cache write. + if sendResp != nil { c.logger.Debug("returning cached response", "path", req.Request.URL.Path) - - reader := bufio.NewReader(bytes.NewReader(index.Response)) - resp, err := http.ReadResponse(reader, nil) - if err != nil { - c.logger.Error("failed to deserialize response", "error", err) - return nil, err - } - - return &SendResponse{ - Response: &api.Response{ - Response: resp, - }, - ResponseBody: index.Response, - }, nil + return sendResp, nil } c.logger.Debug("forwarding request", "path", req.Request.URL.Path, "method", req.Request.Method) @@ -174,7 +219,7 @@ func (c *LeaseCache) Send(ctx context.Context, req *SendRequest) (*SendResponse, } // Build the index to cache based on the response received - index = &cachememdb.Index{ + index := &cachememdb.Index{ ID: id, Namespace: namespace, RequestPath: req.Request.URL.Path,