diff --git a/logical/plugin/backend_client.go b/logical/plugin/backend_client.go index 387c516acf..0015b3fc08 100644 --- a/logical/plugin/backend_client.go +++ b/logical/plugin/backend_client.go @@ -92,6 +92,14 @@ func (b *backendPluginClient) HandleRequest(req *logical.Request) (*logical.Resp } var reply HandleRequestReply + if req.Connection != nil { + oldConnState := req.Connection.ConnState + req.Connection.ConnState = nil + defer func() { + req.Connection.ConnState = oldConnState + }() + } + err := b.client.Call("Plugin.HandleRequest", args, &reply) if err != nil { return nil, err @@ -137,6 +145,14 @@ func (b *backendPluginClient) HandleExistenceCheck(req *logical.Request) (bool, } var reply HandleExistenceCheckReply + if req.Connection != nil { + oldConnState := req.Connection.ConnState + req.Connection.ConnState = nil + defer func() { + req.Connection.ConnState = oldConnState + }() + } + err := b.client.Call("Plugin.HandleExistenceCheck", args, &reply) if err != nil { return false, false, err diff --git a/logical/plugin/mock/path_internal.go b/logical/plugin/mock/path_internal.go index 723ecb653c..92c4f8bfa2 100644 --- a/logical/plugin/mock/path_internal.go +++ b/logical/plugin/mock/path_internal.go @@ -9,15 +9,26 @@ import ( // it is used to test the invalidate func. func pathInternal(b *backend) *framework.Path { return &framework.Path{ - Pattern: "internal", - Fields: map[string]*framework.FieldSchema{}, - ExistenceCheck: b.pathExistenceCheck, + Pattern: "internal", + Fields: map[string]*framework.FieldSchema{ + "value": &framework.FieldSchema{Type: framework.TypeString}, + }, Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathInternalRead, + logical.UpdateOperation: b.pathInternalUpdate, + logical.ReadOperation: b.pathInternalRead, }, } } +func (b *backend) pathInternalUpdate( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + value := data.Get("value").(string) + b.internal = value + // Return the secret + return nil, nil + +} + func (b *backend) pathInternalRead( req *logical.Request, data *framework.FieldData) (*logical.Response, error) { // Return the secret diff --git a/vault/logical_system.go b/vault/logical_system.go index f620e18cba..2c3e75f228 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -825,6 +825,27 @@ func NewSystemBackend(core *Core) *SystemBackend { HelpSynopsis: strings.TrimSpace(sysHelp["plugin-catalog"][0]), HelpDescription: strings.TrimSpace(sysHelp["plugin-catalog"][1]), }, + &framework.Path{ + Pattern: "plugins/backend/reload$", + + Fields: map[string]*framework.FieldSchema{ + "plugin": &framework.FieldSchema{ + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-backend-reload-plugin"][0]), + }, + "mounts": &framework.FieldSchema{ + Type: framework.TypeCommaStringSlice, + Description: strings.TrimSpace(sysHelp["plugin-backend-reload-mounts"][0]), + }, + }, + + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.UpdateOperation: b.handlePluginReloadUpdate, + }, + + HelpSynopsis: strings.TrimSpace(sysHelp["plugin-reload"][0]), + HelpDescription: strings.TrimSpace(sysHelp["plugin-reload"][1]), + }, }, } @@ -975,6 +996,32 @@ func (b *SystemBackend) handlePluginCatalogDelete(req *logical.Request, d *frame return nil, nil } +func (b *SystemBackend) handlePluginReloadUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { + pluginName := d.Get("plugin").(string) + pluginMounts := d.Get("mounts").([]string) + + if pluginName != "" && len(pluginMounts) > 0 { + return logical.ErrorResponse("plugin and mounts cannot be set at the same time"), nil + } + if pluginName == "" && len(pluginMounts) == 0 { + return logical.ErrorResponse("plugin or mounts must be provided"), nil + } + + if pluginName != "" { + err := b.Core.reloadMatchingPlugin(pluginName) + if err != nil { + return nil, err + } + } else if len(pluginMounts) > 0 { + err := b.Core.reloadMatchingPluginMounts(pluginMounts) + if err != nil { + return nil, err + } + } + + return nil, nil +} + // handleAuditedHeaderUpdate creates or overwrites a header entry func (b *SystemBackend) handleAuditedHeaderUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) { header := d.Get("header").(string) @@ -2855,4 +2902,19 @@ This path responds to the following HTTP methods. `The path to list leases under. Example: "aws/creds/deploy"`, "", }, + "plugin-reload": { + "Reload mounts that use a particular backend plugin.", + `Reload mounts that use a particular backend plugin. Either the plugin name + or the desired plugin backend mounts must be provided, but not both. In the + case that the plugin name is provided, all mounted paths that use that plugin + backend will be reloaded.`, + }, + "plugin-backend-reload-plugin": { + `The name of the plugin to reload, as registered in the plugin catalog.`, + "", + }, + "plugin-backend-reload-mounts": { + `The mount paths of the plugin backends to reload.`, + "", + }, } diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index 28a420ac2f..0570bc6b9b 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -1,6 +1,7 @@ package vault_test import ( + "fmt" "os" "testing" "time" @@ -28,11 +29,10 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) { }) cluster.Start() defer cluster.Cleanup() - cores := cluster.Cores + core := cluster.Cores[0].Core + vault.TestWaitActive(t, core) - core := cores[0] - - b := vault.NewSystemBackend(core.Core) + b := vault.NewSystemBackend(core) logger := logformat.NewVaultLogger(log.LevelTrace) bc := &logical.BackendConfig{ Logger: logger, @@ -49,7 +49,7 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) { os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMain") + vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainCredentials") req := logical.TestRequest(t, logical.UpdateOperation, "auth/mock-plugin") req.Data["type"] = "plugin" @@ -64,7 +64,151 @@ func TestSystemBackend_enableAuth_plugin(t *testing.T) { } } -func TestBackend_PluginMain(t *testing.T) { +func TestSystemBackend_PluginReload(t *testing.T) { + data := map[string]interface{}{ + "plugin": "mock-plugin", + } + t.Run("plugin", func(t *testing.T) { testSystemBackend_PluginReload(t, data) }) + + data = map[string]interface{}{ + "mounts": "mock-0/,mock-1/", + } + t.Run("mounts", func(t *testing.T) { testSystemBackend_PluginReload(t, data) }) +} + +func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) { + cluster, b := testSystemBackendMock(t, 2) + defer cluster.Cleanup() + + core := cluster.Cores[0] + + for i := 0; i < 2; i++ { + // Update internal value in the backend + req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mock-%d/internal", i)) + req.ClientToken = core.Client.Token() + req.Data["value"] = "baz" + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } + } + + // Perform plugin reload + req := logical.TestRequest(t, logical.UpdateOperation, "plugins/backend/reload") + req.ClientToken = core.Client.Token() + req.Data = reqData + resp, err := b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } + + for i := 0; i < 2; i++ { + // Ensure internal backed value is reset + req := logical.TestRequest(t, logical.ReadOperation, "mock-1/internal") + req.ClientToken = core.Client.Token() + resp, err := core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + if resp.Data["value"].(string) == "baz" { + t.Fatal("did not expect backend internal value to be 'baz'") + } + } +} + +// testSystemBackendMock returns a systemBackend with the desired number +// of mounted mock plugin backends +func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *vault.SystemBackend) { + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "plugin": plugin.Factory, + }, + } + + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + + core := cluster.Cores[0].Core + vault.TestWaitActive(t, core) + + b := vault.NewSystemBackend(core) + logger := logformat.NewVaultLogger(log.LevelTrace) + bc := &logical.BackendConfig{ + Logger: logger, + System: logical.StaticSystemView{ + DefaultLeaseTTLVal: time.Hour * 24, + MaxLeaseTTLVal: time.Hour * 24 * 32, + }, + } + + err := b.Backend.Setup(bc) + if err != nil { + t.Fatal(err) + } + + os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) + + vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainLogical") + + for i := 0; i < numMounts; i++ { + req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mounts/mock-%d/", i)) + req.Data["type"] = "plugin" + req.Data["config"] = map[string]interface{}{ + "plugin_name": "mock-plugin", + } + + resp, err := b.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } + } + + return cluster, b +} + +func TestBackend_PluginMainLogical(t *testing.T) { + if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { + return + } + + caPEM := os.Getenv(pluginutil.PluginCACertPEMEnv) + if caPEM == "" { + t.Fatal("CA cert not passed in") + } + + factoryFunc := mock.FactoryType(logical.TypeLogical) + + args := []string{"--ca-cert=" + caPEM} + + apiClientMeta := &pluginutil.APIClientMeta{} + flags := apiClientMeta.FlagSet() + flags.Parse(args) + tlsConfig := apiClientMeta.GetTLSConfig() + tlsProviderFunc := pluginutil.VaultPluginTLSProvider(tlsConfig) + err := lplugin.Serve(&lplugin.ServeOpts{ + BackendFactoryFunc: factoryFunc, + TLSProviderFunc: tlsProviderFunc, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestBackend_PluginMainCredentials(t *testing.T) { if os.Getenv(pluginutil.PluginUnwrapTokenEnv) == "" { return } diff --git a/vault/mount.go b/vault/mount.go index 84eef5f8ed..16b048a69f 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -230,7 +230,6 @@ func (c *Core) mount(entry *MountEntry) error { conf["plugin_name"] = entry.Config.PluginName } - // Consider having plugin name under entry.Options backend, err := c.newLogicalBackend(entry.Type, sysView, view, conf) if err != nil { return err diff --git a/vault/plugin_reload.go b/vault/plugin_reload.go new file mode 100644 index 0000000000..eaff18b480 --- /dev/null +++ b/vault/plugin_reload.go @@ -0,0 +1,125 @@ +package vault + +import ( + "fmt" + "strings" + + multierror "github.com/hashicorp/go-multierror" + "github.com/hashicorp/vault/logical" +) + +// reloadPluginMounts reloads provided mounts, regardless of +// plugin name, as long as the backend type is plugin. +func (c *Core) reloadMatchingPluginMounts(mounts []string) error { + c.mountsLock.Lock() + defer c.mountsLock.Unlock() + + var errors error + for _, mount := range mounts { + entry := c.router.MatchingMountEntry(mount) + if entry == nil { + errors = multierror.Append(errors, fmt.Errorf("cannot fetch mount entry on %s", mount)) + continue + // return fmt.Errorf("cannot fetch mount entry on %s", mount) + } + + var isAuth bool + fullPath := c.router.MatchingMount(mount) + if strings.HasPrefix(fullPath, credentialRoutePrefix) { + isAuth = true + } + + if entry.Type == "plugin" { + err := c.reloadPluginCommon(entry, isAuth) + if err != nil { + errors = multierror.Append(errors, fmt.Errorf("cannot reload plugin on %s: %v", mount, err)) + continue + } + c.logger.Info("core: successfully reloaded plugin", "plugin", entry.Config.PluginName, "path", entry.Path) + } + } + return errors +} + +// reloadPlugin reloads all mounted backends that are of +// plugin pluginName (name of the plugin as registered in +// the plugin catalog). +func (c *Core) reloadMatchingPlugin(pluginName string) error { + c.mountsLock.Lock() + defer c.mountsLock.Unlock() + + // Filter mount entries that only matches the plugin name + for _, entry := range c.mounts.Entries { + if entry.Config.PluginName == pluginName && entry.Type == "plugin" { + err := c.reloadPluginCommon(entry, false) + if err != nil { + return err + } + c.logger.Info("core: successfully reloaded plugin", "plugin", pluginName, "path", entry.Path) + } + } + + // Filter auth mount entries that ony matches the plugin name + for _, entry := range c.auth.Entries { + if entry.Config.PluginName == pluginName && entry.Type == "plugin" { + err := c.reloadPluginCommon(entry, true) + if err != nil { + return err + } + c.logger.Info("core: successfully reloaded plugin", "plugin", pluginName, "path", entry.Path) + } + } + + return nil +} + +// reloadPluginCommon is a generic method to reload a backend provided a +// MountEntry. entry.Type should be checked by the caller to ensure that +// it's a "plugin" type. +func (c *Core) reloadPluginCommon(entry *MountEntry, isAuth bool) error { + path := entry.Path + + // Fast-path out if the backend doesn't exist + raw, ok := c.router.root.Get(path) + if !ok { + return nil + } + + // Call backend's Cleanup routine + re := raw.(*routeEntry) + re.backend.Cleanup() + + view := re.storageView + + sysView := c.mountEntrySysView(entry) + conf := make(map[string]string) + if entry.Config.PluginName != "" { + conf["plugin_name"] = entry.Config.PluginName + } + + var backend logical.Backend + var err error + if !isAuth { + // Dispense a new backend + backend, err = c.newLogicalBackend(entry.Type, sysView, view, conf) + } else { + backend, err = c.newCredentialBackend(entry.Type, sysView, view, conf) + } + if err != nil { + return err + } + if backend == nil { + return fmt.Errorf("nil backend of type %q returned from creation function", entry.Type) + } + + // Call initialize; this takes care of init tasks that must be run after + // the ignore paths are collected. + if err := backend.Initialize(); err != nil { + return err + } + + // Set the backend back + re.backend = backend + + return nil +} diff --git a/vault/router.go b/vault/router.go index bfd92ea012..931a1b5dce 100644 --- a/vault/router.go +++ b/vault/router.go @@ -178,6 +178,7 @@ func (r *Router) MatchingMountByUUID(mountID string) *MountEntry { return raw.(*MountEntry) } +// MatchingMountByAccessor returns the MountEntry by accessor lookup func (r *Router) MatchingMountByAccessor(mountAccessor string) *MountEntry { if mountAccessor == "" { return nil @@ -205,7 +206,7 @@ func (r *Router) MatchingMount(path string) string { return mount } -// MatchingView returns the view used for a path +// MatchingStorageView returns the storageView used for a path func (r *Router) MatchingStorageView(path string) *BarrierView { r.l.RLock() _, raw, ok := r.root.LongestPrefix(path) @@ -227,7 +228,7 @@ func (r *Router) MatchingMountEntry(path string) *MountEntry { return raw.(*routeEntry).mountEntry } -// MatchingMountEntry returns the MountEntry used for a path +// MatchingBackend returns the backend used for a path func (r *Router) MatchingBackend(path string) logical.Backend { r.l.RLock() _, raw, ok := r.root.LongestPrefix(path)