diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index a861470002..eb14d78443 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -345,6 +345,8 @@ func TestBackend_config_connection(t *testing.T) { assert.Equal(t, "plugin-test", eventSender.Events[2].Event.Metadata.AsMap()["name"]) } +// TestBackend_BadConnectionString tests that an error response resulting from +// a failed connection does not expose the URL. The middleware should sanitize it. func TestBackend_BadConnectionString(t *testing.T) { cluster, sys := getClusterPostgresDB(t) defer cluster.Cleanup() diff --git a/changelog/28024.txt b/changelog/28024.txt new file mode 100644 index 0000000000..8d1fbaa0e2 --- /dev/null +++ b/changelog/28024.txt @@ -0,0 +1,3 @@ +```release-note:improvement +database/postgres: Add new fields to the plugin's config endpoint for client certificate authentication. +``` diff --git a/helper/testhelpers/postgresql/postgresqlhelper.go b/helper/testhelpers/postgresql/postgresqlhelper.go index cf144b192a..7229d2127b 100644 --- a/helper/testhelpers/postgresql/postgresqlhelper.go +++ b/helper/testhelpers/postgresql/postgresqlhelper.go @@ -9,11 +9,13 @@ import ( "fmt" "net/url" "os" + "strconv" "testing" + "time" "github.com/hashicorp/vault/helper/testhelpers/certhelpers" - "github.com/hashicorp/vault/sdk/database/helper/connutil" "github.com/hashicorp/vault/sdk/helper/docker" + "github.com/hashicorp/vault/sdk/helper/pluginutil" ) const ( @@ -68,7 +70,13 @@ func PrepareTestContainerWithVaultUser(t *testing.T, ctx context.Context) (func( // PrepareTestContainerWithSSL will setup a test container with SSL enabled so // that we can test client certificate authentication. -func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode string, useFallback bool) (func(), string) { +func PrepareTestContainerWithSSL( + t *testing.T, + sslMode string, + caCert certhelpers.Certificate, + clientCert certhelpers.Certificate, + useFallback bool, +) (func(), string) { runOpts := defaultRunOpts(t) runner, err := docker.NewServiceRunner(runOpts) if err != nil { @@ -82,21 +90,11 @@ func PrepareTestContainerWithSSL(t *testing.T, ctx context.Context, sslMode stri } // Create certificates for postgres authentication - caCert := certhelpers.NewCert(t, - certhelpers.CommonName("ca"), - certhelpers.IsCA(true), - certhelpers.SelfSign(), - ) serverCert := certhelpers.NewCert(t, certhelpers.CommonName("server"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert), ) - clientCert := certhelpers.NewCert(t, - certhelpers.CommonName("postgres"), - certhelpers.DNS("localhost"), - certhelpers.Parent(caCert), - ) bCtx := docker.NewBuildContext() bCtx["ca.crt"] = docker.PathContentsFromBytes(caCert.CombinedPEM()) @@ -133,6 +131,9 @@ EOF t.Fatalf("failed to copy to container: %v", err) } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // overwrite the postgresql.conf config file with our ssl settings mustRunCommand(t, ctx, runner, id, []string{"bash", "/var/lib/postgresql/pg-conf.sh"}) @@ -150,7 +151,7 @@ EOF return svc.Cleanup, svc.Config.URL().String() } - sslConfig, err := connectPostgresSSL( + sslConfig := getPostgresSSLConfig( t, svc.Config.URL().Host, sslMode, @@ -197,42 +198,40 @@ func prepareTestContainer(t *testing.T, runOpts docker.RunOptions, password stri return runner, svc.Cleanup, svc.Config.URL().String(), containerID } -// connectPostgresSSL is used to verify the connection of our test container -// and construct the connection string that is used in tests. -// -// NOTE: The RawQuery component of the url sets the custom sslinline field and -// inlines the certificate material in the sslrootcert, sslcert, and sslkey -// fields. This feature will be removed in a future version of the SDK. -func connectPostgresSSL(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) (docker.ServiceConfig, error) { +func getPostgresSSLConfig(t *testing.T, host, sslMode, caCert, clientCert, clientKey string, useFallback bool) docker.ServiceConfig { if useFallback { // set the first host to a bad address so we can test the fallback logic host = "localhost:55," + host } - u := url.URL{ - Scheme: "postgres", - User: url.User("postgres"), - Host: host, - Path: "postgres", - RawQuery: url.Values{ - "sslmode": {sslMode}, - "sslinline": {"true"}, - "sslrootcert": {caCert}, - "sslcert": {clientCert}, - "sslkey": {clientKey}, - }.Encode(), + + u := url.URL{} + + if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); ok { + // TODO: remove this when we remove the underlying feature in a future SDK version + u = url.URL{ + Scheme: "postgres", + User: url.User("postgres"), + Host: host, + Path: "postgres", + RawQuery: url.Values{ + "sslmode": {sslMode}, + "sslinline": {"true"}, + "sslrootcert": {caCert}, + "sslcert": {clientCert}, + "sslkey": {clientKey}, + }.Encode(), + } + } else { + u = url.URL{ + Scheme: "postgres", + User: url.User("postgres"), + Host: host, + Path: "postgres", + RawQuery: url.Values{"sslmode": {sslMode}}.Encode(), + } } - // TODO: remove this deprecated function call in a future SDK version - db, err := connutil.OpenPostgres("pgx", u.String()) - if err != nil { - return nil, err - } - defer db.Close() - - if err = db.Ping(); err != nil { - return nil, err - } - return docker.NewServiceURL(u), nil + return docker.NewServiceURL(u) } func connectPostgres(password, repo string, useFallback bool) docker.ServiceAdapter { diff --git a/plugins/database/mysql/connection_producer.go b/plugins/database/mysql/connection_producer.go index 778626d65e..f35bfaf522 100644 --- a/plugins/database/mysql/connection_producer.go +++ b/plugins/database/mysql/connection_producer.go @@ -123,11 +123,8 @@ func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]inte } // validate auth_type if provided - authType := c.AuthType - if authType != "" { - if ok := connutil.ValidateAuthType(authType); !ok { - return nil, fmt.Errorf("invalid auth_type %s provided", authType) - } + if ok := connutil.ValidateAuthType(c.AuthType); !ok { + return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType) } if c.AuthType == connutil.AuthTypeGCPIAM { diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 0bdd916412..a9279a2867 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -5,7 +5,11 @@ package postgresql import ( "context" + "crypto/tls" + "crypto/x509" "database/sql" + "encoding/pem" + "errors" "fmt" "regexp" "strings" @@ -79,11 +83,65 @@ func new() *PostgreSQL { type PostgreSQL struct { *connutil.SQLConnectionProducer + TLSCertificateData []byte `json:"tls_certificate" structs:"-" mapstructure:"tls_certificate"` + TLSPrivateKey []byte `json:"tls_private_key" structs:"-" mapstructure:"tls_private_key"` + TLSCAData []byte `json:"tls_ca" structs:"-" mapstructure:"tls_ca"` + usernameProducer template.StringTemplate passwordAuthentication passwordAuthentication } func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) { + sslcert, err := strutil.GetString(req.Config, "tls_certificate") + if err != nil { + return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_certificate: %w", err) + } + + sslkey, err := strutil.GetString(req.Config, "tls_private_key") + if err != nil { + return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_private_key: %w", err) + } + + sslrootcert, err := strutil.GetString(req.Config, "tls_ca") + if err != nil { + return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve tls_ca: %w", err) + } + + useTLS := false + tlsConfig := &tls.Config{} + if sslrootcert != "" { + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) { + return dbplugin.InitializeResponse{}, errors.New("unable to add CA to cert pool") + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + p.TLSConfig = tlsConfig + useTLS = true + } + + if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { + return dbplugin.InitializeResponse{}, errors.New(`both "sslcert" and "sslkey" are required`) + } + + if sslcert != "" && sslkey != "" { + block, _ := pem.Decode([]byte(sslkey)) + + cert, err := tls.X509KeyPair([]byte(sslcert), pem.EncodeToMemory(block)) + if err != nil { + return dbplugin.InitializeResponse{}, fmt.Errorf("unable to load cert: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + p.TLSConfig = tlsConfig + useTLS = true + } + + if !useTLS { + // set to nil to flag that this connection does not use a custom TLS config + p.TLSConfig = nil + } + newConf, err := p.SQLConnectionProducer.Init(ctx, req.Config, req.VerifyConnection) if err != nil { return dbplugin.InitializeResponse{}, err diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index 0ca347dab4..e9d4efd20e 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/hashicorp/vault/helper/testhelpers/certhelpers" "github.com/hashicorp/vault/helper/testhelpers/postgresql" "github.com/hashicorp/vault/sdk/database/dbplugin/v5" dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing" @@ -86,15 +87,18 @@ func TestPostgreSQL_InitializeMultiHost(t *testing.T) { } } -// TestPostgreSQL_InitializeSSLFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE +// TestPostgreSQL_InitializeSSLInlineFeatureFlag tests that the VAULT_PLUGIN_USE_POSTGRES_SSLINLINE // flag guards against unwanted usage of the deprecated SSL client authentication path. // TODO: remove this when we remove the underlying feature in a future SDK version -func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) { +func TestPostgreSQL_InitializeSSLInlineFeatureFlag(t *testing.T) { // set the flag to true so we can call PrepareTestContainerWithSSL // which does a validation check on the connection t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true") - cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), "verify-ca", false) + // Create certificates for postgres authentication + caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) + clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, "verify-ca", caCert, clientCert, false) t.Cleanup(cleanup) type testCase struct { @@ -166,11 +170,11 @@ func TestPostgreSQL_InitializeSSLFeatureFlag(t *testing.T) { } } -// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate +// TestPostgreSQL_InitializeSSLInline tests that we can successfully authenticate // with a postgres server via ssl with a URL connection string or DSN (key/value) // for each ssl mode. // TODO: remove this when we remove the underlying feature in a future SDK version -func TestPostgreSQL_InitializeSSL(t *testing.T) { +func TestPostgreSQL_InitializeSSLInline(t *testing.T) { // required to enable the sslinline custom parsing t.Setenv(pluginutil.PluginUsePostgresSSLInline, "true") @@ -287,7 +291,11 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { t.Parallel() - cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, context.Background(), test.sslMode, test.useFallback) + + // Create certificates for postgres authentication + caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) + clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback) t.Cleanup(cleanup) if test.useDSN { @@ -326,6 +334,188 @@ func TestPostgreSQL_InitializeSSL(t *testing.T) { } } +// TestPostgreSQL_InitializeSSL tests that we can successfully authenticate +// with a postgres server via ssl with a URL connection string or DSN (key/value) +// for each ssl mode. +func TestPostgreSQL_InitializeSSL(t *testing.T) { + type testCase struct { + sslMode string + useDSN bool + useFallback bool + wantErr bool + expectedError string + } + + tests := map[string]testCase{ + "disable sslmode": { + sslMode: "disable", + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode": { + sslMode: "allow", + wantErr: false, + }, + "prefer sslmode": { + sslMode: "prefer", + wantErr: false, + }, + "require sslmode": { + sslMode: "require", + wantErr: false, + }, + "verify-ca sslmode": { + sslMode: "verify-ca", + wantErr: false, + }, + "verify-full sslmode": { + sslMode: "verify-full", + wantErr: false, + }, + "disable sslmode with DSN": { + sslMode: "disable", + useDSN: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with DSN": { + sslMode: "allow", + useDSN: true, + wantErr: false, + }, + "prefer sslmode with DSN": { + sslMode: "prefer", + useDSN: true, + wantErr: false, + }, + "require sslmode with DSN": { + sslMode: "require", + useDSN: true, + wantErr: false, + }, + "verify-ca sslmode with DSN": { + sslMode: "verify-ca", + useDSN: true, + wantErr: false, + }, + "verify-full sslmode with DSN": { + sslMode: "verify-full", + useDSN: true, + wantErr: false, + }, + "disable sslmode with fallback": { + sslMode: "disable", + useFallback: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with fallback": { + sslMode: "allow", + useFallback: true, + }, + "prefer sslmode with fallback": { + sslMode: "prefer", + useFallback: true, + }, + "require sslmode with fallback": { + sslMode: "require", + useFallback: true, + }, + "verify-ca sslmode with fallback": { + sslMode: "verify-ca", + useFallback: true, + }, + "verify-full sslmode with fallback": { + sslMode: "verify-full", + useFallback: true, + }, + "disable sslmode with DSN with fallback": { + sslMode: "disable", + useDSN: true, + useFallback: true, + wantErr: true, + expectedError: "error verifying connection", + }, + "allow sslmode with DSN with fallback": { + sslMode: "allow", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "prefer sslmode with DSN with fallback": { + sslMode: "prefer", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "require sslmode with DSN with fallback": { + sslMode: "require", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "verify-ca sslmode with DSN with fallback": { + sslMode: "verify-ca", + useDSN: true, + useFallback: true, + wantErr: false, + }, + "verify-full sslmode with DSN with fallback": { + sslMode: "verify-full", + useDSN: true, + useFallback: true, + wantErr: false, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + // Create certificates for postgres authentication + caCert := certhelpers.NewCert(t, certhelpers.CommonName("ca"), certhelpers.IsCA(true), certhelpers.SelfSign()) + clientCert := certhelpers.NewCert(t, certhelpers.CommonName("postgres"), certhelpers.DNS("localhost"), certhelpers.Parent(caCert)) + cleanup, connURL := postgresql.PrepareTestContainerWithSSL(t, test.sslMode, caCert, clientCert, test.useFallback) + t.Cleanup(cleanup) + + if test.useDSN { + var err error + connURL, err = dbutil.ParseURL(connURL) + if err != nil { + t.Fatal(err) + } + } + connectionDetails := map[string]interface{}{ + "connection_url": connURL, + "max_open_connections": 5, + "tls_certificate": string(clientCert.CombinedPEM()), + "tls_private_key": string(clientCert.PrivateKeyPEM()), + "tls_ca": string(caCert.CombinedPEM()), + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + _, err := dbtesting.VerifyInitialize(t, db, req) + if test.wantErr && err == nil { + t.Fatal("expected error, got nil") + } else if test.wantErr && !strings.Contains(err.Error(), test.expectedError) { + t.Fatalf("got: %s, want: %s", err.Error(), test.expectedError) + } + + if !test.wantErr && !db.Initialized { + t.Fatal("Database should be initialized") + } + + if err := db.Close(); err != nil { + t.Fatalf("err: %s", err) + } + }) + } +} + func TestPostgreSQL_InitializeWithStringVals(t *testing.T) { db, cleanup := getPostgreSQL(t, map[string]interface{}{ "max_open_connections": "5", diff --git a/sdk/database/helper/connutil/cloudsql.go b/sdk/database/helper/connutil/cloudsql.go index 5d81440cc3..f6cbba1d24 100644 --- a/sdk/database/helper/connutil/cloudsql.go +++ b/sdk/database/helper/connutil/cloudsql.go @@ -10,10 +10,6 @@ import ( "cloud.google.com/go/cloudsqlconn/postgres/pgxv4" ) -var configurableAuthTypes = []string{ - AuthTypeGCPIAM, -} - func (c *SQLConnectionProducer) getCloudSQLDriverType() (string, error) { var driverType string // using switch case for future extensibility @@ -62,15 +58,3 @@ func GetCloudSQLAuthOptions(credentials string, usePrivateIP bool) ([]cloudsqlco return opts, nil } - -func ValidateAuthType(authType string) bool { - var valid bool - for _, typ := range configurableAuthTypes { - if authType == typ { - valid = true - break - } - } - - return valid -} diff --git a/sdk/database/helper/connutil/postgres.go b/sdk/database/helper/connutil/postgres.go index 7d96376bd2..f8ad876c5a 100644 --- a/sdk/database/helper/connutil/postgres.go +++ b/sdk/database/helper/connutil/postgres.go @@ -46,7 +46,7 @@ import ( "github.com/jackc/pgx/v4/stdlib" ) -// OpenPostgres parses the connection string and opens a connection to the database. +// openPostgres parses the connection string and opens a connection to the database. // // If sslinline is set, strips the connection string of all ssl settings and // creates a TLS config based on the settings provided, then uses the @@ -54,8 +54,8 @@ import ( // because the pgx driver does not support the sslinline parameter and instead // expects to source ssl material from the file system. // -// Deprecated: OpenPostgres will be removed in a future version of the Vault SDK. -func OpenPostgres(driverName, connString string) (*sql.DB, error) { +// Deprecated: openPostgres will be removed in a future version of the Vault SDK. +func openPostgres(driverName, connString string) (*sql.DB, error) { if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok { return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable") } diff --git a/sdk/database/helper/connutil/sql.go b/sdk/database/helper/connutil/sql.go index 548cc83d38..bd19e77f6b 100644 --- a/sdk/database/helper/connutil/sql.go +++ b/sdk/database/helper/connutil/sql.go @@ -5,6 +5,7 @@ package connutil import ( "context" + "crypto/tls" "database/sql" "fmt" "net/url" @@ -19,12 +20,18 @@ import ( "github.com/hashicorp/vault/sdk/database/dbplugin" "github.com/hashicorp/vault/sdk/database/helper/dbutil" "github.com/hashicorp/vault/sdk/helper/pluginutil" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" "github.com/mitchellh/mapstructure" ) const ( - AuthTypeGCPIAM = "gcp_iam" + AuthTypeGCPIAM = "gcp_iam" + AuthTypeCert = "cert" + AuthTypeUsernamePassword = "" +) +const ( dbTypePostgres = "pgx" cloudSQLPostgres = "cloudsql-postgres" ) @@ -37,14 +44,19 @@ type SQLConnectionProducer struct { MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"` MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"` - Username string `json:"username" mapstructure:"username" structs:"username"` - Password string `json:"password" mapstructure:"password" structs:"password"` - AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"` - ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"` DisableEscaping bool `json:"disable_escaping" mapstructure:"disable_escaping" structs:"disable_escaping"` usePrivateIP bool `json:"use_private_ip" mapstructure:"use_private_ip" structs:"use_private_ip"` - // cloud options here - cloudDriverName is globally unique, but only needs to be retained for the lifetime + // Username/Password is the default auth type when AuthType is not set + Username string `json:"username" mapstructure:"username" structs:"username"` + Password string `json:"password" mapstructure:"password" structs:"password"` + + // AuthType defines the type of client authenticate used for this connection + AuthType string `json:"auth_type" mapstructure:"auth_type" structs:"auth_type"` + ServiceAccountJSON string `json:"service_account_json" mapstructure:"service_account_json" structs:"service_account_json"` + TLSConfig *tls.Config + + // cloudDriverName is globally unique, but only needs to be retained for the lifetime // of driver registration, not across plugin restarts. cloudDriverName string cloudDialerCleanup func() error @@ -125,15 +137,11 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf return nil, errwrap.Wrapf("invalid max_connection_lifetime: {{err}}", err) } - // validate auth_type if provided - authType := c.AuthType - if authType != "" { - if ok := ValidateAuthType(authType); !ok { - return nil, fmt.Errorf("invalid auth_type %s provided", authType) - } + if ok := ValidateAuthType(c.AuthType); !ok { + return nil, fmt.Errorf("invalid auth_type: %s", c.AuthType) } - if authType == AuthTypeGCPIAM { + if c.AuthType == AuthTypeGCPIAM { c.cloudDriverName, err = uuid.GenerateUUID() if err != nil { return nil, fmt.Errorf("unable to generate UUID for IAM configuration: %w", err) @@ -161,7 +169,7 @@ func (c *SQLConnectionProducer) Init(ctx context.Context, conf map[string]interf } if err := c.db.PingContext(ctx); err != nil { - return nil, errwrap.Wrapf("error verifying connection: {{err}}", err) + return nil, errwrap.Wrapf("error verifying connection: ping failed: {{err}}", err) } } @@ -219,16 +227,42 @@ func (c *SQLConnectionProducer) Connection(ctx context.Context) (interface{}, er } } - var err error - if driverName == "pgx" && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" { - // TODO: remove this deprecated function call in a future SDK version - c.db, err = OpenPostgres(driverName, conn) - } else { - c.db, err = sql.Open(driverName, conn) - } + if driverName == dbTypePostgres && c.TLSConfig != nil { + config, err := pgx.ParseConfig(conn) + if err != nil { + return nil, fmt.Errorf("failed to parse config: %w", err) + } + if config.TLSConfig == nil { + // handle sslmode=disable + config.TLSConfig = &tls.Config{} + } - if err != nil { - return nil, err + config.TLSConfig.RootCAs = c.TLSConfig.RootCAs + config.TLSConfig.ClientCAs = c.TLSConfig.ClientCAs + config.TLSConfig.Certificates = c.TLSConfig.Certificates + + // Ensure there are no stale fallbacks when manually setting TLSConfig + for _, fallback := range config.Fallbacks { + fallback.TLSConfig = config.TLSConfig + } + + c.db = stdlib.OpenDB(*config) + if err != nil { + return nil, fmt.Errorf("failed to open connection: %w", err) + } + } else if driverName == dbTypePostgres && os.Getenv(pluginutil.PluginUsePostgresSSLInline) != "" { + var err error + // TODO: remove this deprecated function call in a future SDK version + c.db, err = openPostgres(driverName, conn) + if err != nil { + return nil, fmt.Errorf("failed to open connection: %w", err) + } + } else { + var err error + c.db, err = sql.Open(driverName, conn) + if err != nil { + return nil, fmt.Errorf("failed to open connection: %w", err) + } } // Set some connection pool settings. We don't need much of this, @@ -277,3 +311,13 @@ func (c *SQLConnectionProducer) Close() error { func (c *SQLConnectionProducer) SetCredentials(ctx context.Context, statements dbplugin.Statements, staticUser dbplugin.StaticUserConfig) (username, password string, err error) { return "", "", dbutil.Unimplemented() } + +var configurableAuthTypes = map[string]bool{ + AuthTypeUsernamePassword: true, + AuthTypeCert: true, + AuthTypeGCPIAM: true, +} + +func ValidateAuthType(authType string) bool { + return configurableAuthTypes[authType] +} diff --git a/sdk/go.mod b/sdk/go.mod index 4373b1bdeb..33a6225ff0 100644 --- a/sdk/go.mod +++ b/sdk/go.mod @@ -84,7 +84,7 @@ require ( github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgtype v1.14.0 // indirect - github.com/jackc/pgx/v4 v4.18.3 // indirect + github.com/jackc/pgx/v4 v4.18.3 github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531 // indirect github.com/klauspost/compress v1.16.5 // indirect github.com/mattn/go-colorable v0.1.13 // indirect