diff --git a/builtin/logical/database/dbs/db.go b/builtin/logical/database/dbs/db.go index 2637a73d10..8d44a474e4 100644 --- a/builtin/logical/database/dbs/db.go +++ b/builtin/logical/database/dbs/db.go @@ -20,8 +20,7 @@ const ( var ( ErrUnsupportedDatabaseType = errors.New("unsupported database type") ErrEmptyCreationStatement = errors.New("empty creation statements") - ErrEmptyPluginCommand = errors.New("empty plugin command") - ErrEmptyPluginChecksum = errors.New("empty plugin checksum") + ErrEmptyPluginName = errors.New("empty plugin name") ) // Factory function definition @@ -95,18 +94,19 @@ func BuiltinFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Log // PluginFactory is used to build plugin database types. It wraps the database // object in a logging and metrics middleware. func PluginFactory(conf *DatabaseConfig, sys logical.SystemView, logger log.Logger) (DatabaseType, error) { - if conf.PluginCommand == "" { - return nil, ErrEmptyPluginCommand + if conf.PluginName == "" { + return nil, ErrEmptyPluginName } - if conf.PluginChecksum == "" { - return nil, ErrEmptyPluginChecksum + pluginMeta, err := sys.LookupPlugin(conf.PluginName) + if err != nil { + return nil, err } // Make sure the database type is set to plugin conf.DatabaseType = pluginTypeName - db, err := newPluginClient(sys, conf.PluginCommand, conf.PluginChecksum) + db, err := newPluginClient(sys, pluginMeta) if err != nil { return nil, err } @@ -149,8 +149,7 @@ type DatabaseConfig struct { MaxOpenConnections int `json:"max_open_connections" structs:"max_open_connections" mapstructure:"max_open_connections"` MaxIdleConnections int `json:"max_idle_connections" structs:"max_idle_connections" mapstructure:"max_idle_connections"` MaxConnectionLifetime time.Duration `json:"max_connection_lifetime" structs:"max_connection_lifetime" mapstructure:"max_connection_lifetime"` - PluginCommand string `json:"plugin_command" structs:"plugin_command" mapstructure:"plugin_command"` - PluginChecksum string `json:"plugin_checksum" structs:"plugin_checksum" mapstructure:"plugin_checksum"` + PluginName string `json:"plugin_name" structs:"plugin_name" mapstructure:"plugin_name"` } // GetFactory returns the appropriate factory method for the given database diff --git a/builtin/logical/database/dbs/plugin.go b/builtin/logical/database/dbs/plugin.go index 4bac0d16e1..791f3b4651 100644 --- a/builtin/logical/database/dbs/plugin.go +++ b/builtin/logical/database/dbs/plugin.go @@ -1,12 +1,8 @@ package dbs import ( - "crypto/sha256" - "encoding/hex" "fmt" "net/rpc" - "os/exec" - "strings" "sync" "time" @@ -55,59 +51,17 @@ func (dc *DatabasePluginClient) Close() error { // newPluginClient returns a databaseRPCClient with a connection to a running // plugin. The client is wrapped in a DatabasePluginClient object to ensure the // plugin is killed on call of Close(). -func newPluginClient(sys pluginutil.Wrapper, command, checksum string) (DatabaseType, error) { +func newPluginClient(sys pluginutil.Wrapper, pluginRunner *pluginutil.PluginRunner) (DatabaseType, error) { // pluginMap is the map of plugins we can dispense. var pluginMap = map[string]plugin.Plugin{ "database": new(DatabasePlugin), } - // Get a CA TLS Certificate - CACertBytes, CACert, CAKey, err := pluginutil.GenerateCACert() + client, err := pluginRunner.Run(sys, pluginMap, handshakeConfig, []string{}) if err != nil { return nil, err } - // Use CA to sign a client cert and return a configured TLS config - clientTLSConfig, err := pluginutil.CreateClientTLSConfig(CACert, CAKey) - if err != nil { - return nil, err - } - - // Use CA to sign a server cert and wrap the values in a response wrapped - // token. - wrapToken, err := pluginutil.WrapServerConfig(sys, CACertBytes, CACert, CAKey) - if err != nil { - return nil, err - } - - // Add the response wrap token to the ENV of the plugin - commandArr := strings.Split(command, " ") - var cmd *exec.Cmd - if len(commandArr) > 1 { - cmd = exec.Command(commandArr[0], commandArr[1:]...) - } else { - cmd = exec.Command(commandArr[0]) - } - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", pluginutil.PluginUnwrapTokenEnv, wrapToken)) - - checksumDecoded, err := hex.DecodeString(checksum) - if err != nil { - return nil, err - } - - secureConfig := &plugin.SecureConfig{ - Checksum: checksumDecoded, - Hash: sha256.New(), - } - - client := plugin.NewClient(&plugin.ClientConfig{ - HandshakeConfig: handshakeConfig, - Plugins: pluginMap, - Cmd: cmd, - TLSConfig: clientTLSConfig, - SecureConfig: secureConfig, - }) - // Connect via RPC rpcClient, err := client.Client() if err != nil { diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index b4c699750d..a0494d71ed 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -112,13 +112,7 @@ reduced to the same size.`, a zero or negative value reuses connections forever.`, }, - "plugin_command": &framework.FieldSchema{ - Type: framework.TypeString, - Description: `Maximum amount of time a connection may be reused; - a zero or negative value reuses connections forever.`, - }, - - "plugin_checksum": &framework.FieldSchema{ + "plugin_name": &framework.FieldSchema{ Type: framework.TypeString, Description: `Maximum amount of time a connection may be reused; a zero or negative value reuses connections forever.`, @@ -223,8 +217,7 @@ func (b *databaseBackend) connectionWriteHandler(factory dbs.Factory) framework. MaxOpenConnections: maxOpenConns, MaxIdleConnections: maxIdleConns, MaxConnectionLifetime: maxConnLifetime, - PluginCommand: data.Get("plugin_command").(string), - PluginChecksum: data.Get("plugin_checksum").(string), + PluginName: data.Get("plugin_name").(string), } name := data.Get("name").(string) diff --git a/command/server.go b/command/server.go index 09658b9494..d6eb0d76d9 100644 --- a/command/server.go +++ b/command/server.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "os/signal" + "path/filepath" "runtime" "sort" "strconv" @@ -20,6 +21,7 @@ import ( colorable "github.com/mattn/go-colorable" log "github.com/mgutz/logxi/v1" + homedir "github.com/mitchellh/go-homedir" "google.golang.org/grpc/grpclog" @@ -237,11 +239,22 @@ func (c *ServerCommand) Run(args []string) int { DefaultLeaseTTL: config.DefaultLeaseTTL, ClusterName: config.ClusterName, CacheSize: config.CacheSize, + PluginDirectory: config.PluginDirectory, } if dev { coreConfig.DevToken = devRootTokenID } + if config.PluginDirectory == "" { + homePath, err := homedir.Dir() + if err != nil { + c.Ui.Output(fmt.Sprintf( + "Error getting user's home directory: %v", err)) + return 1 + } + coreConfig.PluginDirectory = filepath.Join(homePath, "/vault-plugins/") + } + var disableClustering bool // Initialize the separate HA physical backend, if it exists diff --git a/command/server/config.go b/command/server/config.go index 00edd5de93..a57fdad13b 100644 --- a/command/server/config.go +++ b/command/server/config.go @@ -38,7 +38,8 @@ type Config struct { DefaultLeaseTTL time.Duration `hcl:"-"` DefaultLeaseTTLRaw string `hcl:"default_lease_ttl"` - ClusterName string `hcl:"cluster_name"` + ClusterName string `hcl:"cluster_name"` + PluginDirectory string `hcl:"plugin_directory"` } // DevConfig is a Config that is used for dev mode of Vault. @@ -339,6 +340,7 @@ func ParseConfig(d string, logger log.Logger) (*Config, error) { "default_lease_ttl", "max_lease_ttl", "cluster_name", + "plugin_directory", } if err := checkHCLKeys(list, valid); err != nil { return nil, err diff --git a/helper/pluginutil/runner.go b/helper/pluginutil/runner.go new file mode 100644 index 0000000000..143a4c8391 --- /dev/null +++ b/helper/pluginutil/runner.go @@ -0,0 +1,61 @@ +package pluginutil + +import ( + "crypto/sha256" + "fmt" + "os/exec" + + plugin "github.com/hashicorp/go-plugin" +) + +type Looker interface { + LookupPlugin(string) (*PluginRunner, error) +} + +type PluginRunner struct { + Name string `json:"name"` + Command string `json:"command"` + Args []string `json:"args"` + Sha256 []byte `json:"sha256"` +} + +func (r *PluginRunner) Run(wrapper Wrapper, pluginMap map[string]plugin.Plugin, hs plugin.HandshakeConfig, env []string) (*plugin.Client, error) { + // Get a CA TLS Certificate + CACertBytes, CACert, CAKey, err := GenerateCACert() + if err != nil { + return nil, err + } + + // Use CA to sign a client cert and return a configured TLS config + clientTLSConfig, err := CreateClientTLSConfig(CACert, CAKey) + if err != nil { + return nil, err + } + + // Use CA to sign a server cert and wrap the values in a response wrapped + // token. + wrapToken, err := WrapServerConfig(wrapper, CACertBytes, CACert, CAKey) + if err != nil { + return nil, err + } + + // Add the response wrap token to the ENV of the plugin + cmd := exec.Command(r.Command, r.Args...) + cmd.Env = append(cmd.Env, env...) + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", PluginUnwrapTokenEnv, wrapToken)) + + secureConfig := &plugin.SecureConfig{ + Checksum: r.Sha256, + Hash: sha256.New(), + } + + client := plugin.NewClient(&plugin.ClientConfig{ + HandshakeConfig: hs, + Plugins: pluginMap, + Cmd: cmd, + TLSConfig: clientTLSConfig, + SecureConfig: secureConfig, + }) + + return client, nil +} diff --git a/logical/system_view.go b/logical/system_view.go index 56254b33a1..a9626bc50e 100644 --- a/logical/system_view.go +++ b/logical/system_view.go @@ -5,6 +5,7 @@ import ( "time" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/pluginutil" ) // SystemView exposes system configuration information in a safe way @@ -42,6 +43,8 @@ type SystemView interface { // ResponseWrapData wraps the given data in a cubbyhole and returns the // token used to unwrap. ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) + + LookupPlugin(string) (*pluginutil.PluginRunner, error) } type StaticSystemView struct { @@ -81,3 +84,7 @@ func (d StaticSystemView) ReplicationState() consts.ReplicationState { func (d StaticSystemView) ResponseWrapData(data map[string]interface{}, ttl time.Duration, jwt bool) (string, error) { return "", errors.New("ResponseWrapData is not implimented in StaticSystemView") } + +func (d StaticSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { + return nil, errors.New("LookupPlugin is not implimented in StaticSystemView") +} diff --git a/vault/core.go b/vault/core.go index ea378fa8ad..08a828643a 100644 --- a/vault/core.go +++ b/vault/core.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/url" + "path/filepath" "sync" "time" @@ -330,6 +331,12 @@ type Core struct { // uiEnabled indicates whether Vault Web UI is enabled or not uiEnabled bool + + // pluginDirectory is the location vault will look for plugins + pluginDirectory string + + // pluginCatalog is used to manage plugin configurations + pluginCatalog *PluginCatalog } // CoreConfig is used to parameterize a core @@ -374,6 +381,8 @@ type CoreConfig struct { EnableUI bool `json:"ui" structs:"ui" mapstructure:"ui"` + PluginDirectory string `json:"plugin_directory" structs:"plugin_directory" mapstructure:"plugin_directory"` + ReloadFuncs *map[string][]ReloadFunc ReloadFuncsLock *sync.RWMutex } @@ -453,8 +462,13 @@ func NewCore(conf *CoreConfig) (*Core, error) { } } - // Construct a new AES-GCM barrier var err error + c.pluginDirectory, err = filepath.Abs(conf.PluginDirectory) + if err != nil { + return nil, fmt.Errorf("core setup failed: %v", err) + } + + // Construct a new AES-GCM barrier c.barrier, err = NewAESGCMBarrier(c.physical) if err != nil { return nil, fmt.Errorf("barrier setup failed: %v", err) @@ -1280,6 +1294,10 @@ func (c *Core) postUnseal() (retErr error) { if err := c.setupAuditedHeadersConfig(); err != nil { return err } + if err := c.setupPluginCatalog(); err != nil { + return err + } + if c.ha != nil { if err := c.startClusterListener(); err != nil { return err diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 4c6807ace9..f318f3ab13 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -4,6 +4,7 @@ import ( "time" "github.com/hashicorp/vault/helper/consts" + "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" ) @@ -114,3 +115,7 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim return resp.WrapInfo.Token, nil } + +func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { + return d.core.pluginCatalog.Get(name) +} diff --git a/vault/logical_system.go b/vault/logical_system.go index 1c439506ca..f5dbe2affa 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -63,6 +63,8 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen "replication/reindex", "rotate", "config/auditing/*", + "plugin-catalog", + "plugin-catalog/*", }, Unauthenticated: []string{ @@ -692,6 +694,30 @@ func NewSystemBackend(core *Core, config *logical.BackendConfig) (logical.Backen HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers"][0]), HelpDescription: strings.TrimSpace(sysHelp["audited-headers"][1]), }, + &framework.Path{ + Pattern: "plugin-catalog/(?P.+)", + + Fields: map[string]*framework.FieldSchema{ + "name": &framework.FieldSchema{ + Type: framework.TypeString, + }, + "sha_256": &framework.FieldSchema{ + Type: framework.TypeString, + }, + "command": &framework.FieldSchema{ + Type: framework.TypeString, + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.handlePluginCatalogUpdate, + logical.DeleteOperation: b.handlePluginCatalogDelete, + logical.ReadOperation: b.handlePluginCatalogRead, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["audited-headers-name"][0]), + HelpDescription: strings.TrimSpace(sysHelp["audited-headers-name"][1]), + }, }, } @@ -724,6 +750,69 @@ func (b *SystemBackend) invalidate(key string) { } } +func (b *SystemBackend) handlePluginCatalogUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + + sha256 := d.Get("sha_256").(string) + if sha256 == "" { + return logical.ErrorResponse("missing SHA-256 value"), nil + } + + command := d.Get("command").(string) + if command == "" { + return logical.ErrorResponse("missing command value"), nil + } + + sha256Bytes, err := hex.DecodeString(sha256) + if err != nil { + return logical.ErrorResponse("Could not decode SHA-256 value from Hex"), err + } + + err = b.Core.pluginCatalog.Set(pluginName, command, sha256Bytes) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (b *SystemBackend) handlePluginCatalogRead(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + plugin, err := b.Core.pluginCatalog.Get(pluginName) + if err != nil { + return nil, err + } + + return &logical.Response{ + Data: map[string]interface{}{ + "plugin": plugin, + }, + }, nil +} + +func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("name").(string) + if pluginName == "" { + return logical.ErrorResponse("missing plugin name"), nil + } + plugin, err := b.Core.pluginCatalog.Get(pluginName) + if err != nil { + return nil, err + } + + return &logical.Response{ + Data: map[string]interface{}{ + "plugin": plugin, + }, + }, nil +} + // handleAuditedHeaderUpdate creates or overwrites a header entry func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { header := d.Get("header").(string) diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go new file mode 100644 index 0000000000..c1f504d2c2 --- /dev/null +++ b/vault/plugin_catalog.go @@ -0,0 +1,101 @@ +package vault + +import ( + "encoding/json" + "errors" + "fmt" + "path/filepath" + "strings" + "sync" + + "github.com/hashicorp/vault/helper/jsonutil" + "github.com/hashicorp/vault/helper/pluginutil" + "github.com/hashicorp/vault/logical" +) + +var ( + pluginCatalogPrefix = "plugin-catalog/" +) + +type PluginCatalog struct { + catalogView *BarrierView + directory string + + lock sync.RWMutex + builtin map[string]*pluginutil.PluginRunner +} + +func NewPluginCatalog(view *BarrierView, directory string) *PluginCatalog { + return &PluginCatalog{ + catalogView: view.SubView(pluginCatalogPrefix), + directory: directory, + } +} + +func (c *Core) setupPluginCatalog() error { + catalog := NewPluginCatalog(c.systemBarrierView, c.pluginDirectory) + c.pluginCatalog = catalog + + return nil +} + +func (c *PluginCatalog) Get(name string) (*pluginutil.PluginRunner, error) { + out, err := c.catalogView.Get(name) + if err != nil { + return nil, fmt.Errorf("failed to retrieve plugin \"%s\": %v", name, err) + } + if out == nil { + return nil, fmt.Errorf("no plugin found with name: %s", name) + } + + entry := new(pluginutil.PluginRunner) + if err := jsonutil.DecodeJSON(out.Value, entry); err != nil { + return nil, fmt.Errorf("failed to decode plugin entry: %v", err) + } + + return entry, nil +} + +func (c *PluginCatalog) Set(name, command string, sha256 []byte) error { + parts := strings.Split(command, " ") + command = parts[0] + args := parts[1:] + + command = filepath.Join(c.directory, command) + + // Best effort check to make sure the command isn't breaking out of the + // configured plugin directory. + sym, err := filepath.EvalSymlinks(command) + if err != nil { + return fmt.Errorf("error while validating the command path: %v", err) + } + symAbs, err := filepath.Abs(filepath.Dir(sym)) + if err != nil { + return fmt.Errorf("error while validating the command path: %v", err) + } + + if symAbs != c.directory { + return errors.New("can not execute files outside of configured plugin directory") + } + + entry := &pluginutil.PluginRunner{ + Name: name, + Command: command, + Args: args, + Sha256: sha256, + } + + buf, err := json.Marshal(entry) + if err != nil { + return fmt.Errorf("failed to encode plugin entry: %v", err) + } + + logicalEntry := logical.StorageEntry{ + Key: name, + Value: buf, + } + if err := c.catalogView.Put(&logicalEntry); err != nil { + return fmt.Errorf("failed to persist plugin entry: %v", err) + } + return nil +}