mirror of
https://github.com/hashicorp/vault.git
synced 2026-05-05 04:16:31 +02:00
Refactor to use builtin plugins from an external repo
This commit is contained in:
parent
8f88452fc0
commit
8a2e29c607
@ -4,16 +4,52 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
|
||||
const databaseConfigPath = "database/dbs/"
|
||||
|
||||
// DatabaseType is the interface that all database objects must implement.
|
||||
type DatabaseType interface {
|
||||
Type() string
|
||||
CreateUser(statements Statements, username, password, expiration string) error
|
||||
RenewUser(statements Statements, username, expiration string) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
|
||||
Initialize(map[string]interface{}) error
|
||||
Close() error
|
||||
|
||||
GenerateUsername(displayName string) (string, error)
|
||||
GeneratePassword() (string, error)
|
||||
GenerateExpiration(ttl time.Duration) (string, error)
|
||||
}
|
||||
|
||||
// DatabaseConfig is used by the Factory function to configure a DatabaseType
|
||||
// object.
|
||||
type DatabaseConfig struct {
|
||||
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
|
||||
// ConnectionDetails stores the database specific connection settings needed
|
||||
// by each database type.
|
||||
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
|
||||
}
|
||||
|
||||
// Statements set in role creation and passed into the database type's functions.
|
||||
// TODO: Add a way of setting defaults here.
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
||||
}
|
||||
|
||||
func Factory(conf *logical.BackendConfig) (logical.Backend, error) {
|
||||
return Backend(conf).Setup(conf)
|
||||
}
|
||||
@ -30,7 +66,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
},
|
||||
|
||||
Paths: []*framework.Path{
|
||||
pathConfigureBuiltinConnection(&b),
|
||||
pathConfigurePluginConnection(&b),
|
||||
pathListRoles(&b),
|
||||
pathRoles(&b),
|
||||
@ -48,12 +83,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend {
|
||||
}
|
||||
|
||||
b.logger = conf.Logger
|
||||
b.connections = make(map[string]dbs.DatabaseType)
|
||||
b.connections = make(map[string]DatabaseType)
|
||||
return &b
|
||||
}
|
||||
|
||||
type databaseBackend struct {
|
||||
connections map[string]dbs.DatabaseType
|
||||
connections map[string]DatabaseType
|
||||
logger log.Logger
|
||||
|
||||
*framework.Backend
|
||||
@ -73,7 +108,7 @@ func (b *databaseBackend) closeAllDBs() {
|
||||
// This function is used to retrieve a database object either from the cached
|
||||
// connection map or by using the database config in storage. The caller of this
|
||||
// function needs to hold the backend's lock.
|
||||
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) {
|
||||
func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (DatabaseType, error) {
|
||||
// if the object already is built and cached, return it
|
||||
db, ok := b.connections[name]
|
||||
if ok {
|
||||
@ -88,14 +123,12 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.
|
||||
return nil, fmt.Errorf("failed to find entry for connection with name: %s", name)
|
||||
}
|
||||
|
||||
var config dbs.DatabaseConfig
|
||||
var config DatabaseConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
factory := config.GetFactory()
|
||||
|
||||
db, err = factory(&config, b.System(), b.logger)
|
||||
db, err = PluginFactory(&config, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
package dbs
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
@ -1,108 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;`
|
||||
defaultRollbackCQL = `DROP USER '{{username}}';`
|
||||
)
|
||||
|
||||
type Cassandra struct {
|
||||
// Session is goroutine safe, however, since we reinitialize
|
||||
// it when connection info changes, we want to make sure we
|
||||
// can close it and use a new connection; hence the lock
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
}
|
||||
|
||||
func (c *Cassandra) Type() string {
|
||||
return cassandraTypeName
|
||||
}
|
||||
|
||||
func (c *Cassandra) getConnection() (*gocql.Session, error) {
|
||||
session, err := c.connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session.(*gocql.Session), nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// Get the connection
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
creationCQL := statements.CreationStatements
|
||||
if creationCQL == "" {
|
||||
creationCQL = defaultCreationCQL
|
||||
}
|
||||
rollbackCQL := statements.RollbackStatements
|
||||
if rollbackCQL == "" {
|
||||
rollbackCQL = defaultRollbackCQL
|
||||
}
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = session.Query(queryHelper(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
if err != nil {
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
session.Query(queryHelper(query, map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})).Exec()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) RevokeUser(statements Statements, username string) error {
|
||||
// Grab the lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
session, err := c.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error removing user %s", username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,280 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
// Import sql drivers
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/hashicorp/vault/helper/certutil"
|
||||
"github.com/hashicorp/vault/helper/tlsutil"
|
||||
)
|
||||
|
||||
var (
|
||||
errNotInitalized = errors.New("connection has not been initalized")
|
||||
)
|
||||
|
||||
// ConnectionProducer can be used as an embeded interface in the DatabaseType
|
||||
// definition. It implements the methods dealing with individual database
|
||||
// connections and is used in all the builtin database types.
|
||||
type ConnectionProducer interface {
|
||||
Close() error
|
||||
Initialize(map[string]interface{}) error
|
||||
|
||||
sync.Locker
|
||||
connection() (interface{}, error)
|
||||
}
|
||||
|
||||
// sqlConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
||||
type sqlConnectionProducer struct {
|
||||
ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"`
|
||||
|
||||
config *DatabaseConfig
|
||||
|
||||
initalized bool
|
||||
db *sql.DB
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := c.connection(); err != nil {
|
||||
return fmt.Errorf("error initalizing connection: %s", err)
|
||||
}
|
||||
|
||||
c.initalized = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sqlConnectionProducer) connection() (interface{}, error) {
|
||||
// If we already have a DB, test it and return
|
||||
if c.db != nil {
|
||||
if err := c.db.Ping(); err == nil {
|
||||
return c.db, nil
|
||||
}
|
||||
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
||||
// reestablishing anyways
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
// For mssql backend, switch to sqlserver instead
|
||||
dbType := c.config.DatabaseType
|
||||
if c.config.DatabaseType == "mssql" {
|
||||
dbType = "sqlserver"
|
||||
}
|
||||
|
||||
// Otherwise, attempt to make connection
|
||||
conn := c.ConnectionURL
|
||||
|
||||
// Ensure timezone is set to UTC for all the conenctions
|
||||
if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") {
|
||||
if strings.Contains(conn, "?") {
|
||||
conn += "&timezone=utc"
|
||||
} else {
|
||||
conn += "?timezone=utc"
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
c.db, err = sql.Open(dbType, conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
// since the request rate shouldn't be high.
|
||||
c.db.SetMaxOpenConns(c.config.MaxOpenConnections)
|
||||
c.db.SetMaxIdleConns(c.config.MaxIdleConnections)
|
||||
c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime)
|
||||
|
||||
return c.db, nil
|
||||
}
|
||||
|
||||
func (c *sqlConnectionProducer) Close() error {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.db != nil {
|
||||
c.db.Close()
|
||||
}
|
||||
|
||||
c.db = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cassandraConnectionProducer implements ConnectionProducer and provides an
|
||||
// interface for cassandra databases to make connections.
|
||||
type cassandraConnectionProducer struct {
|
||||
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
|
||||
Username string `json:"username" structs:"username" mapstructure:"username"`
|
||||
Password string `json:"password" structs:"password" mapstructure:"password"`
|
||||
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
|
||||
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
|
||||
Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"`
|
||||
PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"`
|
||||
IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"`
|
||||
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
|
||||
ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
|
||||
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
||||
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
|
||||
|
||||
config *DatabaseConfig
|
||||
initalized bool
|
||||
session *gocql.Session
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
err := mapstructure.Decode(conf, c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.initalized = true
|
||||
|
||||
if _, err := c.connection(); err != nil {
|
||||
return fmt.Errorf("error Initalizing Connection: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) connection() (interface{}, error) {
|
||||
if !c.initalized {
|
||||
return nil, errNotInitalized
|
||||
}
|
||||
|
||||
// If we already have a DB, return it
|
||||
if c.session != nil {
|
||||
return c.session, nil
|
||||
}
|
||||
|
||||
session, err := c.createSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store the session in backend for reuse
|
||||
c.session = session
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) Close() error {
|
||||
// Grab the write lock
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if c.session != nil {
|
||||
c.session.Close()
|
||||
}
|
||||
|
||||
c.session = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) {
|
||||
clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...)
|
||||
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: c.Username,
|
||||
Password: c.Password,
|
||||
}
|
||||
|
||||
clusterConfig.ProtoVersion = c.ProtocolVersion
|
||||
if clusterConfig.ProtoVersion == 0 {
|
||||
clusterConfig.ProtoVersion = 2
|
||||
}
|
||||
|
||||
clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second
|
||||
|
||||
if c.TLS {
|
||||
var tlsConfig *tls.Config
|
||||
if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 {
|
||||
if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 {
|
||||
return nil, fmt.Errorf("Found certificate for TLS authentication but no private key")
|
||||
}
|
||||
|
||||
certBundle := &certutil.CertBundle{}
|
||||
if len(c.Certificate) > 0 {
|
||||
certBundle.Certificate = c.Certificate
|
||||
certBundle.PrivateKey = c.PrivateKey
|
||||
}
|
||||
if len(c.IssuingCA) > 0 {
|
||||
certBundle.IssuingCA = c.IssuingCA
|
||||
}
|
||||
|
||||
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse certificate bundle: %s", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
||||
if err != nil || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err)
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
||||
|
||||
if c.TLSMinVersion != "" {
|
||||
var ok bool
|
||||
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
||||
}
|
||||
} else {
|
||||
// MinVersion was not being set earlier. Reset it to
|
||||
// zero to gracefully handle upgrades.
|
||||
tlsConfig.MinVersion = 0
|
||||
}
|
||||
}
|
||||
|
||||
clusterConfig.SslOpts = &gocql.SslOptions{
|
||||
Config: *tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
session, err := clusterConfig.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating session: %s", err)
|
||||
}
|
||||
|
||||
// Set consistency
|
||||
if c.Consistency != "" {
|
||||
consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.SetConsistency(consistencyValue)
|
||||
}
|
||||
|
||||
// Verify the info
|
||||
err = session.Query(`LIST USERS`).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error validating connection info: %s", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
@ -1,83 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
uuid "github.com/hashicorp/go-uuid"
|
||||
)
|
||||
|
||||
// CredentialsProducer can be used as an embeded interface in the DatabaseType
|
||||
// definition. It implements the methods for generating user information for a
|
||||
// particular database type and is used in all the builtin database types.
|
||||
type CredentialsProducer interface {
|
||||
GenerateUsername(displayName string) (string, error)
|
||||
GeneratePassword() (string, error)
|
||||
GenerateExpiration(ttl time.Duration) (string, error)
|
||||
}
|
||||
|
||||
// sqlCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types.
|
||||
type sqlCredentialsProducer struct {
|
||||
displayNameLen int
|
||||
usernameLen int
|
||||
}
|
||||
|
||||
func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen {
|
||||
displayName = displayName[:scp.displayNameLen]
|
||||
}
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("%s-%s", displayName, userUUID)
|
||||
if scp.usernameLen > 0 && len(username) > scp.usernameLen {
|
||||
username = username[:scp.usernameLen]
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) {
|
||||
return time.Now().
|
||||
Add(ttl).
|
||||
Format("2006-01-02 15:04:05-0700"), nil
|
||||
}
|
||||
|
||||
// cassandraCredentialsProducer implements CredentialsProducer and provides an
|
||||
// interface for cassandra databases to generate user information.
|
||||
type cassandraCredentialsProducer struct{}
|
||||
|
||||
func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) {
|
||||
userUUID, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix())
|
||||
username = strings.Replace(username, "-", "_", -1)
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) {
|
||||
password, err := uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
@ -1,196 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/logical"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
const (
|
||||
postgreSQLTypeName = "postgres"
|
||||
mySQLTypeName = "mysql"
|
||||
msSQLTypeName = "mssql"
|
||||
cassandraTypeName = "cassandra"
|
||||
pluginTypeName = "plugin"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedDatabaseType = errors.New("unsupported database type")
|
||||
ErrEmptyCreationStatement = errors.New("empty creation statements")
|
||||
ErrEmptyPluginName = errors.New("empty plugin name")
|
||||
)
|
||||
|
||||
// Factory function definition
|
||||
type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error)
|
||||
|
||||
// BuiltinFactory is used to build builtin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
||||
var dbType DatabaseType
|
||||
|
||||
switch conf.DatabaseType {
|
||||
case postgreSQLTypeName:
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &sqlCredentialsProducer{
|
||||
displayNameLen: 23,
|
||||
usernameLen: 63,
|
||||
}
|
||||
|
||||
dbType = &PostgreSQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
case mySQLTypeName:
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &sqlCredentialsProducer{
|
||||
displayNameLen: 4,
|
||||
usernameLen: 16,
|
||||
}
|
||||
|
||||
dbType = &MySQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
case msSQLTypeName:
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &sqlCredentialsProducer{
|
||||
displayNameLen: 10,
|
||||
usernameLen: 63,
|
||||
}
|
||||
|
||||
dbType = &MSSQL{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
case cassandraTypeName:
|
||||
connProducer := &cassandraConnectionProducer{}
|
||||
connProducer.config = conf
|
||||
|
||||
credsProducer := &cassandraCredentialsProducer{}
|
||||
|
||||
dbType = &Cassandra{
|
||||
ConnectionProducer: connProducer,
|
||||
CredentialsProducer: credsProducer,
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabaseType
|
||||
}
|
||||
|
||||
// Wrap with metrics middleware
|
||||
dbType = &databaseMetricsMiddleware{
|
||||
next: dbType,
|
||||
typeStr: dbType.Type(),
|
||||
}
|
||||
|
||||
// Wrap with tracing middleware
|
||||
dbType = &databaseTracingMiddleware{
|
||||
next: dbType,
|
||||
typeStr: dbType.Type(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return dbType, nil
|
||||
}
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
||||
if conf.PluginName == "" {
|
||||
return nil, ErrEmptyPluginName
|
||||
}
|
||||
|
||||
pluginMeta, err := sys.LookupPlugin(conf.PluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Make sure the database type is set to plugin
|
||||
conf.DatabaseType = pluginTypeName
|
||||
|
||||
db, err := newPluginClient(sys, pluginMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap with metrics middleware
|
||||
db = &databaseMetricsMiddleware{
|
||||
next: db,
|
||||
typeStr: db.Type(),
|
||||
}
|
||||
|
||||
// Wrap with tracing middleware
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
typeStr: db.Type(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// DatabaseType is the interface that all database objects must implement.
|
||||
type DatabaseType interface {
|
||||
Type() string
|
||||
CreateUser(statements Statements, username, password, expiration string) error
|
||||
RenewUser(statements Statements, username, expiration string) error
|
||||
RevokeUser(statements Statements, username string) error
|
||||
|
||||
Initialize(map[string]interface{}) error
|
||||
Close() error
|
||||
CredentialsProducer
|
||||
}
|
||||
|
||||
// DatabaseConfig is used by the Factory function to configure a DatabaseType
|
||||
// object.
|
||||
type DatabaseConfig struct {
|
||||
DatabaseType string `json:"type" structs:"type" mapstructure:"type"`
|
||||
// ConnectionDetails stores the database specific connection settings needed
|
||||
// by each database type.
|
||||
ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"`
|
||||
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
|
||||
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
|
||||
}
|
||||
|
||||
// GetFactory returns the appropriate factory method for the given database
|
||||
// type.
|
||||
func (dc *DatabaseConfig) GetFactory() Factory {
|
||||
if dc.DatabaseType == pluginTypeName {
|
||||
return PluginFactory
|
||||
}
|
||||
|
||||
return BuiltinFactory
|
||||
}
|
||||
|
||||
// Statements set in role creation and passed into the database type's functions.
|
||||
// TODO: Add a way of setting defaults here.
|
||||
type Statements struct {
|
||||
CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"`
|
||||
RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"`
|
||||
RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"`
|
||||
RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"`
|
||||
}
|
||||
|
||||
// Query templates a query for us.
|
||||
func queryHelper(tpl string, data map[string]string) string {
|
||||
for k, v := range data {
|
||||
tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1)
|
||||
}
|
||||
|
||||
return tpl
|
||||
}
|
||||
@ -1,219 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
// MSSQL is an implementation of DatabaseType interface
|
||||
type MSSQL struct {
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
}
|
||||
|
||||
// Type returns the TypeName for this backend
|
||||
func (m *MSSQL) Type() string {
|
||||
return msSQLTypeName
|
||||
}
|
||||
|
||||
func (m *MSSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by
|
||||
// the CreationStatement provided.
|
||||
func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if statements.CreationStatements == "" {
|
||||
return ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenewUser is not supported on MSSQL, so this is a no-op.
|
||||
func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeUser attempts to drop the specified user. It will first attempt to disable login,
|
||||
// then kill pending connections from that user, and finally drop the user and login from the
|
||||
// database instance.
|
||||
func (m *MSSQL) RevokeUser(statements Statements, username string) error {
|
||||
// Get connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// First disable server login
|
||||
disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer disableStmt.Close()
|
||||
if _, err := disableStmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Query for sessions for the login so that we can kill any outstanding
|
||||
// sessions. There cannot be any active sessions before we drop the logins
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
sessionStmt, err := db.Prepare(fmt.Sprintf(
|
||||
"SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionStmt.Close()
|
||||
|
||||
sessionRows, err := sessionStmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sessionRows.Close()
|
||||
|
||||
var revokeStmts []string
|
||||
for sessionRows.Next() {
|
||||
var sessionID int
|
||||
err = sessionRows.Scan(&sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID))
|
||||
}
|
||||
|
||||
// Query for database users using undocumented stored procedure for now since
|
||||
// it is the easiest way to get this information;
|
||||
// we need to drop the database users before we can drop the login and the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var loginName, dbName, qUsername string
|
||||
var aliasName sql.NullString
|
||||
err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username))
|
||||
}
|
||||
|
||||
// we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revokeStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all database users are dropped
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all sql statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this login
|
||||
stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const dropUserSQL = `
|
||||
USE [%s]
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM sys.database_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP USER [%s]
|
||||
END
|
||||
`
|
||||
|
||||
const dropLoginSQL = `
|
||||
IF EXISTS
|
||||
(SELECT name
|
||||
FROM master.sys.server_principals
|
||||
WHERE name = N'%s')
|
||||
BEGIN
|
||||
DROP LOGIN [%s]
|
||||
END
|
||||
`
|
||||
@ -1,221 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/denisenkom/go-mssqldb"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testMSQLImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) {
|
||||
if os.Getenv("MSSQL_URL") != "" {
|
||||
return func() {}, os.Getenv("MSSQL_URL")
|
||||
}
|
||||
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect to docker: %s", err)
|
||||
}
|
||||
|
||||
resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"})
|
||||
if err != nil {
|
||||
t.Fatalf("Could not start local MSSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
cleanup = func() {
|
||||
err := pool.Purge(resource)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to cleanup local DynamoDB: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp"))
|
||||
|
||||
// exponential backoff-retry, because the mssql container may not be able to accept connections yet
|
||||
if err = pool.Retry(func() error {
|
||||
var err error
|
||||
var db *sql.DB
|
||||
db, err = sql.Open("mssql", retURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
t.Fatalf("Could not connect to MSSQL docker container: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestMSSQL_Initialize(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: msSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Deconsturct the middleware chain to get the underlying mssql object
|
||||
dbTracer := dbRaw.(*databaseTracingMiddleware)
|
||||
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
|
||||
db := dbMetrics.next.(*MSSQL)
|
||||
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
|
||||
|
||||
err = dbRaw.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.initalized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if connProducer.db != nil {
|
||||
t.Fatal("db object should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_CreateUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: msSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
err = db.CreateUser(Statements{}, username, password, expiration)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSQL_RevokeUser(t *testing.T) {
|
||||
cleanup, connURL := prepareMSSQLTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: msSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testMSSQLRole,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
const testMSSQLRole = `
|
||||
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
|
||||
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
|
||||
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`
|
||||
@ -1,135 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
)
|
||||
|
||||
const defaultMysqlRevocationStmts = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%';
|
||||
DROP USER '{{name}}'@'%'
|
||||
`
|
||||
|
||||
type MySQL struct {
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
}
|
||||
|
||||
func (m *MySQL) Type() string {
|
||||
return mySQLTypeName
|
||||
}
|
||||
|
||||
func (m *MySQL) getConnection() (*sql.DB, error) {
|
||||
db, err := m.connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
// Grab the lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if statements.CreationStatements == "" {
|
||||
return ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NOOP
|
||||
func (m *MySQL) RenewUser(statements Statements, username, expiration string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MySQL) RevokeUser(statements Statements, username string) error {
|
||||
// Grab the read lock
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := m.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
revocationStmts := statements.RevocationStatements
|
||||
// Use a default SQL statement for revocation if one cannot be fetched from the role
|
||||
if revocationStmts == "" {
|
||||
revocationStmts = defaultMysqlRevocationStmts
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// This is not a prepared statement because not all commands are supported
|
||||
// 1295: This command is not supported in the prepared statement protocol yet
|
||||
// Reference https://mariadb.com/kb/en/mariadb/prepare-statement/
|
||||
query = strings.Replace(query, "{{name}}", username, -1)
|
||||
_, err = tx.Exec(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,346 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
testMySQLImagePull sync.Once
|
||||
)
|
||||
|
||||
func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) {
|
||||
if os.Getenv("MYSQL_URL") != "" {
|
||||
return "", os.Getenv("MYSQL_URL")
|
||||
}
|
||||
|
||||
// Without this the checks for whether the container has started seem to
|
||||
// never actually pass. There's really no reason to expose the test
|
||||
// containers, so don't.
|
||||
dockertest.BindDockerToLocalhost = "yep"
|
||||
|
||||
testMySQLImagePull.Do(func() {
|
||||
dockertest.Pull("mysql")
|
||||
})
|
||||
|
||||
cid, connErr := dockertest.ConnectToMySQL(60, 500*time.Millisecond, func(connURL string) bool {
|
||||
// This will cause a validation to run
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.ConnectionURL = connURL
|
||||
connProducer.config = &DatabaseConfig{
|
||||
DatabaseType: mySQLTypeName,
|
||||
}
|
||||
|
||||
conn, err := connProducer.connection()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := conn.(*sql.DB).Ping(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
connProducer.Close()
|
||||
|
||||
retURL = connURL
|
||||
return true
|
||||
})
|
||||
|
||||
if connErr != nil {
|
||||
t.Fatalf("could not connect to database: %v", connErr)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func TestMySQL_Initialize(t *testing.T) {
|
||||
cid, connURL := prepareMySQLTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: mySQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Deconsturct the middleware chain to get the underlying mysql object
|
||||
dbTracer := dbRaw.(*databaseTracingMiddleware)
|
||||
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
|
||||
db := dbMetrics.next.(*MySQL)
|
||||
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
|
||||
|
||||
err = dbRaw.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.initalized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if connProducer.db != nil {
|
||||
t.Fatal("db object should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQL_CreateUser(t *testing.T) {
|
||||
cid, connURL := prepareMySQLTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: mySQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
err = db.CreateUser(Statements{}, username, password, expiration)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
statements.CreationStatements = testMySQLRoleHost
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQL_RenewUser(t *testing.T) {
|
||||
cid, connURL := prepareMySQLTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: mySQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.RenewUser(statements, username, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMySQL_RevokeUser(t *testing.T) {
|
||||
cid, connURL := prepareMySQLTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: mySQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testMySQLRoleWildCard,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements.CreationStatements = testMySQLRoleHost
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = testMySQLRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const testMySQLRoleWildCard = `
|
||||
CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}';
|
||||
GRANT SELECT ON *.* TO '{{name}}'@'%';
|
||||
`
|
||||
const testMySQLRoleHost = `
|
||||
CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}';
|
||||
GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2';
|
||||
`
|
||||
const testMySQLRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2';
|
||||
DROP USER '{{name}}'@'10.1.1.2';
|
||||
`
|
||||
@ -1,279 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/strutil"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type PostgreSQL struct {
|
||||
ConnectionProducer
|
||||
CredentialsProducer
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Type() string {
|
||||
return postgreSQLTypeName
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) getConnection() (*sql.DB, error) {
|
||||
db, err := p.connection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.(*sql.DB), nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error {
|
||||
if statements.CreationStatements == "" {
|
||||
return ErrEmptyCreationStatement
|
||||
}
|
||||
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
// Get the connection
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start a transaction
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
// Return the secret
|
||||
|
||||
// Execute each query
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
"password": password,
|
||||
"expiration": expiration,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the transaction
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
"ALTER ROLE %s VALID UNTIL '%s';",
|
||||
pq.QuoteIdentifier(username),
|
||||
expiration)
|
||||
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) RevokeUser(statements Statements, username string) error {
|
||||
// Grab the lock
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if statements.RevocationStatements == "" {
|
||||
return p.defaultRevokeUser(username)
|
||||
}
|
||||
|
||||
return p.customRevokeUser(username, statements.RevocationStatements)
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") {
|
||||
query = strings.TrimSpace(query)
|
||||
if len(query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
stmt, err := tx.Prepare(queryHelper(query, map[string]string{
|
||||
"name": username,
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) defaultRevokeUser(username string) error {
|
||||
db, err := p.getConnection()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the role exists
|
||||
var exists bool
|
||||
err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists == false {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query for permissions; we need to revoke permissions before we can drop
|
||||
// the role
|
||||
// This isn't done in a transaction because even if we fail along the way,
|
||||
// we want to remove as much access as possible
|
||||
stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
const initialNumRevocations = 16
|
||||
revocationStmts := make([]string, 0, initialNumRevocations)
|
||||
for rows.Next() {
|
||||
var schema string
|
||||
err = rows.Scan(&schema)
|
||||
if err != nil {
|
||||
// keep going; remove as many permissions as possible right now
|
||||
continue
|
||||
}
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE USAGE ON SCHEMA %s FROM %s;`,
|
||||
pq.QuoteIdentifier(schema),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// for good measure, revoke all privileges and usage on schema public
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`,
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
"REVOKE USAGE ON SCHEMA public FROM %s;",
|
||||
pq.QuoteIdentifier(username)))
|
||||
|
||||
// get the current database name so we can issue a REVOKE CONNECT for
|
||||
// this username
|
||||
var dbname sql.NullString
|
||||
if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dbname.Valid {
|
||||
revocationStmts = append(revocationStmts, fmt.Sprintf(
|
||||
`REVOKE CONNECT ON DATABASE %s FROM %s;`,
|
||||
pq.QuoteIdentifier(dbname.String),
|
||||
pq.QuoteIdentifier(username)))
|
||||
}
|
||||
|
||||
// again, here, we do not stop on error, as we want to remove as
|
||||
// many permissions as possible right now
|
||||
var lastStmtError error
|
||||
for _, query := range revocationStmts {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
continue
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
lastStmtError = err
|
||||
}
|
||||
}
|
||||
|
||||
// can't drop if not all privileges are revoked
|
||||
if rows.Err() != nil {
|
||||
return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err())
|
||||
}
|
||||
if lastStmtError != nil {
|
||||
return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError)
|
||||
}
|
||||
|
||||
// Drop this user
|
||||
stmt, err = db.Prepare(fmt.Sprintf(
|
||||
`DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
if _, err := stmt.Exec(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1,414 +0,0 @@
|
||||
package dbs
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
dockertest "gopkg.in/ory-am/dockertest.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
testPostgresImagePull sync.Once
|
||||
)
|
||||
|
||||
func preparePostgresTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) {
|
||||
if os.Getenv("PG_URL") != "" {
|
||||
return "", os.Getenv("PG_URL")
|
||||
}
|
||||
|
||||
// Without this the checks for whether the container has started seem to
|
||||
// never actually pass. There's really no reason to expose the test
|
||||
// containers, so don't.
|
||||
dockertest.BindDockerToLocalhost = "yep"
|
||||
|
||||
testPostgresImagePull.Do(func() {
|
||||
dockertest.Pull("postgres")
|
||||
})
|
||||
|
||||
cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool {
|
||||
// This will cause a validation to run
|
||||
connProducer := &sqlConnectionProducer{}
|
||||
connProducer.ConnectionURL = connURL
|
||||
connProducer.config = &DatabaseConfig{
|
||||
DatabaseType: postgreSQLTypeName,
|
||||
}
|
||||
|
||||
conn, err := connProducer.connection()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if err := conn.(*sql.DB).Ping(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
connProducer.Close()
|
||||
|
||||
retURL = connURL
|
||||
return true
|
||||
})
|
||||
|
||||
if connErr != nil {
|
||||
t.Fatalf("could not connect to database: %v", connErr)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) {
|
||||
err := cid.KillRemove()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_Initialize(t *testing.T) {
|
||||
cid, connURL := preparePostgresTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: postgreSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Deconsturct the middleware chain to get the underlying postgres object
|
||||
dbTracer := dbRaw.(*databaseTracingMiddleware)
|
||||
dbMetrics := dbTracer.next.(*databaseMetricsMiddleware)
|
||||
db := dbMetrics.next.(*PostgreSQL)
|
||||
connProducer := db.ConnectionProducer.(*sqlConnectionProducer)
|
||||
|
||||
err = dbRaw.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if !connProducer.initalized {
|
||||
t.Fatal("Database should be initalized")
|
||||
}
|
||||
|
||||
err = dbRaw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
if connProducer.db != nil {
|
||||
t.Fatal("db object should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_CreateUser(t *testing.T) {
|
||||
cid, connURL := preparePostgresTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: postgreSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test with no configured Creation Statememt
|
||||
err = db.CreateUser(Statements{}, username, password, expiration)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no creation statement is provided")
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
statements.CreationStatements = testPostgresReadOnlyRole
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
/* statements.CreationStatements = testBlockStatementRole
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}*/
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RenewUser(t *testing.T) {
|
||||
cid, connURL := preparePostgresTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: postgreSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.RenewUser(statements, username, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_RevokeUser(t *testing.T) {
|
||||
cid, connURL := preparePostgresTestContainer(t)
|
||||
if cid != "" {
|
||||
defer cleanupTestContainer(t, cid)
|
||||
}
|
||||
|
||||
conf := &DatabaseConfig{
|
||||
DatabaseType: postgreSQLTypeName,
|
||||
ConnectionDetails: map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := BuiltinFactory(conf, nil, &log.NullLogger{})
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
err = db.Initialize(conf.ConnectionDetails)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err := db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err := db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err := db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
statements := Statements{
|
||||
CreationStatements: testPostgresRole,
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test default revoke statememts
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
username, err = db.GenerateUsername("test")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
password, err = db.GeneratePassword()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
expiration, err = db.GenerateExpiration(time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
err = db.CreateUser(statements, username, password, expiration)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
// Test custom revoke statements
|
||||
statements.RevocationStatements = defaultPostgresRevocationSQL
|
||||
err = db.RevokeUser(statements, username)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const testPostgresRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testPostgresReadOnlyRole = `
|
||||
CREATE ROLE "{{name}}" WITH
|
||||
LOGIN
|
||||
PASSWORD '{{password}}'
|
||||
VALID UNTIL '{{expiration}}';
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
|
||||
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";
|
||||
`
|
||||
|
||||
const testPostgresBlockStatementRole = `
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
|
||||
CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';
|
||||
GRANT "foo-role" TO "{{name}}";
|
||||
ALTER ROLE "{{name}}" SET search_path = foo;
|
||||
GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";
|
||||
`
|
||||
|
||||
var testPostgresBlockStatementRoleSlice = []string{
|
||||
`
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN
|
||||
CREATE ROLE "foo-role";
|
||||
CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role";
|
||||
ALTER ROLE "foo-role" SET search_path = foo;
|
||||
GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role";
|
||||
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role";
|
||||
END IF;
|
||||
END
|
||||
$$
|
||||
`,
|
||||
`CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`,
|
||||
`GRANT "foo-role" TO "{{name}}";`,
|
||||
`ALTER ROLE "{{name}}" SET search_path = foo;`,
|
||||
`GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`,
|
||||
}
|
||||
|
||||
const defaultPostgresRevocationSQL = `
|
||||
REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}";
|
||||
REVOKE USAGE ON SCHEMA public FROM "{{name}}";
|
||||
|
||||
DROP ROLE IF EXISTS "{{name}}";
|
||||
`
|
||||
@ -3,10 +3,8 @@ package database
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/structs"
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
@ -50,16 +48,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// pathConfigureBuiltinConnection returns a configured framework.Path setup to
|
||||
// operate on builtin databases.
|
||||
func pathConfigureBuiltinConnection(b *databaseBackend) *framework.Path {
|
||||
return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler(), b.connectionDeleteHandler())
|
||||
}
|
||||
|
||||
// pathConfigurePluginConnection returns a configured framework.Path setup to
|
||||
// operate on plugins.
|
||||
func pathConfigurePluginConnection(b *databaseBackend) *framework.Path {
|
||||
return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler(), b.connectionDeleteHandler())
|
||||
return buildConfigConnectionPath("config/%s", b.connectionWriteHandler(), b.connectionReadHandler(), b.connectionDeleteHandler())
|
||||
}
|
||||
|
||||
// buildConfigConnectionPath reutns a configured framework.Path using the passed
|
||||
@ -74,40 +66,12 @@ func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework
|
||||
Description: "Name of this DB type",
|
||||
},
|
||||
|
||||
"connection_type": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: "DB type (e.g. postgres)",
|
||||
},
|
||||
|
||||
"verify_connection": &framework.FieldSchema{
|
||||
Type: framework.TypeBool,
|
||||
Default: true,
|
||||
Description: `If set, connection_url is verified by actually connecting to the database`,
|
||||
},
|
||||
|
||||
"max_open_connections": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
Description: `Maximum number of open connections to the database;
|
||||
a zero uses the default value of two and a
|
||||
negative value means unlimited`,
|
||||
},
|
||||
|
||||
"max_idle_connections": &framework.FieldSchema{
|
||||
Type: framework.TypeInt,
|
||||
Description: `Maximum number of idle connections to the database;
|
||||
a zero uses the value of max_open_connections
|
||||
and a negative value disables idle connections.
|
||||
If larger than max_open_connections it will be
|
||||
reduced to the same size.`,
|
||||
},
|
||||
|
||||
"max_connection_lifetime": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Default: "0s",
|
||||
Description: `Maximum amount of time a connection may be reused;
|
||||
a zero or negative value reuses connections forever.`,
|
||||
},
|
||||
|
||||
"plugin_name": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: `Maximum amount of time a connection may be reused;
|
||||
@ -139,7 +103,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var config dbs.DatabaseConfig
|
||||
var config DatabaseConfig
|
||||
if err := entry.DecodeJSON(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -180,40 +144,12 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc {
|
||||
|
||||
// connectionWriteHandler returns a handler function for creating and updating
|
||||
// both builtin and plugin database types.
|
||||
func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc {
|
||||
func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc {
|
||||
return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
connType := data.Get("connection_type").(string)
|
||||
if connType == "" {
|
||||
return logical.ErrorResponse("connection_type not set"), nil
|
||||
}
|
||||
|
||||
maxOpenConns := data.Get("max_open_connections").(int)
|
||||
if maxOpenConns == 0 {
|
||||
maxOpenConns = 2
|
||||
}
|
||||
|
||||
maxIdleConns := data.Get("max_idle_connections").(int)
|
||||
if maxIdleConns == 0 {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
if maxIdleConns > maxOpenConns {
|
||||
maxIdleConns = maxOpenConns
|
||||
}
|
||||
|
||||
maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string)
|
||||
maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf(
|
||||
"Invalid max_connection_lifetime: %s", err)), nil
|
||||
}
|
||||
|
||||
config := &dbs.DatabaseConfig{
|
||||
DatabaseType: connType,
|
||||
ConnectionDetails: data.Raw,
|
||||
MaxOpenConnections: maxOpenConns,
|
||||
MaxIdleConnections: maxIdleConns,
|
||||
MaxConnectionLifetime: maxConnLifetime,
|
||||
PluginName: data.Get("plugin_name").(string),
|
||||
config := &DatabaseConfig{
|
||||
ConnectionDetails: data.Raw,
|
||||
PluginName: data.Get("plugin_name").(string),
|
||||
}
|
||||
|
||||
name := data.Get("name").(string)
|
||||
@ -227,7 +163,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
|
||||
b.Lock()
|
||||
defer b.Unlock()
|
||||
|
||||
db, err := factory(config, b.System(), b.logger)
|
||||
db, err := PluginFactory(config, b.System(), b.logger)
|
||||
if err != nil {
|
||||
return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil
|
||||
}
|
||||
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/builtin/logical/database/dbs"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
"github.com/hashicorp/vault/logical/framework"
|
||||
)
|
||||
@ -156,7 +155,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
||||
"Invalid max_ttl: %s", err)), nil
|
||||
}
|
||||
|
||||
statements := dbs.Statements{
|
||||
statements := Statements{
|
||||
CreationStatements: creationStmts,
|
||||
RevocationStatements: revocationStmts,
|
||||
RollbackStatements: rollbackStmts,
|
||||
@ -183,10 +182,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F
|
||||
}
|
||||
|
||||
type roleEntry struct {
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"`
|
||||
Statements Statements `json:"statments" mapstructure:"statements" structs:"statments"`
|
||||
DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"`
|
||||
MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"`
|
||||
}
|
||||
|
||||
const pathRoleHelpSyn = `
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package dbs
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
@ -8,8 +9,47 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-plugin"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
log "github.com/mgutz/logxi/v1"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyPluginName = errors.New("empty plugin name")
|
||||
)
|
||||
|
||||
// PluginFactory is used to build plugin database types. It wraps the database
|
||||
// object in a logging and metrics middleware.
|
||||
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
|
||||
if conf.PluginName == "" {
|
||||
return nil, ErrEmptyPluginName
|
||||
}
|
||||
|
||||
pluginMeta, err := sys.LookupPlugin(conf.PluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := newPluginClient(sys, pluginMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wrap with metrics middleware
|
||||
db = &databaseMetricsMiddleware{
|
||||
next: db,
|
||||
typeStr: db.Type(),
|
||||
}
|
||||
|
||||
// Wrap with tracing middleware
|
||||
db = &databaseTracingMiddleware{
|
||||
next: db,
|
||||
typeStr: db.Type(),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// handshakeConfigs are used to just do a basic handshake between
|
||||
// a plugin and host. If the handshake fails, a user friendly error is shown.
|
||||
// This prevents users from executing bad plugins or executing a plugin
|
||||
@ -33,7 +73,7 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e
|
||||
}
|
||||
|
||||
// DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close
|
||||
// method to also call Close() on the plugin.Client.
|
||||
// method to also call Kill() on the plugin.Client.
|
||||
type DatabasePluginClient struct {
|
||||
client *plugin.Client
|
||||
sync.Mutex
|
||||
@ -1,4 +1,4 @@
|
||||
package dbs
|
||||
package database
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/helper/builtinplugins"
|
||||
"github.com/hashicorp/vault/meta"
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int {
|
||||
|
||||
pluginName := args[0]
|
||||
|
||||
runner, ok := pluginutil.BuiltinPlugins[pluginName]
|
||||
runner, ok := builtinplugins.BuiltinPlugins[pluginName]
|
||||
if !ok {
|
||||
c.Ui.Error(fmt.Sprintf(
|
||||
"No plugin with the name %s found", pluginName))
|
||||
|
||||
8
helper/builtinplugins/builtin.go
Normal file
8
helper/builtinplugins/builtin.go
Normal file
@ -0,0 +1,8 @@
|
||||
package builtinplugins
|
||||
|
||||
import "github.com/hashicorp/vault-plugins/database/mysql"
|
||||
|
||||
var BuiltinPlugins = map[string]func() error{
|
||||
"mysql-database-plugin": mysql.Run,
|
||||
// "postgres-database-plugin": postgres.Run,
|
||||
}
|
||||
@ -1,6 +0,0 @@
|
||||
package pluginutil
|
||||
|
||||
var BuiltinPlugins = map[string]func() error{
|
||||
// "mysql-database-plugin": mysql.Run,
|
||||
// "postgres-database-plugin": postgres.Run,
|
||||
}
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/vault/helper/builtinplugins"
|
||||
"github.com/hashicorp/vault/helper/jsonutil"
|
||||
"github.com/hashicorp/vault/helper/pluginutil"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
@ -53,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) {
|
||||
}
|
||||
|
||||
// Look for builtin plugins
|
||||
if _, ok := pluginutil.BuiltinPlugins[name]; !ok {
|
||||
if _, ok := builtinplugins.BuiltinPlugins[name]; !ok {
|
||||
return nil, fmt.Errorf("no plugin found with name: %s", name)
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user