mirror of
https://github.com/hashicorp/vault.git
synced 2025-11-23 19:51:09 +01:00
Add plugin backend reload capability (#3112)
* Add plugin reload capability on all mounts for a specific plugin type * Comments cleanup * Add per-mount plugin backend reload, add tests * Fix typos * Remove old comment * Reuse existing storage view in reloadPluginCommon * Correctly handle reloading auth plugin backends * Update path to plugin/backend/reload * Use multierrors on reloadMatchingPluginMounts, attempt to reload all mounts provided * Use internal value as check to ensure plugin backend reload * Remove connection state from request for plugins at the moment * Minor cleanup * Refactor tests
This commit is contained in:
parent
f2f0082ba5
commit
01d1c20e4c
@ -92,6 +92,14 @@ func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Resp
|
||||
}
|
||||
var reply HandleRequestReply
|
||||
|
||||
if req.Connection != nil {
|
||||
oldConnState := req.Connection.ConnState
|
||||
req.Connection.ConnState = nil
|
||||
defer func() {
|
||||
req.Connection.ConnState = oldConnState
|
||||
}()
|
||||
}
|
||||
|
||||
err := b.client.Call("Plugin.HandleRequest", args, &reply)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -137,6 +145,14 @@ func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool,
|
||||
}
|
||||
var reply HandleExistenceCheckReply
|
||||
|
||||
if req.Connection != nil {
|
||||
oldConnState := req.Connection.ConnState
|
||||
req.Connection.ConnState = nil
|
||||
defer func() {
|
||||
req.Connection.ConnState = oldConnState
|
||||
}()
|
||||
}
|
||||
|
||||
err := b.client.Call("Plugin.HandleExistenceCheck", args, &reply)
|
||||
if err != nil {
|
||||
return false, false, err
|
||||
|
||||
@ -9,15 +9,26 @@ import (
|
||||
// it is used to test the invalidate func.
|
||||
func pathInternal(b *backend) *framework.Path {
|
||||
return &framework.Path{
|
||||
Pattern: "internal",
|
||||
Fields: map[string]*framework.FieldSchema{},
|
||||
ExistenceCheck: b.pathExistenceCheck,
|
||||
Pattern: "internal",
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"value": &framework.FieldSchema{Type: framework.TypeString},
|
||||
},
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.ReadOperation: b.pathInternalRead,
|
||||
logical.UpdateOperation: b.pathInternalUpdate,
|
||||
logical.ReadOperation: b.pathInternalRead,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *backend) pathInternalUpdate(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
value := data.Get("value").(string)
|
||||
b.internal = value
|
||||
// Return the secret
|
||||
return nil, nil
|
||||
|
||||
}
|
||||
|
||||
func (b *backend) pathInternalRead(
|
||||
req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
|
||||
// Return the secret
|
||||
|
||||
@ -825,6 +825,27 @@ func NewSystemBackend(core *Core) *SystemBackend {
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]),
|
||||
},
|
||||
&framework.Path{
|
||||
Pattern: "plugins/backend/reload$",
|
||||
|
||||
Fields: map[string]*framework.FieldSchema{
|
||||
"plugin": &framework.FieldSchema{
|
||||
Type: framework.TypeString,
|
||||
Description: strings.TrimSpace(sysHelp["plugin-backend-reload-plugin"][0]),
|
||||
},
|
||||
"mounts": &framework.FieldSchema{
|
||||
Type: framework.TypeCommaStringSlice,
|
||||
Description: strings.TrimSpace(sysHelp["plugin-backend-reload-mounts"][0]),
|
||||
},
|
||||
},
|
||||
|
||||
Callbacks: map[logical.Operation]framework.OperationFunc{
|
||||
logical.UpdateOperation: b.handlePluginReloadUpdate,
|
||||
},
|
||||
|
||||
HelpSynopsis: strings.TrimSpace(sysHelp["plugin-reload"][0]),
|
||||
HelpDescription: strings.TrimSpace(sysHelp["plugin-reload"][1]),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -975,6 +996,32 @@ func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *frame
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (b *SystemBackend) handlePluginReloadUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
pluginName := d.Get("plugin").(string)
|
||||
pluginMounts := d.Get("mounts").([]string)
|
||||
|
||||
if pluginName != "" && len(pluginMounts) > 0 {
|
||||
return logical.ErrorResponse("plugin and mounts cannot be set at the same time"), nil
|
||||
}
|
||||
if pluginName == "" && len(pluginMounts) == 0 {
|
||||
return logical.ErrorResponse("plugin or mounts must be provided"), nil
|
||||
}
|
||||
|
||||
if pluginName != "" {
|
||||
err := b.Core.reloadMatchingPlugin(pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if len(pluginMounts) > 0 {
|
||||
err := b.Core.reloadMatchingPluginMounts(pluginMounts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// handleAuditedHeaderUpdate creates or overwrites a header entry
|
||||
func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
|
||||
header := d.Get("header").(string)
|
||||
@ -2855,4 +2902,19 @@ This path responds to the following HTTP methods.
|
||||
`The path to list leases under. Example: "aws/creds/deploy"`,
|
||||
"",
|
||||
},
|
||||
"plugin-reload": {
|
||||
"Reload mounts that use a particular backend plugin.",
|
||||
`Reload mounts that use a particular backend plugin. Either the plugin name
|
||||
or the desired plugin backend mounts must be provided, but not both. In the
|
||||
case that the plugin name is provided, all mounted paths that use that plugin
|
||||
backend will be reloaded.`,
|
||||
},
|
||||
"plugin-backend-reload-plugin": {
|
||||
`The name of the plugin to reload, as registered in the plugin catalog.`,
|
||||
"",
|
||||
},
|
||||
"plugin-backend-reload-mounts": {
|
||||
`The mount paths of the plugin backends to reload.`,
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package vault_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@ -28,11 +29,10 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
|
||||
})
|
||||
cluster.Start()
|
||||
defer cluster.Cleanup()
|
||||
cores := cluster.Cores
|
||||
core := cluster.Cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
core := cores[0]
|
||||
|
||||
b := vault.NewSystemBackend(core.Core)
|
||||
b := vault.NewSystemBackend(core)
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
bc := &logical.BackendConfig{
|
||||
Logger: logger,
|
||||
@ -49,7 +49,7 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
|
||||
|
||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
||||
|
||||
vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMain")
|
||||
vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainCredentials")
|
||||
|
||||
req := logical.TestRequest(t, logical.UpdateOperation, "auth/mock-plugin")
|
||||
req.Data["type"] = "plugin"
|
||||
@ -64,7 +64,151 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMain(t *testing.T) {
|
||||
func TestSystemBackend_PluginReload(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"plugin": "mock-plugin",
|
||||
}
|
||||
t.Run("plugin", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
|
||||
|
||||
data = map[string]interface{}{
|
||||
"mounts": "mock-0/,mock-1/",
|
||||
}
|
||||
t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) })
|
||||
}
|
||||
|
||||
func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) {
|
||||
cluster, b := testSystemBackendMock(t, 2)
|
||||
defer cluster.Cleanup()
|
||||
|
||||
core := cluster.Cores[0]
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
// Update internal value in the backend
|
||||
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mock-%d/internal", i))
|
||||
req.ClientToken = core.Client.Token()
|
||||
req.Data["value"] = "baz"
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// Perform plugin reload
|
||||
req := logical.TestRequest(t, logical.UpdateOperation, "plugins/backend/reload")
|
||||
req.ClientToken = core.Client.Token()
|
||||
req.Data = reqData
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
// Ensure internal backed value is reset
|
||||
req := logical.TestRequest(t, logical.ReadOperation, "mock-1/internal")
|
||||
req.ClientToken = core.Client.Token()
|
||||
resp, err := core.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatalf("bad: response should not be nil")
|
||||
}
|
||||
if resp.Data["value"].(string) == "baz" {
|
||||
t.Fatal("did not expect backend internal value to be 'baz'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testSystemBackendMock returns a systemBackend with the desired number
|
||||
// of mounted mock plugin backends
|
||||
func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *vault.SystemBackend) {
|
||||
coreConfig := &vault.CoreConfig{
|
||||
LogicalBackends: map[string]logical.Factory{
|
||||
"plugin": plugin.Factory,
|
||||
},
|
||||
}
|
||||
|
||||
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
|
||||
HandlerFunc: vaulthttp.Handler,
|
||||
})
|
||||
cluster.Start()
|
||||
|
||||
core := cluster.Cores[0].Core
|
||||
vault.TestWaitActive(t, core)
|
||||
|
||||
b := vault.NewSystemBackend(core)
|
||||
logger := logformat.NewVaultLogger(log.LevelTrace)
|
||||
bc := &logical.BackendConfig{
|
||||
Logger: logger,
|
||||
System: logical.StaticSystemView{
|
||||
DefaultLeaseTTLVal: time.Hour * 24,
|
||||
MaxLeaseTTLVal: time.Hour * 24 * 32,
|
||||
},
|
||||
}
|
||||
|
||||
err := b.Backend.Setup(bc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile)
|
||||
|
||||
vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainLogical")
|
||||
|
||||
for i := 0; i < numMounts; i++ {
|
||||
req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mounts/mock-%d/", i))
|
||||
req.Data["type"] = "plugin"
|
||||
req.Data["config"] = map[string]interface{}{
|
||||
"plugin_name": "mock-plugin",
|
||||
}
|
||||
|
||||
resp, err := b.HandleRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("bad: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
return cluster, b
|
||||
}
|
||||
|
||||
func TestBackend_PluginMainLogical(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv)
|
||||
if caPEM == "" {
|
||||
t.Fatal("CA cert not passed in")
|
||||
}
|
||||
|
||||
factoryFunc := mock.FactoryType(logical.TypeLogical)
|
||||
|
||||
args := []string{"--ca-cert=" + caPEM}
|
||||
|
||||
apiClientMeta := &pluginutil.APIClientMeta{}
|
||||
flags := apiClientMeta.FlagSet()
|
||||
flags.Parse(args)
|
||||
tlsConfig := apiClientMeta.GetTLSConfig()
|
||||
tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig)
|
||||
err := lplugin.Serve(&lplugin.ServeOpts{
|
||||
BackendFactoryFunc: factoryFunc,
|
||||
TLSProviderFunc: tlsProviderFunc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackend_PluginMainCredentials(t *testing.T) {
|
||||
if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
@ -230,7 +230,6 @@ func (c *Core) mount(entry *MountEntry) error {
|
||||
conf["plugin_name"] = entry.Config.PluginName
|
||||
}
|
||||
|
||||
// Consider having plugin name under entry.Options
|
||||
backend, err := c.newLogicalBackend(entry.Type, sysView, view, conf)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
125
vault/plugin_reload.go
Normal file
125
vault/plugin_reload.go
Normal file
@ -0,0 +1,125 @@
|
||||
package vault
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/hashicorp/vault/logical"
|
||||
)
|
||||
|
||||
// reloadPluginMounts reloads provided mounts, regardless of
|
||||
// plugin name, as long as the backend type is plugin.
|
||||
func (c *Core) reloadMatchingPluginMounts(mounts []string) error {
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
var errors error
|
||||
for _, mount := range mounts {
|
||||
entry := c.router.MatchingMountEntry(mount)
|
||||
if entry == nil {
|
||||
errors = multierror.Append(errors, fmt.Errorf("cannot fetch mount entry on %s", mount))
|
||||
continue
|
||||
// return fmt.Errorf("cannot fetch mount entry on %s", mount)
|
||||
}
|
||||
|
||||
var isAuth bool
|
||||
fullPath := c.router.MatchingMount(mount)
|
||||
if strings.HasPrefix(fullPath, credentialRoutePrefix) {
|
||||
isAuth = true
|
||||
}
|
||||
|
||||
if entry.Type == "plugin" {
|
||||
err := c.reloadPluginCommon(entry, isAuth)
|
||||
if err != nil {
|
||||
errors = multierror.Append(errors, fmt.Errorf("cannot reload plugin on %s: %v", mount, err))
|
||||
continue
|
||||
}
|
||||
c.logger.Info("core: successfully reloaded plugin", "plugin", entry.Config.PluginName, "path", entry.Path)
|
||||
}
|
||||
}
|
||||
return errors
|
||||
}
|
||||
|
||||
// reloadPlugin reloads all mounted backends that are of
|
||||
// plugin pluginName (name of the plugin as registered in
|
||||
// the plugin catalog).
|
||||
func (c *Core) reloadMatchingPlugin(pluginName string) error {
|
||||
c.mountsLock.Lock()
|
||||
defer c.mountsLock.Unlock()
|
||||
|
||||
// Filter mount entries that only matches the plugin name
|
||||
for _, entry := range c.mounts.Entries {
|
||||
if entry.Config.PluginName == pluginName && entry.Type == "plugin" {
|
||||
err := c.reloadPluginCommon(entry, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.Info("core: successfully reloaded plugin", "plugin", pluginName, "path", entry.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Filter auth mount entries that ony matches the plugin name
|
||||
for _, entry := range c.auth.Entries {
|
||||
if entry.Config.PluginName == pluginName && entry.Type == "plugin" {
|
||||
err := c.reloadPluginCommon(entry, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.logger.Info("core: successfully reloaded plugin", "plugin", pluginName, "path", entry.Path)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadPluginCommon is a generic method to reload a backend provided a
|
||||
// MountEntry. entry.Type should be checked by the caller to ensure that
|
||||
// it's a "plugin" type.
|
||||
func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error {
|
||||
path := entry.Path
|
||||
|
||||
// Fast-path out if the backend doesn't exist
|
||||
raw, ok := c.router.root.Get(path)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call backend's Cleanup routine
|
||||
re := raw.(*routeEntry)
|
||||
re.backend.Cleanup()
|
||||
|
||||
view := re.storageView
|
||||
|
||||
sysView := c.mountEntrySysView(entry)
|
||||
conf := make(map[string]string)
|
||||
if entry.Config.PluginName != "" {
|
||||
conf["plugin_name"] = entry.Config.PluginName
|
||||
}
|
||||
|
||||
var backend logical.Backend
|
||||
var err error
|
||||
if !isAuth {
|
||||
// Dispense a new backend
|
||||
backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf)
|
||||
} else {
|
||||
backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if backend == nil {
|
||||
return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type)
|
||||
}
|
||||
|
||||
// Call initialize; this takes care of init tasks that must be run after
|
||||
// the ignore paths are collected.
|
||||
if err := backend.Initialize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the backend back
|
||||
re.backend = backend
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -178,6 +178,7 @@ func (r *Router) MatchingMountByUUID(mountID string) *MountEntry {
|
||||
return raw.(*MountEntry)
|
||||
}
|
||||
|
||||
// MatchingMountByAccessor returns the MountEntry by accessor lookup
|
||||
func (r *Router) MatchingMountByAccessor(mountAccessor string) *MountEntry {
|
||||
if mountAccessor == "" {
|
||||
return nil
|
||||
@ -205,7 +206,7 @@ func (r *Router) MatchingMount(path string) string {
|
||||
return mount
|
||||
}
|
||||
|
||||
// MatchingView returns the view used for a path
|
||||
// MatchingStorageView returns the storageView used for a path
|
||||
func (r *Router) MatchingStorageView(path string) *BarrierView {
|
||||
r.l.RLock()
|
||||
_, raw, ok := r.root.LongestPrefix(path)
|
||||
@ -227,7 +228,7 @@ func (r *Router) MatchingMountEntry(path string) *MountEntry {
|
||||
return raw.(*routeEntry).mountEntry
|
||||
}
|
||||
|
||||
// MatchingMountEntry returns the MountEntry used for a path
|
||||
// MatchingBackend returns the backend used for a path
|
||||
func (r *Router) MatchingBackend(path string) logical.Backend {
|
||||
r.l.RLock()
|
||||
_, raw, ok := r.root.LongestPrefix(path)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user