Plugin catalog

This commit is contained in:
Brian Kassouf 2017-04-03 17:52:29 -07:00
parent 1d3d3b7803
commit ac519abecf
11 changed files with 310 additions and 68 deletions

View File

@ -20,8 +20,7 @@ const (
var (
ErrUnsupportedDatabaseType = errors.New("unsupported database type")
ErrEmptyCreationStatement = errors.New("empty creation statements")
ErrEmptyPluginCommand = errors.New("empty plugin command")
ErrEmptyPluginChecksum = errors.New("empty plugin checksum")
ErrEmptyPluginName = errors.New("empty plugin name")
)
// Factory function definition
@ -95,18 +94,19 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log
// PluginFactory is used to build plugin database types. It wraps the database
// object in a logging and metrics middleware.
func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) {
if conf.PluginCommand == "" {
return nil, ErrEmptyPluginCommand
if conf.PluginName == "" {
return nil, ErrEmptyPluginName
}
if conf.PluginChecksum == "" {
return nil, ErrEmptyPluginChecksum
pluginMeta, err := sys.LookupPlugin(conf.PluginName)
if err != nil {
return nil, err
}
// Make sure the database type is set to plugin
conf.DatabaseType = pluginTypeName
db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum)
db, err := newPluginClient(sys, pluginMeta)
if err != nil {
return nil, err
}
@ -149,8 +149,7 @@ type DatabaseConfig struct {
MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"`
MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"`
MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"`
PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"`
PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"`
PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"`
}
// GetFactory returns the appropriate factory method for the given database

View File

@ -1,12 +1,8 @@
package dbs
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/rpc"
"os/exec"
"strings"
"sync"
"time"
@ -55,59 +51,17 @@ func (dc *DatabasePluginClient) Close() error {
// newPluginClient returns a databaseRPCClient with a connection to a running
// plugin. The client is wrapped in a DatabasePluginClient object to ensure the
// plugin is killed on call of Close().
func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (DatabaseType, error) {
func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) {
// pluginMap is the map of plugins we can dispense.
var pluginMap = map[string]plugin.Plugin{
"database": new(DatabasePlugin),
}
// Get a CA TLS Certificate
CACertBytes, CACert, CAKey, err := pluginutil.GenerateCACert()
client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{})
if err != nil {
return nil, err
}
// Use CA to sign a client cert and return a configured TLS config
clientTLSConfig, err := pluginutil.CreateClientTLSConfig(CACert, CAKey)
if err != nil {
return nil, err
}
// Use CA to sign a server cert and wrap the values in a response wrapped
// token.
wrapToken, err := pluginutil.WrapServerConfig(sys, CACertBytes, CACert, CAKey)
if err != nil {
return nil, err
}
// Add the response wrap token to the ENV of the plugin
commandArr := strings.Split(command, " ")
var cmd *exec.Cmd
if len(commandArr) > 1 {
cmd = exec.Command(commandArr[0], commandArr[1:]...)
} else {
cmd = exec.Command(commandArr[0])
}
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken))
checksumDecoded, err := hex.DecodeString(checksum)
if err != nil {
return nil, err
}
secureConfig := &plugin.SecureConfig{
Checksum: checksumDecoded,
Hash: sha256.New(),
}
client := plugin.NewClient(&plugin.ClientConfig{
HandshakeConfig: handshakeConfig,
Plugins: pluginMap,
Cmd: cmd,
TLSConfig: clientTLSConfig,
SecureConfig: secureConfig,
})
// Connect via RPC
rpcClient, err := client.Client()
if err != nil {

View File

@ -112,13 +112,7 @@ reduced to the same size.`,
a zero or negative value reuses connections forever.`,
},
"plugin_command": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Maximum amount of time a connection may be reused;
a zero or negative value reuses connections forever.`,
},
"plugin_checksum": &framework.FieldSchema{
"plugin_name": &framework.FieldSchema{
Type: framework.TypeString,
Description: `Maximum amount of time a connection may be reused;
a zero or negative value reuses connections forever.`,
@ -223,8 +217,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework.
MaxOpenConnections: maxOpenConns,
MaxIdleConnections: maxIdleConns,
MaxConnectionLifetime: maxConnLifetime,
PluginCommand: data.Get("plugin_command").(string),
PluginChecksum: data.Get("plugin_checksum").(string),
PluginName: data.Get("plugin_name").(string),
}
name := data.Get("name").(string)

View File

@ -8,6 +8,7 @@ import (
"net/url"
"os"
"os/signal"
"path/filepath"
"runtime"
"sort"
"strconv"
@ -20,6 +21,7 @@ import (
colorable "github.com/mattn/go-colorable"
log "github.com/mgutz/logxi/v1"
homedir "github.com/mitchellh/go-homedir"
"google.golang.org/grpc/grpclog"
@ -237,11 +239,22 @@ func (c *ServerCommand) Run(args []string) int {
DefaultLeaseTTL: config.DefaultLeaseTTL,
ClusterName: config.ClusterName,
CacheSize: config.CacheSize,
PluginDirectory: config.PluginDirectory,
}
if dev {
coreConfig.DevToken = devRootTokenID
}
if config.PluginDirectory == "" {
homePath, err := homedir.Dir()
if err != nil {
c.Ui.Output(fmt.Sprintf(
"Error getting user's home directory: %v", err))
return 1
}
coreConfig.PluginDirectory = filepath.Join(homePath, "/vault-plugins/")
}
var disableClustering bool
// Initialize the separate HA physical backend, if it exists

View File

@ -38,7 +38,8 @@ type Config struct {
DefaultLeaseTTL time.Duration `hcl:"-"`
DefaultLeaseTTLRaw string `hcl:"default_lease_ttl"`
ClusterName string `hcl:"cluster_name"`
ClusterName string `hcl:"cluster_name"`
PluginDirectory string `hcl:"plugin_directory"`
}
// DevConfig is a Config that is used for dev mode of Vault.
@ -339,6 +340,7 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) {
"default_lease_ttl",
"max_lease_ttl",
"cluster_name",
"plugin_directory",
}
if err := checkHCLKeys(list, valid); err != nil {
return nil, err

View File

@ -0,0 +1,61 @@
package pluginutil
import (
"crypto/sha256"
"fmt"
"os/exec"
plugin "github.com/hashicorp/go-plugin"
)
type Looker interface {
LookupPlugin(string) (*PluginRunner, error)
}
type PluginRunner struct {
Name string `json:"name"`
Command string `json:"command"`
Args []string `json:"args"`
Sha256 []byte `json:"sha256"`
}
func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) {
// Get a CA TLS Certificate
CACertBytes, CACert, CAKey, err := GenerateCACert()
if err != nil {
return nil, err
}
// Use CA to sign a client cert and return a configured TLS config
clientTLSConfig, err := CreateClientTLSConfig(CACert, CAKey)
if err != nil {
return nil, err
}
// Use CA to sign a server cert and wrap the values in a response wrapped
// token.
wrapToken, err := WrapServerConfig(wrapper, CACertBytes, CACert, CAKey)
if err != nil {
return nil, err
}
// Add the response wrap token to the ENV of the plugin
cmd := exec.Command(r.Command, r.Args...)
cmd.Env = append(cmd.Env, env...)
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken))
secureConfig := &plugin.SecureConfig{
Checksum: r.Sha256,
Hash: sha256.New(),
}
client := plugin.NewClient(&plugin.ClientConfig{
HandshakeConfig: hs,
Plugins: pluginMap,
Cmd: cmd,
TLSConfig: clientTLSConfig,
SecureConfig: secureConfig,
})
return client, nil
}

View File

@ -5,6 +5,7 @@ import (
"time"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pluginutil"
)
// SystemView exposes system configuration information in a safe way
@ -42,6 +43,8 @@ type SystemView interface {
// ResponseWrapData wraps the given data in a cubbyhole and returns the
// token used to unwrap.
ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error)
LookupPlugin(string) (*pluginutil.PluginRunner, error)
}
type StaticSystemView struct {
@ -81,3 +84,7 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState {
func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) {
return "", errors.New("ResponseWrapData is not implimented in StaticSystemView")
}
func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
return nil, errors.New("LookupPlugin is not implimented in StaticSystemView")
}

View File

@ -10,6 +10,7 @@ import (
"net"
"net/http"
"net/url"
"path/filepath"
"sync"
"time"
@ -330,6 +331,12 @@ type Core struct {
// uiEnabled indicates whether Vault Web UI is enabled or not
uiEnabled bool
// pluginDirectory is the location vault will look for plugins
pluginDirectory string
// pluginCatalog is used to manage plugin configurations
pluginCatalog *PluginCatalog
}
// CoreConfig is used to parameterize a core
@ -374,6 +381,8 @@ type CoreConfig struct {
EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"`
PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"`
ReloadFuncs *map[string][]ReloadFunc
ReloadFuncsLock *sync.RWMutex
}
@ -453,8 +462,13 @@ func NewCore(conf *CoreConfig) (*Core, error) {
}
}
// Construct a new AES-GCM barrier
var err error
c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory)
if err != nil {
return nil, fmt.Errorf("core setup failed: %v", err)
}
// Construct a new AES-GCM barrier
c.barrier, err = NewAESGCMBarrier(c.physical)
if err != nil {
return nil, fmt.Errorf("barrier setup failed: %v", err)
@ -1280,6 +1294,10 @@ func (c *Core) postUnseal() (retErr error) {
if err := c.setupAuditedHeadersConfig(); err != nil {
return err
}
if err := c.setupPluginCatalog(); err != nil {
return err
}
if c.ha != nil {
if err := c.startClusterListener(); err != nil {
return err

View File

@ -4,6 +4,7 @@ import (
"time"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
)
@ -114,3 +115,7 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim
return resp.WrapInfo.Token, nil
}
func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) {
return d.core.pluginCatalog.Get(name)
}

View File

@ -63,6 +63,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
"replication/reindex",
"rotate",
"config/auditing/*",
"plugin-catalog",
"plugin-catalog/*",
},
Unauthenticated: []string{
@ -692,6 +694,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen
HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]),
HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]),
},
&framework.Path{
Pattern: "plugin-catalog/(?P<name>.+)",
Fields: map[string]*framework.FieldSchema{
"name": &framework.FieldSchema{
Type: framework.TypeString,
},
"sha_256": &framework.FieldSchema{
Type: framework.TypeString,
},
"command": &framework.FieldSchema{
Type: framework.TypeString,
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.handlePluginCatalogUpdate,
logical.DeleteOperation: b.handlePluginCatalogDelete,
logical.ReadOperation: b.handlePluginCatalogRead,
},
HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]),
HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]),
},
},
}
@ -724,6 +750,69 @@ func (b *SystemBackend) invalidate(key string) {
}
}
func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
pluginName := d.Get("name").(string)
if pluginName == "" {
return logical.ErrorResponse("missing plugin name"), nil
}
sha256 := d.Get("sha_256").(string)
if sha256 == "" {
return logical.ErrorResponse("missing SHA-256 value"), nil
}
command := d.Get("command").(string)
if command == "" {
return logical.ErrorResponse("missing command value"), nil
}
sha256Bytes, err := hex.DecodeString(sha256)
if err != nil {
return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err
}
err = b.Core.pluginCatalog.Set(pluginName, command, sha256Bytes)
if err != nil {
return nil, err
}
return nil, nil
}
func (b *SystemBackend) handlePluginCatalogRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
pluginName := d.Get("name").(string)
if pluginName == "" {
return logical.ErrorResponse("missing plugin name"), nil
}
plugin, err := b.Core.pluginCatalog.Get(pluginName)
if err != nil {
return nil, err
}
return &logical.Response{
Data: map[string]interface{}{
"plugin": plugin,
},
}, nil
}
func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
pluginName := d.Get("name").(string)
if pluginName == "" {
return logical.ErrorResponse("missing plugin name"), nil
}
plugin, err := b.Core.pluginCatalog.Get(pluginName)
if err != nil {
return nil, err
}
return &logical.Response{
Data: map[string]interface{}{
"plugin": plugin,
},
}, nil
}
// handleAuditedHeaderUpdate creates or overwrites a header entry
func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
header := d.Get("header").(string)

