mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-05 12:26:34 +02:00
Update to a RWMutex
This commit is contained in:
parent
cb13786f0a
commit
6b050470fd
@ -50,7 +50,7 @@ type databaseBackend struct {
|
||||
logger log.Logger
|
||||
|
||||
*framework.Backend
|
||||
sync.Mutex
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// resetAllDBs closes all connections from all database types
|
||||
@ -66,21 +66,23 @@ func (b *databaseBackend) closeAllDBs() {
|
||||
}
|
||||
|
||||
// This function is used to retrieve a database object either from the cached
|
||||
// connection map or by using the database config in storage. The caller of this
|
||||
// function needs to hold the backend's lock.
|
||||
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
// if the object already is built and cached, return it
|
||||
// connection map. The caller of this function needs to hold the backend's read
|
||||
// lock.
|
||||
func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
return db, nil
|
||||
}
|
||||
return db, ok
|
||||
}
|
||||
|
||||
// This function creates a new db object from the stored configuration and
|
||||
// caches it in the connections map. The caller of this function needs to hold
|
||||
// the backend's write lock
|
||||
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
|
||||
config, err := b.DatabaseConfig(s, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
|
||||
db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
|
||||
b.clearConnection(name)
|
||||
|
||||
// Execute plugin again, we don't need the object so throw away.
|
||||
_, err := b.getOrCreateDBObj(req.Storage, name)
|
||||
_, err := b.createDBObj(req.Storage, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -52,13 +52,23 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
|
||||
return nil, logical.ErrPermissionDenied
|
||||
}
|
||||
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.RLock()
|
||||
|
||||
// Get the Database object
|
||||
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
db, ok := b.getDBObj(role.DBName)
|
||||
if !ok {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
} else {
|
||||
defer b.RUnlock()
|
||||
}
|
||||
|
||||
expiration := time.Now().Add(role.DefaultTTL)
|
||||
|
||||
@ -48,13 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
|
||||
}
|
||||
|
||||
// Grab the read lock
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.RLock()
|
||||
|
||||
// Get our connection
|
||||
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error during renew: %s", err)
|
||||
// Get the Database object
|
||||
db, ok := b.getDBObj(role.DBName)
|
||||
if !ok {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
} else {
|
||||
defer b.RUnlock()
|
||||
}
|
||||
|
||||
// Make sure we increase the VALID UNTIL endpoint for this user.
|
||||
@ -94,13 +104,23 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
|
||||
}
|
||||
|
||||
// Grab the read lock
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
b.RLock()
|
||||
|
||||
// Get our connection
|
||||
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error during revoke: %s", err)
|
||||
db, ok := b.getDBObj(role.DBName)
|
||||
if !ok {
|
||||
// Upgrade lock
|
||||
b.RUnlock()
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
// Create a new DB object
|
||||
db, err = b.createDBObj(req.Storage, role.DBName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
|
||||
}
|
||||
} else {
|
||||
defer b.RUnlock()
|
||||
}
|
||||
|
||||
err = db.RevokeUser(role.Statements, username)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user