Refactor to use builtin plugins from an external repo

This commit is contained in:
Brian Kassouf 2017-04-05 16:20:31 -07:00
parent 8f88452fc0
commit 8a2e29c607
20 changed files with 110 additions and 2380 deletions

View File

@ -4,16 +4,52 @@ import (
"fmt"
"strings"
"sync"
"time"
log "github.com/mgutz/logxi/v1"
"github.com/hashicorp/vault/builtin/logical/database/dbs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
const databaseConfigPath = "database/dbs/"
// DatabaseType is the interface that all database objects must implement.
type DatabaseType interface {
Type() string
CreateUser(statements Statements, username, password, expiration string) error
RenewUser(statements Statements, username, expiration string) error
RevokeUser(statements Statements, username string) error
Initialize(map[string]interface{}) error
Close() error
GenerateUsername(displayName string) (string, error)
GeneratePassword() (string, error)
GenerateExpiration(ttl time.Duration) (string, error)
}
// DatabaseConfig is used by the Factory function to configure a DatabaseType
// object.
type DatabaseConfig struct {
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
// ConnectionDetails stores the database specific connection settings needed
// by each database type.
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
}
// Statements set in role creation and passed into the database type's functions.
// TODO: Add a way of setting defaults here.
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
}
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
return Backend(conf).Setup(conf)
}
@ -30,7 +66,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
},
Paths: []*framework.Path{
pathConfigureBuiltinConnection(&b),
pathConfigurePluginConnection(&b),
pathListRoles(&b),
pathRoles(&b),
@ -48,12 +83,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
}
b.logger = conf.Logger
b.connections = make(map[string]dbs.DatabaseType)
b.connections = make(map[string]DatabaseType)
return &b
}
type databaseBackend struct {
connections map[string]dbs.DatabaseType
connections map[string]DatabaseType
logger log.Logger
*framework.Backend
@ -73,7 +108,7 @@ func (b *databaseBackend) closeAllDBs() {
// This function is used to retrieve a database object either from the cached
// connection map or by using the database config in storage. The caller of this
// function needs to hold the backend's lock.
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) {
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (DatabaseType, error) {
// if the object already is built and cached, return it
db, ok := b.connections[name]
if ok {
@ -88,14 +123,12 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
}
var config dbs.DatabaseConfig
var config DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
factory := config.GetFactory()
db, err = factory(&config, b.System(), b.logger)
db, err = PluginFactory(&config, b.System(), b.logger)
if err != nil {
return nil, err
}

View File

@ -1,4 +1,4 @@
package dbs
package database
import (
"time"

View File

@ -1,108 +0,0 @@
package dbs
import (
"fmt"
"strings"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/strutil"
)
const (
defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
defaultRollbackCQL = `DROP USER '{{username}}';`
)
type Cassandra struct {
// Session is goroutine safe, however, since we reinitialize
// it when connection info changes, we want to make sure we
// can close it and use a new connection; hence the lock
ConnectionProducer
CredentialsProducer
}
func (c *Cassandra) Type() string {
return cassandraTypeName
}
func (c *Cassandra) getConnection() (*gocql.Session, error) {
session, err := c.connection()
if err != nil {
return nil, err
}
return session.(*gocql.Session), nil
}
func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error {
// Grab the lock
c.Lock()
defer c.Unlock()
// Get the connection
session, err := c.getConnection()
if err != nil {
return err
}
creationCQL := statements.CreationStatements
if creationCQL == "" {
creationCQL = defaultCreationCQL
}
rollbackCQL := statements.RollbackStatements
if rollbackCQL == "" {
rollbackCQL = defaultRollbackCQL
}
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
err = session.Query(queryHelper(query, map[string]string{
"username": username,
"password": password,
})).Exec()
if err != nil {
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
session.Query(queryHelper(query, map[string]string{
"username": username,
"password": password,
})).Exec()
}
return err
}
}
return nil
}
func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error {
// NOOP
return nil
}
func (c *Cassandra) RevokeUser(statements Statements, username string) error {
// Grab the lock
c.Lock()
defer c.Unlock()
session, err := c.getConnection()
if err != nil {
return err
}
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec()
if err != nil {
return fmt.Errorf("error removing user %s", username)
}
return nil
}

View File

@ -1,280 +0,0 @@
package dbs
import (
"crypto/tls"
"database/sql"
"errors"
"fmt"
"strings"
"sync"
"time"
// Import sql drivers
_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
"github.com/mitchellh/mapstructure"
"github.com/gocql/gocql"
"github.com/hashicorp/vault/helper/certutil"
"github.com/hashicorp/vault/helper/tlsutil"
)
var (
errNotInitalized = errors.New("connection has not been initalized")
)
// ConnectionProducer can be used as an embeded interface in the DatabaseType
// definition. It implements the methods dealing with individual database
// connections and is used in all the builtin database types.
type ConnectionProducer interface {
Close() error
Initialize(map[string]interface{}) error
sync.Locker
connection() (interface{}, error)
}
// sqlConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
type sqlConnectionProducer struct {
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
config *DatabaseConfig
initalized bool
db *sql.DB
sync.Mutex
}
func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error {
c.Lock()
defer c.Unlock()
err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
if _, err := c.connection(); err != nil {
return fmt.Errorf("error initalizing connection: %s", err)
}
c.initalized = true
return nil
}
func (c *sqlConnectionProducer) connection() (interface{}, error) {
// If we already have a DB, test it and return
if c.db != nil {
if err := c.db.Ping(); err == nil {
return c.db, nil
}
// If the ping was unsuccessful, close it and ignore errors as we'll be
// reestablishing anyways
c.db.Close()
}
// For mssql backend, switch to sqlserver instead
dbType := c.config.DatabaseType
if c.config.DatabaseType == "mssql" {
dbType = "sqlserver"
}
// Otherwise, attempt to make connection
conn := c.ConnectionURL
// Ensure timezone is set to UTC for all the conenctions
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
if strings.Contains(conn, "?") {
conn += "&timezone=utc"
} else {
conn += "?timezone=utc"
}
}
var err error
c.db, err = sql.Open(dbType, conn)
if err != nil {
return nil, err
}
// Set some connection pool settings. We don't need much of this,
// since the request rate shouldn't be high.
c.db.SetMaxOpenConns(c.config.MaxOpenConnections)
c.db.SetMaxIdleConns(c.config.MaxIdleConnections)
c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime)
return c.db, nil
}
func (c *sqlConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
if c.db != nil {
c.db.Close()
}
c.db = nil
return nil
}
// cassandraConnectionProducer implements ConnectionProducer and provides an
// interface for cassandra databases to make connections.
type cassandraConnectionProducer struct {
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
Username string `json:"username" structs:"username" mapstructure:"username"`
Password string `json:"password" structs:"password" mapstructure:"password"`
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
config *DatabaseConfig
initalized bool
session *gocql.Session
sync.Mutex
}
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error {
c.Lock()
defer c.Unlock()
err := mapstructure.Decode(conf, c)
if err != nil {
return err
}
c.initalized = true
if _, err := c.connection(); err != nil {
return fmt.Errorf("error Initalizing Connection: %s", err)
}
return nil
}
func (c *cassandraConnectionProducer) connection() (interface{}, error) {
if !c.initalized {
return nil, errNotInitalized
}
// If we already have a DB, return it
if c.session != nil {
return c.session, nil
}
session, err := c.createSession()
if err != nil {
return nil, err
}
// Store the session in backend for reuse
c.session = session
return session, nil
}
func (c *cassandraConnectionProducer) Close() error {
// Grab the write lock
c.Lock()
defer c.Unlock()
if c.session != nil {
c.session.Close()
}
c.session = nil
return nil
}
func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...)
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
Username: c.Username,
Password: c.Password,
}
clusterConfig.ProtoVersion = c.ProtocolVersion
if clusterConfig.ProtoVersion == 0 {
clusterConfig.ProtoVersion = 2
}
clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second
if c.TLS {
var tlsConfig *tls.Config
if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 {
if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 {
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
}
certBundle := &certutil.CertBundle{}
if len(c.Certificate) > 0 {
certBundle.Certificate = c.Certificate
certBundle.PrivateKey = c.PrivateKey
}
if len(c.IssuingCA) > 0 {
certBundle.IssuingCA = c.IssuingCA
}
parsedCertBundle, err := certBundle.ToParsedCertBundle()
if err != nil {
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
}
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
if err != nil || tlsConfig == nil {
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
}
tlsConfig.InsecureSkipVerify = c.InsecureTLS
if c.TLSMinVersion != "" {
var ok bool
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
} else {
// MinVersion was not being set earlier. Reset it to
// zero to gracefully handle upgrades.
tlsConfig.MinVersion = 0
}
}
clusterConfig.SslOpts = &gocql.SslOptions{
Config: *tlsConfig,
}
}
session, err := clusterConfig.CreateSession()
if err != nil {
return nil, fmt.Errorf("error creating session: %s", err)
}
// Set consistency
if c.Consistency != "" {
consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
if err != nil {
return nil, err
}
session.SetConsistency(consistencyValue)
}
// Verify the info
err = session.Query(`LIST USERS`).Exec()
if err != nil {
return nil, fmt.Errorf("error validating connection info: %s", err)
}
return session, nil
}

