vault/physical/postgresql/postgresql_test.go
hashicorp-copywrite[bot] 0b12cdcfd1
[COMPLIANCE] License changes (#22290)
* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package.

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License.

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUS-1.1

* Fix test that expected exact offset on hcl file

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
Co-authored-by: Sarah Thompson <sthompson@hashicorp.com>
Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
2023-08-10 18:14:03 -07:00

427 lines
11 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package postgresql
import (
"fmt"
"os"
"testing"
"time"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/physical"
_ "github.com/jackc/pgx/v4/stdlib"
)
func TestPostgreSQLBackend(t *testing.T) {
logger := logging.NewVaultLogger(log.Debug)
// Use docker as pg backend if no url is provided via environment variables
connURL := os.Getenv("PGURL")
if connURL == "" {
cleanup, u := postgresql.PrepareTestContainer(t, "11.1")
defer cleanup()
connURL = u
}
table := os.Getenv("PGTABLE")
if table == "" {
table = "vault_kv_store"
}
hae := os.Getenv("PGHAENABLED")
if hae == "" {
hae = "true"
}
// Run vault tests
logger.Info(fmt.Sprintf("Connection URL: %v", connURL))
b1, err := NewPostgreSQLBackend(map[string]string{
"connection_url": connURL,
"table": table,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
b2, err := NewPostgreSQLBackend(map[string]string{
"connection_url": connURL,
"table": table,
"ha_enabled": hae,
}, logger)
if err != nil {
t.Fatalf("Failed to create new backend: %v", err)
}
pg := b1.(*PostgreSQLBackend)
// Read postgres version to test basic connects works
var pgversion string
if err = pg.client.QueryRow("SELECT current_setting('server_version_num')").Scan(&pgversion); err != nil {
t.Fatalf("Failed to check for Postgres version: %v", err)
}
logger.Info(fmt.Sprintf("Postgres Version: %v", pgversion))
setupDatabaseObjects(t, logger, pg)
defer func() {
pg := b1.(*PostgreSQLBackend)
_, err := pg.client.Exec(fmt.Sprintf(" TRUNCATE TABLE %v ", pg.table))
if err != nil {
t.Fatalf("Failed to truncate table: %v", err)
}
}()
logger.Info("Running basic backend tests")
physical.ExerciseBackend(t, b1)
logger.Info("Running list prefix backend tests")
physical.ExerciseBackend_ListPrefix(t, b1)
ha1, ok := b1.(physical.HABackend)
if !ok {
t.Fatalf("PostgreSQLDB does not implement HABackend")
}
ha2, ok := b2.(physical.HABackend)
if !ok {
t.Fatalf("PostgreSQLDB does not implement HABackend")
}
if ha1.HAEnabled() && ha2.HAEnabled() {
logger.Info("Running ha backend tests")
physical.ExerciseHABackend(t, ha1, ha2)
testPostgresSQLLockTTL(t, ha1)
testPostgresSQLLockRenewal(t, ha1)
}
}
func TestPostgreSQLBackendMaxIdleConnectionsParameter(t *testing.T) {
_, err := NewPostgreSQLBackend(map[string]string{
"connection_url": "some connection url",
"max_idle_connections": "bad param",
}, logging.NewVaultLogger(log.Debug))
if err == nil {
t.Error("Expected invalid max_idle_connections param to return error")
}
expectedErrStr := "failed parsing max_idle_connections parameter: strconv.Atoi: parsing \"bad param\": invalid syntax"
if err.Error() != expectedErrStr {
t.Errorf("Expected: %q but found %q", expectedErrStr, err.Error())
}
}
func TestConnectionURL(t *testing.T) {
type input struct {
envar string
conf map[string]string
}
cases := map[string]struct {
want string
input input
}{
"environment_variable_not_set_use_config_value": {
want: "abc",
input: input{
envar: "",
conf: map[string]string{"connection_url": "abc"},
},
},
"no_value_connection_url_set_key_exists": {
want: "",
input: input{
envar: "",
conf: map[string]string{"connection_url": ""},
},
},
"no_value_connection_url_set_key_doesnt_exist": {
want: "",
input: input{
envar: "",
conf: map[string]string{},
},
},
"environment_variable_set": {
want: "abc",
input: input{
envar: "abc",
conf: map[string]string{"connection_url": "def"},
},
},
}
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
// This is necessary to avoid always testing the branch where the env is set.
// As long the the env is set --- even if the value is "" --- `ok` returns true.
if tt.input.envar != "" {
os.Setenv("VAULT_PG_CONNECTION_URL", tt.input.envar)
defer os.Unsetenv("VAULT_PG_CONNECTION_URL")
}
got := connectionURL(tt.input.conf)
if got != tt.want {
t.Errorf("connectionURL(%s): want %q, got %q", tt.input, tt.want, got)
}
})
}
}
// Similar to testHABackend, but using internal implementation details to
// trigger the lock failure scenario by setting the lock renew period for one
// of the locks to a higher value than the lock TTL.
const maxTries = 3
func testPostgresSQLLockTTL(t *testing.T, ha physical.HABackend) {
t.Log("Skipping testPostgresSQLLockTTL portion of test.")
return
for tries := 1; tries <= maxTries; tries++ {
// Try this several times. If the test environment is too slow the lock can naturally lapse
if attemptLockTTLTest(t, ha, tries) {
break
}
}
}
func attemptLockTTLTest(t *testing.T, ha physical.HABackend, tries int) bool {
// Set much smaller lock times to speed up the test.
lockTTL := 3
renewInterval := time.Second * 1
retryInterval := time.Second * 1
longRenewInterval := time.Duration(lockTTL*2) * time.Second
lockkey := "postgresttl"
var leaderCh <-chan struct{}
// Get the lock
origLock, err := ha.LockWith(lockkey, "bar")
if err != nil {
t.Fatalf("err: %v", err)
}
{
// set the first lock renew period to double the expected TTL.
lock := origLock.(*PostgreSQLLock)
lock.renewInterval = longRenewInterval
lock.ttlSeconds = lockTTL
// Attempt to lock
lockTime := time.Now()
leaderCh, err = lock.Lock(nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh == nil {
t.Fatalf("failed to get leader ch")
}
if tries == 1 {
time.Sleep(3 * time.Second)
}
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
if tries < maxTries && time.Since(lockTime) > (time.Second*time.Duration(lockTTL)) {
// Our test environment is slow enough that we failed this, retry
return false
}
t.Fatalf("should be held")
}
if val != "bar" {
t.Fatalf("bad value: %v", val)
}
}
// Second acquisition should succeed because the first lock should
// not renew within the 3 sec TTL.
origLock2, err := ha.LockWith(lockkey, "baz")
if err != nil {
t.Fatalf("err: %v", err)
}
{
lock2 := origLock2.(*PostgreSQLLock)
lock2.renewInterval = renewInterval
lock2.ttlSeconds = lockTTL
lock2.retryInterval = retryInterval
// Cancel attempt in 6 sec so as not to block unit tests forever
stopCh := make(chan struct{})
time.AfterFunc(time.Duration(lockTTL*2)*time.Second, func() {
close(stopCh)
})
// Attempt to lock should work
lockTime := time.Now()
leaderCh2, err := lock2.Lock(stopCh)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh2 == nil {
t.Fatalf("should get leader ch")
}
defer lock2.Unlock()
// Check the value
held, val, err := lock2.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
if tries < maxTries && time.Since(lockTime) > (time.Second*time.Duration(lockTTL)) {
// Our test environment is slow enough that we failed this, retry
return false
}
t.Fatalf("should be held")
}
if val != "baz" {
t.Fatalf("bad value: %v", val)
}
}
// The first lock should have lost the leader channel
select {
case <-time.After(longRenewInterval * 2):
t.Fatalf("original lock did not have its leader channel closed.")
case <-leaderCh:
}
return true
}
// Verify that once Unlock is called, we don't keep trying to renew the original
// lock.
func testPostgresSQLLockRenewal(t *testing.T, ha physical.HABackend) {
// Get the lock
origLock, err := ha.LockWith("pgrenewal", "bar")
if err != nil {
t.Fatalf("err: %v", err)
}
// customize the renewal and watch intervals
lock := origLock.(*PostgreSQLLock)
// lock.renewInterval = time.Second * 1
// Attempt to lock
leaderCh, err := lock.Lock(nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh == nil {
t.Fatalf("failed to get leader ch")
}
// Check the value
held, val, err := lock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
t.Fatalf("should be held")
}
if val != "bar" {
t.Fatalf("bad value: %v", val)
}
// Release the lock, which will delete the stored item
if err := lock.Unlock(); err != nil {
t.Fatalf("err: %v", err)
}
// Wait longer than the renewal time
time.Sleep(1500 * time.Millisecond)
// Attempt to lock with new lock
newLock, err := ha.LockWith("pgrenewal", "baz")
if err != nil {
t.Fatalf("err: %v", err)
}
stopCh := make(chan struct{})
timeout := time.Duration(lock.ttlSeconds)*time.Second + lock.retryInterval + time.Second
var leaderCh2 <-chan struct{}
newlockch := make(chan struct{})
go func() {
leaderCh2, err = newLock.Lock(stopCh)
close(newlockch)
}()
// Cancel attempt after lock ttl + 1s so as not to block unit tests forever
select {
case <-time.After(timeout):
t.Logf("giving up on lock attempt after %v", timeout)
close(stopCh)
case <-newlockch:
// pass through
}
// Attempt to lock should work
if err != nil {
t.Fatalf("err: %v", err)
}
if leaderCh2 == nil {
t.Fatalf("should get leader ch")
}
// Check the value
held, val, err = newLock.Value()
if err != nil {
t.Fatalf("err: %v", err)
}
if !held {
t.Fatalf("should be held")
}
if val != "baz" {
t.Fatalf("bad value: %v", val)
}
// Cleanup
newLock.Unlock()
}
func setupDatabaseObjects(t *testing.T, logger log.Logger, pg *PostgreSQLBackend) {
var err error
// Setup tables and indexes if not exists.
createTableSQL := fmt.Sprintf(
" CREATE TABLE IF NOT EXISTS %v ( "+
" parent_path TEXT COLLATE \"C\" NOT NULL, "+
" path TEXT COLLATE \"C\", "+
" key TEXT COLLATE \"C\", "+
" value BYTEA, "+
" CONSTRAINT pkey PRIMARY KEY (path, key) "+
" ); ", pg.table)
_, err = pg.client.Exec(createTableSQL)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
createIndexSQL := fmt.Sprintf(" CREATE INDEX IF NOT EXISTS parent_path_idx ON %v (parent_path); ", pg.table)
_, err = pg.client.Exec(createIndexSQL)
if err != nil {
t.Fatalf("Failed to create index: %v", err)
}
createHaTableSQL := " CREATE TABLE IF NOT EXISTS vault_ha_locks ( " +
" ha_key TEXT COLLATE \"C\" NOT NULL, " +
" ha_identity TEXT COLLATE \"C\" NOT NULL, " +
" ha_value TEXT COLLATE \"C\", " +
" valid_until TIMESTAMP WITH TIME ZONE NOT NULL, " +
" CONSTRAINT ha_key PRIMARY KEY (ha_key) " +
" ); "
_, err = pg.client.Exec(createHaTableSQL)
if err != nil {
t.Fatalf("Failed to create hatable: %v", err)
}
}