mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-16 11:37:04 +02:00
* Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License. Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUS-1.1 * Fix test that expected exact offset on hcl file --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com> Co-authored-by: Sarah Thompson <sthompson@hashicorp.com> Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
780 lines
21 KiB
Go
780 lines
21 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"math"
|
|
"net/url"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-multierror"
|
|
|
|
metrics "github.com/armon/go-metrics"
|
|
mysql "github.com/go-sql-driver/mysql"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
)
|
|
|
|
// Verify MySQLBackend satisfies the correct interfaces
|
|
var (
|
|
_ physical.Backend = (*MySQLBackend)(nil)
|
|
_ physical.HABackend = (*MySQLBackend)(nil)
|
|
_ physical.Lock = (*MySQLHALock)(nil)
|
|
)
|
|
|
|
// Unreserved tls key
|
|
// Reserved values are "true", "false", "skip-verify"
|
|
const mysqlTLSKey = "default"
|
|
|
|
// MySQLBackend is a physical backend that stores data
|
|
// within MySQL database.
|
|
type MySQLBackend struct {
|
|
dbTable string
|
|
dbLockTable string
|
|
client *sql.DB
|
|
statements map[string]*sql.Stmt
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
conf map[string]string
|
|
redirectHost string
|
|
redirectPort int64
|
|
haEnabled bool
|
|
}
|
|
|
|
// NewMySQLBackend constructs a MySQL backend using the given API client and
|
|
// server address and credential for accessing mysql database.
|
|
func NewMySQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
var err error
|
|
|
|
db, err := NewMySQLClient(conf, logger)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
database := conf["database"]
|
|
if database == "" {
|
|
database = "vault"
|
|
}
|
|
table := conf["table"]
|
|
if table == "" {
|
|
table = "vault"
|
|
}
|
|
|
|
err = validateDBTable(database, table)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dbTable := fmt.Sprintf("`%s`.`%s`", database, table)
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
} else {
|
|
maxParInt = physical.DefaultParallelOperations
|
|
}
|
|
|
|
// Check schema exists
|
|
var schemaExist bool
|
|
schemaRows, err := db.Query("SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", database)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check mysql schema exist: %w", err)
|
|
}
|
|
defer schemaRows.Close()
|
|
schemaExist = schemaRows.Next()
|
|
|
|
// Check table exists
|
|
var tableExist bool
|
|
tableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", table, database)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
|
|
}
|
|
defer tableRows.Close()
|
|
tableExist = tableRows.Next()
|
|
|
|
// Create the required database if it doesn't exists.
|
|
if !schemaExist {
|
|
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS `" + database + "`"); err != nil {
|
|
return nil, fmt.Errorf("failed to create mysql database: %w", err)
|
|
}
|
|
}
|
|
|
|
// Create the required table if it doesn't exists.
|
|
if !tableExist {
|
|
create_query := "CREATE TABLE IF NOT EXISTS " + dbTable +
|
|
" (vault_key varbinary(3072), vault_value mediumblob, PRIMARY KEY (vault_key))"
|
|
if _, err := db.Exec(create_query); err != nil {
|
|
return nil, fmt.Errorf("failed to create mysql table: %w", err)
|
|
}
|
|
}
|
|
|
|
// Default value for ha_enabled
|
|
haEnabledStr, ok := conf["ha_enabled"]
|
|
if !ok {
|
|
haEnabledStr = "false"
|
|
}
|
|
haEnabled, err := strconv.ParseBool(haEnabledStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("value [%v] of 'ha_enabled' could not be understood", haEnabledStr)
|
|
}
|
|
|
|
locktable, ok := conf["lock_table"]
|
|
if !ok {
|
|
locktable = table + "_lock"
|
|
}
|
|
|
|
dbLockTable := "`" + database + "`.`" + locktable + "`"
|
|
|
|
// Only create lock table if ha_enabled is true
|
|
if haEnabled {
|
|
// Check table exists
|
|
var lockTableExist bool
|
|
lockTableRows, err := db.Query("SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME = ? AND TABLE_SCHEMA = ?", locktable, database)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check mysql table exist: %w", err)
|
|
}
|
|
defer lockTableRows.Close()
|
|
lockTableExist = lockTableRows.Next()
|
|
|
|
// Create the required table if it doesn't exists.
|
|
if !lockTableExist {
|
|
create_query := "CREATE TABLE IF NOT EXISTS " + dbLockTable +
|
|
" (node_job varbinary(512), current_leader varbinary(512), PRIMARY KEY (node_job))"
|
|
if _, err := db.Exec(create_query); err != nil {
|
|
return nil, fmt.Errorf("failed to create mysql table: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Setup the backend.
|
|
m := &MySQLBackend{
|
|
dbTable: dbTable,
|
|
dbLockTable: dbLockTable,
|
|
client: db,
|
|
statements: make(map[string]*sql.Stmt),
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
conf: conf,
|
|
haEnabled: haEnabled,
|
|
}
|
|
|
|
// Prepare all the statements required
|
|
statements := map[string]string{
|
|
"put": "INSERT INTO " + dbTable +
|
|
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE vault_value=VALUES(vault_value)",
|
|
"get": "SELECT vault_value FROM " + dbTable + " WHERE vault_key = ?",
|
|
"delete": "DELETE FROM " + dbTable + " WHERE vault_key = ?",
|
|
"list": "SELECT vault_key FROM " + dbTable + " WHERE vault_key LIKE ?",
|
|
}
|
|
|
|
// Only prepare ha-related statements if we need them
|
|
if haEnabled {
|
|
statements["get_lock"] = "SELECT current_leader FROM " + dbLockTable + " WHERE node_job = ?"
|
|
statements["used_lock"] = "SELECT IS_USED_LOCK(?)"
|
|
}
|
|
|
|
for name, query := range statements {
|
|
if err := m.prepare(name, query); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// validateDBTable to prevent SQL injection attacks. This ensures that the database and table names only have valid
|
|
// characters in them. MySQL allows for more characters that this will allow, but there isn't an easy way of
|
|
// representing the full Unicode Basic Multilingual Plane to check against.
|
|
// https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
|
|
func validateDBTable(db, table string) (err error) {
|
|
merr := &multierror.Error{}
|
|
merr = multierror.Append(merr, wrapErr("invalid database: %w", validate(db)))
|
|
merr = multierror.Append(merr, wrapErr("invalid table: %w", validate(table)))
|
|
return merr.ErrorOrNil()
|
|
}
|
|
|
|
func validate(name string) (err error) {
|
|
if name == "" {
|
|
return fmt.Errorf("missing name")
|
|
}
|
|
// From: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
|
|
// - Permitted characters in quoted identifiers include the full Unicode Basic Multilingual Plane (BMP), except U+0000:
|
|
// ASCII: U+0001 .. U+007F
|
|
// Extended: U+0080 .. U+FFFF
|
|
// - ASCII NUL (U+0000) and supplementary characters (U+10000 and higher) are not permitted in quoted or unquoted identifiers.
|
|
// - Identifiers may begin with a digit but unless quoted may not consist solely of digits.
|
|
// - Database, table, and column names cannot end with space characters.
|
|
//
|
|
// We are explicitly excluding all space characters (it's easier to deal with)
|
|
// The name will be quoted, so the all-digit requirement doesn't apply
|
|
runes := []rune(name)
|
|
validationErr := fmt.Errorf("invalid character found: can only include printable, non-space characters between [0x0001-0xFFFF]")
|
|
for _, r := range runes {
|
|
// U+0000 Explicitly disallowed
|
|
if r == 0x0000 {
|
|
return fmt.Errorf("invalid character: cannot include 0x0000")
|
|
}
|
|
// Cannot be above 0xFFFF
|
|
if r > 0xFFFF {
|
|
return fmt.Errorf("invalid character: cannot include any characters above 0xFFFF")
|
|
}
|
|
if r == '`' {
|
|
return fmt.Errorf("invalid character: cannot include '`' character")
|
|
}
|
|
if r == '\'' || r == '"' {
|
|
return fmt.Errorf("invalid character: cannot include quotes")
|
|
}
|
|
// We are excluding non-printable characters (not mentioned in the docs)
|
|
if !unicode.IsPrint(r) {
|
|
return validationErr
|
|
}
|
|
// We are excluding space characters (not mentioned in the docs)
|
|
if unicode.IsSpace(r) {
|
|
return validationErr
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func wrapErr(message string, err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return fmt.Errorf(message, err)
|
|
}
|
|
|
|
func NewMySQLClient(conf map[string]string, logger log.Logger) (*sql.DB, error) {
|
|
var err error
|
|
|
|
// Get the MySQL credentials to perform read/write operations.
|
|
username, ok := conf["username"]
|
|
if !ok || username == "" {
|
|
return nil, fmt.Errorf("missing username")
|
|
}
|
|
password, ok := conf["password"]
|
|
if !ok || password == "" {
|
|
return nil, fmt.Errorf("missing password")
|
|
}
|
|
|
|
// Get or set MySQL server address. Defaults to localhost and default port(3306)
|
|
address, ok := conf["address"]
|
|
if !ok {
|
|
address = "127.0.0.1:3306"
|
|
}
|
|
|
|
maxIdleConnStr, ok := conf["max_idle_connections"]
|
|
var maxIdleConnInt int
|
|
if ok {
|
|
maxIdleConnInt, err = strconv.Atoi(maxIdleConnStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnInt)
|
|
}
|
|
}
|
|
|
|
maxConnLifeStr, ok := conf["max_connection_lifetime"]
|
|
var maxConnLifeInt int
|
|
if ok {
|
|
maxConnLifeInt, err = strconv.Atoi(maxConnLifeStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_connection_lifetime parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_connection_lifetime set", "max_connection_lifetime", maxConnLifeInt)
|
|
}
|
|
}
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
} else {
|
|
maxParInt = physical.DefaultParallelOperations
|
|
}
|
|
|
|
dsnParams := url.Values{}
|
|
tlsCaFile, tlsOk := conf["tls_ca_file"]
|
|
if tlsOk {
|
|
if err := setupMySQLTLSConfig(tlsCaFile); err != nil {
|
|
return nil, fmt.Errorf("failed register TLS config: %w", err)
|
|
}
|
|
|
|
dsnParams.Add("tls", mysqlTLSKey)
|
|
}
|
|
ptAllowed, ptOk := conf["plaintext_connection_allowed"]
|
|
if !(ptOk && strings.ToLower(ptAllowed) == "true") && !tlsOk {
|
|
logger.Warn("No TLS specified, credentials will be sent in plaintext. To mute this warning add 'plaintext_connection_allowed' with a true value to your MySQL configuration in your config file.")
|
|
}
|
|
|
|
// Create MySQL handle for the database.
|
|
dsn := username + ":" + password + "@tcp(" + address + ")/?" + dsnParams.Encode()
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to mysql: %w", err)
|
|
}
|
|
db.SetMaxOpenConns(maxParInt)
|
|
if maxIdleConnInt != 0 {
|
|
db.SetMaxIdleConns(maxIdleConnInt)
|
|
}
|
|
if maxConnLifeInt != 0 {
|
|
db.SetConnMaxLifetime(time.Duration(maxConnLifeInt) * time.Second)
|
|
}
|
|
|
|
return db, err
|
|
}
|
|
|
|
// prepare is a helper to prepare a query for future execution
|
|
func (m *MySQLBackend) prepare(name, query string) error {
|
|
stmt, err := m.client.Prepare(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
|
}
|
|
m.statements[name] = stmt
|
|
return nil
|
|
}
|
|
|
|
// Put is used to insert or update an entry.
|
|
func (m *MySQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"mysql", "put"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, err := m.statements["put"].Exec(entry.Key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch an entry.
|
|
func (m *MySQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"mysql", "get"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
var result []byte
|
|
err := m.statements["get"].QueryRow(key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: result,
|
|
}
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (m *MySQLBackend) Delete(ctx context.Context, key string) error {
|
|
defer metrics.MeasureSince([]string{"mysql", "delete"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, err := m.statements["delete"].Exec(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (m *MySQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"mysql", "list"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
// Add the % wildcard to the prefix to do the prefix search
|
|
likePrefix := prefix + "%"
|
|
rows, err := m.statements["list"].Query(likePrefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to execute statement: %w", err)
|
|
}
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
|
}
|
|
|
|
key = strings.TrimPrefix(key, prefix)
|
|
if i := strings.Index(key, "/"); i == -1 {
|
|
// Add objects only from the current 'folder'
|
|
keys = append(keys, key)
|
|
} else if i != -1 {
|
|
// Add truncated 'folder' paths
|
|
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
|
|
}
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
return keys, nil
|
|
}
|
|
|
|
// LockWith is used for mutual exclusion based on the given key.
|
|
func (m *MySQLBackend) LockWith(key, value string) (physical.Lock, error) {
|
|
l := &MySQLHALock{
|
|
in: m,
|
|
key: key,
|
|
value: value,
|
|
logger: m.logger,
|
|
}
|
|
return l, nil
|
|
}
|
|
|
|
func (m *MySQLBackend) HAEnabled() bool {
|
|
return m.haEnabled
|
|
}
|
|
|
|
// MySQLHALock is a MySQL Lock implementation for the HABackend
|
|
type MySQLHALock struct {
|
|
in *MySQLBackend
|
|
key string
|
|
value string
|
|
logger log.Logger
|
|
|
|
held bool
|
|
localLock sync.Mutex
|
|
leaderCh chan struct{}
|
|
stopCh <-chan struct{}
|
|
lock *MySQLLock
|
|
}
|
|
|
|
func (i *MySQLHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
|
|
i.localLock.Lock()
|
|
defer i.localLock.Unlock()
|
|
if i.held {
|
|
return nil, fmt.Errorf("lock already held")
|
|
}
|
|
|
|
// Attempt an async acquisition
|
|
didLock := make(chan struct{})
|
|
failLock := make(chan error, 1)
|
|
releaseCh := make(chan bool, 1)
|
|
go i.attemptLock(i.key, i.value, didLock, failLock, releaseCh)
|
|
|
|
// Wait for lock acquisition, failure, or shutdown
|
|
select {
|
|
case <-didLock:
|
|
releaseCh <- false
|
|
case err := <-failLock:
|
|
return nil, err
|
|
case <-stopCh:
|
|
releaseCh <- true
|
|
return nil, nil
|
|
}
|
|
|
|
// Create the leader channel
|
|
i.held = true
|
|
i.leaderCh = make(chan struct{})
|
|
|
|
go i.monitorLock(i.leaderCh)
|
|
|
|
i.stopCh = stopCh
|
|
|
|
return i.leaderCh, nil
|
|
}
|
|
|
|
func (i *MySQLHALock) attemptLock(key, value string, didLock chan struct{}, failLock chan error, releaseCh chan bool) {
|
|
lock, err := NewMySQLLock(i.in, i.logger, key, value)
|
|
if err != nil {
|
|
failLock <- err
|
|
return
|
|
}
|
|
|
|
// Set node value
|
|
i.lock = lock
|
|
|
|
err = lock.Lock()
|
|
if err != nil {
|
|
failLock <- err
|
|
return
|
|
}
|
|
|
|
// Signal that lock is held
|
|
close(didLock)
|
|
|
|
// Handle an early abort
|
|
release := <-releaseCh
|
|
if release {
|
|
lock.Unlock()
|
|
}
|
|
}
|
|
|
|
func (i *MySQLHALock) monitorLock(leaderCh chan struct{}) {
|
|
for {
|
|
// The only way to lose this lock is if someone is
|
|
// logging into the DB and altering system tables or you lose a connection in
|
|
// which case you will lose the lock anyway.
|
|
err := i.hasLock(i.key)
|
|
if err != nil {
|
|
// Somehow we lost the lock.... likely because the connection holding
|
|
// the lock was closed or someone was playing around with the locks in the DB.
|
|
close(leaderCh)
|
|
return
|
|
}
|
|
|
|
time.Sleep(5 * time.Second)
|
|
}
|
|
}
|
|
|
|
func (i *MySQLHALock) Unlock() error {
|
|
i.localLock.Lock()
|
|
defer i.localLock.Unlock()
|
|
if !i.held {
|
|
return nil
|
|
}
|
|
|
|
err := i.lock.Unlock()
|
|
|
|
if err == nil {
|
|
i.held = false
|
|
return nil
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// hasLock will check if a lock is held by checking the current lock id against our known ID.
|
|
func (i *MySQLHALock) hasLock(key string) error {
|
|
var result sql.NullInt64
|
|
err := i.in.statements["used_lock"].QueryRow(key).Scan(&result)
|
|
if err == sql.ErrNoRows || !result.Valid {
|
|
// This is not an error to us since it just means the lock isn't held
|
|
return nil
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// IS_USED_LOCK will return the ID of the connection that created the lock.
|
|
if result.Int64 != GlobalLockID {
|
|
return ErrLockHeld
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *MySQLHALock) GetLeader() (string, error) {
|
|
defer metrics.MeasureSince([]string{"mysql", "lock_get"}, time.Now())
|
|
var result string
|
|
err := i.in.statements["get_lock"].QueryRow("leader").Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return "", err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (i *MySQLHALock) Value() (bool, string, error) {
|
|
leaderkey, err := i.GetLeader()
|
|
if err != nil {
|
|
return false, "", err
|
|
}
|
|
|
|
return true, leaderkey, err
|
|
}
|
|
|
|
// MySQLLock provides an easy way to grab and release mysql
|
|
// locks using the built in GET_LOCK function. Note that these
|
|
// locks are released when you lose connection to the server.
|
|
type MySQLLock struct {
|
|
parentConn *MySQLBackend
|
|
in *sql.DB
|
|
logger log.Logger
|
|
statements map[string]*sql.Stmt
|
|
key string
|
|
value string
|
|
}
|
|
|
|
// Errors specific to trying to grab a lock in MySQL
|
|
var (
|
|
// This is the GlobalLockID for checking if the lock we got is still the current lock
|
|
GlobalLockID int64
|
|
// ErrLockHeld is returned when another vault instance already has a lock held for the given key.
|
|
ErrLockHeld = errors.New("mysql: lock already held")
|
|
// ErrUnlockFailed
|
|
ErrUnlockFailed = errors.New("mysql: unable to release lock, already released or not held by this session")
|
|
// You were unable to update that you are the new leader in the DB
|
|
ErrClaimFailed = errors.New("mysql: unable to update DB with new leader information")
|
|
// Error to throw if between getting the lock and checking the ID of it we lost it.
|
|
ErrSettingGlobalID = errors.New("mysql: getting global lock id failed")
|
|
)
|
|
|
|
// NewMySQLLock helper function
|
|
func NewMySQLLock(in *MySQLBackend, l log.Logger, key, value string) (*MySQLLock, error) {
|
|
// Create a new MySQL connection so we can close this and have no effect on
|
|
// the rest of the MySQL backend and any cleanup that might need to be done.
|
|
conn, _ := NewMySQLClient(in.conf, in.logger)
|
|
|
|
m := &MySQLLock{
|
|
parentConn: in,
|
|
in: conn,
|
|
logger: l,
|
|
statements: make(map[string]*sql.Stmt),
|
|
key: key,
|
|
value: value,
|
|
}
|
|
|
|
statements := map[string]string{
|
|
"put": "INSERT INTO " + in.dbLockTable +
|
|
" VALUES( ?, ? ) ON DUPLICATE KEY UPDATE current_leader=VALUES(current_leader)",
|
|
}
|
|
|
|
for name, query := range statements {
|
|
if err := m.prepare(name, query); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// prepare is a helper to prepare a query for future execution
|
|
func (m *MySQLLock) prepare(name, query string) error {
|
|
stmt, err := m.in.Prepare(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
|
}
|
|
m.statements[name] = stmt
|
|
return nil
|
|
}
|
|
|
|
// update the current cluster leader in the DB. This is used so
|
|
// we can tell the servers in standby who the active leader is.
|
|
func (i *MySQLLock) becomeLeader() error {
|
|
_, err := i.statements["put"].Exec("leader", i.value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Lock will try to get a lock for an indefinite amount of time
|
|
// based on the given key that has been requested.
|
|
func (i *MySQLLock) Lock() error {
|
|
defer metrics.MeasureSince([]string{"mysql", "get_lock"}, time.Now())
|
|
|
|
// Lock timeout math.MaxInt32 instead of -1 solves compatibility issues with
|
|
// different MySQL flavours i.e. MariaDB
|
|
rows, err := i.in.Query("SELECT GET_LOCK(?, ?), IS_USED_LOCK(?)", i.key, math.MaxInt32, i.key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
defer rows.Close()
|
|
rows.Next()
|
|
var lock sql.NullInt64
|
|
var connectionID sql.NullInt64
|
|
rows.Scan(&lock, &connectionID)
|
|
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
|
|
// 1 is returned from GET_LOCK if it was able to get the lock
|
|
// 0 if it failed and NULL if some strange error happened.
|
|
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_get-lock
|
|
if !lock.Valid || lock.Int64 != 1 {
|
|
return ErrLockHeld
|
|
}
|
|
|
|
// Since we have the lock alert the rest of the cluster
|
|
// that we are now the active leader.
|
|
err = i.becomeLeader()
|
|
if err != nil {
|
|
return ErrLockHeld
|
|
}
|
|
|
|
// This will return the connection ID of NULL if an error happens
|
|
// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_is-used-lock
|
|
if !connectionID.Valid {
|
|
return ErrSettingGlobalID
|
|
}
|
|
|
|
GlobalLockID = connectionID.Int64
|
|
|
|
return nil
|
|
}
|
|
|
|
// Unlock just closes the connection. This is because closing the MySQL connection
|
|
// is a 100% reliable way to close the lock. If you just release the lock you must
|
|
// do it from the same mysql connection_id that you originally created it from. This
|
|
// is a huge hastle and I actually couldn't find a clean way to do this although one
|
|
// likely does exist. Closing the connection however ensures we don't ever get into a
|
|
// state where we try to release the lock and it hangs it is also much less code.
|
|
func (i *MySQLLock) Unlock() error {
|
|
err := i.in.Close()
|
|
if err != nil {
|
|
return ErrUnlockFailed
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Establish a TLS connection with a given CA certificate
|
|
// Register a tsl.Config associated with the same key as the dns param from sql.Open
|
|
// foo:bar@tcp(127.0.0.1:3306)/dbname?tls=default
|
|
func setupMySQLTLSConfig(tlsCaFile string) error {
|
|
rootCertPool := x509.NewCertPool()
|
|
|
|
pem, err := ioutil.ReadFile(tlsCaFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
|
return err
|
|
}
|
|
|
|
err = mysql.RegisterTLSConfig(mysqlTLSKey, &tls.Config{
|
|
RootCAs: rootCertPool,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|