Token revocation refactor (#4512)

* Hand off lease expiration to expiration manager via timers

* Use sync.Map as the cache to track token deletion state

* Add CreateOrFetchRevocationLeaseByToken to hand off token revocation to exp manager

* Update revoke and revoke-self handlers

* Fix tests

* revokeSalted: Move token entry deletion into the deferred func

* Fix test race

* Add blocking lease revocation test

* Remove test log

* Add HandlerFunc on NoopBackend, adjust locks, and add test

* Add sleep to allow for revocations to settle

* Various updates

* Rename some functions and variables to be more clear
* Change step-down and seal to use expmgr for revoke functionality like
during request handling
* Attempt to WAL the token as being invalid as soon as possible so that
further usage will fail even if revocation does not fully complete

* Address feedback

* Return invalid lease on negative TTL

* Revert "Return invalid lease on negative TTL"

This reverts commit a39597ecdc23cf7fc69fe003eef9f10d533551d8.

* Extend sleep on tests
This commit is contained in:
Calvin Leung Huang 2018-05-10 15:50:02 -04:00 committed by GitHub
parent 2ef3635858
commit 0678d6ba4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 359 additions and 138 deletions

View File

@ -1429,10 +1429,13 @@ func (c *Core) sealInitCommon(ctx context.Context, req *logical.Request) (retErr
return retErr return retErr
} }
if te != nil && te.NumUses == -1 { if te != nil && te.NumUses == tokenRevocationPending {
// Token needs to be revoked. We do this immediately here because // Token needs to be revoked. We do this immediately here because
// we won't have a token store after sealing. // we won't have a token store after sealing.
err = c.tokenStore.Revoke(c.activeContext, te.ID) leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err == nil {
err = c.expiration.Revoke(leaseID)
}
if err != nil { if err != nil {
c.logger.Error("token needed revocation before seal but failed to revoke", "error", err) c.logger.Error("token needed revocation before seal but failed to revoke", "error", err)
retErr = multierror.Append(retErr, ErrInternalError) retErr = multierror.Append(retErr, ErrInternalError)
@ -1540,10 +1543,13 @@ func (c *Core) StepDown(req *logical.Request) (retErr error) {
return retErr return retErr
} }
if te != nil && te.NumUses == -1 { if te != nil && te.NumUses == tokenRevocationPending {
// Token needs to be revoked. We do this immediately here because // Token needs to be revoked. We do this immediately here because
// we won't have a token store after sealing. // we won't have a token store after sealing.
err = c.tokenStore.Revoke(c.activeContext, te.ID) leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err == nil {
err = c.expiration.Revoke(leaseID)
}
if err != nil { if err != nil {
c.logger.Error("token needed revocation before step-down but failed to revoke", "error", err) c.logger.Error("token needed revocation before step-down but failed to revoke", "error", err)
retErr = multierror.Append(retErr, ErrInternalError) retErr = multierror.Append(retErr, ErrInternalError)

View File

@ -561,18 +561,34 @@ func (m *ExpirationManager) RevokeByToken(te *TokenEntry) error {
defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now()) defer metrics.MeasureSince([]string{"expire", "revoke-by-token"}, time.Now())
// Lookup the leases // Lookup the leases
existing, err := m.lookupByToken(te.ID) existing, err := m.lookupLeasesByToken(te.ID)
if err != nil { if err != nil {
return errwrap.Wrapf("failed to scan for leases: {{err}}", err) return errwrap.Wrapf("failed to scan for leases: {{err}}", err)
} }
// Revoke all the keys // Revoke all the keys
for idx, leaseID := range existing { for _, leaseID := range existing {
if err := m.revokeCommon(leaseID, false, false); err != nil { // Load the entry
return errwrap.Wrapf(fmt.Sprintf("failed to revoke %q (%d / %d): {{err}}", leaseID, idx+1, len(existing)), err) le, err := m.loadEntry(leaseID)
if err != nil {
return err
}
// If there's a lease, set expiration to now, persist, and call
// updatePending to hand off revocation to the expiration manager's pending
// timer map
if le != nil {
le.ExpireTime = time.Now()
if err := m.persistEntry(le); err != nil {
return err
}
m.updatePending(le, 0)
} }
} }
// te.Path should never be empty, but we check just in case
if te.Path != "" { if te.Path != "" {
saltedID, err := m.tokenStore.SaltID(m.quitContext, te.ID) saltedID, err := m.tokenStore.SaltID(m.quitContext, te.ID)
if err != nil { if err != nil {
@ -1054,7 +1070,7 @@ func (m *ExpirationManager) revokeEntry(le *leaseEntry) error {
// Revocation of login tokens is special since we can by-pass the // Revocation of login tokens is special since we can by-pass the
// backend and directly interact with the token store // backend and directly interact with the token store
if le.Auth != nil { if le.Auth != nil {
if err := m.tokenStore.RevokeTree(m.quitContext, le.ClientToken); err != nil { if err := m.tokenStore.revokeTree(m.quitContext, le.ClientToken); err != nil {
return errwrap.Wrapf("failed to revoke token: {{err}}", err) return errwrap.Wrapf("failed to revoke token: {{err}}", err)
} }
@ -1247,8 +1263,58 @@ func (m *ExpirationManager) removeIndexByToken(token, leaseID string) error {
return nil return nil
} }
// lookupByToken is used to lookup all the leaseID's via the // CreateOrFetchRevocationLeaseByToken is used to create or fetch the matching
func (m *ExpirationManager) lookupByToken(token string) ([]string, error) { // leaseID for a particular token. The lease is set to expire immediately after
// it's created.
func (m *ExpirationManager) CreateOrFetchRevocationLeaseByToken(te *TokenEntry) (string, error) {
// Fetch the saltedID of the token and construct the leaseID
saltedID, err := m.tokenStore.SaltID(m.quitContext, te.ID)
if err != nil {
return "", err
}
leaseID := path.Join(te.Path, saltedID)
// Load the entry
le, err := m.loadEntry(leaseID)
if err != nil {
return "", err
}
// If there's no associated leaseEntry for the token, we create one
if le == nil {
auth := &logical.Auth{
ClientToken: te.ID,
LeaseOptions: logical.LeaseOptions{
TTL: time.Nanosecond,
},
}
if strings.Contains(te.Path, "..") {
return "", consts.ErrPathContainsParentReferences
}
// Create a lease entry
now := time.Now()
le = &leaseEntry{
LeaseID: leaseID,
ClientToken: auth.ClientToken,
Auth: auth,
Path: te.Path,
IssueTime: now,
ExpireTime: now.Add(time.Nanosecond),
}
// Encode the entry
if err := m.persistEntry(le); err != nil {
return "", err
}
}
return le.LeaseID, nil
}
// lookupLeasesByToken is used to lookup all the leaseID's via the tokenID
func (m *ExpirationManager) lookupLeasesByToken(token string) ([]string, error) {
saltedID, err := m.tokenStore.SaltID(m.quitContext, token) saltedID, err := m.tokenStore.SaltID(m.quitContext, token)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -742,6 +742,108 @@ func TestExpiration_RevokeByToken(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
time.Sleep(300 * time.Millisecond)
noop.Lock()
defer noop.Unlock()
if len(noop.Requests) != 3 {
t.Fatalf("Bad: %v", noop.Requests)
}
for _, req := range noop.Requests {
if req.Operation != logical.RevokeOperation {
t.Fatalf("Bad: %v", req)
}
}
expect := []string{
"foo",
"sub/bar",
"zip",
}
sort.Strings(noop.Paths)
sort.Strings(expect)
if !reflect.DeepEqual(noop.Paths, expect) {
t.Fatalf("bad: %v", noop.Paths)
}
}
func TestExpiration_RevokeByToken_Blocking(t *testing.T) {
exp := mockExpiration(t)
noop := &NoopBackend{}
// Request handle with a timeout context that simulates blocking lease revocation.
noop.RequestHandler = func(ctx context.Context, req *logical.Request) (*logical.Response, error) {
ctx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel()
select {
case <-ctx.Done():
return noop.Response, nil
}
}
_, barrier, _ := mockBarrier(t)
view := NewBarrierView(barrier, "logical/")
meUUID, err := uuid.GenerateUUID()
if err != nil {
t.Fatal(err)
}
err = exp.router.Mount(noop, "prod/aws/", &MountEntry{Path: "prod/aws/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor"}, view)
if err != nil {
t.Fatal(err)
}
paths := []string{
"prod/aws/foo",
"prod/aws/sub/bar",
"prod/aws/zip",
}
for _, path := range paths {
req := &logical.Request{
Operation: logical.ReadOperation,
Path: path,
ClientToken: "foobarbaz",
}
resp := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
TTL: 1 * time.Minute,
},
},
Data: map[string]interface{}{
"access_key": "xyz",
"secret_key": "abcd",
},
}
_, err := exp.Register(req, resp)
if err != nil {
t.Fatalf("err: %v", err)
}
}
// Should nuke all the keys
te := &TokenEntry{
ID: "foobarbaz",
}
if err := exp.RevokeByToken(te); err != nil {
t.Fatalf("err: %v", err)
}
// Lock and check that no requests has gone through yet
noop.Lock()
if len(noop.Requests) != 0 {
t.Fatalf("Bad: %v", noop.Requests)
}
noop.Unlock()
// Wait for a bit for timeouts to trigger and pending revocations to go
// through and then we relock
time.Sleep(300 * time.Millisecond)
noop.Lock()
defer noop.Unlock()
// Now make sure that all requests have gone through
if len(noop.Requests) != 3 { if len(noop.Requests) != 3 {
t.Fatalf("Bad: %v", noop.Requests) t.Fatalf("Bad: %v", noop.Requests)
} }
@ -1239,6 +1341,8 @@ func TestExpiration_revokeEntry_token(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
time.Sleep(300 * time.Millisecond)
out, err := exp.tokenStore.Lookup(context.Background(), le.ClientToken) out, err := exp.tokenStore.Lookup(context.Background(), le.ClientToken)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)

View File

@ -44,7 +44,7 @@ func (g generateStandardRootToken) generate(ctx context.Context, c *Core) (strin
} }
cleanupFunc := func() { cleanupFunc := func() {
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
} }
return te.ID, cleanupFunc, nil return te.ID, cleanupFunc, nil

View File

@ -3184,7 +3184,7 @@ func (b *SystemBackend) responseWrappingUnwrap(ctx context.Context, token string
return "", errwrap.Wrapf("error decrementing wrapping token's use-count: {{err}}", err) return "", errwrap.Wrapf("error decrementing wrapping token's use-count: {{err}}", err)
} }
defer b.Core.tokenStore.Revoke(ctx, token) defer b.Core.tokenStore.revokeOrphan(ctx, token)
} }
cubbyReq := &logical.Request{ cubbyReq := &logical.Request{
@ -3294,7 +3294,7 @@ func (b *SystemBackend) handleWrappingRewrap(ctx context.Context, req *logical.R
if err != nil { if err != nil {
return nil, errwrap.Wrapf("error decrementing wrapping token's use-count: {{err}}", err) return nil, errwrap.Wrapf("error decrementing wrapping token's use-count: {{err}}", err)
} }
defer b.Core.tokenStore.Revoke(ctx, token) defer b.Core.tokenStore.revokeOrphan(ctx, token)
} }
// Fetch the original TTL // Fetch the original TTL

View File

@ -182,12 +182,15 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
retErr = multierror.Append(retErr, logical.ErrPermissionDenied) retErr = multierror.Append(retErr, logical.ErrPermissionDenied)
return nil, nil, retErr return nil, nil, retErr
} }
if te.NumUses == -1 { if te.NumUses == tokenRevocationPending {
// We defer a revocation until after logic has run, since this is a // We defer a revocation until after logic has run, since this is a
// valid request (this is the token's final use). We pass the ID in // valid request (this is the token's final use). We pass the ID in
// directly just to be safe in case something else modifies te later. // directly just to be safe in case something else modifies te later.
defer func(id string) { defer func(id string) {
err = c.tokenStore.Revoke(ctx, id) leaseID, err := c.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err == nil {
err = c.expiration.Revoke(leaseID)
}
if err != nil { if err != nil {
c.logger.Error("failed to revoke token", "error", err) c.logger.Error("failed to revoke token", "error", err)
retResp = nil retResp = nil
@ -398,7 +401,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp
} }
if err := c.expiration.RegisterAuth(te.Path, resp.Auth); err != nil { if err := c.expiration.RegisterAuth(te.Path, resp.Auth); err != nil {
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to register token lease", "request_path", req.Path, "error", err) c.logger.Error("failed to register token lease", "request_path", req.Path, "error", err)
retErr = multierror.Append(retErr, ErrInternalError) retErr = multierror.Append(retErr, ErrInternalError)
return nil, auth, retErr return nil, auth, retErr
@ -604,7 +607,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
// Register with the expiration manager // Register with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, auth); err != nil { if err := c.expiration.RegisterAuth(te.Path, auth); err != nil {
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to register token lease", "request_path", req.Path, "error", err) c.logger.Error("failed to register token lease", "request_path", req.Path, "error", err)
return nil, auth, ErrInternalError return nil, auth, ErrInternalError
} }

View File

@ -14,6 +14,8 @@ import (
"github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical"
) )
type HandlerFunc func(context.Context, *logical.Request) (*logical.Response, error)
type NoopBackend struct { type NoopBackend struct {
sync.Mutex sync.Mutex
@ -22,12 +24,19 @@ type NoopBackend struct {
Paths []string Paths []string
Requests []*logical.Request Requests []*logical.Request
Response *logical.Response Response *logical.Response
RequestHandler HandlerFunc
Invalidations []string Invalidations []string
DefaultLeaseTTL time.Duration DefaultLeaseTTL time.Duration
MaxLeaseTTL time.Duration MaxLeaseTTL time.Duration
} }
func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) {
var err error
resp := n.Response
if n.RequestHandler != nil {
resp, err = n.RequestHandler(ctx, req)
}
n.Lock() n.Lock()
defer n.Unlock() defer n.Unlock()
@ -38,7 +47,7 @@ func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (
return nil, fmt.Errorf("missing view") return nil, fmt.Errorf("missing view")
} }
return n.Response, nil return resp, err
} }
func (n *NoopBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) { func (n *NoopBackend) HandleExistenceCheck(ctx context.Context, req *logical.Request) (bool, bool, error) {

View File

@ -51,19 +51,11 @@ const (
// rolesPrefix is the prefix used to store role information // rolesPrefix is the prefix used to store role information
rolesPrefix = "roles/" rolesPrefix = "roles/"
// tokenRevocationDeferred indicates that the token should not be used // tokenRevocationPending indicates that the token should not be used
// again but is currently fulfilling its final use // again. If this is encountered during an existing request flow, it means
tokenRevocationDeferred = -1 // that the token is but is currently fulfilling its final use; after this
// request it will not be able to be looked up as being valid.
// tokenRevocationInProgress indicates that revocation of that token/its tokenRevocationPending = -1
// leases is ongoing
tokenRevocationInProgress = -2
// tokenRevocationFailed indicates that revocation failed; the entry is
// kept around so that when the tidy function is run it can be tried
// again (or when the revocation function is run again), but all other uses
// will report the token invalid
tokenRevocationFailed = -3
) )
var ( var (
@ -98,6 +90,12 @@ type TokenStore struct {
tokenLocks []*locksutil.LockEntry tokenLocks []*locksutil.LockEntry
// tokenPendingDeletion stores tokens that are being revoked. If the token is
// not in the map, it means that there's no deletion in progress. If the value
// is true it means deletion is in progress, and if false it means deletion
// failed. Revocation needs to handle these states accordingly.
tokensPendingDeletion *sync.Map
cubbyholeDestroyer func(context.Context, *TokenStore, string) error cubbyholeDestroyer func(context.Context, *TokenStore, string) error
logger log.Logger logger log.Logger
@ -122,6 +120,7 @@ func NewTokenStore(ctx context.Context, logger log.Logger, c *Core, config *logi
cubbyholeDestroyer: destroyCubbyhole, cubbyholeDestroyer: destroyCubbyhole,
logger: logger, logger: logger,
tokenLocks: locksutil.CreateLocks(), tokenLocks: locksutil.CreateLocks(),
tokensPendingDeletion: &sync.Map{},
saltLock: sync.RWMutex{}, saltLock: sync.RWMutex{},
identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies, identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies,
} }
@ -916,12 +915,12 @@ func (ts *TokenStore) UseToken(ctx context.Context, te *TokenEntry) (*TokenEntry
// manager revoking children) attempting to acquire the same lock // manager revoking children) attempting to acquire the same lock
// repeatedly. // repeatedly.
if te.NumUses == 1 { if te.NumUses == 1 {
te.NumUses = -1 te.NumUses = tokenRevocationPending
} else { } else {
te.NumUses -= 1 te.NumUses--
} }
err = ts.storeCommon(ctx, te, false) err = ts.store(ctx, te)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1052,7 +1051,7 @@ func (ts *TokenStore) lookupSalted(ctx context.Context, saltedID string, tainted
// If fields are getting upgraded, store the changes // If fields are getting upgraded, store the changes
if persistNeeded { if persistNeeded {
if err := ts.storeCommon(ctx, entry, false); err != nil { if err := ts.store(ctx, entry); err != nil {
return nil, errwrap.Wrapf("failed to persist token upgrade: {{err}}", err) return nil, errwrap.Wrapf("failed to persist token upgrade: {{err}}", err)
} }
} }
@ -1062,7 +1061,7 @@ func (ts *TokenStore) lookupSalted(ctx context.Context, saltedID string, tainted
// Revoke is used to invalidate a given token, any child tokens // Revoke is used to invalidate a given token, any child tokens
// will be orphaned. // will be orphaned.
func (ts *TokenStore) Revoke(ctx context.Context, id string) error { func (ts *TokenStore) revokeOrphan(ctx context.Context, id string) error {
defer metrics.MeasureSince([]string{"token", "revoke"}, time.Now()) defer metrics.MeasureSince([]string{"token", "revoke"}, time.Now())
if id == "" { if id == "" {
return fmt.Errorf("cannot revoke blank token") return fmt.Errorf("cannot revoke blank token")
@ -1078,9 +1077,18 @@ func (ts *TokenStore) Revoke(ctx context.Context, id string) error {
// revokeSalted is used to invalidate a given salted token, // revokeSalted is used to invalidate a given salted token,
// any child tokens will be orphaned. // any child tokens will be orphaned.
func (ts *TokenStore) revokeSalted(ctx context.Context, saltedID string) (ret error) { func (ts *TokenStore) revokeSalted(ctx context.Context, saltedID string) (ret error) {
// Protect the entry lookup/writing with locks. The rub here is that we // Check and set the token deletion state. We only proceed with the deletion
// don't know the ID until we look it up once, so first we look it up, then // if we don't have a pending deletion (empty), or if the deletion previously
// do a locked lookup. // failed (state is false)
state, loaded := ts.tokensPendingDeletion.LoadOrStore(saltedID, true)
// If the entry was loaded and its state is true, we short-circuit
if loaded && state == true {
return nil
}
// The map check above should protect use from any concurrent revocations, so
// doing a bare lookup here should be fine.
entry, err := ts.lookupSalted(ctx, saltedID, true) entry, err := ts.lookupSalted(ctx, saltedID, true)
if err != nil { if err != nil {
return err return err
@ -1089,61 +1097,36 @@ func (ts *TokenStore) revokeSalted(ctx context.Context, saltedID string) (ret er
return nil return nil
} }
lock := locksutil.LockForKey(ts.tokenLocks, entry.ID) if entry.NumUses != tokenRevocationPending {
lock.Lock() entry.NumUses = tokenRevocationPending
if err := ts.store(ctx, entry); err != nil {
// Lookup the token first // The only real reason for this is an underlying storage error
entry, err = ts.lookupSalted(ctx, saltedID, true) // which also means that nothing else in this func or expmgr will
if err != nil { // really work either. So we clear revocation state so the user can
lock.Unlock() // try again.
return err ts.logger.Error("failed to mark token as revoked")
ts.tokensPendingDeletion.Store(saltedID, false)
return err
}
} }
if entry == nil {
lock.Unlock()
return nil
}
// On failure we write -3, so if we hit -2 here we're already running a
// revocation operation. This can happen due to e.g. recursion into this
// function via the expiration manager's RevokeByToken.
if entry.NumUses == tokenRevocationInProgress {
lock.Unlock()
return nil
}
// This acts as a WAL. lookupSalted will no longer return this entry,
// so the token cannot be used, but this way we can keep the entry
// around until after the rest of this function is attempted, and a
// tidy function can key off of this value to try again.
entry.NumUses = tokenRevocationInProgress
err = ts.storeCommon(ctx, entry, false)
lock.Unlock()
if err != nil {
return err
}
// If we are returning an error, mark the entry with -3 to indicate
// failed revocation. This way we don't try to clean up during active
// revocation (-2).
defer func() { defer func() {
// If we succeeded in all other revocation operations after this defer and
// before we return, we can remove the token store entry
if ret == nil {
path := lookupPrefix + saltedID
if err := ts.view.Delete(ctx, path); err != nil {
ret = errwrap.Wrapf("failed to delete entry: {{err}}", err)
}
}
// Check on ret again and update the sync.Map accordingly
if ret != nil { if ret != nil {
lock.Lock() // If we failed on any of the calls within, we store the state as false
defer lock.Unlock() // so that the next call to revokeSalted will retry
ts.tokensPendingDeletion.Store(saltedID, false)
// Lookup the token again to make sure something else didn't } else {
// revoke in the interim ts.tokensPendingDeletion.Delete(saltedID)
entry, err := ts.lookupSalted(ctx, saltedID, true)
if err != nil {
return
}
// If it exists just taint to -3 rather than trying to figure
// out what it means if it's already -3 after the -2 above
if entry != nil {
entry.NumUses = tokenRevocationFailed
ts.storeCommon(ctx, entry, false)
}
} }
}() }()
@ -1219,18 +1202,12 @@ func (ts *TokenStore) revokeSalted(ctx context.Context, saltedID string) (ret er
return errwrap.Wrapf("failed to delete entry: {{err}}", err) return errwrap.Wrapf("failed to delete entry: {{err}}", err)
} }
// Now that the entry is not usable for any revocation tasks, nuke it
path := lookupPrefix + saltedID
if err = ts.view.Delete(ctx, path); err != nil {
return errwrap.Wrapf("failed to delete entry: {{err}}", err)
}
return nil return nil
} }
// RevokeTree is used to invalidate a given token and all // revokeTree is used to invalidate a given token and all
// child tokens. // child tokens.
func (ts *TokenStore) RevokeTree(ctx context.Context, id string) error { func (ts *TokenStore) revokeTree(ctx context.Context, id string) error {
defer metrics.MeasureSince([]string{"token", "revoke-tree"}, time.Now()) defer metrics.MeasureSince([]string{"token", "revoke-tree"}, time.Now())
// Verify the token is not blank // Verify the token is not blank
if id == "" { if id == "" {
@ -1265,6 +1242,11 @@ func (ts *TokenStore) revokeTreeSalted(ctx context.Context, saltedID string) err
// If the length of the children array is zero, // If the length of the children array is zero,
// then we are at a leaf node. // then we are at a leaf node.
if len(children) == 0 { if len(children) == 0 {
// Whenever revokeSalted is called, the token will be removed immediately and
// any underlying secrets will be handed off to the expiration manager which will
// take care of expiring them. If Vault is restarted, any revoked tokens
// would have been deleted, and any pending leases for deletion will be restored
// by the expiration manager.
if err := ts.revokeSalted(ctx, id); err != nil { if err := ts.revokeSalted(ctx, id); err != nil {
return errwrap.Wrapf("failed to revoke entry: {{err}}", err) return errwrap.Wrapf("failed to revoke entry: {{err}}", err)
} }
@ -1617,9 +1599,23 @@ func (ts *TokenStore) handleUpdateRevokeAccessor(ctx context.Context, req *logic
return nil, err return nil, err
} }
// Revoke the token and its children te, err := ts.Lookup(ctx, aEntry.TokenID)
if err := ts.RevokeTree(ctx, aEntry.TokenID); err != nil { if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return nil, err
}
if te == nil {
return logical.ErrorResponse("token not found"), logical.ErrInvalidRequest
}
leaseID, err := ts.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err != nil {
return nil, err
}
err = ts.expiration.Revoke(leaseID)
if err != nil {
return nil, err
} }
if urlaccessor { if urlaccessor {
@ -2054,10 +2050,25 @@ func (ts *TokenStore) handleCreateCommon(ctx context.Context, req *logical.Reque
// in a way that revokes all child tokens. Normally, using sys/revoke/leaseID will revoke // in a way that revokes all child tokens. Normally, using sys/revoke/leaseID will revoke
// the token and all children anyways, but that is only available when there is a lease. // the token and all children anyways, but that is only available when there is a lease.
func (ts *TokenStore) handleRevokeSelf(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { func (ts *TokenStore) handleRevokeSelf(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Revoke the token and its children te, err := ts.Lookup(ctx, req.ClientToken)
if err := ts.RevokeTree(ctx, req.ClientToken); err != nil { if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return nil, err
} }
if te == nil {
return logical.ErrorResponse("token not found"), logical.ErrInvalidRequest
}
leaseID, err := ts.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err != nil {
return nil, err
}
err = ts.expiration.Revoke(leaseID)
if err != nil {
return nil, err
}
return nil, nil return nil, nil
} }
@ -2075,9 +2086,23 @@ func (ts *TokenStore) handleRevokeTree(ctx context.Context, req *logical.Request
urltoken = true urltoken = true
} }
// Revoke the token and its children te, err := ts.Lookup(ctx, id)
if err := ts.RevokeTree(ctx, id); err != nil { if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return nil, err
}
if te == nil {
return logical.ErrorResponse("token not found"), logical.ErrInvalidRequest
}
leaseID, err := ts.expiration.CreateOrFetchRevocationLeaseByToken(te)
if err != nil {
return nil, err
}
err = ts.expiration.Revoke(leaseID)
if err != nil {
return nil, err
} }
if urltoken { if urltoken {
@ -2121,7 +2146,7 @@ func (ts *TokenStore) handleRevokeOrphan(ctx context.Context, req *logical.Reque
} }
// Revoke and orphan // Revoke and orphan
if err := ts.Revoke(ctx, id); err != nil { if err := ts.revokeOrphan(ctx, id); err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
} }

View File

@ -399,6 +399,8 @@ func TestTokenStore_HandleRequest_RevokeAccessor(t *testing.T) {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
} }
time.Sleep(300 * time.Millisecond)
out, err = ts.Lookup(context.Background(), "tokenid") out, err = ts.Lookup(context.Background(), "tokenid")
if err != nil { if err != nil {
t.Fatalf("err: %s", err) t.Fatalf("err: %s", err)
@ -636,10 +638,10 @@ func TestTokenStore_UseToken(t *testing.T) {
if te == nil { if te == nil {
t.Fatalf("token entry for use #2 was nil") t.Fatalf("token entry for use #2 was nil")
} }
if te.NumUses != tokenRevocationDeferred { if te.NumUses != tokenRevocationPending {
t.Fatalf("token entry after use #2 did not have revoke flag") t.Fatalf("token entry after use #2 did not have revoke flag")
} }
ts.Revoke(context.Background(), te.ID) ts.revokeOrphan(context.Background(), te.ID)
// Lookup the token // Lookup the token
ent2, err = ts.Lookup(context.Background(), ent.ID) ent2, err = ts.Lookup(context.Background(), ent.ID)
@ -661,11 +663,11 @@ func TestTokenStore_Revoke(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
err := ts.Revoke(context.Background(), "") err := ts.revokeOrphan(context.Background(), "")
if err.Error() != "cannot revoke blank token" { if err.Error() != "cannot revoke blank token" {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
err = ts.Revoke(context.Background(), ent.ID) err = ts.revokeOrphan(context.Background(), ent.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -719,11 +721,13 @@ func TestTokenStore_Revoke_Leases(t *testing.T) {
} }
// Revoke the token // Revoke the token
err = ts.Revoke(context.Background(), ent.ID) err = ts.revokeOrphan(context.Background(), ent.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
time.Sleep(300 * time.Millisecond)
// Verify the lease is gone // Verify the lease is gone
out, err := ts.expiration.loadEntry(leaseID) out, err := ts.expiration.loadEntry(leaseID)
if err != nil { if err != nil {
@ -747,7 +751,7 @@ func TestTokenStore_Revoke_Orphan(t *testing.T) {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
err := ts.Revoke(context.Background(), ent.ID) err := ts.revokeOrphan(context.Background(), ent.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -778,14 +782,14 @@ func TestTokenStore_RevokeTree(t *testing.T) {
func testTokenStore_RevokeTree_NonRecursive(t testing.TB, depth uint64) { func testTokenStore_RevokeTree_NonRecursive(t testing.TB, depth uint64) {
_, ts, _, _ := TestCoreWithTokenStore(t) _, ts, _, _ := TestCoreWithTokenStore(t)
root, children := buildTokenTree(t, ts, depth) root, children := buildTokenTree(t, ts, depth)
err := ts.RevokeTree(context.Background(), "") err := ts.revokeTree(context.Background(), "")
if err.Error() != "cannot tree-revoke blank token" { if err.Error() != "cannot tree-revoke blank token" {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
// Nuke tree non recursively. // Nuke tree non recursively.
err = ts.RevokeTree(context.Background(), root.ID) err = ts.revokeTree(context.Background(), root.ID)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -881,6 +885,8 @@ func TestTokenStore_RevokeSelf(t *testing.T) {
t.Fatalf("err: %v\nresp: %#v", err, resp) t.Fatalf("err: %v\nresp: %#v", err, resp)
} }
time.Sleep(300 * time.Millisecond)
lookup := []string{ent1.ID, ent2.ID, ent3.ID, ent4.ID} lookup := []string{ent1.ID, ent2.ID, ent3.ID, ent4.ID}
for _, id := range lookup { for _, id := range lookup {
out, err := ts.Lookup(context.Background(), id) out, err := ts.Lookup(context.Background(), id)
@ -1377,6 +1383,8 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
time.Sleep(300 * time.Millisecond)
out, err := ts.Lookup(context.Background(), "child") out, err := ts.Lookup(context.Background(), "child")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -1413,6 +1421,8 @@ func TestTokenStore_HandleRequest_RevokeOrphan(t *testing.T) {
t.Fatalf("bad: %#v", resp) t.Fatalf("bad: %#v", resp)
} }
time.Sleep(300 * time.Millisecond)
out, err := ts.Lookup(context.Background(), "child") out, err := ts.Lookup(context.Background(), "child")
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
@ -1466,6 +1476,8 @@ func TestTokenStore_HandleRequest_RevokeOrphan_NonRoot(t *testing.T) {
t.Fatalf("did not get error when non-root revoking itself with orphan flag; resp is %#v", resp) t.Fatalf("did not get error when non-root revoking itself with orphan flag; resp is %#v", resp)
} }
time.Sleep(300 * time.Millisecond)
// Should still exist // Should still exist
out, err = ts.Lookup(context.Background(), "child") out, err = ts.Lookup(context.Background(), "child")
if err != nil { if err != nil {
@ -3323,7 +3335,7 @@ func TestTokenStore_RevokeUseCountToken(t *testing.T) {
if te == nil { if te == nil {
t.Fatal("nil entry") t.Fatal("nil entry")
} }
if te.NumUses != tokenRevocationDeferred { if te.NumUses != tokenRevocationPending {
t.Fatalf("bad: %d", te.NumUses) t.Fatalf("bad: %d", te.NumUses)
} }
@ -3361,7 +3373,7 @@ func TestTokenStore_RevokeUseCountToken(t *testing.T) {
if te == nil { if te == nil {
t.Fatal("nil entry") t.Fatal("nil entry")
} }
if te.NumUses != tokenRevocationDeferred { if te.NumUses != tokenRevocationPending {
t.Fatalf("bad: %d", te.NumUses) t.Fatalf("bad: %d", te.NumUses)
} }
@ -3376,16 +3388,13 @@ func TestTokenStore_RevokeUseCountToken(t *testing.T) {
t.Fatalf("expected err") t.Fatalf("expected err")
} }
// Since revocation failed we should see the tokenRevocationFailed canary value // Since revocation failed we should still be able to get a token
te, err = ts.lookupSalted(context.Background(), saltTut, true) te, err = ts.lookupSalted(context.Background(), saltTut, true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if te == nil { if te == nil {
t.Fatal("nil entry") t.Fatal("nil token entry")
}
if te.NumUses != tokenRevocationFailed {
t.Fatalf("bad: %d", te.NumUses)
} }
// Check the race condition situation by making the process sleep // Check the race condition situation by making the process sleep
@ -3411,10 +3420,7 @@ func TestTokenStore_RevokeUseCountToken(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if te == nil { if te == nil {
t.Fatal("nil entry") t.Fatal("nil token entry")
}
if te.NumUses != tokenRevocationInProgress {
t.Fatalf("bad: %d", te.NumUses)
} }
// Let things catch up // Let things catch up
@ -3791,7 +3797,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) {
sort.Strings(leases) sort.Strings(leases)
storedLeases, err := exp.lookupByToken(tut) storedLeases, err := exp.lookupLeasesByToken(tut)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -3827,7 +3833,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) {
} }
// Verify leases still exist // Verify leases still exist
storedLeases, err = exp.lookupByToken(tut) storedLeases, err = exp.lookupLeasesByToken(tut)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -3839,8 +3845,10 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) {
// Call tidy // Call tidy
ts.handleTidy(context.Background(), nil, nil) ts.handleTidy(context.Background(), nil, nil)
time.Sleep(300 * time.Millisecond)
// Verify leases are gone // Verify leases are gone
storedLeases, err = exp.lookupByToken(tut) storedLeases, err = exp.lookupLeasesByToken(tut)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -159,7 +159,7 @@ DONELISTHANDLING:
jwt := jws.NewJWT(claims, crypto.SigningMethodES512) jwt := jws.NewJWT(claims, crypto.SigningMethodES512)
serWebToken, err := jwt.Serialize(c.wrappingJWTKey) serWebToken, err := jwt.Serialize(c.wrappingJWTKey)
if err != nil { if err != nil {
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to serialize JWT", "error", err) c.logger.Error("failed to serialize JWT", "error", err)
return nil, ErrInternalError return nil, ErrInternalError
} }
@ -200,7 +200,7 @@ DONELISTHANDLING:
marshaledResponse, err := json.Marshal(httpResponse) marshaledResponse, err := json.Marshal(httpResponse)
if err != nil { if err != nil {
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to marshal wrapped response", "error", err) c.logger.Error("failed to marshal wrapped response", "error", err)
return nil, ErrInternalError return nil, ErrInternalError
} }
@ -213,12 +213,12 @@ DONELISTHANDLING:
cubbyResp, err := c.router.Route(ctx, cubbyReq) cubbyResp, err := c.router.Route(ctx, cubbyReq)
if err != nil { if err != nil {
// Revoke since it's not yet being tracked for expiration // Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to store wrapped response information", "error", err) c.logger.Error("failed to store wrapped response information", "error", err)
return nil, ErrInternalError return nil, ErrInternalError
} }
if cubbyResp != nil && cubbyResp.IsError() { if cubbyResp != nil && cubbyResp.IsError() {
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to store wrapped response information", "error", cubbyResp.Data["error"]) c.logger.Error("failed to store wrapped response information", "error", cubbyResp.Data["error"])
return cubbyResp, nil return cubbyResp, nil
} }
@ -239,12 +239,12 @@ DONELISTHANDLING:
cubbyResp, err = c.router.Route(ctx, cubbyReq) cubbyResp, err = c.router.Route(ctx, cubbyReq)
if err != nil { if err != nil {
// Revoke since it's not yet being tracked for expiration // Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to store wrapping information", "error", err) c.logger.Error("failed to store wrapping information", "error", err)
return nil, ErrInternalError return nil, ErrInternalError
} }
if cubbyResp != nil && cubbyResp.IsError() { if cubbyResp != nil && cubbyResp.IsError() {
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to store wrapping information", "error", cubbyResp.Data["error"]) c.logger.Error("failed to store wrapping information", "error", cubbyResp.Data["error"])
return cubbyResp, nil return cubbyResp, nil
} }
@ -261,7 +261,7 @@ DONELISTHANDLING:
// Register the wrapped token with the expiration manager // Register the wrapped token with the expiration manager
if err := c.expiration.RegisterAuth(te.Path, wAuth); err != nil { if err := c.expiration.RegisterAuth(te.Path, wAuth); err != nil {
// Revoke since it's not yet being tracked for expiration // Revoke since it's not yet being tracked for expiration
c.tokenStore.Revoke(ctx, te.ID) c.tokenStore.revokeOrphan(ctx, te.ID)
c.logger.Error("failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err) c.logger.Error("failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err)
return nil, ErrInternalError return nil, ErrInternalError
} }