Backport Add override_pinned_version support on configure connection for database into ce/main (#10860)

* Add override_pinned_version support on configure connection for database (#10517)

* add DatabaseConfigEnt and split ce-ent impl for connectionWriteHandler() and selectPluginVersion()

* add override_pinned_version handling in connectionWriteHandler() and selectPluginVersion()

* split ce-ent impl for connectionReadHandler() to support override_pinned_version

* split ce-ent impl for databaseBackend.GetConnectionWithConfig() to support override_pinned_version

* split TestBackend_* units related to databased connection config CRUD into ce and ent

* remove EntDatabaseConfig from response

---------

Co-authored-by: Thy Ton <maithytonn@gmail.com>
This commit is contained in:
Vault Automation 2025-12-01 18:18:26 -05:00 committed by GitHub
parent 004d6da92c
commit ff96dceedd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 917 additions and 858 deletions

View File

@ -17,7 +17,6 @@ import (
"github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/strutil"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/builtin/logical/database/schedule"
"github.com/hashicorp/vault/helper/metricsutil"
"github.com/hashicorp/vault/helper/syncmap"
@ -343,73 +342,6 @@ func (b *databaseBackend) GetConnectionSkipVerify(ctx context.Context, s logical
return b.GetConnectionWithConfig(ctx, name, config)
}
func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) {
// fast path, reuse the existing connection
dbi := b.connections.Get(name)
if dbi != nil {
return dbi, nil
}
// slow path, create a new connection
// if we don't lock the rest of the operation, there is a race condition for multiple callers of this function
b.createConnectionLock.Lock()
defer b.createConnectionLock.Unlock()
// check again in case we lost the race
dbi = b.connections.Get(name)
if dbi != nil {
return dbi, nil
}
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// Override the configured version if there is a pinned version.
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return nil, err
}
pluginVersion := config.PluginVersion
if pinnedVersion != "" {
pluginVersion = pinnedVersion
}
dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil {
return nil, fmt.Errorf("unable to create database instance: %w", err)
}
initReq := v5.InitializeRequest{
Config: config.ConnectionDetails,
VerifyConnection: config.VerifyConnection,
}
_, err = dbw.Initialize(ctx, initReq)
if err != nil {
dbw.Close()
return nil, err
}
dbi = &dbPluginInstance{
database: dbw,
id: id,
name: name,
runningPluginVersion: pluginVersion,
}
conn, ok := b.connections.PutIfEmpty(name, dbi)
if !ok {
// this is a bug
b.Logger().Warn("BUG: there was a race condition adding to the database connection map")
// There was already an existing connection, so we will use that and close our new one to avoid a race condition.
err := dbi.Close()
if err != nil {
b.Logger().Warn("Error closing new database connection", "error", err)
}
}
return conn, nil
}
// ClearConnection closes the database connection and
// removes it from the b.connections map.
func (b *databaseBackend) ClearConnection(name string) error {

View File

@ -0,0 +1,82 @@
// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package database
import (
"context"
"fmt"
"github.com/hashicorp/go-uuid"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
)
// GetConnectionWithConfig gets or creates a database connection with the given config for community edition
func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name string, config *DatabaseConfig) (*dbPluginInstance, error) {
// fast path, reuse the existing connection
dbi := b.connections.Get(name)
if dbi != nil {
return dbi, nil
}
// slow path, create a new connection
// if we don't lock the rest of the operation, there is a race condition for multiple callers of this function
b.createConnectionLock.Lock()
defer b.createConnectionLock.Unlock()
// check again in case we lost the race
dbi = b.connections.Get(name)
if dbi != nil {
return dbi, nil
}
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// Override the configured version if there is a pinned version.
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return nil, err
}
pluginVersion := config.PluginVersion
if pinnedVersion != "" {
pluginVersion = pinnedVersion
}
dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil {
return nil, fmt.Errorf("unable to create database instance: %w", err)
}
initReq := v5.InitializeRequest{
Config: config.ConnectionDetails,
VerifyConnection: config.VerifyConnection,
}
_, err = dbw.Initialize(ctx, initReq)
if err != nil {
dbw.Close()
return nil, err
}
dbi = &dbPluginInstance{
database: dbw,
id: id,
name: name,
runningPluginVersion: pluginVersion,
}
conn, ok := b.connections.PutIfEmpty(name, dbi)
if !ok {
// this is a bug
b.Logger().Warn("BUG: there was a race condition adding to the database connection map")
// There was already an existing connection, so we will use that and close our new one to avoid a race condition.
err := dbi.Close()
if err != nil {
b.Logger().Warn("Error closing new database connection", "error", err)
}
}
return conn, nil
}

View File

