Add plugin backend reload capability (#3112)

* Add plugin reload capability on all mounts for a specific plugin type

* Comments cleanup

* Add per-mount plugin backend reload, add tests

* Fix typos

* Remove old comment

* Reuse existing storage view in reloadPluginCommon

* Correctly handle reloading auth plugin backends

* Update path to plugin/backend/reload

* Use multierrors on reloadMatchingPluginMounts, attempt to reload all mounts provided

* Use internal value as check to ensure plugin backend reload

* Remove connection state from request for plugins at the moment

* Minor cleanup

* Refactor tests
This commit is contained in:
Calvin Leung Huang 2017-08-08 00:18:59 -04:00 committed by GitHub
parent f2f0082ba5
commit 01d1c20e4c
7 changed files with 371 additions and 13 deletions

View File

@ -92,6 +92,14 @@ func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Resp
}
var reply HandleRequestReply
if req.Connection != nil {
oldConnState := req.Connection.ConnState
req.Connection.ConnState = nil
defer func() {
req.Connection.ConnState = oldConnState
}()
}
err := b.client.Call("Plugin.HandleRequest", args, &reply)
if err != nil {
return nil, err
@ -137,6 +145,14 @@ func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool,
}
var reply HandleExistenceCheckReply
if req.Connection != nil {
oldConnState := req.Connection.ConnState
req.Connection.ConnState = nil
defer func() {
req.Connection.ConnState = oldConnState
}()
}
err := b.client.Call("Plugin.HandleExistenceCheck", args, &reply)
if err != nil {
return false, false, err

View File

@ -10,14 +10,25 @@ import (
func pathInternal(b *backend) *framework.Path {
return &framework.Path{
Pattern: "internal",
Fields: map[string]*framework.FieldSchema{},
ExistenceCheck: b.pathExistenceCheck,
Fields: map[string]*framework.FieldSchema{
"value": &framework.FieldSchema{Type: framework.TypeString},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.pathInternalUpdate,
logical.ReadOperation: b.pathInternalRead,
},
}
}
func (b *backend) pathInternalUpdate(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
value := data.Get("value").(string)
b.internal = value
// Return the secret
return nil, nil
}
func (b *backend) pathInternalRead(
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
// Return the secret

View File

@ -825,6 +825,27 @@ func NewSystemBackend(core *Core) *SystemBackend {
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]),
HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]),
},
&framework.Path{
Pattern: "plugins/backend/reload$",
Fields: map[string]*framework.FieldSchema{
"plugin": &framework.FieldSchema{
Type: framework.TypeString,
Description: strings.TrimSpace(sysHelp["plugin-backend-reload-plugin"][0]),
},
"mounts": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: strings.TrimSpace(sysHelp["plugin-backend-reload-mounts"][0]),
},
},
Callbacks: map[logical.Operation]framework.OperationFunc{
logical.UpdateOperation: b.handlePluginReloadUpdate,
},
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-reload"][0]),
HelpDescription: strings.TrimSpace(sysHelp["plugin-reload"][1]),
},
},
}
@ -975,6 +996,32 @@ func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *frame
return nil, nil
}
func (b *SystemBackend) handlePluginReloadUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
pluginName := d.Get("plugin").(string)
pluginMounts := d.Get("mounts").([]string)
if pluginName != "" && len(pluginMounts) > 0 {
return logical.ErrorResponse("plugin and mounts cannot be set at the same time"), nil
}
if pluginName == "" && len(pluginMounts) == 0 {
return logical.ErrorResponse("plugin or mounts must be provided"), nil
}
if pluginName != "" {
err := b.Core.reloadMatchingPlugin(pluginName)
if err != nil {
return nil, err
}
} else if len(pluginMounts) > 0 {
err := b.Core.reloadMatchingPluginMounts(pluginMounts)
if err != nil {
return nil, err
}
}
return nil, nil
}
// handleAuditedHeaderUpdate creates or overwrites a header entry
func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
header := d.Get("header").(string)
@ -2855,4 +2902,19 @@ This path responds to the following HTTP methods.
`The path to list leases under. Example: "aws/creds/deploy"`,
"",
},
"plugin-reload": {
"Reload mounts that use a particular backend plugin.",
`Reload mounts that use a particular backend plugin. Either the plugin name
or the desired plugin backend mounts must be provided, but not both. In the
case that the plugin name is provided, all mounted paths that use that plugin
backend will be reloaded.`,
},
"plugin-backend-reload-plugin": {
`The name of the plugin to reload, as registered in the plugin catalog.`,
"",
},
"plugin-backend-reload-mounts": {
`The mount paths of the plugin backends to reload.`,
"",
},
}

