diff --git a/builtin/credential/approle/backend.go b/builtin/credential/approle/backend.go index 0191144129..d57a3ec8a6 100644 --- a/builtin/credential/approle/backend.go +++ b/builtin/credential/approle/backend.go @@ -30,7 +30,7 @@ type backend struct { view logical.Storage // Guard to clean-up the expired SecretID entries - tidySecretIDCASGuard uint32 + tidySecretIDCASGuard *uint32 // Locks to make changes to role entries. These will be initialized to a // predefined number of locks when the backend is created, and will be @@ -85,6 +85,8 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { // Create locks to modify the generated SecretIDAccessors secretIDAccessorLocks: locksutil.CreateLocks(), + + tidySecretIDCASGuard: new(uint32), } // Attach the paths and secrets that are to be handled by the backend diff --git a/builtin/credential/approle/path_tidy_user_id.go b/builtin/credential/approle/path_tidy_user_id.go index 1a385efd2a..590cb7284d 100644 --- a/builtin/credential/approle/path_tidy_user_id.go +++ b/builtin/credential/approle/path_tidy_user_id.go @@ -27,9 +27,9 @@ func pathTidySecretID(b *backend) *framework.Path { // tidySecretID is used to delete entries in the whitelist that are expired. func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error { - grabbed := atomic.CompareAndSwapUint32(&b.tidySecretIDCASGuard, 0, 1) + grabbed := atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1) if grabbed { - defer atomic.StoreUint32(&b.tidySecretIDCASGuard, 0) + defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0) } else { return fmt.Errorf("SecretID tidy operation already running") } diff --git a/builtin/credential/aws/backend.go b/builtin/credential/aws/backend.go index 1466257118..992b295987 100644 --- a/builtin/credential/aws/backend.go +++ b/builtin/credential/aws/backend.go @@ -39,8 +39,8 @@ type backend struct { blacklistMutex sync.RWMutex // Guards the blacklist/whitelist tidy functions - tidyBlacklistCASGuard uint32 - tidyWhitelistCASGuard uint32 + tidyBlacklistCASGuard *uint32 + tidyWhitelistCASGuard *uint32 // Duration after which the periodic function of the backend needs to // tidy the blacklist and whitelist entries. @@ -82,10 +82,12 @@ func Backend(conf *logical.BackendConfig) (*backend, error) { b := &backend{ // Setting the periodic func to be run once in an hour. // If there is a real need, this can be made configurable. - tidyCooldownPeriod: time.Hour, - EC2ClientsMap: make(map[string]map[string]*ec2.EC2), - IAMClientsMap: make(map[string]map[string]*iam.IAM), - iamUserIdToArnCache: cache.New(7*24*time.Hour, 24*time.Hour), + tidyCooldownPeriod: time.Hour, + EC2ClientsMap: make(map[string]map[string]*ec2.EC2), + IAMClientsMap: make(map[string]map[string]*iam.IAM), + iamUserIdToArnCache: cache.New(7*24*time.Hour, 24*time.Hour), + tidyBlacklistCASGuard: new(uint32), + tidyWhitelistCASGuard: new(uint32), } b.resolveArnToUniqueIDFunc = b.resolveArnToRealUniqueId diff --git a/builtin/credential/aws/path_tidy_identity_whitelist.go b/builtin/credential/aws/path_tidy_identity_whitelist.go index fa0e8d82da..f1abe23086 100644 --- a/builtin/credential/aws/path_tidy_identity_whitelist.go +++ b/builtin/credential/aws/path_tidy_identity_whitelist.go @@ -34,9 +34,9 @@ expiration, before it is removed from the backend storage.`, // tidyWhitelistIdentity is used to delete entries in the whitelist that are expired. func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error { - grabbed := atomic.CompareAndSwapUint32(&b.tidyWhitelistCASGuard, 0, 1) + grabbed := atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1) if grabbed { - defer atomic.StoreUint32(&b.tidyWhitelistCASGuard, 0) + defer atomic.StoreUint32(b.tidyWhitelistCASGuard, 0) } else { return fmt.Errorf("identity whitelist tidy operation already running") } diff --git a/builtin/credential/aws/path_tidy_roletag_blacklist.go b/builtin/credential/aws/path_tidy_roletag_blacklist.go index dfb420653e..a29837110d 100644 --- a/builtin/credential/aws/path_tidy_roletag_blacklist.go +++ b/builtin/credential/aws/path_tidy_roletag_blacklist.go @@ -34,9 +34,9 @@ expiration, before it is removed from the backend storage.`, // tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist. func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error { - grabbed := atomic.CompareAndSwapUint32(&b.tidyBlacklistCASGuard, 0, 1) + grabbed := atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1) if grabbed { - defer atomic.StoreUint32(&b.tidyBlacklistCASGuard, 0) + defer atomic.StoreUint32(b.tidyBlacklistCASGuard, 0) } else { return fmt.Errorf("roletag blacklist tidy operation already running") } diff --git a/http/forwarding_test.go b/http/forwarding_test.go index a55b4f66bd..bbf1c15442 100644 --- a/http/forwarding_test.go +++ b/http/forwarding_test.go @@ -191,24 +191,26 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) var key1ver int64 = 1 var key2ver int64 = 1 var key3ver int64 = 1 - var numWorkers uint64 = 50 - var numWorkersStarted uint64 + var numWorkers *uint32 = new(uint32) + *numWorkers = 50 + var numWorkersStarted *uint32 = new(uint32) var waitLock sync.Mutex waitCond := sync.NewCond(&waitLock) // This is the goroutine loop doFuzzy := func(id int, parallel bool) { - var myTotalOps uint64 - var mySuccessfulOps uint64 - var keyVer int64 = 1 + var myTotalOps *uint32 = new(uint32) + var mySuccessfulOps *uint32 = new(uint32) + var keyVer *int32 = new(int32) + *keyVer = 1 // Check for panics, otherwise notify we're done defer func() { if err := recover(); err != nil { core.Logger().Error("got a panic: %v", err) t.Fail() } - atomic.AddUint64(&totalOps, myTotalOps) - atomic.AddUint64(&successfulOps, mySuccessfulOps) + atomic.AddUint32(totalOps, myTotalOps) + atomic.AddUint32(successfulOps, mySuccessfulOps) wg.Done() }() @@ -281,10 +283,10 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) } } - atomic.AddUint64(&numWorkersStarted, 1) + atomic.AddUint32(numWorkersStarted, 1) waitCond.L.Lock() - for atomic.LoadUint64(&numWorkersStarted) != numWorkers { + for atomic.LoadUint32(numWorkersStarted) != atomic.LoadUint32(numWorkers) { waitCond.Wait() } waitCond.L.Unlock() @@ -375,11 +377,11 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) if parallel { switch chosenKey { case "test1": - atomic.AddInt64(&key1ver, 1) + atomic.AddInt32(key1ver, 1) case "test2": - atomic.AddInt64(&key2ver, 1) + atomic.AddInt32(key2ver, 1) case "test3": - atomic.AddInt64(&key3ver, 1) + atomic.AddInt32(key3ver, 1) } } else { keyVer++ @@ -393,11 +395,11 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) if parallel { switch chosenKey { case "test1": - latestVersion = atomic.LoadInt64(&key1ver) + latestVersion = atomic.LoadInt32(key1ver) case "test2": - latestVersion = atomic.LoadInt64(&key2ver) + latestVersion = atomic.LoadInt32(key2ver) case "test3": - latestVersion = atomic.LoadInt64(&key3ver) + latestVersion = atomic.LoadInt32(key3ver) } } @@ -415,10 +417,10 @@ func testHTTP_Forwarding_Stress_Common(t *testing.T, parallel bool, num uint64) } } - atomic.StoreUint64(&numWorkers, num) + atomic.StoreUint32(numWorkers, num) // Spawn some of these workers for 10 seconds - for i := 0; i < int(atomic.LoadUint64(&numWorkers)); i++ { + for i := 0; i < int(atomic.LoadUint32(numWorkers)); i++ { wg.Add(1) //core.Logger().Printf("[TRACE] spawning %d", i) go doFuzzy(i+1, parallel) diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index eb80889c1a..fa050ac60c 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -203,9 +203,9 @@ func TestBackendHandleRequest_renewAuth(t *testing.T) { } func TestBackendHandleRequest_renewAuthCallback(t *testing.T) { - var called uint32 + called := new(uint32) callback := func(context.Context, *logical.Request, *FieldData) (*logical.Response, error) { - atomic.AddUint32(&called, 1) + atomic.AddUint32(called, 1) return nil, nil } @@ -217,14 +217,14 @@ func TestBackendHandleRequest_renewAuthCallback(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if v := atomic.LoadUint32(&called); v != 1 { + if v := atomic.LoadUint32(called); v != 1 { t.Fatalf("bad: %#v", v) } } func TestBackendHandleRequest_renew(t *testing.T) { - var called uint32 + called := new(uint32) callback := func(context.Context, *logical.Request, *FieldData) (*logical.Response, error) { - atomic.AddUint32(&called, 1) + atomic.AddUint32(called, 1) return nil, nil } @@ -240,15 +240,15 @@ func TestBackendHandleRequest_renew(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if v := atomic.LoadUint32(&called); v != 1 { + if v := atomic.LoadUint32(called); v != 1 { t.Fatalf("bad: %#v", v) } } func TestBackendHandleRequest_revoke(t *testing.T) { - var called uint32 + called := new(uint32) callback := func(context.Context, *logical.Request, *FieldData) (*logical.Response, error) { - atomic.AddUint32(&called, 1) + atomic.AddUint32(called, 1) return nil, nil } @@ -264,16 +264,16 @@ func TestBackendHandleRequest_revoke(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if v := atomic.LoadUint32(&called); v != 1 { + if v := atomic.LoadUint32(called); v != 1 { t.Fatalf("bad: %#v", v) } } func TestBackendHandleRequest_rollback(t *testing.T) { - var called uint32 + called := new(uint32) callback := func(_ context.Context, req *logical.Request, kind string, data interface{}) error { if data == "foo" { - atomic.AddUint32(&called, 1) + atomic.AddUint32(called, 1) } return nil } @@ -298,16 +298,16 @@ func TestBackendHandleRequest_rollback(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if v := atomic.LoadUint32(&called); v != 1 { + if v := atomic.LoadUint32(called); v != 1 { t.Fatalf("bad: %#v", v) } } func TestBackendHandleRequest_rollbackMinAge(t *testing.T) { - var called uint32 + called := new(uint32) callback := func(_ context.Context, req *logical.Request, kind string, data interface{}) error { if data == "foo" { - atomic.AddUint32(&called, 1) + atomic.AddUint32(called, 1) } return nil } @@ -330,7 +330,7 @@ func TestBackendHandleRequest_rollbackMinAge(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if v := atomic.LoadUint32(&called); v != 0 { + if v := atomic.LoadUint32(called); v != 0 { t.Fatalf("bad: %#v", v) } } diff --git a/physical/consul/consul.go b/physical/consul/consul.go index 9f0beb3641..982fefca7c 100644 --- a/physical/consul/consul.go +++ b/physical/consul/consul.go @@ -636,9 +636,9 @@ func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ph // and end of a handler's life (or after a handler wakes up from // sleeping during a back-off/retry). var shutdown bool - var checkLock int64 var registeredServiceID string - var serviceRegLock int64 + checkLock := new(int32) + serviceRegLock := new(int32) for !shutdown { select { @@ -654,10 +654,10 @@ func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ph // Abort if service discovery is disabled or a // reconcile handler is already active - if !c.disableRegistration && atomic.CompareAndSwapInt64(&serviceRegLock, 0, 1) { + if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) { // Enter handler with serviceRegLock held go func() { - defer atomic.CompareAndSwapInt64(&serviceRegLock, 1, 0) + defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0) for !shutdown { serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc) if err != nil { @@ -680,10 +680,10 @@ func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh ph checkTimer.Reset(c.checkDuration()) // Abort if service discovery is disabled or a // reconcile handler is active - if !c.disableRegistration && atomic.CompareAndSwapInt64(&checkLock, 0, 1) { + if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) { // Enter handler with checkLock held go func() { - defer atomic.CompareAndSwapInt64(&checkLock, 1, 0) + defer atomic.CompareAndSwapInt32(checkLock, 1, 0) for !shutdown { sealed := sealedFunc() if err := c.runCheck(sealed); err != nil { diff --git a/physical/inmem/inmem.go b/physical/inmem/inmem.go index 139671ce6a..345c1feb9e 100644 --- a/physical/inmem/inmem.go +++ b/physical/inmem/inmem.go @@ -36,10 +36,10 @@ type InmemBackend struct { root *radix.Tree permitPool *physical.PermitPool logger log.Logger - failGet uint32 - failPut uint32 - failDelete uint32 - failList uint32 + failGet *uint32 + failPut *uint32 + failDelete *uint32 + failList *uint32 } type TransactionalInmemBackend struct { @@ -52,6 +52,10 @@ func NewInmem(_ map[string]string, logger log.Logger) (physical.Backend, error) root: radix.New(), permitPool: physical.NewPermitPool(physical.DefaultParallelOperations), logger: logger, + failGet: new(uint32), + failPut: new(uint32), + failDelete: new(uint32), + failList: new(uint32), } return in, nil } @@ -81,7 +85,7 @@ func (i *InmemBackend) Put(ctx context.Context, entry *physical.Entry) error { } func (i *InmemBackend) PutInternal(ctx context.Context, entry *physical.Entry) error { - if atomic.LoadUint32(&i.failPut) != 0 { + if atomic.LoadUint32(i.failPut) != 0 { return PutDisabledError } @@ -94,7 +98,7 @@ func (i *InmemBackend) FailPut(fail bool) { if fail { val = 1 } - atomic.StoreUint32(&i.failPut, val) + atomic.StoreUint32(i.failPut, val) } // Get is used to fetch an entry @@ -109,7 +113,7 @@ func (i *InmemBackend) Get(ctx context.Context, key string) (*physical.Entry, er } func (i *InmemBackend) GetInternal(ctx context.Context, key string) (*physical.Entry, error) { - if atomic.LoadUint32(&i.failGet) != 0 { + if atomic.LoadUint32(i.failGet) != 0 { return nil, GetDisabledError } @@ -127,7 +131,7 @@ func (i *InmemBackend) FailGet(fail bool) { if fail { val = 1 } - atomic.StoreUint32(&i.failGet, val) + atomic.StoreUint32(i.failGet, val) } // Delete is used to permanently delete an entry @@ -142,7 +146,7 @@ func (i *InmemBackend) Delete(ctx context.Context, key string) error { } func (i *InmemBackend) DeleteInternal(ctx context.Context, key string) error { - if atomic.LoadUint32(&i.failDelete) != 0 { + if atomic.LoadUint32(i.failDelete) != 0 { return DeleteDisabledError } @@ -155,7 +159,7 @@ func (i *InmemBackend) FailDelete(fail bool) { if fail { val = 1 } - atomic.StoreUint32(&i.failDelete, val) + atomic.StoreUint32(i.failDelete, val) } // List is used ot list all the keys under a given @@ -171,7 +175,7 @@ func (i *InmemBackend) List(ctx context.Context, prefix string) ([]string, error } func (i *InmemBackend) ListInternal(prefix string) ([]string, error) { - if atomic.LoadUint32(&i.failList) != 0 { + if atomic.LoadUint32(i.failList) != 0 { return nil, ListDisabledError } @@ -201,7 +205,7 @@ func (i *InmemBackend) FailList(fail bool) { if fail { val = 1 } - atomic.StoreUint32(&i.failList, val) + atomic.StoreUint32(i.failList, val) } // Implements the transaction interface diff --git a/vault/core.go b/vault/core.go index 76107f9023..0e38b7a6d2 100644 --- a/vault/core.go +++ b/vault/core.go @@ -191,7 +191,7 @@ type Core struct { standbyDoneCh chan struct{} standbyStopCh chan struct{} manualStepDownCh chan struct{} - keepHALockOnStepDown uint32 + keepHALockOnStepDown *uint32 heldHALock physical.Lock // unlockInfo has the keys provided to Unseal until the threshold number of parts is available, as well as the operation nonce @@ -500,6 +500,7 @@ func NewCore(conf *CoreConfig) (*Core, error) { localClusterCert: new(atomic.Value), localClusterParsedCert: new(atomic.Value), activeNodeReplicationState: new(uint32), + keepHALockOnStepDown: new(uint32), } atomic.StoreUint32(c.replicationState, uint32(consts.ReplicationDRDisabled|consts.ReplicationPerformanceDisabled)) @@ -1138,7 +1139,7 @@ func (c *Core) sealInternal(keepLock bool) error { } } else { if keepLock { - atomic.StoreUint32(&c.keepHALockOnStepDown, 1) + atomic.StoreUint32(c.keepHALockOnStepDown, 1) } // If we are trying to acquire the lock, force it to return with nil so // runStandby will exit @@ -1150,7 +1151,7 @@ func (c *Core) sealInternal(keepLock bool) error { // Wait for runStandby to stop <-c.standbyDoneCh - atomic.StoreUint32(&c.keepHALockOnStepDown, 0) + atomic.StoreUint32(c.keepHALockOnStepDown, 0) c.logger.Debug("runStandby done") } diff --git a/vault/cors.go b/vault/cors.go index 6b0920a73b..db2dd855b7 100644 --- a/vault/cors.go +++ b/vault/cors.go @@ -32,7 +32,7 @@ var StdAllowedHeaders = []string{ type CORSConfig struct { sync.RWMutex `json:"-"` core *Core - Enabled uint32 `json:"enabled"` + Enabled *uint32 `json:"enabled"` AllowedOrigins []string `json:"allowed_origins,omitempty"` AllowedHeaders []string `json:"allowed_headers,omitempty"` } @@ -40,8 +40,9 @@ type CORSConfig struct { func (c *Core) saveCORSConfig(ctx context.Context) error { view := c.systemBarrierView.SubView("config/") + enabled := atomic.LoadUint32(c.corsConfig.Enabled) localConfig := &CORSConfig{ - Enabled: atomic.LoadUint32(&c.corsConfig.Enabled), + Enabled: &enabled, } c.corsConfig.RLock() localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins @@ -109,19 +110,19 @@ func (c *CORSConfig) Enable(ctx context.Context, urls []string, headers []string } c.Unlock() - atomic.StoreUint32(&c.Enabled, CORSEnabled) + atomic.StoreUint32(c.Enabled, CORSEnabled) return c.core.saveCORSConfig(ctx) } // IsEnabled returns the value of CORSConfig.isEnabled func (c *CORSConfig) IsEnabled() bool { - return atomic.LoadUint32(&c.Enabled) == CORSEnabled + return atomic.LoadUint32(c.Enabled) == CORSEnabled } // Disable sets CORS to disabled and clears the allowed origins & headers. func (c *CORSConfig) Disable(ctx context.Context) error { - atomic.StoreUint32(&c.Enabled, CORSDisabled) + atomic.StoreUint32(c.Enabled, CORSDisabled) c.Lock() c.AllowedOrigins = nil diff --git a/vault/expiration.go b/vault/expiration.go index 00e922ba4c..29edadcfbb 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -66,9 +66,9 @@ type ExpirationManager struct { pending map[string]*time.Timer pendingLock sync.RWMutex - tidyLock int32 + tidyLock *int32 - restoreMode int32 + restoreMode *int32 restoreModeLock sync.RWMutex restoreRequestLock sync.RWMutex restoreLocks []*locksutil.LockEntry @@ -77,7 +77,7 @@ type ExpirationManager struct { coreStateLock *sync.RWMutex quitContext context.Context - leaseCheckCounter uint32 + leaseCheckCounter *uint32 logLeaseExpirations bool } @@ -92,19 +92,21 @@ func NewExpirationManager(c *Core, view *BarrierView, logger log.Logger) *Expira tokenStore: c.tokenStore, logger: logger, pending: make(map[string]*time.Timer), + tidyLock: new(int32), // new instances of the expiration manager will go immediately into // restore mode - restoreMode: 1, + restoreMode: new(int32), restoreLocks: locksutil.CreateLocks(), quitCh: make(chan struct{}), coreStateLock: &c.stateLock, quitContext: c.activeContext, - leaseCheckCounter: 0, + leaseCheckCounter: new(uint32), logLeaseExpirations: os.Getenv("VAULT_SKIP_LOGGING_LEASE_EXPIRATIONS") == "", } + *exp.restoreMode = 1 if exp.logger == nil { opts := log.LoggerOptions{Name: "expiration_manager"} @@ -168,7 +170,7 @@ func (m *ExpirationManager) unlockLease(leaseID string) { // inRestoreMode returns if we are currently in restore mode func (m *ExpirationManager) inRestoreMode() bool { - return atomic.LoadInt32(&m.restoreMode) == 1 + return atomic.LoadInt32(m.restoreMode) == 1 } // Tidy cleans up the dangling storage entries for leases. It scans the storage @@ -184,12 +186,12 @@ func (m *ExpirationManager) Tidy() error { var tidyErrors *multierror.Error - if !atomic.CompareAndSwapInt32(&m.tidyLock, 0, 1) { + if !atomic.CompareAndSwapInt32(m.tidyLock, 0, 1) { m.logger.Warn("tidy operation on leases is already in progress") return fmt.Errorf("tidy operation on leases is already in progress") } - defer atomic.CompareAndSwapInt32(&m.tidyLock, 1, 0) + defer atomic.CompareAndSwapInt32(m.tidyLock, 1, 0) m.logger.Info("beginning tidy operation on leases") defer m.logger.Info("finished tidy operation on leases") @@ -294,7 +296,7 @@ func (m *ExpirationManager) Restore(errorFunc func()) (retErr error) { // if restore mode finished successfully, restore mode was already // disabled with the lock. In an error state, this will allow the // Stop() function to shut everything down. - atomic.StoreInt32(&m.restoreMode, 0) + atomic.StoreInt32(m.restoreMode, 0) switch { case retErr == nil: @@ -409,7 +411,7 @@ func (m *ExpirationManager) Restore(errorFunc func()) (retErr error) { m.restoreModeLock.Lock() m.restoreLoaded = sync.Map{} m.restoreLocks = nil - atomic.StoreInt32(&m.restoreMode, 0) + atomic.StoreInt32(m.restoreMode, 0) m.restoreModeLock.Unlock() m.logger.Info("lease restore complete") @@ -1331,11 +1333,11 @@ func (m *ExpirationManager) emitMetrics() { metrics.SetGauge([]string{"expire", "num_leases"}, float32(num)) // Check if lease count is greater than the threshold if num > maxLeaseThreshold { - if atomic.LoadUint32(&m.leaseCheckCounter) > 59 { + if atomic.LoadUint32(m.leaseCheckCounter) > 59 { m.logger.Warn("lease count exceeds warning lease threshold") - atomic.StoreUint32(&m.leaseCheckCounter, 0) + atomic.StoreUint32(m.leaseCheckCounter, 0) } else { - atomic.AddUint32(&m.leaseCheckCounter, 1) + atomic.AddUint32(m.leaseCheckCounter, 1) } } } diff --git a/vault/ha.go b/vault/ha.go index 8f08982435..6850cb235a 100644 --- a/vault/ha.go +++ b/vault/ha.go @@ -29,6 +29,12 @@ func (c *Core) Standby() (bool, error) { // Leader is used to get the current active leader func (c *Core) Leader() (isLeader bool, leaderAddr, clusterAddr string, err error) { + // Check if HA enabled. We don't need the lock for this check as it's set + // on startup and never modified + if c.ha == nil { + return false, "", "", ErrHANotEnabled + } + c.stateLock.RLock() defer c.stateLock.RUnlock() @@ -37,11 +43,6 @@ func (c *Core) Leader() (isLeader bool, leaderAddr, clusterAddr string, err erro return false, "", "", consts.ErrSealed } - // Check if HA enabled - if c.ha == nil { - return false, "", "", ErrHANotEnabled - } - // Check if we are the leader if !c.standby { return true, c.redirectAddr, c.clusterAddr, nil @@ -419,7 +420,7 @@ func (c *Core) runStandby(doneCh, manualStepDownCh, stopCh chan struct{}) { case <-stopCh: // This case comes from sealInternal; we will already be having the // state lock held so we do toggle grabStateLock to false - if atomic.LoadUint32(&c.keepHALockOnStepDown) == 1 { + if atomic.LoadUint32(c.keepHALockOnStepDown) == 1 { releaseHALock = false } grabStateLock = false @@ -466,13 +467,13 @@ func (c *Core) runStandby(doneCh, manualStepDownCh, stopCh chan struct{}) { // the result. func (c *Core) periodicLeaderRefresh(doneCh, stopCh chan struct{}) { defer close(doneCh) - var opCount int32 + opCount := new(int32) for { select { case <-time.After(leaderCheckInterval): - count := atomic.AddInt32(&opCount, 1) + count := atomic.AddInt32(opCount, 1) if count > 1 { - atomic.AddInt32(&opCount, -1) + atomic.AddInt32(opCount, -1) continue } // We do this in a goroutine because otherwise if this refresh is @@ -480,7 +481,7 @@ func (c *Core) periodicLeaderRefresh(doneCh, stopCh chan struct{}) { // deadlock, which then means stopCh can never been seen and we can // block shutdown go func() { - defer atomic.AddInt32(&opCount, -1) + defer atomic.AddInt32(opCount, -1) c.Leader() }() case <-stopCh: @@ -492,18 +493,18 @@ func (c *Core) periodicLeaderRefresh(doneCh, stopCh chan struct{}) { // periodicCheckKeyUpgrade is used to watch for key rotation events as a standby func (c *Core) periodicCheckKeyUpgrade(ctx context.Context, doneCh, stopCh chan struct{}) { defer close(doneCh) - var opCount int32 + opCount := new(int32) for { select { case <-time.After(keyRotateCheckInterval): - count := atomic.AddInt32(&opCount, 1) + count := atomic.AddInt32(opCount, 1) if count > 1 { - atomic.AddInt32(&opCount, -1) + atomic.AddInt32(opCount, -1) continue } go func() { - defer atomic.AddInt32(&opCount, -1) + defer atomic.AddInt32(opCount, -1) // Only check if we are a standby c.stateLock.RLock() standby := c.standby diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index 472ed6cc40..c4c0e73cd9 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -79,7 +79,7 @@ func (c *Core) startForwarding(ctx context.Context) error { fws := &http2.Server{} // Shutdown coordination logic - var shutdown uint32 + shutdown := new(uint32) shutdownWg := &sync.WaitGroup{} for _, addr := range c.clusterListenerAddrs { @@ -120,7 +120,7 @@ func (c *Core) startForwarding(ctx context.Context) error { } for { - if atomic.LoadUint32(&shutdown) > 0 { + if atomic.LoadUint32(shutdown) > 0 { return } @@ -213,7 +213,7 @@ func (c *Core) startForwarding(ctx context.Context) error { // Set the shutdown flag. This will cause the listeners to shut down // within the deadline in clusterListenerAcceptDeadline - atomic.StoreUint32(&shutdown, 1) + atomic.StoreUint32(shutdown, 1) c.logger.Info("forwarding rpc listeners stopped") // Wait for them all to shut down diff --git a/vault/token_store.go b/vault/token_store.go index 4d50c31b7e..fb3f0e3ba4 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -130,7 +130,7 @@ type TokenStore struct { saltLock sync.RWMutex salt *salt.Salt - tidyLock int64 + tidyLock *int32 identityPoliciesDeriverFunc func(string) (*identity.Entity, []string, error) } @@ -150,6 +150,7 @@ func NewTokenStore(ctx context.Context, logger log.Logger, c *Core, config *logi tokensPendingDeletion: &sync.Map{}, saltLock: sync.RWMutex{}, identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies, + tidyLock: new(int32), } if c.policyStore != nil { @@ -1284,12 +1285,12 @@ func (ts *TokenStore) lookupBySaltedAccessor(ctx context.Context, saltedAccessor func (ts *TokenStore) handleTidy(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { var tidyErrors *multierror.Error - if !atomic.CompareAndSwapInt64(&ts.tidyLock, 0, 1) { + if !atomic.CompareAndSwapInt32(ts.tidyLock, 0, 1) { ts.logger.Warn("tidy operation on tokens is already in progress") return nil, fmt.Errorf("tidy operation on tokens is already in progress") } - defer atomic.CompareAndSwapInt64(&ts.tidyLock, 1, 0) + defer atomic.CompareAndSwapInt32(ts.tidyLock, 1, 0) ts.logger.Info("beginning tidy operation on tokens") defer ts.logger.Info("finished tidy operation on tokens") diff --git a/vault/token_store_test.go b/vault/token_store_test.go index abb494675c..c20c36b0c0 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -733,7 +733,7 @@ func TestTokenStore_CreateLookup_ExpirationInRestoreMode(t *testing.T) { // Reset expiration manager to restore mode ts.expiration.restoreModeLock.Lock() - atomic.StoreInt32(&ts.expiration.restoreMode, 1) + atomic.StoreInt32(ts.expiration.restoreMode, 1) ts.expiration.restoreLocks = locksutil.CreateLocks() ts.expiration.restoreModeLock.Unlock()