diff --git a/api/sys_mounts.go b/api/sys_mounts.go index ddaddaf475..75173b4d8f 100644 --- a/api/sys_mounts.go +++ b/api/sys_mounts.go @@ -247,7 +247,6 @@ type MountInput struct { SealWrap bool `json:"seal_wrap" mapstructure:"seal_wrap"` ExternalEntropyAccess bool `json:"external_entropy_access" mapstructure:"external_entropy_access"` Options map[string]string `json:"options"` - PluginVersion string `json:"plugin_version,omitempty"` // Deprecated: Newer server responses should be returning this information in the // Type field (json: "type") instead. @@ -267,6 +266,7 @@ type MountConfigInput struct { AllowedResponseHeaders []string `json:"allowed_response_headers,omitempty" mapstructure:"allowed_response_headers"` TokenType string `json:"token_type,omitempty" mapstructure:"token_type"` AllowedManagedKeys []string `json:"allowed_managed_keys,omitempty" mapstructure:"allowed_managed_keys"` + PluginVersion string `json:"plugin_version,omitempty"` // Deprecated: This field will always be blank for newer server responses. PluginName string `json:"plugin_name,omitempty" mapstructure:"plugin_name"` diff --git a/command/auth_enable.go b/command/auth_enable.go index 4214bf3f70..33ec9291ba 100644 --- a/command/auth_enable.go +++ b/command/auth_enable.go @@ -201,7 +201,7 @@ func (c *AuthEnableCommand) Flags() *FlagSets { }) f.StringVar(&StringVar{ - Name: "plugin-version", + Name: flagNamePluginVersion, Target: &c.flagPluginVersion, Default: "", Usage: "Select the semantic version of the plugin to enable.", @@ -270,7 +270,6 @@ func (c *AuthEnableCommand) Run(args []string) int { authOpts := &api.EnableAuthOptions{ Type: authType, - PluginVersion: c.flagPluginVersion, Description: c.flagDescription, Local: c.flagLocal, SealWrap: c.flagSealWrap, @@ -307,6 +306,10 @@ func (c *AuthEnableCommand) Run(args []string) int { if fl.Name == flagNameTokenType { authOpts.Config.TokenType = c.flagTokenType } + + if fl.Name == flagNamePluginVersion { + authOpts.Config.PluginVersion = c.flagPluginVersion + } }) if err := client.Sys().EnableAuthWithOptions(authPath, authOpts); err != nil { diff --git a/command/auth_tune.go b/command/auth_tune.go index 9c3a963efc..de4c198273 100644 --- a/command/auth_tune.go +++ b/command/auth_tune.go @@ -31,6 +31,7 @@ type AuthTuneCommand struct { flagOptions map[string]string flagTokenType string flagVersion int + flagPluginVersion string } func (c *AuthTuneCommand) Synopsis() string { @@ -144,6 +145,14 @@ func (c *AuthTuneCommand) Flags() *FlagSets { Usage: "Select the version of the auth method to run. Not supported by all auth methods.", }) + f.StringVar(&StringVar{ + Name: flagNamePluginVersion, + Target: &c.flagPluginVersion, + Default: "", + Usage: "Select the semantic version of the plugin to run. The new version must be registered in " + + "the plugin catalog, and will not start running until the plugin is reloaded.", + }) + return set } @@ -221,6 +230,10 @@ func (c *AuthTuneCommand) Run(args []string) int { if fl.Name == flagNameTokenType { mountConfigInput.TokenType = c.flagTokenType } + + if fl.Name == flagNamePluginVersion { + mountConfigInput.PluginVersion = c.flagPluginVersion + } }) // Append /auth (since that's where auths live) and a trailing slash to diff --git a/command/auth_tune_test.go b/command/auth_tune_test.go index 227330ea77..635a70f44b 100644 --- a/command/auth_tune_test.go +++ b/command/auth_tune_test.go @@ -6,6 +6,8 @@ import ( "github.com/go-test/deep" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) @@ -74,7 +76,10 @@ func TestAuthTuneCommand_Run(t *testing.T) { t.Run("integration", func(t *testing.T) { t.Run("flags_all", func(t *testing.T) { t.Parallel() - client, closer := testVaultServer(t) + pluginDir, cleanup := vault.MakeTestPluginDir(t) + defer cleanup(t) + + client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() ui, cmd := testAuthTuneCommand(t) @@ -87,6 +92,21 @@ func TestAuthTuneCommand_Run(t *testing.T) { t.Fatal(err) } + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + mountInfo, ok := auths["my-auth/"] + if !ok { + t.Fatalf("expected mount to exist: %#v", auths) + } + + if exp := ""; mountInfo.PluginVersion != exp { + t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp) + } + + _, _, version := testPluginCreateAndRegisterVersioned(t, client, pluginDir, "userpass", consts.PluginTypeCredential) + code := cmd.Run([]string{ "-description", "new description", "-default-lease-ttl", "30m", @@ -97,6 +117,7 @@ func TestAuthTuneCommand_Run(t *testing.T) { "-passthrough-request-headers", "www-authentication", "-allowed-response-headers", "authorization,www-authentication", "-listing-visibility", "unauth", + "-plugin-version", version, "my-auth/", }) if exp := 0; code != exp { @@ -109,12 +130,12 @@ func TestAuthTuneCommand_Run(t *testing.T) { t.Errorf("expected %q to contain %q", combined, expected) } - auths, err := client.Sys().ListAuth() + auths, err = client.Sys().ListAuth() if err != nil { t.Fatal(err) } - mountInfo, ok := auths["my-auth/"] + mountInfo, ok = auths["my-auth/"] if !ok { t.Fatalf("expected auth to exist") } @@ -124,6 +145,9 @@ func TestAuthTuneCommand_Run(t *testing.T) { if exp := "userpass"; mountInfo.Type != exp { t.Errorf("expected %q to be %q", mountInfo.Type, exp) } + if exp := version; mountInfo.PluginVersion != exp { + t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp) + } if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp { t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp) } diff --git a/command/commands.go b/command/commands.go index f6aad476d1..d4ce6b6ca3 100644 --- a/command/commands.go +++ b/command/commands.go @@ -124,6 +124,8 @@ const ( flagNameTokenType = "token-type" // flagNameAllowedManagedKeys is the flag name used for auth/secrets enable flagNameAllowedManagedKeys = "allowed-managed-keys" + // flagNamePluginVersion selects what version of a plugin should be used. + flagNamePluginVersion = "plugin-version" ) var ( diff --git a/command/plugin_register_test.go b/command/plugin_register_test.go index 69031c4691..c2047d070f 100644 --- a/command/plugin_register_test.go +++ b/command/plugin_register_test.go @@ -1,6 +1,8 @@ package command import ( + "reflect" + "sort" "strings" "testing" @@ -124,6 +126,75 @@ func TestPluginRegisterCommand_Run(t *testing.T) { } }) + t.Run("integration with version", func(t *testing.T) { + t.Parallel() + + pluginDir, cleanup := vault.MakeTestPluginDir(t) + defer cleanup(t) + + client, _, closer := testVaultServerPluginDir(t, pluginDir) + defer closer() + + const pluginName = "my-plugin" + versions := []string{"v1.0.0", "v2.0.1"} + _, sha256Sum := testPluginCreate(t, pluginDir, pluginName) + types := []consts.PluginType{consts.PluginTypeCredential, consts.PluginTypeDatabase, consts.PluginTypeSecrets} + + for _, typ := range types { + for _, version := range versions { + ui, cmd := testPluginRegisterCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-version=" + version, + "-sha256=" + sha256Sum, + typ.String(), + pluginName, + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Registered plugin: my-plugin" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + } + } + + resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{ + Type: consts.PluginTypeUnknown, + }) + if err != nil { + t.Fatal(err) + } + + found := make(map[consts.PluginType]int) + versionsFound := make(map[consts.PluginType][]string) + for _, p := range resp.Details { + if p.Name == pluginName { + typ, err := consts.ParsePluginType(p.Type) + if err != nil { + t.Fatal(err) + } + found[typ]++ + versionsFound[typ] = append(versionsFound[typ], p.Version) + } + } + + for _, typ := range types { + if found[typ] != 2 { + t.Fatalf("expected %q to be found 2 times, but found it %d times for %s type in %#v", pluginName, found[typ], typ.String(), resp.Details) + } + sort.Strings(versions) + sort.Strings(versionsFound[typ]) + if !reflect.DeepEqual(versions, versionsFound[typ]) { + t.Fatalf("expected %v versions but got %v", versions, versionsFound[typ]) + } + } + }) + t.Run("communication_failure", func(t *testing.T) { t.Parallel() diff --git a/command/secrets_enable.go b/command/secrets_enable.go index 91144c1566..4b8d31378e 100644 --- a/command/secrets_enable.go +++ b/command/secrets_enable.go @@ -32,6 +32,7 @@ type SecretsEnableCommand struct { flagAllowedResponseHeaders []string flagForceNoCache bool flagPluginName string + flagPluginVersion string flagOptions map[string]string flagLocal bool flagSealWrap bool @@ -173,6 +174,13 @@ func (c *SecretsEnableCommand) Flags() *FlagSets { "exist in Vault's plugin catalog.", }) + f.StringVar(&StringVar{ + Name: flagNamePluginVersion, + Target: &c.flagPluginVersion, + Default: "", + Usage: "Select the semantic version of the plugin to enable.", + }) + f.StringMapVar(&StringMapVar{ Name: "options", Target: &c.flagOptions, @@ -320,6 +328,10 @@ func (c *SecretsEnableCommand) Run(args []string) int { if fl.Name == flagNameAllowedManagedKeys { mountInput.Config.AllowedManagedKeys = c.flagAllowedManagedKeys } + + if fl.Name == flagNamePluginVersion { + mountInput.Config.PluginVersion = c.flagPluginVersion + } }) if err := client.Sys().Mount(mountPath, mountInput); err != nil { diff --git a/command/secrets_tune.go b/command/secrets_tune.go index 002bdd9ea0..bf8fa3d593 100644 --- a/command/secrets_tune.go +++ b/command/secrets_tune.go @@ -30,6 +30,7 @@ type SecretsTuneCommand struct { flagAllowedResponseHeaders []string flagOptions map[string]string flagVersion int + flagPluginVersion string flagAllowedManagedKeys []string } @@ -146,6 +147,14 @@ func (c *SecretsTuneCommand) Flags() *FlagSets { "each time with 1 key.", }) + f.StringVar(&StringVar{ + Name: flagNamePluginVersion, + Target: &c.flagPluginVersion, + Default: "", + Usage: "Select the semantic version of the plugin to run. The new version must be registered in " + + "the plugin catalog, and will not start running until the plugin is reloaded.", + }) + return set } @@ -226,6 +235,10 @@ func (c *SecretsTuneCommand) Run(args []string) int { if fl.Name == flagNameAllowedManagedKeys { mountConfigInput.AllowedManagedKeys = c.flagAllowedManagedKeys } + + if fl.Name == flagNamePluginVersion { + mountConfigInput.PluginVersion = c.flagPluginVersion + } }) if err := client.Sys().TuneMount(mountPath, mountConfigInput); err != nil { diff --git a/command/secrets_tune_test.go b/command/secrets_tune_test.go index f51b8fb34b..41c6bd2f6f 100644 --- a/command/secrets_tune_test.go +++ b/command/secrets_tune_test.go @@ -6,6 +6,8 @@ import ( "github.com/go-test/deep" "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/sdk/helper/consts" + "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) @@ -148,7 +150,10 @@ func TestSecretsTuneCommand_Run(t *testing.T) { t.Run("integration", func(t *testing.T) { t.Run("flags_all", func(t *testing.T) { t.Parallel() - client, closer := testVaultServer(t) + pluginDir, cleanup := vault.MakeTestPluginDir(t) + defer cleanup(t) + + client, _, closer := testVaultServerPluginDir(t, pluginDir) defer closer() ui, cmd := testSecretsTuneCommand(t) @@ -161,6 +166,21 @@ func TestSecretsTuneCommand_Run(t *testing.T) { t.Fatal(err) } + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + mountInfo, ok := mounts["mount_tune_integration/"] + if !ok { + t.Fatalf("expected mount to exist") + } + + if exp := ""; mountInfo.PluginVersion != exp { + t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp) + } + + _, _, version := testPluginCreateAndRegisterVersioned(t, client, pluginDir, "pki", consts.PluginTypeSecrets) + code := cmd.Run([]string{ "-description", "new description", "-default-lease-ttl", "30m", @@ -172,6 +192,7 @@ func TestSecretsTuneCommand_Run(t *testing.T) { "-allowed-response-headers", "authorization,www-authentication", "-allowed-managed-keys", "key1,key2", "-listing-visibility", "unauth", + "-plugin-version", version, "mount_tune_integration/", }) if exp := 0; code != exp { @@ -184,12 +205,12 @@ func TestSecretsTuneCommand_Run(t *testing.T) { t.Errorf("expected %q to contain %q", combined, expected) } - mounts, err := client.Sys().ListMounts() + mounts, err = client.Sys().ListMounts() if err != nil { t.Fatal(err) } - mountInfo, ok := mounts["mount_tune_integration/"] + mountInfo, ok = mounts["mount_tune_integration/"] if !ok { t.Fatalf("expected mount to exist") } @@ -199,6 +220,9 @@ func TestSecretsTuneCommand_Run(t *testing.T) { if exp := "pki"; mountInfo.Type != exp { t.Errorf("expected %q to be %q", mountInfo.Type, exp) } + if exp := version; mountInfo.PluginVersion != exp { + t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp) + } if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp { t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp) } diff --git a/vault/external_plugin_test.go b/vault/external_plugin_test.go index 3316c7475c..66f46e6841 100644 --- a/vault/external_plugin_test.go +++ b/vault/external_plugin_test.go @@ -318,8 +318,10 @@ func TestCore_EnableExternalCredentialPlugin_NoVersionOnRegister(t *testing.T) { req := logical.TestRequest(t, logical.UpdateOperation, mountTable(tc.pluginType)) req.Data = map[string]interface{}{ - "type": pluginName, - "plugin_version": "v1.0.0", + "type": pluginName, + "config": map[string]interface{}{ + "plugin_version": "v1.0.0", + }, } resp, _ := c.systemBackend.HandleRequest(namespace.RootContext(nil), req) if resp == nil || !resp.IsError() || !strings.Contains(resp.Error().Error(), ErrPluginNotFound.Error()) { @@ -379,22 +381,7 @@ func TestExternalPlugin_getBackendTypeVersion(t *testing.T) { } { t.Run(name, func(t *testing.T) { c, pluginName, pluginSHA256 := testCoreWithPlugin(t, tc.pluginType, tc.setRunningVersion) - d := &framework.FieldData{ - Raw: map[string]interface{}{ - "name": pluginName, - "sha256": pluginSHA256, - "version": tc.setRunningVersion, - "command": pluginName, - }, - Schema: c.systemBackend.pluginsCatalogCRUDPath().Fields, - } - resp, err := c.systemBackend.handlePluginCatalogUpdate(context.Background(), nil, d) - if err != nil { - t.Fatal(err) - } - if resp.Error() != nil { - t.Fatalf("%#v", resp) - } + registerPlugin(t, c.systemBackend, pluginName, tc.pluginType.String(), tc.setRunningVersion, pluginSHA256) shaBytes, _ := hex.DecodeString(pluginSHA256) commandFull := filepath.Join(c.pluginCatalog.directory, pluginName) @@ -407,6 +394,7 @@ func TestExternalPlugin_getBackendTypeVersion(t *testing.T) { } var version logical.PluginVersion + var err error if tc.pluginType == consts.PluginTypeDatabase { version, err = c.pluginCatalog.getDatabaseRunningVersion(context.Background(), entry) } else { @@ -447,7 +435,9 @@ func mountPlugin(t *testing.T, sys *SystemBackend, pluginName string, pluginType "type": pluginName, } if version != "" { - req.Data["plugin_version"] = version + req.Data["config"] = map[string]interface{}{ + "plugin_version": version, + } } resp, err := sys.HandleRequest(namespace.RootContext(nil), req) if err != nil { diff --git a/vault/logical_system.go b/vault/logical_system.go index 3a32a664e9..45669404e7 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1001,10 +1001,6 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d sealWrap := data.Get("seal_wrap").(bool) externalEntropyAccess := data.Get("external_entropy_access").(bool) options := data.Get("options").(map[string]string) - var version string - if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok { - version = pluginVersionRaw.(string) - } var config MountConfig var apiConfig APIMountConfig @@ -1110,6 +1106,7 @@ func (b *SystemBackend) handleMount(ctx context.Context, req *logical.Request, d } } + version := apiConfig.PluginVersion switch version { case "": var err error @@ -2349,10 +2346,6 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque sealWrap := data.Get("seal_wrap").(bool) externalEntropyAccess := data.Get("external_entropy_access").(bool) options := data.Get("options").(map[string]string) - var version string - if pluginVersionRaw, ok := data.GetOk("plugin_version"); ok { - version = pluginVersionRaw.(string) - } var config MountConfig var apiConfig APIMountConfig @@ -2446,6 +2439,7 @@ func (b *SystemBackend) handleEnableAuth(ctx context.Context, req *logical.Reque } } + version := apiConfig.PluginVersion switch version { case "": var err error diff --git a/vault/mount.go b/vault/mount.go index 69c8ae5a64..b369716022 100644 --- a/vault/mount.go +++ b/vault/mount.go @@ -368,6 +368,7 @@ type APIMountConfig struct { AllowedResponseHeaders []string `json:"allowed_response_headers,omitempty" structs:"allowed_response_headers" mapstructure:"allowed_response_headers"` TokenType string `json:"token_type" structs:"token_type" mapstructure:"token_type"` AllowedManagedKeys []string `json:"allowed_managed_keys,omitempty" mapstructure:"allowed_managed_keys"` + PluginVersion string `json:"plugin_version,omitempty" mapstructure:"plugin_version"` // PluginName is the name of the plugin registered in the catalog. // diff --git a/vault/plugin_catalog.go b/vault/plugin_catalog.go index 329a02b8e4..14c6e39460 100644 --- a/vault/plugin_catalog.go +++ b/vault/plugin_catalog.go @@ -922,6 +922,7 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug // Users don't expect to see the plugin type, so we need to strip that here. var normalizedName, version string var semanticVersion *semver.Version + storedType := consts.PluginTypeUnknown parts := strings.Split(plugin, "/") switch len(parts) { @@ -933,7 +934,7 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug return nil, err } case 2: // Unversioned - if isPluginType(parts[0]) { + if storedType, err = consts.ParsePluginType(parts[0]); err == nil { normalizedName = parts[1] // Use 0.0.0 to ensure unversioned is sorted as the oldest version. semanticVersion, err = semver.NewVersion("0.0.0") @@ -941,13 +942,17 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug return nil, err } } else { - return nil, fmt.Errorf("unknown plugin type in plugin catalog: %s", plugin) + return nil, fmt.Errorf("unknown plugin type in plugin catalog: %s: %w", plugin, err) } case 3: // Versioned, with type if !includeVersioned { continue } + storedType, err = consts.ParsePluginType(parts[0]) + if err != nil { + return nil, fmt.Errorf("unexpected error parsing plugin type from plugin catalog entry %q: %w", plugin, err) + } normalizedName, version = parts[1], parts[2] semanticVersion, err = semver.NewVersion(version) if err != nil { @@ -958,18 +963,24 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug } // Only list user-added plugins if they're of the given type. - if entry, err := c.get(ctx, normalizedName, pluginType, version); err == nil && entry != nil { - result = append(result, pluginutil.VersionedPlugin{ - Name: normalizedName, - Type: pluginType.String(), - Version: version, - SHA256: hex.EncodeToString(entry.Sha256), - SemanticVersion: semanticVersion, - }) + if storedType != consts.PluginTypeUnknown && storedType != pluginType { + continue + } + entry, err := c.get(ctx, normalizedName, pluginType, version) + if err != nil || entry == nil { + continue + } - if version == "" { - unversionedPlugins[normalizedName] = struct{}{} - } + result = append(result, pluginutil.VersionedPlugin{ + Name: normalizedName, + Type: pluginType.String(), + Version: version, + SHA256: hex.EncodeToString(entry.Sha256), + SemanticVersion: semanticVersion, + }) + + if version == "" { + unversionedPlugins[normalizedName] = struct{}{} } } @@ -999,8 +1010,3 @@ func (c *PluginCatalog) listInternal(ctx context.Context, pluginType consts.Plug return result, nil } - -func isPluginType(s string) bool { - _, err := consts.ParsePluginType(s) - return err == nil -}