diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 4609d27413..5a77ec85fd 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -39,7 +39,12 @@ const ( minRootCredRollbackAge = 1 * time.Minute ) -var databaseConfigNameFromRotationIDRegex = regexp.MustCompile("^.+/config/(.+$)") +var ( + databaseInitTimeout = 10 * time.Second + databaseConfigNameFromRotationIDRegex = regexp.MustCompile("^.+/config/(.+$)") +) + +var errDatabaseInitializeTimeout = errors.New("timeout exceeded during Initialize") type dbPluginInstance struct { sync.RWMutex @@ -380,6 +385,61 @@ func (b *databaseBackend) CloseIfShutdown(db *dbPluginInstance, err error) { } } +type initializeConnectionResult struct { + resp v5.InitializeResponse + err error +} + +// initializeConnection bounds how long Vault waits for plugin initialization. +// Some drivers and plugins do not reliably honor context cancellation during +// connection verification, so Initialize runs in a goroutine and is raced +// against a timeout to keep rotation and connection creation from hanging +// indefinitely while locks are held. +func (b *databaseBackend) initializeConnection(ctx context.Context, dbw databaseVersionWrapper, initReq v5.InitializeRequest) (v5.InitializeResponse, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, databaseInitTimeout) + defer cancel() + + done := make(chan initializeConnectionResult, 1) + + go func() { + resp, err := dbw.Initialize(timeoutCtx, initReq) + done <- initializeConnectionResult{resp: resp, err: err} + }() + + select { + case result := <-done: + return result.resp, result.err + case <-timeoutCtx.Done(): + // Preserve the caller's cancellation or deadline when it fired first. + if ctx.Err() != nil { + return v5.InitializeResponse{}, ctx.Err() + } + // This only bounds Vault's wait; the underlying Initialize call may still + // be blocked below the plugin boundary until that implementation notices cancellation. + return v5.InitializeResponse{}, fmt.Errorf("%w: %v", errDatabaseInitializeTimeout, timeoutCtx.Err()) + } +} + +// closeDatabaseWrapperAfterInitError treats init timeouts differently because a +// synchronous Close can block behind the same stuck initialization path. +func (b *databaseBackend) closeDatabaseWrapperAfterInitError(dbw databaseVersionWrapper, err error) { + /* + * Use async close when the init goroutine may still be running. This + * covers both Vault's own databaseInitTimeout sentinel and the case where + * the caller's context fired first (path B1 in initializeConnection), + * which returns the raw ctx.Err() rather than the wrapped sentinel. + */ + if errors.Is(err, errDatabaseInitializeTimeout) || + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) { + // Let the caller unwind immediately; best-effort cleanup continues in the background. + go dbw.Close() + return + } + + _ = dbw.Close() +} + // clean closes all connections from all database types // and cancels any rotation queue loading operation. func (b *databaseBackend) clean(_ context.Context) { diff --git a/builtin/logical/database/backend_ce.go b/builtin/logical/database/backend_ce.go index 977576429a..4462321831 100644 --- a/builtin/logical/database/backend_ce.go +++ b/builtin/logical/database/backend_ce.go @@ -56,9 +56,11 @@ func (b *databaseBackend) GetConnectionWithConfig(ctx context.Context, name stri Config: config.ConnectionDetails, VerifyConnection: config.VerifyConnection, } - _, err = dbw.Initialize(ctx, initReq) + // Bound cache-miss initialization so a blocked handshake cannot stall callers + // indefinitely while the connection creation lock is held. + _, err = b.initializeConnection(ctx, dbw, initReq) if err != nil { - dbw.Close() + b.closeDatabaseWrapperAfterInitError(dbw, err) return nil, err } diff --git a/builtin/logical/database/backend_get_test.go b/builtin/logical/database/backend_get_test.go index dc1bd4b50d..fa3743d36a 100644 --- a/builtin/logical/database/backend_get_test.go +++ b/builtin/logical/database/backend_get_test.go @@ -5,9 +5,12 @@ package database import ( "context" + "errors" "sync" "testing" + "time" + v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" "github.com/hashicorp/vault/sdk/helper/consts" "github.com/hashicorp/vault/sdk/helper/pluginutil" "github.com/hashicorp/vault/sdk/logical" @@ -16,12 +19,14 @@ import ( func newSystemViewWrapper(view logical.SystemView) logical.SystemView { return &systemViewWrapper{ - view, + SystemView: view, } } type systemViewWrapper struct { logical.SystemView + pluginName string + builtinFactory func() (interface{}, error) } var _ logical.ExtendedSystemView = (*systemViewWrapper)(nil) @@ -51,11 +56,21 @@ func (s *systemViewWrapper) GetPinnedPluginVersion(ctx context.Context, pluginTy } func (s *systemViewWrapper) LookupPluginVersion(ctx context.Context, pluginName string, pluginType consts.PluginType, version string) (*pluginutil.PluginRunner, error) { + name := s.pluginName + if name == "" { + name = mockv5 + } + + factory := s.builtinFactory + if factory == nil { + factory = New + } + return &pluginutil.PluginRunner{ - Name: mockv5, + Name: name, Type: consts.PluginTypeDatabase, Builtin: true, - BuiltinFactory: New, + BuiltinFactory: factory, }, nil } @@ -77,6 +92,96 @@ func getDbBackend(t *testing.T) (*databaseBackend, logical.Storage) { return b, config.StorageView } +type blockingInitializeDatabase struct { + initializeDone chan struct{} +} + +func newBlockingInitializeDatabase() (interface{}, error) { + return &blockingInitializeDatabase{initializeDone: make(chan struct{})}, nil +} + +func (d *blockingInitializeDatabase) Initialize(context.Context, v5.InitializeRequest) (v5.InitializeResponse, error) { + <-d.initializeDone + return v5.InitializeResponse{}, nil +} + +func (d *blockingInitializeDatabase) NewUser(context.Context, v5.NewUserRequest) (v5.NewUserResponse, error) { + return v5.NewUserResponse{}, nil +} + +func (d *blockingInitializeDatabase) UpdateUser(context.Context, v5.UpdateUserRequest) (v5.UpdateUserResponse, error) { + return v5.UpdateUserResponse{}, nil +} + +func (d *blockingInitializeDatabase) DeleteUser(context.Context, v5.DeleteUserRequest) (v5.DeleteUserResponse, error) { + return v5.DeleteUserResponse{}, nil +} + +func (d *blockingInitializeDatabase) Type() (string, error) { + return mockV5Type, nil +} + +func (d *blockingInitializeDatabase) Close() error { + close(d.initializeDone) + return nil +} + +// slowCloseDatabase blocks in Close until closeCh is closed, so a synchronous +// call to closeDatabaseWrapperAfterInitError would stall the test. +type slowCloseDatabase struct { + closeCh chan struct{} +} + +func (d *slowCloseDatabase) Initialize(context.Context, v5.InitializeRequest) (v5.InitializeResponse, error) { + return v5.InitializeResponse{}, nil +} + +func (d *slowCloseDatabase) NewUser(context.Context, v5.NewUserRequest) (v5.NewUserResponse, error) { + return v5.NewUserResponse{}, nil +} + +func (d *slowCloseDatabase) UpdateUser(context.Context, v5.UpdateUserRequest) (v5.UpdateUserResponse, error) { + return v5.UpdateUserResponse{}, nil +} + +func (d *slowCloseDatabase) DeleteUser(context.Context, v5.DeleteUserRequest) (v5.DeleteUserResponse, error) { + return v5.DeleteUserResponse{}, nil +} +func (d *slowCloseDatabase) Type() (string, error) { return "slow-close", nil } +func (d *slowCloseDatabase) Close() error { + <-d.closeCh + return nil +} + +// TestCloseDatabaseWrapperAfterInitError_ContextCanceled_IsAsync verifies that +// closeDatabaseWrapperAfterInitError does not block when the error is +// context.Canceled (parent context fired before Vault's own databaseInitTimeout). +// In that path the init goroutine may still be running, so Close must be async. +// context.DeadlineExceeded is symmetric (same code branch), so a single test for +// context.Canceled is sufficient. +func TestCloseDatabaseWrapperAfterInitError_ContextCanceled_IsAsync(t *testing.T) { + closeCh := make(chan struct{}) + dbw := databaseVersionWrapper{v5: &slowCloseDatabase{closeCh: closeCh}} + + b := &databaseBackend{} + + done := make(chan struct{}) + go func() { + defer close(done) + b.closeDatabaseWrapperAfterInitError(dbw, context.Canceled) + }() + + select { + case <-done: + // Good: function returned before Close() completed. + case <-time.After(100 * time.Millisecond): + t.Fatal("closeDatabaseWrapperAfterInitError blocked synchronously on Close() for context.Canceled") + } + + // Unblock the background goroutine so it does not leak. + close(closeCh) +} + // TestGetConnectionRaceCondition checks that GetConnection always returns the same instance, even when asked // by multiple goroutines in parallel. func TestGetConnectionRaceCondition(t *testing.T) { @@ -107,3 +212,54 @@ func TestGetConnectionRaceCondition(t *testing.T) { } } } + +// TestGetConnectionInitializeTimeout verifies GetConnection returns an initialize-timeout +// error when plugin initialization blocks longer than the configured timeout +func TestGetConnectionInitializeTimeout(t *testing.T) { + oldTimeout := databaseInitTimeout + databaseInitTimeout = 25 * time.Millisecond + defer func() { + databaseInitTimeout = oldTimeout + }() + + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{ + SystemView: config.System, + builtinFactory: newBlockingInitializeDatabase, + } + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + entry, err := logical.StorageEntryJSON("config/blocked", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + PluginName: mockV5Type, + VerifyConnection: true, + ConnectionDetails: map[string]interface{}{"connection_url": "unused"}, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + start := time.Now() + _, err = b.GetConnection(context.Background(), config.StorageView, "blocked") + if err == nil { + t.Fatal("expected timeout error") + } + if !errors.Is(err, errDatabaseInitializeTimeout) { + t.Fatalf("expected initialize timeout error, got: %v", err) + } + if elapsed := time.Since(start); elapsed > time.Second { + t.Fatalf("GetConnection took too long to fail: %s", elapsed) + } + if conn := b.connections.Get("blocked"); conn != nil { + t.Fatal("expected timed out connection to not be cached") + } +} diff --git a/builtin/logical/database/path_config_connection_ce.go b/builtin/logical/database/path_config_connection_ce.go index c56115bb1a..d8d1a8b484 100644 --- a/builtin/logical/database/path_config_connection_ce.go +++ b/builtin/logical/database/path_config_connection_ce.go @@ -145,9 +145,11 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { Config: config.ConnectionDetails, VerifyConnection: config.VerifyConnection, } - initResp, err := dbw.Initialize(ctx, initReq) + // verify_connection can perform a live handshake here, so bound how long + // Vault waits before failing the write and releasing the caller. + initResp, err := b.initializeConnection(ctx, dbw, initReq) if err != nil { - dbw.Close() + b.closeDatabaseWrapperAfterInitError(dbw, err) return logical.ErrorResponse("error creating database object: %s", err), nil } config.ConnectionDetails = initResp.Config diff --git a/builtin/logical/database/rollback.go b/builtin/logical/database/rollback.go index d1b3b07c40..12ea638f54 100644 --- a/builtin/logical/database/rollback.go +++ b/builtin/logical/database/rollback.go @@ -90,6 +90,16 @@ func (b *databaseBackend) walRollback(ctx context.Context, req *logical.Request, return nil } + // An initialization timeout means the database was unreachable within + // Vault's deadline, not that the stored credentials are wrong. A timeout + // is not a reliable signal of credential state: the rotation may have + // already applied the new password before the database became slow. + // Returning the error here lets the WAL framework retry later rather + // than risking a rollback that reverts a successfully rotated credential. + if errors.Is(err, errDatabaseInitializeTimeout) { + return err + } + return b.rollbackDatabaseCredentials(ctx, config, entry) } diff --git a/builtin/logical/database/rollback_test.go b/builtin/logical/database/rollback_test.go index 1b77c8e143..4320740619 100644 --- a/builtin/logical/database/rollback_test.go +++ b/builtin/logical/database/rollback_test.go @@ -5,6 +5,7 @@ package database import ( "context" + "errors" "strings" "testing" "time" @@ -405,3 +406,129 @@ func TestBackend_RotateRootCredentials_WAL_no_rollback_2(t *testing.T) { t.Fatalf("err:%s resp:%v\n", err, credResp) } } + +// failingInitializeDatabase is a v5.Database mock whose Initialize always +// returns a fixed error. Used to simulate a persistent connection failure +// (wrong credentials, network refused) without a blocking timeout. +type failingInitializeDatabase struct { + err error +} + +func (d *failingInitializeDatabase) Initialize(_ context.Context, _ v5.InitializeRequest) (v5.InitializeResponse, error) { + return v5.InitializeResponse{}, d.err +} + +func (d *failingInitializeDatabase) NewUser(_ context.Context, _ v5.NewUserRequest) (v5.NewUserResponse, error) { + return v5.NewUserResponse{}, nil +} + +func (d *failingInitializeDatabase) UpdateUser(_ context.Context, _ v5.UpdateUserRequest) (v5.UpdateUserResponse, error) { + return v5.UpdateUserResponse{}, nil +} + +func (d *failingInitializeDatabase) DeleteUser(_ context.Context, _ v5.DeleteUserRequest) (v5.DeleteUserResponse, error) { + return v5.DeleteUserResponse{}, nil +} +func (d *failingInitializeDatabase) Type() (string, error) { return mockV5Type, nil } +func (d *failingInitializeDatabase) Close() error { return nil } + +// TestWalRollback_InitializeTimeout_SkipsRollback verifies that a transient +// errDatabaseInitializeTimeout from GetConnection does not trigger +// rollbackDatabaseCredentials. The timeout only means the database was slow; +// it says nothing about whether credentials were already rotated successfully. +func TestWalRollback_InitializeTimeout_SkipsRollback(t *testing.T) { + oldTimeout := databaseInitTimeout + databaseInitTimeout = 25 * time.Millisecond + defer func() { databaseInitTimeout = oldTimeout }() + + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{ + SystemView: config.System, + builtinFactory: newBlockingInitializeDatabase, + } + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + entry, err := logical.StorageEntryJSON("config/mydb", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + VerifyConnection: true, + PluginName: mockV5Type, + ConnectionDetails: map[string]interface{}{ + "password": "original-pass", + }, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + walEntry := &rotateRootCredentialsWAL{ + ConnectionName: "mydb", + UserName: "root", + NewPassword: "new-pass", // != "original-pass": enters the credential-verification branch + OldPassword: "original-pass", + } + err = b.walRollback(context.Background(), &logical.Request{Storage: config.StorageView}, rotateRootWALKey, walEntry) + if !errors.Is(err, errDatabaseInitializeTimeout) { + t.Fatalf("expected errDatabaseInitializeTimeout to propagate, got: %v", err) + } +} + +// TestWalRollback_ConnectionFailed_TriggersRollback verifies that a +// non-timeout connection failure still reaches rollbackDatabaseCredentials, +// preserving the existing behavior for genuine authentication failures. +func TestWalRollback_ConnectionFailed_TriggersRollback(t *testing.T) { + connErr := errors.New("connection refused: invalid credentials") + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{ + SystemView: config.System, + builtinFactory: func() (interface{}, error) { + return &failingInitializeDatabase{err: connErr}, nil + }, + } + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + entry, err := logical.StorageEntryJSON("config/mydb", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + VerifyConnection: true, + PluginName: mockV5Type, + ConnectionDetails: map[string]interface{}{ + "password": "original-pass", + }, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + walEntry := &rotateRootCredentialsWAL{ + ConnectionName: "mydb", + UserName: "root", + NewPassword: "new-pass", + OldPassword: "original-pass", + } + err = b.walRollback(context.Background(), &logical.Request{Storage: config.StorageView}, rotateRootWALKey, walEntry) + // A non-timeout failure must not be swallowed: rollbackDatabaseCredentials is + // called and its error (from a second failed GetConnectionWithConfig) propagates. + if errors.Is(err, errDatabaseInitializeTimeout) { + t.Fatal("timeout sentinel must not appear for a non-timeout connection failure") + } + if err == nil { + t.Fatal("expected rollbackDatabaseCredentials to return an error") + } +} diff --git a/builtin/logical/database/rotation.go b/builtin/logical/database/rotation.go index 19e66d1447..76c425a775 100644 --- a/builtin/logical/database/rotation.go +++ b/builtin/logical/database/rotation.go @@ -29,6 +29,8 @@ const ( staticWALKey = "staticRotationKey" ) +var staticUpdateUserTimeout = 10 * time.Second + // populateQueue loads the priority queue with existing static accounts. This // occurs at initialization, after any WAL entries of failed or interrupted // rotations have been processed. It lists the roles from storage and searches @@ -538,7 +540,27 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag b.Logger().Debug("writing WAL", "role", input.RoleName, "WAL ID", output.WALID) } - _, err = dbi.database.UpdateUser(ctx, updateReq, false) + timeoutCtx, cancel := context.WithTimeout(ctx, staticUpdateUserTimeout) + defer cancel() + updateUserTimeoutErr := errors.New("timeout exceeded during UpdateUser") + + done := make(chan error, 1) + + go func() { + _, e := dbi.database.UpdateUser(timeoutCtx, updateReq, false) + done <- e + }() + + select { + case err = <-done: + case <-timeoutCtx.Done(): + if ctx.Err() != nil { + err = ctx.Err() + } else { + err = updateUserTimeoutErr + } + } + if err != nil { b.CloseIfShutdown(dbi, err) if usedCredentialFromPreviousRotation { @@ -547,11 +569,11 @@ func (b *databaseBackend) setStaticAccount(ctx context.Context, s logical.Storag b.Logger().Warn("failed to delete WAL", "error", err, "WAL ID", output.WALID) } - // Generate a new WAL entry and credential for next attempt output.WALID = "" } return output, fmt.Errorf("error setting credentials: %w", err) } + modified = true // static user password successfully updated in external system diff --git a/builtin/logical/database/rotation_test.go b/builtin/logical/database/rotation_test.go index 9e4af3d067..075f353f2c 100644 --- a/builtin/logical/database/rotation_test.go +++ b/builtin/logical/database/rotation_test.go @@ -1696,6 +1696,47 @@ func TestRotationSchedulePriorityAfterRestart(t *testing.T) { require.Equal(t, newPriority, firstPriority) // confirm that priority has not changed } +// TestRotateRole_BlockedUpdateUser_TimesOut verifies rotate-role returns quickly +// with a timeout error when UpdateUser blocks past Vault's timeout window. +func TestRotateRole_BlockedUpdateUser_TimesOut(t *testing.T) { + ctx := context.Background() + b, storage, mockDB := getBackend(t) + defer b.Cleanup(ctx) + configureDBMount(t, storage) + + oldTimeout := staticUpdateUserTimeout + staticUpdateUserTimeout = 25 * time.Millisecond + defer func() { staticUpdateUserTimeout = oldTimeout }() + + roleName := "hashicorp" + data := map[string]interface{}{ + "username": "hashicorp", + "db_name": "mockv5", + "rotation_period": "10m", + } + createRoleWithData(t, b, storage, mockDB, roleName, data) + + blockCh := make(chan struct{}) + defer close(blockCh) + mockDB.On("UpdateUser", mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + <-blockCh + }). + Return(v5.UpdateUserResponse{}, nil). + Once() + + start := time.Now() + _, err := b.HandleRequest(ctx, &logical.Request{ + Operation: logical.UpdateOperation, + Path: "rotate-role/" + roleName, + Storage: storage, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "timeout exceeded during UpdateUser") + require.Less(t, time.Since(start), time.Second, "rotate-role should return promptly on update timeout") +} + func generateWALFromFailedRotation(t *testing.T, b *databaseBackend, storage logical.Storage, mockDB *mockNewDatabase, roleName string) { t.Helper() mockDB.On("UpdateUser", mock.Anything, mock.Anything). diff --git a/changelog/_13697.txt b/changelog/_13697.txt new file mode 100644 index 0000000000..94ded70fb2 --- /dev/null +++ b/changelog/_13697.txt @@ -0,0 +1,3 @@ +```release-note:bug +database: prevent static role rotation and connection init from hanging indefinitely when database calls block by adding timeouts around UpdateUser and Initialize +``` \ No newline at end of file