diff --git a/vault/expiration.go b/vault/expiration.go index a378e53a50..d8be7fdfab 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -270,16 +270,7 @@ func (m *ExpirationManager) Renew(leaseID string, increment time.Duration) (*log } // Update the expiration time - m.pendingLock.Lock() - if timer, ok := m.pending[leaseID]; ok { - if le.ExpireTime.IsZero() { - timer.Stop() - delete(m.pending, leaseID) - } else { - timer.Reset(resp.Secret.LeaseTotal()) - } - } - m.pendingLock.Unlock() + m.updatePending(le, resp.Secret.LeaseTotal()) // Return the response return resp, nil @@ -320,16 +311,7 @@ func (m *ExpirationManager) RenewToken(source string, token string) (*logical.Au } // Update the expiration time - m.pendingLock.Lock() - if timer, ok := m.pending[leaseID]; ok { - if le.ExpireTime.IsZero() { - timer.Stop() - delete(m.pending, leaseID) - } else { - timer.Reset(le.Auth.LeaseTotal()) - } - } - m.pendingLock.Unlock() + m.updatePending(le, le.Auth.LeaseTotal()) return le.Auth, nil } @@ -364,13 +346,7 @@ func (m *ExpirationManager) Register(req *logical.Request, resp *logical.Respons } // Setup revocation timer if there is a lease - if !le.ExpireTime.IsZero() { - m.pendingLock.Lock() - m.pending[le.LeaseID] = time.AfterFunc(resp.Secret.LeaseTotal(), func() { - m.expireID(le.LeaseID) - }) - m.pendingLock.Unlock() - } + m.updatePending(&le, resp.Secret.LeaseTotal()) // Done return le.LeaseID, nil @@ -396,14 +372,38 @@ func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) erro } // Setup revocation timer - if !le.ExpireTime.IsZero() { - m.pendingLock.Lock() - m.pending[le.LeaseID] = time.AfterFunc(auth.LeaseTotal(), func() { + m.updatePending(&le, auth.LeaseTotal()) + return nil +} + +// updatePending is used to update a pending invocation for a lease +func (m *ExpirationManager) updatePending(le *leaseEntry, leaseTotal time.Duration) { + m.pendingLock.Lock() + defer m.pendingLock.Unlock() + + // Check for an existing timer + timer, ok := m.pending[le.LeaseID] + + // Create entry if it does not exist + if !ok && leaseTotal > 0 { + timer := time.AfterFunc(leaseTotal, func() { m.expireID(le.LeaseID) }) - m.pendingLock.Unlock() + m.pending[le.LeaseID] = timer + return + } + + // Delete the timer if the expiration time is zero + if ok && leaseTotal == 0 { + timer.Stop() + delete(m.pending, le.LeaseID) + return + } + + // Extend the timer by the lease total + if ok && leaseTotal > 0 { + timer.Reset(leaseTotal) } - return nil } // expireID is invoked when a given ID is expired