From c42bc38c62adb6bd671d39dafbfe127934ec9c67 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Fri, 12 Jun 2015 11:41:25 +0545 Subject: [PATCH] Changes done as per feedback --- physical/mysql.go | 110 +++++++++++++++++++++-------------------- physical/mysql_test.go | 2 +- 2 files changed, 58 insertions(+), 54 deletions(-) diff --git a/physical/mysql.go b/physical/mysql.go index 377ac9dc49..ed53a5350d 100644 --- a/physical/mysql.go +++ b/physical/mysql.go @@ -3,8 +3,8 @@ package physical import ( "database/sql" "errors" + "fmt" "sort" - "strings" "time" "github.com/armon/go-metrics" @@ -12,21 +12,17 @@ import ( ) var ( - MySQLDBNameMissing = errors.New("database name is missing in the configuration") - MySQLTableNameMissing = errors.New("table name is missing in the configuration") - MySQLHandlerCreationFailure = errors.New("failed to open handler with database") - MySQLPrepareStmtFailure = errors.New("failed to prepare statement") - MySQLExecuteStmtFailure = errors.New("failed to execute statement") - MySQLGetColumnsFailure = errors.New("failed to get columns") - MySQLScanRowsFailure = errors.New("failed to scan rows") + MySQLPrepareStmtFailure = errors.New("failed to prepare statement") + MySQLExecuteStmtFailure = errors.New("failed to execute statement") ) // MySQLBackend is a physical backend that stores data // within MySQL database. type MySQLBackend struct { - table string - database string - client *sql.DB + table string + database string + client *sql.DB + statements map[string]*sql.Stmt } // newMySQLBackend constructs a MySQL backend using the given API client and @@ -45,39 +41,70 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { // Get the MySQL database and table details. database, ok := conf["database"] if !ok { - return nil, MySQLDBNameMissing + return nil, fmt.Errorf("database name is missing in the configuration") } table, ok := conf["table"] if !ok { - return nil, MySQLTableNameMissing + return nil, fmt.Errorf("table name is missing in the configuration") } // Create MySQL handle for the database. dsn := username + ":" + password + "@tcp(" + address + ")/" + database db, err := sql.Open("mysql", dsn) if err != nil { - return nil, MySQLHandlerCreationFailure + return nil, fmt.Errorf("failed to open handler with database") } defer db.Close() - // Create the required table. - create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(255), vault_value varchar(255), PRIMARY KEY (vault_key))" - stmt, err := db.Prepare(create_stmt) + // Create the required table if it doesn't exists. + create_query := "CREATE TABLE IF NOT EXISTS " + database + "." + table + " (vault_key varchar(512), vault_value mediumblob, PRIMARY KEY (vault_key))" + create_stmt, err := db.Prepare(create_query) if err != nil { return nil, MySQLPrepareStmtFailure } - defer stmt.Close() + defer create_stmt.Close() - _, err = stmt.Exec() + _, err = create_stmt.Exec() if err != nil { return nil, MySQLExecuteStmtFailure } + // Map of query type as key to prepared statement. + var statements map[string]*sql.Stmt + + // Prepare statement for put query. + insert_query := "INSERT INTO " + database + "." + table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" + insert_stmt, err := db.Prepare(insert_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["put"] = insert_stmt + defer insert_stmt.Close() + + // Prepare statement for select query. + select_query := "SELECT vault_value FROM " + database + "." + table + " WHERE vault_key = ?" + select_stmt, err := db.Prepare(select_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["get"] = select_stmt + defer select_stmt.Close() + + // Prepare statement for delete query. + delete_query := "DELETE FROM " + database + "." + table + " WHERE vault_key = ?" + delete_stmt, err := db.Prepare(delete_query) + if err != nil { + return nil, MySQLPrepareStmtFailure + } + statements["delete"] = delete_stmt + defer delete_stmt.Close() + // Setup the backend. m := &MySQLBackend{ - client: db, - table: table, - database: database, + client: db, + table: table, + database: database, + statements: statements, } return m, nil @@ -87,14 +114,7 @@ func newMySQLBackend(conf map[string]string) (Backend, error) { func (m *MySQLBackend) Put(entry *Entry) error { defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now()) - insert_stmt := "INSERT INTO " + m.database + "." + m.table + " VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)" - stmt, err := m.client.Prepare(insert_stmt) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(entry.Key, entry.Value) + _, err := m.statements["put"].Exec(entry.Key, entry.Value) if err != nil { return err } @@ -106,16 +126,9 @@ func (m *MySQLBackend) Put(entry *Entry) error { func (m *MySQLBackend) Get(key string) (*Entry, error) { defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now()) - select_stmt := "SELECT vault_value FROM " + m.database + "." + m.table + " WHERE vault_key = ?" - stmt, err := m.client.Prepare(select_stmt) - if err != nil { - return nil, MySQLPrepareStmtFailure - } - defer stmt.Close() - var result []byte - err = stmt.QueryRow(key).Scan(&result) + err := m.statements["get"].QueryRow(key).Scan(&result) if err != nil { return nil, MySQLExecuteStmtFailure } @@ -132,14 +145,7 @@ func (m *MySQLBackend) Get(key string) (*Entry, error) { func (m *MySQLBackend) Delete(key string) error { defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now()) - delete_stmt := "DELETE FROM " + m.database + "." + m.table + "WHERE vault_key = ?" - stmt, err := m.client.Prepare(delete_stmt) - if err != nil { - return err - } - defer stmt.Close() - - _, err = stmt.Exec(key) + _, err := m.statements["delete"].Exec(key) if err != nil { return err } @@ -152,15 +158,15 @@ func (m *MySQLBackend) Delete(key string) error { func (m *MySQLBackend) List(prefix string) ([]string, error) { defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now()) - list_stmt := "SELECT vault_key FROM " + m.database + "." + m.table - rows, err := m.client.Query(list_stmt) + list_query := "SELECT vault_key FROM " + m.database + "." + m.table + " WHERE vault_key LIKE '" + prefix + "%'" + rows, err := m.client.Query(list_query) if err != nil { return nil, MySQLExecuteStmtFailure } columns, err := rows.Columns() if err != nil { - return nil, MySQLGetColumnsFailure + return nil, fmt.Errorf("failed to get columns") } values := make([]sql.RawBytes, len(columns)) @@ -174,13 +180,11 @@ func (m *MySQLBackend) List(prefix string) ([]string, error) { for rows.Next() { err = rows.Scan(scanArgs...) if err != nil { - return nil, MySQLScanRowsFailure + return nil, fmt.Errorf("failed to scan rows") } for _, col := range values { - if strings.HasPrefix(string(col), prefix) { - keys = append(keys, string(col)) - } + keys = append(keys, string(col)) } } diff --git a/physical/mysql_test.go b/physical/mysql_test.go index b808a29bdb..c1c78c9fc7 100644 --- a/physical/mysql_test.go +++ b/physical/mysql_test.go @@ -37,7 +37,7 @@ func TestMySQLBackend(t *testing.T) { defer db.Close() // Prepare statement for creating table. - create_stmt := "CREATE TABLE " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" + create_stmt := "CREATE TABLE IF NOT EXISTS " + database + "." + table + "(num int, sqr int, PRIMARY KEY (num))" stmtCrt, err := db.Prepare(create_stmt) if err != nil { t.Fatalf("Failed to prepare statement: %v", err)