VAULT-36198: Add API/CLI support for reading, listing, recovering from a snapshot (#30701)

This commit is contained in:
miagilepner 2025-05-21 15:10:20 +02:00 committed by GitHub
parent 1aff56bfc5
commit 2c1d8b6fb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 132 additions and 4 deletions

View File

@ -64,6 +64,12 @@ func (c *Logical) ReadWithData(path string, data map[string][]string) (*Secret,
return c.ReadWithDataWithContext(context.Background(), path, data)
}
// ReadFromSnapshot reads the data at the given Vault path from a previously
// loaded snapshot. The snapshotID parameter is the ID of the loaded snapshot
func (c *Logical) ReadFromSnapshot(path string, snapshotID string) (*Secret, error) {
return c.ReadWithData(path, map[string][]string{"read_snapshot_id": {snapshotID}})
}
func (c *Logical) ReadWithDataWithContext(ctx context.Context, path string, data map[string][]string) (*Secret, error) {
ctx, cancelFunc := c.c.withConfiguredTimeout(ctx)
defer cancelFunc()
@ -104,6 +110,10 @@ func (c *Logical) ReadRawWithData(path string, data map[string][]string) (*Respo
return c.ReadRawWithDataWithContext(context.Background(), path, data)
}
func (c *Logical) ReadRawFromSnapshot(path string, snapshotID string) (*Response, error) {
return c.ReadRawWithDataWithContext(context.Background(), path, map[string][]string{"read_snapshot_id": {snapshotID}})
}
// ReadRawWithDataWithContext attempts to read the value stored at the given
// Vault path (without '/v1/' prefix) and returns a raw *http.Response. The 'data'
// map is added as query parameters to the request.
@ -160,15 +170,26 @@ func (c *Logical) readRawWithDataWithContext(ctx context.Context, path string, d
return c.c.RawRequestWithContext(ctx, r)
}
// ListFromSnapshot lists from the Vault path using a previously loaded
// snapshot. The snapshotID parameter is the ID of the loaded snapshot
func (c *Logical) ListFromSnapshot(path string, snapshotID string) (*Secret, error) {
r := c.c.NewRequest("LIST", "/v1/"+path)
r.Params.Set("read_snapshot_id", snapshotID)
return c.list(context.Background(), r)
}
func (c *Logical) List(path string) (*Secret, error) {
return c.ListWithContext(context.Background(), path)
}
func (c *Logical) ListWithContext(ctx context.Context, path string) (*Secret, error) {
return c.list(ctx, c.c.NewRequest("LIST", "/v1/"+path))
}
func (c *Logical) list(ctx context.Context, r *Request) (*Secret, error) {
ctx, cancelFunc := c.c.withConfiguredTimeout(ctx)
defer cancelFunc()
r := c.c.NewRequest("LIST", "/v1/"+path)
// Set this for broader compatibility, but we use LIST above to be able to
// handle the wrapping lookup function
r.Method = http.MethodGet
@ -223,6 +244,14 @@ func (c *Logical) WriteRawWithContext(ctx context.Context, path string, data []b
return c.writeRaw(ctx, r)
}
// Recover recovers the data at the given Vault path from a loaded snapshot.
// The snapshotID parameter is the ID of the loaded snapshot
func (c *Logical) Recover(ctx context.Context, path string, snapshotID string) (*Secret, error) {
r := c.c.NewRequest(http.MethodPut, "/v1/"+path)
r.Params.Set("recover_snapshot_id", snapshotID)
return c.write(ctx, path, r)
}
func (c *Logical) JSONMergePatch(ctx context.Context, path string, data map[string]interface{}) (*Secret, error) {
r := c.c.NewRequest(http.MethodPatch, "/v1/"+path)
r.Headers.Set("Content-Type", "application/merge-patch+json")

View File

@ -77,6 +77,8 @@ type BaseCommand struct {
tokenHelper tokenhelper.TokenHelper
hcpTokenHelper hcpvlib.HCPTokenHelper
flagSnapshotID string
client *api.Client
}
@ -371,6 +373,7 @@ const (
FlagSetOutputField
FlagSetOutputFormat
FlagSetOutputDetailed
FlagSetSnapshot
)
// flagSet creates the flags for this command. The result is cached on the
@ -614,6 +617,16 @@ func (c *BaseCommand) flagSet(bit FlagSetBit) *FlagSets {
Usage: "Enables additional metadata during some operations",
})
}
if bit&FlagSetSnapshot != 0 {
outputSet.StringVar(&StringVar{
Name: "snapshot-id",
Target: &c.flagSnapshotID,
Default: "",
Completion: complete.PredictAnything,
Usage: "ID of the loaded snapshot that this command will use",
})
}
}
c.flags = set

View File

@ -8,6 +8,7 @@ import (
"strings"
"github.com/hashicorp/cli"
"github.com/hashicorp/vault/api"
"github.com/posener/complete"
)
@ -45,7 +46,7 @@ Usage: vault list [options] PATH
}
func (c *ListCommand) Flags() *FlagSets {
set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat | FlagSetOutputDetailed)
set := c.flagSet(FlagSetHTTP | FlagSetOutputFormat | FlagSetOutputDetailed | FlagSetSnapshot)
return set
}
@ -82,7 +83,12 @@ func (c *ListCommand) Run(args []string) int {
}
path := sanitizePath(args[0])
secret, err := client.Logical().List(path)
var secret *api.Secret
if c.flagSnapshotID != "" {
secret, err = client.Logical().ListFromSnapshot(path, c.flagSnapshotID)
} else {
secret, err = client.Logical().List(path)
}
if err != nil {
c.UI.Error(fmt.Sprintf("Error listing %s: %s", path, err))
return 2

View File

@ -4,10 +4,14 @@
package command
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/hashicorp/cli"
"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/require"
)
func testListCommand(tb testing.TB) (*cli.MockUi, *ListCommand) {
@ -133,3 +137,32 @@ func TestListCommand_Run(t *testing.T) {
assertNoTabs(t, cmd)
})
}
// TestList_Snapshot tests that the read_snapshot_id query parameter is added
// to the request when the -snapshot-id flag is used.
func TestList_Snapshot(t *testing.T) {
t.Parallel()
mockVaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
snapID := r.URL.Query().Get("read_snapshot_id")
if snapID != "abcd" {
w.WriteHeader(http.StatusBadRequest)
}
w.Write([]byte(`{"data":{"keys":["foo","bar"]}}`))
}))
defer mockVaultServer.Close()
cfg := api.DefaultConfig()
cfg.Address = mockVaultServer.URL
client, err := api.NewClient(cfg)
require.NoError(t, err)
ui, cmd := testListCommand(t)
cmd.client = client
// a list command with a snapshot id shouldn't error
code := cmd.Run([]string{
"-snapshot-id", "abcd", "path/",
})
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
require.Equal(t, 0, code, combined)
}

