diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index b64248bc69..4609d27413 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -17,7 +17,6 @@ import ( "github.com/armon/go-metrics" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-secure-stdlib/strutil" - "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/builtin/logical/database/schedule" "github.com/hashicorp/vault/helper/metricsutil" "github.com/hashicorp/vault/helper/syncmap" @@ -343,73 +342,6 @@ func (b *databaseBackend) GetConnectionSkipVerify(ctx context.Context, s logical return b.GetConnectionWithConfig(ctx, name, config) } -func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) { - // fast path, reuse the existing connection - dbi := b.connections.Get(name) - if dbi != nil { - return dbi, nil - } - - // slow path, create a new connection - // if we don't lock the rest of the operation, there is a race condition for multiple callers of this function - b.createConnectionLock.Lock() - defer b.createConnectionLock.Unlock() - - // check again in case we lost the race - dbi = b.connections.Get(name) - if dbi != nil { - return dbi, nil - } - - id, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - - // Override the configured version if there is a pinned version. - pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName) - if err != nil { - return nil, err - } - pluginVersion := config.PluginVersion - if pinnedVersion != "" { - pluginVersion = pinnedVersion - } - - dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger) - if err != nil { - return nil, fmt.Errorf("unable to create database instance: %w", err) - } - - initReq := v5.InitializeRequest{ - Config: config.ConnectionDetails, - VerifyConnection: config.VerifyConnection, - } - _, err = dbw.Initialize(ctx, initReq) - if err != nil { - dbw.Close() - return nil, err - } - - dbi = &dbPluginInstance{ - database: dbw, - id: id, - name: name, - runningPluginVersion: pluginVersion, - } - conn, ok := b.connections.PutIfEmpty(name, dbi) - if !ok { - // this is a bug - b.Logger().Warn("BUG: there was a race condition adding to the database connection map") - // There was already an existing connection, so we will use that and close our new one to avoid a race condition. - err := dbi.Close() - if err != nil { - b.Logger().Warn("Error closing new database connection", "error", err) - } - } - return conn, nil -} - // ClearConnection closes the database connection and // removes it from the b.connections map. func (b *databaseBackend) ClearConnection(name string) error { diff --git a/builtin/logical/database/backend_ce.go b/builtin/logical/database/backend_ce.go new file mode 100644 index 0000000000..977576429a --- /dev/null +++ b/builtin/logical/database/backend_ce.go @@ -0,0 +1,82 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package database + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-uuid" + v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" +) + +// GetConnectionWithConfig gets or creates a database connection with the given config for community edition +func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) { + // fast path, reuse the existing connection + dbi := b.connections.Get(name) + if dbi != nil { + return dbi, nil + } + + // slow path, create a new connection + // if we don't lock the rest of the operation, there is a race condition for multiple callers of this function + b.createConnectionLock.Lock() + defer b.createConnectionLock.Unlock() + + // check again in case we lost the race + dbi = b.connections.Get(name) + if dbi != nil { + return dbi, nil + } + + id, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + // Override the configured version if there is a pinned version. + pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName) + if err != nil { + return nil, err + } + pluginVersion := config.PluginVersion + if pinnedVersion != "" { + pluginVersion = pinnedVersion + } + + dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger) + if err != nil { + return nil, fmt.Errorf("unable to create database instance: %w", err) + } + + initReq := v5.InitializeRequest{ + Config: config.ConnectionDetails, + VerifyConnection: config.VerifyConnection, + } + _, err = dbw.Initialize(ctx, initReq) + if err != nil { + dbw.Close() + return nil, err + } + + dbi = &dbPluginInstance{ + database: dbw, + id: id, + name: name, + runningPluginVersion: pluginVersion, + } + conn, ok := b.connections.PutIfEmpty(name, dbi) + if !ok { + // this is a bug + b.Logger().Warn("BUG: there was a race condition adding to the database connection map") + // There was already an existing connection, so we will use that and close our new one to avoid a race condition. + err := dbi.Close() + if err != nil { + b.Logger().Warn("Error closing new database connection", "error", err) + } + } + return conn, nil +} diff --git a/builtin/logical/database/backend_ce_test.go b/builtin/logical/database/backend_ce_test.go new file mode 100644 index 0000000000..d115e85539 --- /dev/null +++ b/builtin/logical/database/backend_ce_test.go @@ -0,0 +1,448 @@ +// Copyright IBM Corp. 2016, 2025 +// SPDX-License-Identifier: BUSL-1.1 + +//go:build !enterprise + +package database + +import ( + "context" + "encoding/json" + "reflect" + "strings" + "testing" + "time" + + "github.com/go-test/deep" + "github.com/hashicorp/vault/helper/namespace" + postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql" + "github.com/hashicorp/vault/sdk/database/helper/dbutil" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/logical" + _ "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" +) + +// TestBackend_config_connection tests the configuration of a database connection +func TestBackend_config_connection(t *testing.T) { + var resp *logical.Response + var err error + + cluster, sys := getClusterPostgresDB(t) + defer cluster.Cleanup() + + config := logical.TestBackendConfig() + config.StorageView = &logical.InmemStorage{} + config.System = sys + eventSender := logical.NewMockEventSender() + config.EventsSender = eventSender + lb, err := Factory(context.Background(), config) + if err != nil { + t.Fatal(err) + } + b, ok := lb.(*databaseBackend) + if !ok { + t.Fatal("could not convert to database backend") + } + defer b.Cleanup(context.Background()) + + // Test creation + { + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "someotherdata": "testing", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, + "allowed_roles": []string{"*"}, + "name": "plugin-test", + } + + configReq := &logical.Request{ + Operation: logical.CreateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + + exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{ + Raw: configData, + Schema: pathConfigurePluginConnection(b).Fields, + }) + if err != nil { + t.Fatal(err) + } + if exists { + t.Fatal("expected not exists") + } + + resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_connection_url", + "someotherdata": "testing", + }, + "allowed_roles": []string{"*"}, + "root_credentials_rotate_statements": []string{}, + "password_policy": "", + "plugin_version": "", + "verify_connection": false, + "skip_static_role_import_rotation": false, + "rotation_schedule": "", + "rotation_period": time.Duration(0).Seconds(), + "rotation_window": time.Duration(0).Seconds(), + "disable_automated_rotation": false, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } + } + + // Test existence check and an update to a single connection detail parameter + { + configData := map[string]interface{}{ + "connection_url": "sample_convection_url", + "verify_connection": false, + "name": "plugin-test", + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + + exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{ + Raw: configData, + Schema: pathConfigurePluginConnection(b).Fields, + }) + if err != nil { + t.Fatal(err) + } + if !exists { + t.Fatal("expected exists") + } + + resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_convection_url", + "someotherdata": "testing", + }, + "allowed_roles": []string{"*"}, + "root_credentials_rotate_statements": []string{}, + "password_policy": "", + "plugin_version": "", + "verify_connection": false, + "skip_static_role_import_rotation": false, + "rotation_schedule": "", + "rotation_period": time.Duration(0).Seconds(), + "rotation_window": time.Duration(0).Seconds(), + "disable_automated_rotation": false, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + delete(resp.Data, "AutomatedRotationParams") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } + } + + // Test an update to a non-details value + { + configData := map[string]interface{}{ + "verify_connection": false, + "allowed_roles": []string{"flu", "barre"}, + "name": "plugin-test", + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + + resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_convection_url", + "someotherdata": "testing", + }, + "allowed_roles": []string{"flu", "barre"}, + "root_credentials_rotate_statements": []string{}, + "password_policy": "", + "plugin_version": "", + "verify_connection": false, + "skip_static_role_import_rotation": false, + "rotation_schedule": "", + "rotation_period": time.Duration(0).Seconds(), + "rotation_window": time.Duration(0).Seconds(), + "disable_automated_rotation": false, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + delete(resp.Data, "AutomatedRotationParams") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } + } + + req := &logical.Request{ + Operation: logical.ListOperation, + Storage: config.StorageView, + Path: "config/", + } + resp, err = b.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatal(err) + } + keys := resp.Data["keys"].([]string) + key := keys[0] + if key != "plugin-test" { + t.Fatalf("bad key: %q", key) + } + assert.Equal(t, 3, len(eventSender.Events)) + assert.Equal(t, "database/config-write", string(eventSender.Events[0].Type)) + assert.Equal(t, "config/plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["path"]) + assert.Equal(t, "plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["name"]) + assert.Equal(t, "database/config-write", string(eventSender.Events[1].Type)) + assert.Equal(t, "config/plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["path"]) + assert.Equal(t, "plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["name"]) + assert.Equal(t, "database/config-write", string(eventSender.Events[2].Type)) + assert.Equal(t, "config/plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["path"]) + assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"]) +} + +// TestBackend_connectionCrud tests the full CRUD lifecycle of a database connection +func TestBackend_connectionCrud(t *testing.T) { + t.Parallel() + dbFactory := &singletonDBFactory{} + cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory) + defer cluster.Cleanup() + + dbFactory.sys = sys + client := cluster.Cores[0].Client.Logical() + + cleanup, connURL := postgreshelper.PrepareTestContainer(t) + defer cleanup() + + // Mount the database plugin. + resp, err := client.Write("sys/mounts/database", map[string]interface{}{ + "type": "database", + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Configure a connection + resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ + "connection_url": "test", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Configure a second connection to confirm below it doesn't get restarted. + resp, err = client.Write("database/config/plugin-test-hana", map[string]interface{}{ + "connection_url": "test", + "plugin_name": "hana-database-plugin", + "verify_connection": false, + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Create a role + resp, err = client.Write("database/roles/plugin-role-test", map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "revocation_statements": defaultRevocationSQL, + "default_ttl": "5m", + "max_ttl": "10m", + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Update the connection + resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ + "connection_url": connURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, + "username": "postgres", + "password": "secret", + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + if len(resp.Warnings) == 0 { + t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp) + } + + resp, err = client.Read("database/config/plugin-test") + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{}) + if strings.Contains(returnedConnectionDetails["connection_url"].(string), "secret") { + t.Fatal("password should not be found in the connection url") + } + // Covered by the filled out `expected` value below, but be explicit about this requirement. + if _, exists := returnedConnectionDetails["password"]; exists { + t.Fatal("password should NOT be found in the returned config") + } + + // Replace connection url with templated version + templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}") + resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ + "connection_url": templatedConnURL, + "plugin_name": "postgresql-database-plugin", + "allowed_roles": []string{"plugin-role-test"}, + "username": "postgres", + "password": "secret", + }) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read connection + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "username": "postgres", + "connection_url": templatedConnURL, + }, + "allowed_roles": []any{"plugin-role-test"}, + "root_credentials_rotate_statements": []any{}, + "password_policy": "", + "plugin_version": "", + "verify_connection": false, + "skip_static_role_import_rotation": false, + "rotation_schedule": "", + "rotation_period": json.Number("0"), + "rotation_window": json.Number("0"), + "disable_automated_rotation": false, + } + resp, err = client.Read("database/config/plugin-test") + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + delete(resp.Data, "AutomatedRotationParams") + if diff := deep.Equal(resp.Data, expected); diff != nil { + t.Fatal(strings.Join(diff, "\n")) + } + + // Test endpoints for reloading plugins. + for _, reload := range []struct { + path string + data map[string]any + checkCount bool + }{ + {"database/reset/plugin-test", nil, false}, + {"database/reload/postgresql-database-plugin", nil, true}, + {"sys/plugins/reload/backend", map[string]any{ + "plugin": "postgresql-database-plugin", + }, false}, + } { + getConnectionID := func(name string) string { + t.Helper() + dbi := dbFactory.db.connections.Get(name) + if dbi == nil { + t.Fatal("no plugin-test dbi") + } + return dbi.ID() + } + initialID := getConnectionID("plugin-test") + hanaID := getConnectionID("plugin-test-hana") + resp, err = client.Write(reload.path, reload.data) + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + if initialID == getConnectionID("plugin-test") { + t.Fatal("ID unchanged after connection reset") + } + if hanaID != getConnectionID("plugin-test-hana") { + t.Fatal("hana plugin got restarted but shouldn't have been") + } + if reload.checkCount { + actual, err := resp.Data["count"].(json.Number).Int64() + if err != nil { + t.Fatal(err) + } + if expected := 1; expected != int(actual) { + t.Fatalf("expected %d but got %d", expected, resp.Data["count"].(int)) + } + if expected := []any{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) { + t.Fatalf("expected %v but got %v", expected, resp.Data["connections"]) + } + } + } + + // Get creds + credsResp, err := client.Read("database/creds/plugin-role-test") + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + + credCheckURL := dbutil.QueryHelper(templatedConnURL, map[string]string{ + "username": "postgres", + "password": "secret", + }) + if !testCredsExist(t, credsResp.Data, credCheckURL) { + t.Fatalf("Creds should exist") + } + + // Delete Connection + resp, err = client.Delete("database/config/plugin-test") + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Read connection + resp, err = client.Read("database/config/plugin-test") + if err != nil { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Should be empty + if resp != nil { + t.Fatal("Expected response to be nil") + } +} diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index 1cb10575fa..4c088b1579 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -6,7 +6,6 @@ package database import ( "context" "database/sql" - "encoding/json" "errors" "fmt" "log" @@ -32,8 +31,6 @@ import ( "github.com/hashicorp/vault/plugins/database/postgresql" v4 "github.com/hashicorp/vault/sdk/database/dbplugin" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" - "github.com/hashicorp/vault/sdk/database/helper/dbutil" - "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" @@ -157,228 +154,6 @@ func TestBackend_RoleUpgrade(t *testing.T) { } } -func TestBackend_config_connection(t *testing.T) { - var resp *logical.Response - var err error - - cluster, sys := getClusterPostgresDB(t) - defer cluster.Cleanup() - - config := logical.TestBackendConfig() - config.StorageView = &logical.InmemStorage{} - config.System = sys - eventSender := logical.NewMockEventSender() - config.EventsSender = eventSender - lb, err := Factory(context.Background(), config) - if err != nil { - t.Fatal(err) - } - b, ok := lb.(*databaseBackend) - if !ok { - t.Fatal("could not convert to database backend") - } - defer b.Cleanup(context.Background()) - - // Test creation - { - configData := map[string]interface{}{ - "connection_url": "sample_connection_url", - "someotherdata": "testing", - "plugin_name": "postgresql-database-plugin", - "verify_connection": false, - "allowed_roles": []string{"*"}, - "name": "plugin-test", - } - - configReq := &logical.Request{ - Operation: logical.CreateOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: configData, - } - - exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{ - Raw: configData, - Schema: pathConfigurePluginConnection(b).Fields, - }) - if err != nil { - t.Fatal(err) - } - if exists { - t.Fatal("expected not exists") - } - - resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v\n", err, resp) - } - - expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": map[string]interface{}{ - "connection_url": "sample_connection_url", - "someotherdata": "testing", - }, - "allowed_roles": []string{"*"}, - "root_credentials_rotate_statements": []string{}, - "password_policy": "", - "plugin_version": "", - "verify_connection": false, - "skip_static_role_import_rotation": false, - "rotation_schedule": "", - "rotation_period": time.Duration(0).Seconds(), - "rotation_window": time.Duration(0).Seconds(), - "disable_automated_rotation": false, - } - configReq.Operation = logical.ReadOperation - resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - delete(resp.Data["connection_details"].(map[string]interface{}), "name") - if !reflect.DeepEqual(expected, resp.Data) { - t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) - } - } - - // Test existence check and an update to a single connection detail parameter - { - configData := map[string]interface{}{ - "connection_url": "sample_convection_url", - "verify_connection": false, - "name": "plugin-test", - } - - configReq := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: configData, - } - - exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{ - Raw: configData, - Schema: pathConfigurePluginConnection(b).Fields, - }) - if err != nil { - t.Fatal(err) - } - if !exists { - t.Fatal("expected exists") - } - - resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v\n", err, resp) - } - - expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": map[string]interface{}{ - "connection_url": "sample_convection_url", - "someotherdata": "testing", - }, - "allowed_roles": []string{"*"}, - "root_credentials_rotate_statements": []string{}, - "password_policy": "", - "plugin_version": "", - "verify_connection": false, - "skip_static_role_import_rotation": false, - "rotation_schedule": "", - "rotation_period": time.Duration(0).Seconds(), - "rotation_window": time.Duration(0).Seconds(), - "disable_automated_rotation": false, - } - configReq.Operation = logical.ReadOperation - resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - delete(resp.Data["connection_details"].(map[string]interface{}), "name") - delete(resp.Data, "AutomatedRotationParams") - if !reflect.DeepEqual(expected, resp.Data) { - t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) - } - } - - // Test an update to a non-details value - { - configData := map[string]interface{}{ - "verify_connection": false, - "allowed_roles": []string{"flu", "barre"}, - "name": "plugin-test", - } - - configReq := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: configData, - } - - resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%v resp:%#v\n", err, resp) - } - - expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": map[string]interface{}{ - "connection_url": "sample_convection_url", - "someotherdata": "testing", - }, - "allowed_roles": []string{"flu", "barre"}, - "root_credentials_rotate_statements": []string{}, - "password_policy": "", - "plugin_version": "", - "verify_connection": false, - "skip_static_role_import_rotation": false, - "rotation_schedule": "", - "rotation_period": time.Duration(0).Seconds(), - "rotation_window": time.Duration(0).Seconds(), - "disable_automated_rotation": false, - } - configReq.Operation = logical.ReadOperation - resp, err = b.HandleRequest(namespace.RootContext(nil), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - delete(resp.Data["connection_details"].(map[string]interface{}), "name") - delete(resp.Data, "AutomatedRotationParams") - if !reflect.DeepEqual(expected, resp.Data) { - t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) - } - } - - req := &logical.Request{ - Operation: logical.ListOperation, - Storage: config.StorageView, - Path: "config/", - } - resp, err = b.HandleRequest(namespace.RootContext(nil), req) - if err != nil { - t.Fatal(err) - } - keys := resp.Data["keys"].([]string) - key := keys[0] - if key != "plugin-test" { - t.Fatalf("bad key: %q", key) - } - assert.Equal(t, 3, len(eventSender.Events)) - assert.Equal(t, "database/config-write", string(eventSender.Events[0].Type)) - assert.Equal(t, "config/plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["path"]) - assert.Equal(t, "plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["name"]) - assert.Equal(t, "database/config-write", string(eventSender.Events[1].Type)) - assert.Equal(t, "config/plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["path"]) - assert.Equal(t, "plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["name"]) - assert.Equal(t, "database/config-write", string(eventSender.Events[2].Type)) - assert.Equal(t, "config/plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["path"]) - assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"]) -} - // TestBackend_BadConnectionString tests that an error response resulting from // a failed connection does not expose the URL. The middleware should sanitize it. func TestBackend_BadConnectionString(t *testing.T) { @@ -693,206 +468,6 @@ func (s *singletonDBFactory) factory(context.Context, *logical.BackendConfig) (l return s.db, nil } -func TestBackend_connectionCrud(t *testing.T) { - t.Parallel() - dbFactory := &singletonDBFactory{} - cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory) - defer cluster.Cleanup() - - dbFactory.sys = sys - client := cluster.Cores[0].Client.Logical() - - cleanup, connURL := postgreshelper.PrepareTestContainer(t) - defer cleanup() - - // Mount the database plugin. - resp, err := client.Write("sys/mounts/database", map[string]interface{}{ - "type": "database", - }) - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Configure a connection - resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ - "connection_url": "test", - "plugin_name": "postgresql-database-plugin", - "verify_connection": false, - }) - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Configure a second connection to confirm below it doesn't get restarted. - resp, err = client.Write("database/config/plugin-test-hana", map[string]interface{}{ - "connection_url": "test", - "plugin_name": "hana-database-plugin", - "verify_connection": false, - }) - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Create a role - resp, err = client.Write("database/roles/plugin-role-test", map[string]interface{}{ - "db_name": "plugin-test", - "creation_statements": testRole, - "revocation_statements": defaultRevocationSQL, - "default_ttl": "5m", - "max_ttl": "10m", - }) - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Update the connection - resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ - "connection_url": connURL, - "plugin_name": "postgresql-database-plugin", - "allowed_roles": []string{"plugin-role-test"}, - "username": "postgres", - "password": "secret", - }) - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - if len(resp.Warnings) == 0 { - t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp) - } - - resp, err = client.Read("database/config/plugin-test") - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{}) - if strings.Contains(returnedConnectionDetails["connection_url"].(string), "secret") { - t.Fatal("password should not be found in the connection url") - } - // Covered by the filled out `expected` value below, but be explicit about this requirement. - if _, exists := returnedConnectionDetails["password"]; exists { - t.Fatal("password should NOT be found in the returned config") - } - - // Replace connection url with templated version - templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}") - resp, err = client.Write("database/config/plugin-test", map[string]interface{}{ - "connection_url": templatedConnURL, - "plugin_name": "postgresql-database-plugin", - "allowed_roles": []string{"plugin-role-test"}, - "username": "postgres", - "password": "secret", - }) - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Read connection - expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": map[string]interface{}{ - "username": "postgres", - "connection_url": templatedConnURL, - }, - "allowed_roles": []any{"plugin-role-test"}, - "root_credentials_rotate_statements": []any{}, - "password_policy": "", - "plugin_version": "", - "verify_connection": false, - "skip_static_role_import_rotation": false, - "rotation_schedule": "", - "rotation_period": json.Number("0"), - "rotation_window": json.Number("0"), - "disable_automated_rotation": false, - } - resp, err = client.Read("database/config/plugin-test") - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - delete(resp.Data["connection_details"].(map[string]interface{}), "name") - delete(resp.Data, "AutomatedRotationParams") - if diff := deep.Equal(resp.Data, expected); diff != nil { - t.Fatal(strings.Join(diff, "\n")) - } - - // Test endpoints for reloading plugins. - for _, reload := range []struct { - path string - data map[string]any - checkCount bool - }{ - {"database/reset/plugin-test", nil, false}, - {"database/reload/postgresql-database-plugin", nil, true}, - {"sys/plugins/reload/backend", map[string]any{ - "plugin": "postgresql-database-plugin", - }, false}, - } { - getConnectionID := func(name string) string { - t.Helper() - dbi := dbFactory.db.connections.Get(name) - if dbi == nil { - t.Fatal("no plugin-test dbi") - } - return dbi.ID() - } - initialID := getConnectionID("plugin-test") - hanaID := getConnectionID("plugin-test-hana") - resp, err = client.Write(reload.path, reload.data) - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - if initialID == getConnectionID("plugin-test") { - t.Fatal("ID unchanged after connection reset") - } - if hanaID != getConnectionID("plugin-test-hana") { - t.Fatal("hana plugin got restarted but shouldn't have been") - } - if reload.checkCount { - actual, err := resp.Data["count"].(json.Number).Int64() - if err != nil { - t.Fatal(err) - } - if expected := 1; expected != int(actual) { - t.Fatalf("expected %d but got %d", expected, resp.Data["count"].(int)) - } - if expected := []any{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) { - t.Fatalf("expected %v but got %v", expected, resp.Data["connections"]) - } - } - } - - // Get creds - credsResp, err := client.Read("database/creds/plugin-role-test") - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, credsResp) - } - - credCheckURL := dbutil.QueryHelper(templatedConnURL, map[string]string{ - "username": "postgres", - "password": "secret", - }) - if !testCredsExist(t, credsResp.Data, credCheckURL) { - t.Fatalf("Creds should exist") - } - - // Delete Connection - resp, err = client.Delete("database/config/plugin-test") - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Read connection - resp, err = client.Read("database/config/plugin-test") - if err != nil { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } - - // Should be empty - if resp != nil { - t.Fatal("Expected response to be nil") - } -} - func TestBackend_connectionSanitizePrivateKey(t *testing.T) { t.Parallel() dbFactory := &singletonDBFactory{} diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 8701356415..9f170023a7 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -7,20 +7,13 @@ import ( "context" "errors" "fmt" - "net/url" - "sort" - "github.com/fatih/structs" - "github.com/hashicorp/go-uuid" - "github.com/hashicorp/go-version" - "github.com/hashicorp/vault/helper/versions" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/automatedrotationutil" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/hashicorp/vault/sdk/rotation" ) var ( @@ -31,6 +24,8 @@ var ( // DatabaseConfig is used by the Factory function to configure a Database // object. type DatabaseConfig struct { + EntDatabaseConfig `mapstructure:",squash"` + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` PluginVersion string `json:"plugin_version" structs:"plugin_version" mapstructure:"plugin_version"` RunningPluginVersion string `json:"running_plugin_version,omitempty" structs:"running_plugin_version,omitempty" mapstructure:"running_plugin_version,omitempty"` @@ -376,70 +371,6 @@ func (b *databaseBackend) connectionListHandler() framework.OperationFunc { } } -// connectionReadHandler reads out the connection configuration -func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { - return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) - if name == "" { - return logical.ErrorResponse(respErrEmptyName), nil - } - - entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration: %w", err) - } - if entry == nil { - return nil, nil - } - - var config DatabaseConfig - if err := entry.DecodeJSON(&config); err != nil { - return nil, err - } - - // Ensure that we only ever include a redacted valid URL in the response. - if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok { - if p, err := url.Parse(connURLRaw.(string)); err == nil { - config.ConnectionDetails["connection_url"] = p.Redacted() - } - } - - if versions.IsBuiltinVersion(config.PluginVersion) { - // This gets treated as though it's empty when mounting, and will get - // overwritten to be empty when the config is next written. See #18051. - config.PluginVersion = "" - } - - delete(config.ConnectionDetails, "password") - delete(config.ConnectionDetails, "private_key") - delete(config.ConnectionDetails, "service_account_json") - - resp := &logical.Response{} - if dbi, err := b.GetConnectionSkipVerify(ctx, req.Storage, name); err == nil { - config.RunningPluginVersion = dbi.runningPluginVersion - if config.PluginVersion != "" && config.PluginVersion != config.RunningPluginVersion { - warning := fmt.Sprintf("Plugin version is configured as %q, but running %q", config.PluginVersion, config.RunningPluginVersion) - if pinnedVersion, _ := b.getPinnedVersion(ctx, config.PluginName); pinnedVersion == config.RunningPluginVersion { - warning += " because that version is pinned" - } else { - warning += " either due to a pinned version or because the plugin was upgraded and not yet reloaded" - } - resp.AddWarning(warning) - } - } - - resp.Data = structs.New(config).Map() - config.PopulateAutomatedRotationData(resp.Data) - // remove extra nested AutomatedRotationParams key - // before returning response - delete(resp.Data, "AutomatedRotationParams") - - recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigRead) - - return resp, nil - } -} - // connectionDeleteHandler deletes the connection configuration func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { @@ -463,230 +394,6 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { } } -// connectionWriteHandler returns a handler function for creating and updating -// both builtin and plugin database types. -func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { - return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - name := data.Get("name").(string) - if name == "" { - return logical.ErrorResponse(respErrEmptyName), nil - } - - // Baseline - config := &DatabaseConfig{ - VerifyConnection: true, - } - - entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name)) - if err != nil { - return nil, fmt.Errorf("failed to read connection configuration: %w", err) - } - if entry != nil { - if err := entry.DecodeJSON(config); err != nil { - return nil, err - } - } - - // If this value was provided as part of the request we want to set it to this value - if verifyConnectionRaw, ok := data.GetOk("verify_connection"); ok { - config.VerifyConnection = verifyConnectionRaw.(bool) - } else if req.Operation == logical.CreateOperation { - config.VerifyConnection = data.Get("verify_connection").(bool) - } - - if pluginNameRaw, ok := data.GetOk("plugin_name"); ok { - config.PluginName = pluginNameRaw.(string) - } else if req.Operation == logical.CreateOperation { - config.PluginName = data.Get("plugin_name").(string) - } - if config.PluginName == "" { - return logical.ErrorResponse(respErrEmptyPluginName), nil - } - - pluginVersion, respErr, err := b.selectPluginVersion(ctx, config, data, req.Operation) - if respErr != nil || err != nil { - return respErr, err - } - - if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok { - config.AllowedRoles = allowedRolesRaw.([]string) - } else if req.Operation == logical.CreateOperation { - config.AllowedRoles = data.Get("allowed_roles").([]string) - } - - if rootRotationStatementsRaw, ok := data.GetOk("root_rotation_statements"); ok { - config.RootCredentialsRotateStatements = rootRotationStatementsRaw.([]string) - } else if req.Operation == logical.CreateOperation { - config.RootCredentialsRotateStatements = data.Get("root_rotation_statements").([]string) - } - - if passwordPolicyRaw, ok := data.GetOk("password_policy"); ok { - config.PasswordPolicy = passwordPolicyRaw.(string) - } - - if skipImportRotationRaw, ok := data.GetOk("skip_static_role_import_rotation"); ok { - config.SkipStaticRoleImportRotation = skipImportRotationRaw.(bool) - } - - if err := config.ParseAutomatedRotationFields(data); err != nil { - return logical.ErrorResponse(err.Error()), nil - } - - // Remove these entries from the data before we store it keyed under - // ConnectionDetails. - delete(data.Raw, "name") - delete(data.Raw, "plugin_name") - delete(data.Raw, "plugin_version") - delete(data.Raw, "allowed_roles") - delete(data.Raw, "verify_connection") - delete(data.Raw, "root_rotation_statements") - delete(data.Raw, "password_policy") - delete(data.Raw, "skip_static_role_import_rotation") - delete(data.Raw, "rotation_schedule") - delete(data.Raw, "rotation_window") - delete(data.Raw, "rotation_period") - delete(data.Raw, "disable_automated_rotation") - - id, err := uuid.GenerateUUID() - if err != nil { - return nil, err - } - - // If this is an update, take any new values, overwrite what was there - // before, and pass that in as the "new" set of values to the plugin, - // then save what results - if req.Operation == logical.CreateOperation { - config.ConnectionDetails = data.Raw - } else { - if config.ConnectionDetails == nil { - config.ConnectionDetails = make(map[string]interface{}) - } - for k, v := range data.Raw { - config.ConnectionDetails[k] = v - } - } - - // Create a database plugin and initialize it. - dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger) - if err != nil { - return logical.ErrorResponse("error creating database object: %s", err), nil - } - - initReq := v5.InitializeRequest{ - Config: config.ConnectionDetails, - VerifyConnection: config.VerifyConnection, - } - initResp, err := dbw.Initialize(ctx, initReq) - if err != nil { - dbw.Close() - return logical.ErrorResponse("error creating database object: %s", err), nil - } - config.ConnectionDetails = initResp.Config - - b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) - - // Close and remove the old connection - oldConn := b.connections.Put(name, &dbPluginInstance{ - database: dbw, - name: name, - id: id, - runningPluginVersion: pluginVersion, - }) - if oldConn != nil { - oldConn.Close() - } - - var performedRotationManagerOpern string - if config.ShouldDeregisterRotationJob() { - performedRotationManagerOpern = rotation.PerformedDeregistration - // Disable Automated Rotation and Deregister credentials if required - deregisterReq := &rotation.RotationJobDeregisterRequest{ - MountPoint: req.MountPoint, - ReqPath: req.Path, - } - - b.Logger().Debug("Deregistering rotation job", "mount", req.MountPoint+req.Path) - if err := b.System().DeregisterRotationJob(ctx, deregisterReq); err != nil { - return logical.ErrorResponse("error deregistering rotation job: %s", err), nil - } - } else if config.ShouldRegisterRotationJob() { - performedRotationManagerOpern = rotation.PerformedRegistration - // Register the rotation job if it's required. - cfgReq := &rotation.RotationJobConfigureRequest{ - MountPoint: req.MountPoint, - ReqPath: req.Path, - RotationSchedule: config.RotationSchedule, - RotationWindow: config.RotationWindow, - RotationPeriod: config.RotationPeriod, - } - - b.Logger().Debug("Registering rotation job", "mount", req.MountPoint+req.Path) - if _, err = b.System().RegisterRotationJob(ctx, cfgReq); err != nil { - return logical.ErrorResponse("error registering rotation job: %s", err), nil - } - } - - // 1.12.0 and 1.12.1 stored builtin plugins in storage, but 1.12.2 reverted - // that, so clean up any pre-existing stored builtin versions on write. - if versions.IsBuiltinVersion(config.PluginVersion) { - config.PluginVersion = "" - } - err = storeConfig(ctx, req.Storage, name, config) - if err != nil { - wrappedError := err - if performedRotationManagerOpern != "" { - b.Logger().Error("write to storage failed but the rotation manager still succeeded.", - "operation", performedRotationManagerOpern, "mount", req.MountPoint, "path", req.Path) - wrappedError = fmt.Errorf("write to storage failed but the rotation manager still succeeded; "+ - "operation=%s, mount=%s, path=%s, storageError=%s", performedRotationManagerOpern, req.MountPoint, req.Path, err) - } - return nil, wrappedError - } - - resp := &logical.Response{} - - // This is a simple test to check for passwords in the connection_url parameter. If one exists, - // warn the user to use templated url string - if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok { - if connURL, err := url.Parse(connURLRaw.(string)); err == nil { - if _, ok := connURL.User.Password(); ok { - resp.AddWarning("Password found in connection_url, use a templated url to enable root rotation and prevent read access to password information.") - } - } - } - - // If using a legacy DB plugin and set the `password_policy` field, send a warning to the user indicating - // the `password_policy` will not be used - if dbw.isV4() && config.PasswordPolicy != "" { - resp.AddWarning(fmt.Sprintf("%s does not support password policies - upgrade to the latest version of "+ - "Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName)) - } - - // We can ignore the error at this point since we're simply adding a warning. - dbType, _ := dbw.Type() - if dbType == "snowflake" && config.ConnectionDetails["password"] != nil { - resp.AddWarning(`[DEPRECATED] Single-factor password authentication is deprecated in Snowflake and will -be removed by November 2025. Key pair authentication will be required after this date. Please -see the Vault documentation for details on the removal of this feature. More information is -available at https://www.snowflake.com/en/blog/blocking-single-factor-password-authentification`) - } - - var rotationPeriodString string - if config.RotationPeriod != 0 { - rotationPeriodString = config.RotationPeriod.String() - } - b.dbEvent(ctx, "config-write", req.Path, name, true) - recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigWrite, - AdditionalDatabaseMetadata{key: "root_rotation_period", value: rotationPeriodString}, - AdditionalDatabaseMetadata{key: "root_rotation_schedule", value: config.RotationSchedule}) - - if len(resp.Warnings) == 0 { - return nil, nil - } - return resp, nil - } -} - func storeConfig(ctx context.Context, storage logical.Storage, name string, config *DatabaseConfig) error { entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) if err != nil { @@ -717,75 +424,6 @@ func (b *databaseBackend) getPinnedVersion(ctx context.Context, pluginName strin return pin.Version, nil } -func (b *databaseBackend) selectPluginVersion(ctx context.Context, config *DatabaseConfig, data *framework.FieldData, op logical.Operation) (string, *logical.Response, error) { - pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName) - if err != nil { - return "", nil, err - } - pluginVersionRaw, ok := data.GetOk("plugin_version") - - switch { - case ok && pinnedVersion != "": - return "", logical.ErrorResponse("cannot specify plugin_version for plugin %q as it is pinned (v%s)", config.PluginName, pinnedVersion), nil - case pinnedVersion != "": - return pinnedVersion, nil, nil - case ok: - config.PluginVersion = pluginVersionRaw.(string) - } - - var builtinShadowed bool - if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin { - builtinShadowed = true - } - switch { - case config.PluginVersion != "": - semanticVersion, err := version.NewVersion(config.PluginVersion) - if err != nil { - return "", logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil - } - - // Canonicalize the version. - config.PluginVersion = "v" + semanticVersion.String() - - if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) { - if builtinShadowed { - return "", logical.ErrorResponse("database plugin %q, version %s not found, as it is"+ - " overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil - } - - config.PluginVersion = "" - } - case builtinShadowed: - // We'll select the unversioned plugin that's been registered. - case op == logical.CreateOperation: - // No version provided and no unversioned plugin of that name available. - // Pin to the current latest version if any versioned plugins are registered. - plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase) - if err != nil { - return "", nil, err - } - - var versionedCandidates []pluginutil.VersionedPlugin - for _, plugin := range plugins { - if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" { - versionedCandidates = append(versionedCandidates, plugin) - } - } - - if len(versionedCandidates) != 0 { - // Sort in reverse order. - sort.SliceStable(versionedCandidates, func(i, j int) bool { - return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion) - }) - - config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String() - b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates)) - } - } - - return config.PluginVersion, nil, nil -} - const pathConfigConnectionHelpSyn = ` Configure connection details to a database plugin. ` diff --git a/builtin/logical/database/path_config_connection_ce.go b/builtin/logical/database/path_config_connection_ce.go index 31a56fd557..c56115bb1a 100644 --- a/builtin/logical/database/path_config_connection_ce.go +++ b/builtin/logical/database/path_config_connection_ce.go @@ -5,9 +5,390 @@ package database -import "github.com/hashicorp/vault/sdk/framework" +import ( + "context" + "fmt" + "net/url" + "sort" + + "github.com/fatih/structs" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/go-version" + "github.com/hashicorp/vault/helper/versions" + v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" + "github.com/hashicorp/vault/sdk/framework" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/hashicorp/vault/sdk/logical" + "github.com/hashicorp/vault/sdk/rotation" +) + +// EntDatabaseConfig is an empty struct for community edition +type EntDatabaseConfig struct{} // AddConnectionFieldsEnt is a no-op for community edition func AddConnectionFieldsEnt(fields map[string]*framework.FieldSchema) { // no-op } + +// connectionWriteHandler returns a handler function for creating and updating +// both builtin and plugin database types for community edition +func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse(respErrEmptyName), nil + } + + // Baseline + config := &DatabaseConfig{ + VerifyConnection: true, + } + + entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration: %w", err) + } + if entry != nil { + if err := entry.DecodeJSON(config); err != nil { + return nil, err + } + } + + // If this value was provided as part of the request we want to set it to this value + if verifyConnectionRaw, ok := data.GetOk("verify_connection"); ok { + config.VerifyConnection = verifyConnectionRaw.(bool) + } else if req.Operation == logical.CreateOperation { + config.VerifyConnection = data.Get("verify_connection").(bool) + } + + if pluginNameRaw, ok := data.GetOk("plugin_name"); ok { + config.PluginName = pluginNameRaw.(string) + } else if req.Operation == logical.CreateOperation { + config.PluginName = data.Get("plugin_name").(string) + } + if config.PluginName == "" { + return logical.ErrorResponse(respErrEmptyPluginName), nil + } + + pluginVersion, respErr, err := b.selectPluginVersion(ctx, config, data, req.Operation) + if respErr != nil || err != nil { + return respErr, err + } + + if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok { + config.AllowedRoles = allowedRolesRaw.([]string) + } else if req.Operation == logical.CreateOperation { + config.AllowedRoles = data.Get("allowed_roles").([]string) + } + + if rootRotationStatementsRaw, ok := data.GetOk("root_rotation_statements"); ok { + config.RootCredentialsRotateStatements = rootRotationStatementsRaw.([]string) + } else if req.Operation == logical.CreateOperation { + config.RootCredentialsRotateStatements = data.Get("root_rotation_statements").([]string) + } + + if passwordPolicyRaw, ok := data.GetOk("password_policy"); ok { + config.PasswordPolicy = passwordPolicyRaw.(string) + } + + if skipImportRotationRaw, ok := data.GetOk("skip_static_role_import_rotation"); ok { + config.SkipStaticRoleImportRotation = skipImportRotationRaw.(bool) + } + + if err := config.ParseAutomatedRotationFields(data); err != nil { + return logical.ErrorResponse(err.Error()), nil + } + + // Remove these entries from the data before we store it keyed under + // ConnectionDetails. + delete(data.Raw, "name") + delete(data.Raw, "plugin_name") + delete(data.Raw, "plugin_version") + delete(data.Raw, "allowed_roles") + delete(data.Raw, "verify_connection") + delete(data.Raw, "root_rotation_statements") + delete(data.Raw, "password_policy") + delete(data.Raw, "skip_static_role_import_rotation") + delete(data.Raw, "rotation_schedule") + delete(data.Raw, "rotation_window") + delete(data.Raw, "rotation_period") + delete(data.Raw, "disable_automated_rotation") + delete(data.Raw, "EntDatabaseConfig") + + id, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + + // If this is an update, take any new values, overwrite what was there + // before, and pass that in as the "new" set of values to the plugin, + // then save what results + if req.Operation == logical.CreateOperation { + config.ConnectionDetails = data.Raw + } else { + if config.ConnectionDetails == nil { + config.ConnectionDetails = make(map[string]interface{}) + } + for k, v := range data.Raw { + config.ConnectionDetails[k] = v + } + } + + // Create a database plugin and initialize it. + dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger) + if err != nil { + return logical.ErrorResponse("error creating database object: %s", err), nil + } + + initReq := v5.InitializeRequest{ + Config: config.ConnectionDetails, + VerifyConnection: config.VerifyConnection, + } + initResp, err := dbw.Initialize(ctx, initReq) + if err != nil { + dbw.Close() + return logical.ErrorResponse("error creating database object: %s", err), nil + } + config.ConnectionDetails = initResp.Config + + b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName) + + // Close and remove the old connection + oldConn := b.connections.Put(name, &dbPluginInstance{ + database: dbw, + name: name, + id: id, + runningPluginVersion: pluginVersion, + }) + if oldConn != nil { + oldConn.Close() + } + + var performedRotationManagerOpern string + if config.ShouldDeregisterRotationJob() { + performedRotationManagerOpern = rotation.PerformedDeregistration + // Disable Automated Rotation and Deregister credentials if required + deregisterReq := &rotation.RotationJobDeregisterRequest{ + MountPoint: req.MountPoint, + ReqPath: req.Path, + } + + b.Logger().Debug("Deregistering rotation job", "mount", req.MountPoint+req.Path) + if err := b.System().DeregisterRotationJob(ctx, deregisterReq); err != nil { + return logical.ErrorResponse("error deregistering rotation job: %s", err), nil + } + } else if config.ShouldRegisterRotationJob() { + performedRotationManagerOpern = rotation.PerformedRegistration + // Register the rotation job if it's required. + cfgReq := &rotation.RotationJobConfigureRequest{ + MountPoint: req.MountPoint, + ReqPath: req.Path, + RotationSchedule: config.RotationSchedule, + RotationWindow: config.RotationWindow, + RotationPeriod: config.RotationPeriod, + } + + b.Logger().Debug("Registering rotation job", "mount", req.MountPoint+req.Path) + if _, err = b.System().RegisterRotationJob(ctx, cfgReq); err != nil { + return logical.ErrorResponse("error registering rotation job: %s", err), nil + } + } + + // 1.12.0 and 1.12.1 stored builtin plugins in storage, but 1.12.2 reverted + // that, so clean up any pre-existing stored builtin versions on write. + if versions.IsBuiltinVersion(config.PluginVersion) { + config.PluginVersion = "" + } + err = storeConfig(ctx, req.Storage, name, config) + if err != nil { + wrappedError := err + if performedRotationManagerOpern != "" { + b.Logger().Error("write to storage failed but the rotation manager still succeeded.", + "operation", performedRotationManagerOpern, "mount", req.MountPoint, "path", req.Path) + wrappedError = fmt.Errorf("write to storage failed but the rotation manager still succeeded; "+ + "operation=%s, mount=%s, path=%s, storageError=%s", performedRotationManagerOpern, req.MountPoint, req.Path, err) + } + return nil, wrappedError + } + + resp := &logical.Response{} + + // This is a simple test to check for passwords in the connection_url parameter. If one exists, + // warn the user to use templated url string + if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok { + if connURL, err := url.Parse(connURLRaw.(string)); err == nil { + if _, ok := connURL.User.Password(); ok { + resp.AddWarning("Password found in connection_url, use a templated url to enable root rotation and prevent read access to password information.") + } + } + } + + // If using a legacy DB plugin and set the `password_policy` field, send a warning to the user indicating + // the `password_policy` will not be used + if dbw.isV4() && config.PasswordPolicy != "" { + resp.AddWarning(fmt.Sprintf("%s does not support password policies - upgrade to the latest version of "+ + "Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName)) + } + + // We can ignore the error at this point since we're simply adding a warning. + dbType, _ := dbw.Type() + if dbType == "snowflake" && config.ConnectionDetails["password"] != nil { + resp.AddWarning(`[DEPRECATED] Single-factor password authentication is deprecated in Snowflake and will +be removed by November 2025. Key pair authentication will be required after this date. Please +see the Vault documentation for details on the removal of this feature. More information is +available at https://www.snowflake.com/en/blog/blocking-single-factor-password-authentification`) + } + + var rotationPeriodString string + if config.RotationPeriod != 0 { + rotationPeriodString = config.RotationPeriod.String() + } + b.dbEvent(ctx, "config-write", req.Path, name, true) + recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigWrite, + AdditionalDatabaseMetadata{key: "root_rotation_period", value: rotationPeriodString}, + AdditionalDatabaseMetadata{key: "root_rotation_schedule", value: config.RotationSchedule}) + + if len(resp.Warnings) == 0 { + return nil, nil + } + return resp, nil + } +} + +// selectPluginVersion returns the appropriate plugin version for community edition +func (b *databaseBackend) selectPluginVersion(ctx context.Context, config *DatabaseConfig, data *framework.FieldData, op logical.Operation) (string, *logical.Response, error) { + pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName) + if err != nil { + return "", nil, err + } + pluginVersionRaw, ok := data.GetOk("plugin_version") + + switch { + case ok && pinnedVersion != "": + return "", logical.ErrorResponse("cannot specify plugin_version for plugin %q as it is pinned (%s)", config.PluginName, pinnedVersion), nil + case pinnedVersion != "": + return pinnedVersion, nil, nil + case ok: + config.PluginVersion = pluginVersionRaw.(string) + } + + var builtinShadowed bool + if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin { + builtinShadowed = true + } + switch { + case config.PluginVersion != "": + semanticVersion, err := version.NewVersion(config.PluginVersion) + if err != nil { + return "", logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil + } + + // Canonicalize the version. + config.PluginVersion = "v" + semanticVersion.String() + + if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) { + if builtinShadowed { + return "", logical.ErrorResponse("database plugin %q, version %s not found, as it is"+ + " overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil + } + + config.PluginVersion = "" + } + case builtinShadowed: + // We'll select the unversioned plugin that's been registered. + case op == logical.CreateOperation: + // No version provided and no unversioned plugin of that name available. + // Pin to the current latest version if any versioned plugins are registered. + plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase) + if err != nil { + return "", nil, err + } + + var versionedCandidates []pluginutil.VersionedPlugin + for _, plugin := range plugins { + if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" { + versionedCandidates = append(versionedCandidates, plugin) + } + } + + if len(versionedCandidates) != 0 { + // Sort in reverse order. + sort.SliceStable(versionedCandidates, func(i, j int) bool { + return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion) + }) + + config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String() + b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates)) + } + } + + return config.PluginVersion, nil, nil +} + +// connectionReadHandler reads out the connection configuration +func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { + return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + name := data.Get("name").(string) + if name == "" { + return logical.ErrorResponse(respErrEmptyName), nil + } + + entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name)) + if err != nil { + return nil, fmt.Errorf("failed to read connection configuration: %w", err) + } + if entry == nil { + return nil, nil + } + + var config DatabaseConfig + if err := entry.DecodeJSON(&config); err != nil { + return nil, err + } + + // Ensure that we only ever include a redacted valid URL in the response. + if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok { + if p, err := url.Parse(connURLRaw.(string)); err == nil { + config.ConnectionDetails["connection_url"] = p.Redacted() + } + } + + if versions.IsBuiltinVersion(config.PluginVersion) { + // This gets treated as though it's empty when mounting, and will get + // overwritten to be empty when the config is next written. See #18051. + config.PluginVersion = "" + } + + delete(config.ConnectionDetails, "password") + delete(config.ConnectionDetails, "private_key") + delete(config.ConnectionDetails, "service_account_json") + + resp := &logical.Response{} + if dbi, err := b.GetConnectionSkipVerify(ctx, req.Storage, name); err == nil { + config.RunningPluginVersion = dbi.runningPluginVersion + if config.PluginVersion != "" && config.PluginVersion != config.RunningPluginVersion { + warning := fmt.Sprintf("Plugin version is configured as %q, but running %q", config.PluginVersion, config.RunningPluginVersion) + if pinnedVersion, _ := b.getPinnedVersion(ctx, config.PluginName); pinnedVersion == config.RunningPluginVersion { + warning += " because that version is pinned" + } else { + warning += " either due to a pinned version or because the plugin was upgraded and not yet reloaded" + } + resp.AddWarning(warning) + } + } + + resp.Data = structs.New(config).Map() + config.PopulateAutomatedRotationData(resp.Data) + // remove extra nested AutomatedRotationParams key + // before returning response + delete(resp.Data, "AutomatedRotationParams") + + // remove nested EntDatabaseConfig key before returning response + delete(resp.Data, "EntDatabaseConfig") + + recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigRead) + + return resp, nil + } +} diff --git a/changelog/_10517.txt b/changelog/_10517.txt new file mode 100644 index 0000000000..a084b4d362 --- /dev/null +++ b/changelog/_10517.txt @@ -0,0 +1,3 @@ +```release-note:feature +**Plugins (Enterprise)**: Allow overriding pinned version when creating and updating database engines +``` \ No newline at end of file