mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-16 11:37:04 +02:00
* Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package. This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License. Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at https://hashi.co/bsl-blog, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUS-1.1 * Fix test that expected exact offset on hcl file --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com> Co-authored-by: Sarah Thompson <sthompson@hashicorp.com> Co-authored-by: Brian Kassouf <bkassouf@hashicorp.com>
286 lines
6.5 KiB
Go
286 lines
6.5 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package mssql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
metrics "github.com/armon/go-metrics"
|
|
_ "github.com/denisenkom/go-mssqldb"
|
|
log "github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
"github.com/hashicorp/vault/sdk/physical"
|
|
)
|
|
|
|
// Verify MSSQLBackend satisfies the correct interfaces
|
|
var _ physical.Backend = (*MSSQLBackend)(nil)
|
|
var identifierRegex = regexp.MustCompile(`^[\p{L}_][\p{L}\p{Nd}@#$_]*$`)
|
|
|
|
type MSSQLBackend struct {
|
|
dbTable string
|
|
client *sql.DB
|
|
statements map[string]*sql.Stmt
|
|
logger log.Logger
|
|
permitPool *physical.PermitPool
|
|
}
|
|
|
|
func isInvalidIdentifier(name string) bool {
|
|
if !identifierRegex.MatchString(name) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func NewMSSQLBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) {
|
|
username, ok := conf["username"]
|
|
if !ok {
|
|
username = ""
|
|
}
|
|
|
|
password, ok := conf["password"]
|
|
if !ok {
|
|
password = ""
|
|
}
|
|
|
|
server, ok := conf["server"]
|
|
if !ok || server == "" {
|
|
return nil, fmt.Errorf("missing server")
|
|
}
|
|
|
|
port, ok := conf["port"]
|
|
if !ok {
|
|
port = ""
|
|
}
|
|
|
|
maxParStr, ok := conf["max_parallel"]
|
|
var maxParInt int
|
|
var err error
|
|
if ok {
|
|
maxParInt, err = strconv.Atoi(maxParStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed parsing max_parallel parameter: %w", err)
|
|
}
|
|
if logger.IsDebug() {
|
|
logger.Debug("max_parallel set", "max_parallel", maxParInt)
|
|
}
|
|
} else {
|
|
maxParInt = physical.DefaultParallelOperations
|
|
}
|
|
|
|
database, ok := conf["database"]
|
|
if !ok {
|
|
database = "Vault"
|
|
}
|
|
|
|
if isInvalidIdentifier(database) {
|
|
return nil, fmt.Errorf("invalid database name")
|
|
}
|
|
|
|
table, ok := conf["table"]
|
|
if !ok {
|
|
table = "Vault"
|
|
}
|
|
|
|
if isInvalidIdentifier(table) {
|
|
return nil, fmt.Errorf("invalid table name")
|
|
}
|
|
|
|
appname, ok := conf["appname"]
|
|
if !ok {
|
|
appname = "Vault"
|
|
}
|
|
|
|
connectionTimeout, ok := conf["connectiontimeout"]
|
|
if !ok {
|
|
connectionTimeout = "30"
|
|
}
|
|
|
|
logLevel, ok := conf["loglevel"]
|
|
if !ok {
|
|
logLevel = "0"
|
|
}
|
|
|
|
schema, ok := conf["schema"]
|
|
if !ok || schema == "" {
|
|
schema = "dbo"
|
|
}
|
|
|
|
if isInvalidIdentifier(schema) {
|
|
return nil, fmt.Errorf("invalid schema name")
|
|
}
|
|
|
|
connectionString := fmt.Sprintf("server=%s;app name=%s;connection timeout=%s;log=%s", server, appname, connectionTimeout, logLevel)
|
|
if username != "" {
|
|
connectionString += ";user id=" + username
|
|
}
|
|
|
|
if password != "" {
|
|
connectionString += ";password=" + password
|
|
}
|
|
|
|
if port != "" {
|
|
connectionString += ";port=" + port
|
|
}
|
|
|
|
db, err := sql.Open("mssql", connectionString)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to mssql: %w", err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(maxParInt)
|
|
|
|
if _, err := db.Exec("IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = ?) CREATE DATABASE "+database, database); err != nil {
|
|
return nil, fmt.Errorf("failed to create mssql database: %w", err)
|
|
}
|
|
|
|
dbTable := database + "." + schema + "." + table
|
|
createQuery := "IF NOT EXISTS(SELECT 1 FROM " + database + ".INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE' AND TABLE_NAME=? AND TABLE_SCHEMA=?) CREATE TABLE " + dbTable + " (Path VARCHAR(512) PRIMARY KEY, Value VARBINARY(MAX))"
|
|
|
|
if schema != "dbo" {
|
|
|
|
var num int
|
|
err = db.QueryRow("SELECT 1 FROM "+database+".sys.schemas WHERE name = ?", schema).Scan(&num)
|
|
|
|
switch {
|
|
case err == sql.ErrNoRows:
|
|
if _, err := db.Exec("USE " + database + "; EXEC ('CREATE SCHEMA " + schema + "')"); err != nil {
|
|
return nil, fmt.Errorf("failed to create mssql schema: %w", err)
|
|
}
|
|
|
|
case err != nil:
|
|
return nil, fmt.Errorf("failed to check if mssql schema exists: %w", err)
|
|
}
|
|
}
|
|
|
|
if _, err := db.Exec(createQuery, table, schema); err != nil {
|
|
return nil, fmt.Errorf("failed to create mssql table: %w", err)
|
|
}
|
|
|
|
m := &MSSQLBackend{
|
|
dbTable: dbTable,
|
|
client: db,
|
|
statements: make(map[string]*sql.Stmt),
|
|
logger: logger,
|
|
permitPool: physical.NewPermitPool(maxParInt),
|
|
}
|
|
|
|
statements := map[string]string{
|
|
"put": "IF EXISTS(SELECT 1 FROM " + dbTable + " WHERE Path = ?) UPDATE " + dbTable + " SET Value = ? WHERE Path = ?" +
|
|
" ELSE INSERT INTO " + dbTable + " VALUES(?, ?)",
|
|
"get": "SELECT Value FROM " + dbTable + " WHERE Path = ?",
|
|
"delete": "DELETE FROM " + dbTable + " WHERE Path = ?",
|
|
"list": "SELECT Path FROM " + dbTable + " WHERE Path LIKE ?",
|
|
}
|
|
|
|
for name, query := range statements {
|
|
if err := m.prepare(name, query); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return m, nil
|
|
}
|
|
|
|
func (m *MSSQLBackend) prepare(name, query string) error {
|
|
stmt, err := m.client.Prepare(query)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to prepare %q: %w", name, err)
|
|
}
|
|
|
|
m.statements[name] = stmt
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MSSQLBackend) Put(ctx context.Context, entry *physical.Entry) error {
|
|
defer metrics.MeasureSince([]string{"mssql", "put"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, err := m.statements["put"].Exec(entry.Key, entry.Value, entry.Key, entry.Key, entry.Value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MSSQLBackend) Get(ctx context.Context, key string) (*physical.Entry, error) {
|
|
defer metrics.MeasureSince([]string{"mssql", "get"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
var result []byte
|
|
err := m.statements["get"].QueryRow(key).Scan(&result)
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ent := &physical.Entry{
|
|
Key: key,
|
|
Value: result,
|
|
}
|
|
|
|
return ent, nil
|
|
}
|
|
|
|
func (m *MSSQLBackend) Delete(ctx context.Context, key string) error {
|
|
defer metrics.MeasureSince([]string{"mssql", "delete"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
_, err := m.statements["delete"].Exec(key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MSSQLBackend) List(ctx context.Context, prefix string) ([]string, error) {
|
|
defer metrics.MeasureSince([]string{"mssql", "list"}, time.Now())
|
|
|
|
m.permitPool.Acquire()
|
|
defer m.permitPool.Release()
|
|
|
|
likePrefix := prefix + "%"
|
|
rows, err := m.statements["list"].Query(likePrefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var keys []string
|
|
for rows.Next() {
|
|
var key string
|
|
err = rows.Scan(&key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan rows: %w", err)
|
|
}
|
|
|
|
key = strings.TrimPrefix(key, prefix)
|
|
if i := strings.Index(key, "/"); i == -1 {
|
|
keys = append(keys, key)
|
|
} else if i != -1 {
|
|
keys = strutil.AppendIfMissing(keys, string(key[:i+1]))
|
|
}
|
|
}
|
|
|
|
sort.Strings(keys)
|
|
|
|
return keys, nil
|
|
}
|