View File

@ -1,6 +1,7 @@
package vault_test
import (
"fmt"
"os"
"testing"
"time"
@ -28,11 +29,10 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
})
cluster.Start()
defer cluster.Cleanup()
cores := cluster.Cores
core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)
core := cores[0]
b := vault.NewSystemBackend(core.Core)
b := vault.NewSystemBackend(core)
logger := logformat.NewVaultLogger(log.LevelTrace)
bc := &logical.BackendConfig{
Logger: logger,
@ -49,7 +49,7 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMain")
vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainCredentials")
req := logical.TestRequest(t, logical.UpdateOperation, "auth/mock-plugin")
req.Data["type"] = "plugin"
@ -64,7 +64,151 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
}
}
func TestBackend_PluginMain(t *testing.T) {
func TestSystemBackend_PluginReload(t *testing.T) {
data := map[string]interface{}{
"plugin": "mock-plugin",
}
t.Run("plugin", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
data = map[string]interface{}{
"mounts": "mock-0/,mock-1/",
}
t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
}
func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) {
cluster, b := testSystemBackendMock(t, 2)
defer cluster.Cleanup()
core := cluster.Cores[0]
for i := 0; i < 2; i++ {
// Update internal value in the backend
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mock-%d/internal", i))
req.ClientToken = core.Client.Token()
req.Data["value"] = "baz"
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
}
// Perform plugin reload
req := logical.TestRequest(t, logical.UpdateOperation, "plugins/backend/reload")
req.ClientToken = core.Client.Token()
req.Data = reqData
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
for i := 0; i < 2; i++ {
// Ensure internal backed value is reset
req := logical.TestRequest(t, logical.ReadOperation, "mock-1/internal")
req.ClientToken = core.Client.Token()
resp, err := core.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp == nil {
t.Fatalf("bad: response should not be nil")
}
if resp.Data["value"].(string) == "baz" {
t.Fatal("did not expect backend internal value to be 'baz'")
}
}
}
// testSystemBackendMock returns a systemBackend with the desired number
// of mounted mock plugin backends
func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *vault.SystemBackend) {
coreConfig := &vault.CoreConfig{
LogicalBackends: map[string]logical.Factory{
"plugin": plugin.Factory,
},
}
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
HandlerFunc: vaulthttp.Handler,
})
cluster.Start()
core := cluster.Cores[0].Core
vault.TestWaitActive(t, core)
b := vault.NewSystemBackend(core)
logger := logformat.NewVaultLogger(log.LevelTrace)
bc := &logical.BackendConfig{
Logger: logger,
System: logical.StaticSystemView{
DefaultLeaseTTLVal: time.Hour * 24,
MaxLeaseTTLVal: time.Hour * 24 * 32,
},
}
err := b.Backend.Setup(bc)
if err != nil {
t.Fatal(err)
}
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainLogical")
for i := 0; i < numMounts; i++ {
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mounts/mock-%d/", i))
req.Data["type"] = "plugin"
req.Data["config"] = map[string]interface{}{
"plugin_name": "mock-plugin",
}
resp, err := b.HandleRequest(req)
if err != nil {
t.Fatalf("err: %v", err)
}
if resp != nil {
t.Fatalf("bad: %v", resp)
}
}
return cluster, b
}
func TestBackend_PluginMainLogical(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
if caPEM == "" {
t.Fatal("CA cert not passed in")
}
factoryFunc := mock.FactoryType(logical.TypeLogical)
args := []string{"--ca-cert=" + caPEM}
apiClientMeta := &pluginutil.APIClientMeta{}
flags := apiClientMeta.FlagSet()
flags.Parse(args)
tlsConfig := apiClientMeta.GetTLSConfig()
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
err := lplugin.Serve(&lplugin.ServeOpts{
BackendFactoryFunc: factoryFunc,
TLSProviderFunc: tlsProviderFunc,
})
if err != nil {
t.Fatal(err)
}
}
func TestBackend_PluginMainCredentials(t *testing.T) {
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
return
}

View File

