From 8a2e29c607664360ecff7c2f356d830c6b7746f0 Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Wed, 5 Apr 2017 16:20:31 -0700 Subject: [PATCH] Refactor to use builtin plugins from an external repo --- builtin/logical/database/backend.go | 51 ++- .../database/{dbs => }/databasemiddleware.go | 2 +- builtin/logical/database/dbs/cassandra.go | 108 ----- .../database/dbs/connectionproducer.go | 280 ------------ .../database/dbs/credentialsproducer.go | 83 ---- builtin/logical/database/dbs/db.go | 196 --------- builtin/logical/database/dbs/mssql.go | 219 --------- builtin/logical/database/dbs/mssql_test.go | 221 ---------- builtin/logical/database/dbs/mysql.go | 135 ------ builtin/logical/database/dbs/mysql_test.go | 346 --------------- builtin/logical/database/dbs/postgresql.go | 279 ------------ .../logical/database/dbs/postgresql_test.go | 414 ------------------ .../database/path_config_connection.go | 78 +--- builtin/logical/database/path_roles.go | 11 +- builtin/logical/database/{dbs => }/plugin.go | 44 +- .../logical/database/{dbs => }/plugin_test.go | 2 +- command/plugin-exec.go | 4 +- helper/builtinplugins/builtin.go | 8 + helper/pluginutil/builtin.go | 6 - vault/plugin_catalog.go | 3 +- 20 files changed, 110 insertions(+), 2380 deletions(-) rename builtin/logical/database/{dbs => }/databasemiddleware.go (99%) delete mode 100644 builtin/logical/database/dbs/cassandra.go delete mode 100644 builtin/logical/database/dbs/connectionproducer.go delete mode 100644 builtin/logical/database/dbs/credentialsproducer.go delete mode 100644 builtin/logical/database/dbs/db.go delete mode 100644 builtin/logical/database/dbs/mssql.go delete mode 100644 builtin/logical/database/dbs/mssql_test.go delete mode 100644 builtin/logical/database/dbs/mysql.go delete mode 100644 builtin/logical/database/dbs/mysql_test.go delete mode 100644 builtin/logical/database/dbs/postgresql.go delete mode 100644 builtin/logical/database/dbs/postgresql_test.go rename builtin/logical/database/{dbs => }/plugin.go (88%) rename builtin/logical/database/{dbs => }/plugin_test.go (99%) create mode 100644 helper/builtinplugins/builtin.go delete mode 100644 helper/pluginutil/builtin.go diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4d069a4328..a2fff4ba86 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -4,16 +4,52 @@ import ( "fmt" "strings" "sync" + "time" log "github.com/mgutz/logxi/v1" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) const databaseConfigPath = "database/dbs/" +// DatabaseType is the interface that all database objects must implement. +type DatabaseType interface { + Type() string + CreateUser(statements Statements, username, password, expiration string) error + RenewUser(statements Statements, username, expiration string) error + RevokeUser(statements Statements, username string) error + + Initialize(map[string]interface{}) error + Close() error + + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Duration) (string, error) +} + +// DatabaseConfig is used by the Factory function to configure a DatabaseType +// object. +type DatabaseConfig struct { + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` + // ConnectionDetails stores the database specific connection settings needed + // by each database type. + ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` +} + +// Statements set in role creation and passed into the database type's functions. +// TODO: Add a way of setting defaults here. +type Statements struct { + CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` + RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` + RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` + RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` +} + func Factory(conf *logical.BackendConfig) (logical.Backend, error) { return Backend(conf).Setup(conf) } @@ -30,7 +66,6 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { }, Paths: []*framework.Path{ - pathConfigureBuiltinConnection(&b), pathConfigurePluginConnection(&b), pathListRoles(&b), pathRoles(&b), @@ -48,12 +83,12 @@ func Backend(conf *logical.BackendConfig) *databaseBackend { } b.logger = conf.Logger - b.connections = make(map[string]dbs.DatabaseType) + b.connections = make(map[string]DatabaseType) return &b } type databaseBackend struct { - connections map[string]dbs.DatabaseType + connections map[string]DatabaseType logger log.Logger *framework.Backend @@ -73,7 +108,7 @@ func (b *databaseBackend) closeAllDBs() { // This function is used to retrieve a database object either from the cached // connection map or by using the database config in storage. The caller of this // function needs to hold the backend's lock. -func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs.DatabaseType, error) { +func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (DatabaseType, error) { // if the object already is built and cached, return it db, ok := b.connections[name] if ok { @@ -88,14 +123,12 @@ func (b *databaseBackend) getOrCreateDBObj(s logical.Storage, name string) (dbs. return nil, fmt.Errorf("failed to find entry for connection with name: %s", name) } - var config dbs.DatabaseConfig + var config DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } - factory := config.GetFactory() - - db, err = factory(&config, b.System(), b.logger) + db, err = PluginFactory(&config, b.System(), b.logger) if err != nil { return nil, err } diff --git a/builtin/logical/database/dbs/databasemiddleware.go b/builtin/logical/database/databasemiddleware.go similarity index 99% rename from builtin/logical/database/dbs/databasemiddleware.go rename to builtin/logical/database/databasemiddleware.go index d3f037ecbc..5892e8064a 100644 --- a/builtin/logical/database/dbs/databasemiddleware.go +++ b/builtin/logical/database/databasemiddleware.go @@ -1,4 +1,4 @@ -package dbs +package database import ( "time" diff --git a/builtin/logical/database/dbs/cassandra.go b/builtin/logical/database/dbs/cassandra.go deleted file mode 100644 index 1be26766bd..0000000000 --- a/builtin/logical/database/dbs/cassandra.go +++ /dev/null @@ -1,108 +0,0 @@ -package dbs - -import ( - "fmt" - "strings" - - "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/strutil" -) - -const ( - defaultCreationCQL = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;` - defaultRollbackCQL = `DROP USER '{{username}}';` -) - -type Cassandra struct { - // Session is goroutine safe, however, since we reinitialize - // it when connection info changes, we want to make sure we - // can close it and use a new connection; hence the lock - ConnectionProducer - CredentialsProducer -} - -func (c *Cassandra) Type() string { - return cassandraTypeName -} - -func (c *Cassandra) getConnection() (*gocql.Session, error) { - session, err := c.connection() - if err != nil { - return nil, err - } - - return session.(*gocql.Session), nil -} - -func (c *Cassandra) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - c.Lock() - defer c.Unlock() - - // Get the connection - session, err := c.getConnection() - if err != nil { - return err - } - - creationCQL := statements.CreationStatements - if creationCQL == "" { - creationCQL = defaultCreationCQL - } - rollbackCQL := statements.RollbackStatements - if rollbackCQL == "" { - rollbackCQL = defaultRollbackCQL - } - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(creationCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - err = session.Query(queryHelper(query, map[string]string{ - "username": username, - "password": password, - })).Exec() - if err != nil { - for _, query := range strutil.ParseArbitraryStringSlice(rollbackCQL, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - session.Query(queryHelper(query, map[string]string{ - "username": username, - "password": password, - })).Exec() - } - return err - } - } - - return nil -} - -func (c *Cassandra) RenewUser(statements Statements, username, expiration string) error { - // NOOP - return nil -} - -func (c *Cassandra) RevokeUser(statements Statements, username string) error { - // Grab the lock - c.Lock() - defer c.Unlock() - - session, err := c.getConnection() - if err != nil { - return err - } - - err = session.Query(fmt.Sprintf("DROP USER '%s'", username)).Exec() - if err != nil { - return fmt.Errorf("error removing user %s", username) - } - - return nil -} diff --git a/builtin/logical/database/dbs/connectionproducer.go b/builtin/logical/database/dbs/connectionproducer.go deleted file mode 100644 index 31ef2853b7..0000000000 --- a/builtin/logical/database/dbs/connectionproducer.go +++ /dev/null @@ -1,280 +0,0 @@ -package dbs - -import ( - "crypto/tls" - "database/sql" - "errors" - "fmt" - "strings" - "sync" - "time" - - // Import sql drivers - _ "github.com/denisenkom/go-mssqldb" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - "github.com/mitchellh/mapstructure" - - "github.com/gocql/gocql" - "github.com/hashicorp/vault/helper/certutil" - "github.com/hashicorp/vault/helper/tlsutil" -) - -var ( - errNotInitalized = errors.New("connection has not been initalized") -) - -// ConnectionProducer can be used as an embeded interface in the DatabaseType -// definition. It implements the methods dealing with individual database -// connections and is used in all the builtin database types. -type ConnectionProducer interface { - Close() error - Initialize(map[string]interface{}) error - - sync.Locker - connection() (interface{}, error) -} - -// sqlConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases -type sqlConnectionProducer struct { - ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` - - config *DatabaseConfig - - initalized bool - db *sql.DB - sync.Mutex -} - -func (c *sqlConnectionProducer) Initialize(conf map[string]interface{}) error { - c.Lock() - defer c.Unlock() - - err := mapstructure.Decode(conf, c) - if err != nil { - return err - } - - if _, err := c.connection(); err != nil { - return fmt.Errorf("error initalizing connection: %s", err) - } - - c.initalized = true - - return nil -} - -func (c *sqlConnectionProducer) connection() (interface{}, error) { - // If we already have a DB, test it and return - if c.db != nil { - if err := c.db.Ping(); err == nil { - return c.db, nil - } - // If the ping was unsuccessful, close it and ignore errors as we'll be - // reestablishing anyways - c.db.Close() - } - - // For mssql backend, switch to sqlserver instead - dbType := c.config.DatabaseType - if c.config.DatabaseType == "mssql" { - dbType = "sqlserver" - } - - // Otherwise, attempt to make connection - conn := c.ConnectionURL - - // Ensure timezone is set to UTC for all the conenctions - if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { - if strings.Contains(conn, "?") { - conn += "&timezone=utc" - } else { - conn += "?timezone=utc" - } - } - - var err error - c.db, err = sql.Open(dbType, conn) - if err != nil { - return nil, err - } - - // Set some connection pool settings. We don't need much of this, - // since the request rate shouldn't be high. - c.db.SetMaxOpenConns(c.config.MaxOpenConnections) - c.db.SetMaxIdleConns(c.config.MaxIdleConnections) - c.db.SetConnMaxLifetime(c.config.MaxConnectionLifetime) - - return c.db, nil -} - -func (c *sqlConnectionProducer) Close() error { - // Grab the write lock - c.Lock() - defer c.Unlock() - - if c.db != nil { - c.db.Close() - } - - c.db = nil - - return nil -} - -// cassandraConnectionProducer implements ConnectionProducer and provides an -// interface for cassandra databases to make connections. -type cassandraConnectionProducer struct { - Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` - Username string `json:"username" structs:"username" mapstructure:"username"` - Password string `json:"password" structs:"password" mapstructure:"password"` - TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` - InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` - Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` - PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` - IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` - ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` - ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` - TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` - Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` - - config *DatabaseConfig - initalized bool - session *gocql.Session - sync.Mutex -} - -func (c *cassandraConnectionProducer) Initialize(conf map[string]interface{}) error { - c.Lock() - defer c.Unlock() - - err := mapstructure.Decode(conf, c) - if err != nil { - return err - } - c.initalized = true - - if _, err := c.connection(); err != nil { - return fmt.Errorf("error Initalizing Connection: %s", err) - } - - return nil -} - -func (c *cassandraConnectionProducer) connection() (interface{}, error) { - if !c.initalized { - return nil, errNotInitalized - } - - // If we already have a DB, return it - if c.session != nil { - return c.session, nil - } - - session, err := c.createSession() - if err != nil { - return nil, err - } - - // Store the session in backend for reuse - c.session = session - - return session, nil -} - -func (c *cassandraConnectionProducer) Close() error { - // Grab the write lock - c.Lock() - defer c.Unlock() - - if c.session != nil { - c.session.Close() - } - - c.session = nil - - return nil -} - -func (c *cassandraConnectionProducer) createSession() (*gocql.Session, error) { - clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) - clusterConfig.Authenticator = gocql.PasswordAuthenticator{ - Username: c.Username, - Password: c.Password, - } - - clusterConfig.ProtoVersion = c.ProtocolVersion - if clusterConfig.ProtoVersion == 0 { - clusterConfig.ProtoVersion = 2 - } - - clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second - - if c.TLS { - var tlsConfig *tls.Config - if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { - if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { - return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") - } - - certBundle := &certutil.CertBundle{} - if len(c.Certificate) > 0 { - certBundle.Certificate = c.Certificate - certBundle.PrivateKey = c.PrivateKey - } - if len(c.IssuingCA) > 0 { - certBundle.IssuingCA = c.IssuingCA - } - - parsedCertBundle, err := certBundle.ToParsedCertBundle() - if err != nil { - return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) - } - - tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) - if err != nil || tlsConfig == nil { - return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) - } - tlsConfig.InsecureSkipVerify = c.InsecureTLS - - if c.TLSMinVersion != "" { - var ok bool - tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] - if !ok { - return nil, fmt.Errorf("invalid 'tls_min_version' in config") - } - } else { - // MinVersion was not being set earlier. Reset it to - // zero to gracefully handle upgrades. - tlsConfig.MinVersion = 0 - } - } - - clusterConfig.SslOpts = &gocql.SslOptions{ - Config: *tlsConfig, - } - } - - session, err := clusterConfig.CreateSession() - if err != nil { - return nil, fmt.Errorf("error creating session: %s", err) - } - - // Set consistency - if c.Consistency != "" { - consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) - if err != nil { - return nil, err - } - - session.SetConsistency(consistencyValue) - } - - // Verify the info - err = session.Query(`LIST USERS`).Exec() - if err != nil { - return nil, fmt.Errorf("error validating connection info: %s", err) - } - - return session, nil -} diff --git a/builtin/logical/database/dbs/credentialsproducer.go b/builtin/logical/database/dbs/credentialsproducer.go deleted file mode 100644 index 6bd543f4e1..0000000000 --- a/builtin/logical/database/dbs/credentialsproducer.go +++ /dev/null @@ -1,83 +0,0 @@ -package dbs - -import ( - "fmt" - "strings" - "time" - - uuid "github.com/hashicorp/go-uuid" -) - -// CredentialsProducer can be used as an embeded interface in the DatabaseType -// definition. It implements the methods for generating user information for a -// particular database type and is used in all the builtin database types. -type CredentialsProducer interface { - GenerateUsername(displayName string) (string, error) - GeneratePassword() (string, error) - GenerateExpiration(ttl time.Duration) (string, error) -} - -// sqlCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. -type sqlCredentialsProducer struct { - displayNameLen int - usernameLen int -} - -func (scp *sqlCredentialsProducer) GenerateUsername(displayName string) (string, error) { - if scp.displayNameLen > 0 && len(displayName) > scp.displayNameLen { - displayName = displayName[:scp.displayNameLen] - } - userUUID, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - username := fmt.Sprintf("%s-%s", displayName, userUUID) - if scp.usernameLen > 0 && len(username) > scp.usernameLen { - username = username[:scp.usernameLen] - } - - return username, nil -} - -func (scp *sqlCredentialsProducer) GeneratePassword() (string, error) { - password, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - - return password, nil -} - -func (scp *sqlCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { - return time.Now(). - Add(ttl). - Format("2006-01-02 15:04:05-0700"), nil -} - -// cassandraCredentialsProducer implements CredentialsProducer and provides an -// interface for cassandra databases to generate user information. -type cassandraCredentialsProducer struct{} - -func (ccp *cassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { - userUUID, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) - username = strings.Replace(username, "-", "_", -1) - - return username, nil -} - -func (ccp *cassandraCredentialsProducer) GeneratePassword() (string, error) { - password, err := uuid.GenerateUUID() - if err != nil { - return "", err - } - - return password, nil -} - -func (ccp *cassandraCredentialsProducer) GenerateExpiration(ttl time.Duration) (string, error) { - return "", nil -} diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go deleted file mode 100644 index 49b18b3b85..0000000000 --- a/builtin/logical/database/dbs/db.go +++ /dev/null @@ -1,196 +0,0 @@ -package dbs - -import ( - "errors" - "fmt" - "strings" - "time" - - "github.com/hashicorp/vault/logical" - log "github.com/mgutz/logxi/v1" -) - -const ( - postgreSQLTypeName = "postgres" - mySQLTypeName = "mysql" - msSQLTypeName = "mssql" - cassandraTypeName = "cassandra" - pluginTypeName = "plugin" -) - -var ( - ErrUnsupportedDatabaseType = errors.New("unsupported database type") - ErrEmptyCreationStatement = errors.New("empty creation statements") - ErrEmptyPluginName = errors.New("empty plugin name") -) - -// Factory function definition -type Factory func(*DatabaseConfig, logical.SystemView, log.Logger) (DatabaseType, error) - -// BuiltinFactory is used to build builtin database types. It wraps the database -// object in a logging and metrics middleware. -func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - var dbType DatabaseType - - switch conf.DatabaseType { - case postgreSQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 23, - usernameLen: 63, - } - - dbType = &PostgreSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case mySQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 4, - usernameLen: 16, - } - - dbType = &MySQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case msSQLTypeName: - connProducer := &sqlConnectionProducer{} - connProducer.config = conf - - credsProducer := &sqlCredentialsProducer{ - displayNameLen: 10, - usernameLen: 63, - } - - dbType = &MSSQL{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - case cassandraTypeName: - connProducer := &cassandraConnectionProducer{} - connProducer.config = conf - - credsProducer := &cassandraCredentialsProducer{} - - dbType = &Cassandra{ - ConnectionProducer: connProducer, - CredentialsProducer: credsProducer, - } - - default: - return nil, ErrUnsupportedDatabaseType - } - - // Wrap with metrics middleware - dbType = &databaseMetricsMiddleware{ - next: dbType, - typeStr: dbType.Type(), - } - - // Wrap with tracing middleware - dbType = &databaseTracingMiddleware{ - next: dbType, - typeStr: dbType.Type(), - logger: logger, - } - - return dbType, nil -} - -// PluginFactory is used to build plugin database types. It wraps the database -// object in a logging and metrics middleware. -func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginName == "" { - return nil, ErrEmptyPluginName - } - - pluginMeta, err := sys.LookupPlugin(conf.PluginName) - if err != nil { - return nil, err - } - - // Make sure the database type is set to plugin - conf.DatabaseType = pluginTypeName - - db, err := newPluginClient(sys, pluginMeta) - if err != nil { - return nil, err - } - - // Wrap with metrics middleware - db = &databaseMetricsMiddleware{ - next: db, - typeStr: db.Type(), - } - - // Wrap with tracing middleware - db = &databaseTracingMiddleware{ - next: db, - typeStr: db.Type(), - logger: logger, - } - - return db, nil -} - -// DatabaseType is the interface that all database objects must implement. -type DatabaseType interface { - Type() string - CreateUser(statements Statements, username, password, expiration string) error - RenewUser(statements Statements, username, expiration string) error - RevokeUser(statements Statements, username string) error - - Initialize(map[string]interface{}) error - Close() error - CredentialsProducer -} - -// DatabaseConfig is used by the Factory function to configure a DatabaseType -// object. -type DatabaseConfig struct { - DatabaseType string `json:"type" structs:"type" mapstructure:"type"` - // ConnectionDetails stores the database specific connection settings needed - // by each database type. - ConnectionDetails map[string]interface{} `json:"connection_details" structs:"connection_details" mapstructure:"connection_details"` - MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` - MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` - MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` - PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` -} - -// GetFactory returns the appropriate factory method for the given database -// type. -func (dc *DatabaseConfig) GetFactory() Factory { - if dc.DatabaseType == pluginTypeName { - return PluginFactory - } - - return BuiltinFactory -} - -// Statements set in role creation and passed into the database type's functions. -// TODO: Add a way of setting defaults here. -type Statements struct { - CreationStatements string `json:"creation_statments" mapstructure:"creation_statements" structs:"creation_statments"` - RevocationStatements string `json:"revocation_statements" mapstructure:"revocation_statements" structs:"revocation_statements"` - RollbackStatements string `json:"rollback_statements" mapstructure:"rollback_statements" structs:"rollback_statements"` - RenewStatements string `json:"renew_statements" mapstructure:"renew_statements" structs:"renew_statements"` -} - -// Query templates a query for us. -func queryHelper(tpl string, data map[string]string) string { - for k, v := range data { - tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) - } - - return tpl -} diff --git a/builtin/logical/database/dbs/mssql.go b/builtin/logical/database/dbs/mssql.go deleted file mode 100644 index b7439b0a82..0000000000 --- a/builtin/logical/database/dbs/mssql.go +++ /dev/null @@ -1,219 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/strutil" -) - -// MSSQL is an implementation of DatabaseType interface -type MSSQL struct { - ConnectionProducer - CredentialsProducer -} - -// Type returns the TypeName for this backend -func (m *MSSQL) Type() string { - return msSQLTypeName -} - -func (m *MSSQL) getConnection() (*sql.DB, error) { - db, err := m.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by -// the CreationStatement provided. -func (m *MSSQL) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -// RenewUser is not supported on MSSQL, so this is a no-op. -func (m *MSSQL) RenewUser(statements Statements, username, expiration string) error { - // NOOP - return nil -} - -// RevokeUser attempts to drop the specified user. It will first attempt to disable login, -// then kill pending connections from that user, and finally drop the user and login from the -// database instance. -func (m *MSSQL) RevokeUser(statements Statements, username string) error { - // Get connection - db, err := m.getConnection() - if err != nil { - return err - } - - // First disable server login - disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) - if err != nil { - return err - } - defer disableStmt.Close() - if _, err := disableStmt.Exec(); err != nil { - return err - } - - // Query for sessions for the login so that we can kill any outstanding - // sessions. There cannot be any active sessions before we drop the logins - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - sessionStmt, err := db.Prepare(fmt.Sprintf( - "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) - if err != nil { - return err - } - defer sessionStmt.Close() - - sessionRows, err := sessionStmt.Query() - if err != nil { - return err - } - defer sessionRows.Close() - - var revokeStmts []string - for sessionRows.Next() { - var sessionID int - err = sessionRows.Scan(&sessionID) - if err != nil { - return err - } - revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) - } - - // Query for database users using undocumented stored procedure for now since - // it is the easiest way to get this information; - // we need to drop the database users before we can drop the login and the role - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) - if err != nil { - return err - } - defer stmt.Close() - - rows, err := stmt.Query() - if err != nil { - return err - } - defer rows.Close() - - for rows.Next() { - var loginName, dbName, qUsername string - var aliasName sql.NullString - err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) - if err != nil { - return err - } - revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) - } - - // we do not stop on error, as we want to remove as - // many permissions as possible right now - var lastStmtError error - for _, query := range revokeStmts { - stmt, err := db.Prepare(query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { - lastStmtError = err - } - } - - // can't drop if not all database users are dropped - if rows.Err() != nil { - return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) - } - if lastStmtError != nil { - return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) - } - - // Drop this login - stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} - -const dropUserSQL = ` -USE [%s] -IF EXISTS - (SELECT name - FROM sys.database_principals - WHERE name = N'%s') -BEGIN - DROP USER [%s] -END -` - -const dropLoginSQL = ` -IF EXISTS - (SELECT name - FROM master.sys.server_principals - WHERE name = N'%s') -BEGIN - DROP LOGIN [%s] -END -` diff --git a/builtin/logical/database/dbs/mssql_test.go b/builtin/logical/database/dbs/mssql_test.go deleted file mode 100644 index f2169299fa..0000000000 --- a/builtin/logical/database/dbs/mssql_test.go +++ /dev/null @@ -1,221 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "os" - "sync" - "testing" - "time" - - _ "github.com/denisenkom/go-mssqldb" - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v3" -) - -var ( - testMSQLImagePull sync.Once -) - -func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { - if os.Getenv("MSSQL_URL") != "" { - return func() {}, os.Getenv("MSSQL_URL") - } - - pool, err := dockertest.NewPool("") - if err != nil { - t.Fatalf("Failed to connect to docker: %s", err) - } - - resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) - if err != nil { - t.Fatalf("Could not start local MSSQL docker container: %s", err) - } - - cleanup = func() { - err := pool.Purge(resource) - if err != nil { - t.Fatalf("Failed to cleanup local DynamoDB: %s", err) - } - } - - retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) - - // exponential backoff-retry, because the mssql container may not be able to accept connections yet - if err = pool.Retry(func() error { - var err error - var db *sql.DB - db, err = sql.Open("mssql", retURL) - if err != nil { - return err - } - return db.Ping() - }); err != nil { - t.Fatalf("Could not connect to MSSQL docker container: %s", err) - } - - return -} - -func TestMSSQL_Initialize(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying mssql object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*MSSQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestMSSQL_CreateUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testMSSQLRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMSSQL_RevokeUser(t *testing.T) { - cleanup, connURL := prepareMSSQLTestContainer(t) - defer cleanup() - - conf := &DatabaseConfig{ - DatabaseType: msSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMSSQLRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -const testMSSQLRole = ` -CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; -CREATE USER [{{name}}] FOR LOGIN [{{name}}]; -GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` diff --git a/builtin/logical/database/dbs/mysql.go b/builtin/logical/database/dbs/mysql.go deleted file mode 100644 index 54940d8f65..0000000000 --- a/builtin/logical/database/dbs/mysql.go +++ /dev/null @@ -1,135 +0,0 @@ -package dbs - -import ( - "database/sql" - "strings" - - "github.com/hashicorp/vault/helper/strutil" -) - -const defaultMysqlRevocationStmts = ` - REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; - DROP USER '{{name}}'@'%' -` - -type MySQL struct { - ConnectionProducer - CredentialsProducer -} - -func (m *MySQL) Type() string { - return mySQLTypeName -} - -func (m *MySQL) getConnection() (*sql.DB, error) { - db, err := m.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -func (m *MySQL) CreateUser(statements Statements, username, password, expiration string) error { - // Grab the lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -// NOOP -func (m *MySQL) RenewUser(statements Statements, username, expiration string) error { - return nil -} - -func (m *MySQL) RevokeUser(statements Statements, username string) error { - // Grab the read lock - m.Lock() - defer m.Unlock() - - // Get the connection - db, err := m.getConnection() - if err != nil { - return err - } - - revocationStmts := statements.RevocationStatements - // Use a default SQL statement for revocation if one cannot be fetched from the role - if revocationStmts == "" { - revocationStmts = defaultMysqlRevocationStmts - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - // This is not a prepared statement because not all commands are supported - // 1295: This command is not supported in the prepared statement protocol yet - // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ - query = strings.Replace(query, "{{name}}", username, -1) - _, err = tx.Exec(query) - if err != nil { - return err - } - - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} diff --git a/builtin/logical/database/dbs/mysql_test.go b/builtin/logical/database/dbs/mysql_test.go deleted file mode 100644 index 553acc8ffd..0000000000 --- a/builtin/logical/database/dbs/mysql_test.go +++ /dev/null @@ -1,346 +0,0 @@ -package dbs - -import ( - "database/sql" - "os" - "sync" - "testing" - "time" - - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v2" -) - -var ( - testMySQLImagePull sync.Once -) - -func prepareMySQLTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("MYSQL_URL") != "" { - return "", os.Getenv("MYSQL_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testMySQLImagePull.Do(func() { - dockertest.Pull("mysql") - }) - - cid, connErr := dockertest.ConnectToMySQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - connProducer := &sqlConnectionProducer{} - connProducer.ConnectionURL = connURL - connProducer.config = &DatabaseConfig{ - DatabaseType: mySQLTypeName, - } - - conn, err := connProducer.connection() - if err != nil { - return false - } - if err := conn.(*sql.DB).Ping(); err != nil { - return false - } - - connProducer.Close() - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func TestMySQL_Initialize(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying mysql object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*MySQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestMySQL_CreateUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - statements.CreationStatements = testMySQLRoleHost - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMySQL_RenewUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(statements, username, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestMySQL_RevokeUser(t *testing.T) { - cid, connURL := prepareMySQLTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: mySQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testMySQLRoleWildCard, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements.CreationStatements = testMySQLRoleHost - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test custom revoke statements - statements.RevocationStatements = testMySQLRevocationSQL - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - -} - -const testMySQLRoleWildCard = ` -CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'%'; -` -const testMySQLRoleHost = ` -CREATE USER '{{name}}'@'10.1.1.2' IDENTIFIED BY '{{password}}'; -GRANT SELECT ON *.* TO '{{name}}'@'10.1.1.2'; -` -const testMySQLRevocationSQL = ` -REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'10.1.1.2'; -DROP USER '{{name}}'@'10.1.1.2'; -` diff --git a/builtin/logical/database/dbs/postgresql.go b/builtin/logical/database/dbs/postgresql.go deleted file mode 100644 index c8ba110cf7..0000000000 --- a/builtin/logical/database/dbs/postgresql.go +++ /dev/null @@ -1,279 +0,0 @@ -package dbs - -import ( - "database/sql" - "fmt" - "strings" - - "github.com/hashicorp/vault/helper/strutil" - "github.com/lib/pq" -) - -type PostgreSQL struct { - ConnectionProducer - CredentialsProducer -} - -func (p *PostgreSQL) Type() string { - return postgreSQLTypeName -} - -func (p *PostgreSQL) getConnection() (*sql.DB, error) { - db, err := p.connection() - if err != nil { - return nil, err - } - - return db.(*sql.DB), nil -} - -func (p *PostgreSQL) CreateUser(statements Statements, username, password, expiration string) error { - if statements.CreationStatements == "" { - return ErrEmptyCreationStatement - } - - // Grab the lock - p.Lock() - defer p.Unlock() - - // Get the connection - db, err := p.getConnection() - if err != nil { - return err - } - - // Start a transaction - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - tx.Rollback() - }() - // Return the secret - - // Execute each query - for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - "password": password, - "expiration": expiration, - })) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - } - - // Commit the transaction - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) RenewUser(statements Statements, username, expiration string) error { - // Grab the lock - p.Lock() - defer p.Unlock() - - db, err := p.getConnection() - if err != nil { - return err - } - - query := fmt.Sprintf( - "ALTER ROLE %s VALID UNTIL '%s';", - pq.QuoteIdentifier(username), - expiration) - - stmt, err := db.Prepare(query) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) RevokeUser(statements Statements, username string) error { - // Grab the lock - p.Lock() - defer p.Unlock() - - if statements.RevocationStatements == "" { - return p.defaultRevokeUser(username) - } - - return p.customRevokeUser(username, statements.RevocationStatements) -} - -func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { - db, err := p.getConnection() - if err != nil { - return err - } - - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - tx.Rollback() - }() - - for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { - query = strings.TrimSpace(query) - if len(query) == 0 { - continue - } - - stmt, err := tx.Prepare(queryHelper(query, map[string]string{ - "name": username, - })) - if err != nil { - return err - } - defer stmt.Close() - - if _, err := stmt.Exec(); err != nil { - return err - } - } - - if err := tx.Commit(); err != nil { - return err - } - - return nil -} - -func (p *PostgreSQL) defaultRevokeUser(username string) error { - db, err := p.getConnection() - if err != nil { - return err - } - - // Check if the role exists - var exists bool - err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) - if err != nil && err != sql.ErrNoRows { - return err - } - - if exists == false { - return nil - } - - // Query for permissions; we need to revoke permissions before we can drop - // the role - // This isn't done in a transaction because even if we fail along the way, - // we want to remove as much access as possible - stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") - if err != nil { - return err - } - defer stmt.Close() - - rows, err := stmt.Query(username) - if err != nil { - return err - } - defer rows.Close() - - const initialNumRevocations = 16 - revocationStmts := make([]string, 0, initialNumRevocations) - for rows.Next() { - var schema string - err = rows.Scan(&schema) - if err != nil { - // keep going; remove as many permissions as possible right now - continue - } - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, - pq.QuoteIdentifier(schema), - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE USAGE ON SCHEMA %s FROM %s;`, - pq.QuoteIdentifier(schema), - pq.QuoteIdentifier(username))) - } - - // for good measure, revoke all privileges and usage on schema public - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", - pq.QuoteIdentifier(username))) - - revocationStmts = append(revocationStmts, fmt.Sprintf( - "REVOKE USAGE ON SCHEMA public FROM %s;", - pq.QuoteIdentifier(username))) - - // get the current database name so we can issue a REVOKE CONNECT for - // this username - var dbname sql.NullString - if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { - return err - } - - if dbname.Valid { - revocationStmts = append(revocationStmts, fmt.Sprintf( - `REVOKE CONNECT ON DATABASE %s FROM %s;`, - pq.QuoteIdentifier(dbname.String), - pq.QuoteIdentifier(username))) - } - - // again, here, we do not stop on error, as we want to remove as - // many permissions as possible right now - var lastStmtError error - for _, query := range revocationStmts { - stmt, err := db.Prepare(query) - if err != nil { - lastStmtError = err - continue - } - defer stmt.Close() - _, err = stmt.Exec() - if err != nil { - lastStmtError = err - } - } - - // can't drop if not all privileges are revoked - if rows.Err() != nil { - return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) - } - if lastStmtError != nil { - return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) - } - - // Drop this user - stmt, err = db.Prepare(fmt.Sprintf( - `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) - if err != nil { - return err - } - defer stmt.Close() - if _, err := stmt.Exec(); err != nil { - return err - } - - return nil -} diff --git a/builtin/logical/database/dbs/postgresql_test.go b/builtin/logical/database/dbs/postgresql_test.go deleted file mode 100644 index 83aed50ba9..0000000000 --- a/builtin/logical/database/dbs/postgresql_test.go +++ /dev/null @@ -1,414 +0,0 @@ -package dbs - -import ( - "database/sql" - "os" - "sync" - "testing" - "time" - - log "github.com/mgutz/logxi/v1" - dockertest "gopkg.in/ory-am/dockertest.v2" -) - -var ( - testPostgresImagePull sync.Once -) - -func preparePostgresTestContainer(t *testing.T) (cid dockertest.ContainerID, retURL string) { - if os.Getenv("PG_URL") != "" { - return "", os.Getenv("PG_URL") - } - - // Without this the checks for whether the container has started seem to - // never actually pass. There's really no reason to expose the test - // containers, so don't. - dockertest.BindDockerToLocalhost = "yep" - - testPostgresImagePull.Do(func() { - dockertest.Pull("postgres") - }) - - cid, connErr := dockertest.ConnectToPostgreSQL(60, 500*time.Millisecond, func(connURL string) bool { - // This will cause a validation to run - connProducer := &sqlConnectionProducer{} - connProducer.ConnectionURL = connURL - connProducer.config = &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - } - - conn, err := connProducer.connection() - if err != nil { - return false - } - if err := conn.(*sql.DB).Ping(); err != nil { - return false - } - - connProducer.Close() - - retURL = connURL - return true - }) - - if connErr != nil { - t.Fatalf("could not connect to database: %v", connErr) - } - - return -} - -func cleanupTestContainer(t *testing.T, cid dockertest.ContainerID) { - err := cid.KillRemove() - if err != nil { - t.Fatal(err) - } -} - -func TestPostgreSQL_Initialize(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - dbRaw, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Deconsturct the middleware chain to get the underlying postgres object - dbTracer := dbRaw.(*databaseTracingMiddleware) - dbMetrics := dbTracer.next.(*databaseMetricsMiddleware) - db := dbMetrics.next.(*PostgreSQL) - connProducer := db.ConnectionProducer.(*sqlConnectionProducer) - - err = dbRaw.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - if !connProducer.initalized { - t.Fatal("Database should be initalized") - } - - err = dbRaw.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - if connProducer.db != nil { - t.Fatal("db object should be nil") - } -} - -func TestPostgreSQL_CreateUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test with no configured Creation Statememt - err = db.CreateUser(Statements{}, username, password, expiration) - if err == nil { - t.Fatal("Expected error when no creation statement is provided") - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - statements.CreationStatements = testPostgresReadOnlyRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - /* statements.CreationStatements = testBlockStatementRole - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - }*/ -} - -func TestPostgreSQL_RenewUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.RenewUser(statements, username, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } -} - -func TestPostgreSQL_RevokeUser(t *testing.T) { - cid, connURL := preparePostgresTestContainer(t) - if cid != "" { - defer cleanupTestContainer(t, cid) - } - - conf := &DatabaseConfig{ - DatabaseType: postgreSQLTypeName, - ConnectionDetails: map[string]interface{}{ - "connection_url": connURL, - }, - } - - db, err := BuiltinFactory(conf, nil, &log.NullLogger{}) - if err != nil { - t.Fatalf("err: %s", err) - } - err = db.Initialize(conf.ConnectionDetails) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err := db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err := db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err := db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - statements := Statements{ - CreationStatements: testPostgresRole, - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test default revoke statememts - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - - username, err = db.GenerateUsername("test") - if err != nil { - t.Fatalf("err: %s", err) - } - - password, err = db.GeneratePassword() - if err != nil { - t.Fatalf("err: %s", err) - } - - expiration, err = db.GenerateExpiration(time.Minute) - if err != nil { - t.Fatalf("err: %s", err) - } - - err = db.CreateUser(statements, username, password, expiration) - if err != nil { - t.Fatalf("err: %s", err) - } - - // Test custom revoke statements - statements.RevocationStatements = defaultPostgresRevocationSQL - err = db.RevokeUser(statements, username) - if err != nil { - t.Fatalf("err: %s", err) - } - -} - -const testPostgresRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; -` - -const testPostgresReadOnlyRole = ` -CREATE ROLE "{{name}}" WITH - LOGIN - PASSWORD '{{password}}' - VALID UNTIL '{{expiration}}'; -GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; -GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; -` - -const testPostgresBlockStatementRole = ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ - -CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; -GRANT "foo-role" TO "{{name}}"; -ALTER ROLE "{{name}}" SET search_path = foo; -GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; -` - -var testPostgresBlockStatementRoleSlice = []string{ - ` -DO $$ -BEGIN - IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN - CREATE ROLE "foo-role"; - CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; - ALTER ROLE "foo-role" SET search_path = foo; - GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; - GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; - GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; - END IF; -END -$$ -`, - `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, - `GRANT "foo-role" TO "{{name}}";`, - `ALTER ROLE "{{name}}" SET search_path = foo;`, - `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, -} - -const defaultPostgresRevocationSQL = ` -REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; -REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; -REVOKE USAGE ON SCHEMA public FROM "{{name}}"; - -DROP ROLE IF EXISTS "{{name}}"; -` diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index be2038c31c..48d9b88803 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -3,10 +3,8 @@ package database import ( "fmt" "strings" - "time" "github.com/fatih/structs" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -50,16 +48,10 @@ func (b *databaseBackend) pathConnectionReset(req *logical.Request, data *framew return nil, nil } -// pathConfigureBuiltinConnection returns a configured framework.Path setup to -// operate on builtin databases. -func pathConfigureBuiltinConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/%s", b.connectionWriteHandler(dbs.BuiltinFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) -} - // pathConfigurePluginConnection returns a configured framework.Path setup to // operate on plugins. func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { - return buildConfigConnectionPath("dbs/plugin/%s", b.connectionWriteHandler(dbs.PluginFactory), b.connectionReadHandler(), b.connectionDeleteHandler()) + return buildConfigConnectionPath("config/%s", b.connectionWriteHandler(), b.connectionReadHandler(), b.connectionDeleteHandler()) } // buildConfigConnectionPath reutns a configured framework.Path using the passed @@ -74,40 +66,12 @@ func buildConfigConnectionPath(path string, updateOp, readOp, deleteOp framework Description: "Name of this DB type", }, - "connection_type": &framework.FieldSchema{ - Type: framework.TypeString, - Description: "DB type (e.g. postgres)", - }, - "verify_connection": &framework.FieldSchema{ Type: framework.TypeBool, Default: true, Description: `If set, connection_url is verified by actually connecting to the database`, }, - "max_open_connections": &framework.FieldSchema{ - Type: framework.TypeInt, - Description: `Maximum number of open connections to the database; -a zero uses the default value of two and a -negative value means unlimited`, - }, - - "max_idle_connections": &framework.FieldSchema{ - Type: framework.TypeInt, - Description: `Maximum number of idle connections to the database; -a zero uses the value of max_open_connections -and a negative value disables idle connections. -If larger than max_open_connections it will be -reduced to the same size.`, - }, - - "max_connection_lifetime": &framework.FieldSchema{ - Type: framework.TypeString, - Default: "0s", - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, - }, - "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `Maximum amount of time a connection may be reused; @@ -139,7 +103,7 @@ func (b *databaseBackend) connectionReadHandler() framework.OperationFunc { return nil, nil } - var config dbs.DatabaseConfig + var config DatabaseConfig if err := entry.DecodeJSON(&config); err != nil { return nil, err } @@ -180,40 +144,12 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { // connectionWriteHandler returns a handler function for creating and updating // both builtin and plugin database types. -func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.OperationFunc { +func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - connType := data.Get("connection_type").(string) - if connType == "" { - return logical.ErrorResponse("connection_type not set"), nil - } - maxOpenConns := data.Get("max_open_connections").(int) - if maxOpenConns == 0 { - maxOpenConns = 2 - } - - maxIdleConns := data.Get("max_idle_connections").(int) - if maxIdleConns == 0 { - maxIdleConns = maxOpenConns - } - if maxIdleConns > maxOpenConns { - maxIdleConns = maxOpenConns - } - - maxConnLifetimeRaw := data.Get("max_connection_lifetime").(string) - maxConnLifetime, err := time.ParseDuration(maxConnLifetimeRaw) - if err != nil { - return logical.ErrorResponse(fmt.Sprintf( - "Invalid max_connection_lifetime: %s", err)), nil - } - - config := &dbs.DatabaseConfig{ - DatabaseType: connType, - ConnectionDetails: data.Raw, - MaxOpenConnections: maxOpenConns, - MaxIdleConnections: maxIdleConns, - MaxConnectionLifetime: maxConnLifetime, - PluginName: data.Get("plugin_name").(string), + config := &DatabaseConfig{ + ConnectionDetails: data.Raw, + PluginName: data.Get("plugin_name").(string), } name := data.Get("name").(string) @@ -227,7 +163,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. b.Lock() defer b.Unlock() - db, err := factory(config, b.System(), b.logger) + db, err := PluginFactory(config, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("Error creating database object: %s", err)), nil } diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 6f62c79d98..d099ef1787 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/hashicorp/vault/builtin/logical/database/dbs" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -156,7 +155,7 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F "Invalid max_ttl: %s", err)), nil } - statements := dbs.Statements{ + statements := Statements{ CreationStatements: creationStmts, RevocationStatements: revocationStmts, RollbackStatements: rollbackStmts, @@ -183,10 +182,10 @@ func (b *databaseBackend) pathRoleCreate(req *logical.Request, data *framework.F } type roleEntry struct { - DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` - Statements dbs.Statements `json:"statments" mapstructure:"statements" structs:"statments"` - DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` - MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` + DBName string `json:"db_name" mapstructure:"db_name" structs:"db_name"` + Statements Statements `json:"statments" mapstructure:"statements" structs:"statments"` + DefaultTTL time.Duration `json:"default_ttl" mapstructure:"default_ttl" structs:"default_ttl"` + MaxTTL time.Duration `json:"max_ttl" mapstructure:"max_ttl" structs:"max_ttl"` } const pathRoleHelpSyn = ` diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/plugin.go similarity index 88% rename from builtin/logical/database/dbs/plugin.go rename to builtin/logical/database/plugin.go index 441f97ca0f..5a6a8e3285 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/plugin.go @@ -1,6 +1,7 @@ -package dbs +package database import ( + "errors" "fmt" "net/rpc" "sync" @@ -8,8 +9,47 @@ import ( "github.com/hashicorp/go-plugin" "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/logical" + log "github.com/mgutz/logxi/v1" ) +var ( + ErrEmptyPluginName = errors.New("empty plugin name") +) + +// PluginFactory is used to build plugin database types. It wraps the database +// object in a logging and metrics middleware. +func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { + if conf.PluginName == "" { + return nil, ErrEmptyPluginName + } + + pluginMeta, err := sys.LookupPlugin(conf.PluginName) + if err != nil { + return nil, err + } + + db, err := newPluginClient(sys, pluginMeta) + if err != nil { + return nil, err + } + + // Wrap with metrics middleware + db = &databaseMetricsMiddleware{ + next: db, + typeStr: db.Type(), + } + + // Wrap with tracing middleware + db = &databaseTracingMiddleware{ + next: db, + typeStr: db.Type(), + logger: logger, + } + + return db, nil +} + // handshakeConfigs are used to just do a basic handshake between // a plugin and host. If the handshake fails, a user friendly error is shown. // This prevents users from executing bad plugins or executing a plugin @@ -33,7 +73,7 @@ func (DatabasePlugin) Client(b *plugin.MuxBroker, c *rpc.Client) (interface{}, e } // DatabasePluginClient embeds a databasePluginRPCClient and wraps it's close -// method to also call Close() on the plugin.Client. +// method to also call Kill() on the plugin.Client. type DatabasePluginClient struct { client *plugin.Client sync.Mutex diff --git a/builtin/logical/database/dbs/plugin_test.go b/builtin/logical/database/plugin_test.go similarity index 99% rename from builtin/logical/database/dbs/plugin_test.go rename to builtin/logical/database/plugin_test.go index 60cb6814dd..2ec01c9556 100644 --- a/builtin/logical/database/dbs/plugin_test.go +++ b/builtin/logical/database/plugin_test.go @@ -1,4 +1,4 @@ -package dbs +package database import ( "crypto/sha256" diff --git a/command/plugin-exec.go b/command/plugin-exec.go index f0d6a8d51a..70bc8ae1d4 100644 --- a/command/plugin-exec.go +++ b/command/plugin-exec.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/meta" ) @@ -29,7 +29,7 @@ func (c *PluginExec) Run(args []string) int { pluginName := args[0] - runner, ok := pluginutil.BuiltinPlugins[pluginName] + runner, ok := builtinplugins.BuiltinPlugins[pluginName] if !ok { c.Ui.Error(fmt.Sprintf( "No plugin with the name %s found", pluginName)) diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go new file mode 100644 index 0000000000..6880640d15 --- /dev/null +++ b/helper/builtinplugins/builtin.go @@ -0,0 +1,8 @@ +package builtinplugins + +import "github.com/hashicorp/vault-plugins/database/mysql" + +var BuiltinPlugins = map[string]func() error{ + "mysql-database-plugin": mysql.Run, + // "postgres-database-plugin": postgres.Run, +} diff --git a/helper/pluginutil/builtin.go b/helper/pluginutil/builtin.go deleted file mode 100644 index 6a464bb824..0000000000 --- a/helper/pluginutil/builtin.go +++ /dev/null @@ -1,6 +0,0 @@ -package pluginutil - -var BuiltinPlugins = map[string]func() error{ -// "mysql-database-plugin": mysql.Run, -// "postgres-database-plugin": postgres.Run, -} diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index c6e4e4059b..b9c15db22a 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "github.com/hashicorp/vault/helper/builtinplugins" "github.com/hashicorp/vault/helper/jsonutil" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" @@ -53,7 +54,7 @@ func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { } // Look for builtin plugins - if _, ok := pluginutil.BuiltinPlugins[name]; !ok { + if _, ok := builtinplugins.BuiltinPlugins[name]; !ok { return nil, fmt.Errorf("no plugin found with name: %s", name) }