mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-23 07:31:09 +02:00
Changes done as per feedback
This commit is contained in:
parent
9c3881442e
commit
c42bc38c62
@ -3,8 +3,8 @@ package physical
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/armon/go-metrics"
|
"github.com/armon/go-metrics"
|
||||||
@ -12,13 +12,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
MySQLDBNameMissing = errors.New("database name is missing in the configuration")
|
|
||||||
MySQLTableNameMissing = errors.New("table name is missing in the configuration")
|
|
||||||
MySQLHandlerCreationFailure = errors.New("failed to open handler with database")
|
|
||||||
MySQLPrepareStmtFailure = errors.New("failed to prepare statement")
|
MySQLPrepareStmtFailure = errors.New("failed to prepare statement")
|
||||||
MySQLExecuteStmtFailure = errors.New("failed to execute statement")
|
MySQLExecuteStmtFailure = errors.New("failed to execute statement")
|
||||||
MySQLGetColumnsFailure = errors.New("failed to get columns")
|
|
||||||
MySQLScanRowsFailure = errors.New("failed to scan rows")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MySQLBackend is a physical backend that stores data
|
// MySQLBackend is a physical backend that stores data
|
||||||
@ -27,6 +22,7 @@ type MySQLBackend struct {
|
|||||||
table string
|
table string
|
||||||
database string
|
database string
|
||||||
client *sql.DB
|
client *sql.DB
|
||||||
|
statements map[string]*sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMySQLBackend constructs a MySQL backend using the given API client and
|
// newMySQLBackend constructs a MySQL backend using the given API client and
|
||||||
@ -45,39 +41,70 @@ func newMySQLBackend(conf map[string]string) (Backend, error) {
|
|||||||
// Get the MySQL database and table details.
|
// Get the MySQL database and table details.
|
||||||
database, ok := conf["database"]
|
database, ok := conf["database"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, MySQLDBNameMissing
|
return nil, fmt.Errorf("database name is missing in the configuration")
|
||||||
}
|
}
|
||||||
table, ok := conf["table"]
|
table, ok := conf["table"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, MySQLTableNameMissing
|
return nil, fmt.Errorf("table name is missing in the configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create MySQL handle for the database.
|
// Create MySQL handle for the database.
|
||||||
dsn := username + ":" + password + "@tcp(" + address + ")/" + database
|
dsn := username + ":" + password + "@tcp(" + address + ")/" + database
|
||||||
db, err := sql.Open("mysql", dsn)
|
db, err := sql.Open("mysql", dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MySQLHandlerCreationFailure
|
return nil, fmt.Errorf("failed to open handler with database")
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Create the required table.
|
// Create the required table if it doesn't exists.
|
||||||
create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))"
|
create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))"
|
||||||
stmt, err := db.Prepare(create_stmt)
|
create_stmt, err := db.Prepare(create_query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MySQLPrepareStmtFailure
|
return nil, MySQLPrepareStmtFailure
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
defer create_stmt.Close()
|
||||||
|
|
||||||
_, err = stmt.Exec()
|
_, err = create_stmt.Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MySQLExecuteStmtFailure
|
return nil, MySQLExecuteStmtFailure
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Map of query type as key to prepared statement.
|
||||||
|
var statements map[string]*sql.Stmt
|
||||||
|
|
||||||
|
// Prepare statement for put query.
|
||||||
|
insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)"
|
||||||
|
insert_stmt, err := db.Prepare(insert_query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, MySQLPrepareStmtFailure
|
||||||
|
}
|
||||||
|
statements["put"] = insert_stmt
|
||||||
|
defer insert_stmt.Close()
|
||||||
|
|
||||||
|
// Prepare statement for select query.
|
||||||
|
select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?"
|
||||||
|
select_stmt, err := db.Prepare(select_query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, MySQLPrepareStmtFailure
|
||||||
|
}
|
||||||
|
statements["get"] = select_stmt
|
||||||
|
defer select_stmt.Close()
|
||||||
|
|
||||||
|
// Prepare statement for delete query.
|
||||||
|
delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?"
|
||||||
|
delete_stmt, err := db.Prepare(delete_query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, MySQLPrepareStmtFailure
|
||||||
|
}
|
||||||
|
statements["delete"] = delete_stmt
|
||||||
|
defer delete_stmt.Close()
|
||||||
|
|
||||||
// Setup the backend.
|
// Setup the backend.
|
||||||
m := &MySQLBackend{
|
m := &MySQLBackend{
|
||||||
client: db,
|
client: db,
|
||||||
table: table,
|
table: table,
|
||||||
database: database,
|
database: database,
|
||||||
|
statements: statements,
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
@ -87,14 +114,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) {
|
|||||||
func (m *MySQLBackend) Put(entry *Entry) error {
|
func (m *MySQLBackend) Put(entry *Entry) error {
|
||||||
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
|
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
|
||||||
|
|
||||||
insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)"
|
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
|
||||||
stmt, err := m.client.Prepare(insert_stmt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
_, err = stmt.Exec(entry.Key, entry.Value)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -106,16 +126,9 @@ func (m *MySQLBackend) Put(entry *Entry) error {
|
|||||||
func (m *MySQLBackend) Get(key string) (*Entry, error) {
|
func (m *MySQLBackend) Get(key string) (*Entry, error) {
|
||||||
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
|
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
|
||||||
|
|
||||||
select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?"
|
|
||||||
stmt, err := m.client.Prepare(select_stmt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, MySQLPrepareStmtFailure
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
var result []byte
|
var result []byte
|
||||||
|
|
||||||
err = stmt.QueryRow(key).Scan(&result)
|
err := m.statements["get"].QueryRow(key).Scan(&result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MySQLExecuteStmtFailure
|
return nil, MySQLExecuteStmtFailure
|
||||||
}
|
}
|
||||||
@ -132,14 +145,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) {
|
|||||||
func (m *MySQLBackend) Delete(key string) error {
|
func (m *MySQLBackend) Delete(key string) error {
|
||||||
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
|
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
|
||||||
|
|
||||||
delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?"
|
_, err := m.statements["delete"].Exec(key)
|
||||||
stmt, err := m.client.Prepare(delete_stmt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
_, err = stmt.Exec(key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -152,15 +158,15 @@ func (m *MySQLBackend) Delete(key string) error {
|
|||||||
func (m *MySQLBackend) List(prefix string) ([]string, error) {
|
func (m *MySQLBackend) List(prefix string) ([]string, error) {
|
||||||
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
|
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
|
||||||
|
|
||||||
list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table
|
list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'"
|
||||||
rows, err := m.client.Query(list_stmt)
|
rows, err := m.client.Query(list_query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MySQLExecuteStmtFailure
|
return nil, MySQLExecuteStmtFailure
|
||||||
}
|
}
|
||||||
|
|
||||||
columns, err := rows.Columns()
|
columns, err := rows.Columns()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MySQLGetColumnsFailure
|
return nil, fmt.Errorf("failed to get columns")
|
||||||
}
|
}
|
||||||
|
|
||||||
values := make([]sql.RawBytes, len(columns))
|
values := make([]sql.RawBytes, len(columns))
|
||||||
@ -174,15 +180,13 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) {
|
|||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = rows.Scan(scanArgs...)
|
err = rows.Scan(scanArgs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, MySQLScanRowsFailure
|
return nil, fmt.Errorf("failed to scan rows")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, col := range values {
|
for _, col := range values {
|
||||||
if strings.HasPrefix(string(col), prefix) {
|
|
||||||
keys = append(keys, string(col))
|
keys = append(keys, string(col))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) {
|
|||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
// Prepare statement for creating table.
|
// Prepare statement for creating table.
|
||||||
create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))"
|
create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))"
|
||||||
stmtCrt, err := db.Prepare(create_stmt)
|
stmtCrt, err := db.Prepare(create_stmt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to prepare statement: %v", err)
|
t.Fatalf("Failed to prepare statement: %v", err)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user