[VAULT-35190] Implement logic for writing snapshot data to an FSM (#30416)

This commit is contained in:
Kuba Wieczorek 2025-04-29 14:43:53 +01:00 committed by GitHub
parent 733e757c67
commit dbc2f06fbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 131 additions and 43 deletions

View File

@ -19,6 +19,7 @@ import (
"github.com/golang/protobuf/proto"
log "github.com/hashicorp/go-hclog"
iradix "github.com/hashicorp/go-immutable-radix"
"github.com/hashicorp/raft"
"github.com/hashicorp/vault/sdk/plugin/pb"
"github.com/rboyer/safeio"
@ -364,49 +365,10 @@ func (s *BoltSnapshotSink) writeBoltDBFile() error {
defer close(s.doneWritingCh)
defer boltDB.Close()
// The delimted reader will parse full proto messages from the snapshot
// data.
protoReader := NewDelimitedReader(reader, math.MaxInt32)
defer protoReader.Close()
var done bool
var keys int
entry := new(pb.StorageEntry)
for !done {
err := boltDB.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucketIfNotExists(dataBucketName)
if err != nil {
return err
}
// Commit in batches of 50k. Bolt holds all the data in memory and
// doesn't split the pages until commit so we do incremental writes.
for i := 0; i < 50000; i++ {
err := protoReader.ReadMsg(entry)
if err != nil {
if err == io.EOF {
done = true
return nil
}
return err
}
err = b.Put([]byte(entry.Key), entry.Value)
if err != nil {
return err
}
keys += 1
}
return nil
})
if err != nil {
s.logger.Error("snapshot write: failed to write transaction", "error", err)
s.writeError = err
return
}
s.logger.Trace("snapshot write: writing keys", "num_written", keys)
err := loadSnapshot(boltDB, s.logger, reader, nil, false)
if err != nil {
s.writeError = err
return
}
}()
@ -535,3 +497,76 @@ func snapshotName(term, index uint64) string {
msec := now.UnixNano() / int64(time.Millisecond)
return fmt.Sprintf("%d-%d-%d", term, index, msec)
}
// LoadReadOnlySnapshot loads a snapshot from a file into the supplied FSM.
// It sets the fill percent of the underlying boltDB bucket to 1.0. This is a
// blocking call and will not return until the snapshot has been written to the
// FSM. The caller is responsible for closing the reader.
// If pathsToFilter is not nil, the function will filter out any keys that are
// found in the pathsToFilter tree.
func LoadReadOnlySnapshot(fsm *FSM, snapshotFile io.ReadCloser, pathsToFilter *iradix.Tree, logger log.Logger) error {
return loadSnapshot(fsm.db, logger, snapshotFile, pathsToFilter, true)
}
// loadSnapshot loads a snapshot from a file into the supplied boltDB database.
// This is a blocking call and will not return until the snapshot has
// been written to the FSM. The caller is responsible for closing the reader.
// If readOnly is true, it sets the fill percent of the underlying boltDB bucket
// to 1.0.
// If pathsToFilter is not nil, the function will filter out any keys that are
// found in the pathsToFilter tree.
func loadSnapshot(db *bolt.DB, logger log.Logger, snapshotFile io.ReadCloser, pathsToFilter *iradix.Tree, readOnly bool) error {
// The delimited reader will parse full proto messages from the snapshot data.
protoReader := NewDelimitedReader(snapshotFile, math.MaxInt32)
defer protoReader.Close()
var done bool
var keys int
entry := new(pb.StorageEntry)
for !done {
err := db.Update(func(tx *bolt.Tx) error {
b, err := tx.CreateBucketIfNotExists(dataBucketName)
if readOnly {
b.FillPercent = 1.0
}
if err != nil {
return err
}
// Commit in batches of 50k. Bolt holds all the data in memory and
// doesn't split the pages until commit so we do incremental writes.
for i := 0; i < 50000; i++ {
err := protoReader.ReadMsg(entry)
if err != nil {
if err == io.EOF {
done = true
return nil
}
return err
}
if pathsToFilter != nil {
keyToLookUp := []byte(entry.Key)
_, _, ok := pathsToFilter.Root().LongestPrefix(keyToLookUp)
if ok {
continue
}
}
err = b.Put([]byte(entry.Key), entry.Value)
if err != nil {
return err
}
keys += 1
}
return nil
})
if err != nil {
logger.Error("snapshot write: failed to write transaction", "error", err)
return err
}
logger.Trace("snapshot write: writing keys", "num_written", keys)
}
return nil
}

View File

@ -19,10 +19,13 @@ import (
"time"
"github.com/hashicorp/go-hclog"
iradix "github.com/hashicorp/go-immutable-radix"
"github.com/hashicorp/raft"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/physical"
"github.com/hashicorp/vault/sdk/plugin/pb"
"github.com/stretchr/testify/require"
)
type idAddr struct {
@ -958,3 +961,48 @@ func TestBoltSnapshotStore_CloseFailure(t *testing.T) {
t.Fatal("expected write to fail")
}
}
// TestLoadReadOnlySnapshot loads a test snapshot file and verifies that there
// are no errors, and that the expected paths are excluded from the FSM.
func TestLoadReadOnlySnapshot(t *testing.T) {
t.Parallel()
// Load a test snapshot file from the testdata directory.
// The snapshot contains the following paths:
// * /different/path/to/exclude
// * /path/to/exclude/1
// * /path/to/exclude/2
// * /path/to/keep
testSnapshotFilePath := "testdata/TestLoadReadOnlySnapshot.bin"
dir := t.TempDir()
logger := corehelpers.NewTestLogger(t)
snapshotFile, err := os.Open(testSnapshotFilePath)
require.NoError(t, err)
defer snapshotFile.Close()
// Create a radix tree containing paths to exclude.
pathsToExclude := iradix.New()
txn := pathsToExclude.Txn()
_, _ = txn.Insert([]byte("/path/to/exclude"), []byte("value"))
_, _ = txn.Insert([]byte("/path/to/exclude/1"), []byte("value"))
_, _ = txn.Insert([]byte("/different/path/to/exclude"), []byte("value"))
pathsToExclude = txn.Commit()
// Create an FSM to load the snapshot data into.
fsm, err := NewFSM(dir, "test-fsm", logger)
err = LoadReadOnlySnapshot(fsm, snapshotFile, pathsToExclude, logger)
require.NoError(t, err)
value, err := fsm.Get(context.Background(), "/path/to/exclude/1")
require.NoError(t, err)
require.Nil(t, value)
value, err = fsm.Get(context.Background(), "/path/to/exclude/2")
require.NoError(t, err)
require.Nil(t, value)
value, err = fsm.Get(context.Background(), "/different/path/to/exclude")
require.NoError(t, err)
require.Nil(t, value)
value, err = fsm.Get(context.Background(), "/path/to/keep")
require.NoError(t, err)
require.NotNil(t, value)
}

View File

@ -0,0 +1,5 @@

/different/path/to/exclude
/path/to/exclude/1
/path/to/exclude/2
/path/to/keep