From be50cbae91a3af52f72d2ead17222bb300cfe8ea Mon Sep 17 00:00:00 2001 From: Brian Kassouf Date: Thu, 13 Apr 2017 13:48:32 -0700 Subject: [PATCH] Move plugins into main vault repo --- helper/builtinplugins/builtin.go | 4 +- .../mssql/mssql-database-plugin/main.go | 16 + plugins/database/mssql/mssql.go | 268 ++++++++++++++ plugins/database/mssql/mssql_test.go | 173 +++++++++ .../mysql/mysql-database-plugin/main.go | 16 + plugins/database/mysql/mysql.go | 183 ++++++++++ plugins/database/mysql/mysql_test.go | 200 +++++++++++ .../postgresql-database-plugin/main.go | 16 + plugins/database/postgresql/postgresql.go | 337 ++++++++++++++++++ .../database/postgresql/postgresql_test.go | 308 ++++++++++++++++ plugins/helper/database/connutil/cassandra.go | 172 +++++++++ plugins/helper/database/connutil/connutil.go | 21 ++ plugins/helper/database/connutil/sql.go | 131 +++++++ .../helper/database/credsutil/cassandra.go | 37 ++ .../helper/database/credsutil/credsutil.go | 12 + plugins/helper/database/credsutil/sql.go | 43 +++ plugins/helper/database/dbutil/dbutil.go | 20 ++ 17 files changed, 1955 insertions(+), 2 deletions(-) create mode 100644 plugins/database/mssql/mssql-database-plugin/main.go create mode 100644 plugins/database/mssql/mssql.go create mode 100644 plugins/database/mssql/mssql_test.go create mode 100644 plugins/database/mysql/mysql-database-plugin/main.go create mode 100644 plugins/database/mysql/mysql.go create mode 100644 plugins/database/mysql/mysql_test.go create mode 100644 plugins/database/postgresql/postgresql-database-plugin/main.go create mode 100644 plugins/database/postgresql/postgresql.go create mode 100644 plugins/database/postgresql/postgresql_test.go create mode 100644 plugins/helper/database/connutil/cassandra.go create mode 100644 plugins/helper/database/connutil/connutil.go create mode 100644 plugins/helper/database/connutil/sql.go create mode 100644 plugins/helper/database/credsutil/cassandra.go create mode 100644 plugins/helper/database/credsutil/credsutil.go create mode 100644 plugins/helper/database/credsutil/sql.go create mode 100644 plugins/helper/database/dbutil/dbutil.go diff --git a/helper/builtinplugins/builtin.go b/helper/builtinplugins/builtin.go index 55da9a97f3..beedbb15b8 100644 --- a/helper/builtinplugins/builtin.go +++ b/helper/builtinplugins/builtin.go @@ -1,8 +1,8 @@ package builtinplugins import ( - "github.com/hashicorp/vault-plugins/database/mysql" - "github.com/hashicorp/vault-plugins/database/postgresql" + "github.com/hashicorp/vault/plugins/database/mysql" + "github.com/hashicorp/vault/plugins/database/postgresql" ) var BuiltinPlugins *builtinPlugins = &builtinPlugins{ diff --git a/plugins/database/mssql/mssql-database-plugin/main.go b/plugins/database/mssql/mssql-database-plugin/main.go new file mode 100644 index 0000000000..ead1cf8423 --- /dev/null +++ b/plugins/database/mssql/mssql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/mssql" +) + +func main() { + err := mssql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/mssql/mssql.go b/plugins/database/mssql/mssql.go new file mode 100644 index 0000000000..567a095b66 --- /dev/null +++ b/plugins/database/mssql/mssql.go @@ -0,0 +1,268 @@ +package mssql + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const msSQLTypeName = "mssql" + +// MSSQL is an implementation of DatabaseType interface +type MSSQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *MSSQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = msSQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &MSSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MSSQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +// Type returns the TypeName for this backend +func (m *MSSQL) Type() (string, error) { + return msSQLTypeName, nil +} + +func (m *MSSQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +// CreateUser generates the username/password on the underlying MSSQL secret backend as instructed by +// the CreationStatement provided. +func (m *MSSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return "", "", err + } + + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + username, err = m.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = m.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := m.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + } + + return username, password, nil +} + +// RenewUser is not supported on MSSQL, so this is a no-op. +func (m *MSSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // NOOP + return nil +} + +// RevokeUser attempts to drop the specified user. It will first attempt to disable login, +// then kill pending connections from that user, and finally drop the user and login from the +// database instance. +func (m *MSSQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Get connection + db, err := m.getConnection() + if err != nil { + return err + } + + // First disable server login + disableStmt, err := db.Prepare(fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username)) + if err != nil { + return err + } + defer disableStmt.Close() + if _, err := disableStmt.Exec(); err != nil { + return err + } + + // Query for sessions for the login so that we can kill any outstanding + // sessions. There cannot be any active sessions before we drop the logins + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + sessionStmt, err := db.Prepare(fmt.Sprintf( + "SELECT session_id FROM sys.dm_exec_sessions WHERE login_name = '%s';", username)) + if err != nil { + return err + } + defer sessionStmt.Close() + + sessionRows, err := sessionStmt.Query() + if err != nil { + return err + } + defer sessionRows.Close() + + var revokeStmts []string + for sessionRows.Next() { + var sessionID int + err = sessionRows.Scan(&sessionID) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf("KILL %d;", sessionID)) + } + + // Query for database users using undocumented stored procedure for now since + // it is the easiest way to get this information; + // we need to drop the database users before we can drop the login and the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare(fmt.Sprintf("EXEC sp_msloginmappings '%s';", username)) + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query() + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var loginName, dbName, qUsername string + var aliasName sql.NullString + err = rows.Scan(&loginName, &dbName, &qUsername, &aliasName) + if err != nil { + return err + } + revokeStmts = append(revokeStmts, fmt.Sprintf(dropUserSQL, dbName, username, username)) + } + + // we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revokeStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all database users are dropped + if rows.Err() != nil { + return fmt.Errorf("cound not generate sql statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all sql statements: %s", lastStmtError) + } + + // Drop this login + stmt, err = db.Prepare(fmt.Sprintf(dropLoginSQL, username, username)) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +const dropUserSQL = ` +USE [%s] +IF EXISTS + (SELECT name + FROM sys.database_principals + WHERE name = N'%s') +BEGIN + DROP USER [%s] +END +` + +const dropLoginSQL = ` +IF EXISTS + (SELECT name + FROM master.sys.server_principals + WHERE name = N'%s') +BEGIN + DROP LOGIN [%s] +END +` diff --git a/plugins/database/mssql/mssql_test.go b/plugins/database/mssql/mssql_test.go new file mode 100644 index 0000000000..bc182f26fd --- /dev/null +++ b/plugins/database/mssql/mssql_test.go @@ -0,0 +1,173 @@ +package mssql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMSQLImagePull sync.Once +) + +func prepareMSSQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MSSQL_URL") != "" { + return func() {}, os.Getenv("MSSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("microsoft/mssql-server-linux", "latest", []string{"ACCEPT_EULA=Y", "SA_PASSWORD=yourStrong(!)Password"}) + if err != nil { + t.Fatalf("Could not start local MSSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("sqlserver://sa:yourStrong(!)Password@localhost:%s", resource.GetPort("1433/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mssql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MSSQL docker container: %s", err) + } + + return +} + +func TestMSSQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMSSQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testMSSQLRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMSSQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMSSQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testMSSQLRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "sa:yourStrong(!)Password", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("mssql", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testMSSQLRole = ` +CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}'; +CREATE USER [{{name}}] FOR LOGIN [{{name}}]; +GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];` diff --git a/plugins/database/mysql/mysql-database-plugin/main.go b/plugins/database/mysql/mysql-database-plugin/main.go new file mode 100644 index 0000000000..c0ec75c9cd --- /dev/null +++ b/plugins/database/mysql/mysql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/mysql" +) + +func main() { + err := mysql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/mysql/mysql.go b/plugins/database/mysql/mysql.go new file mode 100644 index 0000000000..ea14a6782b --- /dev/null +++ b/plugins/database/mysql/mysql.go @@ -0,0 +1,183 @@ +package mysql + +import ( + "database/sql" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" +) + +const defaultMysqlRevocationStmts = ` + REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; + DROP USER '{{name}}'@'%' +` +const mySQLTypeName = "mysql" + +type MySQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func New() *MySQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = mySQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &MySQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instantiates a MySQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +func (m *MySQL) Type() (string, error) { + return mySQLTypeName, nil +} + +func (m *MySQL) getConnection() (*sql.DB, error) { + db, err := m.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (m *MySQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + // Grab the lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return "", "", err + } + + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + username, err = m.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = m.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := m.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + } + defer tx.Rollback() + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + } + + return username, password, nil +} + +// NOOP +func (m *MySQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + return nil +} + +func (m *MySQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the read lock + m.Lock() + defer m.Unlock() + + // Get the connection + db, err := m.getConnection() + if err != nil { + return err + } + + revocationStmts := statements.RevocationStatements + // Use a default SQL statement for revocation if one cannot be fetched from the role + if revocationStmts == "" { + revocationStmts = defaultMysqlRevocationStmts + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + // This is not a prepared statement because not all commands are supported + // 1295: This command is not supported in the prepared statement protocol yet + // Reference https://mariadb.com/kb/en/mariadb/prepare-statement/ + query = strings.Replace(query, "{{name}}", username, -1) + _, err = tx.Exec(query) + if err != nil { + return err + } + + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/plugins/database/mysql/mysql_test.go b/plugins/database/mysql/mysql_test.go new file mode 100644 index 0000000000..2b1f272918 --- /dev/null +++ b/plugins/database/mysql/mysql_test.go @@ -0,0 +1,200 @@ +package mysql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testMySQLImagePull sync.Once +) + +func prepareMySQLTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("MYSQL_URL") != "" { + return func() {}, os.Getenv("MYSQL_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("mysql", "latest", []string{"MYSQL_ROOT_PASSWORD=secret"}) + if err != nil { + t.Fatalf("Could not start local MySQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("root:secret@(localhost:%s)/mysql?parseTime=true", resource.GetPort("3306/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("mysql", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to MySQL docker container: %s", err) + } + + return +} + +func TestMySQL_Initialize(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestMySQL_CreateUser(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestMySQL_RevokeUser(t *testing.T) { + cleanup, connURL := prepareMySQLTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testMySQLRoleWildCard, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + + statements.CreationStatements = testMySQLRoleWildCard + username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = testMySQLRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "root:secret", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("mysql", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testMySQLRoleWildCard = ` +CREATE USER '{{name}}'@'%' IDENTIFIED BY '{{password}}'; +GRANT SELECT ON *.* TO '{{name}}'@'%'; +` +const testMySQLRevocationSQL = ` +REVOKE ALL PRIVILEGES, GRANT OPTION FROM '{{name}}'@'%'; +DROP USER '{{name}}'@'%'; +` diff --git a/plugins/database/postgresql/postgresql-database-plugin/main.go b/plugins/database/postgresql/postgresql-database-plugin/main.go new file mode 100644 index 0000000000..9b9b813c4c --- /dev/null +++ b/plugins/database/postgresql/postgresql-database-plugin/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "os" + + "github.com/hashicorp/vault/plugins/database/postgresql" +) + +func main() { + err := postgresql.Run() + if err != nil { + fmt.Println(err) + os.Exit(1) + } +} diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go new file mode 100644 index 0000000000..b8449f5498 --- /dev/null +++ b/plugins/database/postgresql/postgresql.go @@ -0,0 +1,337 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/helper/strutil" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + "github.com/hashicorp/vault/plugins/helper/database/credsutil" + "github.com/hashicorp/vault/plugins/helper/database/dbutil" + "github.com/lib/pq" +) + +const postgreSQLTypeName string = "postgres" + +func New() *PostgreSQL { + connProducer := &connutil.SQLConnectionProducer{} + connProducer.Type = postgreSQLTypeName + + credsProducer := &credsutil.SQLCredentialsProducer{ + DisplayNameLen: 4, + UsernameLen: 16, + } + + dbType := &PostgreSQL{ + ConnectionProducer: connProducer, + CredentialsProducer: credsProducer, + } + + return dbType +} + +// Run instatiates a PostgreSQL object, and runs the RPC server for the plugin +func Run() error { + dbType := New() + + dbplugin.NewPluginServer(dbType) + + return nil +} + +type PostgreSQL struct { + connutil.ConnectionProducer + credsutil.CredentialsProducer +} + +func (p *PostgreSQL) Type() (string, error) { + return postgreSQLTypeName, nil +} + +func (p *PostgreSQL) getConnection() (*sql.DB, error) { + db, err := p.Connection() + if err != nil { + return nil, err + } + + return db.(*sql.DB), nil +} + +func (p *PostgreSQL) CreateUser(statements dbplugin.Statements, usernamePrefix string, expiration time.Time) (username string, password string, err error) { + if statements.CreationStatements == "" { + return "", "", dbutil.ErrEmptyCreationStatement + } + + // Grab the lock + p.Lock() + defer p.Unlock() + + username, err = p.GenerateUsername(usernamePrefix) + if err != nil { + return "", "", err + } + + password, err = p.GeneratePassword() + if err != nil { + return "", "", err + } + + expirationStr, err := p.GenerateExpiration(expiration) + if err != nil { + return "", "", err + } + + // Get the connection + db, err := p.getConnection() + if err != nil { + return "", "", err + + } + + // Start a transaction + tx, err := db.Begin() + if err != nil { + return "", "", err + + } + defer func() { + tx.Rollback() + }() + // Return the secret + + // Execute each query + for _, query := range strutil.ParseArbitraryStringSlice(statements.CreationStatements, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + "password": password, + "expiration": expirationStr, + })) + if err != nil { + return "", "", err + + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return "", "", err + + } + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return "", "", err + + } + + return username, password, nil +} + +func (p *PostgreSQL) RenewUser(statements dbplugin.Statements, username string, expiration time.Time) error { + // Grab the lock + p.Lock() + defer p.Unlock() + + db, err := p.getConnection() + if err != nil { + return err + } + + expirationStr, err := p.GenerateExpiration(expiration) + if err != nil { + return err + } + + query := fmt.Sprintf( + "ALTER ROLE %s VALID UNTIL '%s';", + pq.QuoteIdentifier(username), + expirationStr) + + stmt, err := db.Prepare(query) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) RevokeUser(statements dbplugin.Statements, username string) error { + // Grab the lock + p.Lock() + defer p.Unlock() + + if statements.RevocationStatements == "" { + return p.defaultRevokeUser(username) + } + + return p.customRevokeUser(username, statements.RevocationStatements) +} + +func (p *PostgreSQL) customRevokeUser(username, revocationStmts string) error { + db, err := p.getConnection() + if err != nil { + return err + } + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + tx.Rollback() + }() + + for _, query := range strutil.ParseArbitraryStringSlice(revocationStmts, ";") { + query = strings.TrimSpace(query) + if len(query) == 0 { + continue + } + + stmt, err := tx.Prepare(dbutil.QueryHelper(query, map[string]string{ + "name": username, + })) + if err != nil { + return err + } + defer stmt.Close() + + if _, err := stmt.Exec(); err != nil { + return err + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +func (p *PostgreSQL) defaultRevokeUser(username string) error { + db, err := p.getConnection() + if err != nil { + return err + } + + // Check if the role exists + var exists bool + err = db.QueryRow("SELECT exists (SELECT rolname FROM pg_roles WHERE rolname=$1);", username).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return err + } + + if exists == false { + return nil + } + + // Query for permissions; we need to revoke permissions before we can drop + // the role + // This isn't done in a transaction because even if we fail along the way, + // we want to remove as much access as possible + stmt, err := db.Prepare("SELECT DISTINCT table_schema FROM information_schema.role_column_grants WHERE grantee=$1;") + if err != nil { + return err + } + defer stmt.Close() + + rows, err := stmt.Query(username) + if err != nil { + return err + } + defer rows.Close() + + const initialNumRevocations = 16 + revocationStmts := make([]string, 0, initialNumRevocations) + for rows.Next() { + var schema string + err = rows.Scan(&schema) + if err != nil { + // keep going; remove as many permissions as possible right now + continue + } + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE USAGE ON SCHEMA %s FROM %s;`, + pq.QuoteIdentifier(schema), + pq.QuoteIdentifier(username))) + } + + // for good measure, revoke all privileges and usage on schema public + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM %s;`, + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + revocationStmts = append(revocationStmts, fmt.Sprintf( + "REVOKE USAGE ON SCHEMA public FROM %s;", + pq.QuoteIdentifier(username))) + + // get the current database name so we can issue a REVOKE CONNECT for + // this username + var dbname sql.NullString + if err := db.QueryRow("SELECT current_database();").Scan(&dbname); err != nil { + return err + } + + if dbname.Valid { + revocationStmts = append(revocationStmts, fmt.Sprintf( + `REVOKE CONNECT ON DATABASE %s FROM %s;`, + pq.QuoteIdentifier(dbname.String), + pq.QuoteIdentifier(username))) + } + + // again, here, we do not stop on error, as we want to remove as + // many permissions as possible right now + var lastStmtError error + for _, query := range revocationStmts { + stmt, err := db.Prepare(query) + if err != nil { + lastStmtError = err + continue + } + defer stmt.Close() + _, err = stmt.Exec() + if err != nil { + lastStmtError = err + } + } + + // can't drop if not all privileges are revoked + if rows.Err() != nil { + return fmt.Errorf("could not generate revocation statements for all rows: %s", rows.Err()) + } + if lastStmtError != nil { + return fmt.Errorf("could not perform all revocation statements: %s", lastStmtError) + } + + // Drop this user + stmt, err = db.Prepare(fmt.Sprintf( + `DROP ROLE IF EXISTS %s;`, pq.QuoteIdentifier(username))) + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(); err != nil { + return err + } + + return nil +} diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go new file mode 100644 index 0000000000..c7ccc8ee8f --- /dev/null +++ b/plugins/database/postgresql/postgresql_test.go @@ -0,0 +1,308 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" + "github.com/hashicorp/vault/plugins/helper/database/connutil" + dockertest "gopkg.in/ory-am/dockertest.v3" +) + +var ( + testPostgresImagePull sync.Once +) + +func preparePostgresTestContainer(t *testing.T) (cleanup func(), retURL string) { + if os.Getenv("PG_URL") != "" { + return func() {}, os.Getenv("PG_URL") + } + + pool, err := dockertest.NewPool("") + if err != nil { + t.Fatalf("Failed to connect to docker: %s", err) + } + + resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=database"}) + if err != nil { + t.Fatalf("Could not start local PostgreSQL docker container: %s", err) + } + + cleanup = func() { + err := pool.Purge(resource) + if err != nil { + t.Fatalf("Failed to cleanup local container: %s", err) + } + } + + retURL = fmt.Sprintf("postgres://postgres:secret@localhost:%s/database?sslmode=disable", resource.GetPort("5432/tcp")) + + // exponential backoff-retry + if err = pool.Retry(func() error { + var err error + var db *sql.DB + db, err = sql.Open("postgres", retURL) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + t.Fatalf("Could not connect to PostgreSQL docker container: %s", err) + } + + return +} + +func TestPostgreSQL_Initialize(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + connProducer := db.ConnectionProducer.(*connutil.SQLConnectionProducer) + + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + if !connProducer.Initialized { + t.Fatal("Database should be initalized") + } + + err = db.Close() + if err != nil { + t.Fatalf("err: %s", err) + } +} + +func TestPostgreSQL_CreateUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Test with no configured Creation Statememt + _, _, err = db.CreateUser(dbplugin.Statements{}, "test", time.Now().Add(time.Minute)) + if err == nil { + t.Fatal("Expected error when no creation statement is provided") + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + statements.CreationStatements = testPostgresReadOnlyRole + username, password, err = db.CreateUser(statements, "test", time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestPostgreSQL_RenewUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + err = db.RenewUser(statements, username, time.Now().Add(time.Minute)) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Sleep longer than the inital expiration time + time.Sleep(2 * time.Second) + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } +} + +func TestPostgreSQL_RevokeUser(t *testing.T) { + cleanup, connURL := preparePostgresTestContainer(t) + defer cleanup() + + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + } + + db := New() + err := db.Initialize(connectionDetails, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + statements := dbplugin.Statements{ + CreationStatements: testPostgresRole, + } + + username, password, err := db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test default revoke statememts + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } + + username, password, err = db.CreateUser(statements, "test", time.Now().Add(2*time.Second)) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err = testCredsExist(t, connURL, username, password); err != nil { + t.Fatalf("Could not connect with new credentials: %s", err) + } + + // Test custom revoke statements + statements.RevocationStatements = defaultPostgresRevocationSQL + err = db.RevokeUser(statements, username) + if err != nil { + t.Fatalf("err: %s", err) + } + + if err := testCredsExist(t, connURL, username, password); err == nil { + t.Fatal("Credentials were not revoked") + } +} + +func testCredsExist(t testing.TB, connURL, username, password string) error { + // Log in with the new creds + connURL = strings.Replace(connURL, "postgres:secret", fmt.Sprintf("%s:%s", username, password), 1) + db, err := sql.Open("postgres", connURL) + if err != nil { + return err + } + defer db.Close() + return db.Ping() +} + +const testPostgresRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}"; +` + +const testPostgresReadOnlyRole = ` +CREATE ROLE "{{name}}" WITH + LOGIN + PASSWORD '{{password}}' + VALID UNTIL '{{expiration}}'; +GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}"; +GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}"; +` + +const testPostgresBlockStatementRole = ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ + +CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; +GRANT "foo-role" TO "{{name}}"; +ALTER ROLE "{{name}}" SET search_path = foo; +GRANT CONNECT ON DATABASE "postgres" TO "{{name}}"; +` + +var testPostgresBlockStatementRoleSlice = []string{ + ` +DO $$ +BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_roles WHERE rolname='foo-role') THEN + CREATE ROLE "foo-role"; + CREATE SCHEMA IF NOT EXISTS foo AUTHORIZATION "foo-role"; + ALTER ROLE "foo-role" SET search_path = foo; + GRANT TEMPORARY ON DATABASE "postgres" TO "foo-role"; + GRANT ALL PRIVILEGES ON SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA foo TO "foo-role"; + GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA foo TO "foo-role"; + END IF; +END +$$ +`, + `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}';`, + `GRANT "foo-role" TO "{{name}}";`, + `ALTER ROLE "{{name}}" SET search_path = foo;`, + `GRANT CONNECT ON DATABASE "postgres" TO "{{name}}";`, +} + +const defaultPostgresRevocationSQL = ` +REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA public FROM "{{name}}"; +REVOKE ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public FROM "{{name}}"; +REVOKE USAGE ON SCHEMA public FROM "{{name}}"; + +DROP ROLE IF EXISTS "{{name}}"; +` diff --git a/plugins/helper/database/connutil/cassandra.go b/plugins/helper/database/connutil/cassandra.go new file mode 100644 index 0000000000..305bc6e3d0 --- /dev/null +++ b/plugins/helper/database/connutil/cassandra.go @@ -0,0 +1,172 @@ +package connutil + +import ( + "crypto/tls" + "fmt" + "strings" + "sync" + "time" + + "github.com/mitchellh/mapstructure" + + "github.com/gocql/gocql" + "github.com/hashicorp/vault/helper/certutil" + "github.com/hashicorp/vault/helper/tlsutil" +) + +// CassandraConnectionProducer implements ConnectionProducer and provides an +// interface for cassandra databases to make connections. +type CassandraConnectionProducer struct { + Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"` + Username string `json:"username" structs:"username" mapstructure:"username"` + Password string `json:"password" structs:"password" mapstructure:"password"` + TLS bool `json:"tls" structs:"tls" mapstructure:"tls"` + InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"` + Certificate string `json:"certificate" structs:"certificate" mapstructure:"certificate"` + PrivateKey string `json:"private_key" structs:"private_key" mapstructure:"private_key"` + IssuingCA string `json:"issuing_ca" structs:"issuing_ca" mapstructure:"issuing_ca"` + ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"` + ConnectTimeout int `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"` + TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"` + Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"` + + Initialized bool + session *gocql.Session + sync.Mutex +} + +func (c *CassandraConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + c.Initialized = true + + if verifyConnection { + if _, err := c.connection(); err != nil { + return fmt.Errorf("error Initalizing Connection: %s", err) + } + } + return nil +} + +func (c *CassandraConnectionProducer) connection() (interface{}, error) { + if !c.Initialized { + return nil, errNotInitialized + } + + // If we already have a DB, return it + if c.session != nil { + return c.session, nil + } + + session, err := c.createSession() + if err != nil { + return nil, err + } + + // Store the session in backend for reuse + c.session = session + + return session, nil +} + +func (c *CassandraConnectionProducer) Close() error { + // Grab the write lock + c.Lock() + defer c.Unlock() + + if c.session != nil { + c.session.Close() + } + + c.session = nil + + return nil +} + +func (c *CassandraConnectionProducer) createSession() (*gocql.Session, error) { + clusterConfig := gocql.NewCluster(strings.Split(c.Hosts, ",")...) + clusterConfig.Authenticator = gocql.PasswordAuthenticator{ + Username: c.Username, + Password: c.Password, + } + + clusterConfig.ProtoVersion = c.ProtocolVersion + if clusterConfig.ProtoVersion == 0 { + clusterConfig.ProtoVersion = 2 + } + + clusterConfig.Timeout = time.Duration(c.ConnectTimeout) * time.Second + + if c.TLS { + var tlsConfig *tls.Config + if len(c.Certificate) > 0 || len(c.IssuingCA) > 0 { + if len(c.Certificate) > 0 && len(c.PrivateKey) == 0 { + return nil, fmt.Errorf("Found certificate for TLS authentication but no private key") + } + + certBundle := &certutil.CertBundle{} + if len(c.Certificate) > 0 { + certBundle.Certificate = c.Certificate + certBundle.PrivateKey = c.PrivateKey + } + if len(c.IssuingCA) > 0 { + certBundle.IssuingCA = c.IssuingCA + } + + parsedCertBundle, err := certBundle.ToParsedCertBundle() + if err != nil { + return nil, fmt.Errorf("failed to parse certificate bundle: %s", err) + } + + tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient) + if err != nil || tlsConfig == nil { + return nil, fmt.Errorf("failed to get TLS configuration: tlsConfig:%#v err:%v", tlsConfig, err) + } + tlsConfig.InsecureSkipVerify = c.InsecureTLS + + if c.TLSMinVersion != "" { + var ok bool + tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion] + if !ok { + return nil, fmt.Errorf("invalid 'tls_min_version' in config") + } + } else { + // MinVersion was not being set earlier. Reset it to + // zero to gracefully handle upgrades. + tlsConfig.MinVersion = 0 + } + } + + clusterConfig.SslOpts = &gocql.SslOptions{ + Config: *tlsConfig, + } + } + + session, err := clusterConfig.CreateSession() + if err != nil { + return nil, fmt.Errorf("error creating session: %s", err) + } + + // Set consistency + if c.Consistency != "" { + consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency) + if err != nil { + return nil, err + } + + session.SetConsistency(consistencyValue) + } + + // Verify the info + err = session.Query(`LIST USERS`).Exec() + if err != nil { + return nil, fmt.Errorf("error validating connection info: %s", err) + } + + return session, nil +} diff --git a/plugins/helper/database/connutil/connutil.go b/plugins/helper/database/connutil/connutil.go new file mode 100644 index 0000000000..6de3299e38 --- /dev/null +++ b/plugins/helper/database/connutil/connutil.go @@ -0,0 +1,21 @@ +package connutil + +import ( + "errors" + "sync" +) + +var ( + errNotInitialized = errors.New("connection has not been initalized") +) + +// ConnectionProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods dealing with individual database +// connections and is used in all the builtin database types. +type ConnectionProducer interface { + Close() error + Initialize(map[string]interface{}, bool) error + Connection() (interface{}, error) + + sync.Locker +} diff --git a/plugins/helper/database/connutil/sql.go b/plugins/helper/database/connutil/sql.go new file mode 100644 index 0000000000..0bfc5f9f68 --- /dev/null +++ b/plugins/helper/database/connutil/sql.go @@ -0,0 +1,131 @@ +package connutil + +import ( + "database/sql" + "fmt" + "strings" + "sync" + "time" + + // Import sql drivers + _ "github.com/denisenkom/go-mssqldb" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "github.com/mitchellh/mapstructure" +) + +// SQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases +type SQLConnectionProducer struct { + ConnectionURL string `json:"connection_url" structs:"connection_url" mapstructure:"connection_url"` + MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` + MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` + MaxConnectionLifetimeRaw string `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` + + Type string + MaxConnectionLifetime time.Duration + Initialized bool + db *sql.DB + sync.Mutex +} + +func (c *SQLConnectionProducer) Initialize(conf map[string]interface{}, verifyConnection bool) error { + c.Lock() + defer c.Unlock() + + err := mapstructure.Decode(conf, c) + if err != nil { + return err + } + + if c.MaxOpenConnections == 0 { + c.MaxOpenConnections = 2 + } + + if c.MaxIdleConnections == 0 { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxIdleConnections > c.MaxOpenConnections { + c.MaxIdleConnections = c.MaxOpenConnections + } + if c.MaxConnectionLifetimeRaw == "" { + c.MaxConnectionLifetimeRaw = "0s" + } + + c.MaxConnectionLifetime, err = time.ParseDuration(c.MaxConnectionLifetimeRaw) + if err != nil { + return fmt.Errorf("invalid max_connection_lifetime: %s", err) + } + + if verifyConnection { + if _, err := c.Connection(); err != nil { + return fmt.Errorf("error initalizing connection: %s", err) + } + + if err := c.db.Ping(); err != nil { + return fmt.Errorf("error initalizing connection: %s", err) + } + } + + c.Initialized = true + + return nil +} + +func (c *SQLConnectionProducer) Connection() (interface{}, error) { + // If we already have a DB, test it and return + if c.db != nil { + if err := c.db.Ping(); err == nil { + return c.db, nil + } + // If the ping was unsuccessful, close it and ignore errors as we'll be + // reestablishing anyways + c.db.Close() + } + + // For mssql backend, switch to sqlserver instead + dbType := c.Type + if c.Type == "mssql" { + dbType = "sqlserver" + } + + // Otherwise, attempt to make connection + conn := c.ConnectionURL + + // Ensure timezone is set to UTC for all the conenctions + if strings.HasPrefix(conn, "postgres://") || strings.HasPrefix(conn, "postgresql://") { + if strings.Contains(conn, "?") { + conn += "&timezone=utc" + } else { + conn += "?timezone=utc" + } + } + + var err error + c.db, err = sql.Open(dbType, conn) + if err != nil { + return nil, err + } + + // Set some connection pool settings. We don't need much of this, + // since the request rate shouldn't be high. + c.db.SetMaxOpenConns(c.MaxOpenConnections) + c.db.SetMaxIdleConns(c.MaxIdleConnections) + c.db.SetConnMaxLifetime(c.MaxConnectionLifetime) + + return c.db, nil +} + +// Close attempts to close the connection +func (c *SQLConnectionProducer) Close() error { + // Grab the write lock + c.Lock() + defer c.Unlock() + + if c.db != nil { + c.db.Close() + } + + c.db = nil + + return nil +} diff --git a/plugins/helper/database/credsutil/cassandra.go b/plugins/helper/database/credsutil/cassandra.go new file mode 100644 index 0000000000..7ab5630b58 --- /dev/null +++ b/plugins/helper/database/credsutil/cassandra.go @@ -0,0 +1,37 @@ +package credsutil + +import ( + "fmt" + "strings" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +// CassandraCredentialsProducer implements CredentialsProducer and provides an +// interface for cassandra databases to generate user information. +type CassandraCredentialsProducer struct{} + +func (ccp *CassandraCredentialsProducer) GenerateUsername(displayName string) (string, error) { + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("vault_%s_%s_%d", displayName, userUUID, time.Now().Unix()) + username = strings.Replace(username, "-", "_", -1) + + return username, nil +} + +func (ccp *CassandraCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (ccp *CassandraCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) { + return "", nil +} diff --git a/plugins/helper/database/credsutil/credsutil.go b/plugins/helper/database/credsutil/credsutil.go new file mode 100644 index 0000000000..7f388a0f76 --- /dev/null +++ b/plugins/helper/database/credsutil/credsutil.go @@ -0,0 +1,12 @@ +package credsutil + +import "time" + +// CredentialsProducer can be used as an embeded interface in the DatabaseType +// definition. It implements the methods for generating user information for a +// particular database type and is used in all the builtin database types. +type CredentialsProducer interface { + GenerateUsername(displayName string) (string, error) + GeneratePassword() (string, error) + GenerateExpiration(ttl time.Time) (string, error) +} diff --git a/plugins/helper/database/credsutil/sql.go b/plugins/helper/database/credsutil/sql.go new file mode 100644 index 0000000000..23e98102f3 --- /dev/null +++ b/plugins/helper/database/credsutil/sql.go @@ -0,0 +1,43 @@ +package credsutil + +import ( + "fmt" + "time" + + uuid "github.com/hashicorp/go-uuid" +) + +// SQLCredentialsProducer implements CredentialsProducer and provides a generic credentials producer for most sql database types. +type SQLCredentialsProducer struct { + DisplayNameLen int + UsernameLen int +} + +func (scp *SQLCredentialsProducer) GenerateUsername(displayName string) (string, error) { + if scp.DisplayNameLen > 0 && len(displayName) > scp.DisplayNameLen { + displayName = displayName[:scp.DisplayNameLen] + } + userUUID, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + username := fmt.Sprintf("%s-%s", displayName, userUUID) + if scp.UsernameLen > 0 && len(username) > scp.UsernameLen { + username = username[:scp.UsernameLen] + } + + return username, nil +} + +func (scp *SQLCredentialsProducer) GeneratePassword() (string, error) { + password, err := uuid.GenerateUUID() + if err != nil { + return "", err + } + + return password, nil +} + +func (scp *SQLCredentialsProducer) GenerateExpiration(ttl time.Time) (string, error) { + return ttl.Format("2006-01-02 15:04:05-0700"), nil +} diff --git a/plugins/helper/database/dbutil/dbutil.go b/plugins/helper/database/dbutil/dbutil.go new file mode 100644 index 0000000000..e80273b7fb --- /dev/null +++ b/plugins/helper/database/dbutil/dbutil.go @@ -0,0 +1,20 @@ +package dbutil + +import ( + "errors" + "fmt" + "strings" +) + +var ( + ErrEmptyCreationStatement = errors.New("empty creation statements") +) + +// Query templates a query for us. +func QueryHelper(tpl string, data map[string]string) string { + for k, v := range data { + tpl = strings.Replace(tpl, fmt.Sprintf("{{%s}}", k), v, -1) + } + + return tpl +}