DBPW - Migrate Redshift database plugin to v5 interface (#10195)

This commit is contained in:
Tom Proctor 2020-10-23 14:10:57 +01:00 committed by GitHub
parent ee09e54d80
commit be0a3d28f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 405 additions and 523 deletions

View File

@ -108,7 +108,7 @@ func newRegistry() *registry {
"mongodbatlas-database-plugin": dbMongoAtlas.New, "mongodbatlas-database-plugin": dbMongoAtlas.New,
"mssql-database-plugin": dbMssql.New, "mssql-database-plugin": dbMssql.New,
"postgresql-database-plugin": dbPostgres.New, "postgresql-database-plugin": dbPostgres.New,
"redshift-database-plugin": dbRedshift.New(true), "redshift-database-plugin": dbRedshift.New,
}, },
logicalBackends: map[string]logical.Factory{ logicalBackends: map[string]logical.Factory{
"ad": logicalAd.Factory, "ad": logicalAd.Factory,

View File

@ -6,11 +6,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/hashicorp/errwrap" "github.com/hashicorp/errwrap"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/database/dbplugin" dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/connutil"
"github.com/hashicorp/vault/sdk/database/helper/credsutil" "github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil"
@ -31,37 +30,28 @@ const (
ALTER USER "{{name}}" VALID UNTIL '{{expiration}}'; ALTER USER "{{name}}" VALID UNTIL '{{expiration}}';
` `
defaultRotateRootCredentialsSQL = ` defaultRotateRootCredentialsSQL = `
ALTER USER "{{username}}" WITH PASSWORD '{{password}}'; ALTER USER "{{name}}" WITH PASSWORD '{{password}}';
` `
) )
// lowercaseUsername is the reason we wrote this plugin. Redshift implements (mostly) var _ dbplugin.Database = (*RedShift)(nil)
// a postgres 8 interface, and part of that is under the hood, it's lowercasing the
// usernames. // New implements builtinplugins.BuiltinFactory
func New(lowercaseUsername bool) func() (interface{}, error) { // Redshift implements (mostly) a postgres 8 interface, and part of that is
return func() (interface{}, error) { // under the hood, it's lower-casing the usernames.
db := newRedshift(lowercaseUsername) func New() (interface{}, error) {
db := newRedshift()
// Wrap the plugin with middleware to sanitize errors // Wrap the plugin with middleware to sanitize errors
dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil return dbType, nil
} }
}
func newRedshift(lowercaseUsername bool) *RedShift { func newRedshift() *RedShift {
connProducer := &connutil.SQLConnectionProducer{} connProducer := &connutil.SQLConnectionProducer{}
connProducer.Type = sqlTypeName connProducer.Type = sqlTypeName
credsProducer := &credsutil.SQLCredentialsProducer{
DisplayNameLen: 8,
RoleNameLen: 8,
UsernameLen: 63,
Separator: "-",
LowercaseUsername: lowercaseUsername,
}
db := &RedShift{ db := &RedShift{
SQLConnectionProducer: connProducer, SQLConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
} }
return db return db
@ -69,14 +59,32 @@ func newRedshift(lowercaseUsername bool) *RedShift {
type RedShift struct { type RedShift struct {
*connutil.SQLConnectionProducer *connutil.SQLConnectionProducer
credsutil.CredentialsProducer }
func (r *RedShift) secretValues() map[string]string {
return map[string]string{
r.Password: "[password]",
}
} }
func (r *RedShift) Type() (string, error) { func (r *RedShift) Type() (string, error) {
return middlewareTypeName, nil return middlewareTypeName, nil
} }
// getConnection accepts a context and retuns a new pointer to a sql.DB object. // Initialize must be called on each new RedShift struct before use.
// It uses the connutil.SQLConnectionProducer's Init function to do all the lifting.
func (r *RedShift) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
conf, err := r.Init(ctx, req.Config, req.VerifyConnection)
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("error initializing db: %w", err)
}
return dbplugin.InitializeResponse{
Config: conf,
}, nil
}
// getConnection accepts a context and returns a new pointer to a sql.DB object.
// It's up to the caller to close the connection or handle reuse logic. // It's up to the caller to close the connection or handle reuse logic.
func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) { func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) {
db, err := r.Connection(ctx) db, err := r.Connection(ctx)
@ -86,56 +94,182 @@ func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) {
return db.(*sql.DB), nil return db.(*sql.DB), nil
} }
// SetCredentials uses provided information to set/create a user in the // NewUser creates a new user in the database. There is no default statement for
// database. Unlike CreateUser, this method requires a username be provided and // creating users, so one must be specified in the plugin config.
// uses the name given, instead of generating a name. This is used for creating // Generated usernames are of the form v-{display-name}-{role-name}-{UUID}-{timestamp}
// and setting the password of static accounts, as well as rolling back func (r *RedShift) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) {
// passwords in the database in the event an updated database fails to save in if len(req.Statements.Commands) == 0 {
// Vault's storage. return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement
func (r *RedShift) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
if len(statements.Rotation) == 0 {
statements.Rotation = []string{defaultRotateRootCredentialsSQL}
}
username = staticUser.Username
password = staticUser.Password
if username == "" || password == "" {
return "", "", errors.New("must provide both username and password")
} }
// Grab the lock // Grab the lock
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
usernameOpts := []credsutil.UsernameOpt{
credsutil.DisplayName(req.UsernameConfig.DisplayName, 8),
credsutil.RoleName(req.UsernameConfig.RoleName, 8),
credsutil.MaxLength(63),
credsutil.Separator("-"),
credsutil.ToLower(),
}
username, err := credsutil.GenerateUsername(usernameOpts...)
if err != nil {
return dbplugin.NewUserResponse{}, err
}
password := req.Password
expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700")
// Get the connection // Get the connection
db, err := r.getConnection(ctx) db, err := r.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return dbplugin.NewUserResponse{}, err
} }
defer db.Close() defer db.Close()
// Start a transaction
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return dbplugin.NewUserResponse{}, err
}
defer func() {
tx.Rollback()
}()
// Execute each query
for _, stmt := range req.Statements.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
"username": username,
"password": password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return dbplugin.NewUserResponse{}, err
}
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return dbplugin.NewUserResponse{}, err
}
return dbplugin.NewUserResponse{
Username: username,
}, nil
}
// UpdateUser can update the expiration or the password of a user, or both.
// The updates all happen in a single transaction, so they will either all
// succeed or all fail.
// Both updates support both default and custom statements.
func (r *RedShift) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) {
if req.Password == nil && req.Expiration == nil {
return dbplugin.UpdateUserResponse{}, errors.New("no changes requested")
}
r.Lock()
defer r.Unlock()
db, err := r.getConnection(ctx)
if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
defer func() {
tx.Rollback()
}()
if req.Expiration != nil {
err = updateUserExpiration(ctx, req, tx)
if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
}
if req.Password != nil {
err = updateUserPassword(ctx, req, tx)
if err != nil {
return dbplugin.UpdateUserResponse{}, err
}
}
err = tx.Commit()
return dbplugin.UpdateUserResponse{}, err
}
func updateUserExpiration(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error {
if req.Username == "" {
return errors.New("must provide a username to update user expiration")
}
renewStmts := req.Expiration.Statements
if len(renewStmts.Commands) == 0 {
renewStmts.Commands = []string{defaultRenewSQL}
}
expirationStr := req.Expiration.NewExpiration.Format("2006-01-02 15:04:05-0700")
for _, stmt := range renewStmts.Commands {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": req.Username,
"username": req.Username,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return err
}
}
}
return nil
}
func updateUserPassword(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error {
username := req.Username
password := req.Password.NewPassword
if username == "" || password == "" {
return errors.New("must provide both username and a new password to update user password")
}
// Check if the role exists // Check if the role exists
var exists bool var exists bool
err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) err := tx.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return "", "", err // Server error
return err
}
if err == sql.ErrNoRows || !exists {
// Most likely a user error
return fmt.Errorf("cannot update password for username %q because it does not exist", username)
} }
// Vault requires the database user already exist, and that the credentials // Vault requires the database user already exist, and that the credentials
// used to execute the rotation statements has sufficient privileges. // used to execute the rotation statements has sufficient privileges.
stmts := statements.Rotation statements := req.Password.Statements.Commands
if len(statements) == 0 {
// Start a transaction statements = []string{defaultRotateRootCredentialsSQL}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", err
} }
defer func() {
tx.Rollback()
}()
// Execute each query // Execute each query
for _, stmt := range stmts { for _, stmt := range statements {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) == 0 { if len(query) == 0 {
@ -143,67 +277,48 @@ func (r *RedShift) SetCredentials(ctx context.Context, statements dbplugin.State
} }
m := map[string]string{ m := map[string]string{
"name": staticUser.Username, "name": username,
"username": username,
"password": password, "password": password,
} }
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err return err
} }
} }
} }
// Commit the transaction return nil
if err := tx.Commit(); err != nil {
return "", "", err
}
return username, password, nil
}
func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) {
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Creation) == 0 {
return "", "", dbutil.ErrEmptyCreationStatement
} }
// DeleteUser supports both default and custom statements to delete a user.
func (r *RedShift) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
// Grab the lock // Grab the lock
r.Lock() r.Lock()
defer r.Unlock() defer r.Unlock()
username, err = r.GenerateUsername(usernameConfig) if len(req.Statements.Commands) == 0 {
if err != nil { return r.defaultDeleteUser(ctx, req)
return "", "", err
} }
password, err = r.GeneratePassword() return r.customDeleteUser(ctx, req)
if err != nil {
return "", "", err
} }
expirationStr, err := r.GenerateExpiration(expiration) func (r *RedShift) customDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
if err != nil {
return "", "", err
}
// Get the connection
db, err := r.getConnection(ctx) db, err := r.getConnection(ctx)
if err != nil { if err != nil {
return "", "", err return dbplugin.DeleteUserResponse{}, err
} }
defer db.Close() defer db.Close()
// Start a transaction
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return "", "", err return dbplugin.DeleteUserResponse{}, err
} }
defer func() { defer func() {
tx.Rollback() tx.Rollback()
}() }()
// Execute each query for _, stmt := range req.Statements.Commands {
for _, stmt := range statements.Creation {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query) query = strings.TrimSpace(query)
if len(query) == 0 { if len(query) == 0 {
@ -211,137 +326,37 @@ func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statement
} }
m := map[string]string{ m := map[string]string{
"name": username, "name": req.Username,
"password": password, "username": req.Username,
"expiration": expirationStr,
} }
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return "", "", err return dbplugin.DeleteUserResponse{}, err
} }
} }
} }
// Commit the transaction return dbplugin.DeleteUserResponse{}, tx.Commit()
if err := tx.Commit(); err != nil {
return "", "", err
}
return username, password, nil
}
func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error {
r.Lock()
defer r.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
renewStmts := statements.Renewal
if len(renewStmts) == 0 {
renewStmts = []string{defaultRenewSQL}
} }
func (r *RedShift) defaultDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) {
db, err := r.getConnection(ctx) db, err := r.getConnection(ctx)
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer db.Close() defer db.Close()
tx, err := db.BeginTx(ctx, nil) username := req.Username
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
expirationStr, err := r.GenerateExpiration(expiration)
if err != nil {
return err
}
for _, stmt := range renewStmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return err
}
}
}
return tx.Commit()
}
func (r *RedShift) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error {
// Grab the lock
r.Lock()
defer r.Unlock()
statements = dbutil.StatementCompatibilityHelper(statements)
if len(statements.Revocation) == 0 {
return r.defaultRevokeUser(ctx, username)
}
return r.customRevokeUser(ctx, username, statements.Revocation)
}
func (r *RedShift) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error {
db, err := r.getConnection(ctx)
if err != nil {
return err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
for _, stmt := range revocationStmts {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"name": username,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return err
}
}
}
return tx.Commit()
}
func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error {
db, err := r.getConnection(ctx)
if err != nil {
return err
}
defer db.Close()
// Check if the role exists // Check if the role exists
var exists bool var exists bool
err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return dbplugin.DeleteUserResponse{}, err
} }
if !exists { if !exists {
return nil // No error as Redshift may have deleted the user via TTL before we got to it.
return dbplugin.DeleteUserResponse{}, nil
} }
// Query for permissions; we need to revoke permissions before we can drop // Query for permissions; we need to revoke permissions before we can drop
@ -350,13 +365,13 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error
// we want to remove as much access as possible // we want to remove as much access as possible
stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer stmt.Close() defer stmt.Close()
rows, err := stmt.QueryContext(ctx, username) rows, err := stmt.QueryContext(ctx, username)
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer rows.Close() defer rows.Close()
@ -393,7 +408,7 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error
// this username // this username
var dbname sql.NullString var dbname sql.NullString
if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil { if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
if dbname.Valid { if dbname.Valid {
@ -432,78 +447,22 @@ $$;`)
// can't drop if not all privileges are revoked // can't drop if not all privileges are revoked
if rows.Err() != nil { if rows.Err() != nil {
return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err()) return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err())
} }
if lastStmtError != nil { if lastStmtError != nil {
return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError) return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError)
} }
// Drop this user // Drop this user
stmt, err = db.PrepareContext(ctx, fmt.Sprintf( stmt, err = db.PrepareContext(ctx, fmt.Sprintf(
`DROP USER IF EXISTS %s;`, pq.QuoteIdentifier(username))) `DROP USER IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil { if err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
defer stmt.Close() defer stmt.Close()
if _, err := stmt.ExecContext(ctx); err != nil { if _, err := stmt.ExecContext(ctx); err != nil {
return err return dbplugin.DeleteUserResponse{}, err
} }
return nil return dbplugin.DeleteUserResponse{}, nil
}
func (r *RedShift) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) {
r.Lock()
defer r.Unlock()
if len(r.Username) == 0 || len(r.Password) == 0 {
return nil, errors.New("username and password are required to rotate")
}
rotateStatements := statements
if len(rotateStatements) == 0 {
rotateStatements = []string{defaultRotateRootCredentialsSQL}
}
db, err := r.getConnection(ctx)
if err != nil {
return nil, err
}
defer db.Close()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
password, err := r.GeneratePassword()
if err != nil {
return nil, err
}
for _, stmt := range rotateStatements {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": r.Username,
"password": password,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
return nil, err
}
}
}
if err := tx.Commit(); err != nil {
return nil, err
}
r.RawConfig["password"] = password
return r.RawConfig, nil
} }

