diff --git a/changelog/31442.txt b/changelog/31442.txt new file mode 100644 index 0000000000..4a79f604a7 --- /dev/null +++ b/changelog/31442.txt @@ -0,0 +1,3 @@ +```release-note:bug +secrets/database/postgresql: Support for multiline statements in the `rotation_statements` field. +``` diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index ffe460f45c..84173fc81c 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -270,6 +270,28 @@ func (p *PostgreSQL) changeUserPassword(ctx context.Context, username string, ch defer tx.Rollback() for _, stmt := range stmts { + if containsMultilineStatement(stmt) { + // Execute it as-is. + m := map[string]string{ + "name": username, + "username": username, + "password": password, + } + + if p.passwordAuthentication == passwordAuthenticationSCRAMSHA256 { + hashedPassword, err := scram.Hash(password) + if err != nil { + return fmt.Errorf("unable to scram-sha256 password: %w", err) + } + m["password"] = hashedPassword + } + + if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, stmt); err != nil { + return err + } + continue + } + // Otherwise, it's fine to split the statements on the semicolon. for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { @@ -337,6 +359,19 @@ func (p *PostgreSQL) changeUserExpiration(ctx context.Context, username string, expirationStr := changeExp.NewExpiration.Format(expirationFormat) for _, stmt := range renewStmts { + if containsMultilineStatement(stmt) { + // Execute it as-is. + m := map[string]string{ + "name": username, + "username": username, + "expiration": expirationStr, + } + if err := dbtxn.ExecuteTxQueryDirect(ctx, tx, m, stmt); err != nil { + return err + } + continue + } + // Otherwise, it's fine to split the statements on the semicolon. for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 281c42990e..f3829d3e01 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -1063,6 +1063,17 @@ func TestUpdateUser_Password(t *testing.T) { expectErr: false, credsAssertion: assertCredsExist, }, + "multi-line statements": { + statements: []string{ + `DO $$ BEGIN + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='{{name}}') + THEN CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}'; + ELSE ALTER ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}'; + END IF; END $$`, + }, + expectErr: false, + credsAssertion: assertCredsExist, + }, "bad statements": { statements: []string{`asdofyas8uf77asoiajv`}, expectErr: true, @@ -1205,6 +1216,17 @@ func TestUpdateUser_Expiration(t *testing.T) { statements: []string{`ALTER ROLE "{{username}}" VALID UNTIL '{{expiration}}';`}, expectErr: false, }, + "multi-line statements": { + initialExpiration: now.Add(1 * time.Minute), + newExpiration: now.Add(5 * time.Minute), + expectedExpiration: now.Add(5 * time.Minute), + statements: []string{ + `DO $$ BEGIN + ALTER ROLE "{{name}}" VALID UNTIL '{{expiration}}'; + END $$`, + }, + expectErr: false, + }, "bad statements": { initialExpiration: now.Add(1 * time.Minute), newExpiration: now.Add(5 * time.Minute),