From e63bf375b0cd02c84ddec5e433d5c9034a39552e Mon Sep 17 00:00:00 2001 From: Tom Proctor Date: Thu, 22 Sep 2022 13:53:52 +0100 Subject: [PATCH] Plugins: Auto version selection for auth/secrets + tune version (#17167) --- vault/core_test.go | 9 +- vault/external_plugin_test.go | 197 ++++++++++++++-------------------- vault/logical_system.go | 136 ++++++++++++++++++++--- vault/logical_system_paths.go | 8 ++ vault/logical_system_test.go | 31 +++++- 5 files changed, 244 insertions(+), 137 deletions(-) diff --git a/vault/core_test.go b/vault/core_test.go index 31d427b331..1dc93c031e 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -1889,10 +1889,11 @@ func testCore_Standby_Common(t *testing.T, inm physical.Backend, inmha physical. // Create the first core and initialize it redirectOriginal := "http://127.0.0.1:8200" core, err := NewCore(&CoreConfig{ - Physical: inm, - HAPhysical: inmha, - RedirectAddr: redirectOriginal, - DisableMlock: true, + Physical: inm, + HAPhysical: inmha, + RedirectAddr: redirectOriginal, + DisableMlock: true, + BuiltinRegistry: NewMockBuiltinRegistry(), }) if err != nil { t.Fatalf("err: %v", err) diff --git a/vault/external_plugin_test.go b/vault/external_plugin_test.go index bf13010800..3316c7475c 100644 --- a/vault/external_plugin_test.go +++ b/vault/external_plugin_test.go @@ -4,7 +4,6 @@ import ( "context" "crypto/sha256" "encoding/hex" - "errors" "fmt" "os" "os/exec" @@ -155,34 +154,9 @@ func TestCore_EnableExternalPlugin(t *testing.T) { } { t.Run(name, func(t *testing.T) { c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") - d := &framework.FieldData{ - Raw: map[string]interface{}{ - "name": pluginName, - "sha256": pluginSHA256, - "version": "v1.0.0", - "command": pluginName, - }, - Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, - } - resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d) - if err != nil { - t.Fatal(err) - } - if resp.Error() != nil { - t.Fatalf("%#v", resp) - } + registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), "1.0.0", pluginSHA256) - me := &MountEntry{ - Table: mountTable(tc.pluginType), - Path: "foo", - Type: pluginName, - Version: "v1.0.0", - } - enable := enableFunc(c, tc.pluginType) - err = enable(namespace.RootContext(nil), me) - if err != nil { - t.Fatalf("err: %v", err) - } + mountPlugin(t, c.systemBackend, pluginName, tc.pluginType, "v1.0.0") match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath) if match != tc.expectedMatch { @@ -197,6 +171,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { pluginType consts.PluginType registerVersions []string mountVersion string + expectedVersion string routerPath string expectedMatch string }{ @@ -204,6 +179,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { pluginType: consts.PluginTypeCredential, registerVersions: []string{"v1.0.0", "v1.0.1"}, mountVersion: "v1.0.0", + expectedVersion: "v1.0.0", routerPath: "auth/foo/bar", expectedMatch: "auth/foo/", }, @@ -211,6 +187,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { pluginType: consts.PluginTypeSecrets, registerVersions: []string{"v1.0.0", "v1.0.1"}, mountVersion: "v1.0.0", + expectedVersion: "v1.0.0", routerPath: "foo/bar", expectedMatch: "foo/", }, @@ -218,6 +195,7 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { pluginType: consts.PluginTypeCredential, registerVersions: []string{"v1.0.0", "v1.0.1"}, mountVersion: "v1.0.1", + expectedVersion: "v1.0.1", routerPath: "auth/foo/bar", expectedMatch: "auth/foo/", }, @@ -225,6 +203,23 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { pluginType: consts.PluginTypeSecrets, registerVersions: []string{"v1.0.0", "v1.0.1"}, mountVersion: "v1.0.1", + expectedVersion: "v1.0.1", + routerPath: "foo/bar", + expectedMatch: "foo/", + }, + "enable external credential plugin, selects latest when version not specified": { + pluginType: consts.PluginTypeCredential, + registerVersions: []string{"v1.0.0", "v1.0.1"}, + mountVersion: "", + expectedVersion: "v1.0.1", + routerPath: "auth/foo/bar", + expectedMatch: "auth/foo/", + }, + "enable external secrets plugin, selects latest when version not specified": { + pluginType: consts.PluginTypeSecrets, + registerVersions: []string{"v1.0.0", "v1.0.1"}, + mountVersion: "", + expectedVersion: "v1.0.1", routerPath: "foo/bar", expectedMatch: "foo/", }, @@ -232,35 +227,10 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { t.Run(name, func(t *testing.T) { c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") for _, version := range tc.registerVersions { - d := &framework.FieldData{ - Raw: map[string]interface{}{ - "name": pluginName, - "sha256": pluginSHA256, - "version": version, - "command": pluginName, - }, - Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, - } - resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d) - if err != nil { - t.Fatal(err) - } - if resp.Error() != nil { - t.Fatalf("%#v", resp) - } + registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), version, pluginSHA256) } - me := &MountEntry{ - Table: mountTable(tc.pluginType), - Path: "foo", - Type: pluginName, - Version: tc.mountVersion, - } - enable := enableFunc(c, tc.pluginType) - err := enable(namespace.RootContext(nil), me) - if err != nil { - t.Fatalf("err: %v", err) - } + mountPlugin(t, c.systemBackend, pluginName, tc.pluginType, tc.mountVersion) match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath) if match != tc.expectedMatch { @@ -268,8 +238,8 @@ func TestCore_EnableExternalPlugin_MultipleVersions(t *testing.T) { } raw, _ := c.router.root.Get(match) - if raw.(*routeEntry).mountEntry.Version != tc.mountVersion { - t.Errorf("Expected mount to be version %s but got %s", tc.mountVersion, raw.(*routeEntry).mountEntry.Version) + if raw.(*routeEntry).mountEntry.Version != tc.expectedVersion { + t.Errorf("Expected mount to be version %s but got %s", tc.expectedVersion, raw.(*routeEntry).mountEntry.Version) } // we don't override the running version of non-builtins, and they don't have the version set explicitly (yet) @@ -303,32 +273,14 @@ func TestCore_EnableExternalPlugin_NoVersionsOkay(t *testing.T) { } { t.Run(name, func(t *testing.T) { c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") - d := &framework.FieldData{ - Raw: map[string]interface{}{ - "name": pluginName, - "sha256": pluginSHA256, - "command": pluginName, - }, - Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, - } - resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d) - if err != nil { - t.Fatal(err) - } - if resp.Error() != nil { - t.Fatalf("%#v", resp) + // When an unversioned plugin is registered, mounting a plugin with no + // version specified should mount the unversioned plugin even if there + // are versioned plugins available. + for _, version := range []string{"", "v1.0.0"} { + registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), version, pluginSHA256) } - me := &MountEntry{ - Table: mountTable(tc.pluginType), - Path: "foo", - Type: pluginName, - } - enable := enableFunc(c, tc.pluginType) - err = enable(namespace.RootContext(nil), me) - if err != nil { - t.Fatalf("err: %v", err) - } + mountPlugin(t, c.systemBackend, pluginName, tc.pluginType, "") match := c.router.MatchingMount(namespace.RootContext(nil), tc.routerPath) if match != tc.expectedMatch { @@ -362,32 +314,16 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) { } { t.Run(name, func(t *testing.T) { c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, "") - d := &framework.FieldData{ - Raw: map[string]interface{}{ - "name": pluginName, - "sha256": pluginSHA256, - "command": pluginName, - }, - Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, - } - resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d) - if err != nil { - t.Fatal(err) - } - if resp.Error() != nil { - t.Fatalf("%#v", resp) - } + registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), "", pluginSHA256) - me := &MountEntry{ - Table: mountTable(tc.pluginType), - Path: "foo", - Type: pluginName, - Version: "v1.0.0", + req := logical.TestRequest(t, logical.UpdateOperation, mountTable(tc.pluginType)) + req.Data = map[string]interface{}{ + "type": pluginName, + "plugin_version": "v1.0.0", } - enable := enableFunc(c, tc.pluginType) - err = enable(namespace.RootContext(nil), me) - if err == nil || !errors.Is(err, ErrPluginNotFound) { - t.Fatalf("Expected to get plugin not found but got: %v", err) + resp, _ := c.systemBackend.HandleRequest(namespace.RootContext(nil), req) + if resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), ErrPluginNotFound.Error()) { + t.Fatalf("Expected to get plugin not found but got: %v", resp.Error()) } }) } @@ -486,24 +422,49 @@ func TestExternalPlugin_getBackendTypeVersion(t *testing.T) { } } +func registerPlugin(t *testing.T, sys *SystemBackend, pluginName, pluginType, version, sha string) { + t.Helper() + req := logical.TestRequest(t, logical.UpdateOperation, fmt.Sprintf("plugins/catalog/%s/%s", pluginType, pluginName)) + req.Data = map[string]interface{}{ + "name": pluginName, + "command": pluginName, + "sha256": sha, + "version": version, + } + resp, err := sys.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatal(err) + } + if resp.Error() != nil { + t.Fatalf("%#v", resp) + } +} + +func mountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType consts.PluginType, version string) { + t.Helper() + req := logical.TestRequest(t, logical.UpdateOperation, mountTable(pluginType)) + req.Data = map[string]interface{}{ + "type": pluginName, + } + if version != "" { + req.Data["plugin_version"] = version + } + resp, err := sys.HandleRequest(namespace.RootContext(nil), req) + if err != nil { + t.Fatal(err) + } + if resp.Error() != nil { + t.Fatalf("%#v", resp) + } +} + func mountTable(pluginType consts.PluginType) string { switch pluginType { case consts.PluginTypeCredential: - return credentialTableType + return "auth/foo" case consts.PluginTypeSecrets: - return mountTableType + return "mounts/foo" default: panic("test does not support plugin type yet") } } - -func enableFunc(c *Core, pluginType consts.PluginType) func(context.Context, *MountEntry) error { - switch pluginType { - case consts.PluginTypeCredential: - return c.enableCredential - case consts.PluginTypeSecrets: - return c.mount - default: - panic(pluginType.String()) - } -} diff --git a/vault/logical_system.go b/vault/logical_system.go index c635e39524..3a32a664e9 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1001,13 +1001,9 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d sealWrap := data.Get("seal_wrap").(bool) externalEntropyAccess := data.Get("external_entropy_access").(bool) options := data.Get("options").(map[string]string) - version := data.Get("plugin_version").(string) - if version != "" { - v, err := semver.NewSemver(version) - if err != nil { - return nil, err - } - version = "v" + v.String() + var version string + if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok { + version = pluginVersionRaw.(string) } var config MountConfig @@ -1114,6 +1110,27 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d } } + switch version { + case "": + var err error + version, err = selectPluginVersion(ctx, b.System(), logicalType, consts.PluginTypeSecrets) + if err != nil { + return nil, err + } + + if version != "" { + b.logger.Debug("pinning secrets plugin version", "plugin name", logicalType, "plugin version", version) + } + default: + semanticVersion, err := semver.NewVersion(version) + if err != nil { + return logical.ErrorResponse("version %q is not a valid semantic version: %s", version, err), nil + } + + // Canonicalize the version. + version = "v" + semanticVersion.String() + } + // Copy over the force no cache if set if apiConfig.ForceNoCache { config.ForceNoCache = true @@ -1169,6 +1186,39 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d return resp, nil } +func selectPluginVersion(ctx context.Context, sys logical.SystemView, pluginName string, pluginType consts.PluginType) (string, error) { + unversionedPlugin, err := sys.LookupPlugin(ctx, pluginName, pluginType) + if err == nil && !unversionedPlugin.Builtin { + // We'll select the unversioned plugin that's been registered. + return "", nil + } + + // No version provided and no unversioned plugin of that name available. + // Pin to the current latest version if any versioned plugins are registered. + plugins, err := sys.ListVersionedPlugins(ctx, pluginType) + if err != nil { + return "", err + } + + var versionedCandidates []pluginutil.VersionedPlugin + for _, plugin := range plugins { + if !plugin.Builtin && plugin.Name == pluginName && plugin.Version != "" { + versionedCandidates = append(versionedCandidates, plugin) + } + } + + if len(versionedCandidates) != 0 { + // Sort in reverse order. + sort.SliceStable(versionedCandidates, func(i, j int) bool { + return versionedCandidates[i].SemanticVersion.GreaterThan(versionedCandidates[j].SemanticVersion) + }) + + return "v" + versionedCandidates[0].SemanticVersion.String(), nil + } + + return "", nil +} + func (b *SystemBackend) handleReadMount(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { path := data.Get("path").(string) path = sanitizePath(path) @@ -1565,6 +1615,10 @@ func (b *SystemBackend) handleTuneReadCommon(ctx context.Context, path string) ( resp.Data["options"] = mountEntry.Options } + if mountEntry.Version != "" { + resp.Data["plugin_version"] = mountEntry.Version + } + return resp, nil } @@ -1699,6 +1753,43 @@ func (b *SystemBackend) handleTuneWriteCommon(ctx context.Context, path string, } } + if rawVal, ok := data.GetOk("plugin_version"); ok { + version := rawVal.(string) + semanticVersion, err := semver.NewVersion(version) + if err != nil { + return logical.ErrorResponse("version %q is not a valid semantic version: %s", version, err), nil + } + version = "v" + semanticVersion.String() + + // Lookup the version to ensure it exists in the catalog before committing. + pluginType := consts.PluginTypeSecrets + if strings.HasPrefix(path, "auth/") { + pluginType = consts.PluginTypeCredential + } + _, err = b.System().LookupPluginVersion(ctx, mountEntry.Type, pluginType, version) + if err != nil { + return handleError(err) + } + + oldVersion := mountEntry.Version + mountEntry.Version = version + + // Update the mount table + switch { + case strings.HasPrefix(path, "auth/"): + err = b.Core.persistAuth(ctx, b.Core.auth, &mountEntry.Local) + default: + err = b.Core.persistMounts(ctx, b.Core.mounts, &mountEntry.Local) + } + if err != nil { + mountEntry.Version = oldVersion + return handleError(err) + } + if b.Core.logger.IsInfo() { + b.Core.logger.Info("mount tuning of version successful", "path", path, "version", version) + } + } + if rawVal, ok := data.GetOk("audit_non_hmac_request_keys"); ok { auditNonHMACRequestKeys := rawVal.([]string) @@ -2258,13 +2349,9 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque sealWrap := data.Get("seal_wrap").(bool) externalEntropyAccess := data.Get("external_entropy_access").(bool) options := data.Get("options").(map[string]string) - version := data.Get("plugin_version").(string) - if version != "" { - v, err := semver.NewSemver(version) - if err != nil { - return nil, err - } - version = "v" + v.String() + var version string + if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok { + version = pluginVersionRaw.(string) } var config MountConfig @@ -2359,6 +2446,27 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque } } + switch version { + case "": + var err error + version, err = selectPluginVersion(ctx, b.System(), logicalType, consts.PluginTypeCredential) + if err != nil { + return nil, err + } + + if version != "" { + b.logger.Debug("pinning auth plugin version", "plugin name", logicalType, "plugin version", version) + } + default: + semanticVersion, err := semver.NewVersion(version) + if err != nil { + return logical.ErrorResponse("version %q is not a valid semantic version: %s", version, err), nil + } + + // Canonicalize the version. + version = "v" + semanticVersion.String() + } + if options != nil && options["version"] != "" { return logical.ErrorResponse(fmt.Sprintf( "auth method %q does not allow setting a version", logicalType)), diff --git a/vault/logical_system_paths.go b/vault/logical_system_paths.go index 5709a5c17b..afd4343f22 100644 --- a/vault/logical_system_paths.go +++ b/vault/logical_system_paths.go @@ -1542,6 +1542,10 @@ func (b *SystemBackend) authPaths() []*framework.Path { Type: framework.TypeString, Description: strings.TrimSpace(sysHelp["token_type"][0]), }, + "plugin_version": { + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_version"][0]), + }, }, Operations: map[logical.Operation]framework.OperationHandler{ logical.ReadOperation: &framework.PathOperation{ @@ -1921,6 +1925,10 @@ func (b *SystemBackend) mountPaths() []*framework.Path { Type: framework.TypeCommaStringSlice, Description: strings.TrimSpace(sysHelp["tune_allowed_managed_keys"][0]), }, + "plugin_version": { + Type: framework.TypeString, + Description: strings.TrimSpace(sysHelp["plugin-catalog_version"][0]), + }, }, Callbacks: map[logical.Operation]framework.OperationFunc{ diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index dd0a11d4fd..53a6e3a816 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -2009,9 +2009,35 @@ func TestSystemBackend_tuneAuth(t *testing.T) { req = logical.TestRequest(t, logical.UpdateOperation, "auth/token/tune") req.Data["description"] = "" + req.Data["plugin_version"] = "v1.0.0" + resp, err = b.HandleRequest(namespace.RootContext(nil), req) + if err == nil || resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), ErrPluginNotFound.Error()) { + t.Fatalf("expected tune request to fail, but got resp: %#v, err: %s", resp, err) + } + + // Register the plugin in the catalog, and then try the same request again. + { + tempDir, err := filepath.EvalSymlinks(t.TempDir()) + if err != nil { + t.Fatal(err) + } + c.pluginCatalog.directory = tempDir + file, err := os.Create(filepath.Join(tempDir, "foo")) + if err != nil { + t.Fatal(err) + } + if err := file.Close(); err != nil { + t.Fatal(err) + } + err = c.pluginCatalog.Set(context.Background(), "token", consts.PluginTypeCredential, "v1.0.0", "foo", []string{}, []string{}, []byte{}) + if err != nil { + t.Fatal(err) + } + } + resp, err = b.HandleRequest(namespace.RootContext(nil), req) if err != nil { - t.Fatalf("err: %v", err) + t.Fatal(resp, err) } req = logical.TestRequest(t, logical.ReadOperation, "auth/token/tune") @@ -2026,6 +2052,9 @@ func TestSystemBackend_tuneAuth(t *testing.T) { if resp.Data["description"] != "" { t.Fatalf("got: %#v expect: %#v", resp.Data["description"], "") } + if resp.Data["plugin_version"] != "v1.0.0" { + t.Fatalf("got: %#v, expected: %v", resp.Data["version"], "v1.0.0") + } } func TestSystemBackend_policyList(t *testing.T) {