vault/plugins/database/cassandra/cassandra.go

245 lines
6.1 KiB
Go

package cassandra
import (
"context"
"fmt"
"strings"
"github.com/gocql/gocql"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/database/helper/credsutil"
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
"github.com/hashicorp/vault/sdk/database/newdbplugin"
"github.com/hashicorp/vault/sdk/helper/strutil"
)
const (
defaultUserCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
defaultUserDeletionCQL = `DROP USER '{{username}}';`
defaultChangePasswordCQL = `ALTER USER {{username}} WITH PASSWORD '{{password}}';`
cassandraTypeName = "cassandra"
)
var _ newdbplugin.Database = &Cassandra{}
// Cassandra is an implementation of Database interface
type Cassandra struct {
*cassandraConnectionProducer
}
// New returns a new Cassandra instance
func New() (interface{}, error) {
db := new()
dbType := newdbplugin.NewDatabaseErrorSanitizerMiddleware(db, db.secretValues)
return dbType, nil
}
func new() *Cassandra {
connProducer := &cassandraConnectionProducer{}
connProducer.Type = cassandraTypeName
return &Cassandra{
cassandraConnectionProducer: connProducer,
}
}
// Run instantiates a Cassandra object, and runs the RPC server for the plugin
func Run(apiTLSConfig *api.TLSConfig) error {
dbType, err := New()
if err != nil {
return err
}
newdbplugin.Serve(dbType.(newdbplugin.Database), api.VaultPluginTLSProvider(apiTLSConfig))
return nil
}
// Type returns the TypeName for this backend
func (c *Cassandra) Type() (string, error) {
return cassandraTypeName, nil
}
func (c *Cassandra) getConnection(ctx context.Context) (*gocql.Session, error) {
session, err := c.Connection(ctx)
if err != nil {
return nil, err
}
return session.(*gocql.Session), nil
}
// NewUser generates the username/password on the underlying Cassandra secret backend as instructed by
// the statements provided.
func (c *Cassandra) NewUser(ctx context.Context, req newdbplugin.NewUserRequest) (newdbplugin.NewUserResponse, error) {
c.Lock()
defer c.Unlock()
session, err := c.getConnection(ctx)
if err != nil {
return newdbplugin.NewUserResponse{}, err
}
creationCQL := req.Statements.Commands
if len(creationCQL) == 0 {
creationCQL = []string{defaultUserCreationCQL}
}
rollbackCQL := req.RollbackStatements.Commands
if len(rollbackCQL) == 0 {
rollbackCQL = []string{defaultUserDeletionCQL}
}
username, err := credsutil.GenerateUsername(
credsutil.DisplayName(req.UsernameConfig.DisplayName, 15),
credsutil.RoleName(req.UsernameConfig.RoleName, 15),
credsutil.Separator("_"),
credsutil.MaxLength(100),
credsutil.ToLower(),
)
if err != nil {
return newdbplugin.NewUserResponse{}, err
}
username = strings.ReplaceAll(username, "-", "_")
for _, stmt := range creationCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": username,
"password": req.Password,
}
err = session.
Query(dbutil.QueryHelper(query, m)).
WithContext(ctx).
Exec()
if err != nil {
rollbackErr := rollbackUser(ctx, session, username, rollbackCQL)
if rollbackErr != nil {
err = multierror.Append(err, rollbackErr)
}
return newdbplugin.NewUserResponse{}, err
}
}
}
resp := newdbplugin.NewUserResponse{
Username: username,
}
return resp, nil
}
func rollbackUser(ctx context.Context, session *gocql.Session, username string, rollbackCQL []string) error {
for _, stmt := range rollbackCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": username,
}
err := session.
Query(dbutil.QueryHelper(query, m)).
WithContext(ctx).
Exec()
if err != nil {
return fmt.Errorf("failed to roll back user %s: %w", username, err)
}
}
}
return nil
}
func (c *Cassandra) UpdateUser(ctx context.Context, req newdbplugin.UpdateUserRequest) (newdbplugin.UpdateUserResponse, error) {
if req.Password == nil && req.Expiration == nil {
return newdbplugin.UpdateUserResponse{}, fmt.Errorf("no changes requested")
}
if req.Password != nil {
err := c.changeUserPassword(ctx, req.Username, req.Password)
return newdbplugin.UpdateUserResponse{}, err
}
// Expiration is no-op
return newdbplugin.UpdateUserResponse{}, nil
}
func (c *Cassandra) changeUserPassword(ctx context.Context, username string, changePass *newdbplugin.ChangePassword) error {
session, err := c.getConnection(ctx)
if err != nil {
return err
}
rotateCQL := changePass.Statements.Commands
if len(rotateCQL) == 0 {
rotateCQL = []string{defaultChangePasswordCQL}
}
var result *multierror.Error
for _, stmt := range rotateCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": username,
"password": changePass.NewPassword,
}
err := session.
Query(dbutil.QueryHelper(query, m)).
WithContext(ctx).
Exec()
result = multierror.Append(result, err)
}
}
return result.ErrorOrNil()
}
// DeleteUser attempts to drop the specified user.
func (c *Cassandra) DeleteUser(ctx context.Context, req newdbplugin.DeleteUserRequest) (newdbplugin.DeleteUserResponse, error) {
c.Lock()
defer c.Unlock()
session, err := c.getConnection(ctx)
if err != nil {
return newdbplugin.DeleteUserResponse{}, err
}
revocationCQL := req.Statements.Commands
if len(revocationCQL) == 0 {
revocationCQL = []string{defaultUserDeletionCQL}
}
var result *multierror.Error
for _, stmt := range revocationCQL {
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
m := map[string]string{
"username": req.Username,
}
err := session.
Query(dbutil.QueryHelper(query, m)).
WithContext(ctx).
Exec()
result = multierror.Append(result, err)
}
}
return newdbplugin.DeleteUserResponse{}, result.ErrorOrNil()
}