From 0d598a7f1ecda9591983e09b24297ee6ef0c49da Mon Sep 17 00:00:00 2001 From: Seth Vargo Date: Tue, 5 Sep 2017 00:03:29 -0400 Subject: [PATCH] Update policy-write command --- command/policy_write.go | 137 ++++++++++++++-------- command/policy_write_test.go | 214 +++++++++++++++++++++++++++++++---- 2 files changed, 283 insertions(+), 68 deletions(-) diff --git a/command/policy_write.go b/command/policy_write.go index 59b26fb472..73979458c6 100644 --- a/command/policy_write.go +++ b/command/policy_write.go @@ -7,84 +7,125 @@ import ( "os" "strings" - "github.com/hashicorp/vault/meta" + "github.com/mitchellh/cli" + "github.com/posener/complete" ) -// PolicyWriteCommand is a Command that enables a new endpoint. +// Ensure we are implementing the right interfaces. +var _ cli.Command = (*PolicyWriteCommand)(nil) +var _ cli.CommandAutocomplete = (*PolicyWriteCommand)(nil) + +// PolicyWriteCommand is a Command uploads a policy type PolicyWriteCommand struct { - meta.Meta + *BaseCommand + + testStdin io.Reader // for tests +} + +func (c *PolicyWriteCommand) Synopsis() string { + return "Uploads a policy file" +} + +func (c *PolicyWriteCommand) Help() string { + helpText := ` +Usage: vault policy-write [options] NAME PATH + + Uploads a policy with the given name from the contents of a local file or + stdin. If the path is "-", the policy is read from stdin. Otherwise, it is + loaded from the file at the given path. + + Upload a policy named "my-policy" from /tmp/policy.hcl on the local disk: + + $ vault policy-write my-policy /tmp/policy.hcl + + Upload a policy from stdin: + + $ cat my-policy.hcl | vault policy-write my-policy - + + For a full list of examples, please see the documentation. + +` + c.Flags().Help() + + return strings.TrimSpace(helpText) +} + +func (c *PolicyWriteCommand) Flags() *FlagSets { + return c.flagSet(FlagSetHTTP) +} + +func (c *PolicyWriteCommand) AutocompleteArgs() complete.Predictor { + return complete.PredictFunc(func(args complete.Args) []string { + // Predict the LAST argument hcl files - we don't want to predict the + // name argument as a filepath. + if len(args.All) == 3 { + return complete.PredictFiles("*.hcl").Predict(args) + } + return nil + }) +} + +func (c *PolicyWriteCommand) AutocompleteFlags() complete.Flags { + return c.Flags().Completions() } func (c *PolicyWriteCommand) Run(args []string) int { - flags := c.Meta.FlagSet("policy-write", meta.FlagSetDefault) - flags.Usage = func() { c.Ui.Error(c.Help()) } - if err := flags.Parse(args); err != nil { + f := c.Flags() + + if err := f.Parse(args); err != nil { + c.UI.Error(err.Error()) return 1 } - args = flags.Args() - if len(args) != 2 { - flags.Usage() - c.Ui.Error(fmt.Sprintf( - "\npolicy-write expects exactly two arguments")) + args = f.Args() + switch { + case len(args) < 2: + c.UI.Error(fmt.Sprintf("Not enough arguments (expected 2, got %d)", len(args))) + return 1 + case len(args) > 2: + c.UI.Error(fmt.Sprintf("Too many arguments (expected 2, got %d)", len(args))) return 1 } client, err := c.Client() if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error initializing client: %s", err)) + c.UI.Error(err.Error()) return 2 } // Policies are normalized to lowercase - name := strings.ToLower(args[0]) - path := args[1] + name := strings.TrimSpace(strings.ToLower(args[0])) + path := strings.TrimSpace(args[1]) - // Read the policy - var f io.Reader = os.Stdin - if path != "-" { + // Get the policy contents, either from stdin of a file + var reader io.Reader + if path == "-" { + reader = os.Stdin + if c.testStdin != nil { + reader = c.testStdin + } + } else { file, err := os.Open(path) if err != nil { - c.Ui.Error(fmt.Sprintf( - "Error opening file: %s", err)) - return 1 + c.UI.Error(fmt.Sprintf("Error opening policy file: %s", err)) + return 2 } defer file.Close() - f = file + reader = file } + + // Read the policy var buf bytes.Buffer - if _, err := io.Copy(&buf, f); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error reading file: %s", err)) - return 1 + if _, err := io.Copy(&buf, reader); err != nil { + c.UI.Error(fmt.Sprintf("Error reading policy: %s", err)) + return 2 } rules := buf.String() if err := client.Sys().PutPolicy(name, rules); err != nil { - c.Ui.Error(fmt.Sprintf( - "Error: %s", err)) - return 1 + c.UI.Error(fmt.Sprintf("Error uploading policy: %s", err)) + return 2 } - c.Ui.Output(fmt.Sprintf("Policy '%s' written.", name)) + c.UI.Output(fmt.Sprintf("Success! Uploaded policy: %s", name)) return 0 } - -func (c *PolicyWriteCommand) Synopsis() string { - return "Write a policy to the server" -} - -func (c *PolicyWriteCommand) Help() string { - helpText := ` -Usage: vault policy-write [options] name path - - Write a policy with the given name from the contents of a file or stdin. - - If the path is "-", the policy is read from stdin. Otherwise, it is - loaded from the file at the given path. - -General Options: -` + meta.GeneralOptionsUsage() - return strings.TrimSpace(helpText) -} diff --git a/command/policy_write_test.go b/command/policy_write_test.go index d0deeaac69..c8db7dc9dd 100644 --- a/command/policy_write_test.go +++ b/command/policy_write_test.go @@ -1,33 +1,207 @@ package command import ( + "bytes" + "io" + "io/ioutil" + "os" + "reflect" + "strings" "testing" - "github.com/hashicorp/vault/http" - "github.com/hashicorp/vault/meta" - "github.com/hashicorp/vault/vault" "github.com/mitchellh/cli" ) -func TestPolicyWrite(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) - ln, addr := http.TestServer(t, core) - defer ln.Close() +func testPolicyWriteCommand(tb testing.TB) (*cli.MockUi, *PolicyWriteCommand) { + tb.Helper() - ui := new(cli.MockUi) - c := &PolicyWriteCommand{ - Meta: meta.Meta{ - ClientToken: token, - Ui: ui, + ui := cli.NewMockUi() + return ui, &PolicyWriteCommand{ + BaseCommand: &BaseCommand{ + UI: ui, + }, + } +} + +func testPolicyWritePolicyContents(tb testing.TB) []byte { + return bytes.TrimSpace([]byte(` +path "secret/" { + capabilities = ["read"] +} + `)) +} + +func TestPolicyWriteCommand_Run(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + args []string + out string + code int + }{ + { + "too_many_args", + []string{"foo", "bar", "baz"}, + "Too many arguments", + 1, + }, + { + "not_enough_args", + []string{"foo"}, + "Not enough arguments", + 1, + }, + { + "bad_file", + []string{"my-policy", "/not/a/real/path.hcl"}, + "Error opening policy file", + 2, }, } - args := []string{ - "-address", addr, - "foo", - "./test-fixtures/policy.hcl", - } - if code := c.Run(args); code != 0 { - t.Fatalf("bad: %d\n\n%s", code, ui.ErrorWriter.String()) - } + t.Run("validations", func(t *testing.T) { + t.Parallel() + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + + code := cmd.Run(tc.args) + if code != tc.code { + t.Errorf("expected %d to be %d", code, tc.code) + } + + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, tc.out) { + t.Errorf("expected %q to contain %q", combined, tc.out) + } + }) + } + }) + + t.Run("file", func(t *testing.T) { + t.Parallel() + + policy := testPolicyWritePolicyContents(t) + f, err := ioutil.TempFile("", "vault-policy-write") + if err != nil { + t.Fatal(err) + } + if _, err := f.Write(policy); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", f.Name(), + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Uploaded policy: my-policy" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + policies, err := client.Sys().ListPolicies() + if err != nil { + t.Fatal(err) + } + + list := []string{"default", "my-policy", "root"} + if !reflect.DeepEqual(policies, list) { + t.Errorf("expected %q to be %q", policies, list) + } + }) + + t.Run("stdin", func(t *testing.T) { + t.Parallel() + + stdinR, stdinW := io.Pipe() + go func() { + policy := testPolicyWritePolicyContents(t) + stdinW.Write(policy) + stdinW.Close() + }() + + client, closer := testVaultServer(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + cmd.testStdin = stdinR + + code := cmd.Run([]string{ + "my-policy", "-", + }) + if exp := 0; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Success! Uploaded policy: my-policy" + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + + policies, err := client.Sys().ListPolicies() + if err != nil { + t.Fatal(err) + } + + list := []string{"default", "my-policy", "root"} + if !reflect.DeepEqual(policies, list) { + t.Errorf("expected %q to be %q", policies, list) + } + }) + + t.Run("communication_failure", func(t *testing.T) { + t.Parallel() + + client, closer := testVaultServerBad(t) + defer closer() + + ui, cmd := testPolicyWriteCommand(t) + cmd.client = client + + code := cmd.Run([]string{ + "my-policy", "-", + }) + if exp := 2; code != exp { + t.Errorf("expected %d to be %d", code, exp) + } + + expected := "Error uploading policy: " + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + if !strings.Contains(combined, expected) { + t.Errorf("expected %q to contain %q", combined, expected) + } + }) + + t.Run("no_tabs", func(t *testing.T) { + t.Parallel() + + _, cmd := testPolicyWriteCommand(t) + assertNoTabs(t, cmd) + }) }