diff --git a/builtin/plugin/backend.go b/builtin/plugin/backend.go index 1b7105d791..3945f78d2d 100644 --- a/builtin/plugin/backend.go +++ b/builtin/plugin/backend.go @@ -2,7 +2,10 @@ package plugin import ( "fmt" + "net/rpc" + "sync" + uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/logical" bplugin "github.com/hashicorp/vault/logical/plugin" ) @@ -27,13 +30,111 @@ func Factory(conf *logical.BackendConfig) (logical.Backend, error) { // Backend returns an instance of the backend, either as a plugin if external // or as a concrete implementation if builtin, casted as logical.Backend. func Backend(conf *logical.BackendConfig) (logical.Backend, error) { + var b backend name := conf.Config["plugin_name"] sys := conf.System - b, err := bplugin.NewBackend(name, sys, conf.Logger) + raw, err := bplugin.NewBackend(name, sys, conf.Logger) if err != nil { return nil, err } + b.Backend = raw + b.config = conf - return b, nil + return &b, nil +} + +// backend is a thin wrapper around plugin.BackendPluginClient +type backend struct { + logical.Backend + sync.RWMutex + + config *logical.BackendConfig + + // Used to detect if we already reloaded + canary string +} + +func (b *backend) reloadBackend() error { + pluginName := b.config.Config["plugin_name"] + b.Logger().Trace("plugin: reloading plugin backend", "plugin", pluginName) + + // Ensure proper cleanup of the backend (i.e. call client.Kill()) + b.Backend.Cleanup() + + nb, err := bplugin.NewBackend(pluginName, b.config.System, b.config.Logger) + if err != nil { + return err + } + err = nb.Setup(b.config) + if err != nil { + return err + } + b.Backend = nb + + return nil +} + +// HandleRequest is a thin wrapper implementation of HandleRequest that includes automatic plugin reload. +func (b *backend) HandleRequest(req *logical.Request) (*logical.Response, error) { + b.RLock() + canary := b.canary + resp, err := b.Backend.HandleRequest(req) + b.RUnlock() + // Need to compare string value for case were err comes from plugin RPC + // and is returned as plugin.BasicError type. + if err != nil && err.Error() == rpc.ErrShutdown.Error() { + // Reload plugin if it's an rpc.ErrShutdown + b.Lock() + if b.canary == canary { + err := b.reloadBackend() + if err != nil { + b.Unlock() + return nil, err + } + b.canary, err = uuid.GenerateUUID() + if err != nil { + b.Unlock() + return nil, err + } + } + b.Unlock() + + // Try request once more + b.RLock() + defer b.RUnlock() + return b.Backend.HandleRequest(req) + } + return resp, err +} + +// HandleExistenceCheck is a thin wrapper implementation of HandleRequest that includes automatic plugin reload. +func (b *backend) HandleExistenceCheck(req *logical.Request) (bool, bool, error) { + b.RLock() + canary := b.canary + checkFound, exists, err := b.Backend.HandleExistenceCheck(req) + b.RUnlock() + if err != nil && err.Error() == rpc.ErrShutdown.Error() { + // Reload plugin if it's an rpc.ErrShutdown + b.Lock() + if b.canary == canary { + err := b.reloadBackend() + if err != nil { + b.Unlock() + return false, false, err + } + b.canary, err = uuid.GenerateUUID() + if err != nil { + b.Unlock() + return false, false, err + } + } + b.Unlock() + + // Try request once more + b.RLock() + defer b.RUnlock() + return b.Backend.HandleExistenceCheck(req) + } + return checkFound, exists, err } diff --git a/builtin/plugin/backend_test.go b/builtin/plugin/backend_test.go index d9f182cb78..0a37691d63 100644 --- a/builtin/plugin/backend_test.go +++ b/builtin/plugin/backend_test.go @@ -14,6 +14,10 @@ import ( log "github.com/mgutz/logxi/v1" ) +func TestBackend_impl(t *testing.T) { + var _ logical.Backend = &backend{} +} + func TestBackend(t *testing.T) { config, cleanup := testConfig(t) defer cleanup() diff --git a/logical/plugin/mock/backend.go b/logical/plugin/mock/backend.go index fc49f2078e..5f4c977496 100644 --- a/logical/plugin/mock/backend.go +++ b/logical/plugin/mock/backend.go @@ -38,10 +38,13 @@ func Backend() *backend { var b backend b.Backend = &framework.Backend{ Help: "", - Paths: []*framework.Path{ - pathKV(&b), - pathInternal(&b), - }, + Paths: framework.PathAppend( + errorPaths(&b), + kvPaths(&b), + []*framework.Path{ + pathInternal(&b), + }, + ), PathsSpecial: &logical.Paths{ Unauthenticated: []string{ "special", diff --git a/logical/plugin/mock/path_errors.go b/logical/plugin/mock/path_errors.go new file mode 100644 index 0000000000..00c4e3df36 --- /dev/null +++ b/logical/plugin/mock/path_errors.go @@ -0,0 +1,32 @@ +package mock + +import ( + "net/rpc" + + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" +) + +// pathInternal is used to test viewing internal backend values. In this case, +// it is used to test the invalidate func. +func errorPaths(b *backend) []*framework.Path { + return []*framework.Path{ + &framework.Path{ + Pattern: "errors/rpc", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathErrorRPCRead, + }, + }, + &framework.Path{ + Pattern: "errors/kill", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathErrorRPCRead, + }, + }, + } +} + +func (b *backend) pathErrorRPCRead( + req *logical.Request, data *framework.FieldData) (*logical.Response, error) { + return nil, rpc.ErrShutdown +} diff --git a/logical/plugin/mock/path_kv.go b/logical/plugin/mock/path_kv.go index 47883d8ebb..badede2900 100644 --- a/logical/plugin/mock/path_kv.go +++ b/logical/plugin/mock/path_kv.go @@ -7,22 +7,29 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -// pathKV is used to test CRUD and List operations. It is a simplified +// kvPaths is used to test CRUD and List operations. It is a simplified // version of the passthrough backend that only accepts string values. -func pathKV(b *backend) *framework.Path { - return &framework.Path{ - Pattern: "kv/" + framework.GenericNameRegex("key"), - Fields: map[string]*framework.FieldSchema{ - "key": &framework.FieldSchema{Type: framework.TypeString}, - "value": &framework.FieldSchema{Type: framework.TypeString}, +func kvPaths(b *backend) []*framework.Path { + return []*framework.Path{ + &framework.Path{ + Pattern: "kv/?", + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ListOperation: b.pathKVList, + }, }, - ExistenceCheck: b.pathExistenceCheck, - Callbacks: map[logical.Operation]framework.OperationFunc{ - logical.ReadOperation: b.pathKVRead, - logical.CreateOperation: b.pathKVCreateUpdate, - logical.UpdateOperation: b.pathKVCreateUpdate, - logical.DeleteOperation: b.pathKVDelete, - logical.ListOperation: b.pathKVList, + &framework.Path{ + Pattern: "kv/" + framework.GenericNameRegex("key"), + Fields: map[string]*framework.FieldSchema{ + "key": &framework.FieldSchema{Type: framework.TypeString}, + "value": &framework.FieldSchema{Type: framework.TypeString}, + }, + ExistenceCheck: b.pathExistenceCheck, + Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.ReadOperation: b.pathKVRead, + logical.CreateOperation: b.pathKVCreateUpdate, + logical.UpdateOperation: b.pathKVCreateUpdate, + logical.DeleteOperation: b.pathKVDelete, + }, }, } } diff --git a/vault/core.go b/vault/core.go index 4756ccdf8a..26453f7e01 100644 --- a/vault/core.go +++ b/vault/core.go @@ -1347,6 +1347,9 @@ func (c *Core) postUnseal() (retErr error) { if err := c.ensureWrappingKey(); err != nil { return err } + if err := c.setupPluginCatalog(); err != nil { + return err + } if err := c.loadMounts(); err != nil { return err } @@ -1380,9 +1383,6 @@ func (c *Core) postUnseal() (retErr error) { if err := c.setupAuditedHeadersConfig(); err != nil { return err } - if err := c.setupPluginCatalog(); err != nil { - return err - } if c.ha != nil { if err := c.startClusterListener(); err != nil { diff --git a/vault/dynamic_system_view.go b/vault/dynamic_system_view.go index 3844b46bfa..af6cb7fbe3 100644 --- a/vault/dynamic_system_view.go +++ b/vault/dynamic_system_view.go @@ -121,6 +121,12 @@ func (d dynamicSystemView) ResponseWrapData(data map[string]interface{}, ttl tim // LookupPlugin looks for a plugin with the given name in the plugin catalog. It // returns a PluginRunner or an error if no plugin was found. func (d dynamicSystemView) LookupPlugin(name string) (*pluginutil.PluginRunner, error) { + if d.core == nil { + return nil, fmt.Errorf("system view core is nil") + } + if d.core.pluginCatalog == nil { + return nil, fmt.Errorf("system view core plugin catalog is nil") + } r, err := d.core.pluginCatalog.Get(name) if err != nil { return nil, err diff --git a/vault/logical_system_integ_test.go b/vault/logical_system_integ_test.go index 0570bc6b9b..8ddd8eb75c 100644 --- a/vault/logical_system_integ_test.go +++ b/vault/logical_system_integ_test.go @@ -4,67 +4,68 @@ import ( "fmt" "os" "testing" - "time" "github.com/hashicorp/vault/builtin/plugin" - "github.com/hashicorp/vault/helper/logformat" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" lplugin "github.com/hashicorp/vault/logical/plugin" "github.com/hashicorp/vault/logical/plugin/mock" "github.com/hashicorp/vault/vault" - log "github.com/mgutz/logxi/v1" ) -func TestSystemBackend_enableAuth_plugin(t *testing.T) { - coreConfig := &vault.CoreConfig{ - CredentialBackends: map[string]logical.Factory{ - "plugin": plugin.Factory, - }, - } - - cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ - HandlerFunc: vaulthttp.Handler, - }) - cluster.Start() +func TestSystemBackend_Plugin_secret(t *testing.T) { + cluster := testSystemBackendMock(t, 1, logical.TypeLogical) defer cluster.Cleanup() - core := cluster.Cores[0].Core - vault.TestWaitActive(t, core) +} - b := vault.NewSystemBackend(core) - logger := logformat.NewVaultLogger(log.LevelTrace) - bc := &logical.BackendConfig{ - Logger: logger, - System: logical.StaticSystemView{ - DefaultLeaseTTLVal: time.Hour * 24, - MaxLeaseTTLVal: time.Hour * 24 * 32, - }, - } +func TestSystemBackend_Plugin_auth(t *testing.T) { + cluster := testSystemBackendMock(t, 1, logical.TypeCredential) + defer cluster.Cleanup() +} - err := b.Backend.Setup(bc) - if err != nil { - t.Fatal(err) - } +func TestSystemBackend_Plugin_autoReload(t *testing.T) { + cluster := testSystemBackendMock(t, 1, logical.TypeLogical) + defer cluster.Cleanup() - os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) + core := cluster.Cores[0] - vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainCredentials") - - req := logical.TestRequest(t, logical.UpdateOperation, "auth/mock-plugin") - req.Data["type"] = "plugin" - req.Data["plugin_name"] = "mock-plugin" - - resp, err := b.HandleRequest(req) + // Update internal value + req := logical.TestRequest(t, logical.UpdateOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + req.Data["value"] = "baz" + resp, err := core.HandleRequest(req) if err != nil { t.Fatalf("err: %v", err) } if resp != nil { t.Fatalf("bad: %v", resp) } + + // Call errors/rpc endpoint to trigger reload + req = logical.TestRequest(t, logical.ReadOperation, "mock-0/errors/rpc") + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(req) + if err == nil { + t.Fatalf("expected error from error/rpc request") + } + + // Check internal value to make sure it's reset + req = logical.TestRequest(t, logical.ReadOperation, "mock-0/internal") + req.ClientToken = core.Client.Token() + resp, err = core.HandleRequest(req) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatalf("bad: response should not be nil") + } + if resp.Data["value"].(string) == "baz" { + t.Fatal("did not expect backend internal value to be 'baz'") + } } -func TestSystemBackend_PluginReload(t *testing.T) { +func TestSystemBackend_Plugin_reload(t *testing.T) { data := map[string]interface{}{ "plugin": "mock-plugin", } @@ -77,17 +78,17 @@ func TestSystemBackend_PluginReload(t *testing.T) { } func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{}) { - cluster, b := testSystemBackendMock(t, 2) + cluster := testSystemBackendMock(t, 2, logical.TypeLogical) defer cluster.Cleanup() core := cluster.Cores[0] + client := core.Client for i := 0; i < 2; i++ { // Update internal value in the backend - req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mock-%d/internal", i)) - req.ClientToken = core.Client.Token() - req.Data["value"] = "baz" - resp, err := core.HandleRequest(req) + resp, err := client.Logical().Write(fmt.Sprintf("mock-%d/internal", i), map[string]interface{}{ + "value": "baz", + }) if err != nil { t.Fatalf("err: %v", err) } @@ -97,10 +98,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} } // Perform plugin reload - req := logical.TestRequest(t, logical.UpdateOperation, "plugins/backend/reload") - req.ClientToken = core.Client.Token() - req.Data = reqData - resp, err := b.HandleRequest(req) + resp, err := client.Logical().Write("sys/plugins/backend/reload", reqData) if err != nil { t.Fatalf("err: %v", err) } @@ -110,9 +108,7 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} for i := 0; i < 2; i++ { // Ensure internal backed value is reset - req := logical.TestRequest(t, logical.ReadOperation, "mock-1/internal") - req.ClientToken = core.Client.Token() - resp, err := core.HandleRequest(req) + resp, err := client.Logical().Read(fmt.Sprintf("mock-%d/internal", i)) if err != nil { t.Fatalf("err: %v", err) } @@ -127,11 +123,14 @@ func testSystemBackend_PluginReload(t *testing.T, reqData map[string]interface{} // testSystemBackendMock returns a systemBackend with the desired number // of mounted mock plugin backends -func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *vault.SystemBackend) { +func testSystemBackendMock(t *testing.T, numMounts int, backendType logical.BackendType) *vault.TestCluster { coreConfig := &vault.CoreConfig{ LogicalBackends: map[string]logical.Factory{ "plugin": plugin.Factory, }, + CredentialBackends: map[string]logical.Factory{ + "plugin": plugin.Factory, + }, } cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ @@ -139,45 +138,48 @@ func testSystemBackendMock(t *testing.T, numMounts int) (*vault.TestCluster, *va }) cluster.Start() - core := cluster.Cores[0].Core - vault.TestWaitActive(t, core) - - b := vault.NewSystemBackend(core) - logger := logformat.NewVaultLogger(log.LevelTrace) - bc := &logical.BackendConfig{ - Logger: logger, - System: logical.StaticSystemView{ - DefaultLeaseTTLVal: time.Hour * 24, - MaxLeaseTTLVal: time.Hour * 24 * 32, - }, - } - - err := b.Backend.Setup(bc) - if err != nil { - t.Fatal(err) - } + core := cluster.Cores[0] + vault.TestWaitActive(t, core.Core) + client := core.Client os.Setenv(pluginutil.PluginCACertPEMEnv, cluster.CACertPEMFile) - vault.TestAddTestPlugin(t, core, "mock-plugin", "TestBackend_PluginMainLogical") - - for i := 0; i < numMounts; i++ { - req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("mounts/mock-%d/", i)) - req.Data["type"] = "plugin" - req.Data["config"] = map[string]interface{}{ - "plugin_name": "mock-plugin", + switch backendType { + case logical.TypeLogical: + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainLogical") + for i := 0; i < numMounts; i++ { + resp, err := client.Logical().Write(fmt.Sprintf("sys/mounts/mock-%d", i), map[string]interface{}{ + "type": "plugin", + "config": map[string]interface{}{ + "plugin_name": "mock-plugin", + }, + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } } - - resp, err := b.HandleRequest(req) - if err != nil { - t.Fatalf("err: %v", err) - } - if resp != nil { - t.Fatalf("bad: %v", resp) + case logical.TypeCredential: + vault.TestAddTestPlugin(t, core.Core, "mock-plugin", "TestBackend_PluginMainCredentials") + for i := 0; i < numMounts; i++ { + resp, err := client.Logical().Write(fmt.Sprintf("sys/auth/mock-%d", i), map[string]interface{}{ + "type": "plugin", + "plugin_name": "mock-plugin", + }) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp != nil { + t.Fatalf("bad: %v", resp) + } } + default: + t.Fatal("unknown backend type provided") } - return cluster, b + return cluster } func TestBackend_PluginMainLogical(t *testing.T) {