mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-08 15:47:00 +02:00
This is part 1 of 4 for renaming the `newdbplugin` package. This copies the existing package to the new location but keeps the current one in place so we can migrate the existing references over more easily.
283 lines
8.9 KiB
Go
283 lines
8.9 KiB
Go
package cassandra
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gocql/gocql"
|
|
"github.com/hashicorp/errwrap"
|
|
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
|
|
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
|
"github.com/hashicorp/vault/sdk/helper/certutil"
|
|
"github.com/hashicorp/vault/sdk/helper/parseutil"
|
|
"github.com/hashicorp/vault/sdk/helper/tlsutil"
|
|
"github.com/mitchellh/mapstructure"
|
|
)
|
|
|
|
// cassandraConnectionProducer implements ConnectionProducer and provides an
|
|
// interface for cassandra databases to make connections.
|
|
type cassandraConnectionProducer struct {
|
|
Hosts string `json:"hosts" structs:"hosts" mapstructure:"hosts"`
|
|
Port int `json:"port" structs:"port" mapstructure:"port"`
|
|
Username string `json:"username" structs:"username" mapstructure:"username"`
|
|
Password string `json:"password" structs:"password" mapstructure:"password"`
|
|
TLS bool `json:"tls" structs:"tls" mapstructure:"tls"`
|
|
InsecureTLS bool `json:"insecure_tls" structs:"insecure_tls" mapstructure:"insecure_tls"`
|
|
ProtocolVersion int `json:"protocol_version" structs:"protocol_version" mapstructure:"protocol_version"`
|
|
ConnectTimeoutRaw interface{} `json:"connect_timeout" structs:"connect_timeout" mapstructure:"connect_timeout"`
|
|
SocketKeepAliveRaw interface{} `json:"socket_keep_alive" structs:"socket_keep_alive" mapstructure:"socket_keep_alive"`
|
|
TLSMinVersion string `json:"tls_min_version" structs:"tls_min_version" mapstructure:"tls_min_version"`
|
|
Consistency string `json:"consistency" structs:"consistency" mapstructure:"consistency"`
|
|
LocalDatacenter string `json:"local_datacenter" structs:"local_datacenter" mapstructure:"local_datacenter"`
|
|
PemBundle string `json:"pem_bundle" structs:"pem_bundle" mapstructure:"pem_bundle"`
|
|
PemJSON string `json:"pem_json" structs:"pem_json" mapstructure:"pem_json"`
|
|
SkipVerification bool `json:"skip_verification" structs:"skip_verification" mapstructure:"skip_verification"`
|
|
|
|
connectTimeout time.Duration
|
|
socketKeepAlive time.Duration
|
|
certificate string
|
|
privateKey string
|
|
issuingCA string
|
|
rawConfig map[string]interface{}
|
|
|
|
Initialized bool
|
|
Type string
|
|
session *gocql.Session
|
|
sync.Mutex
|
|
}
|
|
|
|
func (c *cassandraConnectionProducer) Initialize(ctx context.Context, req dbplugin.InitializeRequest) (dbplugin.InitializeResponse, error) {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
c.rawConfig = req.Config
|
|
|
|
err := mapstructure.WeakDecode(req.Config, c)
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, err
|
|
}
|
|
|
|
if c.ConnectTimeoutRaw == nil {
|
|
c.ConnectTimeoutRaw = "0s"
|
|
}
|
|
c.connectTimeout, err = parseutil.ParseDurationSecond(c.ConnectTimeoutRaw)
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, errwrap.Wrapf("invalid connect_timeout: {{err}}", err)
|
|
}
|
|
|
|
if c.SocketKeepAliveRaw == nil {
|
|
c.SocketKeepAliveRaw = "0s"
|
|
}
|
|
c.socketKeepAlive, err = parseutil.ParseDurationSecond(c.SocketKeepAliveRaw)
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, errwrap.Wrapf("invalid socket_keep_alive: {{err}}", err)
|
|
}
|
|
|
|
switch {
|
|
case len(c.Hosts) == 0:
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("hosts cannot be empty")
|
|
case len(c.Username) == 0:
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("username cannot be empty")
|
|
case len(c.Password) == 0:
|
|
return dbplugin.InitializeResponse{}, fmt.Errorf("password cannot be empty")
|
|
}
|
|
|
|
var certBundle *certutil.CertBundle
|
|
var parsedCertBundle *certutil.ParsedCertBundle
|
|
switch {
|
|
case len(c.PemJSON) != 0:
|
|
parsedCertBundle, err = certutil.ParsePKIJSON([]byte(c.PemJSON))
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, errwrap.Wrapf("could not parse given JSON; it must be in the format of the output of the PKI backend certificate issuing command: {{err}}", err)
|
|
}
|
|
certBundle, err = parsedCertBundle.ToCertBundle()
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
|
|
}
|
|
c.certificate = certBundle.Certificate
|
|
c.privateKey = certBundle.PrivateKey
|
|
c.issuingCA = certBundle.IssuingCA
|
|
c.TLS = true
|
|
|
|
case len(c.PemBundle) != 0:
|
|
parsedCertBundle, err = certutil.ParsePEMBundle(c.PemBundle)
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, errwrap.Wrapf("Error parsing the given PEM information: {{err}}", err)
|
|
}
|
|
certBundle, err = parsedCertBundle.ToCertBundle()
|
|
if err != nil {
|
|
return dbplugin.InitializeResponse{}, errwrap.Wrapf("Error marshaling PEM information: {{err}}", err)
|
|
}
|
|
c.certificate = certBundle.Certificate
|
|
c.privateKey = certBundle.PrivateKey
|
|
c.issuingCA = certBundle.IssuingCA
|
|
c.TLS = true
|
|
}
|
|
|
|
// Set initialized to true at this point since all fields are set,
|
|
// and the connection can be established at a later time.
|
|
c.Initialized = true
|
|
|
|
if req.VerifyConnection {
|
|
if _, err := c.Connection(ctx); err != nil {
|
|
return dbplugin.InitializeResponse{}, errwrap.Wrapf("error verifying connection: {{err}}", err)
|
|
}
|
|
}
|
|
|
|
resp := dbplugin.InitializeResponse{
|
|
Config: req.Config,
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *cassandraConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
|
if !c.Initialized {
|
|
return nil, connutil.ErrNotInitialized
|
|
}
|
|
|
|
// If we already have a DB, return it
|
|
if c.session != nil && !c.session.Closed() {
|
|
return c.session, nil
|
|
}
|
|
|
|
session, err := c.createSession(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Store the session in backend for reuse
|
|
c.session = session
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func (c *cassandraConnectionProducer) Close() error {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
if c.session != nil {
|
|
c.session.Close()
|
|
}
|
|
|
|
c.session = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *cassandraConnectionProducer) createSession(ctx context.Context) (*gocql.Session, error) {
|
|
hosts := strings.Split(c.Hosts, ",")
|
|
clusterConfig := gocql.NewCluster(hosts...)
|
|
clusterConfig.Authenticator = gocql.PasswordAuthenticator{
|
|
Username: c.Username,
|
|
Password: c.Password,
|
|
}
|
|
|
|
if c.Port != 0 {
|
|
clusterConfig.Port = c.Port
|
|
}
|
|
|
|
clusterConfig.ProtoVersion = c.ProtocolVersion
|
|
if clusterConfig.ProtoVersion == 0 {
|
|
clusterConfig.ProtoVersion = 2
|
|
}
|
|
|
|
clusterConfig.Timeout = c.connectTimeout
|
|
clusterConfig.SocketKeepalive = c.socketKeepAlive
|
|
if c.TLS {
|
|
var tlsConfig *tls.Config
|
|
if len(c.certificate) > 0 || len(c.issuingCA) > 0 {
|
|
if len(c.certificate) > 0 && len(c.privateKey) == 0 {
|
|
return nil, fmt.Errorf("found certificate for TLS authentication but no private key")
|
|
}
|
|
|
|
certBundle := &certutil.CertBundle{}
|
|
if len(c.certificate) > 0 {
|
|
certBundle.Certificate = c.certificate
|
|
certBundle.PrivateKey = c.privateKey
|
|
}
|
|
if len(c.issuingCA) > 0 {
|
|
certBundle.IssuingCA = c.issuingCA
|
|
}
|
|
|
|
parsedCertBundle, err := certBundle.ToParsedCertBundle()
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("failed to parse certificate bundle: {{err}}", err)
|
|
}
|
|
|
|
tlsConfig, err = parsedCertBundle.GetTLSConfig(certutil.TLSClient)
|
|
if err != nil || tlsConfig == nil {
|
|
return nil, errwrap.Wrapf(fmt.Sprintf("failed to get TLS configuration: tlsConfig:%#v err:{{err}}", tlsConfig), err)
|
|
}
|
|
tlsConfig.InsecureSkipVerify = c.InsecureTLS
|
|
|
|
if c.TLSMinVersion != "" {
|
|
var ok bool
|
|
tlsConfig.MinVersion, ok = tlsutil.TLSLookup[c.TLSMinVersion]
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
|
|
}
|
|
} else {
|
|
// MinVersion was not being set earlier. Reset it to
|
|
// zero to gracefully handle upgrades.
|
|
tlsConfig.MinVersion = 0
|
|
}
|
|
}
|
|
|
|
clusterConfig.SslOpts = &gocql.SslOptions{
|
|
Config: tlsConfig,
|
|
}
|
|
}
|
|
|
|
if c.LocalDatacenter != "" {
|
|
clusterConfig.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy(c.LocalDatacenter)
|
|
}
|
|
|
|
session, err := clusterConfig.CreateSession()
|
|
if err != nil {
|
|
return nil, errwrap.Wrapf("error creating session: {{err}}", err)
|
|
}
|
|
|
|
if c.Consistency != "" {
|
|
consistencyValue, err := gocql.ParseConsistencyWrapper(c.Consistency)
|
|
if err != nil {
|
|
session.Close()
|
|
return nil, err
|
|
}
|
|
|
|
session.SetConsistency(consistencyValue)
|
|
}
|
|
|
|
if !c.SkipVerification {
|
|
err = session.Query(`LIST ALL`).WithContext(ctx).Exec()
|
|
if err != nil && len(c.Username) != 0 && strings.Contains(err.Error(), "not authorized") {
|
|
rowNum := session.Query(dbutil.QueryHelper(`LIST CREATE ON ALL ROLES OF '{{username}}';`, map[string]string{
|
|
"username": c.Username,
|
|
})).Iter().NumRows()
|
|
|
|
if rowNum < 1 {
|
|
session.Close()
|
|
return nil, errwrap.Wrapf("error validating connection info: No role create permissions found, previous error: {{err}}", err)
|
|
}
|
|
} else if err != nil {
|
|
session.Close()
|
|
return nil, errwrap.Wrapf("error validating connection info: {{err}}", err)
|
|
}
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func (c *cassandraConnectionProducer) secretValues() map[string]string {
|
|
return map[string]string{
|
|
c.Password: "[password]",
|
|
c.PemBundle: "[pem_bundle]",
|
|
c.PemJSON: "[pem_json]",
|
|
}
|
|
}
|