mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-16 11:37:04 +02:00
* Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License. Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUS-1.1 * Fix test that expected exact offset on hcl file --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com> Co-authored-by: Sarah Thompson <sthompson@hashicorp.com> Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
475 lines
13 KiB
Go
475 lines
13 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package postgresql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/armon/go-metrics"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
_ "github.com/jackc/pgx/v4/stdlib"
|
|
)
|
|
|
|
const (
|
|
|
|
// The lock TTL matches the default that Consul API uses, 15 seconds.
|
|
// Used as part of SQL commands to set/extend lock expiry time relative to
|
|
// database clock.
|
|
PostgreSQLLockTTLSeconds = 15
|
|
|
|
// The amount of time to wait between the lock renewals
|
|
PostgreSQLLockRenewInterval = 5 * time.Second
|
|
|
|
// PostgreSQLLockRetryInterval is the amount of time to wait
|
|
// if a lock fails before trying again.
|
|
PostgreSQLLockRetryInterval = time.Second
|
|
)
|
|
|
|
// Verify PostgreSQLBackend satisfies the correct interfaces
|
|
var _ physical.Backend = (*PostgreSQLBackend)(nil)
|
|
|
|
// HA backend was implemented based on the DynamoDB backend pattern
|
|
// With distinction using central postgres clock, hereby avoiding
|
|
// possible issues with multiple clocks
|
|
var (
|
|
_ physical.HABackend = (*PostgreSQLBackend)(nil)
|
|
_ physical.Lock = (*PostgreSQLLock)(nil)
|
|
)
|
|
|
|
// PostgreSQL Backend is a physical backend that stores data
|
|
// within a PostgreSQL database.
|
|
type PostgreSQLBackend struct {
|
|
table string
|
|
client *sql.DB
|
|
put_query string
|
|
get_query string
|
|
delete_query string
|
|
list_query string
|
|
|
|
ha_table string
|
|
haGetLockValueQuery string
|
|
haUpsertLockIdentityExec string
|
|
haDeleteLockExec string
|
|
|
|
haEnabled bool
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
}
|
|
|
|
// PostgreSQLLock implements a lock using an PostgreSQL client.
|
|
type PostgreSQLLock struct {
|
|
backend *PostgreSQLBackend
|
|
value, key string
|
|
identity string
|
|
lock sync.Mutex
|
|
|
|
renewTicker *time.Ticker
|
|
|
|
// ttlSeconds is how long a lock is valid for
|
|
ttlSeconds int
|
|
|
|
// renewInterval is how much time to wait between lock renewals. must be << ttl
|
|
renewInterval time.Duration
|
|
|
|
// retryInterval is how much time to wait between attempts to grab the lock
|
|
retryInterval time.Duration
|
|
}
|
|
|
|
// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
|
|
// API client, server address, credentials, and database.
|
|
func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
// Get the PostgreSQL credentials to perform read/write operations.
|
|
connURL := connectionURL(conf)
|
|
if connURL == "" {
|
|
return nil, fmt.Errorf("missing connection_url")
|
|
}
|
|
|
|
unquoted_table, ok := conf["table"]
|
|
if !ok {
|
|
unquoted_table = "vault_kv_store"
|
|
}
|
|
quoted_table := dbutil.QuoteIdentifier(unquoted_table)
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
var err error
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
} else {
|
|
maxParInt = physical.DefaultParallelOperations
|
|
}
|
|
|
|
maxIdleConnsStr, maxIdleConnsIsSet := conf["max_idle_connections"]
|
|
var maxIdleConns int
|
|
if maxIdleConnsIsSet {
|
|
maxIdleConns, err = strconv.Atoi(maxIdleConnsStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_idle_connections parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_idle_connections set", "max_idle_connections", maxIdleConnsStr)
|
|
}
|
|
}
|
|
|
|
// Create PostgreSQL handle for the database.
|
|
db, err := sql.Open("pgx", connURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to postgres: %w", err)
|
|
}
|
|
db.SetMaxOpenConns(maxParInt)
|
|
|
|
if maxIdleConnsIsSet {
|
|
db.SetMaxIdleConns(maxIdleConns)
|
|
}
|
|
|
|
// Determine if we should use a function to work around lack of upsert (versions < 9.5)
|
|
var upsertAvailable bool
|
|
upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500"
|
|
if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil {
|
|
return nil, fmt.Errorf("failed to check for native upsert: %w", err)
|
|
}
|
|
|
|
if !upsertAvailable && conf["ha_enabled"] == "true" {
|
|
return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5")
|
|
}
|
|
|
|
// Setup our put strategy based on the presence or absence of a native
|
|
// upsert.
|
|
var put_query string
|
|
if !upsertAvailable {
|
|
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
|
|
} else {
|
|
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
|
|
" ON CONFLICT (path, key) DO " +
|
|
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
|
|
}
|
|
|
|
unquoted_ha_table, ok := conf["ha_table"]
|
|
if !ok {
|
|
unquoted_ha_table = "vault_ha_locks"
|
|
}
|
|
quoted_ha_table := dbutil.QuoteIdentifier(unquoted_ha_table)
|
|
|
|
// Setup the backend.
|
|
m := &PostgreSQLBackend{
|
|
table: quoted_table,
|
|
client: db,
|
|
put_query: put_query,
|
|
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
|
" UNION ALL SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table +
|
|
" WHERE parent_path LIKE $1 || '%'",
|
|
haGetLockValueQuery:
|
|
// only read non expired data
|
|
" SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ",
|
|
haUpsertLockIdentityExec:
|
|
// $1=identity $2=ha_key $3=ha_value $4=TTL in seconds
|
|
// update either steal expired lock OR update expiry for lock owned by me
|
|
" INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds' ) " +
|
|
" ON CONFLICT (ha_key) DO " +
|
|
" UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " +
|
|
" WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " +
|
|
" (t.ha_identity = $1 AND t.ha_key = $2) ",
|
|
haDeleteLockExec:
|
|
// $1=ha_identity $2=ha_key
|
|
" DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ",
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
haEnabled: conf["ha_enabled"] == "true",
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// connectionURL first check the environment variables for a connection URL. If
|
|
// no connection URL exists in the environment variable, the Vault config file is
|
|
// checked. If neither the environment variables or the config file set the connection
|
|
// URL for the Postgres backend, because it is a required field, an error is returned.
|
|
func connectionURL(conf map[string]string) string {
|
|
connURL := conf["connection_url"]
|
|
if envURL := os.Getenv("VAULT_PG_CONNECTION_URL"); envURL != "" {
|
|
connURL = envURL
|
|
}
|
|
|
|
return connURL
|
|
}
|
|
|
|
// splitKey is a helper to split a full path key into individual
|
|
// parts: parentPath, path, key
|
|
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
|
|
var parentPath string
|
|
var path string
|
|
|
|
pieces := strings.Split(fullPath, "/")
|
|
depth := len(pieces)
|
|
key := pieces[depth-1]
|
|
|
|
if depth == 1 {
|
|
parentPath = ""
|
|
path = "/"
|
|
} else if depth == 2 {
|
|
parentPath = "/"
|
|
path = "/" + pieces[0] + "/"
|
|
} else {
|
|
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
|
|
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
|
|
}
|
|
|
|
return parentPath, path, key
|
|
}
|
|
|
|
// Put is used to insert or update an entry.
|
|
func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
parentPath, path, key := m.splitKey(entry.Key)
|
|
|
|
_, err := m.client.ExecContext(ctx, m.put_query, parentPath, path, key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch and entry.
|
|
func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
var result []byte
|
|
err := m.client.QueryRowContext(ctx, m.get_query, path, key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: fullPath,
|
|
Value: result,
|
|
}
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
_, err := m.client.ExecContext(ctx, m.delete_query, path, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
rows, err := m.client.QueryContext(ctx, m.list_query, "/"+prefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
|
}
|
|
|
|
keys = append(keys, key)
|
|
}
|
|
|
|
return keys, nil
|
|
}
|
|
|
|
// LockWith is used for mutual exclusion based on the given key.
|
|
func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) {
|
|
identity, err := uuid.GenerateUUID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &PostgreSQLLock{
|
|
backend: p,
|
|
key: key,
|
|
value: value,
|
|
identity: identity,
|
|
ttlSeconds: PostgreSQLLockTTLSeconds,
|
|
renewInterval: PostgreSQLLockRenewInterval,
|
|
retryInterval: PostgreSQLLockRetryInterval,
|
|
}, nil
|
|
}
|
|
|
|
func (p *PostgreSQLBackend) HAEnabled() bool {
|
|
return p.haEnabled
|
|
}
|
|
|
|
// Lock tries to acquire the lock by repeatedly trying to create a record in the
|
|
// PostgreSQL table. It will block until either the stop channel is closed or
|
|
// the lock could be acquired successfully. The returned channel will be closed
|
|
// once the lock in the PostgreSQL table cannot be renewed, either due to an
|
|
// error speaking to PostgreSQL or because someone else has taken it.
|
|
func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) {
|
|
l.lock.Lock()
|
|
defer l.lock.Unlock()
|
|
|
|
var (
|
|
success = make(chan struct{})
|
|
errors = make(chan error)
|
|
leader = make(chan struct{})
|
|
)
|
|
// try to acquire the lock asynchronously
|
|
go l.tryToLock(stopCh, success, errors)
|
|
|
|
select {
|
|
case <-success:
|
|
// after acquiring it successfully, we must renew the lock periodically
|
|
l.renewTicker = time.NewTicker(l.renewInterval)
|
|
go l.periodicallyRenewLock(leader)
|
|
case err := <-errors:
|
|
return nil, err
|
|
case <-stopCh:
|
|
return nil, nil
|
|
}
|
|
|
|
return leader, nil
|
|
}
|
|
|
|
// Unlock releases the lock by deleting the lock record from the
|
|
// PostgreSQL table.
|
|
func (l *PostgreSQLLock) Unlock() error {
|
|
pg := l.backend
|
|
pg.permitPool.Acquire()
|
|
defer pg.permitPool.Release()
|
|
|
|
if l.renewTicker != nil {
|
|
l.renewTicker.Stop()
|
|
}
|
|
|
|
// Delete lock owned by me
|
|
_, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key)
|
|
return err
|
|
}
|
|
|
|
// Value checks whether or not the lock is held by any instance of PostgreSQLLock,
|
|
// including this one, and returns the current value.
|
|
func (l *PostgreSQLLock) Value() (bool, string, error) {
|
|
pg := l.backend
|
|
pg.permitPool.Acquire()
|
|
defer pg.permitPool.Release()
|
|
var result string
|
|
err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result)
|
|
|
|
switch err {
|
|
case nil:
|
|
return true, result, nil
|
|
case sql.ErrNoRows:
|
|
return false, "", nil
|
|
default:
|
|
return false, "", err
|
|
|
|
}
|
|
}
|
|
|
|
// tryToLock tries to create a new item in PostgreSQL every `retryInterval`.
|
|
// As long as the item cannot be created (because it already exists), it will
|
|
// be retried. If the operation fails due to an error, it is sent to the errors
|
|
// channel. When the lock could be acquired successfully, the success channel
|
|
// is closed.
|
|
func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) {
|
|
ticker := time.NewTicker(l.retryInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-stop:
|
|
return
|
|
case <-ticker.C:
|
|
gotlock, err := l.writeItem()
|
|
switch {
|
|
case err != nil:
|
|
errors <- err
|
|
return
|
|
case gotlock:
|
|
close(success)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) {
|
|
for range l.renewTicker.C {
|
|
gotlock, err := l.writeItem()
|
|
if err != nil || !gotlock {
|
|
close(done)
|
|
l.renewTicker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Attempts to put/update the PostgreSQL item using condition expressions to
|
|
// evaluate the TTL. Returns true if the lock was obtained, false if not.
|
|
// If false error may be nil or non-nil: nil indicates simply that someone
|
|
// else has the lock, whereas non-nil means that something unexpected happened.
|
|
func (l *PostgreSQLLock) writeItem() (bool, error) {
|
|
pg := l.backend
|
|
pg.permitPool.Acquire()
|
|
defer pg.permitPool.Release()
|
|
|
|
// Try steal lock or update expiry on my lock
|
|
|
|
sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if sqlResult == nil {
|
|
return false, fmt.Errorf("empty SQL response received")
|
|
}
|
|
|
|
ar, err := sqlResult.RowsAffected()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return ar == 1, nil
|
|
}
|