Changes done as per feedback

This commit is contained in:
Pradeep Chhetri 2015-06-12 11:41:25 +05:45
parent 9c3881442e
commit c42bc38c62
2 changed files with 58 additions and 54 deletions

View File

@ -3,8 +3,8 @@ package physical
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"sort" "sort"
"strings"
"time" "time"
"github.com/armon/go-metrics" "github.com/armon/go-metrics"
@ -12,13 +12,8 @@ import (
) )
var ( var (
MySQLDBNameMissing = errors.New("database name is missing in the configuration")
MySQLTableNameMissing = errors.New("table name is missing in the configuration")
MySQLHandlerCreationFailure = errors.New("failed to open handler with database")
MySQLPrepareStmtFailure = errors.New("failed to prepare statement") MySQLPrepareStmtFailure = errors.New("failed to prepare statement")
MySQLExecuteStmtFailure = errors.New("failed to execute statement") MySQLExecuteStmtFailure = errors.New("failed to execute statement")
MySQLGetColumnsFailure = errors.New("failed to get columns")
MySQLScanRowsFailure = errors.New("failed to scan rows")
) )
// MySQLBackend is a physical backend that stores data // MySQLBackend is a physical backend that stores data
@ -27,6 +22,7 @@ type MySQLBackend struct {
table string table string
database string database string
client *sql.DB client *sql.DB
statements map[string]*sql.Stmt
} }
// newMySQLBackend constructs a MySQL backend using the given API client and // newMySQLBackend constructs a MySQL backend using the given API client and
@ -45,39 +41,70 @@ func newMySQLBackend(conf map[string]string) (Backend, error) {
// Get the MySQL database and table details. // Get the MySQL database and table details.
database, ok := conf["database"] database, ok := conf["database"]
if !ok { if !ok {
return nil, MySQLDBNameMissing return nil, fmt.Errorf("database name is missing in the configuration")
} }
table, ok := conf["table"] table, ok := conf["table"]
if !ok { if !ok {
return nil, MySQLTableNameMissing return nil, fmt.Errorf("table name is missing in the configuration")
} }
// Create MySQL handle for the database. // Create MySQL handle for the database.
dsn := username + ":" + password + "@tcp(" + address + ")/" + database dsn := username + ":" + password + "@tcp(" + address + ")/" + database
db, err := sql.Open("mysql", dsn) db, err := sql.Open("mysql", dsn)
if err != nil { if err != nil {
return nil, MySQLHandlerCreationFailure return nil, fmt.Errorf("failed to open handler with database")
} }
defer db.Close() defer db.Close()
// Create the required table. // Create the required table if it doesn't exists.
create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))" create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))"
stmt, err := db.Prepare(create_stmt) create_stmt, err := db.Prepare(create_query)
if err != nil { if err != nil {
return nil, MySQLPrepareStmtFailure return nil, MySQLPrepareStmtFailure
} }
defer stmt.Close() defer create_stmt.Close()
_, err = stmt.Exec() _, err = create_stmt.Exec()
if err != nil { if err != nil {
return nil, MySQLExecuteStmtFailure return nil, MySQLExecuteStmtFailure
} }
// Map of query type as key to prepared statement.
var statements map[string]*sql.Stmt
// Prepare statement for put query.
insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)"
insert_stmt, err := db.Prepare(insert_query)
if err != nil {
return nil, MySQLPrepareStmtFailure
}
statements["put"] = insert_stmt
defer insert_stmt.Close()
// Prepare statement for select query.
select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?"
select_stmt, err := db.Prepare(select_query)
if err != nil {
return nil, MySQLPrepareStmtFailure
}
statements["get"] = select_stmt
defer select_stmt.Close()
// Prepare statement for delete query.
delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?"
delete_stmt, err := db.Prepare(delete_query)
if err != nil {
return nil, MySQLPrepareStmtFailure
}
statements["delete"] = delete_stmt
defer delete_stmt.Close()
// Setup the backend. // Setup the backend.
m := &MySQLBackend{ m := &MySQLBackend{
client: db, client: db,
table: table, table: table,
database: database, database: database,
statements: statements,
} }
return m, nil return m, nil
@ -87,14 +114,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) {
func (m *MySQLBackend) Put(entry *Entry) error { func (m *MySQLBackend) Put(entry *Entry) error {
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" _, err := m.statements["put"].Exec(entry.Key, entry.Value)
stmt, err := m.client.Prepare(insert_stmt)
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(entry.Key, entry.Value)
if err != nil { if err != nil {
return err return err
} }
@ -106,16 +126,9 @@ func (m *MySQLBackend) Put(entry *Entry) error {
func (m *MySQLBackend) Get(key string) (*Entry, error) { func (m *MySQLBackend) Get(key string) (*Entry, error) {
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?"
stmt, err := m.client.Prepare(select_stmt)
if err != nil {
return nil, MySQLPrepareStmtFailure
}
defer stmt.Close()
var result []byte var result []byte
err = stmt.QueryRow(key).Scan(&result) err := m.statements["get"].QueryRow(key).Scan(&result)
if err != nil { if err != nil {
return nil, MySQLExecuteStmtFailure return nil, MySQLExecuteStmtFailure
} }
@ -132,14 +145,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) {
func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) Delete(key string) error {
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?" _, err := m.statements["delete"].Exec(key)
stmt, err := m.client.Prepare(delete_stmt)
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(key)
if err != nil { if err != nil {
return err return err
} }
@ -152,15 +158,15 @@ func (m *MySQLBackend) Delete(key string) error {
func (m *MySQLBackend) List(prefix string) ([]string, error) { func (m *MySQLBackend) List(prefix string) ([]string, error) {
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'"
rows, err := m.client.Query(list_stmt) rows, err := m.client.Query(list_query)
if err != nil { if err != nil {
return nil, MySQLExecuteStmtFailure return nil, MySQLExecuteStmtFailure
} }
columns, err := rows.Columns() columns, err := rows.Columns()
if err != nil { if err != nil {
return nil, MySQLGetColumnsFailure return nil, fmt.Errorf("failed to get columns")
} }
values := make([]sql.RawBytes, len(columns)) values := make([]sql.RawBytes, len(columns))
@ -174,15 +180,13 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) {
for rows.Next() { for rows.Next() {
err = rows.Scan(scanArgs...) err = rows.Scan(scanArgs...)
if err != nil { if err != nil {
return nil, MySQLScanRowsFailure return nil, fmt.Errorf("failed to scan rows")
} }
for _, col := range values { for _, col := range values {
if strings.HasPrefix(string(col), prefix) {
keys = append(keys, string(col)) keys = append(keys, string(col))
} }
} }
}
sort.Strings(keys) sort.Strings(keys)

View File

@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) {
defer db.Close() defer db.Close()
// Prepare statement for creating table. // Prepare statement for creating table.
create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))"
stmtCrt, err := db.Prepare(create_stmt) stmtCrt, err := db.Prepare(create_stmt)
if err != nil { if err != nil {
t.Fatalf("Failed to prepare statement: %v", err) t.Fatalf("Failed to prepare statement: %v", err)