View File

@ -57,7 +57,7 @@ Usage: vault read [options] PATH
}
func (c *ReadCommand) Flags() *FlagSets {
return c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat)
return c.flagSet(FlagSetHTTP | FlagSetOutputField | FlagSetOutputFormat | FlagSetSnapshot)
}
func (c *ReadCommand) AutocompleteArgs() complete.Predictor {
@ -107,6 +107,13 @@ func (c *ReadCommand) Run(args []string) int {
return 1
}
if c.flagSnapshotID != "" {
if data == nil {
data = make(map[string][]string)
}
data["read_snapshot_id"] = []string{c.flagSnapshotID}
}
if Format(c.UI) != "raw" {
secret, err := client.Logical().ReadWithDataWithContext(ctx, path, data)
if err != nil {

View File

@ -4,10 +4,14 @@
package command
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/hashicorp/cli"
"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/require"
)
func testReadCommand(tb testing.TB) (*cli.MockUi, *ReadCommand) {
@ -165,3 +169,39 @@ func TestReadCommand_Run(t *testing.T) {
assertNoTabs(t, cmd)
})
}
// TestRead_Snapshot tests that the read_snapshot_id query parameter is added
// to the request when the -snapshot-id flag is used.
func TestRead_Snapshot(t *testing.T) {
t.Parallel()
mockVaultServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
snapID := r.URL.Query().Get("read_snapshot_id")
if snapID != "abcd" {
w.WriteHeader(http.StatusBadRequest)
}
w.Write([]byte(`{"secret":{"data":{"foo":"bar"}}}`))
}))
defer mockVaultServer.Close()
cfg := api.DefaultConfig()
cfg.Address = mockVaultServer.URL
client, err := api.NewClient(cfg)
require.NoError(t, err)
ui, cmd := testReadCommand(t)
cmd.client = client
// a read command with a snapshot id shouldn't error
code := cmd.Run([]string{
"-snapshot-id", "abcd", "path/to/item",
})
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
require.Equal(t, 0, code, combined)
// check that the raw flag also works with a snapshot id
code = cmd.Run([]string{
"-format", "raw", "-snapshot-id", "abcd", "path/to/item",
})
combined = ui.OutputWriter.String() + ui.ErrorWriter.String()
require.Equal(t, 0, code, combined)
}