ExpirationManager restoration to load in the background (#3260)

This commit is contained in:
Chris Hoffman 2017-09-05 11:09:00 -04:00 committed by GitHub
parent 051c0b0719
commit 16fbfeb5ef
5 changed files with 354 additions and 81 deletions

View File

@ -2,6 +2,7 @@ package vault
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"path" "path"
"strings" "strings"
@ -38,9 +39,6 @@ const (
// revokeRetryBase is a baseline retry time // revokeRetryBase is a baseline retry time
revokeRetryBase = 10 * time.Second revokeRetryBase = 10 * time.Second
// minRevokeDelay is used to prevent an instant revoke on restore
minRevokeDelay = 5 * time.Second
// maxLeaseDuration is the default maximum lease duration // maxLeaseDuration is the default maximum lease duration
maxLeaseTTL = 32 * 24 * time.Hour maxLeaseTTL = 32 * 24 * time.Hour
@ -60,9 +58,17 @@ type ExpirationManager struct {
logger log.Logger logger log.Logger
pending map[string]*time.Timer pending map[string]*time.Timer
pendingLock sync.Mutex pendingLock sync.RWMutex
tidyLock int64 tidyLock int64
// A set of locks to handle restoration
restoreMode int64
restoreModeLock sync.RWMutex
restoreRequestLock sync.RWMutex
restoreLocks []*locksutil.LockEntry
restoreLoaded sync.Map
quitCh chan struct{}
} }
// NewExpirationManager creates a new ExpirationManager that is backed // NewExpirationManager creates a new ExpirationManager that is backed
@ -72,6 +78,7 @@ func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, log
logger = log.New("expiration_manager") logger = log.New("expiration_manager")
} }
exp := &ExpirationManager{ exp := &ExpirationManager{
router: router, router: router,
idView: view.SubView(leaseViewPrefix), idView: view.SubView(leaseViewPrefix),
@ -79,6 +86,12 @@ func NewExpirationManager(router *Router, view *BarrierView, ts *TokenStore, log
tokenStore: ts, tokenStore: ts,
logger: logger, logger: logger,
pending: make(map[string]*time.Timer), pending: make(map[string]*time.Timer),
// new instances of the expiration manager will go immediately into
// restore mode
restoreMode: 1,
restoreLocks: locksutil.CreateLocks(),
quitCh: make(chan struct{}),
} }
return exp return exp
} }
@ -100,9 +113,14 @@ func (c *Core) setupExpiration() error {
// Restore the existing state // Restore the existing state
c.logger.Info("expiration: restoring leases") c.logger.Info("expiration: restoring leases")
if err := c.expiration.Restore(); err != nil { errorFunc := func() {
return fmt.Errorf("expiration state restore failed: %v", err) c.logger.Error("expiration: shutting down")
if err := c.Shutdown(); err != nil {
c.logger.Error("expiration: error shutting down core: %v", err)
} }
}
go c.expiration.Restore(errorFunc, 0)
return nil return nil
} }
@ -120,6 +138,21 @@ func (c *Core) stopExpiration() error {
return nil return nil
} }
// lockLease takes out a lock for a given lease ID
func (m *ExpirationManager) lockLease(leaseID string) {
locksutil.LockForKey(m.restoreLocks, leaseID).Lock()
}
// unlockLease unlocks a given lease ID
func (m *ExpirationManager) unlockLease(leaseID string) {
locksutil.LockForKey(m.restoreLocks, leaseID).Unlock()
}
// inRestoreMode returns if we are currently in restore mode
func (m *ExpirationManager) inRestoreMode() bool {
return atomic.LoadInt64(&m.restoreMode) == 1
}
// Tidy cleans up the dangling storage entries for leases. It scans the storage // Tidy cleans up the dangling storage entries for leases. It scans the storage
// view to find all the available leases, checks if the token embedded in it is // view to find all the available leases, checks if the token embedded in it is
// either empty or invalid and in both the cases, it revokes them. It also uses // either empty or invalid and in both the cases, it revokes them. It also uses
@ -127,6 +160,10 @@ func (c *Core) stopExpiration() error {
// not required to use the API that invokes this. This is only intended to // not required to use the API that invokes this. This is only intended to
// clean up the corrupt storage due to bugs. // clean up the corrupt storage due to bugs.
func (m *ExpirationManager) Tidy() error { func (m *ExpirationManager) Tidy() error {
if m.inRestoreMode() {
return errors.New("cannot run tidy while restoring leases")
}
var tidyErrors *multierror.Error var tidyErrors *multierror.Error
if !atomic.CompareAndSwapInt64(&m.tidyLock, 0, 1) { if !atomic.CompareAndSwapInt64(&m.tidyLock, 0, 1) {
@ -198,11 +235,11 @@ func (m *ExpirationManager) Tidy() error {
} else { } else {
if isValid { if isValid {
return return
} else { }
m.logger.Trace("expiration: revoking lease which contains an invalid token", "lease_id", leaseID) m.logger.Trace("expiration: revoking lease which contains an invalid token", "lease_id", leaseID)
revokeLease = true revokeLease = true
deletedCountInvalidToken++ deletedCountInvalidToken++
}
goto REVOKE_CHECK goto REVOKE_CHECK
} }
@ -233,15 +270,33 @@ func (m *ExpirationManager) Tidy() error {
// Restore is used to recover the lease states when starting. // Restore is used to recover the lease states when starting.
// This is used after starting the vault. // This is used after starting the vault.
func (m *ExpirationManager) Restore() error { func (m *ExpirationManager) Restore(errorFunc func(), loadDelay time.Duration) (retErr error) {
m.pendingLock.Lock() defer func() {
defer m.pendingLock.Unlock() // Turn off restore mode. We can do this safely without the lock because
// if restore mode finished successfully, restore mode was already
// disabled with the lock. In an error state, this will allow the
// Stop() function to shut everything down.
atomic.StoreInt64(&m.restoreMode, 0)
switch {
case retErr == nil:
case errwrap.Contains(retErr, ErrBarrierSealed.Error()):
// Don't run error func because we're likely already shutting down
m.logger.Warn("expiration: barrier sealed while restoring leases, stopping lease loading")
retErr = nil
default:
m.logger.Error("expiration: error restoring leases", "error", retErr)
if errorFunc != nil {
errorFunc()
}
}
}()
// Accumulate existing leases // Accumulate existing leases
m.logger.Debug("expiration: collecting leases") m.logger.Debug("expiration: collecting leases")
existing, err := logical.CollectKeys(m.idView) existing, err := logical.CollectKeys(m.idView)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan for leases: %v", err) return errwrap.Wrapf("failed to scan for leases: {{err}}", err)
} }
m.logger.Debug("expiration: leases collected", "num_existing", len(existing)) m.logger.Debug("expiration: leases collected", "num_existing", len(existing))
@ -250,7 +305,7 @@ func (m *ExpirationManager) Restore() error {
quit := make(chan bool) quit := make(chan bool)
// Buffer these channels to prevent deadlocks // Buffer these channels to prevent deadlocks
errs := make(chan error, len(existing)) errs := make(chan error, len(existing))
result := make(chan *leaseEntry, len(existing)) result := make(chan struct{}, len(existing))
// Use a wait group // Use a wait group
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
@ -269,18 +324,21 @@ func (m *ExpirationManager) Restore() error {
return return
} }
le, err := m.loadEntry(leaseID) err := m.processRestore(leaseID, loadDelay)
if err != nil { if err != nil {
errs <- err errs <- err
continue continue
} }
// Write results out to the result channel // Send message that lease is done
result <- le result <- struct{}{}
// quit early // quit early
case <-quit: case <-quit:
return return
case <-m.quitCh:
return
} }
} }
}() }()
@ -291,7 +349,7 @@ func (m *ExpirationManager) Restore() error {
go func() { go func() {
defer wg.Done() defer wg.Done()
for i, leaseID := range existing { for i, leaseID := range existing {
if i%500 == 0 { if i > 0 && i%500 == 0 {
m.logger.Trace("expiration: leases loading", "progress", i) m.logger.Trace("expiration: leases loading", "progress", i)
} }
@ -299,6 +357,9 @@ func (m *ExpirationManager) Restore() error {
case <-quit: case <-quit:
return return
case <-m.quitCh:
return
default: default:
broker <- leaseID broker <- leaseID
} }
@ -308,49 +369,64 @@ func (m *ExpirationManager) Restore() error {
close(broker) close(broker)
}() }()
// Restore each key by pulling from the result chan // Ensure all keys on the chan are processed
for i := 0; i < len(existing); i++ { for i := 0; i < len(existing); i++ {
select { select {
case err := <-errs: case err := <-errs:
// Close all go routines // Close all go routines
close(quit) close(quit)
return err return err
case le := <-result: case <-m.quitCh:
close(quit)
return nil
// If there is no entry, nothing to restore case <-result:
if le == nil {
continue
}
// If there is no expiry time, don't do anything
if le.ExpireTime.IsZero() {
continue
}
// Determine the remaining time to expiration
expires := le.ExpireTime.Sub(time.Now())
if expires <= 0 {
expires = minRevokeDelay
}
// Setup revocation timer
m.pending[le.LeaseID] = time.AfterFunc(expires, func() {
m.expireID(le.LeaseID)
})
} }
} }
// Let all go routines finish // Let all go routines finish
wg.Wait() wg.Wait()
if len(m.pending) > 0 { m.restoreModeLock.Lock()
if m.logger.IsInfo() { m.restoreLoaded = sync.Map{}
m.logger.Info("expire: leases restored", "restored_lease_count", len(m.pending)) m.restoreLocks = nil
} atomic.StoreInt64(&m.restoreMode, 0)
m.restoreModeLock.Unlock()
m.logger.Info("expiration: lease restore complete")
return nil
}
// processRestore takes a lease and restores it in the expiration manager if it has
// not already been seen
func (m *ExpirationManager) processRestore(leaseID string, loadDelay time.Duration) error {
m.restoreRequestLock.RLock()
defer m.restoreRequestLock.RUnlock()
// Check if the lease has been seen
if _, ok := m.restoreLoaded.Load(leaseID); ok {
return nil
} }
m.lockLease(leaseID)
defer m.unlockLease(leaseID)
// Check again with the lease locked
if _, ok := m.restoreLoaded.Load(leaseID); ok {
return nil
}
// Useful for testing to add latency to all load requests
if loadDelay > 0 {
time.Sleep(loadDelay)
}
// Load lease and restore expiration timer
_, err := m.loadEntryInternal(leaseID, true, false)
if err != nil {
return err
}
return nil return nil
} }
@ -358,12 +434,26 @@ func (m *ExpirationManager) Restore() error {
// This must be called before sealing the view. // This must be called before sealing the view.
func (m *ExpirationManager) Stop() error { func (m *ExpirationManager) Stop() error {
// Stop all the pending expiration timers // Stop all the pending expiration timers
m.logger.Debug("expiration: stop triggered")
defer m.logger.Debug("expiration: finished stopping")
m.pendingLock.Lock() m.pendingLock.Lock()
for _, timer := range m.pending { for _, timer := range m.pending {
timer.Stop() timer.Stop()
} }
m.pending = make(map[string]*time.Timer) m.pending = make(map[string]*time.Timer)
m.pendingLock.Unlock() m.pendingLock.Unlock()
close(m.quitCh)
if m.inRestoreMode() {
for {
if !m.inRestoreMode() {
break
}
time.Sleep(10 * time.Millisecond)
}
}
return nil return nil
} }
@ -378,6 +468,7 @@ func (m *ExpirationManager) Revoke(leaseID string) error {
// during revocation and still remove entries/index/lease timers // during revocation and still remove entries/index/lease timers
func (m *ExpirationManager) revokeCommon(leaseID string, force, skipToken bool) error { func (m *ExpirationManager) revokeCommon(leaseID string, force, skipToken bool) error {
defer metrics.MeasureSince([]string{"expire", "revoke-common"}, time.Now()) defer metrics.MeasureSince([]string{"expire", "revoke-common"}, time.Now())
// Load the entry // Load the entry
le, err := m.loadEntry(leaseID) le, err := m.loadEntry(leaseID)
if err != nil { if err != nil {
@ -394,13 +485,13 @@ func (m *ExpirationManager) revokeCommon(leaseID string, force, skipToken bool)
if err := m.revokeEntry(le); err != nil { if err := m.revokeEntry(le); err != nil {
if !force { if !force {
return err return err
} else { }
if m.logger.IsWarn() { if m.logger.IsWarn() {
m.logger.Warn("revocation from the backend failed, but in force mode so ignoring", "error", err) m.logger.Warn("revocation from the backend failed, but in force mode so ignoring", "error", err)
} }
} }
} }
}
// Delete the entry // Delete the entry
if err := m.deleteEntry(leaseID); err != nil { if err := m.deleteEntry(leaseID); err != nil {
@ -447,6 +538,7 @@ func (m *ExpirationManager) RevokePrefix(prefix string) error {
// token store's revokeSalted function. // token store's revokeSalted function.
func (m *ExpirationManager) RevokeByToken(te *TokenEntry) error { func (m *ExpirationManager) RevokeByToken(te *TokenEntry) error {
defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now()) defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now())
// Lookup the leases // Lookup the leases
existing, err := m.lookupByToken(te.ID) existing, err := m.lookupByToken(te.ID)
if err != nil { if err != nil {
@ -455,7 +547,7 @@ func (m *ExpirationManager) RevokeByToken(te *TokenEntry) error {
// Revoke all the keys // Revoke all the keys
for idx, leaseID := range existing { for idx, leaseID := range existing {
if err := m.Revoke(leaseID); err != nil { if err := m.revokeCommon(leaseID, false, false); err != nil {
return fmt.Errorf("failed to revoke '%s' (%d / %d): %v", return fmt.Errorf("failed to revoke '%s' (%d / %d): %v",
leaseID, idx+1, len(existing), err) leaseID, idx+1, len(existing), err)
} }
@ -482,6 +574,11 @@ func (m *ExpirationManager) RevokeByToken(te *TokenEntry) error {
} }
func (m *ExpirationManager) revokePrefixCommon(prefix string, force bool) error { func (m *ExpirationManager) revokePrefixCommon(prefix string, force bool) error {
if m.inRestoreMode() {
m.restoreRequestLock.Lock()
defer m.restoreRequestLock.Unlock()
}
// Ensure there is a trailing slash // Ensure there is a trailing slash
if !strings.HasSuffix(prefix, "/") { if !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/" prefix = prefix + "/"
@ -509,6 +606,7 @@ func (m *ExpirationManager) revokePrefixCommon(prefix string, force bool) error
// and a renew interval. The increment may be ignored. // and a renew interval. The increment may be ignored.
func (m *ExpirationManager) Renew(leaseID string, increment time.Duration) (*logical.Response, error) { func (m *ExpirationManager) Renew(leaseID string, increment time.Duration) (*logical.Response, error) {
defer metrics.MeasureSince([]string{"expire", "renew"}, time.Now()) defer metrics.MeasureSince([]string{"expire", "renew"}, time.Now())
// Load the entry // Load the entry
le, err := m.loadEntry(leaseID) le, err := m.loadEntry(leaseID)
if err != nil { if err != nil {
@ -562,11 +660,51 @@ func (m *ExpirationManager) Renew(leaseID string, increment time.Duration) (*log
return resp, nil return resp, nil
} }
// RestoreSaltedTokenCheck verifies that the token is not expired while running
// in restore mode. If we are not in restore mode, the lease has already been
// restored or the lease still has time left, it returns true.
func (m *ExpirationManager) RestoreSaltedTokenCheck(source string, saltedID string) (bool, error) {
defer metrics.MeasureSince([]string{"expire", "restore-token-check"}, time.Now())
// Return immediately if we are not in restore mode, expiration manager is
// already loaded
if !m.inRestoreMode() {
return true, nil
}
m.restoreModeLock.RLock()
defer m.restoreModeLock.RUnlock()
// Check again after we obtain the lock
if !m.inRestoreMode() {
return true, nil
}
leaseID := path.Join(source, saltedID)
m.lockLease(leaseID)
defer m.unlockLease(leaseID)
le, err := m.loadEntryInternal(leaseID, true, true)
if err != nil {
return false, err
}
if le != nil && !le.ExpireTime.IsZero() {
expires := le.ExpireTime.Sub(time.Now())
if expires <= 0 {
return false, nil
}
}
return true, nil
}
// RenewToken is used to renew a token which does not need to // RenewToken is used to renew a token which does not need to
// invoke a logical backend. // invoke a logical backend.
func (m *ExpirationManager) RenewToken(req *logical.Request, source string, token string, func (m *ExpirationManager) RenewToken(req *logical.Request, source string, token string,
increment time.Duration) (*logical.Response, error) { increment time.Duration) (*logical.Response, error) {
defer metrics.MeasureSince([]string{"expire", "renew-token"}, time.Now()) defer metrics.MeasureSince([]string{"expire", "renew-token"}, time.Now())
// Compute the Lease ID // Compute the Lease ID
saltedID, err := m.tokenStore.SaltID(token) saltedID, err := m.tokenStore.SaltID(token)
if err != nil { if err != nil {
@ -800,8 +938,19 @@ func (m *ExpirationManager) updatePending(le *leaseEntry, leaseTotal time.Durati
// Check for an existing timer // Check for an existing timer
timer, ok := m.pending[le.LeaseID] timer, ok := m.pending[le.LeaseID]
// If there is no expiry time, don't do anything
if le.ExpireTime.IsZero() {
// if the timer happened to exist, stop the time and delete it from the
// pending timers.
if ok {
timer.Stop()
delete(m.pending, le.LeaseID)
}
return
}
// Create entry if it does not exist // Create entry if it does not exist
if !ok && leaseTotal > 0 { if !ok {
timer := time.AfterFunc(leaseTotal, func() { timer := time.AfterFunc(leaseTotal, func() {
m.expireID(le.LeaseID) m.expireID(le.LeaseID)
}) })
@ -809,17 +958,8 @@ func (m *ExpirationManager) updatePending(le *leaseEntry, leaseTotal time.Durati
return return
} }
// Delete the timer if the expiration time is zero
if ok && leaseTotal == 0 {
timer.Stop()
delete(m.pending, le.LeaseID)
return
}
// Extend the timer by the lease total // Extend the timer by the lease total
if ok && leaseTotal > 0 {
timer.Reset(leaseTotal) timer.Reset(leaseTotal)
}
} }
// expireID is invoked when a given ID is expired // expireID is invoked when a given ID is expired
@ -830,17 +970,23 @@ func (m *ExpirationManager) expireID(leaseID string) {
m.pendingLock.Unlock() m.pendingLock.Unlock()
for attempt := uint(0); attempt < maxRevokeAttempts; attempt++ { for attempt := uint(0); attempt < maxRevokeAttempts; attempt++ {
select {
case <-m.quitCh:
m.logger.Error("expiration: shutting down, not attempting further revocation of lease", "lease_id", leaseID)
return
default:
}
err := m.Revoke(leaseID) err := m.Revoke(leaseID)
if err == nil { if err == nil {
if m.logger.IsInfo() { if m.logger.IsInfo() {
m.logger.Info("expire: revoked lease", "lease_id", leaseID) m.logger.Info("expiration: revoked lease", "lease_id", leaseID)
} }
return return
} }
m.logger.Error("expire: failed to revoke lease", "lease_id", leaseID, "error", err) m.logger.Error("expiration: failed to revoke lease", "lease_id", leaseID, "error", err)
time.Sleep((1 << attempt) * revokeRetryBase) time.Sleep((1 << attempt) * revokeRetryBase)
} }
m.logger.Error("expire: maximum revoke attempts reached", "lease_id", leaseID) m.logger.Error("expiration: maximum revoke attempts reached", "lease_id", leaseID)
} }
// revokeEntry is used to attempt revocation of an internal entry // revokeEntry is used to attempt revocation of an internal entry
@ -902,6 +1048,24 @@ func (m *ExpirationManager) renewAuthEntry(req *logical.Request, le *leaseEntry,
// loadEntry is used to read a lease entry // loadEntry is used to read a lease entry
func (m *ExpirationManager) loadEntry(leaseID string) (*leaseEntry, error) { func (m *ExpirationManager) loadEntry(leaseID string) (*leaseEntry, error) {
// Take out the lease locks after we ensure we are in restore mode
restoreMode := m.inRestoreMode()
if restoreMode {
m.restoreModeLock.RLock()
defer m.restoreModeLock.RUnlock()
restoreMode = m.inRestoreMode()
if restoreMode {
m.lockLease(leaseID)
defer m.unlockLease(leaseID)
}
}
return m.loadEntryInternal(leaseID, restoreMode, true)
}
// loadEntryInternal is used when you need to load an entry but also need to
// control the lifecycle of the restoreLock
func (m *ExpirationManager) loadEntryInternal(leaseID string, restoreMode bool, checkRestored bool) (*leaseEntry, error) {
out, err := m.idView.Get(leaseID) out, err := m.idView.Get(leaseID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read lease entry: %v", err) return nil, fmt.Errorf("failed to read lease entry: %v", err)
@ -913,6 +1077,24 @@ func (m *ExpirationManager) loadEntry(leaseID string) (*leaseEntry, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode lease entry: %v", err) return nil, fmt.Errorf("failed to decode lease entry: %v", err)
} }
if restoreMode {
if checkRestored {
// If we have already loaded this lease, we don't need to update on
// load. In the case of renewal and revocation, updatePending will be
// done after making the appropriate modifications to the lease.
if _, ok := m.restoreLoaded.Load(leaseID); ok {
return le, nil
}
}
// Update the cache of restored leases, either synchronously or through
// the lazy loaded restore process
m.restoreLoaded.Store(le.LeaseID, struct{}{})
// Setup revocation timer
m.updatePending(le, le.ExpireTime.Sub(time.Now()))
}
return le, nil return le, nil
} }
@ -1035,9 +1217,9 @@ func (m *ExpirationManager) lookupByToken(token string) ([]string, error) {
// emitMetrics is invoked periodically to emit statistics // emitMetrics is invoked periodically to emit statistics
func (m *ExpirationManager) emitMetrics() { func (m *ExpirationManager) emitMetrics() {
m.pendingLock.Lock() m.pendingLock.RLock()
num := len(m.pending) num := len(m.pending)
m.pendingLock.Unlock() m.pendingLock.RUnlock()
metrics.SetGauge([]string{"expire", "num_leases"}, float32(num)) metrics.SetGauge([]string{"expire", "num_leases"}, float32(num))
} }

View File

@ -37,6 +37,9 @@ func TestExpiration_Tidy(t *testing.T) {
var err error var err error
exp := mockExpiration(t) exp := mockExpiration(t)
if err := exp.Restore(nil, 0); err != nil {
t.Fatal(err)
}
// Set up a count function to calculate number of leases // Set up a count function to calculate number of leases
count := 0 count := 0
@ -210,7 +213,7 @@ func TestExpiration_Tidy(t *testing.T) {
if !(err1 != nil && err1.Error() == "tidy operation on leases is already in progress") && if !(err1 != nil && err1.Error() == "tidy operation on leases is already in progress") &&
!(err2 != nil && err2.Error() == "tidy operation on leases is already in progress") { !(err2 != nil && err2.Error() == "tidy operation on leases is already in progress") {
t.Fatal("expected at least one of err1 or err2 to be set; err1: %#v\n err2:%#v\n", err1, err2) t.Fatalf("expected at least one of err1 or err2 to be set; err1: %#v\n err2:%#v\n", err1, err2)
} }
root, err := exp.tokenStore.rootToken() root, err := exp.tokenStore.rootToken()
@ -311,6 +314,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
req := &logical.Request{ req := &logical.Request{
Operation: logical.ReadOperation, Operation: logical.ReadOperation,
Path: "prod/aws/" + pathUUID, Path: "prod/aws/" + pathUUID,
ClientToken: "root",
} }
resp := &logical.Response{ resp := &logical.Response{
Secret: &logical.Secret{ Secret: &logical.Secret{
@ -337,7 +341,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend,
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
err = exp.Restore() err = exp.Restore(nil, 0)
// Restore // Restore
if err != nil { if err != nil {
b.Fatalf("err: %v", err) b.Fatalf("err: %v", err)
@ -395,7 +399,7 @@ func TestExpiration_Restore(t *testing.T) {
} }
// Restore // Restore
err = exp.Restore() err = exp.Restore(nil, 0)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }

View File

@ -267,17 +267,16 @@ func testTokenStore(t testing.T, c *Core) *TokenStore {
} }
ts := tokenstore.(*TokenStore) ts := tokenstore.(*TokenStore)
router := NewRouter() err = c.router.Unmount("auth/token/")
err = router.Mount(ts, "auth/token/", &MountEntry{Table: credentialTableType, UUID: "authtokenuuid", Path: "auth/token", Accessor: "authtokenaccessor"}, ts.view) if err != nil {
t.Fatal(err)
}
err = c.router.Mount(ts, "auth/token/", &MountEntry{Table: credentialTableType, UUID: "authtokenuuid", Path: "auth/token", Accessor: "authtokenaccessor"}, ts.view)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
subview := c.systemBarrierView.SubView(expirationSubPath) ts.SetExpirationManager(c.expiration)
logger := logformat.NewVaultLogger(log.LevelTrace)
exp := NewExpirationManager(router, subview, ts, logger)
ts.SetExpirationManager(exp)
return ts return ts
} }

View File

@ -894,9 +894,9 @@ func (ts *TokenStore) Lookup(id string) (*TokenEntry, error) {
// lookupSalted is used to find a token given its salted ID. If tainted is // lookupSalted is used to find a token given its salted ID. If tainted is
// true, entries that are in some revocation state (currently, indicated by num // true, entries that are in some revocation state (currently, indicated by num
// uses < 0), the entry will be returned anyways // uses < 0), the entry will be returned anyways
func (ts *TokenStore) lookupSalted(saltedId string, tainted bool) (*TokenEntry, error) { func (ts *TokenStore) lookupSalted(saltedID string, tainted bool) (*TokenEntry, error) {
// Lookup token // Lookup token
path := lookupPrefix + saltedId path := lookupPrefix + saltedID
raw, err := ts.view.Get(path) raw, err := ts.view.Get(path)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read entry: %v", err) return nil, fmt.Errorf("failed to read entry: %v", err)
@ -918,6 +918,16 @@ func (ts *TokenStore) lookupSalted(saltedId string, tainted bool) (*TokenEntry,
return nil, nil return nil, nil
} }
// If we are still restoring the expiration manager, we want to ensure the
// token is not expired
check, err := ts.expiration.RestoreSaltedTokenCheck(entry.Path, saltedID)
if err != nil {
return nil, fmt.Errorf("failed to check token in restore mode: %v", err)
}
if !check {
return nil, nil
}
persistNeeded := false persistNeeded := false
// Upgrade the deprecated fields // Upgrade the deprecated fields

View File

@ -3,6 +3,7 @@ package vault
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"path"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
@ -11,6 +12,7 @@ import (
"time" "time"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/helper/locksutil"
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
@ -448,6 +450,8 @@ func TestTokenStore_CreateLookup(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ts2.SetExpirationManager(c.expiration)
if err := ts2.Initialize(); err != nil { if err := ts2.Initialize(); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -493,6 +497,8 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
ts2.SetExpirationManager(c.expiration)
if err := ts2.Initialize(); err != nil { if err := ts2.Initialize(); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -507,6 +513,73 @@ func TestTokenStore_CreateLookup_ProvidedID(t *testing.T) {
} }
} }
func TestTokenStore_CreateLookup_ExpirationInRestoreMode(t *testing.T) {
_, ts, _, _ := TestCoreWithTokenStore(t)
ent := &TokenEntry{Path: "test", Policies: []string{"dev", "ops"}}
if err := ts.create(ent); err != nil {
t.Fatalf("err: %v", err)
}
if ent.ID == "" {
t.Fatalf("missing ID")
}
// Replace the lease with a lease with an expire time in the past
saltedID, err := ts.SaltID(ent.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
// Create a lease entry
leaseID := path.Join(ent.Path, saltedID)
le := &leaseEntry{
LeaseID: leaseID,
ClientToken: ent.ID,
Path: ent.Path,
IssueTime: time.Now(),
ExpireTime: time.Now().Add(1 * time.Hour),
}
if err := ts.expiration.persistEntry(le); err != nil {
t.Fatalf("err: %v", err)
}
out, err := ts.Lookup(ent.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if !reflect.DeepEqual(out, ent) {
t.Fatalf("bad: expected:%#v\nactual:%#v", ent, out)
}
// Set to expired lease time
le.ExpireTime = time.Now().Add(-1 * time.Hour)
if err := ts.expiration.persistEntry(le); err != nil {
t.Fatalf("err: %v", err)
}
err = ts.expiration.Stop()
if err != nil {
t.Fatal(err)
}
// Reset expiration manager to restore mode
ts.expiration.restoreModeLock.Lock()
ts.expiration.restoreMode = 1
ts.expiration.restoreLocks = locksutil.CreateLocks()
ts.expiration.quitCh = make(chan struct{})
ts.expiration.restoreModeLock.Unlock()
// Test that the token lookup does not return the token entry due to the
// expired lease
out, err = ts.Lookup(ent.ID)
if err != nil {
t.Fatalf("err: %v", err)
}
if out != nil {
t.Fatalf("lease expired, no token expected: %#v", out)
}
}
func TestTokenStore_UseToken(t *testing.T) { func TestTokenStore_UseToken(t *testing.T) {
_, ts, _, root := TestCoreWithTokenStore(t) _, ts, _, root := TestCoreWithTokenStore(t)
@ -2530,9 +2603,14 @@ func TestTokenStore_RoleExplicitMaxTTL(t *testing.T) {
t.Fatalf("expected error") t.Fatalf("expected error")
} }
time.Sleep(2 * time.Second)
req.Operation = logical.ReadOperation req.Operation = logical.ReadOperation
req.Path = "auth/token/lookup-self" req.Path = "auth/token/lookup-self"
resp, err = core.HandleRequest(req) resp, err = core.HandleRequest(req)
if resp != nil && err == nil {
t.Fatalf("expected error, response is %#v", *resp)
}
if err == nil { if err == nil {
t.Fatalf("expected error") t.Fatalf("expected error")
} }