Update to a RWMutex

This commit is contained in:
Brian Kassouf 2017-04-26 15:23:14 -07:00
parent cb13786f0a
commit 6b050470fd
4 changed files with 58 additions and 26 deletions

View File

@ -50,7 +50,7 @@ type databaseBackend struct {
logger log.Logger
*framework.Backend
sync.Mutex
sync.RWMutex
}
// resetAllDBs closes all connections from all database types
@ -66,21 +66,23 @@ func (b *databaseBackend) closeAllDBs() {
}
// This function is used to retrieve a database object either from the cached
// connection map or by using the database config in storage. The caller of this
// function needs to hold the backend's lock.
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
// if the object already is built and cached, return it
// connection map. The caller of this function needs to hold the backend's read
// lock.
func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) {
db, ok := b.connections[name]
if ok {
return db, nil
}
return db, ok
}
// This function creates a new db object from the stored configuration and
// caches it in the connections map. The caller of this function needs to hold
// the backend's write lock
func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) {
config, err := b.DatabaseConfig(s, name)
if err != nil {
return nil, err
}
db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger)
if err != nil {
return nil, err
}

View File

@ -62,7 +62,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc {
b.clearConnection(name)
// Execute plugin again, we don't need the object so throw away.
_, err := b.getOrCreateDBObj(req.Storage, name)
_, err := b.createDBObj(req.Storage, name)
if err != nil {
return nil, err
}

View File

@ -52,13 +52,23 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc {
return nil, logical.ErrPermissionDenied
}
b.Lock()
defer b.Unlock()
b.RLock()
// Get the Database object
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
db, ok := b.getDBObj(role.DBName)
if !ok {
// Upgrade lock
b.RUnlock()
b.Lock()
defer b.Unlock()
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
} else {
defer b.RUnlock()
}
expiration := time.Now().Add(role.DefaultTTL)

View File

@ -48,13 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc {
}
// Grab the read lock
b.Lock()
defer b.Unlock()
b.RLock()
// Get our connection
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("error during renew: %s", err)
// Get the Database object
db, ok := b.getDBObj(role.DBName)
if !ok {
// Upgrade lock
b.RUnlock()
b.Lock()
defer b.Unlock()
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
} else {
defer b.RUnlock()
}
// Make sure we increase the VALID UNTIL endpoint for this user.
@ -94,13 +104,23 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc {
}
// Grab the read lock
b.Lock()
defer b.Unlock()
b.RLock()
// Get our connection
db, err := b.getOrCreateDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("error during revoke: %s", err)
db, ok := b.getDBObj(role.DBName)
if !ok {
// Upgrade lock
b.RUnlock()
b.Lock()
defer b.Unlock()
// Create a new DB object
db, err = b.createDBObj(req.Storage, role.DBName)
if err != nil {
return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err)
}
} else {
defer b.RUnlock()
}
err = db.RevokeUser(role.Statements, username)