diff --git a/physical/postgresql/postgresql.go b/physical/postgresql/postgresql.go index d850ec0b27..79d68d4c77 100644 --- a/physical/postgresql/postgresql.go +++ b/physical/postgresql/postgresql.go @@ -6,20 +6,45 @@ import ( "fmt" "strconv" "strings" + "sync" "time" "github.com/hashicorp/errwrap" "github.com/hashicorp/vault/sdk/physical" log "github.com/hashicorp/go-hclog" + "github.com/hashicorp/go-uuid" - metrics "github.com/armon/go-metrics" + "github.com/armon/go-metrics" "github.com/lib/pq" ) +const ( + + // The lock TTL matches the default that Consul API uses, 15 seconds. + // Used as part of SQL commands to set/extend lock expiry time relative to + // database clock. + PostgreSQLLockTTLSeconds = 15 + + // The amount of time to wait between the lock renewals + PostgreSQLLockRenewInterval = 5 * time.Second + + // PostgreSQLLockRetryInterval is the amount of time to wait + // if a lock fails before trying again. + PostgreSQLLockRetryInterval = time.Second +) + // Verify PostgreSQLBackend satisfies the correct interfaces var _ physical.Backend = (*PostgreSQLBackend)(nil) +// +// HA backend was implemented based on the DynamoDB backend pattern +// With distinction using central postgres clock, hereby avoiding +// possible issues with multiple clocks +// +var _ physical.HABackend = (*PostgreSQLBackend)(nil) +var _ physical.Lock = (*PostgreSQLLock)(nil) + // PostgreSQL Backend is a physical backend that stores data // within a PostgreSQL database. type PostgreSQLBackend struct { @@ -29,8 +54,34 @@ type PostgreSQLBackend struct { get_query string delete_query string list_query string - logger log.Logger - permitPool *physical.PermitPool + + ha_table string + haGetLockValueQuery string + haUpsertLockIdentityExec string + haDeleteLockExec string + + haEnabled bool + logger log.Logger + permitPool *physical.PermitPool +} + +// PostgreSQLLock implements a lock using an PostgreSQL client. +type PostgreSQLLock struct { + backend *PostgreSQLBackend + value, key string + identity string + lock sync.Mutex + + renewTicker *time.Ticker + + // ttlSeconds is how long a lock is valid for + ttlSeconds int + + // renewInterval is how much time to wait between lock renewals. must be << ttl + renewInterval time.Duration + + // retryInterval is how much time to wait between attempts to grab the lock + retryInterval time.Duration } // NewPostgreSQLBackend constructs a PostgreSQL backend using the given @@ -70,17 +121,21 @@ func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.B } db.SetMaxOpenConns(maxParInt) - // Determine if we should use an upsert function (versions < 9.5) - var upsert_required bool - upsert_required_query := "SELECT current_setting('server_version_num')::int < 90500" - if err := db.QueryRow(upsert_required_query).Scan(&upsert_required); err != nil { + // Determine if we should use a function to work around lack of upsert (versions < 9.5) + var upsertAvailable bool + upsertAvailableQuery := "SELECT current_setting('server_version_num')::int >= 90500" + if err := db.QueryRow(upsertAvailableQuery).Scan(&upsertAvailable); err != nil { return nil, errwrap.Wrapf("failed to check for native upsert: {{err}}", err) } + if !upsertAvailable && conf["ha_enabled"] == "true" { + return nil, fmt.Errorf("ha_enabled=true in config but PG version doesn't support HA, must be at least 9.5") + } + // Setup our put strategy based on the presence or absence of a native // upsert. var put_query string - if upsert_required { + if !upsertAvailable { put_query = "SELECT vault_kv_put($1, $2, $3, $4)" } else { put_query = "INSERT INTO " + quoted_table + " VALUES($1, $2, $3, $4)" + @@ -88,6 +143,12 @@ func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.B " UPDATE SET (parent_path, path, key, value) = ($1, $2, $3, $4)" } + unquoted_ha_table, ok := conf["ha_table"] + if !ok { + unquoted_ha_table = "vault_ha_locks" + } + quoted_ha_table := pq.QuoteIdentifier(unquoted_ha_table) + // Setup the backend. m := &PostgreSQLBackend{ table: quoted_table, @@ -96,10 +157,25 @@ func NewPostgreSQLBackend(conf map[string]string, logger log.Logger) (physical.B get_query: "SELECT value FROM " + quoted_table + " WHERE path = $1 AND key = $2", delete_query: "DELETE FROM " + quoted_table + " WHERE path = $1 AND key = $2", list_query: "SELECT key FROM " + quoted_table + " WHERE path = $1" + - "UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + - quoted_table + " WHERE parent_path LIKE $1 || '%'", + " UNION SELECT DISTINCT substring(substr(path, length($1)+1) from '^.*?/') FROM " + quoted_table + + " WHERE parent_path LIKE $1 || '%'", + haGetLockValueQuery: + // only read non expired data + " SELECT ha_value FROM " + quoted_ha_table + " WHERE NOW() <= valid_until AND ha_key = $1 ", + haUpsertLockIdentityExec: + // $1=identity $2=ha_key $3=ha_value $4=TTL in seconds + // update either steal expired lock OR update expiry for lock owned by me + " INSERT INTO " + quoted_ha_table + " as t (ha_identity, ha_key, ha_value, valid_until) VALUES ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds' ) " + + " ON CONFLICT (ha_key) DO " + + " UPDATE SET (ha_identity, ha_key, ha_value, valid_until) = ($1, $2, $3, NOW() + $4 * INTERVAL '1 seconds') " + + " WHERE (t.valid_until < NOW() AND t.ha_key = $2) OR " + + " (t.ha_identity = $1 AND t.ha_key = $2) ", + haDeleteLockExec: + // $1=ha_identity $2=ha_key + " DELETE FROM " + quoted_ha_table + " WHERE ha_identity=$1 AND ha_key=$2 ", logger: logger, permitPool: physical.NewPermitPool(maxParInt), + haEnabled: conf["ha_enabled"] == "true", } return m, nil @@ -213,3 +289,155 @@ func (m *PostgreSQLBackend) List(ctx context.Context, prefix string) ([]string, return keys, nil } + +// LockWith is used for mutual exclusion based on the given key. +func (p *PostgreSQLBackend) LockWith(key, value string) (physical.Lock, error) { + identity, err := uuid.GenerateUUID() + if err != nil { + return nil, err + } + return &PostgreSQLLock{ + backend: p, + key: key, + value: value, + identity: identity, + ttlSeconds: PostgreSQLLockTTLSeconds, + renewInterval: PostgreSQLLockRenewInterval, + retryInterval: PostgreSQLLockRetryInterval, + }, nil +} + +func (p *PostgreSQLBackend) HAEnabled() bool { + return p.haEnabled +} + +// Lock tries to acquire the lock by repeatedly trying to create a record in the +// PostgreSQL table. It will block until either the stop channel is closed or +// the lock could be acquired successfully. The returned channel will be closed +// once the lock in the PostgreSQL table cannot be renewed, either due to an +// error speaking to PostgresSQL or because someone else has taken it. +func (l *PostgreSQLLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { + l.lock.Lock() + defer l.lock.Unlock() + + var ( + success = make(chan struct{}) + errors = make(chan error) + leader = make(chan struct{}) + ) + // try to acquire the lock asynchronously + go l.tryToLock(stopCh, success, errors) + + select { + case <-success: + // after acquiring it successfully, we must renew the lock periodically + l.renewTicker = time.NewTicker(l.renewInterval) + go l.periodicallyRenewLock(leader) + case err := <-errors: + return nil, err + case <-stopCh: + return nil, nil + } + + return leader, nil +} + +// Unlock releases the lock by deleting the lock record from the +// PostgreSQL table. +func (l *PostgreSQLLock) Unlock() error { + pg := l.backend + pg.permitPool.Acquire() + defer pg.permitPool.Release() + + if l.renewTicker != nil { + l.renewTicker.Stop() + } + + // Delete lock owned by me + _, err := pg.client.Exec(pg.haDeleteLockExec, l.identity, l.key) + return err +} + +// Value checks whether or not the lock is held by any instance of PostgreSQLLock, +// including this one, and returns the current value. +func (l *PostgreSQLLock) Value() (bool, string, error) { + pg := l.backend + pg.permitPool.Acquire() + defer pg.permitPool.Release() + var result string + err := pg.client.QueryRow(pg.haGetLockValueQuery, l.key).Scan(&result) + + switch err { + case nil: + return true, result, nil + case sql.ErrNoRows: + return false, "", nil + default: + return false, "", err + + } +} + +// tryToLock tries to create a new item in PostgreSQL every `retryInterval`. +// As long as the item cannot be created (because it already exists), it will +// be retried. If the operation fails due to an error, it is sent to the errors +// channel. When the lock could be acquired successfully, the success channel +// is closed. +func (l *PostgreSQLLock) tryToLock(stop <-chan struct{}, success chan struct{}, errors chan error) { + ticker := time.NewTicker(l.retryInterval) + defer ticker.Stop() + + for { + select { + case <-stop: + return + case <-ticker.C: + gotlock, err := l.writeItem() + switch { + case err != nil: + errors <- err + return + case gotlock: + close(success) + return + } + } + } +} + +func (l *PostgreSQLLock) periodicallyRenewLock(done chan struct{}) { + for range l.renewTicker.C { + gotlock, err := l.writeItem() + if err != nil || !gotlock { + close(done) + l.renewTicker.Stop() + return + } + } +} + +// Attempts to put/update the PostgreSQL item using condition expressions to +// evaluate the TTL. Returns true if the lock was obtained, false if not. +// If false error may be nil or non-nil: nil indicates simply that someone +// else has the lock, whereas non-nil means that something unexpected happened. +func (l *PostgreSQLLock) writeItem() (bool, error) { + pg := l.backend + pg.permitPool.Acquire() + defer pg.permitPool.Release() + + // Try steal lock or update expiry on my lock + + sqlResult, err := pg.client.Exec(pg.haUpsertLockIdentityExec, l.identity, l.key, l.value, l.ttlSeconds) + if err != nil { + return false, err + } + if sqlResult == nil { + return false, fmt.Errorf("empty SQL response received") + } + + ar, err := sqlResult.RowsAffected() + if err != nil { + return false, err + } + return ar == 1, nil +} diff --git a/physical/postgresql/postgresql_test.go b/physical/postgresql/postgresql_test.go index 7855dc162d..54d7d0e60c 100644 --- a/physical/postgresql/postgresql_test.go +++ b/physical/postgresql/postgresql_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "testing" + "time" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/testhelpers/docker" @@ -30,56 +31,260 @@ func TestPostgreSQLBackend(t *testing.T) { table = "vault_kv_store" } + hae := os.Getenv("PGHAENABLED") + if hae == "" { + hae = "true" + } + + // Run vault tests logger.Info(fmt.Sprintf("Connection URL: %v", connURL)) - b, err := NewPostgreSQLBackend(map[string]string{ + b1, err := NewPostgreSQLBackend(map[string]string{ "connection_url": connURL, "table": table, + "ha_enabled": hae, }, logger) + if err != nil { t.Fatalf("Failed to create new backend: %v", err) } - pg := b.(*PostgreSQLBackend) - //Read postgres version to test basic connects works + b2, err := NewPostgreSQLBackend(map[string]string{ + "connection_url": connURL, + "table": table, + "ha_enabled": hae, + }, logger) + + if err != nil { + t.Fatalf("Failed to create new backend: %v", err) + } + pg := b1.(*PostgreSQLBackend) + + // Read postgres version to test basic connects works var pgversion string if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil { t.Fatalf("Failed to check for Postgres version: %v", err) } logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion)) - //Setup tables and indexes if not exists. - createTableSQL := fmt.Sprintf( - " CREATE TABLE IF NOT EXISTS %v ( "+ - " parent_path TEXT COLLATE \"C\" NOT NULL, "+ - " path TEXT COLLATE \"C\", "+ - " key TEXT COLLATE \"C\", "+ - " value BYTEA, "+ - " CONSTRAINT pkey PRIMARY KEY (path, key) "+ - " ); ", table) - - _, err = pg.client.Exec(createTableSQL) - if err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", table) - - _, err = pg.client.Exec(createIndexSQL) - if err != nil { - t.Fatalf("Failed to create index: %v", err) - } + setupDatabaseObjects(t, logger, pg) defer func() { - pg := b.(*PostgreSQLBackend) + pg := b1.(*PostgreSQLBackend) _, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table)) if err != nil { t.Fatalf("Failed to truncate table: %v", err) } }() - physical.ExerciseBackend(t, b) - physical.ExerciseBackend_ListPrefix(t, b) + logger.Info("Running basic backend tests") + physical.ExerciseBackend(t, b1) + logger.Info("Running list prefix backend tests") + physical.ExerciseBackend_ListPrefix(t, b1) + + ha1, ok := b1.(physical.HABackend) + if !ok { + t.Fatalf("PostgreSQLDB does not implement HABackend") + } + + ha2, ok := b2.(physical.HABackend) + if !ok { + t.Fatalf("PostgreSQLDB does not implement HABackend") + } + + if ha1.HAEnabled() && ha2.HAEnabled() { + logger.Info("Running ha backend tests") + physical.ExerciseHABackend(t, ha1, ha2) + testPostgresSQLLockTTL(t, ha1) + testPostgresSQLLockRenewal(t, ha1) + } +} + +// Similar to testHABackend, but using internal implementation details to +// trigger the lock failure scenario by setting the lock renew period for one +// of the locks to a higher value than the lock TTL. +func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) { + // Set much smaller lock times to speed up the test. + lockTTL := 3 + renewInterval := time.Second * 1 + watchInterval := time.Second * 1 + + // Get the lock + origLock, err := ha.LockWith("dynamodbttl", "bar") + if err != nil { + t.Fatalf("err: %v", err) + } + // set the first lock renew period to double the expected TTL. + lock := origLock.(*PostgreSQLLock) + lock.renewInterval = time.Duration(lockTTL*2) * time.Second + lock.ttlSeconds = lockTTL + // lock.retryInterval = watchInterval + + // Attempt to lock + leaderCh, err := lock.Lock(nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if leaderCh == nil { + t.Fatalf("failed to get leader ch") + } + + // Check the value + held, val, err := lock.Value() + if err != nil { + t.Fatalf("err: %v", err) + } + if !held { + t.Fatalf("should be held") + } + if val != "bar" { + t.Fatalf("bad value: %v", err) + } + + // Second acquisition should succeed because the first lock should + // not renew within the 3 sec TTL. + origLock2, err := ha.LockWith("dynamodbttl", "baz") + if err != nil { + t.Fatalf("err: %v", err) + } + + lock2 := origLock2.(*PostgreSQLLock) + lock2.renewInterval = renewInterval + lock2.ttlSeconds = lockTTL + // lock2.retryInterval = watchInterval + + // Cancel attempt in 6 sec so as not to block unit tests forever + stopCh := make(chan struct{}) + time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() { + close(stopCh) + }) + + // Attempt to lock should work + leaderCh2, err := lock2.Lock(stopCh) + if err != nil { + t.Fatalf("err: %v", err) + } + if leaderCh2 == nil { + t.Fatalf("should get leader ch") + } + + // Check the value + held, val, err = lock2.Value() + if err != nil { + t.Fatalf("err: %v", err) + } + if !held { + t.Fatalf("should be held") + } + if val != "baz" { + t.Fatalf("bad value: %v", err) + } + + // The first lock should have lost the leader channel + leaderChClosed := false + blocking := make(chan struct{}) + // Attempt to read from the leader or the blocking channel, which ever one + // happens first. + go func() { + select { + case <-time.After(watchInterval * 3): + return + case <-leaderCh: + leaderChClosed = true + close(blocking) + case <-blocking: + return + } + }() + + <-blocking + if !leaderChClosed { + t.Fatalf("original lock did not have its leader channel closed.") + } + + // Cleanup + lock2.Unlock() +} + +// Verify that once Unlock is called, we don't keep trying to renew the original +// lock. +func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) { + // Get the lock + origLock, err := ha.LockWith("pgrenewal", "bar") + if err != nil { + t.Fatalf("err: %v", err) + } + + // customize the renewal and watch intervals + lock := origLock.(*PostgreSQLLock) + // lock.renewInterval = time.Second * 1 + + // Attempt to lock + leaderCh, err := lock.Lock(nil) + if err != nil { + t.Fatalf("err: %v", err) + } + if leaderCh == nil { + t.Fatalf("failed to get leader ch") + } + + // Check the value + held, val, err := lock.Value() + if err != nil { + t.Fatalf("err: %v", err) + } + if !held { + t.Fatalf("should be held") + } + if val != "bar" { + t.Fatalf("bad value: %v", err) + } + + // Release the lock, which will delete the stored item + if err := lock.Unlock(); err != nil { + t.Fatalf("err: %v", err) + } + + // Wait longer than the renewal time + time.Sleep(1500 * time.Millisecond) + + // Attempt to lock with new lock + newLock, err := ha.LockWith("pgrenewal", "baz") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Cancel attempt after lock ttl + 1s so as not to block unit tests forever + stopCh := make(chan struct{}) + timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second + time.AfterFunc(timeout, func() { + t.Logf("giving up on lock attempt after %v", timeout) + close(stopCh) + }) + + // Attempt to lock should work + leaderCh2, err := newLock.Lock(stopCh) + if err != nil { + t.Fatalf("err: %v", err) + } + if leaderCh2 == nil { + t.Fatalf("should get leader ch") + } + + // Check the value + held, val, err = newLock.Value() + if err != nil { + t.Fatalf("err: %v", err) + } + if !held { + t.Fatalf("should be held") + } + if val != "baz" { + t.Fatalf("bad value: %v", err) + } + + // Cleanup + newLock.Unlock() } func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retConnString string) { @@ -92,7 +297,7 @@ func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retC if err != nil { t.Fatalf("Failed to connect to docker: %s", err) } - //using 11.1 which is currently latest, use hard version for stabillity of tests + // using 11.1 which is currently latest, use hard version for stability of tests resource, err := pool.Run("postgres", "11.1", []string{}) if err != nil { t.Fatalf("Could not start docker Postgres: %s", err) @@ -122,3 +327,42 @@ func prepareTestContainer(t *testing.T, logger log.Logger) (cleanup func(), retC return cleanup, retConnString } + +func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) { + var err error + // Setup tables and indexes if not exists. + createTableSQL := fmt.Sprintf( + " CREATE TABLE IF NOT EXISTS %v ( "+ + " parent_path TEXT COLLATE \"C\" NOT NULL, "+ + " path TEXT COLLATE \"C\", "+ + " key TEXT COLLATE \"C\", "+ + " value BYTEA, "+ + " CONSTRAINT pkey PRIMARY KEY (path, key) "+ + " ); ", pg.table) + + _, err = pg.client.Exec(createTableSQL) + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table) + + _, err = pg.client.Exec(createIndexSQL) + if err != nil { + t.Fatalf("Failed to create index: %v", err) + } + + createHaTableSQL := + " CREATE TABLE IF NOT EXISTS vault_ha_locks ( " + + " ha_key TEXT COLLATE \"C\" NOT NULL, " + + " ha_identity TEXT COLLATE \"C\" NOT NULL, " + + " ha_value TEXT COLLATE \"C\", " + + " valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " + + " CONSTRAINT ha_key PRIMARY KEY (ha_key) " + + " ); " + + _, err = pg.client.Exec(createHaTableSQL) + if err != nil { + t.Fatalf("Failed to create hatable: %v", err) + } +} diff --git a/website/source/docs/configuration/storage/postgresql.html.md b/website/source/docs/configuration/storage/postgresql.html.md index 84d5cdaa23..a81bcd8b7a 100644 --- a/website/source/docs/configuration/storage/postgresql.html.md +++ b/website/source/docs/configuration/storage/postgresql.html.md @@ -13,8 +13,8 @@ description: |- The PostgreSQL storage backend is used to persist Vault's data in a [PostgreSQL][postgresql] server or cluster. -- **No High Availability** – the PostgreSQL storage backend does not support - high availability. +- **High Availability** – the PostgreSQL storage backend supports + high availability. Requires PostgreSQL 9.5 or later. - **Community Supported** – the PostgreSQL storage backend is supported by the community. While it has undergone review by HashiCorp employees, they may not @@ -42,6 +42,19 @@ CREATE TABLE vault_kv_store ( CREATE INDEX parent_path_idx ON vault_kv_store (parent_path); ``` +Store for HAEnabled backend + +```sql +CREATE TABLE vault_ha_locks ( + ha_key TEXT COLLATE "C" NOT NULL, + ha_identity TEXT COLLATE "C" NOT NULL, + ha_value TEXT COLLATE "C", + valid_until TIMESTAMP WITH TIME ZONE NOT NULL, + CONSTRAINT ha_key PRIMARY KEY (ha_key) +); +``` + + If you're using a version of PostgreSQL prior to 9.5, create the following function: ```sql @@ -86,6 +99,8 @@ LANGUAGE plpgsql; - `max_parallel` `(string: "128")` – Specifies the maximum number of concurrent requests to PostgreSQL. +- `ha_enabled` `(string: "true|false")` – Default not enabled, requires 9.5 or later. + ## `postgresql` Examples ### Custom SSL Verification