mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-10 16:47:01 +02:00
The result will still pass gofmtcheck and won't trigger additional changes if someone isn't using goimports, but it will avoid the piecemeal imports changes we've been seeing.
216 lines
5.6 KiB
Go
216 lines
5.6 KiB
Go
package postgresql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hashicorp/errwrap"
|
|
"github.com/hashicorp/vault/physical"
|
|
|
|
log "github.com/hashicorp/go-hclog"
|
|
|
|
metrics "github.com/armon/go-metrics"
|
|
"github.com/lib/pq"
|
|
)
|
|
|
|
// Verify PostgreSQLBackend satisfies the correct interfaces
|
|
var _ physical.Backend = (*PostgreSQLBackend)(nil)
|
|
|
|
// PostgreSQL Backend is a physical backend that stores data
|
|
// within a PostgreSQL database.
|
|
type PostgreSQLBackend struct {
|
|
table string
|
|
client *sql.DB
|
|
put_query string
|
|
get_query string
|
|
delete_query string
|
|
list_query string
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
}
|
|
|
|
// NewPostgreSQLBackend constructs a PostgreSQL backend using the given
|
|
// API client, server address, credentials, and database.
|
|
func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
// Get the PostgreSQL credentials to perform read/write operations.
|
|
connURL, ok := conf["connection_url"]
|
|
if !ok || connURL == "" {
|
|
return nil, fmt.Errorf("missing connection_url")
|
|
}
|
|
|
|
unquoted_table, ok := conf["table"]
|
|
if !ok {
|
|
unquoted_table = "vault_kv_store"
|
|
}
|
|
quoted_table := pq.QuoteIdentifier(unquoted_table)
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
var err error
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
} else {
|
|
maxParInt = physical.DefaultParallelOperations
|
|
}
|
|
|
|
// Create PostgreSQL handle for the database.
|
|
db, err := sql.Open("postgres", connURL)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed to connect to postgres: {{err}}", err)
|
|
}
|
|
db.SetMaxOpenConns(maxParInt)
|
|
|
|
// Determine if we should use an upsert function (versions < 9.5)
|
|
var upsert_required bool
|
|
upsert_required_query := "SELECT current_setting('server_version_num')::int < 90500"
|
|
if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil {
|
|
return nil, errwrap.Wrapf("failed to check for native upsert: {{err}}", err)
|
|
}
|
|
|
|
// Setup our put strategy based on the presence or absence of a native
|
|
// upsert.
|
|
var put_query string
|
|
if upsert_required {
|
|
put_query = "SELECT vault_kv_put($1, $2, $3, $4)"
|
|
} else {
|
|
put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" +
|
|
" ON CONFLICT (path, key) DO " +
|
|
" UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)"
|
|
}
|
|
|
|
// Setup the backend.
|
|
m := &PostgreSQLBackend{
|
|
table: quoted_table,
|
|
client: db,
|
|
put_query: put_query,
|
|
get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2",
|
|
list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" +
|
|
"UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " +
|
|
quoted_table + " WHERE parent_path LIKE $1 || '%'",
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
// splitKey is a helper to split a full path key into individual
|
|
// parts: parentPath, path, key
|
|
func (m *PostgreSQLBackend) splitKey(fullPath string) (string, string, string) {
|
|
var parentPath string
|
|
var path string
|
|
|
|
pieces := strings.Split(fullPath, "/")
|
|
depth := len(pieces)
|
|
key := pieces[depth-1]
|
|
|
|
if depth == 1 {
|
|
parentPath = ""
|
|
path = "/"
|
|
} else if depth == 2 {
|
|
parentPath = "/"
|
|
path = "/" + pieces[0] + "/"
|
|
} else {
|
|
parentPath = "/" + strings.Join(pieces[:depth-2], "/") + "/"
|
|
path = "/" + strings.Join(pieces[:depth-1], "/") + "/"
|
|
}
|
|
|
|
return parentPath, path, key
|
|
}
|
|
|
|
// Put is used to insert or update an entry.
|
|
func (m *PostgreSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "put"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
parentPath, path, key := m.splitKey(entry.Key)
|
|
|
|
_, err := m.client.Exec(m.put_query, parentPath, path, key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get is used to fetch and entry.
|
|
func (m *PostgreSQLBackend) Get(ctx context.Context, fullPath string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "get"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
var result []byte
|
|
err := m.client.QueryRow(m.get_query, path, key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: result,
|
|
}
|
|
return ent, nil
|
|
}
|
|
|
|
// Delete is used to permanently delete an entry
|
|
func (m *PostgreSQLBackend) Delete(ctx context.Context, fullPath string) error {
|
|
defer metrics.MeasureSince([]string{"postgres", "delete"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, path, key := m.splitKey(fullPath)
|
|
|
|
_, err := m.client.Exec(m.delete_query, path, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List is used to list all the keys under a given
|
|
// prefix, up to the next prefix.
|
|
func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"postgres", "list"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
rows, err := m.client.Query(m.list_query, "/"+prefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed to scan rows: {{err}}", err)
|
|
}
|
|
|
|
keys = append(keys, key)
|
|
}
|
|
|
|
return keys, nil
|
|
}
|