View File

@ -1,83 +0,0 @@
package dbs
import (
"fmt"
"strings"
"time"
uuid "github.com/hashicorp/go-uuid"
)
// CredentialsProducer can be used as an embeded interface in the DatabaseType
// definition. It implements the methods for generating user information for a
// particular database type and is used in all the builtin database types.
type CredentialsProducer interface {
GenerateUsername(displayName string) (string, error)
GeneratePassword() (string, error)
GenerateExpiration(ttl time.Duration) (string, error)
}
// sqlCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types.
type sqlCredentialsProducer struct {
displayNameLen int
usernameLen int
}
func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) {
if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen {
displayName = displayName[:scp.displayNameLen]
}
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
username := fmt.Sprintf("%s-%s", displayName, userUUID)
if scp.usernameLen > 0 && len(username) > scp.usernameLen {
username = username[:scp.usernameLen]
}
return username, nil
}
func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
return password, nil
}
func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) {
return time.Now().
Add(ttl).
Format("2006-01-02 15:04:05-0700"), nil
}
// cassandraCredentialsProducer implements CredentialsProducer and provides an
// interface for cassandra databases to generate user information.
type cassandraCredentialsProducer struct{}
func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
userUUID, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix())
username = strings.Replace(username, "-", "_", -1)
return username, nil
}
func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
password, err := uuid.GenerateUUID()
if err != nil {
return "", err
}
return password, nil
}
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) {
return "", nil
}

