diff --git a/command/audit_disable.go b/command/audit_disable.go index 31c4457287..6bb9f68ff9 100644 --- a/command/audit_disable.go +++ b/command/audit_disable.go @@ -4,68 +4,87 @@ import ( "fmt" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) +// Ensure we are implementing the right interfaces. +var _ cli.Command = (*AuditDisableCommand)(nil) +var _ cli.CommandAutocomplete = (*AuditDisableCommand)(nil) + // AuditDisableCommand is a Command that mounts a new mount. type AuditDisableCommand struct { - meta.Meta -} - -func (c *AuditDisableCommand) Run(args []string) int { - flags := c.Meta.FlagSet("mount", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { - return 1 - } - - args = flags.Args() - if len(args) != 1 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\naudit-disable expects one argument: the id to disable")) - return 1 - } - - id := args[0] - - client, err := c.Client() - if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) - return 2 - } - - if err := client.Sys().DisableAudit(id); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error disabling audit backend: %s", err)) - return 2 - } - - c.Ui.Output(fmt.Sprintf( - "Successfully disabled audit backend '%s' if it was enabled", id)) - - return 0 + *BaseCommand } func (c *AuditDisableCommand) Synopsis() string { - return "Disable an audit backend" + return "Disables an audit backend" } func (c *AuditDisableCommand) Help() string { helpText := ` -Usage: vault audit-disable [options] id +Usage: vault audit-disable [options] PATH - Disable an audit backend. + Disables an audit backend. Once an audit backend is disabled, no future + audit logs are dispatched to it. The data associated with the audit backend + is not affected. - Once the audit backend is disabled no more audit logs will be sent to - it. The data associated with the audit backend isn't affected. + The argument corresponds to the PATH of the mount, not the TYPE! - The "id" parameter should map to the "path" used in "audit-enable". If - no path was provided to "audit-enable" you should use the backend - type (e.g. "file"). + Disable the audit backend at file/: + + $ vault audit-disable file/ + +` + c.Flags().Help() -General Options: -` + meta.GeneralOptionsUsage() return strings.TrimSpace(helpText) } + +func (c *AuditDisableCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *AuditDisableCommand) AutocompleteArgs() complete.Predictor { + return c.PredictVaultAudits() +} + +func (c *AuditDisableCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() +} + +func (c *AuditDisableCommand) Run(args []string) int { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) + return 1 + } + + args = f.Args() + path, kvs, err := extractPath(args) + if err != nil { + c.UI.Error(err.Error()) + return 1 + } + path = ensureTrailingSlash(path) + + if len(kvs) > 0 { + c.UI.Error(fmt.Sprintf("Too many arguments (expected 1, got %d)", len(args))) + return 1 + } + + client, err := c.Client() + if err != nil { + c.UI.Error(err.Error()) + return 2 + } + + if err := client.Sys().DisableAudit(path); err != nil { + c.UI.Error(fmt.Sprintf("Error disabling audit backend: %s", err)) + return 2 + } + + c.UI.Output(fmt.Sprintf("Success! Disabled audit backend (if it was enabled) at: %s", path)) + + return 0 +} diff --git a/command/audit_disable_test.go b/command/audit_disable_test.go index 500ee9ccb1..6980179abd 100644 --- a/command/audit_disable_test.go +++ b/command/audit_disable_test.go @@ -1,86 +1,160 @@ package command import ( + "strings" "testing" "github.com/hashicorp/vault/api" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestAuditDisable(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testAuditDisableCommand(tb testing.TB) (*cli.MockUi, *AuditDisableCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &AuditDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &AuditDisableCommand{ + BaseCommand: &BaseCommand{ + UI: ui, }, } - - args := []string{ - "-address", addr, - "noop", - } - - // Run once to get the client - c.Run(args) - - // Get the client - client, err := c.Client() - if err != nil { - t.Fatalf("err: %#v", err) - } - if err := client.Sys().EnableAudit("noop", "noop", "", nil); err != nil { - t.Fatalf("err: %#v", err) - } - - // Run again - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } } -func TestAuditDisableWithOptions(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func TestAuditDisableCommand_Run(t *testing.T) { + t.Parallel() - ui := new(cli.MockUi) - c := &AuditDisableCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + cases := []struct { + name string + args []string + out string + code int + }{ + { + "empty", + nil, + "Missing PATH!", + 1, + }, + { + "slash", + []string{"/"}, + "Missing PATH!", + 1, + }, + { + "not_real", + []string{"not_real"}, + "Success! Disabled audit backend (if it was enabled) at: not_real/", + 0, + }, + { + "default", + []string{"file"}, + "Success! Disabled audit backend (if it was enabled) at: file/", + 0, }, } - args := []string{ - "-address", addr, - "noop", + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + if err := client.Sys().EnableAuditWithOptions("file", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": "discard", + }, + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuditDisableCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) } - // Run once to get the client - c.Run(args) + t.Run("integration", func(t *testing.T) { + t.Parallel() - // Get the client - client, err := c.Client() - if err != nil { - t.Fatalf("err: %#v", err) - } - if err := client.Sys().EnableAuditWithOptions("noop", &api.EnableAuditOptions{ - Type: "noop", - Description: "noop", - }); err != nil { - t.Fatalf("err: %#v", err) - } + client, closer := testVaultServer(t) + defer closer() - // Run again - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + if err := client.Sys().EnableAuditWithOptions("integration_audit_disable", &api.EnableAuditOptions{ + Type: "file", + Options: map[string]string{ + "file_path": "discard", + }, + }); err != nil { + t.Fatal(err) + } + + ui, cmd := testAuditDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "integration_audit_disable/", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Disabled audit backend (if it was enabled) at: integration_audit_disable/" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + mounts, err := client.Sys().ListMounts() + if err != nil { + t.Fatal(err) + } + + if _, ok := mounts["integration_audit_disable"]; ok { + t.Errorf("expected mount to not exist: %#v", mounts) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testAuditDisableCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "file", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error disabling audit backend: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testAuditDisableCommand(t) + assertNoTabs(t, cmd) + }) }