From 6b050470fdee424a8934383cd6905e484128e226 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 26 Apr 2017 15:23:14 -0700 Subject: [PATCH] Update to a RWMutex --- builtin/logical/database/backend.go | 20 +++++---- .../database/path_config_connection.go | 2 +- builtin/logical/database/path_creds_create.go | 20 ++++++--- builtin/logical/database/secret_creds.go | 42 ++++++++++++++----- 4 files changed, 58 insertions(+), 26 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index e8cf98ebbd..2aff47375d 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -50,7 +50,7 @@ type databaseBackend struct { logger log.Logger *framework.Backend - sync.Mutex + sync.RWMutex } // resetAllDBs closes all connections from all database types @@ -66,21 +66,23 @@ func (b *databaseBackend) closeAllDBs() { } // This function is used to retrieve a database object either from the cached -// connection map or by using the database config in storage. The caller of this -// function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbplugin.Database, error) { - // if the object already is built and cached, return it +// connection map. The caller of this function needs to hold the backend's read +// lock. +func (b *databaseBackend) getDBObj(name string) (dbplugin.Database, bool) { db, ok := b.connections[name] - if ok { - return db, nil - } + return db, ok +} +// This function creates a new db object from the stored configuration and +// caches it in the connections map. The caller of this function needs to hold +// the backend's write lock +func (b *databaseBackend) createDBObj(s logical.Storage, name string) (dbplugin.Database, error) { config, err := b.DatabaseConfig(s, name) if err != nil { return nil, err } - db, err = dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) + db, err := dbplugin.PluginFactory(config.PluginName, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index f52cfec59c..39eb3d0008 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -62,7 +62,7 @@ func (b *databaseBackend) pathConnectionReset() framework.OperationFunc { b.clearConnection(name) // Execute plugin again, we don't need the object so throw away. - _, err := b.getOrCreateDBObj(req.Storage, name) + _, err := b.createDBObj(req.Storage, name) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 9bbaceb54b..60f0c5e3ed 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -52,13 +52,23 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { return nil, logical.ErrPermissionDenied } - b.Lock() - defer b.Unlock() + b.RLock() // Get the Database object - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } expiration := time.Now().Add(role.DefaultTTL) diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 2704eb287c..690b41565e 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -48,13 +48,23 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { } // Grab the read lock - b.Lock() - defer b.Unlock() + b.RLock() - // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("error during renew: %s", err) + // Get the Database object + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } // Make sure we increase the VALID UNTIL endpoint for this user. @@ -94,13 +104,23 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { } // Grab the read lock - b.Lock() - defer b.Unlock() + b.RLock() // Get our connection - db, err := b.getOrCreateDBObj(req.Storage, role.DBName) - if err != nil { - return nil, fmt.Errorf("error during revoke: %s", err) + db, ok := b.getDBObj(role.DBName) + if !ok { + // Upgrade lock + b.RUnlock() + b.Lock() + defer b.Unlock() + + // Create a new DB object + db, err = b.createDBObj(req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } + } else { + defer b.RUnlock() } err = db.RevokeUser(role.Statements, username)