View File

@ -3,16 +3,19 @@ package redshift
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"os" "os"
"strings" "reflect"
"regexp"
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-uuid" "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/database/dbplugin"
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
"github.com/hashicorp/vault/sdk/helper/dbtxn" "github.com/hashicorp/vault/sdk/helper/dbtxn"
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
"github.com/lib/pq" "github.com/lib/pq"
) )
@ -25,10 +28,6 @@ as environment variables to be used to run these tests. Note that these tests
will create users on your redshift cluster and currently do not clean up after will create users on your redshift cluster and currently do not clean up after
themselves. themselves.
The RotateRoot test is potentially destructive in that it will rotate your root
password on your Redshift cluster to an insecure, cleartext password defined in the
test method. Because of this, you must pass TEST_ROTATE_ROOT=1 to enable it explicitly.
Do not run this test suite against a production Redshift cluster. Do not run this test suite against a production Redshift cluster.
Configuration: Configuration:
@ -37,7 +36,6 @@ Configuration:
REDSHIFT_USER=my-redshift-admin-user REDSHIFT_USER=my-redshift-admin-user
REDSHIFT_PASSWORD=my-redshift-admin-password REDSHIFT_PASSWORD=my-redshift-admin-password
VAULT_ACC=<unset || 1> # This must be set to run any of the tests in this test suite VAULT_ACC=<unset || 1> # This must be set to run any of the tests in this test suite
TEST_ROTATE_ROOT=<unset || 1> # This must be set to explicitly run the rotate root test
*/ */
var ( var (
@ -48,281 +46,230 @@ var (
vaultACC = "VAULT_ACC" vaultACC = "VAULT_ACC"
) )
func redshiftEnv() (url string, user string, password string, errEmpty error) { func interpolateConnectionURL(url, user, password string) string {
errEmpty = errors.New("err: empty but required env value") return fmt.Sprintf("postgres://%s:%s@%s", user, password, url)
}
func redshiftEnv() (connURL string, url string, user string, password string, errEmpty error) {
if url = os.Getenv(keyRedshiftURL); url == "" { if url = os.Getenv(keyRedshiftURL); url == "" {
return "", "", "", errEmpty return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftURL)
} }
if user = os.Getenv(keyRedshiftUser); url == "" { if user = os.Getenv(keyRedshiftUser); url == "" {
return "", "", "", errEmpty return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftUser)
} }
if password = os.Getenv(keyRedshiftPassword); url == "" { if password = os.Getenv(keyRedshiftPassword); url == "" {
return "", "", "", errEmpty return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftPassword)
} }
url = fmt.Sprintf("postgres://%s:%s@%s", user, password, url) connURL = interpolateConnectionURL(url, user, password)
return connURL, url, user, password, nil
return url, user, password, nil
} }
func TestPostgreSQL_Initialize(t *testing.T) { func TestRedshift_Initialize(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, _, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
"max_open_connections": 5, "max_open_connections": 73,
} }
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) resp := dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
if !db.Initialized { if !db.Initialized {
t.Fatal("Database should be initialized") t.Fatal("Database should be initialized")
} }
expectedConfig := make(map[string]interface{})
err = db.Close() for k, v := range connectionDetails {
if err != nil { expectedConfig[k] = v
t.Fatalf("err: %s", err) }
if !reflect.DeepEqual(expectedConfig, resp.Config) {
t.Fatalf("Expected config %+v, but was %v", expectedConfig, resp.Config)
}
if db.MaxOpenConnections != 73 {
t.Fatalf("Expected max_open_connections to be set to 73, but was %d", db.MaxOpenConnections)
} }
// Test decoding a string value for max_open_connections dbtesting.AssertClose(t, db)
connectionDetails = map[string]interface{}{
"connection_url": url,
"max_open_connections": "5",
} }
_, err = db.Init(context.Background(), connectionDetails, true) func TestRedshift_NewUser(t *testing.T) {
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPostgreSQL_CreateUser(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
} }
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
usernameConfig := dbplugin.UsernameConfig{ usernameConfig := dbplugin.UsernameMetadata{
DisplayName: "test", DisplayName: "test",
RoleName: "test", RoleName: "test",
} }
// Test with no configured Creation Statement const password = "SuperSecurePa55w0rd!"
_, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) for _, commands := range [][]string{{testRedshiftRole}, {testRedshiftReadOnlyRole}} {
if err == nil { resp := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{
t.Fatal("Expected error when no creation statement is provided") UsernameConfig: usernameConfig,
} Password: password,
Statements: dbplugin.Statements{
statements := dbplugin.Statements{ Commands: commands,
Creation: []string{testRedshiftRole}, },
} Expiration: time.Now().Add(5 * time.Minute),
})
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) username := resp.Username
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, url, username, password); err != nil { if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s\n%s:%s", err, username, password) t.Fatalf("Could not connect with new credentials: %s\n%s:%s", err, username, password)
} }
statements.Creation = []string{testRedshiftReadOnlyRole} usernameRegex := regexp.MustCompile("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$")
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) if !usernameRegex.Match([]byte(username)) {
if err != nil { t.Fatalf("Expected username %q to match regex %q", username, usernameRegex.String())
t.Fatalf("err: %s", err)
}
// Sleep to make sure we haven't expired if granularity is only down to the second
time.Sleep(2 * time.Second)
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
} }
} }
func TestPostgreSQL_RenewUser(t *testing.T) { dbtesting.AssertClose(t, db)
}
func TestRedshift_NewUser_NoCreationStatement_ShouldError(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, _, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
} }
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
statements := dbplugin.Statements{ usernameConfig := dbplugin.UsernameMetadata{
Creation: []string{testRedshiftRole},
}
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test", DisplayName: "test",
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) const password = "SuperSecurePa55w0rd!"
if err != nil {
t.Fatalf("err: %s", err) // Test with no configured Creation Statement
_, err = db.NewUser(context.Background(), dbplugin.NewUserRequest{
UsernameConfig: usernameConfig,
Password: password,
Statements: dbplugin.Statements{
Commands: []string{}, // Empty commands field here should cause error.
},
Expiration: time.Now().Add(5 * time.Minute),
})
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
} }
if err = testCredsExist(t, url, username, password); err != nil { dbtesting.AssertClose(t, db)
t.Fatalf("Could not connect with new credentials: %s", err)
} }
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) func TestRedshift_UpdateUser_Expiration(t *testing.T) {
if err != nil {
t.Fatalf("err: %s", err)
}
// Sleep longer than the initial expiration time
time.Sleep(2 * time.Second)
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
statements.Renewal = []string{defaultRenewSQL}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("err: %s", err)
}
// Sleep longer than the initial expiration time
time.Sleep(2 * time.Second)
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
}
func TestPostgreSQL_RevokeUser(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
} }
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
statements := dbplugin.Statements{ usernameConfig := dbplugin.UsernameMetadata{
Creation: []string{testRedshiftRole},
}
usernameConfig := dbplugin.UsernameConfig{
DisplayName: "test", DisplayName: "test",
RoleName: "test", RoleName: "test",
} }
username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) const password = "SuperSecurePa55w0rd!"
if err != nil { const initialTTL = 2 * time.Second
t.Fatalf("err: %s", err) const longTTL = time.Minute
} for _, commands := range [][]string{{}, {defaultRenewSQL}} {
newResp := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{
UsernameConfig: usernameConfig,
Password: password,
Statements: dbplugin.Statements{Commands: []string{testRedshiftRole}},
Expiration: time.Now().Add(initialTTL),
})
username := newResp.Username
if err = testCredsExist(t, url, username, password); err != nil { if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
// Test default revoke statements dbtesting.AssertUpdateUser(t, db, dbplugin.UpdateUserRequest{
err = db.RevokeUser(context.Background(), statements, username) Username: username,
if err != nil { Expiration: &dbplugin.ChangeExpiration{
t.Fatalf("err: %s", err) NewExpiration: time.Now().Add(longTTL),
} Statements: dbplugin.Statements{Commands: commands},
},
})
if err := testCredsExist(t, url, username, password); err == nil { // Sleep longer than the initial expiration time
t.Fatal("Credentials were not revoked") time.Sleep(initialTTL + time.Second)
}
username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second))
if err != nil {
t.Fatalf("err: %s", err)
}
if err = testCredsExist(t, url, username, password); err != nil { if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
// Test custom revoke statements
statements.Revocation = []string{defaultRedshiftRevocationSQL}
err = db.RevokeUser(context.Background(), statements, username)
if err != nil {
t.Fatalf("err: %s", err)
} }
if err := testCredsExist(t, url, username, password); err == nil { dbtesting.AssertClose(t, db)
t.Fatal("Credentials were not revoked")
}
} }
func TestPostgresSQL_SetCredentials(t *testing.T) { func TestRedshift_UpdateUser_Password(t *testing.T) {
if os.Getenv(vaultACC) != "1" { if os.Getenv(vaultACC) != "1" {
t.SkipNow() t.SkipNow()
} }
url, _, _, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
} }
// create the database user // create the database user
@ -331,121 +278,97 @@ func TestPostgresSQL_SetCredentials(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
dbUser := "vaultstatictest-" + fmt.Sprintf("%s", uid) dbUser := "vaultstatictest-" + fmt.Sprintf("%s", uid)
createTestPGUser(t, url, dbUser, "1Password", testRoleStaticCreate) createTestPGUser(t, connURL, dbUser, "1Password", testRoleStaticCreate)
db := newRedshift(true) db := newRedshift()
_, err = db.Init(context.Background(), connectionDetails, true) dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
if err != nil { Config: connectionDetails,
t.Fatalf("err: %s", err) VerifyConnection: true,
} })
password, err := db.GenerateCredentials(context.Background()) const password1 = "MyTemporaryUserPassword1!"
if err != nil { const password2 = "MyTemporaryUserPassword2!"
t.Fatal(err)
}
usernameConfig := dbplugin.StaticUserConfig{ for _, tc := range []struct {
password string
commands []string
}{
{password1, []string{}},
{password2, []string{testRedshiftStaticRoleRotate}},
} {
dbtesting.AssertUpdateUser(t, db, dbplugin.UpdateUserRequest{
Username: dbUser, Username: dbUser,
Password: password, Password: &dbplugin.ChangePassword{
} NewPassword: tc.password,
Statements: dbplugin.Statements{Commands: tc.commands},
},
})
// Test with no configured Rotation Statement if err := testCredsExist(t, url, dbUser, tc.password); err != nil {
username, password, err := db.SetCredentials(context.Background(), dbplugin.Statements{}, usernameConfig)
if err == nil {
t.Fatalf("err: %s", err)
}
statements := dbplugin.Statements{
Rotation: []string{testRedshiftStaticRoleRotate},
}
// User should not exist, make sure we can create
username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig)
if err != nil {
t.Fatalf("err: %s", err)
}
if err := testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
// call SetCredentials again, password will change
newPassword, _ := db.GenerateCredentials(context.Background())
usernameConfig.Password = newPassword
username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig)
if err != nil {
t.Fatalf("err: %s", err)
}
if password != newPassword {
t.Fatal("passwords should have changed")
}
if err := testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err) t.Fatalf("Could not connect with new credentials: %s", err)
} }
} }
func TestPostgreSQL_RotateRootCredentials(t *testing.T) { dbtesting.AssertClose(t, db)
/* }
Extra precaution is taken for rotating root creds because it's assumed that this
test will run against a live redshift cluster. This test must run last because
it is destructive.
To run this test you must pass TEST_ROTATE_ROOT=1 func TestRedshift_DeleteUser(t *testing.T) {
*/ if os.Getenv(vaultACC) != "1" {
if os.Getenv(vaultACC) != "1" || os.Getenv("TEST_ROTATE_ROOT") != "1" {
t.SkipNow() t.SkipNow()
} }
url, adminUser, adminPassword, err := redshiftEnv() connURL, url, _, _, err := redshiftEnv()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
connectionDetails := map[string]interface{}{ connectionDetails := map[string]interface{}{
"connection_url": url, "connection_url": connURL,
"username": adminUser,
"password": adminPassword,
} }
db := newRedshift(true) db := newRedshift()
dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
})
connProducer := db.SQLConnectionProducer usernameConfig := dbplugin.UsernameMetadata{
DisplayName: "test",
_, err = db.Init(context.Background(), connectionDetails, true) RoleName: "test",
if err != nil {
t.Fatalf("err: %s", err)
} }
if !connProducer.Initialized { const password = "SuperSecretPa55word!"
t.Fatal("Database should be initialized") for _, commands := range [][]string{{}, {defaultRedshiftRevocationSQL}} {
newResponse := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{
UsernameConfig: usernameConfig,
Statements: dbplugin.Statements{Commands: []string{testRedshiftRole}},
Password: password,
Expiration: time.Now().Add(2 * time.Second),
})
username := newResponse.Username
if err = testCredsExist(t, url, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
} }
newConf, err := db.RotateRootCredentials(context.Background(), nil) // Intentionally _not_ using dbtesting here as the call almost always takes longer than the 2s default timeout
if err != nil { db.DeleteUser(context.Background(), dbplugin.DeleteUserRequest{
t.Fatalf("err: %v", err) Username: username,
} Statements: dbplugin.Statements{Commands: commands},
})
fmt.Printf("rotated root credentials, new user/pass:\nusername: %s\npassword: %s\n", newConf["username"], newConf["password"]) if err := testCredsExist(t, url, username, password); err == nil {
t.Fatal("Credentials were not revoked")
if newConf["password"] == adminPassword {
t.Fatal("password was not updated")
}
err = db.Close()
if err != nil {
t.Fatalf("err: %s", err)
} }
} }
func testCredsExist(t testing.TB, connURL, username, password string) error { dbtesting.AssertClose(t, db)
}
func testCredsExist(t testing.TB, url, username, password string) error {
t.Helper() t.Helper()
_, adminUser, adminPassword, err := redshiftEnv()
if err != nil {
return err
}
connURL = strings.Replace(connURL, fmt.Sprintf("%s:%s", adminUser, adminPassword), fmt.Sprintf("%s:%s", username, password), 1) connURL := interpolateConnectionURL(url, username, password)
db, err := sql.Open("postgres", connURL) db, err := sql.Open("postgres", connURL)
if err != nil { if err != nil {
return err return err