diff --git a/vault/core.go b/vault/core.go index f39a773ec4..9d3bf06cbf 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2399,7 +2399,7 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c if err := c.entPostUnseal(false); err != nil { return err } - if !c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary | consts.ReplicationDRSecondary) { + if c.isPrimary() { // Only perf primarys should write feature flags, but we do it by // excluding other states so that we don't have to change it when // a non-replicated cluster becomes a primary. @@ -2414,89 +2414,14 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c go c.autoRotateBarrierLoop(autoRotateCtx) } + // Run setup-like functions + if err := runUnsealSetupFunctions(ctx, buildUnsealSetupFunctionSlice(c)); err != nil { + return err + } + if !c.IsDRSecondary() { - if err := c.ensureWrappingKey(ctx); err != nil { - return err - } - } - if pluginRuntimeCatalog, err := plugincatalog.SetupPluginRuntimeCatalog(ctx, c.logger, NewBarrierView(c.barrier, pluginRuntimeCatalogPath)); err != nil { - return err - } else { - c.pluginRuntimeCatalog = pluginRuntimeCatalog - } - if pluginCatalog, err := plugincatalog.SetupPluginCatalog(ctx, c.logger, c.builtinRegistry, NewBarrierView(c.barrier, pluginCatalogPath), c.pluginDirectory, c.enableMlock, c.pluginRuntimeCatalog); err != nil { - return err - } else { - c.pluginCatalog = pluginCatalog - } - if err := c.loadMounts(ctx); err != nil { - return err - } - if err := c.entSetupFilteredPaths(); err != nil { - return err - } - if err := c.setupMounts(ctx); err != nil { - return err - } - if err := c.entSetupAPILock(ctx); err != nil { - return err - } - if err := c.setupPolicyStore(ctx); err != nil { - return err - } - if err := c.setupManagedKeyRegistry(); err != nil { - return err - } - if err := c.loadCORSConfig(ctx); err != nil { - return err - } - if err := c.loadCredentials(ctx); err != nil { - return err - } - if err := c.entSetupFilteredPaths(); err != nil { - return err - } - if err := c.setupCredentials(ctx); err != nil { - return err - } - if err := c.setupQuotas(ctx, false); err != nil { - return err - } - if err := c.setupHeaderHMACKey(ctx, false); err != nil { - return err - } - if !c.IsDRSecondary() { - c.updateLockedUserEntries() - - if err := c.startRollback(); err != nil { - return err - } - if err := c.setupExpiration(expireLeaseStrategyFairsharing); err != nil { - return err - } - if err := c.loadAudits(ctx); err != nil { - return err - } - if err := c.setupAuditedHeadersConfig(ctx); err != nil { - return err - } - - if err := c.setupAudits(ctx); err != nil { - return err - } - if err := c.loadIdentityStoreArtifacts(ctx); err != nil { - return err - } - if err := loadPolicyMFAConfigs(ctx, c); err != nil { - return err - } - c.setupCachedMFAResponseAuth() - if err := c.loadLoginMFAConfigs(ctx); err != nil { - return err - } - if err := c.setupCensusAgent(); err != nil { - c.logger.Error("skipping reporting for nil agent", "error", err) + logger.Error("skipping reporting for nil agent", "error", err) } // not waiting on wg to avoid changing existing behavior @@ -2510,59 +2435,16 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c if err != nil { return fmt.Errorf("unable to parse feature flag: %q: %w", featureFlagDisableEventLogger, err) } - c.auditBroker, err = NewAuditBroker(c.logger, !disableEventLogger) + c.auditBroker, err = NewAuditBroker(logger, !disableEventLogger) if err != nil { return err } } - if !c.ReplicationState().HasState(consts.ReplicationPerformanceSecondary | consts.ReplicationDRSecondary) { - // Cannot do this above, as we need other resources like mounts to be setup - if err := c.setupPluginReload(); err != nil { + if c.isPrimary() { + if err := c.runUnsealSetupForPrimary(ctx, logger); err != nil { return err } - - // Retrieve the seal generation information from storage - existingGenerationInfo, err := PhysicalSealGenInfo(ctx, c.physical) - if err != nil { - c.logger.Error("cannot read existing seal generation info from storage", "error", err) - return err - } - - sealGenerationInfo := c.seal.GetAccess().GetSealGenerationInfo() - - switch { - case existingGenerationInfo == nil: - // This is the first time we store seal generation information - fallthrough - case existingGenerationInfo.Generation < sealGenerationInfo.Generation: - // We have incremented the seal generation - if err := c.SetPhysicalSealGenInfo(ctx, sealGenerationInfo); err != nil { - c.logger.Error("failed to store seal generation info", "error", err) - return err - } - - case existingGenerationInfo.Generation == sealGenerationInfo.Generation: - // Same generation, update the rewrapped flag in case the previous active node - // changed its value. In other words, a rewrap may have happened, or a rewrap may have been - // started but not completed. - c.seal.GetAccess().GetSealGenerationInfo().SetRewrapped(existingGenerationInfo.IsRewrapped()) - - case existingGenerationInfo.Generation > sealGenerationInfo.Generation: - // Our seal information is out of date. The previous active node used a newer generation. - c.logger.Error("A newer seal generation was found in storage. The seal configuration in this node should be updated to match that of the previous active node, and this node should be restarted.") - return errors.New("newer seal generation found in storage, in memory seal configuration is out of date") - } - - if server.IsMultisealSupported() && !sealGenerationInfo.IsRewrapped() { - // Set the migration done flag so that a seal-rewrap gets triggered later. - // Note that in the case where multi seal is not supported, Core.migrateSeal() takes care of - // triggering the rewrap when necessary. - c.logger.Trace("seal generation information indicates that a seal-rewrap is needed", "generation", sealGenerationInfo.Generation) - atomic.StoreUint32(c.sealMigrationDone, 1) - } - - startPartialSealRewrapping(c) } if c.getClusterListener() != nil && (c.ha != nil || shouldStartClusterListener(c)) { @@ -2593,6 +2475,170 @@ func (s standardUnsealStrategy) unseal(ctx context.Context, logger log.Logger, c return nil } +// setupPluginRuntimeCatalog wraps the plugincatalog.SetupPluginRuntimeCatalog +// in way where this method can be included in the slice of functions returned +// by the buildUnsealSetupFunctionsSlice function. +func (c *Core) setupPluginRuntimeCatalog(ctx context.Context) error { + pluginRuntimeCatalog, err := plugincatalog.SetupPluginRuntimeCatalog(ctx, c.logger, NewBarrierView(c.barrier, pluginRuntimeCatalogPath)) + if err != nil { + return err + } + + c.pluginRuntimeCatalog = pluginRuntimeCatalog + + return nil +} + +// setupPluginCatalog wraps the plugincatalog.SetupPluginCatalog in way where +// this method can be included in the slice of functions returned by the +// buildUnsealSetupFunctionsSlice function. +func (c *Core) setupPluginCatalog(ctx context.Context) error { + pluginCatalog, err := plugincatalog.SetupPluginCatalog(ctx, c.logger, c.builtinRegistry, NewBarrierView(c.barrier, pluginCatalogPath), c.pluginDirectory, c.enableMlock, c.pluginRuntimeCatalog) + if err != nil { + return err + } + + c.pluginCatalog = pluginCatalog + + return nil +} + +// buildUnsealSetupFunctionSlice returns a slice of functions, tailored for this +// Core's replication state, that can be passed to the runUnsealSetupFunctions +// function. +func buildUnsealSetupFunctionSlice(c *Core) []func(context.Context) error { + // setupFunctions is a slice of functions that need to be called in order, + // that if any return an error, processing should immediately cease. + setupFunctions := []func(context.Context) error{ + c.setupPluginRuntimeCatalog, + c.setupPluginCatalog, + c.loadMounts, + func(_ context.Context) error { + return c.entSetupFilteredPaths() + }, + c.setupMounts, + c.entSetupAPILock, + c.setupPolicyStore, + func(_ context.Context) error { + return c.setupManagedKeyRegistry() + }, + c.loadCORSConfig, + c.loadCredentials, + func(_ context.Context) error { + return c.entSetupFilteredPaths() + }, + c.setupCredentials, + func(ctx context.Context) error { + return c.setupQuotas(ctx, false) + }, + func(ctx context.Context) error { + return c.setupHeaderHMACKey(ctx, false) + }, + } + + // If this server is not part of a Disaster Recovery secondary cluster, + // the following additional setupFunctions also apply. + if !c.IsDRSecondary() { + // This first setupFunction must be inserted at the beginning of the + // slice. The remainder should be appended at the end. + temp := []func(context.Context) error{ + c.ensureWrappingKey, + } + + setupFunctions = append(temp, setupFunctions...) + setupFunctions = append(setupFunctions, func(_ context.Context) error { + c.updateLockedUserEntries() + return nil + }) + setupFunctions = append(setupFunctions, func(_ context.Context) error { + return c.startRollback() + }) + setupFunctions = append(setupFunctions, func(_ context.Context) error { + return c.setupExpiration(expireLeaseStrategyFairsharing) + }) + setupFunctions = append(setupFunctions, c.loadAudits) + setupFunctions = append(setupFunctions, c.setupAuditedHeadersConfig) + setupFunctions = append(setupFunctions, c.setupAudits) + setupFunctions = append(setupFunctions, c.loadIdentityStoreArtifacts) + setupFunctions = append(setupFunctions, func(ctx context.Context) error { + return loadPolicyMFAConfigs(ctx, c) + }) + setupFunctions = append(setupFunctions, func(_ context.Context) error { + c.setupCachedMFAResponseAuth() + return nil + }) + setupFunctions = append(setupFunctions, c.loadLoginMFAConfigs) + } + + return setupFunctions +} + +// runUnsealSetupFunctions iterates through the provided slice of functions and +// calls each one, passing the provided context.Context as the sole argument. If +// any of the functions returns an error, this function returns it immediately. +func runUnsealSetupFunctions(ctx context.Context, setupFunctions []func(context.Context) error) error { + // call the setupFunctions sequentially + for _, fn := range setupFunctions { + if err := fn(ctx); err != nil { + return err + } + } + + return nil +} + +// runUnsealSetupForPrimary runs some setup code specific to clusters that are +// in the primary role (as defined by the (*Core).isPrimary method). +func (c *Core) runUnsealSetupForPrimary(ctx context.Context, logger log.Logger) error { + if err := c.setupPluginReload(); err != nil { + return err + } + + // Retrieve the seal generation information from storage + existingGenerationInfo, err := PhysicalSealGenInfo(ctx, c.physical) + if err != nil { + logger.Error("cannot read existing seal generation info from storage", "error", err) + return err + } + + sealGenerationInfo := c.seal.GetAccess().GetSealGenerationInfo() + + switch { + case existingGenerationInfo == nil: + // This is the first time we store seal generation information + fallthrough + case existingGenerationInfo.Generation < sealGenerationInfo.Generation: + // We have incremented the seal generation + if err := c.SetPhysicalSealGenInfo(ctx, sealGenerationInfo); err != nil { + logger.Error("failed to store seal generation info", "error", err) + return err + } + + case existingGenerationInfo.Generation == sealGenerationInfo.Generation: + // Same generation, update the rewrapped flag in case the previous active node + // changed its value. In other words, a rewrap may have happened, or a rewrap may have been + // started but not completed. + c.seal.GetAccess().GetSealGenerationInfo().SetRewrapped(existingGenerationInfo.IsRewrapped()) + + case existingGenerationInfo.Generation > sealGenerationInfo.Generation: + // Our seal information is out of date. The previous active node used a newer generation. + logger.Error("A newer seal generation was found in storage. The seal configuration in this node should be updated to match that of the previous active node, and this node should be restarted.") + return errors.New("newer seal generation found in storage, in memory seal configuration is out of date") + } + + if server.IsMultisealSupported() && !sealGenerationInfo.IsRewrapped() { + // Set the migration done flag so that a seal-rewrap gets triggered later. + // Note that in the case where multi seal is not supported, Core.migrateSeal() takes care of + // triggering the rewrap when necessary. + logger.Trace("seal generation information indicates that a seal-rewrap is needed", "generation", sealGenerationInfo.Generation) + atomic.StoreUint32(c.sealMigrationDone, 1) + } + + startPartialSealRewrapping(c) + + return nil +} + // postUnseal is invoked on the active node, and performance standby nodes, // after the barrier is unsealed, but before // allowing any user operations. This allows us to setup any state that diff --git a/vault/core_test.go b/vault/core_test.go index 6ecc40af31..097ea51b4a 100644 --- a/vault/core_test.go +++ b/vault/core_test.go @@ -5,6 +5,7 @@ package vault import ( "context" + "errors" "fmt" "reflect" "strings" @@ -24,6 +25,7 @@ import ( "github.com/hashicorp/vault/builtin/audit/file" "github.com/hashicorp/vault/builtin/audit/socket" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/go-test/deep" @@ -3451,3 +3453,85 @@ func TestStatelock_DeadlockDetection(t *testing.T) { t.Fatal("statelock doesn't have deadlock detection enabled, it should") } } + +// TestRunUnsealSetupFunctions verifies the correct behaviour of the +// runUnsealSetupFunctions function. This function's job is to run each of the +// function elements it is given with the context.Context that it's provided +// as the sole argument. +func TestRunUnsealSetupFunctions(t *testing.T) { + // First, check that the context.Context provided to runUnsealSetupFunctions + // is actually used to call the function elements, by running a method that + // records the context.Context used each time it's called. + checker := contextChecker{} + setupFunctions := []func(context.Context) error{ + checker.setupFunction, + checker.setupFunction, + checker.setupFunction, + } + + testContext := context.WithValue(context.Background(), "test", "pass") + assert.NoError(t, runUnsealSetupFunctions(testContext, setupFunctions)) + for _, v := range checker.values { + assert.Equal(t, "pass", v.(string)) + } + + // Finally, check that when an error is returned by a function element, the + // runUnsealSetupFunctions function immediately returns it, by using the + // same test as above but the second function element is one that returns + // an error, so the checker.values slice should only contain 1 element. + setupFunctions[1] = func(_ context.Context) error { + return errors.New("error") + } + checker = contextChecker{} + + assert.Error(t, runUnsealSetupFunctions(testContext, setupFunctions)) + assert.NotNil(t, checker.values) + assert.Equal(t, 1, len(checker.values)) +} + +// contextChecker is testing struct used to verify that the correct +// context.Context is passed to the setupFunctions by the +// runUnsealSetupFunctions function. +type contextChecker struct { + values []any +} + +func (c *contextChecker) setupFunction(ctx context.Context) error { + value := ctx.Value("test") + c.values = append(c.values, value) + + return nil +} + +// TestBuildUnsealSetupFunctionSlice verifies that the +// buildUnsealSetupFunctionSlice function returns the correct slice of functions +// for the provided Core instance. +func TestBuildUnsealSetupFunctionSlice(t *testing.T) { + uint32Ptr := func(value uint32) *uint32 { + return &value + } + + for _, testcase := range []struct { + name string + core *Core + expectedLength int + }{ + { + name: "primary core", + core: &Core{ + replicationState: uint32Ptr(uint32(0)), + }, + expectedLength: 25, + }, + { + name: "dr secondary core", + core: &Core{ + replicationState: uint32Ptr(uint32(consts.ReplicationDRSecondary)), + }, + expectedLength: 14, + }, + } { + funcs := buildUnsealSetupFunctionSlice(testcase.core) + assert.Equal(t, testcase.expectedLength, len(funcs), testcase.name) + } +}