From c9430538b3f76ee5cfc6371aeb7788971e31141b Mon Sep 17 00:00:00 2001 From: Vault Automation Date: Wed, 29 Apr 2026 03:59:02 -0600 Subject: [PATCH] VAULT-44064 - Add rollback support to the snowflake key pair root credentials rotation (#14046) (#14400) * Add rollback support to the snowflake key pair root rotation flow * Added changelog * Updated changelog * Updated changelog * Updated rollback logic * Updated rollback logic * Updated rollback logic * Updated rollback logic * Updated tests * Addressed PR comments * Updated tests * Addressing PR Review Comments --------- Co-authored-by: santoshhashicorp Co-authored-by: John-Michael Faircloth --- .../database/path_rotate_credentials.go | 6 +- builtin/logical/database/rollback.go | 109 +++++++ builtin/logical/database/rollback_test.go | 291 ++++++++++++++++++ changelog/_14046.txt | 3 + 4 files changed, 405 insertions(+), 4 deletions(-) create mode 100644 changelog/_14046.txt diff --git a/builtin/logical/database/path_rotate_credentials.go b/builtin/logical/database/path_rotate_credentials.go index c3242c0aee..f7caafe92f 100644 --- a/builtin/logical/database/path_rotate_credentials.go +++ b/builtin/logical/database/path_rotate_credentials.go @@ -189,9 +189,8 @@ func (b *databaseBackend) performRootRotation(ctx context.Context, req *logical. if err != nil { return nil, err } - config.ConnectionDetails["private_key"] = string(newPrivateKey) - oldPrivateKey := config.ConnectionDetails["private_key"].(string) + config.ConnectionDetails["private_key"] = string(newPrivateKey) walEntry = NewRotateRootCredentialsWALPrivateKeyEntry(name, rootUsername, string(newPublicKey), string(newPrivateKey), oldPrivateKey) updateReq = v5.UpdateUserRequest{ Username: rootUsername, @@ -210,9 +209,8 @@ func (b *databaseBackend) performRootRotation(ctx context.Context, req *logical. if err != nil { return nil, err } - config.ConnectionDetails["password"] = newPassword - oldPassword := config.ConnectionDetails["password"].(string) + config.ConnectionDetails["password"] = newPassword walEntry = NewRotateRootCredentialsWALPasswordEntry(name, rootUsername, newPassword, oldPassword) updateReq = v5.UpdateUserRequest{ Username: rootUsername, diff --git a/builtin/logical/database/rollback.go b/builtin/logical/database/rollback.go index 12ea638f54..bb16407756 100644 --- a/builtin/logical/database/rollback.go +++ b/builtin/logical/database/rollback.go @@ -5,7 +5,12 @@ package database import ( "context" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "errors" + "fmt" + "strings" "github.com/hashicorp/vault/sdk/database/dbplugin" v5 "github.com/hashicorp/vault/sdk/database/dbplugin/v5" @@ -18,6 +23,10 @@ import ( // WAL storage key used for the rollback of root database credentials const rotateRootWALKey = "rotateRootWALKey" +// snowflakeErrJWTTokenInvalid is the Snowflake server-side error code for JWT +// authentication failure. +const snowflakeErrJWTTokenInvalid = "390144" + // WAL entry used for the rollback of root database credentials type rotateRootCredentialsWAL struct { ConnectionName string @@ -71,6 +80,25 @@ func (b *databaseBackend) walRollback(ctx context.Context, req *logical.Request, return err } + // Route based on credential type in the WAL entry. + if entry.NewPrivateKey != "" { + // Stored key matches WAL new key: rotation completed, WAL not yet deleted. + if config.ConnectionDetails["private_key"] == entry.NewPrivateKey { + b.Logger().Info("WAL rollback: private key already rotated, nothing to roll back", + "connection", entry.ConnectionName) + return nil + } + + b.Logger().Warn("WAL rollback: private key out of sync, starting rollback", + "connection", entry.ConnectionName, "username", entry.UserName) + + if err := b.ClearConnection(entry.ConnectionName); err != nil { + return err + } + + return b.rollbackDatabasePrivateKey(ctx, config, entry) + } + // The password in storage doesn't match the new password // in the WAL entry. This means there was a partial failure // to update either the database or storage. @@ -148,3 +176,84 @@ func (b *databaseBackend) rollbackDatabaseCredentials(ctx context.Context, confi } return err } + +// rollbackDatabasePrivateKey restores the old public key on Snowflake for key-pair +// auth connections by connecting with the new private key and issuing ALTER USER. +func (b *databaseBackend) rollbackDatabasePrivateKey(ctx context.Context, config *DatabaseConfig, entry rotateRootCredentialsWAL) error { + oldPublicKey, err := derivePublicKeyFromPrivateKeyPEM(entry.OldPrivateKey) + if err != nil { + return fmt.Errorf("failed to derive old public key for rollback: %w", err) + } + + config.ConnectionDetails["private_key"] = entry.NewPrivateKey + dbi, err := b.GetConnectionWithConfig(ctx, entry.ConnectionName, config) + if err != nil { + b.Logger().Error("WAL rollback: failed to connect using new private key", "connection", entry.ConnectionName, "error", err.Error()) + return err + } + + defer func() { + if err := b.ClearConnection(entry.ConnectionName); err != nil { + b.Logger().Error("error closing database plugin connection", "error", err) + } + }() + + b.Logger().Info("WAL rollback: restoring old public key on Snowflake", "connection", entry.ConnectionName, "username", entry.UserName) + + updateReq := v5.UpdateUserRequest{ + Username: entry.UserName, + CredentialType: v5.CredentialTypeRSAPrivateKey, + PublicKey: &v5.ChangePublicKey{ + NewPublicKey: oldPublicKey, + Statements: v5.Statements{ + Commands: config.RootCredentialsRotateStatements, + }, + }, + } + + _, err = dbi.database.UpdateUser(ctx, updateReq, false) + if status.Code(err) == codes.Unimplemented || err == dbplugin.ErrPluginStaticUnsupported { + return nil + } + if err != nil { + // Snowflake error 390144 means JWT authentication failed. This occurs when + // the new private key was never registered with Snowflake (crash before UpdateUser), + // so the system is already consistent with the old key — delete the WAL cleanly. + if strings.Contains(err.Error(), snowflakeErrJWTTokenInvalid) { + b.Logger().Info("WAL rollback: new private key rejected by Snowflake (crash before UpdateUser), system already consistent", + "connection", entry.ConnectionName) + return nil + } + b.Logger().Error("WAL rollback: failed to restore old public key", "connection", entry.ConnectionName, "error", err.Error()) + return err + } + b.Logger().Info("WAL rollback: successfully restored old public key", "connection", entry.ConnectionName, "username", entry.UserName) + return nil +} + +func derivePublicKeyFromPrivateKeyPEM(privateKeyPEM string) ([]byte, error) { + block, _ := pem.Decode([]byte(privateKeyPEM)) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block from private key") + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private key is not an RSA key") + } + + pubKeyBytes, err := x509.MarshalPKIXPublicKey(&rsaKey.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal public key: %w", err) + } + + return pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: pubKeyBytes, + }), nil +} diff --git a/builtin/logical/database/rollback_test.go b/builtin/logical/database/rollback_test.go index 4320740619..eecf8715e3 100644 --- a/builtin/logical/database/rollback_test.go +++ b/builtin/logical/database/rollback_test.go @@ -5,6 +5,10 @@ package database import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "errors" "strings" "testing" @@ -20,6 +24,7 @@ import ( const ( databaseUser = "postgres" defaultPassword = "secret" + newPrivateKey = "new-private-key-pem" ) // Tests that the WAL rollback function rolls back the database password. @@ -532,3 +537,289 @@ func TestWalRollback_ConnectionFailed_TriggersRollback(t *testing.T) { t.Fatal("expected rollbackDatabaseCredentials to return an error") } } + +// configuredUpdateUserDatabase is a v5.Database mock whose Initialize always +// succeeds and whose UpdateUser returns a configurable error (nil for success). +type configuredUpdateUserDatabase struct { + updateUserErr error +} + +func (d *configuredUpdateUserDatabase) Initialize(_ context.Context, _ v5.InitializeRequest) (v5.InitializeResponse, error) { + return v5.InitializeResponse{}, nil +} + +func (d *configuredUpdateUserDatabase) NewUser(_ context.Context, _ v5.NewUserRequest) (v5.NewUserResponse, error) { + return v5.NewUserResponse{}, nil +} + +func (d *configuredUpdateUserDatabase) UpdateUser(_ context.Context, _ v5.UpdateUserRequest) (v5.UpdateUserResponse, error) { + return v5.UpdateUserResponse{}, d.updateUserErr +} + +func (d *configuredUpdateUserDatabase) DeleteUser(_ context.Context, _ v5.DeleteUserRequest) (v5.DeleteUserResponse, error) { + return v5.DeleteUserResponse{}, nil +} + +func (d *configuredUpdateUserDatabase) Type() (string, error) { return mockV5Type, nil } +func (d *configuredUpdateUserDatabase) Close() error { return nil } + +// generateTestRSAPrivateKeyPEM generates a 2048-bit RSA private key in +// PKCS#8 PEM format suitable for use with derivePublicKeyFromPrivateKeyPEM. +func generateTestRSAPrivateKeyPEM(t *testing.T) string { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + keyBytes, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + t.Fatal(err) + } + return string(pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: keyBytes, + })) +} + +// TestWalRollback_PrivateKey_RotationCompleted_NoRollback verifies that when +// the private key already stored matches the WAL new key, walRollback returns +// nil immediately — the rotation completed and the WAL simply wasn't deleted. +func TestWalRollback_PrivateKey_RotationCompleted_NoRollback(t *testing.T) { + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{SystemView: config.System} + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + entry, err := logical.StorageEntryJSON("config/mydb", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + PluginName: mockV5Type, + ConnectionDetails: map[string]interface{}{ + // Stored key matches WAL new key: rotation completed, WAL not yet GC'd. + "private_key": newPrivateKey, + }, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + walEntry := &rotateRootCredentialsWAL{ + ConnectionName: "mydb", + UserName: "root", + NewPrivateKey: newPrivateKey, + OldPrivateKey: "old-private-key-pem", + } + err = b.walRollback(context.Background(), &logical.Request{Storage: config.StorageView}, rotateRootWALKey, walEntry) + if err != nil { + t.Fatalf("expected no error when rotation already completed, got: %v", err) + } +} + +// TestWalRollback_PrivateKey_ConnectionFails_ReturnsError verifies that when +// connecting with the WAL new private key fails, the error propagates. +func TestWalRollback_PrivateKey_ConnectionFails_ReturnsError(t *testing.T) { + connErr := errors.New("JWT token rejected: invalid key pair") + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{ + SystemView: config.System, + builtinFactory: func() (interface{}, error) { + return &failingInitializeDatabase{err: connErr}, nil + }, + } + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + oldPrivateKey := generateTestRSAPrivateKeyPEM(t) + + entry, err := logical.StorageEntryJSON("config/mydb", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + VerifyConnection: true, + PluginName: mockV5Type, + ConnectionDetails: map[string]interface{}{ + // Stored key does not match new key: out-of-sync, triggers rollback path. + "private_key": "old-private-key-pem", + }, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + walEntry := &rotateRootCredentialsWAL{ + ConnectionName: "mydb", + UserName: "root", + NewPrivateKey: "new-private-key-pem", + OldPrivateKey: oldPrivateKey, + } + err = b.walRollback(context.Background(), &logical.Request{Storage: config.StorageView}, rotateRootWALKey, walEntry) + if err == nil { + t.Fatal("expected connection error to propagate, got nil") + } +} + +// TestWalRollback_PrivateKey_UpdateUserFails_ReturnsError verifies that a +// UpdateUser error propagates so the WAL framework can retry. +func TestWalRollback_PrivateKey_UpdateUserFails_ReturnsError(t *testing.T) { + updateErr := errors.New("internal server error") + + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{ + SystemView: config.System, + builtinFactory: func() (interface{}, error) { + return &configuredUpdateUserDatabase{updateUserErr: updateErr}, nil + }, + } + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + oldPrivateKey := generateTestRSAPrivateKeyPEM(t) + + entry, err := logical.StorageEntryJSON("config/mydb", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + VerifyConnection: true, + PluginName: mockV5Type, + ConnectionDetails: map[string]interface{}{ + "private_key": "old-private-key-pem", + }, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + walEntry := &rotateRootCredentialsWAL{ + ConnectionName: "mydb", + UserName: "root", + NewPrivateKey: "new-private-key-pem", + OldPrivateKey: oldPrivateKey, + } + err = b.walRollback(context.Background(), &logical.Request{Storage: config.StorageView}, rotateRootWALKey, walEntry) + if err == nil { + t.Fatal("expected UpdateUser error to propagate, got nil") + } +} + +// TestWalRollback_PrivateKey_RollbackSucceeds verifies the happy path: when the +// stored key is out-of-sync with the WAL new key and UpdateUser succeeds, +// walRollback returns nil indicating the rollback completed successfully. +func TestWalRollback_PrivateKey_RollbackSucceeds(t *testing.T) { + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{ + SystemView: config.System, + builtinFactory: func() (interface{}, error) { + return &configuredUpdateUserDatabase{updateUserErr: nil}, nil + }, + } + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + oldPrivateKey := generateTestRSAPrivateKeyPEM(t) + + entry, err := logical.StorageEntryJSON("config/mydb", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + VerifyConnection: true, + PluginName: mockV5Type, + ConnectionDetails: map[string]interface{}{ + "private_key": "old-private-key-pem", + }, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + walEntry := &rotateRootCredentialsWAL{ + ConnectionName: "mydb", + UserName: "root", + NewPrivateKey: "new-private-key-pem", + OldPrivateKey: oldPrivateKey, + } + err = b.walRollback(context.Background(), &logical.Request{Storage: config.StorageView}, rotateRootWALKey, walEntry) + if err != nil { + t.Fatalf("expected successful rollback, got: %v", err) + } +} + +// TestWalRollback_PrivateKey_SnowflakeJWTError_TreatsAsNoOp verifies the +// crash-before-UpdateUser safety path: when UpdateUser returns a Snowflake +// 390144 JWT error, the new private key was never registered with Snowflake, +// so the system is already consistent with the old key. walRollback must +// return nil to cleanly delete the WAL rather than retrying indefinitely. +func TestWalRollback_PrivateKey_SnowflakeJWTError_TreatsAsNoOp(t *testing.T) { + // Simulate the error Snowflake returns when the JWT is signed with a key + // it has never seen. The error crosses the gRPC plugin boundary as a plain + // string, so it is matched via strings.Contains against the error code. + jwtErr := errors.New("390144 (08001): JWT token is invalid") + + config := logical.TestBackendConfig() + config.System = &systemViewWrapper{ + SystemView: config.System, + builtinFactory: func() (interface{}, error) { + return &configuredUpdateUserDatabase{updateUserErr: jwtErr}, nil + }, + } + config.StorageView = &logical.InmemStorage{} + + b := Backend(config) + if err := b.Setup(context.Background(), config); err != nil { + t.Fatal(err) + } + defer b.Cleanup(context.Background()) + + oldPrivateKey := generateTestRSAPrivateKeyPEM(t) + + entry, err := logical.StorageEntryJSON("config/mydb", &DatabaseConfig{ + AllowedRoles: []string{"*"}, + VerifyConnection: true, + PluginName: mockV5Type, + ConnectionDetails: map[string]interface{}{ + // Stored key does not match new key: rollback path is entered. + "private_key": "old-private-key-pem", + }, + }) + if err != nil { + t.Fatal(err) + } + if err := config.StorageView.Put(context.Background(), entry); err != nil { + t.Fatal(err) + } + + walEntry := &rotateRootCredentialsWAL{ + ConnectionName: "mydb", + UserName: "root", + NewPrivateKey: "new-private-key-pem", + OldPrivateKey: oldPrivateKey, + } + err = b.walRollback(context.Background(), &logical.Request{Storage: config.StorageView}, rotateRootWALKey, walEntry) + if err != nil { + t.Fatalf("expected 390144 JWT error to be treated as no-op (nil), got: %v", err) + } +} diff --git a/changelog/_14046.txt b/changelog/_14046.txt new file mode 100644 index 0000000000..59217b75cc --- /dev/null +++ b/changelog/_14046.txt @@ -0,0 +1,3 @@ +```release-note:bug +database/snowflake: Fix WAL rollback issue for key-pair root credential rotation. +``` \ No newline at end of file