From 9c3881442e1b438b9a100e277e125732a2054ba6 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Mon, 8 Jun 2015 16:17:44 +0545 Subject: [PATCH 01/21] Physical MySQL backend implementation - First Cut --- physical/mysql.go | 190 +++++++++++++++++++++++++++++++++++++++++ physical/mysql_test.go | 86 +++++++++++++++++++ physical/physical.go | 1 + 3 files changed, 277 insertions(+) create mode 100644 physical/mysql.go create mode 100644 physical/mysql_test.go diff --git a/physical/mysql.go b/physical/mysql.go new file mode 100644 index 0000000000..377ac9dc49 --- /dev/null +++ b/physical/mysql.go @@ -0,0 +1,190 @@ +package physical + +import ( + "database/sql" + "errors" + "sort" + "strings" + "time" + + "github.com/armon/go-metrics" + _ "github.com/go-sql-driver/mysql" +) + +var ( + MySQLDBNameMissing = errors.New("database name is missing in the configuration") + MySQLTableNameMissing = errors.New("table name is missing in the configuration") + MySQLHandlerCreationFailure = errors.New("failed to open handler with database") + MySQLPrepareStmtFailure = errors.New("failed to prepare statement") + MySQLExecuteStmtFailure = errors.New("failed to execute statement") + MySQLGetColumnsFailure = errors.New("failed to get columns") + MySQLScanRowsFailure = errors.New("failed to scan rows") +) + +// MySQLBackend is a physical backend that stores data +// within MySQL database. +type MySQLBackend struct { + table string + database string + client *sql.DB +} + +// newMySQLBackend constructs a MySQL backend using the given API client and +// server address and credential for accessing mysql database. +func newMySQLBackend(conf map[string]string) (Backend, error) { + // Get or set MySQL server address. Defaults to localhost and default port(3306) + address, ok := conf["address"] + if !ok { + address = "127.0.0.1:3306" + } + + // Get the MySQL credentials to perform read/write operations. + username, ok := conf["username"] + password, ok := conf["password"] + + // Get the MySQL database and table details. + database, ok := conf["database"] + if !ok { + return nil, MySQLDBNameMissing + } + table, ok := conf["table"] + if !ok { + return nil, MySQLTableNameMissing + } + + // Create MySQL handle for the database. + dsn := username + ":" + password + "@tcp(" + address + ")/" + database + db, err := sql.Open("mysql", dsn) + if err != nil { + return nil, MySQLHandlerCreationFailure + } + defer db.Close() + + // Create the required table. + create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))" + stmt, err := db.Prepare(create_stmt) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + defer stmt.Close() + + _, err = stmt.Exec() + if err != nil { + return nil, MySQLExecuteStmtFailure + } + + // Setup the backend. + m := &MySQLBackend{ + client: db, + table: table, + database: database, + } + + return m, nil +} + +// Put is used to insert or update an entry. +func (m *MySQLBackend) Put(entry *Entry) error { + defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) + + insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" + stmt, err := m.client.Prepare(insert_stmt) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(entry.Key, entry.Value) + if err != nil { + return err + } + + return nil +} + +// Get is used to fetch and entry. +func (m *MySQLBackend) Get(key string) (*Entry, error) { + defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) + + select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?" + stmt, err := m.client.Prepare(select_stmt) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + defer stmt.Close() + + var result []byte + + err = stmt.QueryRow(key).Scan(&result) + if err != nil { + return nil, MySQLExecuteStmtFailure + } + + ent := &Entry{ + Key: key, + Value: result, + } + + return ent, nil +} + +// Delete is used to permanently delete an entry +func (m *MySQLBackend) Delete(key string) error { + defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) + + delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?" + stmt, err := m.client.Prepare(delete_stmt) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(key) + if err != nil { + return err + } + + return nil +} + +// List is used to list all the keys under a given +// prefix, up to the next prefix. +func (m *MySQLBackend) List(prefix string) ([]string, error) { + defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) + + list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table + rows, err := m.client.Query(list_stmt) + if err != nil { + return nil, MySQLExecuteStmtFailure + } + + columns, err := rows.Columns() + if err != nil { + return nil, MySQLGetColumnsFailure + } + + values := make([]sql.RawBytes, len(columns)) + + scanArgs := make([]interface{}, len(values)) + for i := range values { + scanArgs[i] = &values[i] + } + + keys := []string{} + for rows.Next() { + err = rows.Scan(scanArgs...) + if err != nil { + return nil, MySQLScanRowsFailure + } + + for _, col := range values { + if strings.HasPrefix(string(col), prefix) { + keys = append(keys, string(col)) + } + } + } + + sort.Strings(keys) + + return keys, nil +} diff --git a/physical/mysql_test.go b/physical/mysql_test.go new file mode 100644 index 0000000000..b808a29bdb --- /dev/null +++ b/physical/mysql_test.go @@ -0,0 +1,86 @@ +package physical + +import ( + "database/sql" + "fmt" + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" +) + +func TestMySQLBackend(t *testing.T) { + address := os.Getenv("MYSQL_ADDR") + if address == "" { + t.SkipNow() + } + + database := os.Getenv("MYSQL_DB") + if database == "" { + database = "test" + } + + table := os.Getenv("MYSQL_TABLE") + if table == "" { + table = "test" + } + + username := os.Getenv("MYSQL_USERNAME") + password := os.Getenv("MYSQL_PASSWORD") + + // Create MySQL handle for the database. + db, err := sql.Open("mysql", username+":"+password+"@tcp("+address+")/"+database) + + if err != nil { + t.Fatalf("Failed to open an handler with database: %v", err) + } + defer db.Close() + + // Prepare statement for creating table. + create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" + stmtCrt, err := db.Prepare(create_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmtCrt.Close() + + // Create table + _, err = stmtCrt.Exec() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + // Prepare statement for inserting data. + insert_stmt := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" + stmtIns, err := db.Prepare(insert_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmtIns.Close() + + // Prepare statement for reading data. + select_stmt := "SELECT sqr FROM " + database + "." + table + " WHERE num = ?" + stmtOut, err := db.Prepare(select_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmtOut.Close() + + // Insert square numbers for 0-24 in the database + for i := 0; i < 25; i++ { + _, err = stmtIns.Exec(i, (i * i)) // Insert tuples (i, i^2) + if err != nil { + t.Fatalf("Failed to insert data: %v", err) + } + } + + var square int + + // Query the square-number of 13 + err = stmtOut.QueryRow(13).Scan(&square) + if err != nil { + t.Fatalf("Failed to query data: %v", err) + } + fmt.Printf("The square number of 13 is: %d", square) + +} diff --git a/physical/physical.go b/physical/physical.go index b992c10102..bd307e8dea 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -84,4 +84,5 @@ var BuiltinBackends = map[string]Factory{ "file": newFileBackend, "s3": newS3Backend, "etcd": newEtcdBackend, + "mysql": newMySQLBackend, } From c42bc38c62adb6bd671d39dafbfe127934ec9c67 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Fri, 12 Jun 2015 11:41:25 +0545 Subject: [PATCH 02/21] Changes done as per feedback --- physical/mysql.go | 110 +++++++++++++++++++++-------------------- physical/mysql_test.go | 2 +- 2 files changed, 58 insertions(+), 54 deletions(-) diff --git a/physical/mysql.go b/physical/mysql.go index 377ac9dc49..ed53a5350d 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -3,8 +3,8 @@ package physical import ( "database/sql" "errors" + "fmt" "sort" - "strings" "time" "github.com/armon/go-metrics" @@ -12,21 +12,17 @@ import ( ) var ( - MySQLDBNameMissing = errors.New("database name is missing in the configuration") - MySQLTableNameMissing = errors.New("table name is missing in the configuration") - MySQLHandlerCreationFailure = errors.New("failed to open handler with database") - MySQLPrepareStmtFailure = errors.New("failed to prepare statement") - MySQLExecuteStmtFailure = errors.New("failed to execute statement") - MySQLGetColumnsFailure = errors.New("failed to get columns") - MySQLScanRowsFailure = errors.New("failed to scan rows") + MySQLPrepareStmtFailure = errors.New("failed to prepare statement") + MySQLExecuteStmtFailure = errors.New("failed to execute statement") ) // MySQLBackend is a physical backend that stores data // within MySQL database. type MySQLBackend struct { - table string - database string - client *sql.DB + table string + database string + client *sql.DB + statements map[string]*sql.Stmt } // newMySQLBackend constructs a MySQL backend using the given API client and @@ -45,39 +41,70 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { // Get the MySQL database and table details. database, ok := conf["database"] if !ok { - return nil, MySQLDBNameMissing + return nil, fmt.Errorf("database name is missing in the configuration") } table, ok := conf["table"] if !ok { - return nil, MySQLTableNameMissing + return nil, fmt.Errorf("table name is missing in the configuration") } // Create MySQL handle for the database. dsn := username + ":" + password + "@tcp(" + address + ")/" + database db, err := sql.Open("mysql", dsn) if err != nil { - return nil, MySQLHandlerCreationFailure + return nil, fmt.Errorf("failed to open handler with database") } defer db.Close() - // Create the required table. - create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))" - stmt, err := db.Prepare(create_stmt) + // Create the required table if it doesn't exists. + create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" + create_stmt, err := db.Prepare(create_query) if err != nil { return nil, MySQLPrepareStmtFailure } - defer stmt.Close() + defer create_stmt.Close() - _, err = stmt.Exec() + _, err = create_stmt.Exec() if err != nil { return nil, MySQLExecuteStmtFailure } + // Map of query type as key to prepared statement. + var statements map[string]*sql.Stmt + + // Prepare statement for put query. + insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" + insert_stmt, err := db.Prepare(insert_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["put"] = insert_stmt + defer insert_stmt.Close() + + // Prepare statement for select query. + select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?" + select_stmt, err := db.Prepare(select_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["get"] = select_stmt + defer select_stmt.Close() + + // Prepare statement for delete query. + delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?" + delete_stmt, err := db.Prepare(delete_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["delete"] = delete_stmt + defer delete_stmt.Close() + // Setup the backend. m := &MySQLBackend{ - client: db, - table: table, - database: database, + client: db, + table: table, + database: database, + statements: statements, } return m, nil @@ -87,14 +114,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { func (m *MySQLBackend) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) - insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" - stmt, err := m.client.Prepare(insert_stmt) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(entry.Key, entry.Value) + _, err := m.statements["put"].Exec(entry.Key, entry.Value) if err != nil { return err } @@ -106,16 +126,9 @@ func (m *MySQLBackend) Put(entry *Entry) error { func (m *MySQLBackend) Get(key string) (*Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) - select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?" - stmt, err := m.client.Prepare(select_stmt) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - defer stmt.Close() - var result []byte - err = stmt.QueryRow(key).Scan(&result) + err := m.statements["get"].QueryRow(key).Scan(&result) if err != nil { return nil, MySQLExecuteStmtFailure } @@ -132,14 +145,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { func (m *MySQLBackend) Delete(key string) error { defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) - delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?" - stmt, err := m.client.Prepare(delete_stmt) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(key) + _, err := m.statements["delete"].Exec(key) if err != nil { return err } @@ -152,15 +158,15 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) - list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table - rows, err := m.client.Query(list_stmt) + list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'" + rows, err := m.client.Query(list_query) if err != nil { return nil, MySQLExecuteStmtFailure } columns, err := rows.Columns() if err != nil { - return nil, MySQLGetColumnsFailure + return nil, fmt.Errorf("failed to get columns") } values := make([]sql.RawBytes, len(columns)) @@ -174,13 +180,11 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) { for rows.Next() { err = rows.Scan(scanArgs...) if err != nil { - return nil, MySQLScanRowsFailure + return nil, fmt.Errorf("failed to scan rows") } for _, col := range values { - if strings.HasPrefix(string(col), prefix) { - keys = append(keys, string(col)) - } + keys = append(keys, string(col)) } } diff --git a/physical/mysql_test.go b/physical/mysql_test.go index b808a29bdb..c1c78c9fc7 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) { defer db.Close() // Prepare statement for creating table. - create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" + create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" stmtCrt, err := db.Prepare(create_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) From b33d707b693b9a1b858d75103feec12c4afa2440 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Fri, 12 Jun 2015 15:32:45 +0545 Subject: [PATCH 03/21] Added the test as per suggestion --- physical/mysql.go | 3 ++- physical/mysql_test.go | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/physical/mysql.go b/physical/mysql.go index ed53a5350d..50d4112bc2 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -70,7 +70,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { } // Map of query type as key to prepared statement. - var statements map[string]*sql.Stmt + statements := make(map[string]*sql.Stmt) // Prepare statement for put query. insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" @@ -158,6 +158,7 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) + // Query to get all keys matching a prefix. list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'" rows, err := m.client.Query(list_query) if err != nil { diff --git a/physical/mysql_test.go b/physical/mysql_test.go index c1c78c9fc7..c8b4fd470c 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -83,4 +83,18 @@ func TestMySQLBackend(t *testing.T) { } fmt.Printf("The square number of 13 is: %d", square) + b, err := NewBackend("mysql", map[string]string{ + "address": address, + "database": database, + "table": table, + "username": username, + "password": password, + }) + + if err != nil { + t.Fatalf("Failed to create new backend: %v", err) + } + + testBackend(t, b) + testBackend_ListPrefix(t, b) } From 3ff10a757330a05594576f0d72e07f443c2ab15a Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Fri, 12 Jun 2015 23:16:46 +0545 Subject: [PATCH 04/21] Fixing List command behaviour --- physical/mysql.go | 21 +++++++++++++++------ physical/mysql_test.go | 7 ++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/physical/mysql.go b/physical/mysql.go index 50d4112bc2..8666c249d3 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sort" + "strings" "time" "github.com/armon/go-metrics" @@ -54,7 +55,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { if err != nil { return nil, fmt.Errorf("failed to open handler with database") } - defer db.Close() // Create the required table if it doesn't exists. create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" @@ -79,7 +79,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { return nil, MySQLPrepareStmtFailure } statements["put"] = insert_stmt - defer insert_stmt.Close() // Prepare statement for select query. select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?" @@ -88,7 +87,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { return nil, MySQLPrepareStmtFailure } statements["get"] = select_stmt - defer select_stmt.Close() // Prepare statement for delete query. delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?" @@ -97,7 +95,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { return nil, MySQLPrepareStmtFailure } statements["delete"] = delete_stmt - defer delete_stmt.Close() // Setup the backend. m := &MySQLBackend{ @@ -133,6 +130,11 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { return nil, MySQLExecuteStmtFailure } + // Handle a non-existing value + if result == nil { + return nil, nil + } + ent := &Entry{ Key: key, Value: result, @@ -184,8 +186,15 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) { return nil, fmt.Errorf("failed to scan rows") } - for _, col := range values { - keys = append(keys, string(col)) + for _, key := range values { + key := strings.TrimPrefix(string(key), prefix) + if i := strings.Index(string(key), "/"); i == -1 { + // Add objects only from the current 'folder' + keys = append(keys, string(key)) + } else if i != -1 { + // Add truncated 'folder' paths + keys = appendIfMissing(keys, string(key[:i+1])) + } } } diff --git a/physical/mysql_test.go b/physical/mysql_test.go index c8b4fd470c..9cbb11a5db 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) { defer db.Close() // Prepare statement for creating table. - create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" + create_stmt := "CREATE TABLE IF NOT EXISTS test.square (num int, sqr int, PRIMARY KEY (num))" stmtCrt, err := db.Prepare(create_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) @@ -51,7 +51,7 @@ func TestMySQLBackend(t *testing.T) { } // Prepare statement for inserting data. - insert_stmt := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" + insert_stmt := "INSERT INTO test.square VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" stmtIns, err := db.Prepare(insert_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) @@ -59,7 +59,7 @@ func TestMySQLBackend(t *testing.T) { defer stmtIns.Close() // Prepare statement for reading data. - select_stmt := "SELECT sqr FROM " + database + "." + table + " WHERE num = ?" + select_stmt := "SELECT sqr FROM test.square WHERE num = ?" stmtOut, err := db.Prepare(select_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) @@ -83,6 +83,7 @@ func TestMySQLBackend(t *testing.T) { } fmt.Printf("The square number of 13 is: %d", square) + // Run vault tests b, err := NewBackend("mysql", map[string]string{ "address": address, "database": database, From 7c7f64fe6728bd3960bdf6e5870640c2ba32c034 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Sat, 13 Jun 2015 08:04:40 +0545 Subject: [PATCH 05/21] Fixed a failing test and drop table after running tests --- physical/mysql.go | 2 +- physical/mysql_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/physical/mysql.go b/physical/mysql.go index 8666c249d3..bb4a89c5e5 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -127,7 +127,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { err := m.statements["get"].QueryRow(key).Scan(&result) if err != nil { - return nil, MySQLExecuteStmtFailure + return nil, nil } // Handle a non-existing value diff --git a/physical/mysql_test.go b/physical/mysql_test.go index 9cbb11a5db..2b00404ffb 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -98,4 +98,17 @@ func TestMySQLBackend(t *testing.T) { testBackend(t, b) testBackend_ListPrefix(t, b) + + // Drop table after running tests + drop_stmt := "DROP TABLE " + database + "." + table + stmt, err := db.Prepare(drop_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmt.Close() + + _, err = stmt.Exec() + if err != nil { + t.Fatalf("Failed to drop table: %v", err) + } } From c60889572ea7232775ff53d979904e0a7646cb6a Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:23:59 -0700 Subject: [PATCH 06/21] vault: support core shutdown --- vault/core.go | 23 +++++++++++++++++++++++ vault/core_test.go | 11 +++++++++++ 2 files changed, 34 insertions(+) diff --git a/vault/core.go b/vault/core.go index 9791412c08..3c67f63fad 100644 --- a/vault/core.go +++ b/vault/core.go @@ -328,6 +328,21 @@ func NewCore(conf *CoreConfig) (*Core, error) { return c, nil } +// Shutdown is invoked when the Vault instance is about to be terminated. It +// should not be accessible as part of an API call as it will cause an availability +// problem. It is only used to gracefully quit in the case of HA so that failover +// happens as quickly as possible. +func (c *Core) Shutdown() error { + c.stateLock.Lock() + defer c.stateLock.Unlock() + if c.sealed { + return nil + } + + // Seal the Vault, causes a leader stepdown + return c.sealInternal() +} + // HandleRequest is used to handle a new incoming request func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) { c.stateLock.RLock() @@ -930,6 +945,14 @@ func (c *Core) Seal(token string) error { return err } + // Seal the Vault + return c.sealInternal() +} + +// sealInternal is an internal method used to seal the vault. +// It does not do any authorization checking. The stateLock must +// be held prior to calling. +func (c *Core) sealInternal() error { // Enable that we are sealed to prevent furthur transactions c.sealed = true diff --git a/vault/core_test.go b/vault/core_test.go index 25cebf1943..4b48c59acd 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -348,6 +348,17 @@ func TestCore_SealUnseal(t *testing.T) { } } +// Attempt to shutdown after unseal +func TestCore_Shutdown(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + if err := c.Shutdown(); err != nil { + t.Fatalf("err: %v", err) + } + if sealed, err := c.Sealed(); err != nil || !sealed { + t.Fatalf("err: %v", err) + } +} + // Attempt to seal bad token func TestCore_Seal_BadToken(t *testing.T) { c, _, _ := TestCoreUnsealed(t) From 70ee1866ca35f60a209fd88a48f18b02b040cc2c Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:24:56 -0700 Subject: [PATCH 07/21] server: graceful shutdown for fast failover. Fixes #308 --- cli/commands.go | 20 ++++++++++++++++++++ command/server.go | 14 +++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/cli/commands.go b/cli/commands.go index 1a7916c742..074b80dd11 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -2,6 +2,8 @@ package cli import ( "os" + "os/signal" + "syscall" auditFile "github.com/hashicorp/vault/builtin/audit/file" auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" @@ -68,6 +70,7 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { "transit": transit.Factory, "mysql": mysql.Factory, }, + ShutdownCh: makeShutdownCh(), }, nil }, @@ -268,3 +271,20 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { }, } } + +// makeShutdownCh returns a channel that can be used for shutdown +// notifications for commands. This channel will send a message for every +// interrupt or SIGTERM received. +func makeShutdownCh() <-chan struct{} { + resultCh := make(chan struct{}) + + signalCh := make(chan os.Signal, 4) + signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) + go func() { + for { + <-signalCh + resultCh <- struct{}{} + } + }() + return resultCh +} diff --git a/command/server.go b/command/server.go index 2796b22cda..ed842a26f8 100644 --- a/command/server.go +++ b/command/server.go @@ -32,6 +32,7 @@ type ServerCommand struct { CredentialBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory + ShutdownCh <-chan struct{} Meta } @@ -237,7 +238,14 @@ func (c *ServerCommand) Run(args []string) int { // Release the log gate. logGate.Flush() - <-make(chan struct{}) + // Wait for shutdown + select { + case <-c.ShutdownCh: + c.Ui.Output("==> Vault shutdown triggered") + if err := core.Shutdown(); err != nil { + c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err)) + } + } return 0 } @@ -407,8 +415,8 @@ General Options: specified multiple times. If it is a directory, all files with a ".hcl" or ".json" suffix will be loaded. - -dev Enables Dev mode. In this mode, Vault is completely - in-memory and unsealed. Do not run the Dev server in + -dev Enables Dev mode. In this mode, Vault is completely + in-memory and unsealed. Do not run the Dev server in production! -log-level=info Log verbosity. Defaults to "info", will be outputted From 0277cedc8acae3e892d654750d8bc1940c1560d4 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:33:15 -0700 Subject: [PATCH 08/21] cmomand/read: strip path prefix if necessary. Fixes #343 --- command/read.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/command/read.go b/command/read.go index 8f94bcba8f..85d23b73e3 100644 --- a/command/read.go +++ b/command/read.go @@ -27,7 +27,11 @@ func (c *ReadCommand) Run(args []string) int { flags.Usage() return 1 } + path := args[0] + if path[0] == '/' { + path = path[1:] + } client, err := c.Client() if err != nil { @@ -98,7 +102,7 @@ Read Options: -format=table The format for output. By default it is a whitespace- delimited table. This can also be json. - -field=field If included, the raw value of the specified field + -field=field If included, the raw value of the specified field will be output raw to stdout. ` From 9238c6def36195a99acf210a1b60b699a6099d92 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:42:23 -0700 Subject: [PATCH 09/21] secret/transit: Use special endpoint to get underlying keys. Fixes #219 --- builtin/logical/transit/backend.go | 2 + builtin/logical/transit/backend_test.go | 39 ++++++++++++++++++ builtin/logical/transit/path_keys.go | 1 - builtin/logical/transit/path_raw.go | 54 +++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 builtin/logical/transit/path_raw.go diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index b63f903b13..d719d42c64 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -15,11 +15,13 @@ func Backend() *framework.Backend { PathsSpecial: &logical.Paths{ Root: []string{ "keys/*", + "raw/*", }, }, Paths: []*framework.Path{ pathKeys(), + pathRaw(), pathEncrypt(), pathDecrypt(), }, diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index c4c7c3d435..cb2a5e1872 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -21,10 +21,12 @@ func TestBackend_basic(t *testing.T) { Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test"), testAccStepReadPolicy(t, "test", false), + testAccStepReadRaw(t, "test", false), testAccStepEncrypt(t, "test", testPlaintext, decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData), testAccStepDeletePolicy(t, "test"), testAccStepReadPolicy(t, "test", true), + testAccStepReadRaw(t, "test", true), }, }) } @@ -65,6 +67,43 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicalte return err } + if d.Name != name { + return fmt.Errorf("bad: %#v", d) + } + if d.CipherMode != "aes-gcm" { + return fmt.Errorf("bad: %#v", d) + } + // Should NOT get a key back + if d.Key != nil { + return fmt.Errorf("bad: %#v", d) + } + return nil + }, + } +} + +func testAccStepReadRaw(t *testing.T, name string, expectNone bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "raw/" + name, + Check: func(resp *logical.Response) error { + if resp == nil && !expectNone { + return fmt.Errorf("missing response") + } else if expectNone { + if resp != nil { + return fmt.Errorf("response when expecting none") + } + return nil + } + var d struct { + Name string `mapstructure:"name"` + Key []byte `mapstructure:"key"` + CipherMode string `mapstructure:"cipher_mode"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + if d.Name != name { return fmt.Errorf("bad: %#v", d) } diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index b856964689..9b2424004a 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -124,7 +124,6 @@ func pathPolicyRead( resp := &logical.Response{ Data: map[string]interface{}{ "name": p.Name, - "key": p.Key, "cipher_mode": p.CipherMode, }, } diff --git a/builtin/logical/transit/path_raw.go b/builtin/logical/transit/path_raw.go new file mode 100644 index 0000000000..ebe411a571 --- /dev/null +++ b/builtin/logical/transit/path_raw.go @@ -0,0 +1,54 @@ +package transit + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathRaw() *framework.Path { + return &framework.Path{ + Pattern: `raw/(?P\w+)`, + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the key", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: pathRawRead, + }, + + HelpSynopsis: pathPolicyHelpSyn, + HelpDescription: pathPolicyHelpDesc, + } +} + +func pathRawRead( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + if p == nil { + return nil, nil + } + + // Return the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "name": p.Name, + "key": p.Key, + "cipher_mode": p.CipherMode, + }, + } + return resp, nil +} + +const pathRawHelpSyn = `Fetch raw keys for named encrption keys` + +const pathRawHelpDesc = ` +This path is used to get the underlying encryption keys used for the +named keys that are available. +` From 7c31e29295aa493a4c08b0e2e20865a709d1c98f Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:45:29 -0700 Subject: [PATCH 10/21] website: update the transit documentation --- .../source/docs/secrets/transit/index.html.md | 61 +++++++++++++++---- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/website/source/docs/secrets/transit/index.html.md b/website/source/docs/secrets/transit/index.html.md index 3506eca03a..ae04b2794b 100644 --- a/website/source/docs/secrets/transit/index.html.md +++ b/website/source/docs/secrets/transit/index.html.md @@ -54,6 +54,15 @@ $ vault read transit/keys/foo Key Value name foo cipher_mode aes-gcm +```` + +We can read from the `raw/` endpoint to see the encryption key itself: + +``` +$ vault read transit/raw/foo +Key Value +name foo +cipher_mode aes-gcm key PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8= ```` @@ -114,17 +123,7 @@ only encrypt or decrypt using the named keys they need access to.
Returns
- - ```javascript - { - "data": { - "name": "foo", - "cipher_mode": "aes-gcm", - "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" - } - } - ``` - + A `204` response code.
@@ -156,7 +155,6 @@ only encrypt or decrypt using the named keys they need access to. "data": { "name": "foo", "cipher_mode": "aes-gcm", - "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" } } ``` @@ -269,3 +267,42 @@ only encrypt or decrypt using the named keys they need access to. + +### /transit/raw/ +#### GET + +
+
Description
+
+ Returns raw information about a named encryption key, + Including the underlying encryption key. This is a root protected endpoint. +
+ +
Method
+
GET
+ +
URL
+
`/transit/raw/`
+ +
Parameters
+
+ None +
+ +
Returns
+
+ + ```javascript + { + "data": { + "name": "foo", + "cipher_mode": "aes-gcm", + "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" + } + } + ``` + +
+
+ + From 96119946f348e3d6ab7c6acedf9c00e46cf59086 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:51:05 -0700 Subject: [PATCH 11/21] secret/transit: allow policies to be upserted --- builtin/logical/transit/backend_test.go | 15 ++++++ builtin/logical/transit/path_encrypt.go | 6 ++- builtin/logical/transit/path_keys.go | 66 ++++++++++++++----------- 3 files changed, 58 insertions(+), 29 deletions(-) diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index cb2a5e1872..c03ab0e326 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -31,6 +31,21 @@ func TestBackend_basic(t *testing.T) { }) } +func TestBackend_upsert(t *testing.T) { + decryptData := make(map[string]interface{}) + logicaltest.Test(t, logicaltest.TestCase{ + Backend: Backend(), + Steps: []logicaltest.TestStep{ + testAccStepReadPolicy(t, "test", true), + testAccStepEncrypt(t, "test", testPlaintext, decryptData), + testAccStepReadPolicy(t, "test", false), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepDeletePolicy(t, "test"), + testAccStepReadPolicy(t, "test", true), + }, + }) +} + func testAccStepWritePolicy(t *testing.T, name string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.WriteOperation, diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index d30d72f3b4..761af7e352 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" + "fmt" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -56,7 +57,10 @@ func pathEncryptWrite( // Error if invalid policy if p == nil { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + p, err = generatePolicy(req.Storage, name) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("failed to upsert policy: %v", err)), logical.ErrInvalidRequest + } } // Guard against a potentially invalid cipher-mode diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 9b2424004a..f83ae4037e 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -45,6 +45,41 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) { return p, nil } +// generatePolicy is used to create a new named policy with +// a randomly generated key +func generatePolicy(storage logical.Storage, name string) (*Policy, error) { + // Create the policy object + p := &Policy{ + Name: name, + CipherMode: "aes-gcm", + } + + // Generate a 256bit key + p.Key = make([]byte, 32) + _, err := rand.Read(p.Key) + if err != nil { + return nil, err + } + + // Encode the policy + buf, err := p.Serialize() + if err != nil { + return nil, err + } + + // Write the policy into storage + err = storage.Put(&logical.StorageEntry{ + Key: "policy/" + name, + Value: buf, + }) + if err != nil { + return nil, err + } + + // Return the policy + return p, nil +} + func pathKeys() *framework.Path { return &framework.Path{ Pattern: `keys/(?P\w+)`, @@ -79,34 +114,9 @@ func pathPolicyWrite( return nil, nil } - // Create the policy object - p := &Policy{ - Name: name, - CipherMode: "aes-gcm", - } - - // Generate a 256bit key - p.Key = make([]byte, 32) - _, err = rand.Read(p.Key) - if err != nil { - return nil, err - } - - // Encode the policy - buf, err := p.Serialize() - if err != nil { - return nil, err - } - - // Write the policy into storage - err = req.Storage.Put(&logical.StorageEntry{ - Key: "policy/" + name, - Value: buf, - }) - if err != nil { - return nil, err - } - return nil, nil + // Generate the policy + _, err = generatePolicy(req.Storage, name) + return nil, err } func pathPolicyRead( From ba24d891fd907fff44ec86fb36b5d0c345840c50 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:51:58 -0700 Subject: [PATCH 12/21] website: document transit upsert behavior --- website/source/docs/secrets/transit/index.html.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/website/source/docs/secrets/transit/index.html.md b/website/source/docs/secrets/transit/index.html.md index ae04b2794b..743ba0662e 100644 --- a/website/source/docs/secrets/transit/index.html.md +++ b/website/source/docs/secrets/transit/index.html.md @@ -194,7 +194,9 @@ only encrypt or decrypt using the named keys they need access to.
Description
- Encrypts the provided plaintext using the named key. + Encrypts the provided plaintext using the named key. If the named key + does not already exist, it will be automatically generated for the given + name with the default parameters.
Method
From ee176b2f5de65274deba8ec525a90755f7c59605 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 19:19:02 -0700 Subject: [PATCH 13/21] command/auth: warn about the VAULT_TOKEN env var. Fixes #195 --- command/auth.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/command/auth.go b/command/auth.go index 130dd7d278..46ea4657a5 100644 --- a/command/auth.go +++ b/command/auth.go @@ -173,6 +173,15 @@ func (c *AuthCommand) Run(args []string) int { return 1 } + // Warn if the VAULT_TOKEN environment variable is set, as that will take + // precedence + if os.Getenv("VAULT_TOKEN") != "" { + c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") + c.Ui.Output(" The environment variable takes precedence over the value") + c.Ui.Output(" set by the auth command. Either update the value of the") + c.Ui.Output(" environment variable or unset it to use the new token.\n") + } + // Get the policies we have policiesRaw, ok := secret.Data["policies"] if !ok { From 0696bc47e0c2e128b8cf89dea2826a1ac2de5ebc Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 13:48:04 -0700 Subject: [PATCH 14/21] command/auth: warn earlier about VAULT_TOKEN --- command/auth.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/command/auth.go b/command/auth.go index 46ea4657a5..f31b2a358a 100644 --- a/command/auth.go +++ b/command/auth.go @@ -114,6 +114,15 @@ func (c *AuthCommand) Run(args []string) int { return 0 } + // Warn if the VAULT_TOKEN environment variable is set, as that will take + // precedence + if os.Getenv("VAULT_TOKEN") != "" { + c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") + c.Ui.Output(" The environment variable takes precedence over the value") + c.Ui.Output(" set by the auth command. Either update the value of the") + c.Ui.Output(" environment variable or unset it to use the new token.\n") + } + var vars map[string]string if len(args) > 0 { builder := kvbuilder.Builder{Stdin: os.Stdin} @@ -173,15 +182,6 @@ func (c *AuthCommand) Run(args []string) int { return 1 } - // Warn if the VAULT_TOKEN environment variable is set, as that will take - // precedence - if os.Getenv("VAULT_TOKEN") != "" { - c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") - c.Ui.Output(" The environment variable takes precedence over the value") - c.Ui.Output(" set by the auth command. Either update the value of the") - c.Ui.Output(" environment variable or unset it to use the new token.\n") - } - // Get the policies we have policiesRaw, ok := secret.Data["policies"] if !ok { From 57d1230e6c38fba51d8d29904ef25c07eed7bfba Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 13:48:18 -0700 Subject: [PATCH 15/21] command/server: fixing output weirdness --- command/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/command/server.go b/command/server.go index ed842a26f8..f3db676d76 100644 --- a/command/server.go +++ b/command/server.go @@ -155,7 +155,7 @@ func (c *ServerCommand) Run(args []string) int { "immediately begin using the Vault CLI.\n\n"+ "The only step you need to take is to set the following\n"+ "environment variables:\n\n"+ - " export VAULT_ADDR='http://127.0.0.1:8200'\n"+ + " export VAULT_ADDR='http://127.0.0.1:8200'\n\n"+ "The unseal key and root token are reproduced below in case you\n"+ "want to seal/unseal the Vault or play with authentication.\n\n"+ "Unseal Key: %s\nRoot Token: %s\n", From 27728075478ec48070fb9b0f1670b6e108cbeedb Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 13:51:06 -0700 Subject: [PATCH 16/21] command/write: adding force flag for when no data fields are necessary. Fixes #357 --- command/write.go | 11 ++++++++++- command/write_test.go | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/command/write.go b/command/write.go index fefa23e72d..ba0fdd823d 100644 --- a/command/write.go +++ b/command/write.go @@ -19,15 +19,18 @@ type WriteCommand struct { func (c *WriteCommand) Run(args []string) int { var format string + var force bool flags := c.Meta.FlagSet("write", FlagSetDefault) flags.StringVar(&format, "format", "table", "") + flags.BoolVar(&force, "force", false, "") + flags.BoolVar(&force, "f", false, "") flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { return 1 } args = flags.Args() - if len(args) < 2 { + if len(args) < 2 && !force { c.Ui.Error("write expects at least two arguments") flags.Usage() return 1 @@ -117,6 +120,12 @@ General Options: not recommended. This is especially not recommended for unsealing a vault. +Write Options: + + -f | -force Force the write to continue without any data values + specified. This allows writing to keys that do not + need or expect any fields to be specified. + ` return strings.TrimSpace(helpText) } diff --git a/command/write_test.go b/command/write_test.go index ce570f3730..51774e3c0b 100644 --- a/command/write_test.go +++ b/command/write_test.go @@ -246,3 +246,26 @@ func TestWrite_Output(t *testing.T) { t.Fatalf("bad: %s", string(ui.OutputWriter.Bytes())) } } + +func TestWrite_force(t *testing.T) { + core, _, token := vault.TestCoreUnsealed(t) + ln, addr := http.TestServer(t, core) + defer ln.Close() + + ui := new(cli.MockUi) + c := &WriteCommand{ + Meta: Meta{ + ClientToken: token, + Ui: ui, + }, + } + + args := []string{ + "-address", addr, + "-force", + "sys/rotate", + } + if code := c.Run(args); code != 0 { + t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + } +} From 46ba8d10a555b8d5287fd0b47812574b036e2e06 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 14:31:00 -0700 Subject: [PATCH 17/21] physical/mysql: cleanup and documentation --- physical/mysql.go | 163 ++++++++++------------- physical/mysql_test.go | 77 ++--------- website/source/docs/config/index.html.md | 17 +++ 3 files changed, 92 insertions(+), 165 deletions(-) diff --git a/physical/mysql.go b/physical/mysql.go index bb4a89c5e5..5db23b9384 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -2,7 +2,6 @@ package physical import ( "database/sql" - "errors" "fmt" "sort" "strings" @@ -12,16 +11,10 @@ import ( _ "github.com/go-sql-driver/mysql" ) -var ( - MySQLPrepareStmtFailure = errors.New("failed to prepare statement") - MySQLExecuteStmtFailure = errors.New("failed to execute statement") -) - // MySQLBackend is a physical backend that stores data // within MySQL database. type MySQLBackend struct { - table string - database string + dbTable string client *sql.DB statements map[string]*sql.Stmt } @@ -29,84 +22,85 @@ type MySQLBackend struct { // newMySQLBackend constructs a MySQL backend using the given API client and // server address and credential for accessing mysql database. func newMySQLBackend(conf map[string]string) (Backend, error) { + // Get the MySQL credentials to perform read/write operations. + username, ok := conf["username"] + if !ok || username == "" { + return nil, fmt.Errorf("missing username") + } + password, ok := conf["password"] + if !ok || username == "" { + return nil, fmt.Errorf("missing password") + } + // Get or set MySQL server address. Defaults to localhost and default port(3306) address, ok := conf["address"] if !ok { address = "127.0.0.1:3306" } - // Get the MySQL credentials to perform read/write operations. - username, ok := conf["username"] - password, ok := conf["password"] - // Get the MySQL database and table details. database, ok := conf["database"] if !ok { - return nil, fmt.Errorf("database name is missing in the configuration") + database = "vault" } table, ok := conf["table"] if !ok { - return nil, fmt.Errorf("table name is missing in the configuration") + table = "vault" } + dbTable := database + "." + table // Create MySQL handle for the database. - dsn := username + ":" + password + "@tcp(" + address + ")/" + database + dsn := username + ":" + password + "@tcp(" + address + ")/" db, err := sql.Open("mysql", dsn) if err != nil { - return nil, fmt.Errorf("failed to open handler with database") + return nil, fmt.Errorf("failed to connect to mysql: %v", err) + } + + // Create the required database if it doesn't exists. + if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS " + database); err != nil { + return nil, fmt.Errorf("failed to create mysql database: %v", err) } // Create the required table if it doesn't exists. - create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" - create_stmt, err := db.Prepare(create_query) - if err != nil { - return nil, MySQLPrepareStmtFailure + create_query := "CREATE TABLE IF NOT EXISTS " + dbTable + + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" + if _, err := db.Exec(create_query); err != nil { + return nil, fmt.Errorf("failed to create mysql table: %v", err) } - defer create_stmt.Close() - - _, err = create_stmt.Exec() - if err != nil { - return nil, MySQLExecuteStmtFailure - } - - // Map of query type as key to prepared statement. - statements := make(map[string]*sql.Stmt) - - // Prepare statement for put query. - insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" - insert_stmt, err := db.Prepare(insert_query) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - statements["put"] = insert_stmt - - // Prepare statement for select query. - select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?" - select_stmt, err := db.Prepare(select_query) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - statements["get"] = select_stmt - - // Prepare statement for delete query. - delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?" - delete_stmt, err := db.Prepare(delete_query) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - statements["delete"] = delete_stmt // Setup the backend. m := &MySQLBackend{ + dbTable: dbTable, client: db, - table: table, - database: database, - statements: statements, + statements: make(map[string]*sql.Stmt), } + // Prepare all the statements required + statements := map[string]string{ + "put": "INSERT INTO " + dbTable + + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)", + "get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?", + "delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?", + "list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?", + } + for name, query := range statements { + if err := m.prepare(name, query); err != nil { + return nil, err + } + } return m, nil } +// prepare is a helper to prepare a query for future execution +func (m *MySQLBackend) prepare(name, query string) error { + stmt, err := m.client.Prepare(query) + if err != nil { + return fmt.Errorf("failed to prepare '%s': %v", name, err) + } + m.statements[name] = stmt + return nil +} + // Put is used to insert or update an entry. func (m *MySQLBackend) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) @@ -115,7 +109,6 @@ func (m *MySQLBackend) Put(entry *Entry) error { if err != nil { return err } - return nil } @@ -124,22 +117,18 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) var result []byte - err := m.statements["get"].QueryRow(key).Scan(&result) - if err != nil { + if err == sql.ErrNoRows { return nil, nil } - - // Handle a non-existing value - if result == nil { - return nil, nil + if err != nil { + return nil, err } ent := &Entry{ Key: key, Value: result, } - return ent, nil } @@ -151,7 +140,6 @@ func (m *MySQLBackend) Delete(key string) error { if err != nil { return err } - return nil } @@ -160,45 +148,28 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) - // Query to get all keys matching a prefix. - list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'" - rows, err := m.client.Query(list_query) - if err != nil { - return nil, MySQLExecuteStmtFailure - } + // Add the % wildcard to the prefix to do the prefix search + likePrefix := prefix + "%" + rows, err := m.statements["list"].Query(likePrefix) - columns, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("failed to get columns") - } - - values := make([]sql.RawBytes, len(columns)) - - scanArgs := make([]interface{}, len(values)) - for i := range values { - scanArgs[i] = &values[i] - } - - keys := []string{} + var keys []string for rows.Next() { - err = rows.Scan(scanArgs...) + var key string + err = rows.Scan(&key) if err != nil { - return nil, fmt.Errorf("failed to scan rows") + return nil, fmt.Errorf("failed to scan rows: %v", err) } - for _, key := range values { - key := strings.TrimPrefix(string(key), prefix) - if i := strings.Index(string(key), "/"); i == -1 { - // Add objects only from the current 'folder' - keys = append(keys, string(key)) - } else if i != -1 { - // Add truncated 'folder' paths - keys = appendIfMissing(keys, string(key[:i+1])) - } + key = strings.TrimPrefix(key, prefix) + if i := strings.Index(key, "/"); i == -1 { + // Add objects only from the current 'folder' + keys = append(keys, key) + } else if i != -1 { + // Add truncated 'folder' paths + keys = appendIfMissing(keys, string(key[:i+1])) } } sort.Strings(keys) - return keys, nil } diff --git a/physical/mysql_test.go b/physical/mysql_test.go index 2b00404ffb..a28fb1441f 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -1,8 +1,6 @@ package physical import ( - "database/sql" - "fmt" "os" "testing" @@ -28,61 +26,6 @@ func TestMySQLBackend(t *testing.T) { username := os.Getenv("MYSQL_USERNAME") password := os.Getenv("MYSQL_PASSWORD") - // Create MySQL handle for the database. - db, err := sql.Open("mysql", username+":"+password+"@tcp("+address+")/"+database) - - if err != nil { - t.Fatalf("Failed to open an handler with database: %v", err) - } - defer db.Close() - - // Prepare statement for creating table. - create_stmt := "CREATE TABLE IF NOT EXISTS test.square (num int, sqr int, PRIMARY KEY (num))" - stmtCrt, err := db.Prepare(create_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmtCrt.Close() - - // Create table - _, err = stmtCrt.Exec() - if err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - // Prepare statement for inserting data. - insert_stmt := "INSERT INTO test.square VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" - stmtIns, err := db.Prepare(insert_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmtIns.Close() - - // Prepare statement for reading data. - select_stmt := "SELECT sqr FROM test.square WHERE num = ?" - stmtOut, err := db.Prepare(select_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmtOut.Close() - - // Insert square numbers for 0-24 in the database - for i := 0; i < 25; i++ { - _, err = stmtIns.Exec(i, (i * i)) // Insert tuples (i, i^2) - if err != nil { - t.Fatalf("Failed to insert data: %v", err) - } - } - - var square int - - // Query the square-number of 13 - err = stmtOut.QueryRow(13).Scan(&square) - if err != nil { - t.Fatalf("Failed to query data: %v", err) - } - fmt.Printf("The square number of 13 is: %d", square) - // Run vault tests b, err := NewBackend("mysql", map[string]string{ "address": address, @@ -96,19 +39,15 @@ func TestMySQLBackend(t *testing.T) { t.Fatalf("Failed to create new backend: %v", err) } + defer func() { + mysql := b.(*MySQLBackend) + _, err := mysql.client.Exec("DROP TABLE " + mysql.dbTable) + if err != nil { + t.Fatalf("Failed to drop table: %v", err) + } + }() + testBackend(t, b) testBackend_ListPrefix(t, b) - // Drop table after running tests - drop_stmt := "DROP TABLE " + database + "." + table - stmt, err := db.Prepare(drop_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmt.Close() - - _, err = stmt.Exec() - if err != nil { - t.Fatalf("Failed to drop table: %v", err) - } } diff --git a/website/source/docs/config/index.html.md b/website/source/docs/config/index.html.md index 278b4d9549..8cae4df7c6 100644 --- a/website/source/docs/config/index.html.md +++ b/website/source/docs/config/index.html.md @@ -76,6 +76,8 @@ durability, etc. * `s3` - Store data within an S3 bucket [S3](http://aws.amazon.com/s3/). This backend does not support HA. + * `mysql` - Store data within MySQL. This backend does not support HA. + * `inmem` - Store data in-memory. This is only really useful for development and experimentation. Data is lost whenever Vault is restarted. @@ -143,6 +145,21 @@ For S3, the following options are supported: * `region` (optional) - The AWS region. It can be sourced from the AWS_DEFAULT_REGION environment variable and will default to "us-east-1" if not specified. +#### Backend Reference: MySQL + +The MySQL backend has the following options: + + * `username` (required) - The MySQL username to connect with. + + * `password` (required) - The MySQL password to connect with. + + * `address` (optional) - The address of the MySQL host. Defaults to + "127.0.0.1:3306. + + * `database` (optional) - The name of the database to use. Defaults to "vault". + + * `table` (optional) - The name of the table to use. Defaults to "vault". + #### Backend Reference: Inmem The in-memory backend has no configuration options. From 2d0cde4ccc0323591d9414342cb15f5cb70271d7 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 15:37:08 -0700 Subject: [PATCH 18/21] vault: improve lease error message. Fixes #338 --- vault/expiration.go | 2 +- vault/expiration_test.go | 2 +- vault/logical_system_test.go | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vault/expiration.go b/vault/expiration.go index 294624d6e0..3df1738ed2 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -635,7 +635,7 @@ func (l *leaseEntry) encode() ([]byte, error) { func (le *leaseEntry) renewable() error { // If there is no entry, cannot review if le == nil || le.ExpireTime.IsZero() { - return fmt.Errorf("lease not found") + return fmt.Errorf("lease not found or lease is not renewable") } // Determine if the lease is expired diff --git a/vault/expiration_test.go b/vault/expiration_test.go index f7b581be25..0fc4484bd8 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -152,7 +152,7 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) { // Should not be able to renew, no expiration _, err = exp.RenewToken("auth/github/login", root.ID, 0) - if err.Error() != "lease not found" { + if err.Error() != "lease not found or lease is not renewable" { t.Fatalf("err: %v", err) } diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index fd23f59aa4..33749ae14f 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -202,7 +202,7 @@ func TestSystemBackend_renew_invalidID(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp.Data["error"] != "lease not found" { + if resp.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } @@ -250,7 +250,7 @@ func TestSystemBackend_revoke(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp3.Data["error"] != "lease not found" { + if resp3.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } @@ -312,7 +312,7 @@ func TestSystemBackend_revokePrefix(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp3.Data["error"] != "lease not found" { + if resp3.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } From 48e7531f7954a789c44948a42a78bfa237ad638c Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 15:56:42 -0700 Subject: [PATCH 19/21] command/path-help: rename command, better error if sealed. Fixes #234 --- cli/commands.go | 4 ++-- command/{help.go => path_help.go} | 26 ++++++++++++++------- command/{help_test.go => path_help_test.go} | 2 +- 3 files changed, 21 insertions(+), 11 deletions(-) rename command/{help.go => path_help.go} (69%) rename command/{help_test.go => path_help_test.go} (95%) diff --git a/cli/commands.go b/cli/commands.go index 074b80dd11..d9cc3a8624 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -74,8 +74,8 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { }, nil }, - "help": func() (cli.Command, error) { - return &command.HelpCommand{ + "path-help": func() (cli.Command, error) { + return &command.PathHelpCommand{ Meta: meta, }, nil }, diff --git a/command/help.go b/command/path_help.go similarity index 69% rename from command/help.go rename to command/path_help.go index f832f07bc6..792ea9e3a4 100644 --- a/command/help.go +++ b/command/path_help.go @@ -5,12 +5,12 @@ import ( "strings" ) -// HelpCommand is a Command that lists the mounts. -type HelpCommand struct { +// PathHelpCommand is a Command that lists the mounts. +type PathHelpCommand struct { Meta } -func (c *HelpCommand) Run(args []string) int { +func (c *PathHelpCommand) Run(args []string) int { flags := c.Meta.FlagSet("help", FlagSetDefault) flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { @@ -35,8 +35,15 @@ func (c *HelpCommand) Run(args []string) int { help, err := client.Help(path) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading help: %s", err)) + if strings.Contains(err.Error(), "Vault is sealed") { + c.Ui.Error(`Error: Vault is sealed. + +The path-help command requires the Vault to be unsealed so that +mount points of secret backends are known.`) + } else { + c.Ui.Error(fmt.Sprintf( + "Error reading help: %s", err)) + } return 1 } @@ -44,13 +51,13 @@ func (c *HelpCommand) Run(args []string) int { return 0 } -func (c *HelpCommand) Synopsis() string { +func (c *PathHelpCommand) Synopsis() string { return "Look up the help for a path" } -func (c *HelpCommand) Help() string { +func (c *PathHelpCommand) Help() string { helpText := ` -Usage: vault help [options] path +Usage: vault path-help [options] path Look up the help for a path. @@ -58,6 +65,9 @@ Usage: vault help [options] path providers provide built-in help. This command looks up and outputs that help. + The command requires that the Vault be unsealed, because otherwise + the mount points of the backends are unknown. + General Options: -address=addr The address of the Vault server. diff --git a/command/help_test.go b/command/path_help_test.go similarity index 95% rename from command/help_test.go rename to command/path_help_test.go index c4facc0ca8..faec9723d9 100644 --- a/command/help_test.go +++ b/command/path_help_test.go @@ -14,7 +14,7 @@ func TestHelp(t *testing.T) { defer ln.Close() ui := new(cli.MockUi) - c := &HelpCommand{ + c := &PathHelpCommand{ Meta: Meta{ ClientToken: token, Ui: ui, From f91b91289ce18aab351068d55b49c0b0b2542cf6 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 16:00:41 -0700 Subject: [PATCH 20/21] command/read: Ensure only a single argument. Fixes #304 --- command/read.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/command/read.go b/command/read.go index 85d23b73e3..983ae56bb3 100644 --- a/command/read.go +++ b/command/read.go @@ -22,8 +22,8 @@ func (c *ReadCommand) Run(args []string) int { } args = flags.Args() - if len(args) < 1 || len(args) > 2 { - c.Ui.Error("read expects one or two arguments") + if len(args) != 1 { + c.Ui.Error("read expects one argument") flags.Usage() return 1 } From 8c970cf000e9a7f28f9aa94c0cdcf4c4fc874ab8 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 17:12:21 -0700 Subject: [PATCH 21/21] cli: adding path-help to common commands list --- cli/help.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cli/help.go b/cli/help.go index 6c3e63d8b2..b614212c4f 100644 --- a/cli/help.go +++ b/cli/help.go @@ -13,14 +13,14 @@ import ( // HelpFunc is a cli.HelpFunc that can is used to output the help for Vault. func HelpFunc(commands map[string]cli.CommandFactory) string { commonNames := map[string]struct{}{ - "delete": struct{}{}, - "help": struct{}{}, - "read": struct{}{}, - "renew": struct{}{}, - "revoke": struct{}{}, - "write": struct{}{}, - "server": struct{}{}, - "status": struct{}{}, + "delete": struct{}{}, + "path-help": struct{}{}, + "read": struct{}{}, + "renew": struct{}{}, + "revoke": struct{}{}, + "write": struct{}{}, + "server": struct{}{}, + "status": struct{}{}, } // Determine the maximum key length, and classify based on type