From d4da61fc4e987de07ca9b60604fddfb83314a3b2 Mon Sep 17 00:00:00 2001 From: davidadeleon <56207066+davidadeleon@users.noreply.github.com> Date: Thu, 20 Jun 2024 14:46:01 -0400 Subject: [PATCH] CE side change for vault-24636 (#26675) --- vault/login_mfa.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/vault/login_mfa.go b/vault/login_mfa.go index abe3f70192..2a14b81107 100644 --- a/vault/login_mfa.go +++ b/vault/login_mfa.go @@ -2653,6 +2653,31 @@ func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigByNameAndNamespace(name, return eConfig.Clone() } +func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigByID(id string) (*mfa.MFAEnforcementConfig, error) { + if id == "" { + return nil, fmt.Errorf("missing config id") + } + + txn := b.db.Txn(false) + defer txn.Abort() + + eConfigRaw, err := txn.First(memDBMFALoginEnforcementsTable, "id", id) + if err != nil { + return nil, fmt.Errorf("failed to fetch MFA login enforcement config from memdb using id: %w", err) + } + + if eConfigRaw == nil { + return nil, nil + } + + eConfig, ok := eConfigRaw.(*mfa.MFAEnforcementConfig) + if !ok { + return nil, fmt.Errorf("invalid type for MFA login enforcement config in memdb") + } + + return eConfig.Clone() +} + func (b *LoginMFABackend) MemDBMFALoginEnforcementConfigIterator() (memdb.ResultIterator, error) { txn := b.db.Txn(false) defer txn.Abort() @@ -2710,6 +2735,32 @@ func (b *LoginMFABackend) deleteMFALoginEnforcementConfigByNameAndNamespace(ctx return nil } +func (b *LoginMFABackend) MemDBDeleteMFALoginEnforcementConfigByID(id string) error { + if id == "" { + return nil + } + + txn := b.db.Txn(true) + defer txn.Abort() + + eConfig, err := b.MemDBMFALoginEnforcementConfigByID(id) + if err != nil { + return err + } + + if eConfig == nil { + return nil + } + + err = txn.Delete(memDBMFALoginEnforcementsTable, eConfig) + if err != nil { + return err + } + + txn.Commit() + return nil +} + func (b *LoginMFABackend) MemDBDeleteMFALoginEnforcementConfigByNameAndNamespace(name, namespaceId, tableName string) error { if name == "" || namespaceId == "" { return nil