From cf48236a3c864d4b3cccd3ef7de21594f76a1a53 Mon Sep 17 00:00:00 2001 From: Christopher Swenson Date: Thu, 22 Jun 2023 10:20:13 -0700 Subject: [PATCH] Move database connections map out to separate package (#21207) The upcoming event main plugin will use a very similar pattern as the database plugin map, so it makes sense to refactor this and move this map out. It also cleans up the database plugin backend so that it does not have to keep track of the lock. Co-authored-by: Tom Proctor --- .../scripts/generate-test-package-lists.sh | 1 + builtin/logical/database/backend.go | 78 ++++------------- .../database/path_config_connection.go | 2 +- builtin/logical/database/rotation_test.go | 2 +- helper/syncmap/syncmap.go | 86 +++++++++++++++++++ helper/syncmap/syncmap_test.go | 75 ++++++++++++++++ 6 files changed, 179 insertions(+), 65 deletions(-) create mode 100644 helper/syncmap/syncmap.go create mode 100644 helper/syncmap/syncmap_test.go diff --git a/.github/scripts/generate-test-package-lists.sh b/.github/scripts/generate-test-package-lists.sh index b71b1d72ea..6d5e41ada5 100755 --- a/.github/scripts/generate-test-package-lists.sh +++ b/.github/scripts/generate-test-package-lists.sh @@ -120,6 +120,7 @@ test_packages[6]+=" $base/helper/namespace" test_packages[6]+=" $base/helper/osutil" test_packages[6]+=" $base/helper/parseip" test_packages[6]+=" $base/helper/policies" +test_packages[6]+=" $base/helper/syncmap" test_packages[6]+=" $base/helper/testhelpers/logical" test_packages[6]+=" $base/helper/timeutil" test_packages[6]+=" $base/helper/useragent" diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 94091e2019..f4e5ef31bd 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/go-secure-stdlib/strutil" "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/helper/metricsutil" + "github.com/hashicorp/vault/helper/syncmap" "github.com/hashicorp/vault/internalshared/configutil" v4 "github.com/hashicorp/vault/sdk/database/dbplugin" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" @@ -43,6 +44,10 @@ type dbPluginInstance struct { closed bool } +func (dbi *dbPluginInstance) ID() string { + return dbi.id +} + func (dbi *dbPluginInstance) Close() error { dbi.Lock() defer dbi.Unlock() @@ -119,7 +124,7 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]*dbPluginInstance) + b.connections = syncmap.NewSyncMap[string, *dbPluginInstance]() b.queueCtx, b.cancelQueueCtx = context.WithCancel(context.Background()) b.roleLocks = locksutil.CreateLocks() return &b @@ -127,17 +132,9 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]metricsutil.GaugeLabelValues, error) { // copy the map so we can release the lock - connMapCopy := func() map[string]*dbPluginInstance { - b.connLock.RLock() - defer b.connLock.RUnlock() - mapCopy := map[string]*dbPluginInstance{} - for k, v := range b.connections { - mapCopy[k] = v - } - return mapCopy - }() + connectionsCopy := b.connections.Values() counts := map[string]int{} - for _, v := range connMapCopy { + for _, v := range connectionsCopy { dbType, err := v.database.Type() if err != nil { // there's a chance this will already be closed since we don't hold the lock @@ -156,10 +153,8 @@ func (b *databaseBackend) collectPluginInstanceGaugeValues(context.Context) ([]m } type databaseBackend struct { - // connLock is used to synchronize access to the connections map - connLock sync.RWMutex // connections holds configured database connections by config name - connections map[string]*dbPluginInstance + connections *syncmap.SyncMap[string, *dbPluginInstance] logger log.Logger *framework.Backend @@ -183,49 +178,6 @@ type databaseBackend struct { gaugeCollectionProcessStop sync.Once } -func (b *databaseBackend) connGet(name string) *dbPluginInstance { - b.connLock.RLock() - defer b.connLock.RUnlock() - return b.connections[name] -} - -func (b *databaseBackend) connPop(name string) *dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - dbi, ok := b.connections[name] - if ok { - delete(b.connections, name) - } - return dbi -} - -func (b *databaseBackend) connPopIfEqual(name, id string) *dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - dbi, ok := b.connections[name] - if ok && dbi.id == id { - delete(b.connections, name) - return dbi - } - return nil -} - -func (b *databaseBackend) connPut(name string, newDbi *dbPluginInstance) *dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - dbi := b.connections[name] - b.connections[name] = newDbi - return dbi -} - -func (b *databaseBackend) connClear() map[string]*dbPluginInstance { - b.connLock.Lock() - defer b.connLock.Unlock() - old := b.connections - b.connections = make(map[string]*dbPluginInstance) - return old -} - func (b *databaseBackend) DatabaseConfig(ctx context.Context, s logical.Storage, name string) (*DatabaseConfig, error) { entry, err := s.Get(ctx, fmt.Sprintf("config/%s", name)) if err != nil { @@ -330,7 +282,7 @@ func (b *databaseBackend) GetConnection(ctx context.Context, s logical.Storage, } func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) { - dbi := b.connGet(name) + dbi := b.connections.Get(name) if dbi != nil { return dbi, nil } @@ -360,7 +312,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri id: id, name: name, } - oldConn := b.connPut(name, dbi) + oldConn := b.connections.Put(name, dbi) if oldConn != nil { err := oldConn.Close() if err != nil { @@ -373,7 +325,7 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri // ClearConnection closes the database connection and // removes it from the b.connections map. func (b *databaseBackend) ClearConnection(name string) error { - db := b.connPop(name) + db := b.connections.Pop(name) if db != nil { // Ignore error here since the database client is always killed db.Close() @@ -384,7 +336,7 @@ func (b *databaseBackend) ClearConnection(name string) error { // ClearConnectionId closes the database connection with a specific id and // removes it from the b.connections map. func (b *databaseBackend) ClearConnectionId(name, id string) error { - db := b.connPopIfEqual(name, id) + db := b.connections.PopIfEqual(name, id) if db != nil { // Ignore error here since the database client is always killed db.Close() @@ -403,7 +355,7 @@ func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) { db.Close() // Delete the connection if it is still active. - b.connPopIfEqual(db.name, db.id) + b.connections.PopIfEqual(db.name, db.id) }() } } @@ -416,7 +368,7 @@ func (b *databaseBackend) clean(_ context.Context) { b.cancelQueueCtx() } - connections := b.connClear() + connections := b.connections.Clear() for _, db := range connections { go db.Close() } diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index b869facef0..a50499280f 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -462,7 +462,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) // Close and remove the old connection - oldConn := b.connPut(name, &dbPluginInstance{ + oldConn := b.connections.Put(name, &dbPluginInstance{ database: dbw, name: name, id: id, diff --git a/builtin/logical/database/rotation_test.go b/builtin/logical/database/rotation_test.go index e0cb96dd67..5dfa096593 100644 --- a/builtin/logical/database/rotation_test.go +++ b/builtin/logical/database/rotation_test.go @@ -1390,7 +1390,7 @@ func setupMockDB(b *databaseBackend) *mockNewDatabase { id: "foo-id", name: "mockV5", } - b.connections["mockv5"] = dbi + b.connections.Put("mockv5", dbi) return mockDB } diff --git a/helper/syncmap/syncmap.go b/helper/syncmap/syncmap.go new file mode 100644 index 0000000000..ce3d9e8ca0 --- /dev/null +++ b/helper/syncmap/syncmap.go @@ -0,0 +1,86 @@ +package syncmap + +import "sync" + +// SyncMap implements a map similar to sync.Map, but with generics and with an equality +// in the values specified by an "ID()" method. +type SyncMap[K comparable, V IDer] struct { + // lock is used to synchronize access to the map + lock sync.RWMutex + // data holds the actual data + data map[K]V +} + +// NewSyncMap returns a new, empty SyncMap. +func NewSyncMap[K comparable, V IDer]() *SyncMap[K, V] { + return &SyncMap[K, V]{ + data: make(map[K]V), + } +} + +// Get returns the value for the given key. +func (m *SyncMap[K, V]) Get(k K) V { + m.lock.RLock() + defer m.lock.RUnlock() + return m.data[k] +} + +// Pop deletes and returns the value for the given key, if it exists. +func (m *SyncMap[K, V]) Pop(k K) V { + m.lock.Lock() + defer m.lock.Unlock() + v, ok := m.data[k] + if ok { + delete(m.data, k) + } + return v +} + +// PopIfEqual deletes and returns the value for the given key, if it exists +// and only if the ID is equal to the provided string. +func (m *SyncMap[K, V]) PopIfEqual(k K, id string) V { + m.lock.Lock() + defer m.lock.Unlock() + v, ok := m.data[k] + if ok && v.ID() == id { + delete(m.data, k) + return v + } + var zero V + return zero +} + +// Put adds the given key-value pair to the map and returns the previous value, if any. +func (m *SyncMap[K, V]) Put(k K, v V) V { + m.lock.Lock() + defer m.lock.Unlock() + oldV := m.data[k] + m.data[k] = v + return oldV +} + +// Clear deletes all entries from the map, and returns the previous map. +func (m *SyncMap[K, V]) Clear() map[K]V { + m.lock.Lock() + defer m.lock.Unlock() + old := m.data + m.data = make(map[K]V) + return old +} + +// Values returns a copy of all values in the map. +func (m *SyncMap[K, V]) Values() []V { + m.lock.RLock() + defer m.lock.RUnlock() + + values := make([]V, 0, len(m.data)) + for _, v := range m.data { + values = append(values, v) + } + return values +} + +// IDer is used to extract an ID that SyncMap uses for equality checking. +type IDer interface { + ID() string +} diff --git a/helper/syncmap/syncmap_test.go b/helper/syncmap/syncmap_test.go new file mode 100644 index 0000000000..a62de301fa --- /dev/null +++ b/helper/syncmap/syncmap_test.go @@ -0,0 +1,75 @@ +package syncmap + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +type stringID struct { + val string + id string +} + +func (s stringID) ID() string { + return s.id +} + +var _ IDer = stringID{"", ""} + +// TestSyncMap_Get tests that basic getting and putting works. +func TestSyncMap_Get(t *testing.T) { + m := NewSyncMap[string, stringID]() + m.Put("a", stringID{"b", "b"}) + assert.Equal(t, stringID{"b", "b"}, m.Get("a")) + assert.Equal(t, stringID{"", ""}, m.Get("c")) +} + +// TestSyncMap_Pop tests that basic Pop operations work. +func TestSyncMap_Pop(t *testing.T) { + m := NewSyncMap[string, stringID]() + m.Put("a", stringID{"b", "b"}) + assert.Equal(t, stringID{"b", "b"}, m.Pop("a")) + assert.Equal(t, stringID{"", ""}, m.Pop("a")) + assert.Equal(t, stringID{"", ""}, m.Pop("c")) +} + +// TestSyncMap_PopIfEqual tests that basic PopIfEqual operations pop only if the IDs are equal. +func TestSyncMap_PopIfEqual(t *testing.T) { + m := NewSyncMap[string, stringID]() + m.Put("a", stringID{"b", "c"}) + assert.Equal(t, stringID{"", ""}, m.PopIfEqual("a", "b")) + assert.Equal(t, stringID{"b", "c"}, m.PopIfEqual("a", "c")) + assert.Equal(t, stringID{"", ""}, m.PopIfEqual("a", "c")) +} + +// TestSyncMap_Clear checks that clearing works as expected and returns a copy of the original map. +func TestSyncMap_Clear(t *testing.T) { + m := NewSyncMap[string, stringID]() + assert.Equal(t, map[string]stringID{}, m.data) + oldMap := m.Clear() + assert.Equal(t, map[string]stringID{}, m.data) + assert.Equal(t, map[string]stringID{}, oldMap) + + m.Put("a", stringID{"b", "b"}) + m.Put("c", stringID{"d", "d"}) + oldMap = m.Clear() + + assert.Equal(t, map[string]stringID{"a": {"b", "b"}, "c": {"d", "d"}}, oldMap) + assert.Equal(t, map[string]stringID{}, m.data) +} + +// TestSyncMap_Values checks that the Values method returns an array of the values. +func TestSyncMap_Values(t *testing.T) { + m := NewSyncMap[string, stringID]() + assert.Equal(t, []stringID{}, m.Values()) + m.Put("a", stringID{"b", "b"}) + assert.Equal(t, []stringID{{"b", "b"}}, m.Values()) + m.Put("c", stringID{"d", "d"}) + values := m.Values() + sort.Slice(values, func(i, j int) bool { + return values[i].val < values[j].val + }) + assert.Equal(t, []stringID{{"b", "b"}, {"d", "d"}}, m.Values()) +}