diff --git a/api/sys_plugins.go b/api/sys_plugins.go index dc279a9dd8..989c78f1d5 100644 --- a/api/sys_plugins.go +++ b/api/sys_plugins.go @@ -133,7 +133,8 @@ type GetPluginInput struct { Name string `json:"-"` // Type of the plugin. Required. - Type consts.PluginType `json:"type"` + Type consts.PluginType `json:"type"` + Version string `json:"version"` } // GetPluginResponse is the response from the GetPlugin call. @@ -144,6 +145,7 @@ type GetPluginResponse struct { Name string `json:"name"` SHA256 string `json:"sha256"` DeprecationStatus string `json:"deprecation_status,omitempty"` + Version string `json:"version,omitempty"` } // GetPlugin wraps GetPluginWithContext using context.Background. @@ -158,6 +160,9 @@ func (c *Sys) GetPluginWithContext(ctx context.Context, i *GetPluginInput) (*Get path := catalogPathByType(i.Type, i.Name) req := c.c.NewRequest(http.MethodGet, path) + if i.Version != "" { + req.Params.Set("version", i.Version) + } resp, err := c.c.rawRequestWithContext(ctx, req) if err != nil { diff --git a/api/sys_plugins_test.go b/api/sys_plugins_test.go index 14ba98b349..b3b94d7302 100644 --- a/api/sys_plugins_test.go +++ b/api/sys_plugins_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "reflect" "testing" "github.com/hashicorp/vault/sdk/helper/consts" @@ -115,6 +116,141 @@ func TestListPlugins(t *testing.T) { } } +func TestGetPlugin(t *testing.T) { + for name, tc := range map[string]struct { + version string + body string + expected GetPluginResponse + }{ + "builtin": { + body: getResponse, + expected: GetPluginResponse{ + Args: nil, + Builtin: true, + Command: "", + Name: "azure", + SHA256: "", + DeprecationStatus: "supported", + Version: "v0.14.0+builtin", + }, + }, + "external": { + version: "v1.0.0", + body: getResponseExternal, + expected: GetPluginResponse{ + Args: []string{}, + Builtin: false, + Command: "azure-plugin", + Name: "azure", + SHA256: "8ba442dba253803685b05e35ad29dcdebc48dec16774614aa7a4ebe53c1e90e1", + DeprecationStatus: "", + Version: "v1.0.0", + }, + }, + "old server": { + body: getResponseOldServerVersion, + expected: GetPluginResponse{ + Args: nil, + Builtin: true, + Command: "", + Name: "azure", + SHA256: "", + DeprecationStatus: "", + Version: "", + }, + }, + } { + t.Run(name, func(t *testing.T) { + mockVaultServer := httptest.NewServer(http.HandlerFunc(mockVaultHandlerInfo(tc.body))) + defer mockVaultServer.Close() + + cfg := DefaultConfig() + cfg.Address = mockVaultServer.URL + client, err := NewClient(cfg) + if err != nil { + t.Fatal(err) + } + + input := GetPluginInput{ + Name: "azure", + Type: consts.PluginTypeSecrets, + } + if tc.version != "" { + input.Version = tc.version + } + + info, err := client.Sys().GetPluginWithContext(context.Background(), &input) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(tc.expected, *info) { + t.Errorf("expected: %#v\ngot: %#v", tc.expected, info) + } + }) + } +} + +func mockVaultHandlerInfo(body string) func(w http.ResponseWriter, _ *http.Request) { + return func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(body)) + } +} + +const getResponse = `{ + "request_id": "e93d3f93-8e4f-8443-a803-f1c97c495241", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "args": null, + "builtin": true, + "command": "", + "deprecation_status": "supported", + "name": "azure", + "sha256": "", + "version": "v0.14.0+builtin" + }, + "wrap_info": null, + "warnings": null, + "auth": null +}` + +const getResponseExternal = `{ + "request_id": "e93d3f93-8e4f-8443-a803-f1c97c495241", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "args": [], + "builtin": false, + "command": "azure-plugin", + "name": "azure", + "sha256": "8ba442dba253803685b05e35ad29dcdebc48dec16774614aa7a4ebe53c1e90e1", + "version": "v1.0.0" + }, + "wrap_info": null, + "warnings": null, + "auth": null +}` + +const getResponseOldServerVersion = `{ + "request_id": "e93d3f93-8e4f-8443-a803-f1c97c495241", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "args": null, + "builtin": true, + "command": "", + "name": "azure", + "sha256": "" + }, + "wrap_info": null, + "warnings": null, + "auth": null +}` + func mockVaultHandlerList(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte(listUntypedResponse)) } diff --git a/command/plugin_deregister_test.go b/command/plugin_deregister_test.go index 7a6bc12d41..fc3bc5801e 100644 --- a/command/plugin_deregister_test.go +++ b/command/plugin_deregister_test.go @@ -84,7 +84,7 @@ func TestPluginDeregisterCommand_Run(t *testing.T) { defer closer() pluginName := "my-plugin" - _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential) + _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential, "") ui, cmd := testPluginDeregisterCommand(t) cmd.client = client diff --git a/command/plugin_info.go b/command/plugin_info.go index bb7a4a5053..8fedb98315 100644 --- a/command/plugin_info.go +++ b/command/plugin_info.go @@ -17,6 +17,8 @@ var ( type PluginInfoCommand struct { *BaseCommand + + flagVersion string } func (c *PluginInfoCommand) Synopsis() string { @@ -41,7 +43,18 @@ Usage: vault plugin info [options] TYPE NAME } func (c *PluginInfoCommand) Flags() *FlagSets { - return c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + set := c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + + f := set.NewFlagSet("Command Options") + + f.StringVar(&StringVar{ + Name: "version", + Target: &c.flagVersion, + Completion: complete.PredictAnything, + Usage: "Semantic version of the plugin. Optional.", + }) + + return set } func (c *PluginInfoCommand) AutocompleteArgs() complete.Predictor { @@ -93,8 +106,9 @@ func (c *PluginInfoCommand) Run(args []string) int { pluginName := strings.TrimSpace(pluginNameRaw) resp, err := client.Sys().GetPlugin(&api.GetPluginInput{ - Name: pluginName, - Type: pluginType, + Name: pluginName, + Type: pluginType, + Version: c.flagVersion, }) if err != nil { c.UI.Error(fmt.Sprintf("Error reading plugin named %s: %s", pluginName, err)) @@ -113,6 +127,7 @@ func (c *PluginInfoCommand) Run(args []string) int { "name": resp.Name, "sha256": resp.SHA256, "deprecation_status": resp.DeprecationStatus, + "version": resp.Version, } if c.flagField != "" { diff --git a/command/plugin_info_test.go b/command/plugin_info_test.go index cfdab72ab3..714ac1e59b 100644 --- a/command/plugin_info_test.go +++ b/command/plugin_info_test.go @@ -4,6 +4,7 @@ import ( "strings" "testing" + "github.com/hashicorp/vault/helper/versions" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" @@ -81,7 +82,7 @@ func TestPluginInfoCommand_Run(t *testing.T) { defer closer() pluginName := "my-plugin" - _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential) + _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential, "") ui, cmd := testPluginInfoCommand(t) cmd.client = client @@ -102,6 +103,52 @@ func TestPluginInfoCommand_Run(t *testing.T) { } }) + t.Run("version flag", func(t *testing.T) { + t.Parallel() + + pluginDir, cleanup := vault.MakeTestPluginDir(t) + defer cleanup(t) + + client, _, closer := testVaultServerPluginDir(t, pluginDir) + defer closer() + + const pluginName = "azure" + _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential, "v1.0.0") + + for name, tc := range map[string]struct { + version string + expectedSHA string + }{ + "versioned": {"v1.0.0", sha256Sum}, + "builtin version": {versions.GetBuiltinVersion(consts.PluginTypeSecrets, pluginName), ""}, + } { + t.Run(name, func(t *testing.T) { + ui, cmd := testPluginInfoCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "-version=" + tc.version, + consts.PluginTypeCredential.String(), pluginName, + }) + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if exp := 0; code != exp { + t.Errorf("expected %d to be %d: %s", code, exp, combined) + } + + if !strings.Contains(combined, pluginName) { + t.Errorf("expected %q to contain %q", combined, pluginName) + } + if !strings.Contains(combined, tc.expectedSHA) { + t.Errorf("expected %q to contain %q", combined, tc.expectedSHA) + } + if !strings.Contains(combined, tc.version) { + t.Errorf("expected %q to contain %q", combined, tc.version) + } + }) + } + }) + t.Run("field", func(t *testing.T) { t.Parallel() @@ -112,7 +159,7 @@ func TestPluginInfoCommand_Run(t *testing.T) { defer closer() pluginName := "my-plugin" - testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential) + testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential, "") ui, cmd := testPluginInfoCommand(t) cmd.client = client diff --git a/command/plugin_reload_test.go b/command/plugin_reload_test.go index 6c4982295b..5713d1a150 100644 --- a/command/plugin_reload_test.go +++ b/command/plugin_reload_test.go @@ -90,7 +90,7 @@ func TestPluginReloadCommand_Run(t *testing.T) { defer closer() pluginName := "my-plugin" - _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential) + _, sha256Sum := testPluginCreateAndRegister(t, client, pluginDir, pluginName, consts.PluginTypeCredential, "") ui, cmd := testPluginReloadCommand(t) cmd.client = client diff --git a/command/plugin_test.go b/command/plugin_test.go index be40abef8e..cc83efc772 100644 --- a/command/plugin_test.go +++ b/command/plugin_test.go @@ -38,7 +38,7 @@ func testPluginCreate(tb testing.TB, dir, name string) (string, string) { } // testPluginCreateAndRegister creates a plugin and registers it in the catalog. -func testPluginCreateAndRegister(tb testing.TB, client *api.Client, dir, name string, pluginType consts.PluginType) (string, string) { +func testPluginCreateAndRegister(tb testing.TB, client *api.Client, dir, name string, pluginType consts.PluginType, version string) (string, string) { tb.Helper() pth, sha256Sum := testPluginCreate(tb, dir, name) @@ -48,6 +48,7 @@ func testPluginCreateAndRegister(tb testing.TB, client *api.Client, dir, name st Type: pluginType, Command: name, SHA256: sha256Sum, + Version: version, }); err != nil { tb.Fatal(err) } diff --git a/vault/logical_system.go b/vault/logical_system.go index 8f5c372fa9..1d1b7de8ba 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -473,10 +473,13 @@ func (b *SystemBackend) handlePluginCatalogUpdate(ctx context.Context, _ *logica return nil, err } - pluginVersion, err := getVersion(d) + pluginVersion, builtin, err := getVersion(d) if err != nil { return logical.ErrorResponse(err.Error()), nil } + if builtin { + return logical.ErrorResponse("version %q is not allowed because 'builtin' is a reserved metadata identifier", pluginVersion), nil + } sha256 := d.Get("sha256").(string) if sha256 == "" { @@ -546,7 +549,7 @@ func (b *SystemBackend) handlePluginCatalogRead(ctx context.Context, _ *logical. return nil, err } - pluginVersion, err := getVersion(d) + pluginVersion, _, err := getVersion(d) if err != nil { return logical.ErrorResponse(err.Error()), nil } @@ -592,10 +595,13 @@ func (b *SystemBackend) handlePluginCatalogDelete(ctx context.Context, _ *logica return logical.ErrorResponse("missing plugin name"), nil } - pluginVersion, err := getVersion(d) + pluginVersion, builtin, err := getVersion(d) if err != nil { return logical.ErrorResponse(err.Error()), nil } + if builtin { + return logical.ErrorResponse("version %q cannot be deleted", pluginVersion), nil + } var resp *logical.Response pluginTypeStr := d.Get("type").(string) @@ -620,18 +626,19 @@ func (b *SystemBackend) handlePluginCatalogDelete(ctx context.Context, _ *logica return resp, nil } -func getVersion(d *framework.FieldData) (string, error) { - version := d.Get("version").(string) +func getVersion(d *framework.FieldData) (version string, builtin bool, err error) { + version = d.Get("version").(string) if version != "" { semanticVersion, err := semver.NewSemver(version) if err != nil { - return "", fmt.Errorf("version %q is not a valid semantic version: %w", version, err) + return "", false, fmt.Errorf("version %q is not a valid semantic version: %w", version, err) } metadataIdentifiers := strings.Split(semanticVersion.Metadata(), ".") for _, identifier := range metadataIdentifiers { if identifier == "builtin" { - return "", fmt.Errorf("version %q is not allowed because 'builtin' is a reserved metadata identifier", version) + builtin = true + break } } @@ -640,7 +647,7 @@ func getVersion(d *framework.FieldData) (string, error) { version = "v" + semanticVersion.String() } - return version, nil + return version, builtin, nil } func (b *SystemBackend) handlePluginReloadUpdate(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {