mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-07 07:07:05 +02:00
* add inline cert auth to postres db plugin * handle both sslinline and new TLS plugin fields * refactor PrepareTestContainerWithSSL * add tests for postgres inline TLS fields * changelog * revert back to errwrap since the middleware sanitizing depends on it * enable only setting sslrootcert
467 lines
13 KiB
Go
467 lines
13 KiB
Go
// Copyright (c) 2019-2021 Jack Christensen
|
|
|
|
// MIT License
|
|
|
|
// Permission is hereby granted, free of charge, to any person obtaining
|
|
// a copy of this software and associated documentation files (the
|
|
// "Software"), to deal in the Software without restriction, including
|
|
// without limitation the rights to use, copy, modify, merge, publish,
|
|
// distribute, sublicense, and/or sell copies of the Software, and to
|
|
// permit persons to whom the Software is furnished to do so, subject to
|
|
// the following conditions:
|
|
|
|
// The above copyright notice and this permission notice shall be
|
|
// included in all copies or substantial portions of the Software.
|
|
|
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
|
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
|
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
|
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
// Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go
|
|
|
|
package connutil
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/hashicorp/vault/sdk/helper/pluginutil"
|
|
"github.com/jackc/pgconn"
|
|
"github.com/jackc/pgx/v4"
|
|
"github.com/jackc/pgx/v4/stdlib"
|
|
)
|
|
|
|
// openPostgres parses the connection string and opens a connection to the database.
|
|
//
|
|
// If sslinline is set, strips the connection string of all ssl settings and
|
|
// creates a TLS config based on the settings provided, then uses the
|
|
// RegisterConnConfig function to create a new connection. This is necessary
|
|
// because the pgx driver does not support the sslinline parameter and instead
|
|
// expects to source ssl material from the file system.
|
|
//
|
|
// Deprecated: openPostgres will be removed in a future version of the Vault SDK.
|
|
func openPostgres(driverName, connString string) (*sql.DB, error) {
|
|
if ok, _ := strconv.ParseBool(os.Getenv(pluginutil.PluginUsePostgresSSLInline)); !ok {
|
|
return nil, fmt.Errorf("failed to open postgres connection with deprecated funtion, set feature flag to enable")
|
|
}
|
|
|
|
var options pgconn.ParseConfigOptions
|
|
|
|
settings := make(map[string]string)
|
|
if connString != "" {
|
|
var err error
|
|
// connString may be a database URL or a DSN
|
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
|
settings, err = parsePostgresURLSettings(connString)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse as URL: %w", err)
|
|
}
|
|
} else {
|
|
settings, err = parsePostgresDSNSettings(connString)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse as DSN: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// get the inline flag
|
|
sslInline := settings["sslinline"] == "true"
|
|
|
|
// if sslinline is not set, open a regular connection
|
|
if !sslInline {
|
|
return sql.Open(driverName, connString)
|
|
}
|
|
|
|
// generate a new DSN without the ssl settings
|
|
newConnStr := []string{"sslmode=disable"}
|
|
for k, v := range settings {
|
|
switch k {
|
|
case "sslinline", "sslcert", "sslkey", "sslrootcert", "sslmode":
|
|
continue
|
|
}
|
|
|
|
newConnStr = append(newConnStr, fmt.Sprintf("%s='%s'", k, v))
|
|
}
|
|
|
|
// parse the updated config
|
|
config, err := pgx.ParseConfig(strings.Join(newConnStr, " "))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// create a TLS config
|
|
fallbacks := []*pgconn.FallbackConfig{}
|
|
|
|
hosts := strings.Split(settings["host"], ",")
|
|
ports := strings.Split(settings["port"], ",")
|
|
|
|
for i, host := range hosts {
|
|
var portStr string
|
|
if i < len(ports) {
|
|
portStr = ports[i]
|
|
} else {
|
|
portStr = ports[0]
|
|
}
|
|
|
|
port, err := parsePort(portStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid port: %w", err)
|
|
}
|
|
|
|
var tlsConfigs []*tls.Config
|
|
|
|
// Ignore TLS settings if Unix domain socket like libpq
|
|
if network, _ := pgconn.NetworkAddress(host, port); network == "unix" {
|
|
tlsConfigs = append(tlsConfigs, nil)
|
|
} else {
|
|
var err error
|
|
tlsConfigs, err = configPostgresTLS(settings, host, options)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to configure TLS: %w", err)
|
|
}
|
|
}
|
|
|
|
for _, tlsConfig := range tlsConfigs {
|
|
fallbacks = append(fallbacks, &pgconn.FallbackConfig{
|
|
Host: host,
|
|
Port: port,
|
|
TLSConfig: tlsConfig,
|
|
})
|
|
}
|
|
}
|
|
|
|
config.Host = fallbacks[0].Host
|
|
config.Port = fallbacks[0].Port
|
|
config.TLSConfig = fallbacks[0].TLSConfig
|
|
config.Fallbacks = fallbacks[1:]
|
|
|
|
return sql.Open(driverName, stdlib.RegisterConnConfig(config))
|
|
}
|
|
|
|
// configPostgresTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
|
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
|
// "prefer" allow fallback.
|
|
//
|
|
// Copied from https://github.com/jackc/pgconn/blob/1860f4e57204614f40d05a5c76a43e8d80fde9da/config.go
|
|
// and modified to read ssl material by value instead of file location.
|
|
func configPostgresTLS(settings map[string]string, thisHost string, parseConfigOptions pgconn.ParseConfigOptions) ([]*tls.Config, error) {
|
|
host := thisHost
|
|
sslmode := settings["sslmode"]
|
|
sslrootcert := settings["sslrootcert"]
|
|
sslcert := settings["sslcert"]
|
|
sslkey := settings["sslkey"]
|
|
sslpassword := settings["sslpassword"]
|
|
sslsni := settings["sslsni"]
|
|
|
|
// Match libpq default behavior
|
|
if sslmode == "" {
|
|
sslmode = "prefer"
|
|
}
|
|
if sslsni == "" {
|
|
sslsni = "1"
|
|
}
|
|
|
|
tlsConfig := &tls.Config{}
|
|
|
|
switch sslmode {
|
|
case "disable":
|
|
return []*tls.Config{nil}, nil
|
|
case "allow", "prefer":
|
|
tlsConfig.InsecureSkipVerify = true
|
|
case "require":
|
|
// According to PostgreSQL documentation, if a root CA file exists,
|
|
// the behavior of sslmode=require should be the same as that of verify-ca
|
|
//
|
|
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
|
if sslrootcert != "" {
|
|
goto nextCase
|
|
}
|
|
tlsConfig.InsecureSkipVerify = true
|
|
break
|
|
nextCase:
|
|
fallthrough
|
|
case "verify-ca":
|
|
// Don't perform the default certificate verification because it
|
|
// will verify the hostname. Instead, verify the server's
|
|
// certificate chain ourselves in VerifyPeerCertificate and
|
|
// ignore the server name. This emulates libpq's verify-ca
|
|
// behavior.
|
|
//
|
|
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
|
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
|
// for more info.
|
|
tlsConfig.InsecureSkipVerify = true
|
|
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
|
certs := make([]*x509.Certificate, len(certificates))
|
|
for i, asn1Data := range certificates {
|
|
cert, err := x509.ParseCertificate(asn1Data)
|
|
if err != nil {
|
|
return errors.New("failed to parse certificate from server: " + err.Error())
|
|
}
|
|
certs[i] = cert
|
|
}
|
|
|
|
// Leave DNSName empty to skip hostname verification.
|
|
opts := x509.VerifyOptions{
|
|
Roots: tlsConfig.RootCAs,
|
|
Intermediates: x509.NewCertPool(),
|
|
}
|
|
// Skip the first cert because it's the leaf. All others
|
|
// are intermediates.
|
|
for _, cert := range certs[1:] {
|
|
opts.Intermediates.AddCert(cert)
|
|
}
|
|
_, err := certs[0].Verify(opts)
|
|
return err
|
|
}
|
|
case "verify-full":
|
|
tlsConfig.ServerName = host
|
|
default:
|
|
return nil, errors.New("sslmode is invalid")
|
|
}
|
|
|
|
if sslrootcert != "" {
|
|
caCertPool := x509.NewCertPool()
|
|
if !caCertPool.AppendCertsFromPEM([]byte(sslrootcert)) {
|
|
return nil, errors.New("unable to add CA to cert pool")
|
|
}
|
|
|
|
tlsConfig.RootCAs = caCertPool
|
|
tlsConfig.ClientCAs = caCertPool
|
|
}
|
|
|
|
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
|
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
|
}
|
|
|
|
if sslcert != "" && sslkey != "" {
|
|
block, _ := pem.Decode([]byte(sslkey))
|
|
var pemKey []byte
|
|
var decryptedKey []byte
|
|
var decryptedError error
|
|
// If PEM is encrypted, attempt to decrypt using pass phrase
|
|
if x509.IsEncryptedPEMBlock(block) {
|
|
// Attempt decryption with pass phrase
|
|
// NOTE: only supports RSA (PKCS#1)
|
|
if sslpassword != "" {
|
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
|
}
|
|
// if sslpassword not provided or has decryption error when use it
|
|
// try to find sslpassword with callback function
|
|
if sslpassword == "" || decryptedError != nil {
|
|
if parseConfigOptions.GetSSLPassword != nil {
|
|
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
|
}
|
|
if sslpassword == "" {
|
|
return nil, fmt.Errorf("unable to find sslpassword")
|
|
}
|
|
}
|
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
|
// Should we also provide warning for PKCS#1 needed?
|
|
if decryptedError != nil {
|
|
return nil, fmt.Errorf("unable to decrypt key: %w", decryptedError)
|
|
}
|
|
|
|
pemBytes := pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: decryptedKey,
|
|
}
|
|
pemKey = pem.EncodeToMemory(&pemBytes)
|
|
} else {
|
|
pemKey = pem.EncodeToMemory(block)
|
|
}
|
|
|
|
cert, err := tls.X509KeyPair([]byte(sslcert), pemKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to load cert: %w", err)
|
|
}
|
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
|
}
|
|
|
|
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
|
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
|
// or IPv6).
|
|
if sslsni == "1" && net.ParseIP(host) == nil {
|
|
tlsConfig.ServerName = host
|
|
}
|
|
|
|
switch sslmode {
|
|
case "allow":
|
|
return []*tls.Config{nil, tlsConfig}, nil
|
|
case "prefer":
|
|
return []*tls.Config{tlsConfig, nil}, nil
|
|
case "require", "verify-ca", "verify-full":
|
|
return []*tls.Config{tlsConfig}, nil
|
|
default:
|
|
panic("BUG: bad sslmode should already have been caught")
|
|
}
|
|
}
|
|
|
|
func parsePort(s string) (uint16, error) {
|
|
port, err := strconv.ParseUint(s, 10, 16)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if port < 1 || port > math.MaxUint16 {
|
|
return 0, errors.New("outside range")
|
|
}
|
|
return uint16(port), nil
|
|
}
|
|
|
|
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
|
|
|
func parsePostgresURLSettings(connString string) (map[string]string, error) {
|
|
settings := make(map[string]string)
|
|
|
|
url, err := url.Parse(connString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if url.User != nil {
|
|
settings["user"] = url.User.Username()
|
|
if password, present := url.User.Password(); present {
|
|
settings["password"] = password
|
|
}
|
|
}
|
|
|
|
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
|
var hosts []string
|
|
var ports []string
|
|
for _, host := range strings.Split(url.Host, ",") {
|
|
if host == "" {
|
|
continue
|
|
}
|
|
if isIPOnly(host) {
|
|
hosts = append(hosts, strings.Trim(host, "[]"))
|
|
continue
|
|
}
|
|
h, p, err := net.SplitHostPort(host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
|
}
|
|
if h != "" {
|
|
hosts = append(hosts, h)
|
|
}
|
|
if p != "" {
|
|
ports = append(ports, p)
|
|
}
|
|
}
|
|
if len(hosts) > 0 {
|
|
settings["host"] = strings.Join(hosts, ",")
|
|
}
|
|
if len(ports) > 0 {
|
|
settings["port"] = strings.Join(ports, ",")
|
|
}
|
|
|
|
database := strings.TrimLeft(url.Path, "/")
|
|
if database != "" {
|
|
settings["database"] = database
|
|
}
|
|
|
|
nameMap := map[string]string{
|
|
"dbname": "database",
|
|
}
|
|
|
|
for k, v := range url.Query() {
|
|
if k2, present := nameMap[k]; present {
|
|
k = k2
|
|
}
|
|
|
|
settings[k] = v[0]
|
|
}
|
|
|
|
return settings, nil
|
|
}
|
|
|
|
func parsePostgresDSNSettings(s string) (map[string]string, error) {
|
|
settings := make(map[string]string)
|
|
|
|
nameMap := map[string]string{
|
|
"dbname": "database",
|
|
}
|
|
|
|
for len(s) > 0 {
|
|
var key, val string
|
|
eqIdx := strings.IndexRune(s, '=')
|
|
if eqIdx < 0 {
|
|
return nil, errors.New("invalid dsn")
|
|
}
|
|
|
|
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
|
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
|
if len(s) == 0 {
|
|
} else if s[0] != '\'' {
|
|
end := 0
|
|
for ; end < len(s); end++ {
|
|
if asciiSpace[s[end]] == 1 {
|
|
break
|
|
}
|
|
if s[end] == '\\' {
|
|
end++
|
|
if end == len(s) {
|
|
return nil, errors.New("invalid backslash")
|
|
}
|
|
}
|
|
}
|
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
|
if end == len(s) {
|
|
s = ""
|
|
} else {
|
|
s = s[end+1:]
|
|
}
|
|
} else { // quoted string
|
|
s = s[1:]
|
|
end := 0
|
|
for ; end < len(s); end++ {
|
|
if s[end] == '\'' {
|
|
break
|
|
}
|
|
if s[end] == '\\' {
|
|
end++
|
|
}
|
|
}
|
|
if end == len(s) {
|
|
return nil, errors.New("unterminated quoted string in connection info string")
|
|
}
|
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
|
if end == len(s) {
|
|
s = ""
|
|
} else {
|
|
s = s[end+1:]
|
|
}
|
|
}
|
|
|
|
if k, ok := nameMap[key]; ok {
|
|
key = k
|
|
}
|
|
|
|
if key == "" {
|
|
return nil, errors.New("invalid dsn")
|
|
}
|
|
|
|
settings[key] = val
|
|
}
|
|
|
|
return settings, nil
|
|
}
|
|
|
|
func isIPOnly(host string) bool {
|
|
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
|
}
|