diff --git a/api/client.go b/api/client.go index 6255bb4460..a77ceb863a 100644 --- a/api/client.go +++ b/api/client.go @@ -423,7 +423,7 @@ func NewClient(c *Config) (*Client, error) { } if namespace := os.Getenv(EnvVaultNamespace); namespace != "" { - client.SetNamespace(namespace) + client.setNamespace(namespace) } return client, nil @@ -535,7 +535,10 @@ func (c *Client) SetMFACreds(creds []string) { func (c *Client) SetNamespace(namespace string) { c.modifyLock.Lock() defer c.modifyLock.Unlock() + c.setNamespace(namespace) +} +func (c *Client) setNamespace(namespace string) { if c.headers == nil { c.headers = make(http.Header) } diff --git a/vault/barrier_aes_gcm.go b/vault/barrier_aes_gcm.go index ddddb42ca9..2ad4b77050 100644 --- a/vault/barrier_aes_gcm.go +++ b/vault/barrier_aes_gcm.go @@ -308,17 +308,26 @@ func (b *AESGCMBarrier) ReloadMasterKey(ctx context.Context) error { return nil } - defer memzero(out.Value) + // Grab write lock and refetch + b.l.Lock() + defer b.l.Unlock() + + out, err = b.lockSwitchedGet(ctx, masterKeyPath, false) + if err != nil { + return errwrap.Wrapf("failed to read master key path: {{err}}", err) + } + + if out == nil { + return nil + } // Deserialize the master key key, err := DeserializeKey(out.Value) + memzero(out.Value) if err != nil { return errwrap.Wrapf("failed to deserialize key: {{err}}", err) } - b.l.Lock() - defer b.l.Unlock() - // Check if the master key is the same if subtle.ConstantTimeCompare(b.keyring.MasterKey(), key.Value) == 1 { return nil @@ -499,8 +508,8 @@ func (b *AESGCMBarrier) Rotate(ctx context.Context) (uint32, error) { // CreateUpgrade creates an upgrade path key to the given term from the previous term func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error { b.l.RLock() - defer b.l.RUnlock() if b.sealed { + b.l.RUnlock() return ErrBarrierSealed } @@ -509,6 +518,7 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error { buf, err := termKey.Serialize() defer memzero(buf) if err != nil { + b.l.RUnlock() return err } @@ -516,11 +526,13 @@ func (b *AESGCMBarrier) CreateUpgrade(ctx context.Context, term uint32) error { prevTerm := term - 1 primary, err := b.aeadForTerm(prevTerm) if err != nil { + b.l.RUnlock() return err } key := fmt.Sprintf("%s%d", keyringUpgradePrefix, prevTerm) value, err := b.encrypt(key, prevTerm, primary, buf) + b.l.RUnlock() if err != nil { return err } @@ -541,8 +553,8 @@ func (b *AESGCMBarrier) DestroyUpgrade(ctx context.Context, term uint32) error { // CheckUpgrade looks for an upgrade to the current term and installs it func (b *AESGCMBarrier) CheckUpgrade(ctx context.Context) (bool, uint32, error) { b.l.RLock() - defer b.l.RUnlock() if b.sealed { + b.l.RUnlock() return false, 0, ErrBarrierSealed } @@ -551,30 +563,48 @@ func (b *AESGCMBarrier) CheckUpgrade(ctx context.Context) (bool, uint32, error) // Check for an upgrade key upgrade := fmt.Sprintf("%s%d", keyringUpgradePrefix, activeTerm) - entry, err := b.Get(ctx, upgrade) + entry, err := b.lockSwitchedGet(ctx, upgrade, false) if err != nil { + b.l.RUnlock() return false, 0, err } // Nothing to do if no upgrade if entry == nil { + b.l.RUnlock() return false, 0, nil } - defer memzero(entry.Value) - - // Deserialize the key - key, err := DeserializeKey(entry.Value) - if err != nil { - return false, 0, err - } - // Upgrade from read lock to write lock b.l.RUnlock() - defer b.l.RLock() b.l.Lock() defer b.l.Unlock() + // Validate base cases and refetch values again + + if b.sealed { + return false, 0, ErrBarrierSealed + } + + activeTerm = b.keyring.ActiveTerm() + + upgrade = fmt.Sprintf("%s%d", keyringUpgradePrefix, activeTerm) + entry, err = b.lockSwitchedGet(ctx, upgrade, false) + if err != nil { + return false, 0, err + } + + if entry == nil { + return false, 0, nil + } + + // Deserialize the key + key, err := DeserializeKey(entry.Value) + memzero(entry.Value) + if err != nil { + return false, 0, err + } + // Update the keyring newKeyring, err := b.keyring.AddKey(key) if err != nil { @@ -692,25 +722,39 @@ func (b *AESGCMBarrier) Put(ctx context.Context, entry *logical.StorageEntry) er // Get is used to fetch an entry func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*logical.StorageEntry, error) { + return b.lockSwitchedGet(ctx, key, true) +} + +func (b *AESGCMBarrier) lockSwitchedGet(ctx context.Context, key string, getLock bool) (*logical.StorageEntry, error) { defer metrics.MeasureSince([]string{"barrier", "get"}, time.Now()) - b.l.RLock() + if getLock { + b.l.RLock() + } if b.sealed { - b.l.RUnlock() + if getLock { + b.l.RUnlock() + } return nil, ErrBarrierSealed } // Read the key from the backend pe, err := b.backend.Get(ctx, key) if err != nil { - b.l.RUnlock() + if getLock { + b.l.RUnlock() + } return nil, err } else if pe == nil { - b.l.RUnlock() + if getLock { + b.l.RUnlock() + } return nil, nil } if len(pe.Value) < 4 { - b.l.RUnlock() + if getLock { + b.l.RUnlock() + } return nil, errors.New("invalid value") } @@ -721,7 +765,9 @@ func (b *AESGCMBarrier) Get(ctx context.Context, key string) (*logical.StorageEn // It is expensive to do this first but it is not a // normal case that this won't match gcm, err := b.aeadForTerm(term) - b.l.RUnlock() + if getLock { + b.l.RUnlock() + } if err != nil { return nil, err }