101
vault/plugin_catalog.go Normal file
View File

@ -0,0 +1,101 @@
package vault
import (
"encoding/json"
"errors"
"fmt"
"path/filepath"
"strings"
"sync"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/pluginutil"
"github.com/hashicorp/vault/logical"
)
var (
pluginCatalogPrefix = "plugin-catalog/"
)
type PluginCatalog struct {
catalogView *BarrierView
directory string
lock sync.RWMutex
builtin map[string]*pluginutil.PluginRunner
}
func NewPluginCatalog(view *BarrierView, directory string) *PluginCatalog {
return &PluginCatalog{
catalogView: view.SubView(pluginCatalogPrefix),
directory: directory,
}
}
func (c *Core) setupPluginCatalog() error {
catalog := NewPluginCatalog(c.systemBarrierView, c.pluginDirectory)
c.pluginCatalog = catalog
return nil
}
func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) {
out, err := c.catalogView.Get(name)
if err != nil {
return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err)
}
if out == nil {
return nil, fmt.Errorf("no plugin found with name: %s", name)
}
entry := new(pluginutil.PluginRunner)
if err := jsonutil.DecodeJSON(out.Value, entry); err != nil {
return nil, fmt.Errorf("failed to decode plugin entry: %v", err)
}
return entry, nil
}
func (c *PluginCatalog) Set(name, command string, sha256 []byte) error {
parts := strings.Split(command, " ")
command = parts[0]
args := parts[1:]
command = filepath.Join(c.directory, command)
// Best effort check to make sure the command isn't breaking out of the
// configured plugin directory.
sym, err := filepath.EvalSymlinks(command)
if err != nil {
return fmt.Errorf("error while validating the command path: %v", err)
}
symAbs, err := filepath.Abs(filepath.Dir(sym))
if err != nil {
return fmt.Errorf("error while validating the command path: %v", err)
}
if symAbs != c.directory {
return errors.New("can not execute files outside of configured plugin directory")
}
entry := &pluginutil.PluginRunner{
Name: name,
Command: command,
Args: args,
Sha256: sha256,
}
buf, err := json.Marshal(entry)
if err != nil {
return fmt.Errorf("failed to encode plugin entry: %v", err)
}
logicalEntry := logical.StorageEntry{
Key: name,
Value: buf,
}
if err := c.catalogView.Put(&logicalEntry); err != nil {
return fmt.Errorf("failed to persist plugin entry: %v", err)
}
return nil
}