@ -230,7 +230,6 @@ func (c *Core) mount(entry *MountEntry) error {
conf["plugin_name"] = entry.Config.PluginName
}
// Consider having plugin name under entry.Options
backend, err := c.newLogicalBackend(entry.Type, sysView, view, conf)
if err != nil {
return err

125
vault/plugin_reload.go Normal file
View File

@ -0,0 +1,125 @@
package vault
import (
"fmt"
"strings"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/logical"
)
// reloadPluginMounts reloads provided mounts, regardless of
// plugin name, as long as the backend type is plugin.
func (c *Core) reloadMatchingPluginMounts(mounts []string) error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
var errors error
for _, mount := range mounts {
entry := c.router.MatchingMountEntry(mount)
if entry == nil {
errors = multierror.Append(errors, fmt.Errorf("cannot fetch mount entry on %s", mount))
continue
// return fmt.Errorf("cannot fetch mount entry on %s", mount)
}
var isAuth bool
fullPath := c.router.MatchingMount(mount)
if strings.HasPrefix(fullPath, credentialRoutePrefix) {
isAuth = true
}
if entry.Type == "plugin" {
err := c.reloadPluginCommon(entry, isAuth)
if err != nil {
errors = multierror.Append(errors, fmt.Errorf("cannot reload plugin on %s: %v", mount, err))
continue
}
c.logger.Info("core: successfully reloaded plugin", "plugin", entry.Config.PluginName, "path", entry.Path)
}
}
return errors
}
// reloadPlugin reloads all mounted backends that are of
// plugin pluginName (name of the plugin as registered in
// the plugin catalog).
func (c *Core) reloadMatchingPlugin(pluginName string) error {
c.mountsLock.Lock()
defer c.mountsLock.Unlock()
// Filter mount entries that only matches the plugin name
for _, entry := range c.mounts.Entries {
if entry.Config.PluginName == pluginName && entry.Type == "plugin" {
err := c.reloadPluginCommon(entry, false)
if err != nil {
return err
}
c.logger.Info("core: successfully reloaded plugin", "plugin", pluginName, "path", entry.Path)
}
}
// Filter auth mount entries that ony matches the plugin name
for _, entry := range c.auth.Entries {
if entry.Config.PluginName == pluginName && entry.Type == "plugin" {
err := c.reloadPluginCommon(entry, true)
if err != nil {
return err
}
c.logger.Info("core: successfully reloaded plugin", "plugin", pluginName, "path", entry.Path)
}
}
return nil
}
// reloadPluginCommon is a generic method to reload a backend provided a
// MountEntry. entry.Type should be checked by the caller to ensure that
// it's a "plugin" type.
func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error {
path := entry.Path
// Fast-path out if the backend doesn't exist
raw, ok := c.router.root.Get(path)
if !ok {
return nil
}
// Call backend's Cleanup routine
re := raw.(*routeEntry)
re.backend.Cleanup()
view := re.storageView
sysView := c.mountEntrySysView(entry)
conf := make(map[string]string)
if entry.Config.PluginName != "" {
conf["plugin_name"] = entry.Config.PluginName
}
var backend logical.Backend
var err error
if !isAuth {
// Dispense a new backend
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
} else {
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
}
if err != nil {
return err
}
if backend == nil {
return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type)
}
// Call initialize; this takes care of init tasks that must be run after
// the ignore paths are collected.
if err := backend.Initialize(); err != nil {
return err
}
// Set the backend back
re.backend = backend
return nil
}

View File

@ -178,6 +178,7 @@ func (r *Router) MatchingMountByUUID(mountID string) *MountEntry {
return raw.(*MountEntry)
}
// MatchingMountByAccessor returns the MountEntry by accessor lookup
func (r *Router) MatchingMountByAccessor(mountAccessor string) *MountEntry {
if mountAccessor == "" {
return nil
@ -205,7 +206,7 @@ func (r *Router) MatchingMount(path string) string {
return mount
}
// MatchingView returns the view used for a path
// MatchingStorageView returns the storageView used for a path
func (r *Router) MatchingStorageView(path string) *BarrierView {
r.l.RLock()
_, raw, ok := r.root.LongestPrefix(path)
@ -227,7 +228,7 @@ func (r *Router) MatchingMountEntry(path string) *MountEntry {
return raw.(*routeEntry).mountEntry
}
// MatchingMountEntry returns the MountEntry used for a path
// MatchingBackend returns the backend used for a path
func (r *Router) MatchingBackend(path string) logical.Backend {
r.l.RLock()
_, raw, ok := r.root.LongestPrefix(path)