diff --git a/command/base_helpers.go b/command/base_helpers.go index ae574f6d60..4c5bad6b88 100644 --- a/command/base_helpers.go +++ b/command/base_helpers.go @@ -2,35 +2,72 @@ package command import ( "fmt" + "io" "strings" + "time" + + "github.com/hashicorp/vault/api" + kvbuilder "github.com/hashicorp/vault/helper/kv-builder" + homedir "github.com/mitchellh/go-homedir" + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" + "github.com/ryanuber/columnize" ) +var ErrMissingID = fmt.Errorf("Missing ID!") var ErrMissingPath = fmt.Errorf("Missing PATH!") +var ErrMissingThing = fmt.Errorf("Missing THING!") + +// extractListData reads the secret and returns a typed list of data and a +// boolean indicating whether the extraction was successful. +func extractListData(secret *api.Secret) ([]interface{}, bool) { + if secret == nil || secret.Data == nil { + return nil, false + } + + k, ok := secret.Data["keys"] + if !ok || k == nil { + return nil, false + } + + i, ok := k.([]interface{}) + return i, ok +} // extractPath extracts the path and list of arguments from the args. If there // are no extra arguments, the remaining args will be nil. func extractPath(args []string) (string, []string, error) { + str, remaining, err := extractThings(args) + if err == ErrMissingThing { + err = ErrMissingPath + } + return str, remaining, err +} + +// extractID extracts the path and list of arguments from the args. If there +// are no extra arguments, the remaining args will be nil. +func extractID(args []string) (string, []string, error) { + str, remaining, err := extractThings(args) + if err == ErrMissingThing { + err = ErrMissingID + } + return str, remaining, err +} + +func extractThings(args []string) (string, []string, error) { if len(args) < 1 { - return "", nil, ErrMissingPath + return "", nil, ErrMissingThing } // Path is always the first argument after all flags - path := args[0] + thing := args[0] // Strip leading and trailing slashes - for len(path) > 0 && path[0] == '/' { - path = path[1:] - } - for len(path) > 0 && path[len(path)-1] == '/' { - path = path[:len(path)-1] - } + thing = sanitizePath(thing) - // Trim any leading/trailing whitespace - path = strings.TrimSpace(path) - - // Verify we have a path - if path == "" { - return "", nil, ErrMissingPath + // Verify we have a thing + if thing == "" { + return "", nil, ErrMissingThing } // Splice remaining args @@ -39,5 +76,150 @@ func extractPath(args []string) (string, []string, error) { remaining = args[1:] } - return path, remaining, nil + return thing, remaining, nil +} + +// sanitizePath removes any leading or trailing things from a "path". +func sanitizePath(s string) string { + return ensureNoTrailingSlash(ensureNoLeadingSlash(s)) +} + +// ensureTrailingSlash ensures the given string has a trailing slash. +func ensureTrailingSlash(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + + for len(s) > 0 && s[len(s)-1] != '/' { + s = s + "/" + } + return s +} + +// ensureNoTrailingSlash ensures the given string has a trailing slash. +func ensureNoTrailingSlash(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + + for len(s) > 0 && s[len(s)-1] == '/' { + s = s[:len(s)-1] + } + return s +} + +// ensureNoLeadingSlash ensures the given string has a trailing slash. +func ensureNoLeadingSlash(s string) string { + s = strings.TrimSpace(s) + if s == "" { + return "" + } + + for len(s) > 0 && s[0] == '/' { + s = s[1:] + } + return s +} + +// columnOuput prints the list of items as a table with no headers. +func columnOutput(list []string) string { + if len(list) == 0 { + return "" + } + + return columnize.Format(list, &columnize.Config{ + Glue: " ", + Empty: "n/a", + }) +} + +// tableOutput prints the list of items as columns, where the first row is +// the list of headers. +func tableOutput(list []string) string { + if len(list) == 0 { + return "" + } + + underline := "" + headers := strings.Split(list[0], "|") + for i, h := range headers { + h = strings.TrimSpace(h) + u := strings.Repeat("-", len(h)) + + underline = underline + u + if i != len(headers)-1 { + underline = underline + " | " + } + } + + list = append(list, "") + copy(list[2:], list[1:]) + list[1] = underline + + return columnOutput(list) +} + +// parseArgsData parses the given args in the format key=value into a map of +// the provided arguments. The given reader can also supply key=value pairs. +func parseArgsData(stdin io.Reader, args []string) (map[string]interface{}, error) { + builder := &kvbuilder.Builder{Stdin: stdin} + if err := builder.Add(args...); err != nil { + return nil, err + } + + return builder.Map(), nil +} + +// parseArgsDataString parses the args data and returns the values as strings. +// If the values cannot be represented as strings, an error is returned. +func parseArgsDataString(stdin io.Reader, args []string) (map[string]string, error) { + raw, err := parseArgsData(stdin, args) + if err != nil { + return nil, err + } + + var result map[string]string + if err := mapstructure.WeakDecode(raw, &result); err != nil { + return nil, errors.Wrap(err, "failed to convert values to strings") + } + return result, nil +} + +// truncateToSeconds truncates the given duaration to the number of seconds. If +// the duration is less than 1s, it is returned as 0. The integer represents +// the whole number unit of seconds for the duration. +func truncateToSeconds(d time.Duration) int { + d = d.Truncate(1 * time.Second) + + // Handle the case where someone requested a ridiculously short increment - + // incremenents must be larger than a second. + if d < 1*time.Second { + return 0 + } + + return int(d.Seconds()) +} + +// printKeyStatus prints the KeyStatus response from the API. +func printKeyStatus(ks *api.KeyStatus) string { + return columnOutput([]string{ + fmt.Sprintf("Key Term | %d", ks.Term), + fmt.Sprintf("Install Time | %s", ks.InstallTime.UTC().Format(time.RFC822)), + }) +} + +// expandPath takes a filepath and returns the full expanded path, accounting +// for user-relative things like ~/. +func expandPath(s string) string { + if s == "" { + return "" + } + + e, err := homedir.Expand(s) + if err != nil { + return s + } + return e } diff --git a/command/base_helpers_test.go b/command/base_helpers_test.go new file mode 100644 index 0000000000..87c0bff695 --- /dev/null +++ b/command/base_helpers_test.go @@ -0,0 +1,162 @@ +package command + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "testing" + "time" +) + +func TestParseArgsData(t *testing.T) { + t.Parallel() + + t.Run("stdin_full", func(t *testing.T) { + t.Parallel() + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(`{"foo":"bar"}`)) + stdinW.Close() + }() + + m, err := parseArgsData(stdinR, []string{"-"}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("stdin_value", func(t *testing.T) { + t.Parallel() + + stdinR, stdinW := io.Pipe() + go func() { + stdinW.Write([]byte(`bar`)) + stdinW.Close() + }() + + m, err := parseArgsData(stdinR, []string{"foo=-"}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("file_full", func(t *testing.T) { + t.Parallel() + + f, err := ioutil.TempFile("", "vault") + if err != nil { + t.Fatal(err) + } + f.Write([]byte(`{"foo":"bar"}`)) + f.Close() + defer os.Remove(f.Name()) + + m, err := parseArgsData(os.Stdin, []string{"@" + f.Name()}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("file_value", func(t *testing.T) { + t.Parallel() + + f, err := ioutil.TempFile("", "vault") + if err != nil { + t.Fatal(err) + } + f.Write([]byte(`bar`)) + f.Close() + defer os.Remove(f.Name()) + + m, err := parseArgsData(os.Stdin, []string{"foo=@" + f.Name()}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "bar" { + t.Errorf("expected %q to be %q", v, "bar") + } + }) + + t.Run("file_value_escaped", func(t *testing.T) { + t.Parallel() + + m, err := parseArgsData(os.Stdin, []string{`foo=\@`}) + if err != nil { + t.Fatal(err) + } + + if v, ok := m["foo"]; !ok || v != "@" { + t.Errorf("expected %q to be %q", v, "@") + } + }) +} + +func TestTruncateToSeconds(t *testing.T) { + t.Parallel() + + cases := []struct { + d time.Duration + exp int + }{ + { + 10 * time.Nanosecond, + 0, + }, + { + 10 * time.Microsecond, + 0, + }, + { + 10 * time.Millisecond, + 0, + }, + { + 1 * time.Second, + 1, + }, + { + 10 * time.Second, + 10, + }, + { + 100 * time.Second, + 100, + }, + { + 3 * time.Minute, + 180, + }, + { + 3 * time.Hour, + 10800, + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("%s", tc.d), func(t *testing.T) { + t.Parallel() + + act := truncateToSeconds(tc.d) + if act != tc.exp { + t.Errorf("expected %d to be %d", act, tc.exp) + } + }) + } +}