@ -0,0 +1,448 @@
// Copyright IBM Corp. 2016, 2025
// SPDX-License-Identifier: BUSL-1.1
//go:build !enterprise
package database
import (
"context"
"encoding/json"
"reflect"
"strings"
"testing"
"time"
"github.com/go-test/deep"
"github.com/hashicorp/vault/helper/namespace"
postgreshelper "github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
_ "github.com/jackc/pgx/v4"
"github.com/stretchr/testify/assert"
)
// TestBackend_config_connection tests the configuration of a database connection
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
cluster, sys := getClusterPostgresDB(t)
defer cluster.Cleanup()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
eventSender := logical.NewMockEventSender()
config.EventsSender = eventSender
lb, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
b, ok := lb.(*databaseBackend)
if !ok {
t.Fatal("could not convert to database backend")
}
defer b.Cleanup(context.Background())
// Test creation
{
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"someotherdata": "testing",
"plugin_name": "postgresql-database-plugin",
"verify_connection": false,
"allowed_roles": []string{"*"},
"name": "plugin-test",
}
configReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: configData,
}
exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{
Raw: configData,
Schema: pathConfigurePluginConnection(b).Fields,
})
if err != nil {
t.Fatal(err)
}
if exists {
t.Fatal("expected not exists")
}
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"connection_url": "sample_connection_url",
"someotherdata": "testing",
},
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": time.Duration(0).Seconds(),
"rotation_window": time.Duration(0).Seconds(),
"disable_automated_rotation": false,
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
}
// Test existence check and an update to a single connection detail parameter
{
configData := map[string]interface{}{
"connection_url": "sample_convection_url",
"verify_connection": false,
"name": "plugin-test",
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: configData,
}
exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{
Raw: configData,
Schema: pathConfigurePluginConnection(b).Fields,
})
if err != nil {
t.Fatal(err)
}
if !exists {
t.Fatal("expected exists")
}
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"connection_url": "sample_convection_url",
"someotherdata": "testing",
},
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": time.Duration(0).Seconds(),
"rotation_window": time.Duration(0).Seconds(),
"disable_automated_rotation": false,
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
delete(resp.Data, "AutomatedRotationParams")
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
}
// Test an update to a non-details value
{
configData := map[string]interface{}{
"verify_connection": false,
"allowed_roles": []string{"flu", "barre"},
"name": "plugin-test",
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"connection_url": "sample_convection_url",
"someotherdata": "testing",
},
"allowed_roles": []string{"flu", "barre"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": time.Duration(0).Seconds(),
"rotation_window": time.Duration(0).Seconds(),
"disable_automated_rotation": false,
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
delete(resp.Data, "AutomatedRotationParams")
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
}
req := &logical.Request{
Operation: logical.ListOperation,
Storage: config.StorageView,
Path: "config/",
}
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatal(err)
}
keys := resp.Data["keys"].([]string)
key := keys[0]
if key != "plugin-test" {
t.Fatalf("bad key: %q", key)
}
assert.Equal(t, 3, len(eventSender.Events))
assert.Equal(t, "database/config-write", string(eventSender.Events[0].Type))
assert.Equal(t, "config/plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["path"])
assert.Equal(t, "plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["name"])
assert.Equal(t, "database/config-write", string(eventSender.Events[1].Type))
assert.Equal(t, "config/plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["path"])
assert.Equal(t, "plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["name"])
assert.Equal(t, "database/config-write", string(eventSender.Events[2].Type))
assert.Equal(t, "config/plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["path"])
assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"])
}
// TestBackend_connectionCrud tests the full CRUD lifecycle of a database connection
func TestBackend_connectionCrud(t *testing.T) {
t.Parallel()
dbFactory := &singletonDBFactory{}
cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory)
defer cluster.Cleanup()
dbFactory.sys = sys
client := cluster.Cores[0].Client.Logical()
cleanup, connURL := postgreshelper.PrepareTestContainer(t)
defer cleanup()
// Mount the database plugin.
resp, err := client.Write("sys/mounts/database", map[string]interface{}{
"type": "database",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Configure a connection
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
"connection_url": "test",
"plugin_name": "postgresql-database-plugin",
"verify_connection": false,
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Configure a second connection to confirm below it doesn't get restarted.
resp, err = client.Write("database/config/plugin-test-hana", map[string]interface{}{
"connection_url": "test",
"plugin_name": "hana-database-plugin",
"verify_connection": false,
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Create a role
resp, err = client.Write("database/roles/plugin-role-test", map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testRole,
"revocation_statements": defaultRevocationSQL,
"default_ttl": "5m",
"max_ttl": "10m",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Update the connection
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
"connection_url": connURL,
"plugin_name": "postgresql-database-plugin",
"allowed_roles": []string{"plugin-role-test"},
"username": "postgres",
"password": "secret",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if len(resp.Warnings) == 0 {
t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp)
}
resp, err = client.Read("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{})
if strings.Contains(returnedConnectionDetails["connection_url"].(string), "secret") {
t.Fatal("password should not be found in the connection url")
}
// Covered by the filled out `expected` value below, but be explicit about this requirement.
if _, exists := returnedConnectionDetails["password"]; exists {
t.Fatal("password should NOT be found in the returned config")
}
// Replace connection url with templated version
templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}")
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
"connection_url": templatedConnURL,
"plugin_name": "postgresql-database-plugin",
"allowed_roles": []string{"plugin-role-test"},
"username": "postgres",
"password": "secret",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Read connection
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"username": "postgres",
"connection_url": templatedConnURL,
},
"allowed_roles": []any{"plugin-role-test"},
"root_credentials_rotate_statements": []any{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": json.Number("0"),
"rotation_window": json.Number("0"),
"disable_automated_rotation": false,
}
resp, err = client.Read("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
delete(resp.Data, "AutomatedRotationParams")
if diff := deep.Equal(resp.Data, expected); diff != nil {
t.Fatal(strings.Join(diff, "\n"))
}
// Test endpoints for reloading plugins.
for _, reload := range []struct {
path string
data map[string]any
checkCount bool
}{
{"database/reset/plugin-test", nil, false},
{"database/reload/postgresql-database-plugin", nil, true},
{"sys/plugins/reload/backend", map[string]any{
"plugin": "postgresql-database-plugin",
}, false},
} {
getConnectionID := func(name string) string {
t.Helper()
dbi := dbFactory.db.connections.Get(name)
if dbi == nil {
t.Fatal("no plugin-test dbi")
}
return dbi.ID()
}
initialID := getConnectionID("plugin-test")
hanaID := getConnectionID("plugin-test-hana")
resp, err = client.Write(reload.path, reload.data)
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if initialID == getConnectionID("plugin-test") {
t.Fatal("ID unchanged after connection reset")
}
if hanaID != getConnectionID("plugin-test-hana") {
t.Fatal("hana plugin got restarted but shouldn't have been")
}
if reload.checkCount {
actual, err := resp.Data["count"].(json.Number).Int64()
if err != nil {
t.Fatal(err)
}
if expected := 1; expected != int(actual) {
t.Fatalf("expected %d but got %d", expected, resp.Data["count"].(int))
}
if expected := []any{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) {
t.Fatalf("expected %v but got %v", expected, resp.Data["connections"])
}
}
}
// Get creds
credsResp, err := client.Read("database/creds/plugin-role-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
}
credCheckURL := dbutil.QueryHelper(templatedConnURL, map[string]string{
"username": "postgres",
"password": "secret",
})
if !testCredsExist(t, credsResp.Data, credCheckURL) {
t.Fatalf("Creds should exist")
}
// Delete Connection
resp, err = client.Delete("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Read connection
resp, err = client.Read("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Should be empty
if resp != nil {
t.Fatal("Expected response to be nil")
}
}

View File

@ -6,7 +6,6 @@ package database
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
@ -32,8 +31,6 @@ import (
"github.com/hashicorp/vault/plugins/database/postgresql"
v4 "github.com/hashicorp/vault/sdk/database/dbplugin"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
@ -157,228 +154,6 @@ func TestBackend_RoleUpgrade(t *testing.T) {
}
}
func TestBackend_config_connection(t *testing.T) {
var resp *logical.Response
var err error
cluster, sys := getClusterPostgresDB(t)
defer cluster.Cleanup()
config := logical.TestBackendConfig()
config.StorageView = &logical.InmemStorage{}
config.System = sys
eventSender := logical.NewMockEventSender()
config.EventsSender = eventSender
lb, err := Factory(context.Background(), config)
if err != nil {
t.Fatal(err)
}
b, ok := lb.(*databaseBackend)
if !ok {
t.Fatal("could not convert to database backend")
}
defer b.Cleanup(context.Background())
// Test creation
{
configData := map[string]interface{}{
"connection_url": "sample_connection_url",
"someotherdata": "testing",
"plugin_name": "postgresql-database-plugin",
"verify_connection": false,
"allowed_roles": []string{"*"},
"name": "plugin-test",
}
configReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: configData,
}
exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{
Raw: configData,
Schema: pathConfigurePluginConnection(b).Fields,
})
if err != nil {
t.Fatal(err)
}
if exists {
t.Fatal("expected not exists")
}
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"connection_url": "sample_connection_url",
"someotherdata": "testing",
},
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": time.Duration(0).Seconds(),
"rotation_window": time.Duration(0).Seconds(),
"disable_automated_rotation": false,
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
}
// Test existence check and an update to a single connection detail parameter
{
configData := map[string]interface{}{
"connection_url": "sample_convection_url",
"verify_connection": false,
"name": "plugin-test",
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: configData,
}
exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{
Raw: configData,
Schema: pathConfigurePluginConnection(b).Fields,
})
if err != nil {
t.Fatal(err)
}
if !exists {
t.Fatal("expected exists")
}
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"connection_url": "sample_convection_url",
"someotherdata": "testing",
},
"allowed_roles": []string{"*"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": time.Duration(0).Seconds(),
"rotation_window": time.Duration(0).Seconds(),
"disable_automated_rotation": false,
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
delete(resp.Data, "AutomatedRotationParams")
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
}
// Test an update to a non-details value
{
configData := map[string]interface{}{
"verify_connection": false,
"allowed_roles": []string{"flu", "barre"},
"name": "plugin-test",
}
configReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config/plugin-test",
Storage: config.StorageView,
Data: configData,
}
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v\n", err, resp)
}
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"connection_url": "sample_convection_url",
"someotherdata": "testing",
},
"allowed_roles": []string{"flu", "barre"},
"root_credentials_rotate_statements": []string{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": time.Duration(0).Seconds(),
"rotation_window": time.Duration(0).Seconds(),
"disable_automated_rotation": false,
}
configReq.Operation = logical.ReadOperation
resp, err = b.HandleRequest(namespace.RootContext(nil), configReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
delete(resp.Data, "AutomatedRotationParams")
if !reflect.DeepEqual(expected, resp.Data) {
t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data)
}
}
req := &logical.Request{
Operation: logical.ListOperation,
Storage: config.StorageView,
Path: "config/",
}
resp, err = b.HandleRequest(namespace.RootContext(nil), req)
if err != nil {
t.Fatal(err)
}
keys := resp.Data["keys"].([]string)
key := keys[0]
if key != "plugin-test" {
t.Fatalf("bad key: %q", key)
}
assert.Equal(t, 3, len(eventSender.Events))
assert.Equal(t, "database/config-write", string(eventSender.Events[0].Type))
assert.Equal(t, "config/plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["path"])
assert.Equal(t, "plugin-test", eventSender.Events[0].Event.Metadata.AsMap()["name"])
assert.Equal(t, "database/config-write", string(eventSender.Events[1].Type))
assert.Equal(t, "config/plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["path"])
assert.Equal(t, "plugin-test", eventSender.Events[1].Event.Metadata.AsMap()["name"])
assert.Equal(t, "database/config-write", string(eventSender.Events[2].Type))
assert.Equal(t, "config/plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["path"])
assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"])
}
// TestBackend_BadConnectionString tests that an error response resulting from
// a failed connection does not expose the URL. The middleware should sanitize it.
func TestBackend_BadConnectionString(t *testing.T) {
@ -693,206 +468,6 @@ func (s *singletonDBFactory) factory(context.Context, *logical.BackendConfig) (l
return s.db, nil
}
func TestBackend_connectionCrud(t *testing.T) {
t.Parallel()
dbFactory := &singletonDBFactory{}
cluster, sys := getClusterPostgresDBWithFactory(t, dbFactory.factory)
defer cluster.Cleanup()
dbFactory.sys = sys
client := cluster.Cores[0].Client.Logical()
cleanup, connURL := postgreshelper.PrepareTestContainer(t)
defer cleanup()
// Mount the database plugin.
resp, err := client.Write("sys/mounts/database", map[string]interface{}{
"type": "database",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Configure a connection
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
"connection_url": "test",
"plugin_name": "postgresql-database-plugin",
"verify_connection": false,
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Configure a second connection to confirm below it doesn't get restarted.
resp, err = client.Write("database/config/plugin-test-hana", map[string]interface{}{
"connection_url": "test",
"plugin_name": "hana-database-plugin",
"verify_connection": false,
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Create a role
resp, err = client.Write("database/roles/plugin-role-test", map[string]interface{}{
"db_name": "plugin-test",
"creation_statements": testRole,
"revocation_statements": defaultRevocationSQL,
"default_ttl": "5m",
"max_ttl": "10m",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Update the connection
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
"connection_url": connURL,
"plugin_name": "postgresql-database-plugin",
"allowed_roles": []string{"plugin-role-test"},
"username": "postgres",
"password": "secret",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if len(resp.Warnings) == 0 {
t.Fatalf("expected warning about password in url %s, resp:%#v\n", connURL, resp)
}
resp, err = client.Read("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
returnedConnectionDetails := resp.Data["connection_details"].(map[string]interface{})
if strings.Contains(returnedConnectionDetails["connection_url"].(string), "secret") {
t.Fatal("password should not be found in the connection url")
}
// Covered by the filled out `expected` value below, but be explicit about this requirement.
if _, exists := returnedConnectionDetails["password"]; exists {
t.Fatal("password should NOT be found in the returned config")
}
// Replace connection url with templated version
templatedConnURL := strings.ReplaceAll(connURL, "postgres:secret", "{{username}}:{{password}}")
resp, err = client.Write("database/config/plugin-test", map[string]interface{}{
"connection_url": templatedConnURL,
"plugin_name": "postgresql-database-plugin",
"allowed_roles": []string{"plugin-role-test"},
"username": "postgres",
"password": "secret",
})
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Read connection
expected := map[string]interface{}{
"plugin_name": "postgresql-database-plugin",
"connection_details": map[string]interface{}{
"username": "postgres",
"connection_url": templatedConnURL,
},
"allowed_roles": []any{"plugin-role-test"},
"root_credentials_rotate_statements": []any{},
"password_policy": "",
"plugin_version": "",
"verify_connection": false,
"skip_static_role_import_rotation": false,
"rotation_schedule": "",
"rotation_period": json.Number("0"),
"rotation_window": json.Number("0"),
"disable_automated_rotation": false,
}
resp, err = client.Read("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
delete(resp.Data["connection_details"].(map[string]interface{}), "name")
delete(resp.Data, "AutomatedRotationParams")
if diff := deep.Equal(resp.Data, expected); diff != nil {
t.Fatal(strings.Join(diff, "\n"))
}
// Test endpoints for reloading plugins.
for _, reload := range []struct {
path string
data map[string]any
checkCount bool
}{
{"database/reset/plugin-test", nil, false},
{"database/reload/postgresql-database-plugin", nil, true},
{"sys/plugins/reload/backend", map[string]any{
"plugin": "postgresql-database-plugin",
}, false},
} {
getConnectionID := func(name string) string {
t.Helper()
dbi := dbFactory.db.connections.Get(name)
if dbi == nil {
t.Fatal("no plugin-test dbi")
}
return dbi.ID()
}
initialID := getConnectionID("plugin-test")
hanaID := getConnectionID("plugin-test-hana")
resp, err = client.Write(reload.path, reload.data)
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
if initialID == getConnectionID("plugin-test") {
t.Fatal("ID unchanged after connection reset")
}
if hanaID != getConnectionID("plugin-test-hana") {
t.Fatal("hana plugin got restarted but shouldn't have been")
}
if reload.checkCount {
actual, err := resp.Data["count"].(json.Number).Int64()
if err != nil {
t.Fatal(err)
}
if expected := 1; expected != int(actual) {
t.Fatalf("expected %d but got %d", expected, resp.Data["count"].(int))
}
if expected := []any{"plugin-test"}; !reflect.DeepEqual(expected, resp.Data["connections"]) {
t.Fatalf("expected %v but got %v", expected, resp.Data["connections"])
}
}
}
// Get creds
credsResp, err := client.Read("database/creds/plugin-role-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, credsResp)
}
credCheckURL := dbutil.QueryHelper(templatedConnURL, map[string]string{
"username": "postgres",
"password": "secret",
})
if !testCredsExist(t, credsResp.Data, credCheckURL) {
t.Fatalf("Creds should exist")
}
// Delete Connection
resp, err = client.Delete("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Read connection
resp, err = client.Read("database/config/plugin-test")
if err != nil {
t.Fatalf("err:%s resp:%#v\n", err, resp)
}
// Should be empty
if resp != nil {
t.Fatal("Expected response to be nil")
}
}
func TestBackend_connectionSanitizePrivateKey(t *testing.T) {
t.Parallel()
dbFactory := &singletonDBFactory{}

View File

@ -7,20 +7,13 @@ import (
"context"
"errors"
"fmt"
"net/url"
"sort"
"github.com/fatih/structs"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-version"
"github.com/hashicorp/vault/helper/versions"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/automatedrotationutil"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/rotation"
)
var (
@ -31,6 +24,8 @@ var (
// DatabaseConfig is used by the Factory function to configure a Database
// object.
type DatabaseConfig struct {
EntDatabaseConfig `mapstructure:",squash"`
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
PluginVersion string `json:"plugin_version" structs:"plugin_version" mapstructure:"plugin_version"`
RunningPluginVersion string `json:"running_plugin_version,omitempty" structs:"running_plugin_version,omitempty" mapstructure:"running_plugin_version,omitempty"`
@ -376,70 +371,6 @@ func (b *databaseBackend) connectionListHandler() framework.OperationFunc {
}
}
// connectionReadHandler reads out the connection configuration
func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse(respErrEmptyName), nil
}
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
}
if entry == nil {
return nil, nil
}
var config DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
// Ensure that we only ever include a redacted valid URL in the response.
if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok {
if p, err := url.Parse(connURLRaw.(string)); err == nil {
config.ConnectionDetails["connection_url"] = p.Redacted()
}
}
if versions.IsBuiltinVersion(config.PluginVersion) {
// This gets treated as though it's empty when mounting, and will get
// overwritten to be empty when the config is next written. See #18051.
config.PluginVersion = ""
}
delete(config.ConnectionDetails, "password")
delete(config.ConnectionDetails, "private_key")
delete(config.ConnectionDetails, "service_account_json")
resp := &logical.Response{}
if dbi, err := b.GetConnectionSkipVerify(ctx, req.Storage, name); err == nil {
config.RunningPluginVersion = dbi.runningPluginVersion
if config.PluginVersion != "" && config.PluginVersion != config.RunningPluginVersion {
warning := fmt.Sprintf("Plugin version is configured as %q, but running %q", config.PluginVersion, config.RunningPluginVersion)
if pinnedVersion, _ := b.getPinnedVersion(ctx, config.PluginName); pinnedVersion == config.RunningPluginVersion {
warning += " because that version is pinned"
} else {
warning += " either due to a pinned version or because the plugin was upgraded and not yet reloaded"
}
resp.AddWarning(warning)
}
}
resp.Data = structs.New(config).Map()
config.PopulateAutomatedRotationData(resp.Data)
// remove extra nested AutomatedRotationParams key
// before returning response
delete(resp.Data, "AutomatedRotationParams")
recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigRead)
return resp, nil
}
}
// connectionDeleteHandler deletes the connection configuration
func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
@ -463,230 +394,6 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
}
}
// connectionWriteHandler returns a handler function for creating and updating
// both builtin and plugin database types.
func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse(respErrEmptyName), nil
}
// Baseline
config := &DatabaseConfig{
VerifyConnection: true,
}
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
}
if entry != nil {
if err := entry.DecodeJSON(config); err != nil {
return nil, err
}
}
// If this value was provided as part of the request we want to set it to this value
if verifyConnectionRaw, ok := data.GetOk("verify_connection"); ok {
config.VerifyConnection = verifyConnectionRaw.(bool)
} else if req.Operation == logical.CreateOperation {
config.VerifyConnection = data.Get("verify_connection").(bool)
}
if pluginNameRaw, ok := data.GetOk("plugin_name"); ok {
config.PluginName = pluginNameRaw.(string)
} else if req.Operation == logical.CreateOperation {
config.PluginName = data.Get("plugin_name").(string)
}
if config.PluginName == "" {
return logical.ErrorResponse(respErrEmptyPluginName), nil
}
pluginVersion, respErr, err := b.selectPluginVersion(ctx, config, data, req.Operation)
if respErr != nil || err != nil {
return respErr, err
}
if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok {
config.AllowedRoles = allowedRolesRaw.([]string)
} else if req.Operation == logical.CreateOperation {
config.AllowedRoles = data.Get("allowed_roles").([]string)
}
if rootRotationStatementsRaw, ok := data.GetOk("root_rotation_statements"); ok {
config.RootCredentialsRotateStatements = rootRotationStatementsRaw.([]string)
} else if req.Operation == logical.CreateOperation {
config.RootCredentialsRotateStatements = data.Get("root_rotation_statements").([]string)
}
if passwordPolicyRaw, ok := data.GetOk("password_policy"); ok {
config.PasswordPolicy = passwordPolicyRaw.(string)
}
if skipImportRotationRaw, ok := data.GetOk("skip_static_role_import_rotation"); ok {
config.SkipStaticRoleImportRotation = skipImportRotationRaw.(bool)
}
if err := config.ParseAutomatedRotationFields(data); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Remove these entries from the data before we store it keyed under
// ConnectionDetails.
delete(data.Raw, "name")
delete(data.Raw, "plugin_name")
delete(data.Raw, "plugin_version")
delete(data.Raw, "allowed_roles")
delete(data.Raw, "verify_connection")
delete(data.Raw, "root_rotation_statements")
delete(data.Raw, "password_policy")
delete(data.Raw, "skip_static_role_import_rotation")
delete(data.Raw, "rotation_schedule")
delete(data.Raw, "rotation_window")
delete(data.Raw, "rotation_period")
delete(data.Raw, "disable_automated_rotation")
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// If this is an update, take any new values, overwrite what was there
// before, and pass that in as the "new" set of values to the plugin,
// then save what results
if req.Operation == logical.CreateOperation {
config.ConnectionDetails = data.Raw
} else {
if config.ConnectionDetails == nil {
config.ConnectionDetails = make(map[string]interface{})
}
for k, v := range data.Raw {
config.ConnectionDetails[k] = v
}
}
// Create a database plugin and initialize it.
dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse("error creating database object: %s", err), nil
}
initReq := v5.InitializeRequest{
Config: config.ConnectionDetails,
VerifyConnection: config.VerifyConnection,
}
initResp, err := dbw.Initialize(ctx, initReq)
if err != nil {
dbw.Close()
return logical.ErrorResponse("error creating database object: %s", err), nil
}
config.ConnectionDetails = initResp.Config
b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName)
// Close and remove the old connection
oldConn := b.connections.Put(name, &dbPluginInstance{
database: dbw,
name: name,
id: id,
runningPluginVersion: pluginVersion,
})
if oldConn != nil {
oldConn.Close()
}
var performedRotationManagerOpern string
if config.ShouldDeregisterRotationJob() {
performedRotationManagerOpern = rotation.PerformedDeregistration
// Disable Automated Rotation and Deregister credentials if required
deregisterReq := &rotation.RotationJobDeregisterRequest{
MountPoint: req.MountPoint,
ReqPath: req.Path,
}
b.Logger().Debug("Deregistering rotation job", "mount", req.MountPoint+req.Path)
if err := b.System().DeregisterRotationJob(ctx, deregisterReq); err != nil {
return logical.ErrorResponse("error deregistering rotation job: %s", err), nil
}
} else if config.ShouldRegisterRotationJob() {
performedRotationManagerOpern = rotation.PerformedRegistration
// Register the rotation job if it's required.
cfgReq := &rotation.RotationJobConfigureRequest{
MountPoint: req.MountPoint,
ReqPath: req.Path,
RotationSchedule: config.RotationSchedule,
RotationWindow: config.RotationWindow,
RotationPeriod: config.RotationPeriod,
}
b.Logger().Debug("Registering rotation job", "mount", req.MountPoint+req.Path)
if _, err = b.System().RegisterRotationJob(ctx, cfgReq); err != nil {
return logical.ErrorResponse("error registering rotation job: %s", err), nil
}
}
// 1.12.0 and 1.12.1 stored builtin plugins in storage, but 1.12.2 reverted
// that, so clean up any pre-existing stored builtin versions on write.
if versions.IsBuiltinVersion(config.PluginVersion) {
config.PluginVersion = ""
}
err = storeConfig(ctx, req.Storage, name, config)
if err != nil {
wrappedError := err
if performedRotationManagerOpern != "" {
b.Logger().Error("write to storage failed but the rotation manager still succeeded.",
"operation", performedRotationManagerOpern, "mount", req.MountPoint, "path", req.Path)
wrappedError = fmt.Errorf("write to storage failed but the rotation manager still succeeded; "+
"operation=%s, mount=%s, path=%s, storageError=%s", performedRotationManagerOpern, req.MountPoint, req.Path, err)
}
return nil, wrappedError
}
resp := &logical.Response{}
// This is a simple test to check for passwords in the connection_url parameter. If one exists,
// warn the user to use templated url string
if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok {
if connURL, err := url.Parse(connURLRaw.(string)); err == nil {
if _, ok := connURL.User.Password(); ok {
resp.AddWarning("Password found in connection_url, use a templated url to enable root rotation and prevent read access to password information.")
}
}
}
// If using a legacy DB plugin and set the `password_policy` field, send a warning to the user indicating
// the `password_policy` will not be used
if dbw.isV4() && config.PasswordPolicy != "" {
resp.AddWarning(fmt.Sprintf("%s does not support password policies - upgrade to the latest version of "+
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
}
// We can ignore the error at this point since we're simply adding a warning.
dbType, _ := dbw.Type()
if dbType == "snowflake" && config.ConnectionDetails["password"] != nil {
resp.AddWarning(`[DEPRECATED] Single-factor password authentication is deprecated in Snowflake and will
be removed by November 2025. Key pair authentication will be required after this date. Please
see the Vault documentation for details on the removal of this feature. More information is
available at https://www.snowflake.com/en/blog/blocking-single-factor-password-authentification`)
}
var rotationPeriodString string
if config.RotationPeriod != 0 {
rotationPeriodString = config.RotationPeriod.String()
}
b.dbEvent(ctx, "config-write", req.Path, name, true)
recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigWrite,
AdditionalDatabaseMetadata{key: "root_rotation_period", value: rotationPeriodString},
AdditionalDatabaseMetadata{key: "root_rotation_schedule", value: config.RotationSchedule})
if len(resp.Warnings) == 0 {
return nil, nil
}
return resp, nil
}
}
func storeConfig(ctx context.Context, storage logical.Storage, name string, config *DatabaseConfig) error {
entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config)
if err != nil {
@ -717,75 +424,6 @@ func (b *databaseBackend) getPinnedVersion(ctx context.Context, pluginName strin
return pin.Version, nil
}
func (b *databaseBackend) selectPluginVersion(ctx context.Context, config *DatabaseConfig, data *framework.FieldData, op logical.Operation) (string, *logical.Response, error) {
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return "", nil, err
}
pluginVersionRaw, ok := data.GetOk("plugin_version")
switch {
case ok && pinnedVersion != "":
return "", logical.ErrorResponse("cannot specify plugin_version for plugin %q as it is pinned (v%s)", config.PluginName, pinnedVersion), nil
case pinnedVersion != "":
return pinnedVersion, nil, nil
case ok:
config.PluginVersion = pluginVersionRaw.(string)
}
var builtinShadowed bool
if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin {
builtinShadowed = true
}
switch {
case config.PluginVersion != "":
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return "", logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}
// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()
if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) {
if builtinShadowed {
return "", logical.ErrorResponse("database plugin %q, version %s not found, as it is"+
" overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil
}
config.PluginVersion = ""
}
case builtinShadowed:
// We'll select the unversioned plugin that's been registered.
case op == logical.CreateOperation:
// No version provided and no unversioned plugin of that name available.
// Pin to the current latest version if any versioned plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return "", nil, err
}
var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}
if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})
config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
}
return config.PluginVersion, nil, nil
}
const pathConfigConnectionHelpSyn = `
Configure connection details to a database plugin.
`

View File

@ -5,9 +5,390 @@
package database
import "github.com/hashicorp/vault/sdk/framework"
import (
"context"
"fmt"
"net/url"
"sort"
"github.com/fatih/structs"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/go-version"
"github.com/hashicorp/vault/helper/versions"
v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/rotation"
)
// EntDatabaseConfig is an empty struct for community edition
type EntDatabaseConfig struct{}
// AddConnectionFieldsEnt is a no-op for community edition
func AddConnectionFieldsEnt(fields map[string]*framework.FieldSchema) {
// no-op
}
// connectionWriteHandler returns a handler function for creating and updating
// both builtin and plugin database types for community edition
func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse(respErrEmptyName), nil
}
// Baseline
config := &DatabaseConfig{
VerifyConnection: true,
}
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
}
if entry != nil {
if err := entry.DecodeJSON(config); err != nil {
return nil, err
}
}
// If this value was provided as part of the request we want to set it to this value
if verifyConnectionRaw, ok := data.GetOk("verify_connection"); ok {
config.VerifyConnection = verifyConnectionRaw.(bool)
} else if req.Operation == logical.CreateOperation {
config.VerifyConnection = data.Get("verify_connection").(bool)
}
if pluginNameRaw, ok := data.GetOk("plugin_name"); ok {
config.PluginName = pluginNameRaw.(string)
} else if req.Operation == logical.CreateOperation {
config.PluginName = data.Get("plugin_name").(string)
}
if config.PluginName == "" {
return logical.ErrorResponse(respErrEmptyPluginName), nil
}
pluginVersion, respErr, err := b.selectPluginVersion(ctx, config, data, req.Operation)
if respErr != nil || err != nil {
return respErr, err
}
if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok {
config.AllowedRoles = allowedRolesRaw.([]string)
} else if req.Operation == logical.CreateOperation {
config.AllowedRoles = data.Get("allowed_roles").([]string)
}
if rootRotationStatementsRaw, ok := data.GetOk("root_rotation_statements"); ok {
config.RootCredentialsRotateStatements = rootRotationStatementsRaw.([]string)
} else if req.Operation == logical.CreateOperation {
config.RootCredentialsRotateStatements = data.Get("root_rotation_statements").([]string)
}
if passwordPolicyRaw, ok := data.GetOk("password_policy"); ok {
config.PasswordPolicy = passwordPolicyRaw.(string)
}
if skipImportRotationRaw, ok := data.GetOk("skip_static_role_import_rotation"); ok {
config.SkipStaticRoleImportRotation = skipImportRotationRaw.(bool)
}
if err := config.ParseAutomatedRotationFields(data); err != nil {
return logical.ErrorResponse(err.Error()), nil
}
// Remove these entries from the data before we store it keyed under
// ConnectionDetails.
delete(data.Raw, "name")
delete(data.Raw, "plugin_name")
delete(data.Raw, "plugin_version")
delete(data.Raw, "allowed_roles")
delete(data.Raw, "verify_connection")
delete(data.Raw, "root_rotation_statements")
delete(data.Raw, "password_policy")
delete(data.Raw, "skip_static_role_import_rotation")
delete(data.Raw, "rotation_schedule")
delete(data.Raw, "rotation_window")
delete(data.Raw, "rotation_period")
delete(data.Raw, "disable_automated_rotation")
delete(data.Raw, "EntDatabaseConfig")
id, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
// If this is an update, take any new values, overwrite what was there
// before, and pass that in as the "new" set of values to the plugin,
// then save what results
if req.Operation == logical.CreateOperation {
config.ConnectionDetails = data.Raw
} else {
if config.ConnectionDetails == nil {
config.ConnectionDetails = make(map[string]interface{})
}
for k, v := range data.Raw {
config.ConnectionDetails[k] = v
}
}
// Create a database plugin and initialize it.
dbw, err := newDatabaseWrapper(ctx, config.PluginName, pluginVersion, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse("error creating database object: %s", err), nil
}
initReq := v5.InitializeRequest{
Config: config.ConnectionDetails,
VerifyConnection: config.VerifyConnection,
}
initResp, err := dbw.Initialize(ctx, initReq)
if err != nil {
dbw.Close()
return logical.ErrorResponse("error creating database object: %s", err), nil
}
config.ConnectionDetails = initResp.Config
b.Logger().Debug("created database object", "name", name, "plugin_name", config.PluginName)
// Close and remove the old connection
oldConn := b.connections.Put(name, &dbPluginInstance{
database: dbw,
name: name,
id: id,
runningPluginVersion: pluginVersion,
})
if oldConn != nil {
oldConn.Close()
}
var performedRotationManagerOpern string
if config.ShouldDeregisterRotationJob() {
performedRotationManagerOpern = rotation.PerformedDeregistration
// Disable Automated Rotation and Deregister credentials if required
deregisterReq := &rotation.RotationJobDeregisterRequest{
MountPoint: req.MountPoint,
ReqPath: req.Path,
}
b.Logger().Debug("Deregistering rotation job", "mount", req.MountPoint+req.Path)
if err := b.System().DeregisterRotationJob(ctx, deregisterReq); err != nil {
return logical.ErrorResponse("error deregistering rotation job: %s", err), nil
}
} else if config.ShouldRegisterRotationJob() {
performedRotationManagerOpern = rotation.PerformedRegistration
// Register the rotation job if it's required.
cfgReq := &rotation.RotationJobConfigureRequest{
MountPoint: req.MountPoint,
ReqPath: req.Path,
RotationSchedule: config.RotationSchedule,
RotationWindow: config.RotationWindow,
RotationPeriod: config.RotationPeriod,
}
b.Logger().Debug("Registering rotation job", "mount", req.MountPoint+req.Path)
if _, err = b.System().RegisterRotationJob(ctx, cfgReq); err != nil {
return logical.ErrorResponse("error registering rotation job: %s", err), nil
}
}
// 1.12.0 and 1.12.1 stored builtin plugins in storage, but 1.12.2 reverted
// that, so clean up any pre-existing stored builtin versions on write.
if versions.IsBuiltinVersion(config.PluginVersion) {
config.PluginVersion = ""
}
err = storeConfig(ctx, req.Storage, name, config)
if err != nil {
wrappedError := err
if performedRotationManagerOpern != "" {
b.Logger().Error("write to storage failed but the rotation manager still succeeded.",
"operation", performedRotationManagerOpern, "mount", req.MountPoint, "path", req.Path)
wrappedError = fmt.Errorf("write to storage failed but the rotation manager still succeeded; "+
"operation=%s, mount=%s, path=%s, storageError=%s", performedRotationManagerOpern, req.MountPoint, req.Path, err)
}
return nil, wrappedError
}
resp := &logical.Response{}
// This is a simple test to check for passwords in the connection_url parameter. If one exists,
// warn the user to use templated url string
if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok {
if connURL, err := url.Parse(connURLRaw.(string)); err == nil {
if _, ok := connURL.User.Password(); ok {
resp.AddWarning("Password found in connection_url, use a templated url to enable root rotation and prevent read access to password information.")
}
}
}
// If using a legacy DB plugin and set the `password_policy` field, send a warning to the user indicating
// the `password_policy` will not be used
if dbw.isV4() && config.PasswordPolicy != "" {
resp.AddWarning(fmt.Sprintf("%s does not support password policies - upgrade to the latest version of "+
"Vault (or the sdk if using a custom plugin) to gain password policy support", config.PluginName))
}
// We can ignore the error at this point since we're simply adding a warning.
dbType, _ := dbw.Type()
if dbType == "snowflake" && config.ConnectionDetails["password"] != nil {
resp.AddWarning(`[DEPRECATED] Single-factor password authentication is deprecated in Snowflake and will
be removed by November 2025. Key pair authentication will be required after this date. Please
see the Vault documentation for details on the removal of this feature. More information is
available at https://www.snowflake.com/en/blog/blocking-single-factor-password-authentification`)
}
var rotationPeriodString string
if config.RotationPeriod != 0 {
rotationPeriodString = config.RotationPeriod.String()
}
b.dbEvent(ctx, "config-write", req.Path, name, true)
recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigWrite,
AdditionalDatabaseMetadata{key: "root_rotation_period", value: rotationPeriodString},
AdditionalDatabaseMetadata{key: "root_rotation_schedule", value: config.RotationSchedule})
if len(resp.Warnings) == 0 {
return nil, nil
}
return resp, nil
}
}
// selectPluginVersion returns the appropriate plugin version for community edition
func (b *databaseBackend) selectPluginVersion(ctx context.Context, config *DatabaseConfig, data *framework.FieldData, op logical.Operation) (string, *logical.Response, error) {
pinnedVersion, err := b.getPinnedVersion(ctx, config.PluginName)
if err != nil {
return "", nil, err
}
pluginVersionRaw, ok := data.GetOk("plugin_version")
switch {
case ok && pinnedVersion != "":
return "", logical.ErrorResponse("cannot specify plugin_version for plugin %q as it is pinned (%s)", config.PluginName, pinnedVersion), nil
case pinnedVersion != "":
return pinnedVersion, nil, nil
case ok:
config.PluginVersion = pluginVersionRaw.(string)
}
var builtinShadowed bool
if unversionedPlugin, err := b.System().LookupPlugin(ctx, config.PluginName, consts.PluginTypeDatabase); err == nil && !unversionedPlugin.Builtin {
builtinShadowed = true
}
switch {
case config.PluginVersion != "":
semanticVersion, err := version.NewVersion(config.PluginVersion)
if err != nil {
return "", logical.ErrorResponse("version %q is not a valid semantic version: %s", config.PluginVersion, err), nil
}
// Canonicalize the version.
config.PluginVersion = "v" + semanticVersion.String()
if config.PluginVersion == versions.GetBuiltinVersion(consts.PluginTypeDatabase, config.PluginName) {
if builtinShadowed {
return "", logical.ErrorResponse("database plugin %q, version %s not found, as it is"+
" overridden by an unversioned plugin of the same name. Omit `plugin_version` to use the unversioned plugin", config.PluginName, config.PluginVersion), nil
}
config.PluginVersion = ""
}
case builtinShadowed:
// We'll select the unversioned plugin that's been registered.
case op == logical.CreateOperation:
// No version provided and no unversioned plugin of that name available.
// Pin to the current latest version if any versioned plugins are registered.
plugins, err := b.System().ListVersionedPlugins(ctx, consts.PluginTypeDatabase)
if err != nil {
return "", nil, err
}
var versionedCandidates []pluginutil.VersionedPlugin
for _, plugin := range plugins {
if !plugin.Builtin && plugin.Name == config.PluginName && plugin.Version != "" {
versionedCandidates = append(versionedCandidates, plugin)
}
}
if len(versionedCandidates) != 0 {
// Sort in reverse order.
sort.SliceStable(versionedCandidates, func(i, j int) bool {
return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion)
})
config.PluginVersion = "v" + versionedCandidates[0].SemanticVersion.String()
b.logger.Debug(fmt.Sprintf("pinning %q database plugin version %q from candidates %v", config.PluginName, config.PluginVersion, versionedCandidates))
}
}
return config.PluginVersion, nil, nil
}
// connectionReadHandler reads out the connection configuration
func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
name := data.Get("name").(string)
if name == "" {
return logical.ErrorResponse(respErrEmptyName), nil
}
entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name))
if err != nil {
return nil, fmt.Errorf("failed to read connection configuration: %w", err)
}
if entry == nil {
return nil, nil
}
var config DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
// Ensure that we only ever include a redacted valid URL in the response.
if connURLRaw, ok := config.ConnectionDetails["connection_url"]; ok {
if p, err := url.Parse(connURLRaw.(string)); err == nil {
config.ConnectionDetails["connection_url"] = p.Redacted()
}
}
if versions.IsBuiltinVersion(config.PluginVersion) {
// This gets treated as though it's empty when mounting, and will get
// overwritten to be empty when the config is next written. See #18051.
config.PluginVersion = ""
}
delete(config.ConnectionDetails, "password")
delete(config.ConnectionDetails, "private_key")
delete(config.ConnectionDetails, "service_account_json")
resp := &logical.Response{}
if dbi, err := b.GetConnectionSkipVerify(ctx, req.Storage, name); err == nil {
config.RunningPluginVersion = dbi.runningPluginVersion
if config.PluginVersion != "" && config.PluginVersion != config.RunningPluginVersion {
warning := fmt.Sprintf("Plugin version is configured as %q, but running %q", config.PluginVersion, config.RunningPluginVersion)
if pinnedVersion, _ := b.getPinnedVersion(ctx, config.PluginName); pinnedVersion == config.RunningPluginVersion {
warning += " because that version is pinned"
} else {
warning += " either due to a pinned version or because the plugin was upgraded and not yet reloaded"
}
resp.AddWarning(warning)
}
}
resp.Data = structs.New(config).Map()
config.PopulateAutomatedRotationData(resp.Data)
// remove extra nested AutomatedRotationParams key
// before returning response
delete(resp.Data, "AutomatedRotationParams")
// remove nested EntDatabaseConfig key before returning response
delete(resp.Data, "EntDatabaseConfig")
recordDatabaseObservation(ctx, b, req, name, ObservationTypeDatabaseConfigRead)
return resp, nil
}
}

3
changelog/_10517.txt Normal file
View File

@ -0,0 +1,3 @@
```release-note:feature
**Plugins (Enterprise)**: Allow overriding pinned version when creating and updating database engines
```