diff --git a/helper/builtinplugins/registry.go b/helper/builtinplugins/registry.go index 0a94d791f2..19261fa5f3 100644 --- a/helper/builtinplugins/registry.go +++ b/helper/builtinplugins/registry.go @@ -108,7 +108,7 @@ func newRegistry() *registry { "mongodbatlas-database-plugin": dbMongoAtlas.New, "mssql-database-plugin": dbMssql.New, "postgresql-database-plugin": dbPostgres.New, - "redshift-database-plugin": dbRedshift.New(true), + "redshift-database-plugin": dbRedshift.New, }, logicalBackends: map[string]logical.Factory{ "ad": logicalAd.Factory, diff --git a/plugins/database/redshift/redshift.go b/plugins/database/redshift/redshift.go index 9fe5821358..4964393b7d 100644 --- a/plugins/database/redshift/redshift.go +++ b/plugins/database/redshift/redshift.go @@ -6,11 +6,10 @@ import ( "errors" "fmt" "strings" - "time" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-multierror" - "github.com/hashicorp/vault/sdk/database/dbplugin" + dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/database/helper/credsutil" "github.com/hashicorp/vault/sdk/database/helper/dbutil" @@ -31,37 +30,28 @@ const ( ALTER USER "{{name}}" VALID UNTIL '{{expiration}}'; ` defaultRotateRootCredentialsSQL = ` -ALTER USER "{{username}}" WITH PASSWORD '{{password}}'; +ALTER USER "{{name}}" WITH PASSWORD '{{password}}'; ` ) -// lowercaseUsername is the reason we wrote this plugin. Redshift implements (mostly) -// a postgres 8 interface, and part of that is under the hood, it's lowercasing the -// usernames. -func New(lowercaseUsername bool) func() (interface{}, error) { - return func() (interface{}, error) { - db := newRedshift(lowercaseUsername) - // Wrap the plugin with middleware to sanitize errors - dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.SecretValues) - return dbType, nil - } +var _ dbplugin.Database = (*RedShift)(nil) + +// New implements builtinplugins.BuiltinFactory +// Redshift implements (mostly) a postgres 8 interface, and part of that is +// under the hood, it's lower-casing the usernames. +func New() (interface{}, error) { + db := newRedshift() + // Wrap the plugin with middleware to sanitize errors + dbType := dbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues) + return dbType, nil } -func newRedshift(lowercaseUsername bool) *RedShift { +func newRedshift() *RedShift { connProducer := &connutil.SQLConnectionProducer{} connProducer.Type = sqlTypeName - credsProducer := &credsutil.SQLCredentialsProducer{ - DisplayNameLen: 8, - RoleNameLen: 8, - UsernameLen: 63, - Separator: "-", - LowercaseUsername: lowercaseUsername, - } - db := &RedShift{ SQLConnectionProducer: connProducer, - CredentialsProducer: credsProducer, } return db @@ -69,14 +59,32 @@ func newRedshift(lowercaseUsername bool) *RedShift { type RedShift struct { *connutil.SQLConnectionProducer - credsutil.CredentialsProducer +} + +func (r *RedShift) secretValues() map[string]string { + return map[string]string{ + r.Password: "[password]", + } } func (r *RedShift) Type() (string, error) { return middlewareTypeName, nil } -// getConnection accepts a context and retuns a new pointer to a sql.DB object. +// Initialize must be called on each new RedShift struct before use. +// It uses the connutil.SQLConnectionProducer's Init function to do all the lifting. +func (r *RedShift) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) { + conf, err := r.Init(ctx, req.Config, req.VerifyConnection) + if err != nil { + return dbplugin.InitializeResponse{}, fmt.Errorf("error initializing db: %w", err) + } + + return dbplugin.InitializeResponse{ + Config: conf, + }, nil +} + +// getConnection accepts a context and returns a new pointer to a sql.DB object. // It's up to the caller to close the connection or handle reuse logic. func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) { db, err := r.Connection(ctx) @@ -86,116 +94,44 @@ func (r *RedShift) getConnection(ctx context.Context) (*sql.DB, error) { return db.(*sql.DB), nil } -// SetCredentials uses provided information to set/create a user in the -// database. Unlike CreateUser, this method requires a username be provided and -// uses the name given, instead of generating a name. This is used for creating -// and setting the password of static accounts, as well as rolling back -// passwords in the database in the event an updated database fails to save in -// Vault's storage. -func (r *RedShift) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { - if len(statements.Rotation) == 0 { - statements.Rotation = []string{defaultRotateRootCredentialsSQL} - } - - username = staticUser.Username - password = staticUser.Password - if username == "" || password == "" { - return "", "", errors.New("must provide both username and password") +// NewUser creates a new user in the database. There is no default statement for +// creating users, so one must be specified in the plugin config. +// Generated usernames are of the form v-{display-name}-{role-name}-{UUID}-{timestamp} +func (r *RedShift) NewUser(ctx context.Context, req dbplugin.NewUserRequest) (dbplugin.NewUserResponse, error) { + if len(req.Statements.Commands) == 0 { + return dbplugin.NewUserResponse{}, dbutil.ErrEmptyCreationStatement } // Grab the lock r.Lock() defer r.Unlock() - // Get the connection - db, err := r.getConnection(ctx) + usernameOpts := []credsutil.UsernameOpt{ + credsutil.DisplayName(req.UsernameConfig.DisplayName, 8), + credsutil.RoleName(req.UsernameConfig.RoleName, 8), + credsutil.MaxLength(63), + credsutil.Separator("-"), + credsutil.ToLower(), + } + + username, err := credsutil.GenerateUsername(usernameOpts...) if err != nil { - return "", "", err - } - defer db.Close() - - // Check if the role exists - var exists bool - err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) - if err != nil && err != sql.ErrNoRows { - return "", "", err - } - - // Vault requires the database user already exist, and that the credentials - // used to execute the rotation statements has sufficient privileges. - stmts := statements.Rotation - - // Start a transaction - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return "", "", err - } - defer func() { - tx.Rollback() - }() - - // Execute each query - for _, stmt := range stmts { - for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - m := map[string]string{ - "name": staticUser.Username, - "password": password, - } - if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return "", "", err - } - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return "", "", err - } - return username, password, nil -} - -func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statements, usernameConfig dbplugin.UsernameConfig, expiration time.Time) (username string, password string, err error) { - statements = dbutil.StatementCompatibilityHelper(statements) - - if len(statements.Creation) == 0 { - return "", "", dbutil.ErrEmptyCreationStatement - } - - // Grab the lock - r.Lock() - defer r.Unlock() - - username, err = r.GenerateUsername(usernameConfig) - if err != nil { - return "", "", err - } - - password, err = r.GeneratePassword() - if err != nil { - return "", "", err - } - - expirationStr, err := r.GenerateExpiration(expiration) - if err != nil { - return "", "", err + return dbplugin.NewUserResponse{}, err } + password := req.Password + expirationStr := req.Expiration.Format("2006-01-02 15:04:05-0700") // Get the connection db, err := r.getConnection(ctx) if err != nil { - return "", "", err + return dbplugin.NewUserResponse{}, err } defer db.Close() // Start a transaction tx, err := db.BeginTx(ctx, nil) if err != nil { - return "", "", err + return dbplugin.NewUserResponse{}, err } defer func() { @@ -203,7 +139,7 @@ func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statement }() // Execute each query - for _, stmt := range statements.Creation { + for _, stmt := range req.Statements.Commands { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { @@ -212,53 +148,81 @@ func (r *RedShift) CreateUser(ctx context.Context, statements dbplugin.Statement m := map[string]string{ "name": username, + "username": username, "password": password, "expiration": expirationStr, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return "", "", err + return dbplugin.NewUserResponse{}, err } } } // Commit the transaction if err := tx.Commit(); err != nil { - return "", "", err + return dbplugin.NewUserResponse{}, err } - return username, password, nil + return dbplugin.NewUserResponse{ + Username: username, + }, nil } -func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements, username string, expiration time.Time) error { +// UpdateUser can update the expiration or the password of a user, or both. +// The updates all happen in a single transaction, so they will either all +// succeed or all fail. +// Both updates support both default and custom statements. +func (r *RedShift) UpdateUser(ctx context.Context, req dbplugin.UpdateUserRequest) (dbplugin.UpdateUserResponse, error) { + if req.Password == nil && req.Expiration == nil { + return dbplugin.UpdateUserResponse{}, errors.New("no changes requested") + } + r.Lock() defer r.Unlock() - statements = dbutil.StatementCompatibilityHelper(statements) - - renewStmts := statements.Renewal - if len(renewStmts) == 0 { - renewStmts = []string{defaultRenewSQL} - } - db, err := r.getConnection(ctx) if err != nil { - return err + return dbplugin.UpdateUserResponse{}, err } defer db.Close() tx, err := db.BeginTx(ctx, nil) if err != nil { - return err + return dbplugin.UpdateUserResponse{}, err } defer func() { tx.Rollback() }() - expirationStr, err := r.GenerateExpiration(expiration) - if err != nil { - return err + if req.Expiration != nil { + err = updateUserExpiration(ctx, req, tx) + if err != nil { + return dbplugin.UpdateUserResponse{}, err + } } - for _, stmt := range renewStmts { + if req.Password != nil { + err = updateUserPassword(ctx, req, tx) + if err != nil { + return dbplugin.UpdateUserResponse{}, err + } + } + + err = tx.Commit() + return dbplugin.UpdateUserResponse{}, err +} + +func updateUserExpiration(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error { + if req.Username == "" { + return errors.New("must provide a username to update user expiration") + } + renewStmts := req.Expiration.Statements + if len(renewStmts.Commands) == 0 { + renewStmts.Commands = []string{defaultRenewSQL} + } + + expirationStr := req.Expiration.NewExpiration.Format("2006-01-02 15:04:05-0700") + + for _, stmt := range renewStmts.Commands { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { @@ -266,7 +230,8 @@ func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements } m := map[string]string{ - "name": username, + "name": req.Username, + "username": req.Username, "expiration": expirationStr, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { @@ -275,39 +240,36 @@ func (r *RedShift) RenewUser(ctx context.Context, statements dbplugin.Statements } } - return tx.Commit() + return nil } -func (r *RedShift) RevokeUser(ctx context.Context, statements dbplugin.Statements, username string) error { - // Grab the lock - r.Lock() - defer r.Unlock() - - statements = dbutil.StatementCompatibilityHelper(statements) - - if len(statements.Revocation) == 0 { - return r.defaultRevokeUser(ctx, username) +func updateUserPassword(ctx context.Context, req dbplugin.UpdateUserRequest, tx *sql.Tx) error { + username := req.Username + password := req.Password.NewPassword + if username == "" || password == "" { + return errors.New("must provide both username and a new password to update user password") } - return r.customRevokeUser(ctx, username, statements.Revocation) -} - -func (r *RedShift) customRevokeUser(ctx context.Context, username string, revocationStmts []string) error { - db, err := r.getConnection(ctx) - if err != nil { + // Check if the role exists + var exists bool + err := tx.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + // Server error return err } - defer db.Close() - - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return err + if err == sql.ErrNoRows || !exists { + // Most likely a user error + return fmt.Errorf("cannot update password for username %q because it does not exist", username) } - defer func() { - tx.Rollback() - }() - for _, stmt := range revocationStmts { + // Vault requires the database user already exist, and that the credentials + // used to execute the rotation statements has sufficient privileges. + statements := req.Password.Statements.Commands + if len(statements) == 0 { + statements = []string{defaultRotateRootCredentialsSQL} + } + // Execute each query + for _, stmt := range statements { for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { query = strings.TrimSpace(query) if len(query) == 0 { @@ -315,7 +277,9 @@ func (r *RedShift) customRevokeUser(ctx context.Context, username string, revoca } m := map[string]string{ - "name": username, + "name": username, + "username": username, + "password": password, } if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { return err @@ -323,25 +287,76 @@ func (r *RedShift) customRevokeUser(ctx context.Context, username string, revoca } } - return tx.Commit() + return nil } -func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error { +// DeleteUser supports both default and custom statements to delete a user. +func (r *RedShift) DeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) { + // Grab the lock + r.Lock() + defer r.Unlock() + + if len(req.Statements.Commands) == 0 { + return r.defaultDeleteUser(ctx, req) + } + + return r.customDeleteUser(ctx, req) +} + +func (r *RedShift) customDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) { db, err := r.getConnection(ctx) if err != nil { - return err + return dbplugin.DeleteUserResponse{}, err } defer db.Close() + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return dbplugin.DeleteUserResponse{}, err + } + defer func() { + tx.Rollback() + }() + + for _, stmt := range req.Statements.Commands { + for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + m := map[string]string{ + "name": req.Username, + "username": req.Username, + } + if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { + return dbplugin.DeleteUserResponse{}, err + } + } + } + + return dbplugin.DeleteUserResponse{}, tx.Commit() +} + +func (r *RedShift) defaultDeleteUser(ctx context.Context, req dbplugin.DeleteUserRequest) (dbplugin.DeleteUserResponse, error) { + db, err := r.getConnection(ctx) + if err != nil { + return dbplugin.DeleteUserResponse{}, err + } + defer db.Close() + + username := req.Username + // Check if the role exists var exists bool err = db.QueryRowContext(ctx, "SELECT exists (SELECT usename FROM pg_user WHERE usename=$1);", username).Scan(&exists) if err != nil && err != sql.ErrNoRows { - return err + return dbplugin.DeleteUserResponse{}, err } if !exists { - return nil + // No error as Redshift may have deleted the user via TTL before we got to it. + return dbplugin.DeleteUserResponse{}, nil } // Query for permissions; we need to revoke permissions before we can drop @@ -350,13 +365,13 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error // we want to remove as much access as possible stmt, err := db.PrepareContext(ctx, "SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") if err != nil { - return err + return dbplugin.DeleteUserResponse{}, err } defer stmt.Close() rows, err := stmt.QueryContext(ctx, username) if err != nil { - return err + return dbplugin.DeleteUserResponse{}, err } defer rows.Close() @@ -393,7 +408,7 @@ func (r *RedShift) defaultRevokeUser(ctx context.Context, username string) error // this username var dbname sql.NullString if err := db.QueryRowContext(ctx, "SELECT current_database();").Scan(&dbname); err != nil { - return err + return dbplugin.DeleteUserResponse{}, err } if dbname.Valid { @@ -432,78 +447,22 @@ $$;`) // can't drop if not all privileges are revoked if rows.Err() != nil { - return errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err()) + return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not generate revocation statements for all rows: {{err}}", rows.Err()) } if lastStmtError != nil { - return errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError) + return dbplugin.DeleteUserResponse{}, errwrap.Wrapf("could not perform all revocation statements: {{err}}", lastStmtError) } // Drop this user stmt, err = db.PrepareContext(ctx, fmt.Sprintf( `DROP USER IF EXISTS %s;`, pq.QuoteIdentifier(username))) if err != nil { - return err + return dbplugin.DeleteUserResponse{}, err } defer stmt.Close() if _, err := stmt.ExecContext(ctx); err != nil { - return err + return dbplugin.DeleteUserResponse{}, err } - return nil -} - -func (r *RedShift) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { - r.Lock() - defer r.Unlock() - - if len(r.Username) == 0 || len(r.Password) == 0 { - return nil, errors.New("username and password are required to rotate") - } - - rotateStatements := statements - if len(rotateStatements) == 0 { - rotateStatements = []string{defaultRotateRootCredentialsSQL} - } - - db, err := r.getConnection(ctx) - if err != nil { - return nil, err - } - defer db.Close() - - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer func() { - tx.Rollback() - }() - - password, err := r.GeneratePassword() - if err != nil { - return nil, err - } - - for _, stmt := range rotateStatements { - for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - m := map[string]string{ - "username": r.Username, - "password": password, - } - if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil { - return nil, err - } - } - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - r.RawConfig["password"] = password - return r.RawConfig, nil + return dbplugin.DeleteUserResponse{}, nil } diff --git a/plugins/database/redshift/redshift_test.go b/plugins/database/redshift/redshift_test.go index c8b6ffdd30..17aeb06cda 100644 --- a/plugins/database/redshift/redshift_test.go +++ b/plugins/database/redshift/redshift_test.go @@ -3,16 +3,19 @@ package redshift import ( "context" "database/sql" - "errors" "fmt" "os" - "strings" + "reflect" + "regexp" "testing" "time" "github.com/hashicorp/go-uuid" - "github.com/hashicorp/vault/sdk/database/dbplugin" + + dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing" "github.com/hashicorp/vault/sdk/helper/dbtxn" + + dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/lib/pq" ) @@ -25,10 +28,6 @@ as environment variables to be used to run these tests. Note that these tests will create users on your redshift cluster and currently do not clean up after themselves. -The RotateRoot test is potentially destructive in that it will rotate your root -password on your Redshift cluster to an insecure, cleartext password defined in the -test method. Because of this, you must pass TEST_ROTATE_ROOT=1 to enable it explicitly. - Do not run this test suite against a production Redshift cluster. Configuration: @@ -37,7 +36,6 @@ Configuration: REDSHIFT_USER=my-redshift-admin-user REDSHIFT_PASSWORD=my-redshift-admin-password VAULT_ACC= # This must be set to run any of the tests in this test suite - TEST_ROTATE_ROOT= # This must be set to explicitly run the rotate root test */ var ( @@ -48,281 +46,230 @@ var ( vaultACC = "VAULT_ACC" ) -func redshiftEnv() (url string, user string, password string, errEmpty error) { - errEmpty = errors.New("err: empty but required env value") +func interpolateConnectionURL(url, user, password string) string { + return fmt.Sprintf("postgres://%s:%s@%s", user, password, url) +} +func redshiftEnv() (connURL string, url string, user string, password string, errEmpty error) { if url = os.Getenv(keyRedshiftURL); url == "" { - return "", "", "", errEmpty + return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftURL) } if user = os.Getenv(keyRedshiftUser); url == "" { - return "", "", "", errEmpty + return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftUser) } if password = os.Getenv(keyRedshiftPassword); url == "" { - return "", "", "", errEmpty + return "", "", "", "", fmt.Errorf("%s environment variable required", keyRedshiftPassword) } - url = fmt.Sprintf("postgres://%s:%s@%s", user, password, url) - - return url, user, password, nil + connURL = interpolateConnectionURL(url, user, password) + return connURL, url, user, password, nil } -func TestPostgreSQL_Initialize(t *testing.T) { +func TestRedshift_Initialize(t *testing.T) { if os.Getenv(vaultACC) != "1" { t.SkipNow() } - url, _, _, err := redshiftEnv() + connURL, _, _, _, err := redshiftEnv() if err != nil { t.Fatal(err) } connectionDetails := map[string]interface{}{ - "connection_url": url, - "max_open_connections": 5, + "connection_url": connURL, + "max_open_connections": 73, } - db := newRedshift(true) - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } + db := newRedshift() + resp := dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + }) if !db.Initialized { t.Fatal("Database should be initialized") } - - err = db.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test decoding a string value for max_open_connections - connectionDetails = map[string]interface{}{ - "connection_url": url, - "max_open_connections": "5", - } - - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) + expectedConfig := make(map[string]interface{}) + for k, v := range connectionDetails { + expectedConfig[k] = v + } + if !reflect.DeepEqual(expectedConfig, resp.Config) { + t.Fatalf("Expected config %+v, but was %v", expectedConfig, resp.Config) + } + if db.MaxOpenConnections != 73 { + t.Fatalf("Expected max_open_connections to be set to 73, but was %d", db.MaxOpenConnections) } + dbtesting.AssertClose(t, db) } -func TestPostgreSQL_CreateUser(t *testing.T) { +func TestRedshift_NewUser(t *testing.T) { if os.Getenv(vaultACC) != "1" { t.SkipNow() } - url, _, _, err := redshiftEnv() + connURL, url, _, _, err := redshiftEnv() if err != nil { t.Fatal(err) } connectionDetails := map[string]interface{}{ - "connection_url": url, + "connection_url": connURL, } - db := newRedshift(true) - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } + db := newRedshift() + dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + }) - usernameConfig := dbplugin.UsernameConfig{ + usernameConfig := dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", } + const password = "SuperSecurePa55w0rd!" + for _, commands := range [][]string{{testRedshiftRole}, {testRedshiftReadOnlyRole}} { + resp := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{ + UsernameConfig: usernameConfig, + Password: password, + Statements: dbplugin.Statements{ + Commands: commands, + }, + Expiration: time.Now().Add(5 * time.Minute), + }) + username := resp.Username + + if err = testCredsExist(t, url, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s\n%s:%s", err, username, password) + } + + usernameRegex := regexp.MustCompile("^v-test-test-[a-zA-Z0-9]{20}-[0-9]{10}$") + if !usernameRegex.Match([]byte(username)) { + t.Fatalf("Expected username %q to match regex %q", username, usernameRegex.String()) + } + } + + dbtesting.AssertClose(t, db) +} + +func TestRedshift_NewUser_NoCreationStatement_ShouldError(t *testing.T) { + if os.Getenv(vaultACC) != "1" { + t.SkipNow() + } + + connURL, _, _, _, err := redshiftEnv() + if err != nil { + t.Fatal(err) + } + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := newRedshift() + dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + }) + + usernameConfig := dbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", + } + + const password = "SuperSecurePa55w0rd!" + // Test with no configured Creation Statement - _, _, err = db.CreateUser(context.Background(), dbplugin.Statements{}, usernameConfig, time.Now().Add(time.Minute)) + _, err = db.NewUser(context.Background(), dbplugin.NewUserRequest{ + UsernameConfig: usernameConfig, + Password: password, + Statements: dbplugin.Statements{ + Commands: []string{}, // Empty commands field here should cause error. + }, + Expiration: time.Now().Add(5 * time.Minute), + }) if err == nil { t.Fatal("Expected error when no creation statement is provided") } - statements := dbplugin.Statements{ - Creation: []string{testRedshiftRole}, - } - - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s\n%s:%s", err, username, password) - } - - statements.Creation = []string{testRedshiftReadOnlyRole} - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Sleep to make sure we haven't expired if granularity is only down to the second - time.Sleep(2 * time.Second) - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + dbtesting.AssertClose(t, db) } -func TestPostgreSQL_RenewUser(t *testing.T) { +func TestRedshift_UpdateUser_Expiration(t *testing.T) { if os.Getenv(vaultACC) != "1" { t.SkipNow() } - url, _, _, err := redshiftEnv() + connURL, url, _, _, err := redshiftEnv() if err != nil { t.Fatal(err) } connectionDetails := map[string]interface{}{ - "connection_url": url, + "connection_url": connURL, } - db := newRedshift(true) - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } + db := newRedshift() + dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + }) - statements := dbplugin.Statements{ - Creation: []string{testRedshiftRole}, - } - - usernameConfig := dbplugin.UsernameConfig{ + usernameConfig := dbplugin.UsernameMetadata{ DisplayName: "test", RoleName: "test", } - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Sleep longer than the initial expiration time - time.Sleep(2 * time.Second) - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - statements.Renewal = []string{defaultRenewSQL} - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute)) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Sleep longer than the initial expiration time - time.Sleep(2 * time.Second) - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) + const password = "SuperSecurePa55w0rd!" + const initialTTL = 2 * time.Second + const longTTL = time.Minute + for _, commands := range [][]string{{}, {defaultRenewSQL}} { + newResp := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{ + UsernameConfig: usernameConfig, + Password: password, + Statements: dbplugin.Statements{Commands: []string{testRedshiftRole}}, + Expiration: time.Now().Add(initialTTL), + }) + username := newResp.Username + + if err = testCredsExist(t, url, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + dbtesting.AssertUpdateUser(t, db, dbplugin.UpdateUserRequest{ + Username: username, + Expiration: &dbplugin.ChangeExpiration{ + NewExpiration: time.Now().Add(longTTL), + Statements: dbplugin.Statements{Commands: commands}, + }, + }) + + // Sleep longer than the initial expiration time + time.Sleep(initialTTL + time.Second) + + if err = testCredsExist(t, url, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } } + dbtesting.AssertClose(t, db) } -func TestPostgreSQL_RevokeUser(t *testing.T) { +func TestRedshift_UpdateUser_Password(t *testing.T) { if os.Getenv(vaultACC) != "1" { t.SkipNow() } - url, _, _, err := redshiftEnv() + connURL, url, _, _, err := redshiftEnv() if err != nil { t.Fatal(err) } connectionDetails := map[string]interface{}{ - "connection_url": url, - } - - db := newRedshift(true) - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := dbplugin.Statements{ - Creation: []string{testRedshiftRole}, - } - - usernameConfig := dbplugin.UsernameConfig{ - DisplayName: "test", - RoleName: "test", - } - - username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - // Test default revoke statements - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, url, username, password); err == nil { - t.Fatal("Credentials were not revoked") - } - - username, password, err = db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(2*time.Second)) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err = testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - // Test custom revoke statements - statements.Revocation = []string{defaultRedshiftRevocationSQL} - err = db.RevokeUser(context.Background(), statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, url, username, password); err == nil { - t.Fatal("Credentials were not revoked") - } -} - -func TestPostgresSQL_SetCredentials(t *testing.T) { - if os.Getenv(vaultACC) != "1" { - t.SkipNow() - } - - url, _, _, err := redshiftEnv() - if err != nil { - t.Fatal(err) - } - - connectionDetails := map[string]interface{}{ - "connection_url": url, + "connection_url": connURL, } // create the database user @@ -331,121 +278,97 @@ func TestPostgresSQL_SetCredentials(t *testing.T) { t.Fatal(err) } dbUser := "vaultstatictest-" + fmt.Sprintf("%s", uid) - createTestPGUser(t, url, dbUser, "1Password", testRoleStaticCreate) + createTestPGUser(t, connURL, dbUser, "1Password", testRoleStaticCreate) - db := newRedshift(true) - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) + db := newRedshift() + dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + }) + + const password1 = "MyTemporaryUserPassword1!" + const password2 = "MyTemporaryUserPassword2!" + + for _, tc := range []struct { + password string + commands []string + }{ + {password1, []string{}}, + {password2, []string{testRedshiftStaticRoleRotate}}, + } { + dbtesting.AssertUpdateUser(t, db, dbplugin.UpdateUserRequest{ + Username: dbUser, + Password: &dbplugin.ChangePassword{ + NewPassword: tc.password, + Statements: dbplugin.Statements{Commands: tc.commands}, + }, + }) + + if err := testCredsExist(t, url, dbUser, tc.password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } } - password, err := db.GenerateCredentials(context.Background()) - if err != nil { - t.Fatal(err) - } - - usernameConfig := dbplugin.StaticUserConfig{ - Username: dbUser, - Password: password, - } - - // Test with no configured Rotation Statement - username, password, err := db.SetCredentials(context.Background(), dbplugin.Statements{}, usernameConfig) - if err == nil { - t.Fatalf("err: %s", err) - } - - statements := dbplugin.Statements{ - Rotation: []string{testRedshiftStaticRoleRotate}, - } - // User should not exist, make sure we can create - username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } - - // call SetCredentials again, password will change - newPassword, _ := db.GenerateCredentials(context.Background()) - usernameConfig.Password = newPassword - username, password, err = db.SetCredentials(context.Background(), statements, usernameConfig) - if err != nil { - t.Fatalf("err: %s", err) - } - - if password != newPassword { - t.Fatal("passwords should have changed") - } - - if err := testCredsExist(t, url, username, password); err != nil { - t.Fatalf("Could not connect with new credentials: %s", err) - } + dbtesting.AssertClose(t, db) } -func TestPostgreSQL_RotateRootCredentials(t *testing.T) { - /* - Extra precaution is taken for rotating root creds because it's assumed that this - test will run against a live redshift cluster. This test must run last because - it is destructive. - - To run this test you must pass TEST_ROTATE_ROOT=1 - */ - if os.Getenv(vaultACC) != "1" || os.Getenv("TEST_ROTATE_ROOT") != "1" { +func TestRedshift_DeleteUser(t *testing.T) { + if os.Getenv(vaultACC) != "1" { t.SkipNow() } - url, adminUser, adminPassword, err := redshiftEnv() + connURL, url, _, _, err := redshiftEnv() if err != nil { t.Fatal(err) } connectionDetails := map[string]interface{}{ - "connection_url": url, - "username": adminUser, - "password": adminPassword, + "connection_url": connURL, } - db := newRedshift(true) + db := newRedshift() + dbtesting.AssertInitialize(t, db, dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + }) - connProducer := db.SQLConnectionProducer - - _, err = db.Init(context.Background(), connectionDetails, true) - if err != nil { - t.Fatalf("err: %s", err) + usernameConfig := dbplugin.UsernameMetadata{ + DisplayName: "test", + RoleName: "test", } - if !connProducer.Initialized { - t.Fatal("Database should be initialized") + const password = "SuperSecretPa55word!" + for _, commands := range [][]string{{}, {defaultRedshiftRevocationSQL}} { + newResponse := dbtesting.AssertNewUser(t, db, dbplugin.NewUserRequest{ + UsernameConfig: usernameConfig, + Statements: dbplugin.Statements{Commands: []string{testRedshiftRole}}, + Password: password, + Expiration: time.Now().Add(2 * time.Second), + }) + username := newResponse.Username + + if err = testCredsExist(t, url, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Intentionally _not_ using dbtesting here as the call almost always takes longer than the 2s default timeout + db.DeleteUser(context.Background(), dbplugin.DeleteUserRequest{ + Username: username, + Statements: dbplugin.Statements{Commands: commands}, + }) + + if err := testCredsExist(t, url, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } } - newConf, err := db.RotateRootCredentials(context.Background(), nil) - if err != nil { - t.Fatalf("err: %v", err) - } - - fmt.Printf("rotated root credentials, new user/pass:\nusername: %s\npassword: %s\n", newConf["username"], newConf["password"]) - - if newConf["password"] == adminPassword { - t.Fatal("password was not updated") - } - - err = db.Close() - if err != nil { - t.Fatalf("err: %s", err) - } + dbtesting.AssertClose(t, db) } -func testCredsExist(t testing.TB, connURL, username, password string) error { +func testCredsExist(t testing.TB, url, username, password string) error { t.Helper() - _, adminUser, adminPassword, err := redshiftEnv() - if err != nil { - return err - } - connURL = strings.Replace(connURL, fmt.Sprintf("%s:%s", adminUser, adminPassword), fmt.Sprintf("%s:%s", username, password), 1) + connURL := interpolateConnectionURL(url, username, password) db, err := sql.Open("postgres", connURL) if err != nil { return err