mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-15 02:57:04 +02:00
* Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License. Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUS-1.1 * Fix test that expected exact offset on hcl file --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com> Co-authored-by: Sarah Thompson <sthompson@hashicorp.com> Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
233 lines
6.6 KiB
Go
233 lines
6.6 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"fmt"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-sql-driver/mysql"
|
|
"github.com/hashicorp/go-secure-stdlib/parseutil"
|
|
"github.com/hashicorp/go-uuid"
|
|
"github.com/hashicorp/vault/sdk/database/helper/connutil"
|
|
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
|
|
"github.com/mitchellh/mapstructure"
|
|
)
|
|
|
|
// mySQLConnectionProducer implements ConnectionProducer and provides a generic producer for most sql databases
|
|
type mySQLConnectionProducer struct {
|
|
ConnectionURL string `json:"connection_url" mapstructure:"connection_url" structs:"connection_url"`
|
|
MaxOpenConnections int `json:"max_open_connections" mapstructure:"max_open_connections" structs:"max_open_connections"`
|
|
MaxIdleConnections int `json:"max_idle_connections" mapstructure:"max_idle_connections" structs:"max_idle_connections"`
|
|
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime" mapstructure:"max_connection_lifetime" structs:"max_connection_lifetime"`
|
|
Username string `json:"username" mapstructure:"username" structs:"username"`
|
|
Password string `json:"password" mapstructure:"password" structs:"password"`
|
|
|
|
TLSCertificateKeyData []byte `json:"tls_certificate_key" mapstructure:"tls_certificate_key" structs:"-"`
|
|
TLSCAData []byte `json:"tls_ca" mapstructure:"tls_ca" structs:"-"`
|
|
TLSServerName string `json:"tls_server_name" mapstructure:"tls_server_name" structs:"tls_server_name"`
|
|
TLSSkipVerify bool `json:"tls_skip_verify" mapstructure:"tls_skip_verify" structs:"tls_skip_verify"`
|
|
|
|
// tlsConfigName is a globally unique name that references the TLS config for this instance in the mysql driver
|
|
tlsConfigName string
|
|
|
|
RawConfig map[string]interface{}
|
|
maxConnectionLifetime time.Duration
|
|
Initialized bool
|
|
db *sql.DB
|
|
sync.Mutex
|
|
}
|
|
|
|
func (c *mySQLConnectionProducer) Initialize(ctx context.Context, conf map[string]interface{}, verifyConnection bool) error {
|
|
_, err := c.Init(ctx, conf, verifyConnection)
|
|
return err
|
|
}
|
|
|
|
func (c *mySQLConnectionProducer) Init(ctx context.Context, conf map[string]interface{}, verifyConnection bool) (map[string]interface{}, error) {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
c.RawConfig = conf
|
|
|
|
err := mapstructure.WeakDecode(conf, &c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(c.ConnectionURL) == 0 {
|
|
return nil, fmt.Errorf("connection_url cannot be empty")
|
|
}
|
|
|
|
// Don't escape special characters for MySQL password
|
|
password := c.Password
|
|
|
|
// QueryHelper doesn't do any SQL escaping, but if it starts to do so
|
|
// then maybe we won't be able to use it to do URL substitution any more.
|
|
c.ConnectionURL = dbutil.QueryHelper(c.ConnectionURL, map[string]string{
|
|
"username": url.PathEscape(c.Username),
|
|
"password": password,
|
|
})
|
|
|
|
if c.MaxOpenConnections == 0 {
|
|
c.MaxOpenConnections = 4
|
|
}
|
|
|
|
if c.MaxIdleConnections == 0 {
|
|
c.MaxIdleConnections = c.MaxOpenConnections
|
|
}
|
|
if c.MaxIdleConnections > c.MaxOpenConnections {
|
|
c.MaxIdleConnections = c.MaxOpenConnections
|
|
}
|
|
if c.MaxConnectionLifetimeRaw == nil {
|
|
c.MaxConnectionLifetimeRaw = "0s"
|
|
}
|
|
|
|
c.maxConnectionLifetime, err = parseutil.ParseDurationSecond(c.MaxConnectionLifetimeRaw)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid max_connection_lifetime: %w", err)
|
|
}
|
|
|
|
tlsConfig, err := c.getTLSAuth()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if tlsConfig != nil {
|
|
if c.tlsConfigName == "" {
|
|
c.tlsConfigName, err = uuid.GenerateUUID()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to generate UUID for TLS configuration: %w", err)
|
|
}
|
|
}
|
|
|
|
mysql.RegisterTLSConfig(c.tlsConfigName, tlsConfig)
|
|
}
|
|
|
|
// Set initialized to true at this point since all fields are set,
|
|
// and the connection can be established at a later time.
|
|
c.Initialized = true
|
|
|
|
if verifyConnection {
|
|
if _, err = c.Connection(ctx); err != nil {
|
|
return nil, fmt.Errorf("error verifying - connection: %w", err)
|
|
}
|
|
|
|
if err := c.db.PingContext(ctx); err != nil {
|
|
return nil, fmt.Errorf("error verifying - ping: %w", err)
|
|
}
|
|
}
|
|
|
|
return c.RawConfig, nil
|
|
}
|
|
|
|
func (c *mySQLConnectionProducer) Connection(ctx context.Context) (interface{}, error) {
|
|
if !c.Initialized {
|
|
return nil, connutil.ErrNotInitialized
|
|
}
|
|
|
|
// If we already have a DB, test it and return
|
|
if c.db != nil {
|
|
if err := c.db.PingContext(ctx); err == nil {
|
|
return c.db, nil
|
|
}
|
|
// If the ping was unsuccessful, close it and ignore errors as we'll be
|
|
// reestablishing anyways
|
|
c.db.Close()
|
|
}
|
|
|
|
connURL, err := c.addTLStoDSN()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.db, err = sql.Open("mysql", connURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set some connection pool settings. We don't need much of this,
|
|
// since the request rate shouldn't be high.
|
|
c.db.SetMaxOpenConns(c.MaxOpenConnections)
|
|
c.db.SetMaxIdleConns(c.MaxIdleConnections)
|
|
c.db.SetConnMaxLifetime(c.maxConnectionLifetime)
|
|
|
|
return c.db, nil
|
|
}
|
|
|
|
func (c *mySQLConnectionProducer) SecretValues() map[string]string {
|
|
return map[string]string{
|
|
c.Password: "[password]",
|
|
}
|
|
}
|
|
|
|
// Close attempts to close the connection
|
|
func (c *mySQLConnectionProducer) Close() error {
|
|
// Grab the write lock
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
if c.db != nil {
|
|
c.db.Close()
|
|
}
|
|
|
|
c.db = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *mySQLConnectionProducer) getTLSAuth() (tlsConfig *tls.Config, err error) {
|
|
if len(c.TLSCAData) == 0 &&
|
|
len(c.TLSCertificateKeyData) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
rootCertPool := x509.NewCertPool()
|
|
if len(c.TLSCAData) > 0 {
|
|
ok := rootCertPool.AppendCertsFromPEM(c.TLSCAData)
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to append CA to client options")
|
|
}
|
|
}
|
|
|
|
clientCert := make([]tls.Certificate, 0, 1)
|
|
|
|
if len(c.TLSCertificateKeyData) > 0 {
|
|
certificate, err := tls.X509KeyPair(c.TLSCertificateKeyData, c.TLSCertificateKeyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to load tls_certificate_key_data: %w", err)
|
|
}
|
|
|
|
clientCert = append(clientCert, certificate)
|
|
}
|
|
|
|
tlsConfig = &tls.Config{
|
|
RootCAs: rootCertPool,
|
|
Certificates: clientCert,
|
|
ServerName: c.TLSServerName,
|
|
InsecureSkipVerify: c.TLSSkipVerify,
|
|
}
|
|
|
|
return tlsConfig, nil
|
|
}
|
|
|
|
func (c *mySQLConnectionProducer) addTLStoDSN() (connURL string, err error) {
|
|
config, err := mysql.ParseDSN(c.ConnectionURL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to parse connectionURL: %s", err)
|
|
}
|
|
|
|
if len(c.tlsConfigName) > 0 {
|
|
config.TLSConfig = c.tlsConfigName
|
|
}
|
|
|
|
connURL = config.FormatDSN()
|
|
return connURL, nil
|
|
}
|