mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-06 06:37:02 +02:00
database/postgres: add inline certificate authentication fields (#28024)
* add inline cert auth to postres db plugin * handle both sslinline and new TLS plugin fields * refactor PrepareTestContainerWithSSL * add tests for postgres inline TLS fields * changelog * revert back to errwrap since the middleware sanitizing depends on it * enable only setting sslrootcert
This commit is contained in:
parent
a19195c901
commit
3fcb1a67c5
@ -345,6 +345,8 @@ func TestBackend_config_connection(t *testing.T) {
|
||||
assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"])
|
||||
}
|
||||
|
||||
// TestBackend_BadConnectionString tests that an error response resulting from
|
||||
// a failed connection does not expose the URL. The middleware should sanitize it.
|
||||
func TestBackend_BadConnectionString(t *testing.T) {
|
||||
cluster, sys := getClusterPostgresDB(t)
|
||||
defer cluster.Cleanup()
|
||||
|
3
changelog/28024.txt
Normal file
3
changelog/28024.txt
Normal file
@ -0,0 +1,3 @@
|
||||
```release-note:improvement
|
||||
database/postgres: Add new fields to the plugin's config endpoint for client certificate authentication.
|
||||
```
|
@ -9,11 +9,13 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/docker"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -68,7 +70,13 @@ func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func(
|
||||
|
||||
// PrepareTestContainerWithSSL will setup a test container with SSL enabled so
|
||||
// that we can test client certificate authentication.
|
||||
func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode string, useFallback bool) (func(), string) {
|
||||
func PrepareTestContainerWithSSL(
|
||||
t *testing.T,
|
||||
sslMode string,
|
||||
caCert certhelpers.Certificate,
|
||||
clientCert certhelpers.Certificate,
|
||||
useFallback bool,
|
||||
) (func(), string) {
|
||||
runOpts := defaultRunOpts(t)
|
||||
runner, err := docker.NewServiceRunner(runOpts)
|
||||
if err != nil {
|
||||
@ -82,21 +90,11 @@ func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode stri
|
||||
}
|
||||
|
||||
// Create certificates for postgres authentication
|
||||
caCert := certhelpers.NewCert(t,
|
||||
certhelpers.CommonName("ca"),
|
||||
certhelpers.IsCA(true),
|
||||
certhelpers.SelfSign(),
|
||||
)
|
||||
serverCert := certhelpers.NewCert(t,
|
||||
certhelpers.CommonName("server"),
|
||||
certhelpers.DNS("localhost"),
|
||||
certhelpers.Parent(caCert),
|
||||
)
|
||||
clientCert := certhelpers.NewCert(t,
|
||||
certhelpers.CommonName("postgres"),
|
||||
certhelpers.DNS("localhost"),
|
||||
certhelpers.Parent(caCert),
|
||||
)
|
||||
|
||||
bCtx := docker.NewBuildContext()
|
||||
bCtx["ca.crt"] = docker.PathContentsFromBytes(caCert.CombinedPEM())
|
||||
@ -133,6 +131,9 @@ EOF
|
||||
t.Fatalf("failed to copy to container: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// overwrite the postgresql.conf config file with our ssl settings
|
||||
mustRunCommand(t, ctx, runner, id,
|
||||
[]string{"bash", "/var/lib/postgresql/pg-conf.sh"})
|
||||
@ -150,7 +151,7 @@ EOF
|
||||
return svc.Cleanup, svc.Config.URL().String()
|
||||
}
|
||||
|
||||
sslConfig, err := connectPostgresSSL(
|
||||
sslConfig := getPostgresSSLConfig(
|
||||
t,
|
||||
svc.Config.URL().Host,
|
||||
sslMode,
|
||||
@ -197,42 +198,40 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri
|
||||
return runner, svc.Cleanup, svc.Config.URL().String(), containerID
|
||||
}
|
||||
|
||||
// connectPostgresSSL is used to verify the connection of our test container
|
||||
// and construct the connection string that is used in tests.
|
||||
//
|
||||
// NOTE: The RawQuery component of the url sets the custom sslinline field and
|
||||
// inlines the certificate material in the sslrootcert, sslcert, and sslkey
|
||||
// fields. This feature will be removed in a future version of the SDK.
|
||||
func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) {
|
||||
func getPostgresSSLConfig(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) docker.ServiceConfig {
|
||||
if useFallback {
|
||||
// set the first host to a bad address so we can test the fallback logic
|
||||
host = "localhost:55," + host
|
||||
}
|
||||
u := url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.User("postgres"),
|
||||
Host: host,
|
||||
Path: "postgres",
|
||||
RawQuery: url.Values{
|
||||
"sslmode": {sslMode},
|
||||
"sslinline": {"true"},
|
||||
"sslrootcert": {caCert},
|
||||
"sslcert": {clientCert},
|
||||
"sslkey": {clientKey},
|
||||
}.Encode(),
|
||||
|
||||
u := url.URL{}
|
||||
|
||||
if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); ok {
|
||||
// TODO: remove this when we remove the underlying feature in a future SDK version
|
||||
u = url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.User("postgres"),
|
||||
Host: host,
|
||||
Path: "postgres",
|
||||
RawQuery: url.Values{
|
||||
"sslmode": {sslMode},
|
||||
"sslinline": {"true"},
|
||||
"sslrootcert": {caCert},
|
||||
"sslcert": {clientCert},
|
||||
"sslkey": {clientKey},
|
||||
}.Encode(),
|
||||
}
|
||||
} else {
|
||||
u = url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.User("postgres"),
|
||||
Host: host,
|
||||
Path: "postgres",
|
||||
RawQuery: url.Values{"sslmode": {sslMode}}.Encode(),
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove this deprecated function call in a future SDK version
|
||||
db, err := connutil.OpenPostgres("pgx", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err = db.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return docker.NewServiceURL(u), nil
|
||||
return docker.NewServiceURL(u)
|
||||
}
|
||||
|
||||
func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter {
|
||||
|
@ -123,11 +123,8 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte
|
||||
}
|
||||
|
||||
// validate auth_type if provided
|
||||
authType := c.AuthType
|
||||
if authType != "" {
|
||||
if ok := connutil.ValidateAuthType(authType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
|
||||
}
|
||||
if ok := connutil.ValidateAuthType(c.AuthType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType)
|
||||
}
|
||||
|
||||
if c.AuthType == connutil.AuthTypeGCPIAM {
|
||||
|
@ -5,7 +5,11 @@ package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
@ -79,11 +83,65 @@ func new() *PostgreSQL {
|
||||
type PostgreSQL struct {
|
||||
*connutil.SQLConnectionProducer
|
||||
|
||||
TLSCertificateData []byte `json:"tls_certificate" structs:"-" mapstructure:"tls_certificate"`
|
||||
TLSPrivateKey []byte `json:"tls_private_key" structs:"-" mapstructure:"tls_private_key"`
|
||||
TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"`
|
||||
|
||||
usernameProducer template.StringTemplate
|
||||
passwordAuthentication passwordAuthentication
|
||||
}
|
||||
|
||||
func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
|
||||
sslcert, err := strutil.GetString(req.Config, "tls_certificate")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_certificate: %w", err)
|
||||
}
|
||||
|
||||
sslkey, err := strutil.GetString(req.Config, "tls_private_key")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_private_key: %w", err)
|
||||
}
|
||||
|
||||
sslrootcert, err := strutil.GetString(req.Config, "tls_ca")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_ca: %w", err)
|
||||
}
|
||||
|
||||
useTLS := false
|
||||
tlsConfig := &tls.Config{}
|
||||
if sslrootcert != "" {
|
||||
caCertPool := x509.NewCertPool()
|
||||
if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) {
|
||||
return dbplugin.InitializeResponse{}, errors.New("unable to add CA to cert pool")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
p.TLSConfig = tlsConfig
|
||||
useTLS = true
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return dbplugin.InitializeResponse{}, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
block, _ := pem.Decode([]byte(sslkey))
|
||||
|
||||
cert, err := tls.X509KeyPair([]byte(sslcert), pem.EncodeToMemory(block))
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("unable to load cert: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
p.TLSConfig = tlsConfig
|
||||
useTLS = true
|
||||
}
|
||||
|
||||
if !useTLS {
|
||||
// set to nil to flag that this connection does not use a custom TLS config
|
||||
p.TLSConfig = nil
|
||||
}
|
||||
|
||||
newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection)
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, err
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/vault/helper/testhelpers/certhelpers"
|
||||
"github.com/hashicorp/vault/helper/testhelpers/postgresql"
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
||||
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
|
||||
@ -86,15 +87,18 @@ func TestPostgreSQL_InitializeMultiHost(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
|
||||
// TestPostgreSQL_InitializeSSLInlineFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE
|
||||
// flag guards against unwanted usage of the deprecated SSL client authentication path.
|
||||
// TODO: remove this when we remove the underlying feature in a future SDK version
|
||||
func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
|
||||
func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) {
|
||||
// set the flag to true so we can call PrepareTestContainerWithSSL
|
||||
// which does a validation check on the connection
|
||||
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
|
||||
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false)
|
||||
// Create certificates for postgres authentication
|
||||
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
|
||||
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, "verify-ca", caCert, clientCert, false)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
type testCase struct {
|
||||
@ -166,11 +170,11 @@ func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
|
||||
// TestPostgreSQL_InitializeSSLInline tests that we can successfully authenticate
|
||||
// with a postgres server via ssl with a URL connection string or DSN (key/value)
|
||||
// for each ssl mode.
|
||||
// TODO: remove this when we remove the underlying feature in a future SDK version
|
||||
func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
func TestPostgreSQL_InitializeSSLInline(t *testing.T) {
|
||||
// required to enable the sslinline custom parsing
|
||||
t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true")
|
||||
|
||||
@ -287,7 +291,11 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback)
|
||||
|
||||
// Create certificates for postgres authentication
|
||||
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
|
||||
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
if test.useDSN {
|
||||
@ -326,6 +334,188 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate
|
||||
// with a postgres server via ssl with a URL connection string or DSN (key/value)
|
||||
// for each ssl mode.
|
||||
func TestPostgreSQL_InitializeSSL(t *testing.T) {
|
||||
type testCase struct {
|
||||
sslMode string
|
||||
useDSN bool
|
||||
useFallback bool
|
||||
wantErr bool
|
||||
expectedError string
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"disable sslmode": {
|
||||
sslMode: "disable",
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode": {
|
||||
sslMode: "allow",
|
||||
wantErr: false,
|
||||
},
|
||||
"prefer sslmode": {
|
||||
sslMode: "prefer",
|
||||
wantErr: false,
|
||||
},
|
||||
"require sslmode": {
|
||||
sslMode: "require",
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-ca sslmode": {
|
||||
sslMode: "verify-ca",
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-full sslmode": {
|
||||
sslMode: "verify-full",
|
||||
wantErr: false,
|
||||
},
|
||||
"disable sslmode with DSN": {
|
||||
sslMode: "disable",
|
||||
useDSN: true,
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode with DSN": {
|
||||
sslMode: "allow",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"prefer sslmode with DSN": {
|
||||
sslMode: "prefer",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"require sslmode with DSN": {
|
||||
sslMode: "require",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-ca sslmode with DSN": {
|
||||
sslMode: "verify-ca",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-full sslmode with DSN": {
|
||||
sslMode: "verify-full",
|
||||
useDSN: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"disable sslmode with fallback": {
|
||||
sslMode: "disable",
|
||||
useFallback: true,
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode with fallback": {
|
||||
sslMode: "allow",
|
||||
useFallback: true,
|
||||
},
|
||||
"prefer sslmode with fallback": {
|
||||
sslMode: "prefer",
|
||||
useFallback: true,
|
||||
},
|
||||
"require sslmode with fallback": {
|
||||
sslMode: "require",
|
||||
useFallback: true,
|
||||
},
|
||||
"verify-ca sslmode with fallback": {
|
||||
sslMode: "verify-ca",
|
||||
useFallback: true,
|
||||
},
|
||||
"verify-full sslmode with fallback": {
|
||||
sslMode: "verify-full",
|
||||
useFallback: true,
|
||||
},
|
||||
"disable sslmode with DSN with fallback": {
|
||||
sslMode: "disable",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: true,
|
||||
expectedError: "error verifying connection",
|
||||
},
|
||||
"allow sslmode with DSN with fallback": {
|
||||
sslMode: "allow",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"prefer sslmode with DSN with fallback": {
|
||||
sslMode: "prefer",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"require sslmode with DSN with fallback": {
|
||||
sslMode: "require",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-ca sslmode with DSN with fallback": {
|
||||
sslMode: "verify-ca",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"verify-full sslmode with DSN with fallback": {
|
||||
sslMode: "verify-full",
|
||||
useDSN: true,
|
||||
useFallback: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create certificates for postgres authentication
|
||||
caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign())
|
||||
clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert))
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
if test.useDSN {
|
||||
var err error
|
||||
connURL, err = dbutil.ParseURL(connURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": connURL,
|
||||
"max_open_connections": 5,
|
||||
"tls_certificate": string(clientCert.CombinedPEM()),
|
||||
"tls_private_key": string(clientCert.PrivateKeyPEM()),
|
||||
"tls_ca": string(caCert.CombinedPEM()),
|
||||
}
|
||||
|
||||
req := dbplugin.InitializeRequest{
|
||||
Config: connectionDetails,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
db := new()
|
||||
_, err := dbtesting.VerifyInitialize(t, db, req)
|
||||
if test.wantErr && err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
} else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) {
|
||||
t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError)
|
||||
}
|
||||
|
||||
if !test.wantErr && !db.Initialized {
|
||||
t.Fatal("Database should be initialized")
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
t.Fatalf("err: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_InitializeWithStringVals(t *testing.T) {
|
||||
db, cleanup := getPostgreSQL(t, map[string]interface{}{
|
||||
"max_open_connections": "5",
|
||||
|
@ -10,10 +10,6 @@ import (
|
||||
"cloud.google.com/go/cloudsqlconn/postgres/pgxv4"
|
||||
)
|
||||
|
||||
var configurableAuthTypes = []string{
|
||||
AuthTypeGCPIAM,
|
||||
}
|
||||
|
||||
func (c *SQLConnectionProducer) getCloudSQLDriverType() (string, error) {
|
||||
var driverType string
|
||||
// using switch case for future extensibility
|
||||
@ -62,15 +58,3 @@ func GetCloudSQLAuthOptions(credentials string, usePrivateIP bool) ([]cloudsqlco
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func ValidateAuthType(authType string) bool {
|
||||
var valid bool
|
||||
for _, typ := range configurableAuthTypes {
|
||||
if authType == typ {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return valid
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ import (
|
||||
"github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
// OpenPostgres parses the connection string and opens a connection to the database.
|
||||
// openPostgres parses the connection string and opens a connection to the database.
|
||||
//
|
||||
// If sslinline is set, strips the connection string of all ssl settings and
|
||||
// creates a TLS config based on the settings provided, then uses the
|
||||
@ -54,8 +54,8 @@ import (
|
||||
// because the pgx driver does not support the sslinline parameter and instead
|
||||
// expects to source ssl material from the file system.
|
||||
//
|
||||
// Deprecated: OpenPostgres will be removed in a future version of the Vault SDK.
|
||||
func OpenPostgres(driverName, connString string) (*sql.DB, error) {
|
||||
// Deprecated: openPostgres will be removed in a future version of the Vault SDK.
|
||||
func openPostgres(driverName, connString string) (*sql.DB, error) {
|
||||
if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok {
|
||||
return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable")
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ package connutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@ -19,12 +20,18 @@ import (
|
||||
"github.com/hashicorp/vault/sdk/database/dbplugin"
|
||||
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
||||
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthTypeGCPIAM = "gcp_iam"
|
||||
AuthTypeGCPIAM = "gcp_iam"
|
||||
AuthTypeCert = "cert"
|
||||
AuthTypeUsernamePassword = ""
|
||||
)
|
||||
|
||||
const (
|
||||
dbTypePostgres = "pgx"
|
||||
cloudSQLPostgres = "cloudsql-postgres"
|
||||
)
|
||||
@ -37,14 +44,19 @@ type SQLConnectionProducer struct {
|
||||
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
|
||||
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
|
||||
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
|
||||
Username string `json:"username" mapstructure:"username" structs:"username"`
|
||||
Password string `json:"password" mapstructure:"password" structs:"password"`
|
||||
AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
|
||||
ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`
|
||||
DisableEscaping bool `json:"disable_escaping" mapstructure:"disable_escaping" structs:"disable_escaping"`
|
||||
usePrivateIP bool `json:"use_private_ip" mapstructure:"use_private_ip" structs:"use_private_ip"`
|
||||
|
||||
// cloud options here - cloudDriverName is globally unique, but only needs to be retained for the lifetime
|
||||
// Username/Password is the default auth type when AuthType is not set
|
||||
Username string `json:"username" mapstructure:"username" structs:"username"`
|
||||
Password string `json:"password" mapstructure:"password" structs:"password"`
|
||||
|
||||
// AuthType defines the type of client authenticate used for this connection
|
||||
AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"`
|
||||
ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"`
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// cloudDriverName is globally unique, but only needs to be retained for the lifetime
|
||||
// of driver registration, not across plugin restarts.
|
||||
cloudDriverName string
|
||||
cloudDialerCleanup func() error
|
||||
@ -125,15 +137,11 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf
|
||||
return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err)
|
||||
}
|
||||
|
||||
// validate auth_type if provided
|
||||
authType := c.AuthType
|
||||
if authType != "" {
|
||||
if ok := ValidateAuthType(authType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type %s provided", authType)
|
||||
}
|
||||
if ok := ValidateAuthType(c.AuthType); !ok {
|
||||
return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType)
|
||||
}
|
||||
|
||||
if authType == AuthTypeGCPIAM {
|
||||
if c.AuthType == AuthTypeGCPIAM {
|
||||
c.cloudDriverName, err = uuid.GenerateUUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err)
|
||||
@ -161,7 +169,7 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf
|
||||
}
|
||||
|
||||
if err := c.db.PingContext(ctx); err != nil {
|
||||
return nil, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
||||
return nil, errwrap.Wrapf("error verifying connection: ping failed: {{err}}", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -219,16 +227,42 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if driverName == "pgx" && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" {
|
||||
// TODO: remove this deprecated function call in a future SDK version
|
||||
c.db, err = OpenPostgres(driverName, conn)
|
||||
} else {
|
||||
c.db, err = sql.Open(driverName, conn)
|
||||
}
|
||||
if driverName == dbTypePostgres && c.TLSConfig != nil {
|
||||
config, err := pgx.ParseConfig(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config: %w", err)
|
||||
}
|
||||
if config.TLSConfig == nil {
|
||||
// handle sslmode=disable
|
||||
config.TLSConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
config.TLSConfig.RootCAs = c.TLSConfig.RootCAs
|
||||
config.TLSConfig.ClientCAs = c.TLSConfig.ClientCAs
|
||||
config.TLSConfig.Certificates = c.TLSConfig.Certificates
|
||||
|
||||
// Ensure there are no stale fallbacks when manually setting TLSConfig
|
||||
for _, fallback := range config.Fallbacks {
|
||||
fallback.TLSConfig = config.TLSConfig
|
||||
}
|
||||
|
||||
c.db = stdlib.OpenDB(*config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open connection: %w", err)
|
||||
}
|
||||
} else if driverName == dbTypePostgres && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" {
|
||||
var err error
|
||||
// TODO: remove this deprecated function call in a future SDK version
|
||||
c.db, err = openPostgres(driverName, conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open connection: %w", err)
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
c.db, err = sql.Open(driverName, conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set some connection pool settings. We don't need much of this,
|
||||
@ -277,3 +311,13 @@ func (c *SQLConnectionProducer) Close() error {
|
||||
func (c *SQLConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) {
|
||||
return "", "", dbutil.Unimplemented()
|
||||
}
|
||||
|
||||
var configurableAuthTypes = map[string]bool{
|
||||
AuthTypeUsernamePassword: true,
|
||||
AuthTypeCert: true,
|
||||
AuthTypeGCPIAM: true,
|
||||
}
|
||||
|
||||
func ValidateAuthType(authType string) bool {
|
||||
return configurableAuthTypes[authType]
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ require (
|
||||
github.com/jackc/pgproto3/v2 v2.3.3 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgtype v1.14.0 // indirect
|
||||
github.com/jackc/pgx/v4 v4.18.3 // indirect
|
||||
github.com/jackc/pgx/v4 v4.18.3
|
||||
github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531 // indirect
|
||||
github.com/klauspost/compress v1.16.5 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
|
Loading…
Reference in New Issue
Block a user