diff --git a/api/logical.go b/api/logical.go index 068e9068f3..bddb8b0765 100644 --- a/api/logical.go +++ b/api/logical.go @@ -64,6 +64,12 @@ func (c *Logical) ReadWithData(path string, data map[string][]string) (*Secret, return c.ReadWithDataWithContext(context.Background(), path, data) } +// ReadFromSnapshot reads the data at the given Vault path from a previously +// loaded snapshot. The snapshotID parameter is the ID of the loaded snapshot +func (c *Logical) ReadFromSnapshot(path string, snapshotID string) (*Secret, error) { + return c.ReadWithData(path, map[string][]string{"read_snapshot_id": {snapshotID}}) +} + func (c *Logical) ReadWithDataWithContext(ctx context.Context, path string, data map[string][]string) (*Secret, error) { ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) defer cancelFunc() @@ -104,6 +110,10 @@ func (c *Logical) ReadRawWithData(path string, data map[string][]string) (*Respo return c.ReadRawWithDataWithContext(context.Background(), path, data) } +func (c *Logical) ReadRawFromSnapshot(path string, snapshotID string) (*Response, error) { + return c.ReadRawWithDataWithContext(context.Background(), path, map[string][]string{"read_snapshot_id": {snapshotID}}) +} + // ReadRawWithDataWithContext attempts to read the value stored at the given // Vault path (without '/v1/' prefix) and returns a raw *http.Response. The 'data' // map is added as query parameters to the request. @@ -160,15 +170,26 @@ func (c *Logical) readRawWithDataWithContext(ctx context.Context, path string, d return c.c.RawRequestWithContext(ctx, r) } +// ListFromSnapshot lists from the Vault path using a previously loaded +// snapshot. The snapshotID parameter is the ID of the loaded snapshot +func (c *Logical) ListFromSnapshot(path string, snapshotID string) (*Secret, error) { + r := c.c.NewRequest("LIST", "/v1/"+path) + r.Params.Set("read_snapshot_id", snapshotID) + return c.list(context.Background(), r) +} + func (c *Logical) List(path string) (*Secret, error) { return c.ListWithContext(context.Background(), path) } func (c *Logical) ListWithContext(ctx context.Context, path string) (*Secret, error) { + return c.list(ctx, c.c.NewRequest("LIST", "/v1/"+path)) +} + +func (c *Logical) list(ctx context.Context, r *Request) (*Secret, error) { ctx, cancelFunc := c.c.withConfiguredTimeout(ctx) defer cancelFunc() - r := c.c.NewRequest("LIST", "/v1/"+path) // Set this for broader compatibility, but we use LIST above to be able to // handle the wrapping lookup function r.Method = http.MethodGet @@ -223,6 +244,14 @@ func (c *Logical) WriteRawWithContext(ctx context.Context, path string, data []b return c.writeRaw(ctx, r) } +// Recover recovers the data at the given Vault path from a loaded snapshot. +// The snapshotID parameter is the ID of the loaded snapshot +func (c *Logical) Recover(ctx context.Context, path string, snapshotID string) (*Secret, error) { + r := c.c.NewRequest(http.MethodPut, "/v1/"+path) + r.Params.Set("recover_snapshot_id", snapshotID) + return c.write(ctx, path, r) +} + func (c *Logical) JSONMergePatch(ctx context.Context, path string, data map[string]interface{}) (*Secret, error) { r := c.c.NewRequest(http.MethodPatch, "/v1/"+path) r.Headers.Set("Content-Type", "application/merge-patch+json") diff --git a/command/base.go b/command/base.go index f8c2a8a920..1100dcff2a 100644 --- a/command/base.go +++ b/command/base.go @@ -77,6 +77,8 @@ type BaseCommand struct { tokenHelper tokenhelper.TokenHelper hcpTokenHelper hcpvlib.HCPTokenHelper + flagSnapshotID string + client *api.Client } @@ -371,6 +373,7 @@ const ( FlagSetOutputField FlagSetOutputFormat FlagSetOutputDetailed + FlagSetSnapshot ) // flagSet creates the flags for this command. The result is cached on the @@ -614,6 +617,16 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets { Usage: "Enables additional metadata during some operations", }) } + + if bit&FlagSetSnapshot != 0 { + outputSet.StringVar(&StringVar{ + Name: "snapshot-id", + Target: &c.flagSnapshotID, + Default: "", + Completion: complete.PredictAnything, + Usage: "ID of the loaded snapshot that this command will use", + }) + } } c.flags = set diff --git a/command/list.go b/command/list.go index 6505f76af8..befdd6fbd8 100644 --- a/command/list.go +++ b/command/list.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/hashicorp/cli" + "github.com/hashicorp/vault/api" "github.com/posener/complete" ) @@ -45,7 +46,7 @@ Usage: vault list [options] PATH } func (c *ListCommand) Flags() *FlagSets { - set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat | FlagSetOutputDetailed) + set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat | FlagSetOutputDetailed | FlagSetSnapshot) return set } @@ -82,7 +83,12 @@ func (c *ListCommand) Run(args []string) int { } path := sanitizePath(args[0]) - secret, err := client.Logical().List(path) + var secret *api.Secret + if c.flagSnapshotID != "" { + secret, err = client.Logical().ListFromSnapshot(path, c.flagSnapshotID) + } else { + secret, err = client.Logical().List(path) + } if err != nil { c.UI.Error(fmt.Sprintf("Error listing %s: %s", path, err)) return 2 diff --git a/command/list_test.go b/command/list_test.go index e7a870d7ff..621c3f9817 100644 --- a/command/list_test.go +++ b/command/list_test.go @@ -4,10 +4,14 @@ package command import ( + "net/http" + "net/http/httptest" "strings" "testing" "github.com/hashicorp/cli" + "github.com/hashicorp/vault/api" + "github.com/stretchr/testify/require" ) func testListCommand(tb testing.TB) (*cli.MockUi, *ListCommand) { @@ -133,3 +137,32 @@ func TestListCommand_Run(t *testing.T) { assertNoTabs(t, cmd) }) } + +// TestList_Snapshot tests that the read_snapshot_id query parameter is added +// to the request when the -snapshot-id flag is used. +func TestList_Snapshot(t *testing.T) { + t.Parallel() + mockVaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + snapID := r.URL.Query().Get("read_snapshot_id") + if snapID != "abcd" { + w.WriteHeader(http.StatusBadRequest) + } + w.Write([]byte(`{"data":{"keys":["foo","bar"]}}`)) + })) + defer mockVaultServer.Close() + + cfg := api.DefaultConfig() + cfg.Address = mockVaultServer.URL + client, err := api.NewClient(cfg) + require.NoError(t, err) + + ui, cmd := testListCommand(t) + cmd.client = client + + // a list command with a snapshot id shouldn't error + code := cmd.Run([]string{ + "-snapshot-id", "abcd", "path/", + }) + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + require.Equal(t, 0, code, combined) +} diff --git a/command/read.go b/command/read.go index 67ee2d6d7c..ad14ebc4aa 100644 --- a/command/read.go +++ b/command/read.go @@ -57,7 +57,7 @@ Usage: vault read [options] PATH } func (c *ReadCommand) Flags() *FlagSets { - return c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat) + return c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat | FlagSetSnapshot) } func (c *ReadCommand) AutocompleteArgs() complete.Predictor { @@ -107,6 +107,13 @@ func (c *ReadCommand) Run(args []string) int { return 1 } + if c.flagSnapshotID != "" { + if data == nil { + data = make(map[string][]string) + } + data["read_snapshot_id"] = []string{c.flagSnapshotID} + } + if Format(c.UI) != "raw" { secret, err := client.Logical().ReadWithDataWithContext(ctx, path, data) if err != nil { diff --git a/command/read_test.go b/command/read_test.go index fe8961afb6..acad42b993 100644 --- a/command/read_test.go +++ b/command/read_test.go @@ -4,10 +4,14 @@ package command import ( + "net/http" + "net/http/httptest" "strings" "testing" "github.com/hashicorp/cli" + "github.com/hashicorp/vault/api" + "github.com/stretchr/testify/require" ) func testReadCommand(tb testing.TB) (*cli.MockUi, *ReadCommand) { @@ -165,3 +169,39 @@ func TestReadCommand_Run(t *testing.T) { assertNoTabs(t, cmd) }) } + +// TestRead_Snapshot tests that the read_snapshot_id query parameter is added +// to the request when the -snapshot-id flag is used. +func TestRead_Snapshot(t *testing.T) { + t.Parallel() + mockVaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + snapID := r.URL.Query().Get("read_snapshot_id") + if snapID != "abcd" { + w.WriteHeader(http.StatusBadRequest) + } + w.Write([]byte(`{"secret":{"data":{"foo":"bar"}}}`)) + })) + defer mockVaultServer.Close() + + cfg := api.DefaultConfig() + cfg.Address = mockVaultServer.URL + client, err := api.NewClient(cfg) + require.NoError(t, err) + + ui, cmd := testReadCommand(t) + cmd.client = client + + // a read command with a snapshot id shouldn't error + code := cmd.Run([]string{ + "-snapshot-id", "abcd", "path/to/item", + }) + combined := ui.OutputWriter.String() + ui.ErrorWriter.String() + require.Equal(t, 0, code, combined) + + // check that the raw flag also works with a snapshot id + code = cmd.Run([]string{ + "-format", "raw", "-snapshot-id", "abcd", "path/to/item", + }) + combined = ui.OutputWriter.String() + ui.ErrorWriter.String() + require.Equal(t, 0, code, combined) +}