mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-22 15:11:07 +02:00
Changes done as per feedback
This commit is contained in:
parent
9c3881442e
commit
c42bc38c62
@ -3,8 +3,8 @@ package physical
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
@ -12,21 +12,17 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
MySQLDBNameMissing = errors.New("database name is missing in the configuration")
|
||||
MySQLTableNameMissing = errors.New("table name is missing in the configuration")
|
||||
MySQLHandlerCreationFailure = errors.New("failed to open handler with database")
|
||||
MySQLPrepareStmtFailure = errors.New("failed to prepare statement")
|
||||
MySQLExecuteStmtFailure = errors.New("failed to execute statement")
|
||||
MySQLGetColumnsFailure = errors.New("failed to get columns")
|
||||
MySQLScanRowsFailure = errors.New("failed to scan rows")
|
||||
MySQLPrepareStmtFailure = errors.New("failed to prepare statement")
|
||||
MySQLExecuteStmtFailure = errors.New("failed to execute statement")
|
||||
)
|
||||
|
||||
// MySQLBackend is a physical backend that stores data
|
||||
// within MySQL database.
|
||||
type MySQLBackend struct {
|
||||
table string
|
||||
database string
|
||||
client *sql.DB
|
||||
table string
|
||||
database string
|
||||
client *sql.DB
|
||||
statements map[string]*sql.Stmt
|
||||
}
|
||||
|
||||
// newMySQLBackend constructs a MySQL backend using the given API client and
|
||||
@ -45,39 +41,70 @@ func newMySQLBackend(conf map[string]string) (Backend, error) {
|
||||
// Get the MySQL database and table details.
|
||||
database, ok := conf["database"]
|
||||
if !ok {
|
||||
return nil, MySQLDBNameMissing
|
||||
return nil, fmt.Errorf("database name is missing in the configuration")
|
||||
}
|
||||
table, ok := conf["table"]
|
||||
if !ok {
|
||||
return nil, MySQLTableNameMissing
|
||||
return nil, fmt.Errorf("table name is missing in the configuration")
|
||||
}
|
||||
|
||||
// Create MySQL handle for the database.
|
||||
dsn := username + ":" + password + "@tcp(" + address + ")/" + database
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, MySQLHandlerCreationFailure
|
||||
return nil, fmt.Errorf("failed to open handler with database")
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create the required table.
|
||||
create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))"
|
||||
stmt, err := db.Prepare(create_stmt)
|
||||
// Create the required table if it doesn't exists.
|
||||
create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))"
|
||||
create_stmt, err := db.Prepare(create_query)
|
||||
if err != nil {
|
||||
return nil, MySQLPrepareStmtFailure
|
||||
}
|
||||
defer stmt.Close()
|
||||
defer create_stmt.Close()
|
||||
|
||||
_, err = stmt.Exec()
|
||||
_, err = create_stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, MySQLExecuteStmtFailure
|
||||
}
|
||||
|
||||
// Map of query type as key to prepared statement.
|
||||
var statements map[string]*sql.Stmt
|
||||
|
||||
// Prepare statement for put query.
|
||||
insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)"
|
||||
insert_stmt, err := db.Prepare(insert_query)
|
||||
if err != nil {
|
||||
return nil, MySQLPrepareStmtFailure
|
||||
}
|
||||
statements["put"] = insert_stmt
|
||||
defer insert_stmt.Close()
|
||||
|
||||
// Prepare statement for select query.
|
||||
select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?"
|
||||
select_stmt, err := db.Prepare(select_query)
|
||||
if err != nil {
|
||||
return nil, MySQLPrepareStmtFailure
|
||||
}
|
||||
statements["get"] = select_stmt
|
||||
defer select_stmt.Close()
|
||||
|
||||
// Prepare statement for delete query.
|
||||
delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?"
|
||||
delete_stmt, err := db.Prepare(delete_query)
|
||||
if err != nil {
|
||||
return nil, MySQLPrepareStmtFailure
|
||||
}
|
||||
statements["delete"] = delete_stmt
|
||||
defer delete_stmt.Close()
|
||||
|
||||
// Setup the backend.
|
||||
m := &MySQLBackend{
|
||||
client: db,
|
||||
table: table,
|
||||
database: database,
|
||||
client: db,
|
||||
table: table,
|
||||
database: database,
|
||||
statements: statements,
|
||||
}
|
||||
|
||||
return m, nil
|
||||
@ -87,14 +114,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) {
|
||||
func (m *MySQLBackend) Put(entry *Entry) error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
|
||||
|
||||
insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)"
|
||||
stmt, err := m.client.Prepare(insert_stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(entry.Key, entry.Value)
|
||||
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -106,16 +126,9 @@ func (m *MySQLBackend) Put(entry *Entry) error {
|
||||
func (m *MySQLBackend) Get(key string) (*Entry, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
|
||||
|
||||
select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?"
|
||||
stmt, err := m.client.Prepare(select_stmt)
|
||||
if err != nil {
|
||||
return nil, MySQLPrepareStmtFailure
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
var result []byte
|
||||
|
||||
err = stmt.QueryRow(key).Scan(&result)
|
||||
err := m.statements["get"].QueryRow(key).Scan(&result)
|
||||
if err != nil {
|
||||
return nil, MySQLExecuteStmtFailure
|
||||
}
|
||||
@ -132,14 +145,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) {
|
||||
func (m *MySQLBackend) Delete(key string) error {
|
||||
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
|
||||
|
||||
delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?"
|
||||
stmt, err := m.client.Prepare(delete_stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(key)
|
||||
_, err := m.statements["delete"].Exec(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -152,15 +158,15 @@ func (m *MySQLBackend) Delete(key string) error {
|
||||
func (m *MySQLBackend) List(prefix string) ([]string, error) {
|
||||
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
|
||||
|
||||
list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table
|
||||
rows, err := m.client.Query(list_stmt)
|
||||
list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'"
|
||||
rows, err := m.client.Query(list_query)
|
||||
if err != nil {
|
||||
return nil, MySQLExecuteStmtFailure
|
||||
}
|
||||
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, MySQLGetColumnsFailure
|
||||
return nil, fmt.Errorf("failed to get columns")
|
||||
}
|
||||
|
||||
values := make([]sql.RawBytes, len(columns))
|
||||
@ -174,13 +180,11 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) {
|
||||
for rows.Next() {
|
||||
err = rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
return nil, MySQLScanRowsFailure
|
||||
return nil, fmt.Errorf("failed to scan rows")
|
||||
}
|
||||
|
||||
for _, col := range values {
|
||||
if strings.HasPrefix(string(col), prefix) {
|
||||
keys = append(keys, string(col))
|
||||
}
|
||||
keys = append(keys, string(col))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) {
|
||||
defer db.Close()
|
||||
|
||||
// Prepare statement for creating table.
|
||||
create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))"
|
||||
create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))"
|
||||
stmtCrt, err := db.Prepare(create_stmt)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to prepare statement: %v", err)
|
||||
|
Loading…
x
Reference in New Issue
Block a user