diff --git a/command/mounts.go b/command/mounts.go index 403d9d4d91..b25bfe4dcb 100644 --- a/command/mounts.go +++ b/command/mounts.go @@ -6,93 +6,165 @@ import ( "strconv" "strings" - "github.com/hashicorp/vault/meta" - "github.com/ryanuber/columnize" + "github.com/hashicorp/vault/api" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) +// Ensure we are implementing the right interfaces. +var _ cli.Command = (*MountsCommand)(nil) +var _ cli.CommandAutocomplete = (*MountsCommand)(nil) + // MountsCommand is a Command that lists the mounts. type MountsCommand struct { - meta.Meta -} + *BaseCommand -func (c *MountsCommand) Run(args []string) int { - flags := c.Meta.FlagSet("mounts", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - mounts, err := client.Sys().ListMounts() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading mounts: %s", err)) - return 2 - } - - paths := make([]string, 0, len(mounts)) - for path := range mounts { - paths = append(paths, path) - } - sort.Strings(paths) - - columns := []string{"Path | Type | Accessor | Plugin | Default TTL | Max TTL | Force No Cache | Replication Behavior | Description"} - for _, path := range paths { - mount := mounts[path] - pluginName := "n/a" - if mount.Config.PluginName != "" { - pluginName = mount.Config.PluginName - } - defTTL := "system" - switch { - case mount.Type == "system", mount.Type == "cubbyhole", mount.Type == "identity": - defTTL = "n/a" - case mount.Config.DefaultLeaseTTL != 0: - defTTL = strconv.Itoa(mount.Config.DefaultLeaseTTL) - } - - maxTTL := "system" - switch { - case mount.Type == "system", mount.Type == "cubbyhole", mount.Type == "identity": - maxTTL = "n/a" - case mount.Config.MaxLeaseTTL != 0: - maxTTL = strconv.Itoa(mount.Config.MaxLeaseTTL) - } - - replicatedBehavior := "replicated" - if mount.Local { - replicatedBehavior = "local" - } - columns = append(columns, fmt.Sprintf( - "%s | %s | %s | %s | %s | %s | %v | %s | %s", path, mount.Type, mount.Accessor, pluginName, defTTL, maxTTL, - mount.Config.ForceNoCache, replicatedBehavior, mount.Description)) - } - - c.Ui.Output(columnize.SimpleFormat(columns)) - return 0 + flagDetailed bool } func (c *MountsCommand) Synopsis() string { - return "Lists mounted backends in Vault" + return "Lists mounted secret backends" } func (c *MountsCommand) Help() string { helpText := ` Usage: vault mounts [options] - Outputs information about the mounted backends. + Lists the mounted secret backends on the Vault server. This command also + outputs information about the mount point including configured TTLs and + human-friendly descriptions. A TTL of "system" indicates that the system + default is in use. - This command lists the mounted backends, their mount points, the - configured TTLs, and a human-friendly description of the mount point. - A TTL of 'system' indicates that the system default is being used. + List all mounts: + + $ vault mounts + + List all mounts with detailed output: + + $ vault mounts -detailed + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *MountsCommand) Flags() *FlagSets { + set := c.flagSet(FlagSetHTTP) + + f := set.NewFlagSet("Command Options") + + f.BoolVar(&BoolVar{ + Name: "detailed", + Target: &c.flagDetailed, + Default: false, + Usage: "Print detailed information such as TTLs and replication status " + + "about each mount.", + }) + + return set +} + +func (c *MountsCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultFiles() +} + +func (c *MountsCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *MountsCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + if len(args) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 0, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + c.UI.Error(fmt.Sprintf("Error listing mounts: %s", err)) + return 2 + } + + if c.flagDetailed { + c.UI.Output(tableOutput(c.detailedMounts(mounts))) + return 0 + } + + c.UI.Output(tableOutput(c.simpleMounts(mounts))) + return 0 +} + +func (c *MountsCommand) simpleMounts(mounts map[string]*api.MountOutput) []string { + paths := make([]string, 0, len(mounts)) + for path := range mounts { + paths = append(paths, path) + } + sort.Strings(paths) + + out := []string{"Path | Type | Description"} + for _, path := range paths { + mount := mounts[path] + out = append(out, fmt.Sprintf("%s | %s | %s", path, mount.Type, mount.Description)) + } + + return out +} + +func (c *MountsCommand) detailedMounts(mounts map[string]*api.MountOutput) []string { + paths := make([]string, 0, len(mounts)) + for path := range mounts { + paths = append(paths, path) + } + sort.Strings(paths) + + calcTTL := func(typ string, ttl int) string { + switch { + case typ == "system", typ == "cubbyhole": + return "" + case ttl != 0: + return strconv.Itoa(ttl) + default: + return "system" + } + } + + out := []string{"Path | Type | Accessor | Plugin | Default TTL | Max TTL | Force No Cache | Replication | Description"} + for _, path := range paths { + mount := mounts[path] + + defaultTTL := calcTTL(mount.Type, mount.Config.DefaultLeaseTTL) + maxTTL := calcTTL(mount.Type, mount.Config.MaxLeaseTTL) + + replication := "replicated" + if mount.Local { + replication = "local" + } + + out = append(out, fmt.Sprintf("%s | %s | %s | %s | %s | %s | %v | %s | %s", + path, + mount.Type, + mount.Accessor, + mount.Config.PluginName, + defaultTTL, + maxTTL, + mount.Config.ForceNoCache, + replication, + mount.Description, + )) + } + + return out +} diff --git a/command/mounts_test.go b/command/mounts_test.go index 55e5f679f6..8f1c8b9c85 100644 --- a/command/mounts_test.go +++ b/command/mounts_test.go @@ -1,31 +1,105 @@ package command import ( + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestMounts(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testMountsCommand(tb testing.TB) (*cli.MockUi, *MountsCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &MountsCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &MountsCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func TestMountsCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo"}, + "Too many arguments", + 1, + }, + { + "lists", + nil, + "Path", + 0, + }, + { + "detailed", + []string{"-detailed"}, + "Default TTL", + 0, }, } - args := []string{ - "-address", addr, - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testMountsCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testMountsCommand(t) + cmd.client = client + + code := cmd.Run([]string{}) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error listing mounts: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testMountsCommand(t) + assertNoTabs(t, cmd) + }) }