View File

@ -1,196 +0,0 @@
package dbs
import (
"errors"
"fmt"
"strings"
"time"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
)
const (
postgreSQLTypeName = "postgres"
mySQLTypeName = "mysql"
msSQLTypeName = "mssql"
cassandraTypeName = "cassandra"
pluginTypeName = "plugin"
)
var (
ErrUnsupportedDatabaseType = errors.New("unsupported database type")
ErrEmptyCreationStatement = errors.New("empty creation statements")
ErrEmptyPluginName = errors.New("empty plugin name")
)
// Factory function definition
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
// BuiltinFactory is used to build builtin database types. It wraps the database
// object in a logging and metrics middleware.
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
var dbType DatabaseType
switch conf.DatabaseType {
case postgreSQLTypeName:
connProducer := &sqlConnectionProducer{}
connProducer.config = conf
credsProducer := &sqlCredentialsProducer{
displayNameLen: 23,
usernameLen: 63,
}
dbType = &PostgreSQL{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
case mySQLTypeName:
connProducer := &sqlConnectionProducer{}
connProducer.config = conf
credsProducer := &sqlCredentialsProducer{
displayNameLen: 4,
usernameLen: 16,
}
dbType = &MySQL{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
case msSQLTypeName:
connProducer := &sqlConnectionProducer{}
connProducer.config = conf
credsProducer := &sqlCredentialsProducer{
displayNameLen: 10,
usernameLen: 63,
}
dbType = &MSSQL{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
case cassandraTypeName:
connProducer := &cassandraConnectionProducer{}
connProducer.config = conf
credsProducer := &cassandraCredentialsProducer{}
dbType = &Cassandra{
ConnectionProducer: connProducer,
CredentialsProducer: credsProducer,
}
default:
return nil, ErrUnsupportedDatabaseType
}
// Wrap with metrics middleware
dbType = &databaseMetricsMiddleware{
next: dbType,
typeStr: dbType.Type(),
}
// Wrap with tracing middleware
dbType = &databaseTracingMiddleware{
next: dbType,
typeStr: dbType.Type(),
logger: logger,
}
return dbType, nil
}
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
if conf.PluginName == "" {
return nil, ErrEmptyPluginName
}
pluginMeta, err := sys.LookupPlugin(conf.PluginName)
if err != nil {
return nil, err
}
// Make sure the database type is set to plugin
conf.DatabaseType = pluginTypeName
db, err := newPluginClient(sys, pluginMeta)
if err != nil {
return nil, err
}
// Wrap with metrics middleware
db = &databaseMetricsMiddleware{
next: db,
typeStr: db.Type(),
}
// Wrap with tracing middleware
db = &databaseTracingMiddleware{
next: db,
typeStr: db.Type(),
logger: logger,
}
return db, nil
}
// DatabaseType is the interface that all database objects must implement.
type DatabaseType interface {
Type() string
CreateUser(statements Statements, username, password, expiration string) error
RenewUser(statements Statements, username, expiration string) error
RevokeUser(statements Statements, username string) error
Initialize(map[string]interface{}) error
Close() error
CredentialsProducer
}
// DatabaseConfig is used by the Factory function to configure a DatabaseType
// object.
type DatabaseConfig struct {
DatabaseType string `json:"type" structs:"type" mapstructure:"type"`
// ConnectionDetails stores the database specific connection settings needed
// by each database type.
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
}
// GetFactory returns the appropriate factory method for the given database
// type.
func (dc *DatabaseConfig) GetFactory() Factory {
if dc.DatabaseType == pluginTypeName {
return PluginFactory
}
return BuiltinFactory
}
// Statements set in role creation and passed into the database type's functions.
// TODO: Add a way of setting defaults here.
type Statements struct {
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
}
// Query templates a query for us.
func queryHelper(tpl string, data map[string]string) string {
for k, v := range data {
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
}
return tpl
}

View File

@ -1,219 +0,0 @@
package dbs
import (
"database/sql"
"fmt"
"strings"
"github.com/hashicorp/vault/helper/strutil"
)
// MSSQL is an implementation of DatabaseType interface
type MSSQL struct {
ConnectionProducer
CredentialsProducer
}
// Type returns the TypeName for this backend
func (m *MSSQL) Type() string {
return msSQLTypeName
}
func (m *MSSQL) getConnection() (*sql.DB, error) {
db, err := m.connection()
if err != nil {
return nil, err
}
return db.(*sql.DB), nil
}
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
// the CreationStatement provided.
func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error {
// Grab the lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
if err != nil {
return err
}
if statements.CreationStatements == "" {
return ErrEmptyCreationStatement
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
"password": password,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}
// RenewUser is not supported on MSSQL, so this is a no-op.
func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error {
// NOOP
return nil
}
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
// then kill pending connections from that user, and finally drop the user and login from the
// database instance.
func (m *MSSQL) RevokeUser(statements Statements, username string) error {
// Get connection
db, err := m.getConnection()
if err != nil {
return err
}
// First disable server login
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
if err != nil {
return err
}
defer disableStmt.Close()
if _, err := disableStmt.Exec(); err != nil {
return err
}
// Query for sessions for the login so that we can kill any outstanding
// sessions. There cannot be any active sessions before we drop the logins
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
sessionStmt, err := db.Prepare(fmt.Sprintf(
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
if err != nil {
return err
}
defer sessionStmt.Close()
sessionRows, err := sessionStmt.Query()
if err != nil {
return err
}
defer sessionRows.Close()
var revokeStmts []string
for sessionRows.Next() {
var sessionID int
err = sessionRows.Scan(&sessionID)
if err != nil {
return err
}
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
}
// Query for database users using undocumented stored procedure for now since
// it is the easiest way to get this information;
// we need to drop the database users before we can drop the login and the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username))
if err != nil {
return err
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var loginName, dbName, qUsername string
var aliasName sql.NullString
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
if err != nil {
return err
}
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username))
}
// we do not stop on error, as we want to remove as
// many permissions as possible right now
var lastStmtError error
for _, query := range revokeStmts {
stmt, err := db.Prepare(query)
if err != nil {
lastStmtError = err
continue
}
defer stmt.Close()
_, err = stmt.Exec()
if err != nil {
lastStmtError = err
}
}
// can't drop if not all database users are dropped
if rows.Err() != nil {
return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
}
if lastStmtError != nil {
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
}
// Drop this login
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
return nil
}
const dropUserSQL = `
USE [%s]
IF EXISTS
(SELECT name
FROM sys.database_principals
WHERE name = N'%s')
BEGIN
DROP USER [%s]
END
`
const dropLoginSQL = `
IF EXISTS
(SELECT name
FROM master.sys.server_principals
WHERE name = N'%s')
BEGIN
DROP LOGIN [%s]
END
`

View File

@ -1,221 +0,0 @@
package dbs
import (
"database/sql"
"fmt"
"os"
"sync"
"testing"
"time"
_ "github.com/denisenkom/go-mssqldb"
log "github.com/mgutz/logxi/v1"
dockertest "gopkg.in/ory-am/dockertest.v3"
)
var (
testMSQLImagePull sync.Once
)
func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) {
if os.Getenv("MSSQL_URL") != "" {
return func() {}, os.Getenv("MSSQL_URL")
}
pool, err := dockertest.NewPool("")
if err != nil {
t.Fatalf("Failed to connect to docker: %s", err)
}
resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"})
if err != nil {
t.Fatalf("Could not start local MSSQL docker container: %s", err)
}
cleanup = func() {
err := pool.Purge(resource)
if err != nil {
t.Fatalf("Failed to cleanup local DynamoDB: %s", err)
}
}
retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp"))
// exponential backoff-retry, because the mssql container may not be able to accept connections yet
if err = pool.Retry(func() error {
var err error
var db *sql.DB
db, err = sql.Open("mssql", retURL)
if err != nil {
return err
}
return db.Ping()
}); err != nil {
t.Fatalf("Could not connect to MSSQL docker container: %s", err)
}
return
}
func TestMSSQL_Initialize(t *testing.T) {
cleanup, connURL := prepareMSSQLTestContainer(t)
defer cleanup()
conf := &DatabaseConfig{
DatabaseType: msSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
// Deconsturct the middleware chain to get the underlying mssql object
dbTracer := dbRaw.(*databaseTracingMiddleware)
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
db := dbMetrics.next.(*MSSQL)
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
err = dbRaw.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.initalized {
t.Fatal("Database should be initalized")
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
if connProducer.db != nil {
t.Fatal("db object should be nil")
}
}
func TestMSSQL_CreateUser(t *testing.T) {
cleanup, connURL := prepareMSSQLTestContainer(t)
defer cleanup()
conf := &DatabaseConfig{
DatabaseType: msSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test with no configured Creation Statememt
err = db.CreateUser(Statements{}, username, password, expiration)
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
statements := Statements{
CreationStatements: testMSSQLRole,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMSSQL_RevokeUser(t *testing.T) {
cleanup, connURL := prepareMSSQLTestContainer(t)
defer cleanup()
conf := &DatabaseConfig{
DatabaseType: msSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := Statements{
CreationStatements: testMSSQLRole,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
}
const testMSSQLRole = `
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`

View File

@ -1,135 +0,0 @@
package dbs
import (
"database/sql"
"strings"
"github.com/hashicorp/vault/helper/strutil"
)
const defaultMysqlRevocationStmts = `
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
DROP USER '{{name}}'@'%'
`
type MySQL struct {
ConnectionProducer
CredentialsProducer
}
func (m *MySQL) Type() string {
return mySQLTypeName
}
func (m *MySQL) getConnection() (*sql.DB, error) {
db, err := m.connection()
if err != nil {
return nil, err
}
return db.(*sql.DB), nil
}
func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error {
// Grab the lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
if err != nil {
return err
}
if statements.CreationStatements == "" {
return ErrEmptyCreationStatement
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
"password": password,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}
// NOOP
func (m *MySQL) RenewUser(statements Statements, username, expiration string) error {
return nil
}
func (m *MySQL) RevokeUser(statements Statements, username string) error {
// Grab the read lock
m.Lock()
defer m.Unlock()
// Get the connection
db, err := m.getConnection()
if err != nil {
return err
}
revocationStmts := statements.RevocationStatements
// Use a default SQL statement for revocation if one cannot be fetched from the role
if revocationStmts == "" {
revocationStmts = defaultMysqlRevocationStmts
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
// This is not a prepared statement because not all commands are supported
// 1295: This command is not supported in the prepared statement protocol yet
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
query = strings.Replace(query, "{{name}}", username, -1)
_, err = tx.Exec(query)
if err != nil {
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}

View File

@ -1,346 +0,0 @@
package dbs
import (
"database/sql"
"os"
"sync"
"testing"
"time"
log "github.com/mgutz/logxi/v1"
dockertest "gopkg.in/ory-am/dockertest.v2"
)
var (
testMySQLImagePull sync.Once
)
func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) {
if os.Getenv("MYSQL_URL") != "" {
return "", os.Getenv("MYSQL_URL")
}
// Without this the checks for whether the container has started seem to
// never actually pass. There's really no reason to expose the test
// containers, so don't.
dockertest.BindDockerToLocalhost = "yep"
testMySQLImagePull.Do(func() {
dockertest.Pull("mysql")
})
cid, connErr := dockertest.ConnectToMySQL(60, 500*time.Millisecond, func(connURL string) bool {
// This will cause a validation to run
connProducer := &sqlConnectionProducer{}
connProducer.ConnectionURL = connURL
connProducer.config = &DatabaseConfig{
DatabaseType: mySQLTypeName,
}
conn, err := connProducer.connection()
if err != nil {
return false
}
if err := conn.(*sql.DB).Ping(); err != nil {
return false
}
connProducer.Close()
retURL = connURL
return true
})
if connErr != nil {
t.Fatalf("could not connect to database: %v", connErr)
}
return
}
func TestMySQL_Initialize(t *testing.T) {
cid, connURL := prepareMySQLTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: mySQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
// Deconsturct the middleware chain to get the underlying mysql object
dbTracer := dbRaw.(*databaseTracingMiddleware)
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
db := dbMetrics.next.(*MySQL)
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
err = dbRaw.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.initalized {
t.Fatal("Database should be initalized")
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
if connProducer.db != nil {
t.Fatal("db object should be nil")
}
}
func TestMySQL_CreateUser(t *testing.T) {
cid, connURL := prepareMySQLTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: mySQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test with no configured Creation Statememt
err = db.CreateUser(Statements{}, username, password, expiration)
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
statements := Statements{
CreationStatements: testMySQLRoleWildCard,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements.CreationStatements = testMySQLRoleHost
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMySQL_RenewUser(t *testing.T) {
cid, connURL := prepareMySQLTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: mySQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := Statements{
CreationStatements: testMySQLRoleWildCard,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(statements, username, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestMySQL_RevokeUser(t *testing.T) {
cid, connURL := prepareMySQLTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: mySQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := Statements{
CreationStatements: testMySQLRoleWildCard,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements.CreationStatements = testMySQLRoleHost
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test custom revoke statements
statements.RevocationStatements = testMySQLRevocationSQL
err = db.RevokeUser(statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
}
const testMySQLRoleWildCard = `
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
GRANT SELECT ON *.* TO '{{name}}'@'%';
`
const testMySQLRoleHost = `
CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}';
GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2';
`
const testMySQLRevocationSQL = `
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2';
DROP USER '{{name}}'@'10.1.1.2';
`

View File

@ -1,279 +0,0 @@
package dbs
import (
"database/sql"
"fmt"
"strings"
"github.com/hashicorp/vault/helper/strutil"
"github.com/lib/pq"
)
type PostgreSQL struct {
ConnectionProducer
CredentialsProducer
}
func (p *PostgreSQL) Type() string {
return postgreSQLTypeName
}
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
db, err := p.connection()
if err != nil {
return nil, err
}
return db.(*sql.DB), nil
}
func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error {
if statements.CreationStatements == "" {
return ErrEmptyCreationStatement
}
// Grab the lock
p.Lock()
defer p.Unlock()
// Get the connection
db, err := p.getConnection()
if err != nil {
return err
}
// Start a transaction
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
// Return the secret
// Execute each query
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
"password": password,
"expiration": expiration,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}
// Commit the transaction
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error {
// Grab the lock
p.Lock()
defer p.Unlock()
db, err := p.getConnection()
if err != nil {
return err
}
query := fmt.Sprintf(
"ALTER ROLE %s VALID UNTIL '%s';",
pq.QuoteIdentifier(username),
expiration)
stmt, err := db.Prepare(query)
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
return nil
}
func (p *PostgreSQL) RevokeUser(statements Statements, username string) error {
// Grab the lock
p.Lock()
defer p.Unlock()
if statements.RevocationStatements == "" {
return p.defaultRevokeUser(username)
}
return p.customRevokeUser(username, statements.RevocationStatements)
}
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
db, err := p.getConnection()
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
continue
}
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
"name": username,
}))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
func (p *PostgreSQL) defaultRevokeUser(username string) error {
db, err := p.getConnection()
if err != nil {
return err
}
// Check if the role exists
var exists bool
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
if err != nil && err != sql.ErrNoRows {
return err
}
if exists == false {
return nil
}
// Query for permissions; we need to revoke permissions before we can drop
// the role
// This isn't done in a transaction because even if we fail along the way,
// we want to remove as much access as possible
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
if err != nil {
return err
}
defer stmt.Close()
rows, err := stmt.Query(username)
if err != nil {
return err
}
defer rows.Close()
const initialNumRevocations = 16
revocationStmts := make([]string, 0, initialNumRevocations)
for rows.Next() {
var schema string
err = rows.Scan(&schema)
if err != nil {
// keep going; remove as many permissions as possible right now
continue
}
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
pq.QuoteIdentifier(schema),
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
pq.QuoteIdentifier(schema),
pq.QuoteIdentifier(username)))
}
// for good measure, revoke all privileges and usage on schema public
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
revocationStmts = append(revocationStmts, fmt.Sprintf(
"REVOKE USAGE ON SCHEMA public FROM %s;",
pq.QuoteIdentifier(username)))
// get the current database name so we can issue a REVOKE CONNECT for
// this username
var dbname sql.NullString
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
return err
}
if dbname.Valid {
revocationStmts = append(revocationStmts, fmt.Sprintf(
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
pq.QuoteIdentifier(dbname.String),
pq.QuoteIdentifier(username)))
}
// again, here, we do not stop on error, as we want to remove as
// many permissions as possible right now
var lastStmtError error
for _, query := range revocationStmts {
stmt, err := db.Prepare(query)
if err != nil {
lastStmtError = err
continue
}
defer stmt.Close()
_, err = stmt.Exec()
if err != nil {
lastStmtError = err
}
}
// can't drop if not all privileges are revoked
if rows.Err() != nil {
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
}
if lastStmtError != nil {
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
}
// Drop this user
stmt, err = db.Prepare(fmt.Sprintf(
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
if err != nil {
return err
}
defer stmt.Close()
if _, err := stmt.Exec(); err != nil {
return err
}
return nil
}

View File

@ -1,414 +0,0 @@
package dbs
import (
"database/sql"
"os"
"sync"
"testing"
"time"
log "github.com/mgutz/logxi/v1"
dockertest "gopkg.in/ory-am/dockertest.v2"
)
var (
testPostgresImagePull sync.Once
)
func preparePostgresTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) {
if os.Getenv("PG_URL") != "" {
return "", os.Getenv("PG_URL")
}
// Without this the checks for whether the container has started seem to
// never actually pass. There's really no reason to expose the test
// containers, so don't.
dockertest.BindDockerToLocalhost = "yep"
testPostgresImagePull.Do(func() {
dockertest.Pull("postgres")
})
cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool {
// This will cause a validation to run
connProducer := &sqlConnectionProducer{}
connProducer.ConnectionURL = connURL
connProducer.config = &DatabaseConfig{
DatabaseType: postgreSQLTypeName,
}
conn, err := connProducer.connection()
if err != nil {
return false
}
if err := conn.(*sql.DB).Ping(); err != nil {
return false
}
connProducer.Close()
retURL = connURL
return true
})
if connErr != nil {
t.Fatalf("could not connect to database: %v", connErr)
}
return
}
func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
err := cid.KillRemove()
if err != nil {
t.Fatal(err)
}
}
func TestPostgreSQL_Initialize(t *testing.T) {
cid, connURL := preparePostgresTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: postgreSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
// Deconsturct the middleware chain to get the underlying postgres object
dbTracer := dbRaw.(*databaseTracingMiddleware)
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
db := dbMetrics.next.(*PostgreSQL)
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
err = dbRaw.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
if !connProducer.initalized {
t.Fatal("Database should be initalized")
}
err = dbRaw.Close()
if err != nil {
t.Fatalf("err: %s", err)
}
if connProducer.db != nil {
t.Fatal("db object should be nil")
}
}
func TestPostgreSQL_CreateUser(t *testing.T) {
cid, connURL := preparePostgresTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: postgreSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test with no configured Creation Statememt
err = db.CreateUser(Statements{}, username, password, expiration)
if err == nil {
t.Fatal("Expected error when no creation statement is provided")
}
statements := Statements{
CreationStatements: testPostgresRole,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements.CreationStatements = testPostgresReadOnlyRole
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
/* statements.CreationStatements = testBlockStatementRole
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}*/
}
func TestPostgreSQL_RenewUser(t *testing.T) {
cid, connURL := preparePostgresTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: postgreSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := Statements{
CreationStatements: testPostgresRole,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.RenewUser(statements, username, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
}
func TestPostgreSQL_RevokeUser(t *testing.T) {
cid, connURL := preparePostgresTestContainer(t)
if cid != "" {
defer cleanupTestContainer(t, cid)
}
conf := &DatabaseConfig{
DatabaseType: postgreSQLTypeName,
ConnectionDetails: map[string]interface{}{
"connection_url": connURL,
},
}
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.Initialize(conf.ConnectionDetails)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err := db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err := db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err := db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
statements := Statements{
CreationStatements: testPostgresRole,
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test default revoke statememts
err = db.RevokeUser(statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
username, err = db.GenerateUsername("test")
if err != nil {
t.Fatalf("err: %s", err)
}
password, err = db.GeneratePassword()
if err != nil {
t.Fatalf("err: %s", err)
}
expiration, err = db.GenerateExpiration(time.Minute)
if err != nil {
t.Fatalf("err: %s", err)
}
err = db.CreateUser(statements, username, password, expiration)
if err != nil {
t.Fatalf("err: %s", err)
}
// Test custom revoke statements
statements.RevocationStatements = defaultPostgresRevocationSQL
err = db.RevokeUser(statements, username)
if err != nil {
t.Fatalf("err: %s", err)
}
}
const testPostgresRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
`
const testPostgresReadOnlyRole = `
CREATE ROLE "{{name}}" WITH
LOGIN
PASSWORD '{{password}}'
VALID UNTIL '{{expiration}}';
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
`
const testPostgresBlockStatementRole = `
DO $$
BEGIN
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
CREATE ROLE "foo-role";
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
ALTER ROLE "foo-role" SET search_path = foo;
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
END IF;
END
$$
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
GRANT "foo-role" TO "{{name}}";
ALTER ROLE "{{name}}" SET search_path = foo;
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
`
var testPostgresBlockStatementRoleSlice = []string{
`
DO $$
BEGIN
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
CREATE ROLE "foo-role";
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
ALTER ROLE "foo-role" SET search_path = foo;
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
END IF;
END
$$
`,
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
`GRANT "foo-role" TO "{{name}}";`,
`ALTER ROLE "{{name}}" SET search_path = foo;`,
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
}
const defaultPostgresRevocationSQL = `
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
DROP ROLE IF EXISTS "{{name}}";
`

View File

@ -3,10 +3,8 @@ package database
import (
"fmt"
"strings"
"time"
"github.com/fatih/structs"
"github.com/hashicorp/vault/builtin/logical/database/dbs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -50,16 +48,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
return nil, nil
}
// pathConfigureBuiltinConnection returns a configured framework.Path setup to
// operate on builtin databases.
func pathConfigureBuiltinConnection(b *databaseBackend) *framework.Path {
return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler(), b.connectionDeleteHandler())
}
// pathConfigurePluginConnection returns a configured framework.Path setup to
// operate on plugins.
func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler(), b.connectionDeleteHandler())
return buildConfigConnectionPath("config/%s", b.connectionWriteHandler(), b.connectionReadHandler(), b.connectionDeleteHandler())
}
// buildConfigConnectionPath reutns a configured framework.Path using the passed
@ -74,40 +66,12 @@ func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework
Description: "Name of this DB type",
},
"connection_type": &framework.FieldSchema{
Type: framework.TypeString,
Description: "DB type (e.g. postgres)",
},
"verify_connection": &framework.FieldSchema{
Type: framework.TypeBool,
Default: true,
Description: `If set, connection_url is verified by actually connecting to the database`,
},
"max_open_connections": &framework.FieldSchema{
Type: framework.TypeInt,
Description: `Maximum number of open connections to the database;
a zero uses the default value of two and a
negative value means unlimited`,
},
"max_idle_connections": &framework.FieldSchema{
Type: framework.TypeInt,
Description: `Maximum number of idle connections to the database;
a zero uses the value of max_open_connections
and a negative value disables idle connections.
If larger than max_open_connections it will be
reduced to the same size.`,
},
"max_connection_lifetime": &framework.FieldSchema{
Type: framework.TypeString,
Default: "0s",
Description: `Maximum amount of time a connection may be reused;
a zero or negative value reuses connections forever.`,
},
"plugin_name": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Maximum amount of time a connection may be reused;
@ -139,7 +103,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
return nil, nil
}
var config dbs.DatabaseConfig
var config DatabaseConfig
if err := entry.DecodeJSON(&config); err != nil {
return nil, err
}
@ -180,40 +144,12 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
// connectionWriteHandler returns a handler function for creating and updating
// both builtin and plugin database types.
func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc {
func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
connType := data.Get("connection_type").(string)
if connType == "" {
return logical.ErrorResponse("connection_type not set"), nil
}
maxOpenConns := data.Get("max_open_connections").(int)
if maxOpenConns == 0 {
maxOpenConns = 2
}
maxIdleConns := data.Get("max_idle_connections").(int)
if maxIdleConns == 0 {
maxIdleConns = maxOpenConns
}
if maxIdleConns > maxOpenConns {
maxIdleConns = maxOpenConns
}
maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string)
maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf(
"Invalid max_connection_lifetime: %s", err)), nil
}
config := &dbs.DatabaseConfig{
DatabaseType: connType,
ConnectionDetails: data.Raw,
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
MaxConnectionLifetime: maxConnLifetime,
PluginName: data.Get("plugin_name").(string),
config := &DatabaseConfig{
ConnectionDetails: data.Raw,
PluginName: data.Get("plugin_name").(string),
}
name := data.Get("name").(string)
@ -227,7 +163,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
b.Lock()
defer b.Unlock()
db, err := factory(config, b.System(), b.logger)
db, err := PluginFactory(config, b.System(), b.logger)
if err != nil {
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
}

View File

@ -4,7 +4,6 @@ import (
"fmt"
"time"
"github.com/hashicorp/vault/builtin/logical/database/dbs"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
@ -156,7 +155,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
"Invalid max_ttl: %s", err)), nil
}
statements := dbs.Statements{
statements := Statements{
CreationStatements: creationStmts,
RevocationStatements: revocationStmts,
RollbackStatements: rollbackStmts,
@ -183,10 +182,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
}
type roleEntry struct {
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
Statements Statements `json:"statments" mapstructure:"statements" structs:"statments"`
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
}
const pathRoleHelpSyn = `

View File

@ -1,6 +1,7 @@
package dbs
package database
import (
"errors"
"fmt"
"net/rpc"
"sync"
@ -8,8 +9,47 @@ import (
"github.com/hashicorp/go-plugin"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
log "github.com/mgutz/logxi/v1"
)
var (
ErrEmptyPluginName = errors.New("empty plugin name")
)
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
if conf.PluginName == "" {
return nil, ErrEmptyPluginName
}
pluginMeta, err := sys.LookupPlugin(conf.PluginName)
if err != nil {
return nil, err
}
db, err := newPluginClient(sys, pluginMeta)
if err != nil {
return nil, err
}
// Wrap with metrics middleware
db = &databaseMetricsMiddleware{
next: db,
typeStr: db.Type(),
}
// Wrap with tracing middleware
db = &databaseTracingMiddleware{
next: db,
typeStr: db.Type(),
logger: logger,
}
return db, nil
}
// handshakeConfigs are used to just do a basic handshake between
// a plugin and host. If the handshake fails, a user friendly error is shown.
// This prevents users from executing bad plugins or executing a plugin
@ -33,7 +73,7 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
}
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close
// method to also call Close() on the plugin.Client.
// method to also call Kill() on the plugin.Client.
type DatabasePluginClient struct {
client *plugin.Client
sync.Mutex

View File

@ -1,4 +1,4 @@
package dbs
package database
import (
"crypto/sha256"

View File

@ -4,7 +4,7 @@ import (
"fmt"
"strings"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/helper/builtinplugins"
"github.com/hashicorp/vault/meta"
)
@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int {
pluginName := args[0]
runner, ok := pluginutil.BuiltinPlugins[pluginName]
runner, ok := builtinplugins.BuiltinPlugins[pluginName]
if !ok {
c.Ui.Error(fmt.Sprintf(
"No plugin with the name %s found", pluginName))

View File

@ -0,0 +1,8 @@
package builtinplugins
import "github.com/hashicorp/vault-plugins/database/mysql"
var BuiltinPlugins = map[string]func() error{
"mysql-database-plugin": mysql.Run,
// "postgres-database-plugin": postgres.Run,
}

View File

@ -1,6 +0,0 @@
package pluginutil
var BuiltinPlugins = map[string]func() error{
// "mysql-database-plugin": mysql.Run,
// "postgres-database-plugin": postgres.Run,
}

View File

@ -8,6 +8,7 @@ import (
"strings"
"sync"
"github.com/hashicorp/vault/helper/builtinplugins"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
@ -53,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) {
}
// Look for builtin plugins
if _, ok := pluginutil.BuiltinPlugins[name]; !ok {
if _, ok := builtinplugins.BuiltinPlugins[name]; !ok {
return nil, fmt.Errorf("no plugin found with name: %s", name)
}