From 9c3881442e1b438b9a100e277e125732a2054ba6 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Mon, 8 Jun 2015 16:17:44 +0545 Subject: [PATCH 01/33] Physical MySQL backend implementation - First Cut --- physical/mysql.go | 190 +++++++++++++++++++++++++++++++++++++++++ physical/mysql_test.go | 86 +++++++++++++++++++ physical/physical.go | 1 + 3 files changed, 277 insertions(+) create mode 100644 physical/mysql.go create mode 100644 physical/mysql_test.go diff --git a/physical/mysql.go b/physical/mysql.go new file mode 100644 index 0000000000..377ac9dc49 --- /dev/null +++ b/physical/mysql.go @@ -0,0 +1,190 @@ +package physical + +import ( + "database/sql" + "errors" + "sort" + "strings" + "time" + + "github.com/armon/go-metrics" + _ "github.com/go-sql-driver/mysql" +) + +var ( + MySQLDBNameMissing = errors.New("database name is missing in the configuration") + MySQLTableNameMissing = errors.New("table name is missing in the configuration") + MySQLHandlerCreationFailure = errors.New("failed to open handler with database") + MySQLPrepareStmtFailure = errors.New("failed to prepare statement") + MySQLExecuteStmtFailure = errors.New("failed to execute statement") + MySQLGetColumnsFailure = errors.New("failed to get columns") + MySQLScanRowsFailure = errors.New("failed to scan rows") +) + +// MySQLBackend is a physical backend that stores data +// within MySQL database. +type MySQLBackend struct { + table string + database string + client *sql.DB +} + +// newMySQLBackend constructs a MySQL backend using the given API client and +// server address and credential for accessing mysql database. +func newMySQLBackend(conf map[string]string) (Backend, error) { + // Get or set MySQL server address. Defaults to localhost and default port(3306) + address, ok := conf["address"] + if !ok { + address = "127.0.0.1:3306" + } + + // Get the MySQL credentials to perform read/write operations. + username, ok := conf["username"] + password, ok := conf["password"] + + // Get the MySQL database and table details. + database, ok := conf["database"] + if !ok { + return nil, MySQLDBNameMissing + } + table, ok := conf["table"] + if !ok { + return nil, MySQLTableNameMissing + } + + // Create MySQL handle for the database. + dsn := username + ":" + password + "@tcp(" + address + ")/" + database + db, err := sql.Open("mysql", dsn) + if err != nil { + return nil, MySQLHandlerCreationFailure + } + defer db.Close() + + // Create the required table. + create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))" + stmt, err := db.Prepare(create_stmt) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + defer stmt.Close() + + _, err = stmt.Exec() + if err != nil { + return nil, MySQLExecuteStmtFailure + } + + // Setup the backend. + m := &MySQLBackend{ + client: db, + table: table, + database: database, + } + + return m, nil +} + +// Put is used to insert or update an entry. +func (m *MySQLBackend) Put(entry *Entry) error { + defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) + + insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" + stmt, err := m.client.Prepare(insert_stmt) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(entry.Key, entry.Value) + if err != nil { + return err + } + + return nil +} + +// Get is used to fetch and entry. +func (m *MySQLBackend) Get(key string) (*Entry, error) { + defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) + + select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?" + stmt, err := m.client.Prepare(select_stmt) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + defer stmt.Close() + + var result []byte + + err = stmt.QueryRow(key).Scan(&result) + if err != nil { + return nil, MySQLExecuteStmtFailure + } + + ent := &Entry{ + Key: key, + Value: result, + } + + return ent, nil +} + +// Delete is used to permanently delete an entry +func (m *MySQLBackend) Delete(key string) error { + defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) + + delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?" + stmt, err := m.client.Prepare(delete_stmt) + if err != nil { + return err + } + defer stmt.Close() + + _, err = stmt.Exec(key) + if err != nil { + return err + } + + return nil +} + +// List is used to list all the keys under a given +// prefix, up to the next prefix. +func (m *MySQLBackend) List(prefix string) ([]string, error) { + defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) + + list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table + rows, err := m.client.Query(list_stmt) + if err != nil { + return nil, MySQLExecuteStmtFailure + } + + columns, err := rows.Columns() + if err != nil { + return nil, MySQLGetColumnsFailure + } + + values := make([]sql.RawBytes, len(columns)) + + scanArgs := make([]interface{}, len(values)) + for i := range values { + scanArgs[i] = &values[i] + } + + keys := []string{} + for rows.Next() { + err = rows.Scan(scanArgs...) + if err != nil { + return nil, MySQLScanRowsFailure + } + + for _, col := range values { + if strings.HasPrefix(string(col), prefix) { + keys = append(keys, string(col)) + } + } + } + + sort.Strings(keys) + + return keys, nil +} diff --git a/physical/mysql_test.go b/physical/mysql_test.go new file mode 100644 index 0000000000..b808a29bdb --- /dev/null +++ b/physical/mysql_test.go @@ -0,0 +1,86 @@ +package physical + +import ( + "database/sql" + "fmt" + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" +) + +func TestMySQLBackend(t *testing.T) { + address := os.Getenv("MYSQL_ADDR") + if address == "" { + t.SkipNow() + } + + database := os.Getenv("MYSQL_DB") + if database == "" { + database = "test" + } + + table := os.Getenv("MYSQL_TABLE") + if table == "" { + table = "test" + } + + username := os.Getenv("MYSQL_USERNAME") + password := os.Getenv("MYSQL_PASSWORD") + + // Create MySQL handle for the database. + db, err := sql.Open("mysql", username+":"+password+"@tcp("+address+")/"+database) + + if err != nil { + t.Fatalf("Failed to open an handler with database: %v", err) + } + defer db.Close() + + // Prepare statement for creating table. + create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" + stmtCrt, err := db.Prepare(create_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmtCrt.Close() + + // Create table + _, err = stmtCrt.Exec() + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + // Prepare statement for inserting data. + insert_stmt := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" + stmtIns, err := db.Prepare(insert_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmtIns.Close() + + // Prepare statement for reading data. + select_stmt := "SELECT sqr FROM " + database + "." + table + " WHERE num = ?" + stmtOut, err := db.Prepare(select_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmtOut.Close() + + // Insert square numbers for 0-24 in the database + for i := 0; i < 25; i++ { + _, err = stmtIns.Exec(i, (i * i)) // Insert tuples (i, i^2) + if err != nil { + t.Fatalf("Failed to insert data: %v", err) + } + } + + var square int + + // Query the square-number of 13 + err = stmtOut.QueryRow(13).Scan(&square) + if err != nil { + t.Fatalf("Failed to query data: %v", err) + } + fmt.Printf("The square number of 13 is: %d", square) + +} diff --git a/physical/physical.go b/physical/physical.go index b992c10102..bd307e8dea 100644 --- a/physical/physical.go +++ b/physical/physical.go @@ -84,4 +84,5 @@ var BuiltinBackends = map[string]Factory{ "file": newFileBackend, "s3": newS3Backend, "etcd": newEtcdBackend, + "mysql": newMySQLBackend, } From c42bc38c62adb6bd671d39dafbfe127934ec9c67 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Fri, 12 Jun 2015 11:41:25 +0545 Subject: [PATCH 02/33] Changes done as per feedback --- physical/mysql.go | 110 +++++++++++++++++++++-------------------- physical/mysql_test.go | 2 +- 2 files changed, 58 insertions(+), 54 deletions(-) diff --git a/physical/mysql.go b/physical/mysql.go index 377ac9dc49..ed53a5350d 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -3,8 +3,8 @@ package physical import ( "database/sql" "errors" + "fmt" "sort" - "strings" "time" "github.com/armon/go-metrics" @@ -12,21 +12,17 @@ import ( ) var ( - MySQLDBNameMissing = errors.New("database name is missing in the configuration") - MySQLTableNameMissing = errors.New("table name is missing in the configuration") - MySQLHandlerCreationFailure = errors.New("failed to open handler with database") - MySQLPrepareStmtFailure = errors.New("failed to prepare statement") - MySQLExecuteStmtFailure = errors.New("failed to execute statement") - MySQLGetColumnsFailure = errors.New("failed to get columns") - MySQLScanRowsFailure = errors.New("failed to scan rows") + MySQLPrepareStmtFailure = errors.New("failed to prepare statement") + MySQLExecuteStmtFailure = errors.New("failed to execute statement") ) // MySQLBackend is a physical backend that stores data // within MySQL database. type MySQLBackend struct { - table string - database string - client *sql.DB + table string + database string + client *sql.DB + statements map[string]*sql.Stmt } // newMySQLBackend constructs a MySQL backend using the given API client and @@ -45,39 +41,70 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { // Get the MySQL database and table details. database, ok := conf["database"] if !ok { - return nil, MySQLDBNameMissing + return nil, fmt.Errorf("database name is missing in the configuration") } table, ok := conf["table"] if !ok { - return nil, MySQLTableNameMissing + return nil, fmt.Errorf("table name is missing in the configuration") } // Create MySQL handle for the database. dsn := username + ":" + password + "@tcp(" + address + ")/" + database db, err := sql.Open("mysql", dsn) if err != nil { - return nil, MySQLHandlerCreationFailure + return nil, fmt.Errorf("failed to open handler with database") } defer db.Close() - // Create the required table. - create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))" - stmt, err := db.Prepare(create_stmt) + // Create the required table if it doesn't exists. + create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" + create_stmt, err := db.Prepare(create_query) if err != nil { return nil, MySQLPrepareStmtFailure } - defer stmt.Close() + defer create_stmt.Close() - _, err = stmt.Exec() + _, err = create_stmt.Exec() if err != nil { return nil, MySQLExecuteStmtFailure } + // Map of query type as key to prepared statement. + var statements map[string]*sql.Stmt + + // Prepare statement for put query. + insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" + insert_stmt, err := db.Prepare(insert_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["put"] = insert_stmt + defer insert_stmt.Close() + + // Prepare statement for select query. + select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?" + select_stmt, err := db.Prepare(select_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["get"] = select_stmt + defer select_stmt.Close() + + // Prepare statement for delete query. + delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?" + delete_stmt, err := db.Prepare(delete_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["delete"] = delete_stmt + defer delete_stmt.Close() + // Setup the backend. m := &MySQLBackend{ - client: db, - table: table, - database: database, + client: db, + table: table, + database: database, + statements: statements, } return m, nil @@ -87,14 +114,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { func (m *MySQLBackend) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) - insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" - stmt, err := m.client.Prepare(insert_stmt) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(entry.Key, entry.Value) + _, err := m.statements["put"].Exec(entry.Key, entry.Value) if err != nil { return err } @@ -106,16 +126,9 @@ func (m *MySQLBackend) Put(entry *Entry) error { func (m *MySQLBackend) Get(key string) (*Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) - select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?" - stmt, err := m.client.Prepare(select_stmt) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - defer stmt.Close() - var result []byte - err = stmt.QueryRow(key).Scan(&result) + err := m.statements["get"].QueryRow(key).Scan(&result) if err != nil { return nil, MySQLExecuteStmtFailure } @@ -132,14 +145,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { func (m *MySQLBackend) Delete(key string) error { defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) - delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?" - stmt, err := m.client.Prepare(delete_stmt) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(key) + _, err := m.statements["delete"].Exec(key) if err != nil { return err } @@ -152,15 +158,15 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) - list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table - rows, err := m.client.Query(list_stmt) + list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'" + rows, err := m.client.Query(list_query) if err != nil { return nil, MySQLExecuteStmtFailure } columns, err := rows.Columns() if err != nil { - return nil, MySQLGetColumnsFailure + return nil, fmt.Errorf("failed to get columns") } values := make([]sql.RawBytes, len(columns)) @@ -174,13 +180,11 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) { for rows.Next() { err = rows.Scan(scanArgs...) if err != nil { - return nil, MySQLScanRowsFailure + return nil, fmt.Errorf("failed to scan rows") } for _, col := range values { - if strings.HasPrefix(string(col), prefix) { - keys = append(keys, string(col)) - } + keys = append(keys, string(col)) } } diff --git a/physical/mysql_test.go b/physical/mysql_test.go index b808a29bdb..c1c78c9fc7 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) { defer db.Close() // Prepare statement for creating table. - create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" + create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" stmtCrt, err := db.Prepare(create_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) From b33d707b693b9a1b858d75103feec12c4afa2440 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Fri, 12 Jun 2015 15:32:45 +0545 Subject: [PATCH 03/33] Added the test as per suggestion --- physical/mysql.go | 3 ++- physical/mysql_test.go | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/physical/mysql.go b/physical/mysql.go index ed53a5350d..50d4112bc2 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -70,7 +70,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { } // Map of query type as key to prepared statement. - var statements map[string]*sql.Stmt + statements := make(map[string]*sql.Stmt) // Prepare statement for put query. insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" @@ -158,6 +158,7 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) + // Query to get all keys matching a prefix. list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'" rows, err := m.client.Query(list_query) if err != nil { diff --git a/physical/mysql_test.go b/physical/mysql_test.go index c1c78c9fc7..c8b4fd470c 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -83,4 +83,18 @@ func TestMySQLBackend(t *testing.T) { } fmt.Printf("The square number of 13 is: %d", square) + b, err := NewBackend("mysql", map[string]string{ + "address": address, + "database": database, + "table": table, + "username": username, + "password": password, + }) + + if err != nil { + t.Fatalf("Failed to create new backend: %v", err) + } + + testBackend(t, b) + testBackend_ListPrefix(t, b) } From 3ff10a757330a05594576f0d72e07f443c2ab15a Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Fri, 12 Jun 2015 23:16:46 +0545 Subject: [PATCH 04/33] Fixing List command behaviour --- physical/mysql.go | 21 +++++++++++++++------ physical/mysql_test.go | 7 ++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/physical/mysql.go b/physical/mysql.go index 50d4112bc2..8666c249d3 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sort" + "strings" "time" "github.com/armon/go-metrics" @@ -54,7 +55,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { if err != nil { return nil, fmt.Errorf("failed to open handler with database") } - defer db.Close() // Create the required table if it doesn't exists. create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" @@ -79,7 +79,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { return nil, MySQLPrepareStmtFailure } statements["put"] = insert_stmt - defer insert_stmt.Close() // Prepare statement for select query. select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?" @@ -88,7 +87,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { return nil, MySQLPrepareStmtFailure } statements["get"] = select_stmt - defer select_stmt.Close() // Prepare statement for delete query. delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?" @@ -97,7 +95,6 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { return nil, MySQLPrepareStmtFailure } statements["delete"] = delete_stmt - defer delete_stmt.Close() // Setup the backend. m := &MySQLBackend{ @@ -133,6 +130,11 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { return nil, MySQLExecuteStmtFailure } + // Handle a non-existing value + if result == nil { + return nil, nil + } + ent := &Entry{ Key: key, Value: result, @@ -184,8 +186,15 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) { return nil, fmt.Errorf("failed to scan rows") } - for _, col := range values { - keys = append(keys, string(col)) + for _, key := range values { + key := strings.TrimPrefix(string(key), prefix) + if i := strings.Index(string(key), "/"); i == -1 { + // Add objects only from the current 'folder' + keys = append(keys, string(key)) + } else if i != -1 { + // Add truncated 'folder' paths + keys = appendIfMissing(keys, string(key[:i+1])) + } } } diff --git a/physical/mysql_test.go b/physical/mysql_test.go index c8b4fd470c..9cbb11a5db 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) { defer db.Close() // Prepare statement for creating table. - create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" + create_stmt := "CREATE TABLE IF NOT EXISTS test.square (num int, sqr int, PRIMARY KEY (num))" stmtCrt, err := db.Prepare(create_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) @@ -51,7 +51,7 @@ func TestMySQLBackend(t *testing.T) { } // Prepare statement for inserting data. - insert_stmt := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" + insert_stmt := "INSERT INTO test.square VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" stmtIns, err := db.Prepare(insert_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) @@ -59,7 +59,7 @@ func TestMySQLBackend(t *testing.T) { defer stmtIns.Close() // Prepare statement for reading data. - select_stmt := "SELECT sqr FROM " + database + "." + table + " WHERE num = ?" + select_stmt := "SELECT sqr FROM test.square WHERE num = ?" stmtOut, err := db.Prepare(select_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err) @@ -83,6 +83,7 @@ func TestMySQLBackend(t *testing.T) { } fmt.Printf("The square number of 13 is: %d", square) + // Run vault tests b, err := NewBackend("mysql", map[string]string{ "address": address, "database": database, From 7c7f64fe6728bd3960bdf6e5870640c2ba32c034 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Sat, 13 Jun 2015 08:04:40 +0545 Subject: [PATCH 05/33] Fixed a failing test and drop table after running tests --- physical/mysql.go | 2 +- physical/mysql_test.go | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/physical/mysql.go b/physical/mysql.go index 8666c249d3..bb4a89c5e5 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -127,7 +127,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { err := m.statements["get"].QueryRow(key).Scan(&result) if err != nil { - return nil, MySQLExecuteStmtFailure + return nil, nil } // Handle a non-existing value diff --git a/physical/mysql_test.go b/physical/mysql_test.go index 9cbb11a5db..2b00404ffb 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -98,4 +98,17 @@ func TestMySQLBackend(t *testing.T) { testBackend(t, b) testBackend_ListPrefix(t, b) + + // Drop table after running tests + drop_stmt := "DROP TABLE " + database + "." + table + stmt, err := db.Prepare(drop_stmt) + if err != nil { + t.Fatalf("Failed to prepare statement: %v", err) + } + defer stmt.Close() + + _, err = stmt.Exec() + if err != nil { + t.Fatalf("Failed to drop table: %v", err) + } } From 24b9ef49c16c78363ae85b4078edbe608468ae08 Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Tue, 16 Jun 2015 13:02:15 -0400 Subject: [PATCH 06/33] Accept PUT as well as post to sys/mounts --- http/sys_mount.go | 9 ++++----- http/sys_mount_test.go | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/http/sys_mount.go b/http/sys_mount.go index 36499cfabf..68eae02e41 100644 --- a/http/sys_mount.go +++ b/http/sys_mount.go @@ -13,7 +13,7 @@ func handleSysMounts(core *vault.Core) http.Handler { switch r.Method { case "GET": handleSysListMounts(core).ServeHTTP(w, r) - case "POST": + case "PUT", "POST": fallthrough case "DELETE": handleSysMountUnmount(core, w, r) @@ -27,8 +27,7 @@ func handleSysMounts(core *vault.Core) http.Handler { func handleSysRemount(core *vault.Core) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.Method { - case "POST": - case "PUT": + case "PUT", "POST": default: respondError(w, http.StatusMethodNotAllowed, nil) return @@ -80,7 +79,7 @@ func handleSysListMounts(core *vault.Core) http.Handler { func handleSysMountUnmount(core *vault.Core, w http.ResponseWriter, r *http.Request) { switch r.Method { - case "POST": + case "PUT", "POST": case "DELETE": default: respondError(w, http.StatusMethodNotAllowed, nil) @@ -100,7 +99,7 @@ func handleSysMountUnmount(core *vault.Core, w http.ResponseWriter, r *http.Requ } switch r.Method { - case "POST": + case "PUT", "POST": handleSysMount(core, w, r, path) case "DELETE": handleSysUnmount(core, w, r, path) diff --git a/http/sys_mount_test.go b/http/sys_mount_test.go index c91cd66024..9e7a943405 100644 --- a/http/sys_mount_test.go +++ b/http/sys_mount_test.go @@ -76,6 +76,22 @@ func TestSysMount(t *testing.T) { } } +func TestSysMount_put(t *testing.T) { + core, _, token := vault.TestCoreUnsealed(t) + ln, addr := TestServer(t, core) + defer ln.Close() + TestServerAuth(t, addr, token) + + resp := testHttpPut(t, addr+"/v1/sys/mounts/foo", map[string]interface{}{ + "type": "generic", + "description": "foo", + }) + testResponseStatus(t, resp, 204) + + // The TestSysMount test tests the thing is actually created. See that test + // for more info. +} + func TestSysRemount(t *testing.T) { core, _, token := vault.TestCoreUnsealed(t) ln, addr := TestServer(t, core) From 05fa4a4a48362f815060d442596e2d5bbe0f4769 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 13:31:56 -0700 Subject: [PATCH 07/33] secret/postgres: Ensure sane username length. Fixes #326 --- builtin/logical/postgresql/path_role_create.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/builtin/logical/postgresql/path_role_create.go b/builtin/logical/postgresql/path_role_create.go index 497b9a6c63..7d74148512 100644 --- a/builtin/logical/postgresql/path_role_create.go +++ b/builtin/logical/postgresql/path_role_create.go @@ -2,7 +2,6 @@ package postgresql import ( "fmt" - "math/rand" "time" "github.com/hashicorp/vault/logical" @@ -51,10 +50,15 @@ func (b *backend) pathRoleCreateRead( lease = &configLease{Lease: 1 * time.Hour} } - // Generate the username, password and expiration - username := fmt.Sprintf( - "vault-%s-%d-%d", - req.DisplayName, time.Now().Unix(), rand.Int31n(10000)) + // Generate the username, password and expiration. PG limits user to 63 characters + displayName := req.DisplayName + if len(displayName) > 26 { + displayName = displayName[:26] + } + username := fmt.Sprintf("%s-%s", displayName, generateUUID()) + if len(username) > 63 { + username = username[:63] + } password := generateUUID() expiration := time.Now().UTC(). Add(lease.Lease + time.Duration((float64(lease.Lease) * 0.1))). From 7d05dfeb1f74b5b2116b899fd2efc090fd5ea6e4 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 13:59:09 -0700 Subject: [PATCH 08/33] logical: remove IncrementedLease, simplify ExpirationTime calculation --- logical/lease.go | 18 ++---------------- logical/lease_test.go | 30 ++++++++---------------------- 2 files changed, 10 insertions(+), 38 deletions(-) diff --git a/logical/lease.go b/logical/lease.go index 834242fd26..32ccb65978 100644 --- a/logical/lease.go +++ b/logical/lease.go @@ -46,22 +46,8 @@ func (l *LeaseOptions) LeaseTotal() time.Duration { // ExpirationTime computes the time until expiration including the grace period func (l *LeaseOptions) ExpirationTime() time.Time { var expireTime time.Time - if !l.LeaseIssue.IsZero() && l.Lease > 0 { - expireTime = l.LeaseIssue.UTC().Add(l.LeaseTotal()) + if l.LeaseEnabled() { + expireTime = time.Now().UTC().Add(l.LeaseTotal()) } - return expireTime } - -// IncrementedLease returns the lease duration that would need to set -// in order to increment the _current_ lease by the given duration -// if the auth were re-issued right now. -func (l *LeaseOptions) IncrementedLease(inc time.Duration) time.Duration { - var result time.Duration - expireTime := l.ExpirationTime() - if expireTime.IsZero() { - return result - } - - return expireTime.Add(inc).Sub(time.Now().UTC()) -} diff --git a/logical/lease_test.go b/logical/lease_test.go index ba74c06426..02916bc817 100644 --- a/logical/lease_test.go +++ b/logical/lease_test.go @@ -5,17 +5,6 @@ import ( "time" ) -func TestLeaseOptionsIncrementedLease(t *testing.T) { - var l LeaseOptions - l.Lease = 1 * time.Second - l.LeaseIssue = time.Now().UTC() - - actual := l.IncrementedLease(1 * time.Second) - if actual > 3*time.Second || actual < 1*time.Second { - t.Fatalf("bad: %s", actual) - } -} - func TestLeaseOptionsLeaseTotal(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour @@ -66,12 +55,11 @@ func TestLeaseOptionsLeaseTotal_negGrace(t *testing.T) { func TestLeaseOptionsExpirationTime(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour - l.LeaseIssue = time.Now().UTC() - actual := l.ExpirationTime() - expected := l.LeaseIssue.Add(l.Lease) - if !actual.Equal(expected) { - t.Fatalf("bad: %s", actual) + limit := time.Now().UTC().Add(time.Hour) + exp := l.ExpirationTime() + if exp.Before(limit) { + t.Fatalf("bad: %s", exp) } } @@ -79,11 +67,10 @@ func TestLeaseOptionsExpirationTime_grace(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour l.LeaseGracePeriod = 30 * time.Minute - l.LeaseIssue = time.Now().UTC() + limit := time.Now().UTC().Add(time.Hour + 30*time.Minute) actual := l.ExpirationTime() - expected := l.LeaseIssue.Add(l.Lease + l.LeaseGracePeriod) - if !actual.Equal(expected) { + if actual.Before(limit) { t.Fatalf("bad: %s", actual) } } @@ -92,11 +79,10 @@ func TestLeaseOptionsExpirationTime_graceNegative(t *testing.T) { var l LeaseOptions l.Lease = 1 * time.Hour l.LeaseGracePeriod = -1 * 30 * time.Minute - l.LeaseIssue = time.Now().UTC() + limit := time.Now().UTC().Add(time.Hour) actual := l.ExpirationTime() - expected := l.LeaseIssue.Add(l.Lease) - if !actual.Equal(expected) { + if actual.Before(limit) { t.Fatalf("bad: %s", actual) } } From 2a894171cac21fe7bc4436920fedf682658670b0 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 14:16:44 -0700 Subject: [PATCH 09/33] logical/framework: simplify calculation of lease renew --- logical/framework/lease.go | 56 +++++++++++++++----------------------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/logical/framework/lease.go b/logical/framework/lease.go index 7203d516da..194701b365 100644 --- a/logical/framework/lease.go +++ b/logical/framework/lease.go @@ -13,7 +13,7 @@ import ( // setting it to 2 hours forces a renewal within the next 2 hours again. // // maxSession is the maximum session length allowed since the original -// issue time. If this is zero, it is ignored,. +// issue time. If this is zero, it is ignored. func LeaseExtend(max, maxSession time.Duration) OperationFunc { return func(req *logical.Request, data *FieldData) (*logical.Response, error) { lease := detectLease(req) @@ -21,55 +21,44 @@ func LeaseExtend(max, maxSession time.Duration) OperationFunc { return nil, fmt.Errorf("no lease options for request") } + // Sanity check the desired increment + switch { + // Protect against negative leases + case lease.LeaseIncrement < 0: + return logical.ErrorResponse( + "increment must be greater than 0"), logical.ErrInvalidRequest + + // If no lease increment, or too large of an increment, use the max + case max > 0 && lease.LeaseIncrement == 0, max > 0 && lease.LeaseIncrement > max: + lease.LeaseIncrement = max + } + + // Get the current time now := time.Now().UTC() // Check if we're passed the issue limit var maxSessionTime time.Time if maxSession > 0 { maxSessionTime = lease.LeaseIssue.Add(maxSession) - if maxSessionTime.Sub(now) <= 0 { + if maxSessionTime.Before(now) { return logical.ErrorResponse(fmt.Sprintf( "lease can only be renewed up to %s past original issue", maxSession)), logical.ErrInvalidRequest } } - // Protect against negative leases - if lease.LeaseIncrement < 0 { - return logical.ErrorResponse( - "increment must be greater than 0"), logical.ErrInvalidRequest - } - - // If the lease is zero, then assume max - if lease.LeaseIncrement == 0 { - lease.LeaseIncrement = max - } - - // If the increment is greater than the amount of time we have left - // on our session, set it to that. - if !maxSessionTime.IsZero() { - diff := maxSessionTime.Sub(lease.ExpirationTime()) - if diff < lease.LeaseIncrement { - lease.LeaseIncrement = diff - } + // The new lease is the minimum of the requested LeaseIncrement + // or the maxSessionTime + requestedLease := now.Add(lease.LeaseIncrement) + if !maxSessionTime.IsZero() && requestedLease.After(maxSessionTime) { + requestedLease = maxSessionTime } // Determine the requested lease - newLease := lease.IncrementedLease(lease.LeaseIncrement) - - if max > 0 { - // Determine if the requested lease is too long - maxExpiration := now.Add(max) - newExpiration := now.Add(newLease) - if newExpiration.Sub(maxExpiration) > 0 { - // The new expiration is past the max expiration. In this - // case, admit the longest lease we can. - newLease = maxExpiration.Sub(lease.ExpirationTime()) - } - } + newLeaseDuration := requestedLease.Sub(now) // Set the lease - lease.Lease = newLease + lease.Lease = newLeaseDuration return &logical.Response{Auth: req.Auth, Secret: req.Secret}, nil } } @@ -80,6 +69,5 @@ func detectLease(req *logical.Request) *logical.LeaseOptions { } else if req.Secret != nil { return &req.Secret.LeaseOptions } - return nil } From daf94d67217466c43c0b02ceda29debdab69204a Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 14:24:12 -0700 Subject: [PATCH 10/33] logical/framework: allow the lease max to come from existing lease --- logical/framework/backend_test.go | 2 +- logical/framework/lease.go | 11 ++++++++++- logical/framework/lease_test.go | 21 ++++++++++++++------- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index dee5d9419d..093e3e5cdf 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -215,7 +215,7 @@ func TestBackendHandleRequest_renew(t *testing.T) { func TestBackendHandleRequest_renewExtend(t *testing.T) { secret := &Secret{ Type: "foo", - Renew: LeaseExtend(0, 0), + Renew: LeaseExtend(0, 0, false), DefaultDuration: 5 * time.Minute, } b := &Backend{ diff --git a/logical/framework/lease.go b/logical/framework/lease.go index 194701b365..4ba250d26f 100644 --- a/logical/framework/lease.go +++ b/logical/framework/lease.go @@ -14,13 +14,22 @@ import ( // // maxSession is the maximum session length allowed since the original // issue time. If this is zero, it is ignored. -func LeaseExtend(max, maxSession time.Duration) OperationFunc { +// +// maxFromLease controls if the maximum renewal period comes from the existing +// lease. This means the value of `max` will be replaced with the existing +// lease duration. +func LeaseExtend(max, maxSession time.Duration, maxFromLease bool) OperationFunc { return func(req *logical.Request, data *FieldData) (*logical.Response, error) { lease := detectLease(req) if lease == nil { return nil, fmt.Errorf("no lease options for request") } + // Check if we should limit max + if maxFromLease { + max = lease.Lease + } + // Sanity check the desired increment switch { // Protect against negative leases diff --git a/logical/framework/lease_test.go b/logical/framework/lease_test.go index ad2fec97da..f22ce798d6 100644 --- a/logical/framework/lease_test.go +++ b/logical/framework/lease_test.go @@ -11,11 +11,12 @@ func TestLeaseExtend(t *testing.T) { now := time.Now().UTC().Round(time.Hour) cases := map[string]struct { - Max time.Duration - MaxSession time.Duration - Request time.Duration - Result time.Duration - Error bool + Max time.Duration + MaxSession time.Duration + Request time.Duration + Result time.Duration + MaxFromLease bool + Error bool }{ "valid request, good bounds": { Max: 30 * time.Hour, @@ -62,20 +63,26 @@ func TestLeaseExtend(t *testing.T) { Request: -7 * time.Hour, Error: true, }, + + "max form lease, request too large": { + Request: 10 * time.Hour, + MaxFromLease: true, + Result: time.Hour, + }, } for name, tc := range cases { req := &logical.Request{ Auth: &logical.Auth{ LeaseOptions: logical.LeaseOptions{ - Lease: 1 * time.Second, + Lease: 1 * time.Hour, LeaseIssue: now, LeaseIncrement: tc.Request, }, }, } - callback := LeaseExtend(tc.Max, tc.MaxSession) + callback := LeaseExtend(tc.Max, tc.MaxSession, tc.MaxFromLease) resp, err := callback(req, nil) if (err != nil) != tc.Error { t.Fatalf("bad: %s\nerr: %s", name, err) From 2b04348e061c0971f390cb4d69c5ca195adad4e7 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 14:28:13 -0700 Subject: [PATCH 11/33] vault: fixing issues with token renewal --- vault/expiration.go | 11 ++--------- vault/token_store.go | 5 ++++- vault/token_store_test.go | 9 ++++++--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/vault/expiration.go b/vault/expiration.go index 601722cf80..294624d6e0 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -337,7 +337,6 @@ func (m *ExpirationManager) RenewToken(source string, token string, // Attach the ClientToken resp.Auth.ClientToken = token resp.Auth.LeaseIncrement = 0 - resp.Auth.LeaseIssue = time.Now().UTC() // Update the lease entry le.Auth = resp.Auth @@ -366,9 +365,6 @@ func (m *ExpirationManager) Register(req *logical.Request, resp *logical.Respons return "", err } - // Setup some of the fields on auth - resp.Secret.LeaseIssue = time.Now().UTC() - // Create a lease entry le := leaseEntry{ LeaseID: path.Join(req.Path, generateUUID()), @@ -376,7 +372,7 @@ func (m *ExpirationManager) Register(req *logical.Request, resp *logical.Respons Path: req.Path, Data: resp.Data, Secret: resp.Secret, - IssueTime: resp.Secret.LeaseIssue, + IssueTime: time.Now().UTC(), ExpireTime: resp.Secret.ExpirationTime(), } @@ -403,16 +399,13 @@ func (m *ExpirationManager) Register(req *logical.Request, resp *logical.Respons func (m *ExpirationManager) RegisterAuth(source string, auth *logical.Auth) error { defer metrics.MeasureSince([]string{"expire", "register-auth"}, time.Now()) - // Setup some of the fields on auth - auth.LeaseIssue = time.Now().UTC() - // Create a lease entry le := leaseEntry{ LeaseID: path.Join(source, m.tokenStore.SaltID(auth.ClientToken)), ClientToken: auth.ClientToken, Auth: auth, Path: source, - IssueTime: auth.LeaseIssue, + IssueTime: time.Now().UTC(), ExpireTime: auth.ExpirationTime(), } diff --git a/vault/token_store.go b/vault/token_store.go index b3ddc889c0..98845aa53c 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -85,7 +85,10 @@ func NewTokenStore(c *Core) (*TokenStore, error) { // Setup the framework endpoints t.Backend = &framework.Backend{ - AuthRenew: framework.LeaseExtend(0, 0), + // Allow a token lease to be extended indefinitely, but each time for only + // as much as the original lease allowed for. If the lease has a 1 hour expiration, + // it can only be extended up to another hour each time this means. + AuthRenew: framework.LeaseExtend(0, 0, true), PathsSpecial: &logical.Paths{ Root: []string{ diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 76b2e97b8e..3c887e233a 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -820,6 +820,7 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) { // Get the original expire time to compare originalExpire := auth.ExpirationTime() + beforeRenew := time.Now().UTC() req := logical.TestRequest(t, logical.WriteOperation, "renew/"+root.ID) req.Data["increment"] = "3600" resp, err := ts.HandleRequest(req) @@ -829,9 +830,11 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) { // Get the new expire time newExpire := resp.Auth.ExpirationTime() - expireDiff := newExpire.Sub(originalExpire) - if expireDiff < 30*time.Minute || expireDiff > 3*time.Hour { - t.Fatalf("bad: %#v", expireDiff) + if newExpire.Before(originalExpire) { + t.Fatalf("should expire later: %s %s", newExpire, originalExpire) + } + if newExpire.Before(beforeRenew.Add(time.Hour)) { + t.Fatalf("should have at least an hour: %s %s", newExpire, beforeRenew) } } From 28dd283c9387b66cf8d8edf2bf285392db38fe95 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 14:34:11 -0700 Subject: [PATCH 12/33] builtin: fixing API change in logical framework --- builtin/credential/cert/path_login.go | 4 ++-- builtin/credential/ldap/path_login.go | 2 +- builtin/credential/userpass/path_login.go | 2 +- builtin/logical/aws/secret_access_keys.go | 2 +- builtin/logical/consul/secret_token.go | 2 +- builtin/logical/mysql/secret_creds.go | 2 +- builtin/logical/postgresql/secret_creds.go | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 14dbf8212a..2b673271e3 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -61,7 +61,7 @@ func (b *backend) pathLogin( Policies: matched.Entry.Policies, DisplayName: matched.Entry.DisplayName, Metadata: map[string]string{ - "cert_name": matched.Entry.Name, + "cert_name": matched.Entry.Name, "common_name": connState.PeerCertificates[0].Subject.CommonName, }, LeaseOptions: logical.LeaseOptions{ @@ -187,5 +187,5 @@ func (b *backend) pathLoginRenew( return nil, nil } - return framework.LeaseExtend(cert.Lease, 0)(req, d) + return framework.LeaseExtend(cert.Lease, 0, false)(req, d) } diff --git a/builtin/credential/ldap/path_login.go b/builtin/credential/ldap/path_login.go index 5b77f5772e..ad566b771a 100644 --- a/builtin/credential/ldap/path_login.go +++ b/builtin/credential/ldap/path_login.go @@ -77,7 +77,7 @@ func (b *backend) pathLoginRenew( return logical.ErrorResponse("policies have changed, revoking login"), nil } - return framework.LeaseExtend(1*time.Hour, 0)(req, d) + return framework.LeaseExtend(1*time.Hour, 0, false)(req, d) } const pathLoginSyn = ` diff --git a/builtin/credential/userpass/path_login.go b/builtin/credential/userpass/path_login.go index 54b0df1c39..7e427d825c 100644 --- a/builtin/credential/userpass/path_login.go +++ b/builtin/credential/userpass/path_login.go @@ -68,7 +68,7 @@ func (b *backend) pathLoginRenew( return nil, nil } - return framework.LeaseExtend(1*time.Hour, 0)(req, d) + return framework.LeaseExtend(1*time.Hour, 0, false)(req, d) } const pathLoginSyn = ` diff --git a/builtin/logical/aws/secret_access_keys.go b/builtin/logical/aws/secret_access_keys.go index 9ddb4d9a36..5d03b312dc 100644 --- a/builtin/logical/aws/secret_access_keys.go +++ b/builtin/logical/aws/secret_access_keys.go @@ -115,7 +115,7 @@ func (b *backend) secretAccessKeysRenew( lease = &configLease{Lease: 1 * time.Hour} } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax) + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, false) return f(req, d) } diff --git a/builtin/logical/consul/secret_token.go b/builtin/logical/consul/secret_token.go index 06679b1313..e6e83de6b0 100644 --- a/builtin/logical/consul/secret_token.go +++ b/builtin/logical/consul/secret_token.go @@ -26,7 +26,7 @@ func secretToken() *framework.Secret { DefaultDuration: DefaultLeaseDuration, DefaultGracePeriod: DefaultGracePeriod, - Renew: framework.LeaseExtend(1*time.Hour, 0), + Renew: framework.LeaseExtend(0, 0, true), Revoke: secretTokenRevoke, } } diff --git a/builtin/logical/mysql/secret_creds.go b/builtin/logical/mysql/secret_creds.go index 5bed159764..c60beb0add 100644 --- a/builtin/logical/mysql/secret_creds.go +++ b/builtin/logical/mysql/secret_creds.go @@ -44,7 +44,7 @@ func (b *backend) secretCredsRenew( lease = &configLease{Lease: 1 * time.Hour} } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax) + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, false) return f(req, d) } diff --git a/builtin/logical/postgresql/secret_creds.go b/builtin/logical/postgresql/secret_creds.go index 00714eb6f7..204daa073c 100644 --- a/builtin/logical/postgresql/secret_creds.go +++ b/builtin/logical/postgresql/secret_creds.go @@ -58,7 +58,7 @@ func (b *backend) secretCredsRenew( lease = &configLease{Lease: 1 * time.Hour} } - f := framework.LeaseExtend(lease.Lease, lease.LeaseMax) + f := framework.LeaseExtend(lease.Lease, lease.LeaseMax, false) resp, err := f(req, d) if err != nil { return nil, err From 0bd806a5860aef787da9a0dd39c5830bbf5b0c83 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 15:22:50 -0700 Subject: [PATCH 13/33] vault: ensure token renew does not double register --- vault/core.go | 5 +++-- vault/core_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/vault/core.go b/vault/core.go index 723fe23796..9791412c08 100644 --- a/vault/core.go +++ b/vault/core.go @@ -413,8 +413,9 @@ func (c *Core) handleRequest(req *logical.Request) (*logical.Response, error) { } // Only the token store is allowed to return an auth block, for any - // other request this is an internal error - if resp != nil && resp.Auth != nil { + // other request this is an internal error. We exclude renewal of a token, + // since it does not need to be re-registered + if resp != nil && resp.Auth != nil && !strings.HasPrefix(req.Path, "auth/token/renew/") { if !strings.HasPrefix(req.Path, "auth/token/") { c.logger.Printf( "[ERR] core: unexpected Auth response for non-token backend "+ diff --git a/vault/core_test.go b/vault/core_test.go index c27f03460d..25cebf1943 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -1368,6 +1368,55 @@ func TestCore_RenewSameLease(t *testing.T) { } } +// Renew of a token should not create a new lease +func TestCore_RenewToken_SingleRegister(t *testing.T) { + c, _, root := TestCoreUnsealed(t) + + // Create a new token + req := &logical.Request{ + Operation: logical.WriteOperation, + Path: "auth/token/create", + Data: map[string]interface{}{ + "lease": "1h", + }, + ClientToken: root, + } + resp, err := c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + newClient := resp.Auth.ClientToken + + // Renew the token + req = logical.TestRequest(t, logical.WriteOperation, "auth/token/renew/"+newClient) + req.ClientToken = newClient + resp, err = c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Revoke using the renew prefix + req = logical.TestRequest(t, logical.WriteOperation, "sys/revoke-prefix/auth/token/renew/") + req.ClientToken = root + resp, err = c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Verify our token is still valid (e.g. we did not get invalided by the revoke) + req = logical.TestRequest(t, logical.ReadOperation, "auth/token/lookup/"+newClient) + req.ClientToken = newClient + resp, err = c.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Verify the token exists + if resp.Data["id"] != newClient { + t.Fatalf("bad: %#v", resp.Data) + } +} + // Based on bug GH-203, attempt to disable a credential backend with leased secrets func TestCore_EnableDisableCred_WithLease(t *testing.T) { // Create a badass credential backend that always logs in as armon From dcb45874bfd1c8042762fc33b78f3d80fe4ce905 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 15:56:26 -0700 Subject: [PATCH 14/33] logical/framework: adding a new duration type to convert to seconds --- logical/framework/backend.go | 2 ++ logical/framework/backend_test.go | 10 +++++++ logical/framework/field_data.go | 43 ++++++++++++++++++++++----- logical/framework/field_data_test.go | 44 ++++++++++++++++++++++++++++ logical/framework/field_type.go | 6 ++++ 5 files changed, 98 insertions(+), 7 deletions(-) diff --git a/logical/framework/backend.go b/logical/framework/backend.go index b53c9f21b4..042c0e70e3 100644 --- a/logical/framework/backend.go +++ b/logical/framework/backend.go @@ -373,6 +373,8 @@ func (t FieldType) Zero() interface{} { return false case TypeMap: return map[string]interface{}{} + case TypeDurationSecond: + return 0 default: panic("unknown type: " + t.String()) } diff --git a/logical/framework/backend_test.go b/logical/framework/backend_test.go index 093e3e5cdf..4058d194de 100644 --- a/logical/framework/backend_test.go +++ b/logical/framework/backend_test.go @@ -508,6 +508,16 @@ func TestFieldSchemaDefaultOrZero(t *testing.T) { &FieldSchema{Type: TypeString}, "", }, + + "default duration set": { + &FieldSchema{Type: TypeDurationSecond, Default: 60}, + 60, + }, + + "default duration not set": { + &FieldSchema{Type: TypeDurationSecond}, + 0, + }, } for name, tc := range cases { diff --git a/logical/framework/field_data.go b/logical/framework/field_data.go index e8255c8c5a..40d1ac182a 100644 --- a/logical/framework/field_data.go +++ b/logical/framework/field_data.go @@ -2,6 +2,9 @@ package framework import ( "fmt" + "strconv" + "strings" + "time" "github.com/mitchellh/mapstructure" ) @@ -64,13 +67,7 @@ func (d *FieldData) GetOkErr(k string) (interface{}, bool, error) { } switch schema.Type { - case TypeBool: - fallthrough - case TypeInt: - fallthrough - case TypeMap: - fallthrough - case TypeString: + case TypeBool, TypeInt, TypeMap, TypeDurationSecond, TypeString: return d.getPrimitive(k, schema) default: return nil, false, @@ -114,6 +111,38 @@ func (d *FieldData) getPrimitive( } return result, true, nil + + case TypeDurationSecond: + var result int + switch inp := raw.(type) { + case int: + result = inp + case float32: + result = int(inp) + case float64: + result = int(inp) + case string: + // Look for a suffix otherwise its a plain second value + if strings.HasSuffix(inp, "s") || strings.HasSuffix(inp, "m") || strings.HasSuffix(inp, "h") { + dur, err := time.ParseDuration(inp) + if err != nil { + return nil, true, err + } + result = int(dur.Seconds()) + } else { + // Plain integer + val, err := strconv.ParseInt(inp, 10, 64) + if err != nil { + return nil, true, err + } + result = int(val) + } + + default: + return nil, false, fmt.Errorf("invalid input '%v'", raw) + } + return result, true, nil + default: panic(fmt.Sprintf("Unknown type: %s", schema.Type)) } diff --git a/logical/framework/field_data_test.go b/logical/framework/field_data_test.go index 000ded72a5..e6a32f8cb1 100644 --- a/logical/framework/field_data_test.go +++ b/logical/framework/field_data_test.go @@ -91,6 +91,50 @@ func TestFieldDataGet(t *testing.T) { "child": true, }, }, + + "duration type, string value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": "42", + }, + "foo", + 42, + }, + + "duration type, string duration value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": "42m", + }, + "foo", + 2520, + }, + + "duration type, int value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": 42, + }, + "foo", + 42, + }, + + "duration type, float value": { + map[string]*FieldSchema{ + "foo": &FieldSchema{Type: TypeDurationSecond}, + }, + map[string]interface{}{ + "foo": 42.0, + }, + "foo", + 42, + }, } for name, tc := range cases { diff --git a/logical/framework/field_type.go b/logical/framework/field_type.go index a02b77bcd8..d9d0ef3d24 100644 --- a/logical/framework/field_type.go +++ b/logical/framework/field_type.go @@ -9,6 +9,10 @@ const ( TypeInt TypeBool TypeMap + + // TypeDurationSecond represent as seconds, this can be either an + // integer or go duration format string (e.g. 24h) + TypeDurationSecond ) func (t FieldType) String() string { @@ -21,6 +25,8 @@ func (t FieldType) String() string { return "bool" case TypeMap: return "map" + case TypeDurationSecond: + return "duration (sec)" default: return "unknown type" } From 81df0d6e4971bcbcb2aa74e680248dbf4aca15ad Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 15:58:20 -0700 Subject: [PATCH 15/33] vault: allow increment to be duration string. Fixes #340 --- vault/logical_system.go | 2 +- vault/logical_system_test.go | 2 +- vault/token_store.go | 2 +- vault/token_store_test.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vault/logical_system.go b/vault/logical_system.go index 9769a882bd..6673226d0a 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -110,7 +110,7 @@ func NewSystemBackend(core *Core) logical.Backend { Description: strings.TrimSpace(sysHelp["lease_id"][0]), }, "increment": &framework.FieldSchema{ - Type: framework.TypeInt, + Type: framework.TypeDurationSecond, Description: strings.TrimSpace(sysHelp["increment"][0]), }, }, diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index b3ba308b62..fd23f59aa4 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -181,7 +181,7 @@ func TestSystemBackend_renew(t *testing.T) { // Attempt renew req2 := logical.TestRequest(t, logical.WriteOperation, "renew/"+resp.Secret.LeaseID) - req2.Data["increment"] = 100 + req2.Data["increment"] = "100s" resp2, err := b.HandleRequest(req2) if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) diff --git a/vault/token_store.go b/vault/token_store.go index 98845aa53c..03f278c83c 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -211,7 +211,7 @@ func NewTokenStore(c *Core) (*TokenStore, error) { Description: "Token to renew", }, "increment": &framework.FieldSchema{ - Type: framework.TypeInt, + Type: framework.TypeDurationSecond, Description: "The desired increment in seconds to the token expiration", }, }, diff --git a/vault/token_store_test.go b/vault/token_store_test.go index 3c887e233a..2abea5cbf9 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -822,7 +822,7 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) { beforeRenew := time.Now().UTC() req := logical.TestRequest(t, logical.WriteOperation, "renew/"+root.ID) - req.Data["increment"] = "3600" + req.Data["increment"] = "3600s" resp, err := ts.HandleRequest(req) if err != nil { t.Fatalf("err: %v %v", err, resp) From d19b74f78f7cf3a99c3ba562fa990aa6089fa67c Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 16:59:50 -0700 Subject: [PATCH 16/33] command/token-create: provide more useful output. Fixes #337 --- command/format.go | 10 ++++++++++ command/token_create.go | 9 +++++++-- command/token_create_test.go | 7 +++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/command/format.go b/command/format.go index d3fad141a4..0b6aa7d968 100644 --- a/command/format.go +++ b/command/format.go @@ -53,6 +53,16 @@ func outputFormatTable(ui cli.Ui, s *api.Secret, whitespace bool) int { "lease_renewable %s %s", config.Delim, strconv.FormatBool(s.Renewable))) } + if s.Auth != nil { + input = append(input, fmt.Sprintf("token %s %s", config.Delim, s.Auth.ClientToken)) + input = append(input, fmt.Sprintf("token_duration %s %d", config.Delim, s.Auth.LeaseDuration)) + input = append(input, fmt.Sprintf("token_renewable %s %v", config.Delim, s.Auth.Renewable)) + input = append(input, fmt.Sprintf("token_policies %s %v", config.Delim, s.Auth.Policies)) + for k, v := range s.Auth.Metadata { + input = append(input, fmt.Sprintf("token_meta_%s %s %#v", k, config.Delim, v)) + } + } + for k, v := range s.Data { input = append(input, fmt.Sprintf("%s %s %v", k, config.Delim, v)) } diff --git a/command/token_create.go b/command/token_create.go index 85107e3fde..21d85d92b5 100644 --- a/command/token_create.go +++ b/command/token_create.go @@ -15,12 +15,14 @@ type TokenCreateCommand struct { } func (c *TokenCreateCommand) Run(args []string) int { + var format string var displayName, lease string var orphan bool var metadata map[string]string var numUses int var policies []string flags := c.Meta.FlagSet("mount", FlagSetDefault) + flags.StringVar(&format, "format", "table", "") flags.StringVar(&displayName, "display-name", "", "") flags.StringVar(&lease, "lease", "", "") flags.BoolVar(&orphan, "orphan", false, "") @@ -61,8 +63,7 @@ func (c *TokenCreateCommand) Run(args []string) int { return 2 } - c.Ui.Output(secret.Auth.ClientToken) - return 0 + return OutputSecret(c.Ui, format, secret) } func (c *TokenCreateCommand) Synopsis() string { @@ -121,6 +122,10 @@ Token Options: -use-limit=5 The number of times this token can be used until it is automatically revoked. + + -format=table The format for output. By default it is a whitespace- + delimited table. This can also be json. + ` return strings.TrimSpace(helpText) } diff --git a/command/token_create_test.go b/command/token_create_test.go index 2b659165c5..93482bad1e 100644 --- a/command/token_create_test.go +++ b/command/token_create_test.go @@ -1,6 +1,7 @@ package command import ( + "strings" "testing" "github.com/hashicorp/vault/http" @@ -27,4 +28,10 @@ func TestTokenCreate(t *testing.T) { if code := c.Run(args); code != 0 { t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) } + + // Ensure we get lease info + output := ui.OutputWriter.String() + if !strings.Contains(output, "token_duration") { + t.Fatalf("bad: %#v", output) + } } From c60889572ea7232775ff53d979904e0a7646cb6a Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:23:59 -0700 Subject: [PATCH 17/33] vault: support core shutdown --- vault/core.go | 23 +++++++++++++++++++++++ vault/core_test.go | 11 +++++++++++ 2 files changed, 34 insertions(+) diff --git a/vault/core.go b/vault/core.go index 9791412c08..3c67f63fad 100644 --- a/vault/core.go +++ b/vault/core.go @@ -328,6 +328,21 @@ func NewCore(conf *CoreConfig) (*Core, error) { return c, nil } +// Shutdown is invoked when the Vault instance is about to be terminated. It +// should not be accessible as part of an API call as it will cause an availability +// problem. It is only used to gracefully quit in the case of HA so that failover +// happens as quickly as possible. +func (c *Core) Shutdown() error { + c.stateLock.Lock() + defer c.stateLock.Unlock() + if c.sealed { + return nil + } + + // Seal the Vault, causes a leader stepdown + return c.sealInternal() +} + // HandleRequest is used to handle a new incoming request func (c *Core) HandleRequest(req *logical.Request) (resp *logical.Response, err error) { c.stateLock.RLock() @@ -930,6 +945,14 @@ func (c *Core) Seal(token string) error { return err } + // Seal the Vault + return c.sealInternal() +} + +// sealInternal is an internal method used to seal the vault. +// It does not do any authorization checking. The stateLock must +// be held prior to calling. +func (c *Core) sealInternal() error { // Enable that we are sealed to prevent furthur transactions c.sealed = true diff --git a/vault/core_test.go b/vault/core_test.go index 25cebf1943..4b48c59acd 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -348,6 +348,17 @@ func TestCore_SealUnseal(t *testing.T) { } } +// Attempt to shutdown after unseal +func TestCore_Shutdown(t *testing.T) { + c, _, _ := TestCoreUnsealed(t) + if err := c.Shutdown(); err != nil { + t.Fatalf("err: %v", err) + } + if sealed, err := c.Sealed(); err != nil || !sealed { + t.Fatalf("err: %v", err) + } +} + // Attempt to seal bad token func TestCore_Seal_BadToken(t *testing.T) { c, _, _ := TestCoreUnsealed(t) From 70ee1866ca35f60a209fd88a48f18b02b040cc2c Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:24:56 -0700 Subject: [PATCH 18/33] server: graceful shutdown for fast failover. Fixes #308 --- cli/commands.go | 20 ++++++++++++++++++++ command/server.go | 14 +++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/cli/commands.go b/cli/commands.go index 1a7916c742..074b80dd11 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -2,6 +2,8 @@ package cli import ( "os" + "os/signal" + "syscall" auditFile "github.com/hashicorp/vault/builtin/audit/file" auditSyslog "github.com/hashicorp/vault/builtin/audit/syslog" @@ -68,6 +70,7 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { "transit": transit.Factory, "mysql": mysql.Factory, }, + ShutdownCh: makeShutdownCh(), }, nil }, @@ -268,3 +271,20 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { }, } } + +// makeShutdownCh returns a channel that can be used for shutdown +// notifications for commands. This channel will send a message for every +// interrupt or SIGTERM received. +func makeShutdownCh() <-chan struct{} { + resultCh := make(chan struct{}) + + signalCh := make(chan os.Signal, 4) + signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) + go func() { + for { + <-signalCh + resultCh <- struct{}{} + } + }() + return resultCh +} diff --git a/command/server.go b/command/server.go index 2796b22cda..ed842a26f8 100644 --- a/command/server.go +++ b/command/server.go @@ -32,6 +32,7 @@ type ServerCommand struct { CredentialBackends map[string]logical.Factory LogicalBackends map[string]logical.Factory + ShutdownCh <-chan struct{} Meta } @@ -237,7 +238,14 @@ func (c *ServerCommand) Run(args []string) int { // Release the log gate. logGate.Flush() - <-make(chan struct{}) + // Wait for shutdown + select { + case <-c.ShutdownCh: + c.Ui.Output("==> Vault shutdown triggered") + if err := core.Shutdown(); err != nil { + c.Ui.Error(fmt.Sprintf("Error with core shutdown: %s", err)) + } + } return 0 } @@ -407,8 +415,8 @@ General Options: specified multiple times. If it is a directory, all files with a ".hcl" or ".json" suffix will be loaded. - -dev Enables Dev mode. In this mode, Vault is completely - in-memory and unsealed. Do not run the Dev server in + -dev Enables Dev mode. In this mode, Vault is completely + in-memory and unsealed. Do not run the Dev server in production! -log-level=info Log verbosity. Defaults to "info", will be outputted From 0277cedc8acae3e892d654750d8bc1940c1560d4 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:33:15 -0700 Subject: [PATCH 19/33] cmomand/read: strip path prefix if necessary. Fixes #343 --- command/read.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/command/read.go b/command/read.go index 8f94bcba8f..85d23b73e3 100644 --- a/command/read.go +++ b/command/read.go @@ -27,7 +27,11 @@ func (c *ReadCommand) Run(args []string) int { flags.Usage() return 1 } + path := args[0] + if path[0] == '/' { + path = path[1:] + } client, err := c.Client() if err != nil { @@ -98,7 +102,7 @@ Read Options: -format=table The format for output. By default it is a whitespace- delimited table. This can also be json. - -field=field If included, the raw value of the specified field + -field=field If included, the raw value of the specified field will be output raw to stdout. ` From 9238c6def36195a99acf210a1b60b699a6099d92 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:42:23 -0700 Subject: [PATCH 20/33] secret/transit: Use special endpoint to get underlying keys. Fixes #219 --- builtin/logical/transit/backend.go | 2 + builtin/logical/transit/backend_test.go | 39 ++++++++++++++++++ builtin/logical/transit/path_keys.go | 1 - builtin/logical/transit/path_raw.go | 54 +++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 builtin/logical/transit/path_raw.go diff --git a/builtin/logical/transit/backend.go b/builtin/logical/transit/backend.go index b63f903b13..d719d42c64 100644 --- a/builtin/logical/transit/backend.go +++ b/builtin/logical/transit/backend.go @@ -15,11 +15,13 @@ func Backend() *framework.Backend { PathsSpecial: &logical.Paths{ Root: []string{ "keys/*", + "raw/*", }, }, Paths: []*framework.Path{ pathKeys(), + pathRaw(), pathEncrypt(), pathDecrypt(), }, diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index c4c7c3d435..cb2a5e1872 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -21,10 +21,12 @@ func TestBackend_basic(t *testing.T) { Steps: []logicaltest.TestStep{ testAccStepWritePolicy(t, "test"), testAccStepReadPolicy(t, "test", false), + testAccStepReadRaw(t, "test", false), testAccStepEncrypt(t, "test", testPlaintext, decryptData), testAccStepDecrypt(t, "test", testPlaintext, decryptData), testAccStepDeletePolicy(t, "test"), testAccStepReadPolicy(t, "test", true), + testAccStepReadRaw(t, "test", true), }, }) } @@ -65,6 +67,43 @@ func testAccStepReadPolicy(t *testing.T, name string, expectNone bool) logicalte return err } + if d.Name != name { + return fmt.Errorf("bad: %#v", d) + } + if d.CipherMode != "aes-gcm" { + return fmt.Errorf("bad: %#v", d) + } + // Should NOT get a key back + if d.Key != nil { + return fmt.Errorf("bad: %#v", d) + } + return nil + }, + } +} + +func testAccStepReadRaw(t *testing.T, name string, expectNone bool) logicaltest.TestStep { + return logicaltest.TestStep{ + Operation: logical.ReadOperation, + Path: "raw/" + name, + Check: func(resp *logical.Response) error { + if resp == nil && !expectNone { + return fmt.Errorf("missing response") + } else if expectNone { + if resp != nil { + return fmt.Errorf("response when expecting none") + } + return nil + } + var d struct { + Name string `mapstructure:"name"` + Key []byte `mapstructure:"key"` + CipherMode string `mapstructure:"cipher_mode"` + } + if err := mapstructure.Decode(resp.Data, &d); err != nil { + return err + } + if d.Name != name { return fmt.Errorf("bad: %#v", d) } diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index b856964689..9b2424004a 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -124,7 +124,6 @@ func pathPolicyRead( resp := &logical.Response{ Data: map[string]interface{}{ "name": p.Name, - "key": p.Key, "cipher_mode": p.CipherMode, }, } diff --git a/builtin/logical/transit/path_raw.go b/builtin/logical/transit/path_raw.go new file mode 100644 index 0000000000..ebe411a571 --- /dev/null +++ b/builtin/logical/transit/path_raw.go @@ -0,0 +1,54 @@ +package transit + +import ( + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +func pathRaw() *framework.Path { + return &framework.Path{ + Pattern: `raw/(?P\w+)`, + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + Description: "Name of the key", + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: pathRawRead, + }, + + HelpSynopsis: pathPolicyHelpSyn, + HelpDescription: pathPolicyHelpDesc, + } +} + +func pathRawRead( + req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + name := d.Get("name").(string) + p, err := getPolicy(req, name) + if err != nil { + return nil, err + } + if p == nil { + return nil, nil + } + + // Return the response + resp := &logical.Response{ + Data: map[string]interface{}{ + "name": p.Name, + "key": p.Key, + "cipher_mode": p.CipherMode, + }, + } + return resp, nil +} + +const pathRawHelpSyn = `Fetch raw keys for named encrption keys` + +const pathRawHelpDesc = ` +This path is used to get the underlying encryption keys used for the +named keys that are available. +` From 7c31e29295aa493a4c08b0e2e20865a709d1c98f Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:45:29 -0700 Subject: [PATCH 21/33] website: update the transit documentation --- .../source/docs/secrets/transit/index.html.md | 61 +++++++++++++++---- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/website/source/docs/secrets/transit/index.html.md b/website/source/docs/secrets/transit/index.html.md index 3506eca03a..ae04b2794b 100644 --- a/website/source/docs/secrets/transit/index.html.md +++ b/website/source/docs/secrets/transit/index.html.md @@ -54,6 +54,15 @@ $ vault read transit/keys/foo Key Value name foo cipher_mode aes-gcm +```` + +We can read from the `raw/` endpoint to see the encryption key itself: + +``` +$ vault read transit/raw/foo +Key Value +name foo +cipher_mode aes-gcm key PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8= ```` @@ -114,17 +123,7 @@ only encrypt or decrypt using the named keys they need access to.
Returns
- - ```javascript - { - "data": { - "name": "foo", - "cipher_mode": "aes-gcm", - "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" - } - } - ``` - + A `204` response code.
@@ -156,7 +155,6 @@ only encrypt or decrypt using the named keys they need access to. "data": { "name": "foo", "cipher_mode": "aes-gcm", - "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" } } ``` @@ -269,3 +267,42 @@ only encrypt or decrypt using the named keys they need access to. + +### /transit/raw/ +#### GET + +
+
Description
+
+ Returns raw information about a named encryption key, + Including the underlying encryption key. This is a root protected endpoint. +
+ +
Method
+
GET
+ +
URL
+
`/transit/raw/`
+ +
Parameters
+
+ None +
+ +
Returns
+
+ + ```javascript + { + "data": { + "name": "foo", + "cipher_mode": "aes-gcm", + "key": "PhKFTALCmhAhVQfMBAH4+UwJ6J2gybapUH9BsrtIgR8=" + } + } + ``` + +
+
+ + From 96119946f348e3d6ab7c6acedf9c00e46cf59086 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:51:05 -0700 Subject: [PATCH 22/33] secret/transit: allow policies to be upserted --- builtin/logical/transit/backend_test.go | 15 ++++++ builtin/logical/transit/path_encrypt.go | 6 ++- builtin/logical/transit/path_keys.go | 66 ++++++++++++++----------- 3 files changed, 58 insertions(+), 29 deletions(-) diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index cb2a5e1872..c03ab0e326 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -31,6 +31,21 @@ func TestBackend_basic(t *testing.T) { }) } +func TestBackend_upsert(t *testing.T) { + decryptData := make(map[string]interface{}) + logicaltest.Test(t, logicaltest.TestCase{ + Backend: Backend(), + Steps: []logicaltest.TestStep{ + testAccStepReadPolicy(t, "test", true), + testAccStepEncrypt(t, "test", testPlaintext, decryptData), + testAccStepReadPolicy(t, "test", false), + testAccStepDecrypt(t, "test", testPlaintext, decryptData), + testAccStepDeletePolicy(t, "test"), + testAccStepReadPolicy(t, "test", true), + }, + }) +} + func testAccStepWritePolicy(t *testing.T, name string) logicaltest.TestStep { return logicaltest.TestStep{ Operation: logical.WriteOperation, diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index d30d72f3b4..761af7e352 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -5,6 +5,7 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" + "fmt" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -56,7 +57,10 @@ func pathEncryptWrite( // Error if invalid policy if p == nil { - return logical.ErrorResponse("policy not found"), logical.ErrInvalidRequest + p, err = generatePolicy(req.Storage, name) + if err != nil { + return logical.ErrorResponse(fmt.Sprintf("failed to upsert policy: %v", err)), logical.ErrInvalidRequest + } } // Guard against a potentially invalid cipher-mode diff --git a/builtin/logical/transit/path_keys.go b/builtin/logical/transit/path_keys.go index 9b2424004a..f83ae4037e 100644 --- a/builtin/logical/transit/path_keys.go +++ b/builtin/logical/transit/path_keys.go @@ -45,6 +45,41 @@ func getPolicy(req *logical.Request, name string) (*Policy, error) { return p, nil } +// generatePolicy is used to create a new named policy with +// a randomly generated key +func generatePolicy(storage logical.Storage, name string) (*Policy, error) { + // Create the policy object + p := &Policy{ + Name: name, + CipherMode: "aes-gcm", + } + + // Generate a 256bit key + p.Key = make([]byte, 32) + _, err := rand.Read(p.Key) + if err != nil { + return nil, err + } + + // Encode the policy + buf, err := p.Serialize() + if err != nil { + return nil, err + } + + // Write the policy into storage + err = storage.Put(&logical.StorageEntry{ + Key: "policy/" + name, + Value: buf, + }) + if err != nil { + return nil, err + } + + // Return the policy + return p, nil +} + func pathKeys() *framework.Path { return &framework.Path{ Pattern: `keys/(?P\w+)`, @@ -79,34 +114,9 @@ func pathPolicyWrite( return nil, nil } - // Create the policy object - p := &Policy{ - Name: name, - CipherMode: "aes-gcm", - } - - // Generate a 256bit key - p.Key = make([]byte, 32) - _, err = rand.Read(p.Key) - if err != nil { - return nil, err - } - - // Encode the policy - buf, err := p.Serialize() - if err != nil { - return nil, err - } - - // Write the policy into storage - err = req.Storage.Put(&logical.StorageEntry{ - Key: "policy/" + name, - Value: buf, - }) - if err != nil { - return nil, err - } - return nil, nil + // Generate the policy + _, err = generatePolicy(req.Storage, name) + return nil, err } func pathPolicyRead( From ba24d891fd907fff44ec86fb36b5d0c345840c50 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 18:51:58 -0700 Subject: [PATCH 23/33] website: document transit upsert behavior --- website/source/docs/secrets/transit/index.html.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/website/source/docs/secrets/transit/index.html.md b/website/source/docs/secrets/transit/index.html.md index ae04b2794b..743ba0662e 100644 --- a/website/source/docs/secrets/transit/index.html.md +++ b/website/source/docs/secrets/transit/index.html.md @@ -194,7 +194,9 @@ only encrypt or decrypt using the named keys they need access to.
Description
- Encrypts the provided plaintext using the named key. + Encrypts the provided plaintext using the named key. If the named key + does not already exist, it will be automatically generated for the given + name with the default parameters.
Method
From ee176b2f5de65274deba8ec525a90755f7c59605 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Wed, 17 Jun 2015 19:19:02 -0700 Subject: [PATCH 24/33] command/auth: warn about the VAULT_TOKEN env var. Fixes #195 --- command/auth.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/command/auth.go b/command/auth.go index 130dd7d278..46ea4657a5 100644 --- a/command/auth.go +++ b/command/auth.go @@ -173,6 +173,15 @@ func (c *AuthCommand) Run(args []string) int { return 1 } + // Warn if the VAULT_TOKEN environment variable is set, as that will take + // precedence + if os.Getenv("VAULT_TOKEN") != "" { + c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") + c.Ui.Output(" The environment variable takes precedence over the value") + c.Ui.Output(" set by the auth command. Either update the value of the") + c.Ui.Output(" environment variable or unset it to use the new token.\n") + } + // Get the policies we have policiesRaw, ok := secret.Data["policies"] if !ok { From 0696bc47e0c2e128b8cf89dea2826a1ac2de5ebc Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 13:48:04 -0700 Subject: [PATCH 25/33] command/auth: warn earlier about VAULT_TOKEN --- command/auth.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/command/auth.go b/command/auth.go index 46ea4657a5..f31b2a358a 100644 --- a/command/auth.go +++ b/command/auth.go @@ -114,6 +114,15 @@ func (c *AuthCommand) Run(args []string) int { return 0 } + // Warn if the VAULT_TOKEN environment variable is set, as that will take + // precedence + if os.Getenv("VAULT_TOKEN") != "" { + c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") + c.Ui.Output(" The environment variable takes precedence over the value") + c.Ui.Output(" set by the auth command. Either update the value of the") + c.Ui.Output(" environment variable or unset it to use the new token.\n") + } + var vars map[string]string if len(args) > 0 { builder := kvbuilder.Builder{Stdin: os.Stdin} @@ -173,15 +182,6 @@ func (c *AuthCommand) Run(args []string) int { return 1 } - // Warn if the VAULT_TOKEN environment variable is set, as that will take - // precedence - if os.Getenv("VAULT_TOKEN") != "" { - c.Ui.Output("==> WARNING: VAULT_TOKEN environment variable set!\n") - c.Ui.Output(" The environment variable takes precedence over the value") - c.Ui.Output(" set by the auth command. Either update the value of the") - c.Ui.Output(" environment variable or unset it to use the new token.\n") - } - // Get the policies we have policiesRaw, ok := secret.Data["policies"] if !ok { From 57d1230e6c38fba51d8d29904ef25c07eed7bfba Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 13:48:18 -0700 Subject: [PATCH 26/33] command/server: fixing output weirdness --- command/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/command/server.go b/command/server.go index ed842a26f8..f3db676d76 100644 --- a/command/server.go +++ b/command/server.go @@ -155,7 +155,7 @@ func (c *ServerCommand) Run(args []string) int { "immediately begin using the Vault CLI.\n\n"+ "The only step you need to take is to set the following\n"+ "environment variables:\n\n"+ - " export VAULT_ADDR='http://127.0.0.1:8200'\n"+ + " export VAULT_ADDR='http://127.0.0.1:8200'\n\n"+ "The unseal key and root token are reproduced below in case you\n"+ "want to seal/unseal the Vault or play with authentication.\n\n"+ "Unseal Key: %s\nRoot Token: %s\n", From 27728075478ec48070fb9b0f1670b6e108cbeedb Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 13:51:06 -0700 Subject: [PATCH 27/33] command/write: adding force flag for when no data fields are necessary. Fixes #357 --- command/write.go | 11 ++++++++++- command/write_test.go | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/command/write.go b/command/write.go index fefa23e72d..ba0fdd823d 100644 --- a/command/write.go +++ b/command/write.go @@ -19,15 +19,18 @@ type WriteCommand struct { func (c *WriteCommand) Run(args []string) int { var format string + var force bool flags := c.Meta.FlagSet("write", FlagSetDefault) flags.StringVar(&format, "format", "table", "") + flags.BoolVar(&force, "force", false, "") + flags.BoolVar(&force, "f", false, "") flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { return 1 } args = flags.Args() - if len(args) < 2 { + if len(args) < 2 && !force { c.Ui.Error("write expects at least two arguments") flags.Usage() return 1 @@ -117,6 +120,12 @@ General Options: not recommended. This is especially not recommended for unsealing a vault. +Write Options: + + -f | -force Force the write to continue without any data values + specified. This allows writing to keys that do not + need or expect any fields to be specified. + ` return strings.TrimSpace(helpText) } diff --git a/command/write_test.go b/command/write_test.go index ce570f3730..51774e3c0b 100644 --- a/command/write_test.go +++ b/command/write_test.go @@ -246,3 +246,26 @@ func TestWrite_Output(t *testing.T) { t.Fatalf("bad: %s", string(ui.OutputWriter.Bytes())) } } + +func TestWrite_force(t *testing.T) { + core, _, token := vault.TestCoreUnsealed(t) + ln, addr := http.TestServer(t, core) + defer ln.Close() + + ui := new(cli.MockUi) + c := &WriteCommand{ + Meta: Meta{ + ClientToken: token, + Ui: ui, + }, + } + + args := []string{ + "-address", addr, + "-force", + "sys/rotate", + } + if code := c.Run(args); code != 0 { + t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) + } +} From 46ba8d10a555b8d5287fd0b47812574b036e2e06 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 14:31:00 -0700 Subject: [PATCH 28/33] physical/mysql: cleanup and documentation --- physical/mysql.go | 163 ++++++++++------------- physical/mysql_test.go | 77 ++--------- website/source/docs/config/index.html.md | 17 +++ 3 files changed, 92 insertions(+), 165 deletions(-) diff --git a/physical/mysql.go b/physical/mysql.go index bb4a89c5e5..5db23b9384 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -2,7 +2,6 @@ package physical import ( "database/sql" - "errors" "fmt" "sort" "strings" @@ -12,16 +11,10 @@ import ( _ "github.com/go-sql-driver/mysql" ) -var ( - MySQLPrepareStmtFailure = errors.New("failed to prepare statement") - MySQLExecuteStmtFailure = errors.New("failed to execute statement") -) - // MySQLBackend is a physical backend that stores data // within MySQL database. type MySQLBackend struct { - table string - database string + dbTable string client *sql.DB statements map[string]*sql.Stmt } @@ -29,84 +22,85 @@ type MySQLBackend struct { // newMySQLBackend constructs a MySQL backend using the given API client and // server address and credential for accessing mysql database. func newMySQLBackend(conf map[string]string) (Backend, error) { + // Get the MySQL credentials to perform read/write operations. + username, ok := conf["username"] + if !ok || username == "" { + return nil, fmt.Errorf("missing username") + } + password, ok := conf["password"] + if !ok || username == "" { + return nil, fmt.Errorf("missing password") + } + // Get or set MySQL server address. Defaults to localhost and default port(3306) address, ok := conf["address"] if !ok { address = "127.0.0.1:3306" } - // Get the MySQL credentials to perform read/write operations. - username, ok := conf["username"] - password, ok := conf["password"] - // Get the MySQL database and table details. database, ok := conf["database"] if !ok { - return nil, fmt.Errorf("database name is missing in the configuration") + database = "vault" } table, ok := conf["table"] if !ok { - return nil, fmt.Errorf("table name is missing in the configuration") + table = "vault" } + dbTable := database + "." + table // Create MySQL handle for the database. - dsn := username + ":" + password + "@tcp(" + address + ")/" + database + dsn := username + ":" + password + "@tcp(" + address + ")/" db, err := sql.Open("mysql", dsn) if err != nil { - return nil, fmt.Errorf("failed to open handler with database") + return nil, fmt.Errorf("failed to connect to mysql: %v", err) + } + + // Create the required database if it doesn't exists. + if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS " + database); err != nil { + return nil, fmt.Errorf("failed to create mysql database: %v", err) } // Create the required table if it doesn't exists. - create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" - create_stmt, err := db.Prepare(create_query) - if err != nil { - return nil, MySQLPrepareStmtFailure + create_query := "CREATE TABLE IF NOT EXISTS " + dbTable + + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" + if _, err := db.Exec(create_query); err != nil { + return nil, fmt.Errorf("failed to create mysql table: %v", err) } - defer create_stmt.Close() - - _, err = create_stmt.Exec() - if err != nil { - return nil, MySQLExecuteStmtFailure - } - - // Map of query type as key to prepared statement. - statements := make(map[string]*sql.Stmt) - - // Prepare statement for put query. - insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" - insert_stmt, err := db.Prepare(insert_query) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - statements["put"] = insert_stmt - - // Prepare statement for select query. - select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?" - select_stmt, err := db.Prepare(select_query) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - statements["get"] = select_stmt - - // Prepare statement for delete query. - delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?" - delete_stmt, err := db.Prepare(delete_query) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - statements["delete"] = delete_stmt // Setup the backend. m := &MySQLBackend{ + dbTable: dbTable, client: db, - table: table, - database: database, - statements: statements, + statements: make(map[string]*sql.Stmt), } + // Prepare all the statements required + statements := map[string]string{ + "put": "INSERT INTO " + dbTable + + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)", + "get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?", + "delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?", + "list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?", + } + for name, query := range statements { + if err := m.prepare(name, query); err != nil { + return nil, err + } + } return m, nil } +// prepare is a helper to prepare a query for future execution +func (m *MySQLBackend) prepare(name, query string) error { + stmt, err := m.client.Prepare(query) + if err != nil { + return fmt.Errorf("failed to prepare '%s': %v", name, err) + } + m.statements[name] = stmt + return nil +} + // Put is used to insert or update an entry. func (m *MySQLBackend) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) @@ -115,7 +109,6 @@ func (m *MySQLBackend) Put(entry *Entry) error { if err != nil { return err } - return nil } @@ -124,22 +117,18 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) var result []byte - err := m.statements["get"].QueryRow(key).Scan(&result) - if err != nil { + if err == sql.ErrNoRows { return nil, nil } - - // Handle a non-existing value - if result == nil { - return nil, nil + if err != nil { + return nil, err } ent := &Entry{ Key: key, Value: result, } - return ent, nil } @@ -151,7 +140,6 @@ func (m *MySQLBackend) Delete(key string) error { if err != nil { return err } - return nil } @@ -160,45 +148,28 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) - // Query to get all keys matching a prefix. - list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'" - rows, err := m.client.Query(list_query) - if err != nil { - return nil, MySQLExecuteStmtFailure - } + // Add the % wildcard to the prefix to do the prefix search + likePrefix := prefix + "%" + rows, err := m.statements["list"].Query(likePrefix) - columns, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("failed to get columns") - } - - values := make([]sql.RawBytes, len(columns)) - - scanArgs := make([]interface{}, len(values)) - for i := range values { - scanArgs[i] = &values[i] - } - - keys := []string{} + var keys []string for rows.Next() { - err = rows.Scan(scanArgs...) + var key string + err = rows.Scan(&key) if err != nil { - return nil, fmt.Errorf("failed to scan rows") + return nil, fmt.Errorf("failed to scan rows: %v", err) } - for _, key := range values { - key := strings.TrimPrefix(string(key), prefix) - if i := strings.Index(string(key), "/"); i == -1 { - // Add objects only from the current 'folder' - keys = append(keys, string(key)) - } else if i != -1 { - // Add truncated 'folder' paths - keys = appendIfMissing(keys, string(key[:i+1])) - } + key = strings.TrimPrefix(key, prefix) + if i := strings.Index(key, "/"); i == -1 { + // Add objects only from the current 'folder' + keys = append(keys, key) + } else if i != -1 { + // Add truncated 'folder' paths + keys = appendIfMissing(keys, string(key[:i+1])) } } sort.Strings(keys) - return keys, nil } diff --git a/physical/mysql_test.go b/physical/mysql_test.go index 2b00404ffb..a28fb1441f 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -1,8 +1,6 @@ package physical import ( - "database/sql" - "fmt" "os" "testing" @@ -28,61 +26,6 @@ func TestMySQLBackend(t *testing.T) { username := os.Getenv("MYSQL_USERNAME") password := os.Getenv("MYSQL_PASSWORD") - // Create MySQL handle for the database. - db, err := sql.Open("mysql", username+":"+password+"@tcp("+address+")/"+database) - - if err != nil { - t.Fatalf("Failed to open an handler with database: %v", err) - } - defer db.Close() - - // Prepare statement for creating table. - create_stmt := "CREATE TABLE IF NOT EXISTS test.square (num int, sqr int, PRIMARY KEY (num))" - stmtCrt, err := db.Prepare(create_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmtCrt.Close() - - // Create table - _, err = stmtCrt.Exec() - if err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - // Prepare statement for inserting data. - insert_stmt := "INSERT INTO test.square VALUES( ?, ? ) ON DUPLICATE KEY UPDATE sqr=VALUES(sqr)" - stmtIns, err := db.Prepare(insert_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmtIns.Close() - - // Prepare statement for reading data. - select_stmt := "SELECT sqr FROM test.square WHERE num = ?" - stmtOut, err := db.Prepare(select_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmtOut.Close() - - // Insert square numbers for 0-24 in the database - for i := 0; i < 25; i++ { - _, err = stmtIns.Exec(i, (i * i)) // Insert tuples (i, i^2) - if err != nil { - t.Fatalf("Failed to insert data: %v", err) - } - } - - var square int - - // Query the square-number of 13 - err = stmtOut.QueryRow(13).Scan(&square) - if err != nil { - t.Fatalf("Failed to query data: %v", err) - } - fmt.Printf("The square number of 13 is: %d", square) - // Run vault tests b, err := NewBackend("mysql", map[string]string{ "address": address, @@ -96,19 +39,15 @@ func TestMySQLBackend(t *testing.T) { t.Fatalf("Failed to create new backend: %v", err) } + defer func() { + mysql := b.(*MySQLBackend) + _, err := mysql.client.Exec("DROP TABLE " + mysql.dbTable) + if err != nil { + t.Fatalf("Failed to drop table: %v", err) + } + }() + testBackend(t, b) testBackend_ListPrefix(t, b) - // Drop table after running tests - drop_stmt := "DROP TABLE " + database + "." + table - stmt, err := db.Prepare(drop_stmt) - if err != nil { - t.Fatalf("Failed to prepare statement: %v", err) - } - defer stmt.Close() - - _, err = stmt.Exec() - if err != nil { - t.Fatalf("Failed to drop table: %v", err) - } } diff --git a/website/source/docs/config/index.html.md b/website/source/docs/config/index.html.md index 278b4d9549..8cae4df7c6 100644 --- a/website/source/docs/config/index.html.md +++ b/website/source/docs/config/index.html.md @@ -76,6 +76,8 @@ durability, etc. * `s3` - Store data within an S3 bucket [S3](http://aws.amazon.com/s3/). This backend does not support HA. + * `mysql` - Store data within MySQL. This backend does not support HA. + * `inmem` - Store data in-memory. This is only really useful for development and experimentation. Data is lost whenever Vault is restarted. @@ -143,6 +145,21 @@ For S3, the following options are supported: * `region` (optional) - The AWS region. It can be sourced from the AWS_DEFAULT_REGION environment variable and will default to "us-east-1" if not specified. +#### Backend Reference: MySQL + +The MySQL backend has the following options: + + * `username` (required) - The MySQL username to connect with. + + * `password` (required) - The MySQL password to connect with. + + * `address` (optional) - The address of the MySQL host. Defaults to + "127.0.0.1:3306. + + * `database` (optional) - The name of the database to use. Defaults to "vault". + + * `table` (optional) - The name of the table to use. Defaults to "vault". + #### Backend Reference: Inmem The in-memory backend has no configuration options. From 2d0cde4ccc0323591d9414342cb15f5cb70271d7 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 15:37:08 -0700 Subject: [PATCH 29/33] vault: improve lease error message. Fixes #338 --- vault/expiration.go | 2 +- vault/expiration_test.go | 2 +- vault/logical_system_test.go | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vault/expiration.go b/vault/expiration.go index 294624d6e0..3df1738ed2 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -635,7 +635,7 @@ func (l *leaseEntry) encode() ([]byte, error) { func (le *leaseEntry) renewable() error { // If there is no entry, cannot review if le == nil || le.ExpireTime.IsZero() { - return fmt.Errorf("lease not found") + return fmt.Errorf("lease not found or lease is not renewable") } // Determine if the lease is expired diff --git a/vault/expiration_test.go b/vault/expiration_test.go index f7b581be25..0fc4484bd8 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -152,7 +152,7 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) { // Should not be able to renew, no expiration _, err = exp.RenewToken("auth/github/login", root.ID, 0) - if err.Error() != "lease not found" { + if err.Error() != "lease not found or lease is not renewable" { t.Fatalf("err: %v", err) } diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index fd23f59aa4..33749ae14f 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -202,7 +202,7 @@ func TestSystemBackend_renew_invalidID(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp.Data["error"] != "lease not found" { + if resp.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } @@ -250,7 +250,7 @@ func TestSystemBackend_revoke(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp3.Data["error"] != "lease not found" { + if resp3.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } @@ -312,7 +312,7 @@ func TestSystemBackend_revokePrefix(t *testing.T) { if err != logical.ErrInvalidRequest { t.Fatalf("err: %v", err) } - if resp3.Data["error"] != "lease not found" { + if resp3.Data["error"] != "lease not found or lease is not renewable" { t.Fatalf("bad: %v", resp) } } From 48e7531f7954a789c44948a42a78bfa237ad638c Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 15:56:42 -0700 Subject: [PATCH 30/33] command/path-help: rename command, better error if sealed. Fixes #234 --- cli/commands.go | 4 ++-- command/{help.go => path_help.go} | 26 ++++++++++++++------- command/{help_test.go => path_help_test.go} | 2 +- 3 files changed, 21 insertions(+), 11 deletions(-) rename command/{help.go => path_help.go} (69%) rename command/{help_test.go => path_help_test.go} (95%) diff --git a/cli/commands.go b/cli/commands.go index 074b80dd11..d9cc3a8624 100644 --- a/cli/commands.go +++ b/cli/commands.go @@ -74,8 +74,8 @@ func Commands(metaPtr *command.Meta) map[string]cli.CommandFactory { }, nil }, - "help": func() (cli.Command, error) { - return &command.HelpCommand{ + "path-help": func() (cli.Command, error) { + return &command.PathHelpCommand{ Meta: meta, }, nil }, diff --git a/command/help.go b/command/path_help.go similarity index 69% rename from command/help.go rename to command/path_help.go index f832f07bc6..792ea9e3a4 100644 --- a/command/help.go +++ b/command/path_help.go @@ -5,12 +5,12 @@ import ( "strings" ) -// HelpCommand is a Command that lists the mounts. -type HelpCommand struct { +// PathHelpCommand is a Command that lists the mounts. +type PathHelpCommand struct { Meta } -func (c *HelpCommand) Run(args []string) int { +func (c *PathHelpCommand) Run(args []string) int { flags := c.Meta.FlagSet("help", FlagSetDefault) flags.Usage = func() { c.Ui.Error(c.Help()) } if err := flags.Parse(args); err != nil { @@ -35,8 +35,15 @@ func (c *HelpCommand) Run(args []string) int { help, err := client.Help(path) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading help: %s", err)) + if strings.Contains(err.Error(), "Vault is sealed") { + c.Ui.Error(`Error: Vault is sealed. + +The path-help command requires the Vault to be unsealed so that +mount points of secret backends are known.`) + } else { + c.Ui.Error(fmt.Sprintf( + "Error reading help: %s", err)) + } return 1 } @@ -44,13 +51,13 @@ func (c *HelpCommand) Run(args []string) int { return 0 } -func (c *HelpCommand) Synopsis() string { +func (c *PathHelpCommand) Synopsis() string { return "Look up the help for a path" } -func (c *HelpCommand) Help() string { +func (c *PathHelpCommand) Help() string { helpText := ` -Usage: vault help [options] path +Usage: vault path-help [options] path Look up the help for a path. @@ -58,6 +65,9 @@ Usage: vault help [options] path providers provide built-in help. This command looks up and outputs that help. + The command requires that the Vault be unsealed, because otherwise + the mount points of the backends are unknown. + General Options: -address=addr The address of the Vault server. diff --git a/command/help_test.go b/command/path_help_test.go similarity index 95% rename from command/help_test.go rename to command/path_help_test.go index c4facc0ca8..faec9723d9 100644 --- a/command/help_test.go +++ b/command/path_help_test.go @@ -14,7 +14,7 @@ func TestHelp(t *testing.T) { defer ln.Close() ui := new(cli.MockUi) - c := &HelpCommand{ + c := &PathHelpCommand{ Meta: Meta{ ClientToken: token, Ui: ui, From f91b91289ce18aab351068d55b49c0b0b2542cf6 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 16:00:41 -0700 Subject: [PATCH 31/33] command/read: Ensure only a single argument. Fixes #304 --- command/read.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/command/read.go b/command/read.go index 85d23b73e3..983ae56bb3 100644 --- a/command/read.go +++ b/command/read.go @@ -22,8 +22,8 @@ func (c *ReadCommand) Run(args []string) int { } args = flags.Args() - if len(args) < 1 || len(args) > 2 { - c.Ui.Error("read expects one or two arguments") + if len(args) != 1 { + c.Ui.Error("read expects one argument") flags.Usage() return 1 } From 8c970cf000e9a7f28f9aa94c0cdcf4c4fc874ab8 Mon Sep 17 00:00:00 2001 From: Armon Dadgar Date: Thu, 18 Jun 2015 17:12:21 -0700 Subject: [PATCH 32/33] cli: adding path-help to common commands list --- cli/help.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cli/help.go b/cli/help.go index 6c3e63d8b2..b614212c4f 100644 --- a/cli/help.go +++ b/cli/help.go @@ -13,14 +13,14 @@ import ( // HelpFunc is a cli.HelpFunc that can is used to output the help for Vault. func HelpFunc(commands map[string]cli.CommandFactory) string { commonNames := map[string]struct{}{ - "delete": struct{}{}, - "help": struct{}{}, - "read": struct{}{}, - "renew": struct{}{}, - "revoke": struct{}{}, - "write": struct{}{}, - "server": struct{}{}, - "status": struct{}{}, + "delete": struct{}{}, + "path-help": struct{}{}, + "read": struct{}{}, + "renew": struct{}{}, + "revoke": struct{}{}, + "write": struct{}{}, + "server": struct{}{}, + "status": struct{}{}, } // Determine the maximum key length, and classify based on type From 943d914fec8922235934db0fe57592d470a7eeda Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Fri, 19 Jun 2015 03:31:19 -0700 Subject: [PATCH 33/33] audit: some tests --- audit/hashstructure_test.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/audit/hashstructure_test.go b/audit/hashstructure_test.go index cc8e339333..b827310f0e 100644 --- a/audit/hashstructure_test.go +++ b/audit/hashstructure_test.go @@ -80,6 +80,8 @@ func TestCopy_response(t *testing.T) { } func TestHash(t *testing.T) { + now := time.Now().UTC() + cases := []struct { Input interface{} Output interface{} @@ -116,6 +118,24 @@ func TestHash(t *testing.T) { "foo", "foo", }, + { + &logical.Auth{ + LeaseOptions: logical.LeaseOptions{ + Lease: 1 * time.Hour, + LeaseIssue: now, + }, + + ClientToken: "foo", + }, + &logical.Auth{ + LeaseOptions: logical.LeaseOptions{ + Lease: 1 * time.Hour, + LeaseIssue: now, + }, + + ClientToken: "sha1:0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33", + }, + }, } for _, tc := range cases {