diff --git a/plugins/database/mongodb/connection_producer.go b/plugins/database/mongodb/connection_producer.go index bb0d8aa7a2..fbae8c4dcb 100644 --- a/plugins/database/mongodb/connection_producer.go +++ b/plugins/database/mongodb/connection_producer.go @@ -80,7 +80,6 @@ func (c *mongoDBConnectionProducer) Init(ctx context.Context, conf map[string]in return nil, err } - c.ConnectionURL = c.getConnectionURL() c.clientOptions = options.MergeClientOptions(writeOpts, authOpts) // Set initialized to true at this point since all fields are set, @@ -117,21 +116,30 @@ func (c *mongoDBConnectionProducer) Connection(ctx context.Context) (interface{} _ = c.client.Disconnect(ctx) } - if c.clientOptions == nil { - c.clientOptions = options.Client() - } - c.clientOptions.SetSocketTimeout(1 * time.Minute) - c.clientOptions.SetConnectTimeout(1 * time.Minute) - - var err error - opts := c.clientOptions.ApplyURI(c.ConnectionURL) - c.client, err = mongo.Connect(ctx, opts) + connURL := c.getConnectionURL() + client, err := createClient(ctx, connURL, c.clientOptions) if err != nil { return nil, err } + c.client = client return c.client, nil } +func createClient(ctx context.Context, connURL string, clientOptions *options.ClientOptions) (client *mongo.Client, err error) { + if clientOptions == nil { + clientOptions = options.Client() + } + clientOptions.SetSocketTimeout(1 * time.Minute) + clientOptions.SetConnectTimeout(1 * time.Minute) + + opts := clientOptions.ApplyURI(connURL) + client, err = mongo.Connect(ctx, opts) + if err != nil { + return nil, err + } + return client, nil +} + // Close terminates the database connection. func (c *mongoDBConnectionProducer) Close() error { c.Lock() diff --git a/plugins/database/mongodb/mongodb.go b/plugins/database/mongodb/mongodb.go index aafa42ba57..8af06839bb 100644 --- a/plugins/database/mongodb/mongodb.go +++ b/plugins/database/mongodb/mongodb.go @@ -3,7 +3,6 @@ package mongodb import ( "context" "encoding/json" - "errors" "fmt" "io" "strings" @@ -155,7 +154,8 @@ func (m *MongoDB) SetCredentials(ctx context.Context, statements dbplugin.Statem Password: password, } - cs, err := connstring.Parse(m.ConnectionURL) + connURL := m.getConnectionURL() + cs, err := connstring.Parse(connURL) if err != nil { return "", "", err } @@ -212,9 +212,33 @@ func (m *MongoDB) RevokeUser(ctx context.Context, statements dbplugin.Statements return m.runCommandWithRetry(ctx, db, dropUserCmd) } -// RotateRootCredentials is not currently supported on MongoDB +// RotateRootCredentials in MongoDB func (m *MongoDB) RotateRootCredentials(ctx context.Context, statements []string) (map[string]interface{}, error) { - return nil, errors.New("root credential rotation is not currently implemented in this database secrets engine") + // Grab the lock + m.Lock() + defer m.Unlock() + + if m.Username == "" { + return m.RawConfig, fmt.Errorf("username not specified for root credentials") + } + + password, err := m.GeneratePassword() + if err != nil { + return nil, err + } + + changeUserCmd := &updateUserCommand{ + Username: m.Username, + Password: password, + } + + if err := m.runCommandWithRetry(ctx, "admin", changeUserCmd); err != nil { + return nil, err + } + + m.RawConfig["password"] = password + m.Password = password + return m.RawConfig, nil } // runCommandWithRetry runs a command and retries once more if there's a failure diff --git a/plugins/database/mongodb/mongodb_test.go b/plugins/database/mongodb/mongodb_test.go index 1fe93cce39..503279a693 100644 --- a/plugins/database/mongodb/mongodb_test.go +++ b/plugins/database/mongodb/mongodb_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "net/url" "reflect" "strings" "testing" @@ -333,3 +334,130 @@ func appendToCertPool(t *testing.T, pool *x509.CertPool, caPem []byte) *x509.Cer } return pool } + +func TestMongoDB_RotateRootCredentials(t *testing.T) { + cleanup, connURL := mongodb.PrepareTestContainer(t, "latest") + defer cleanup() + + // Test to ensure that we can't rotate the root creds if no username has been specified + testCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + db := new() + connDetailsWithoutUsername := map[string]interface{}{ + "connection_url": connURL, + } + _, err := db.Init(testCtx, connDetailsWithoutUsername, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Rotate credentials should fail because no username is specified + cfg, err := db.RotateRootCredentials(testCtx, nil) + if err == nil { + t.Fatalf("successfully rotated root credentials when no username was present") + } + if !reflect.DeepEqual(cfg, connDetailsWithoutUsername) { + t.Fatalf("expected connection details: %#v but were %#v", connDetailsWithoutUsername, cfg) + } + + db.Close() + + // Reset the database object with new connection details + username := "vault-test-admin" + initialPassword := "myreallysecurepassword" + + db = new() + connDetailsWithUsername := map[string]interface{}{ + "connection_url": connURL, + "username": username, + "password": initialPassword, + } + _, err = db.Init(testCtx, connDetailsWithUsername, true) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Create root user + createUser(t, connURL, username, initialPassword) + initialURL := setUserPassOnURL(t, connURL, username, initialPassword) + + // Ensure the initial root user can connect + err = assertConnection(testCtx, initialURL) + if err != nil { + t.Fatalf("%s", err) + } + + // Rotate credentials + newCfg, err := db.RotateRootCredentials(testCtx, nil) + if err != nil { + t.Fatalf("unexpected err rotating root credentials: %s", err) + } + + // Ensure the initial root user can no longer connect + err = assertConnection(testCtx, initialURL) + if err == nil { + t.Fatalf("connection with initial credentials succeeded when it shouldn't have") + } + + // Ensure the new password can connect + newURL := setUserPassOnURL(t, connURL, username, newCfg["password"].(string)) + err = assertConnection(testCtx, newURL) + if err != nil { + t.Fatalf("unexpected error pinging client with new credentials: %s", err) + } +} + +func createUser(t *testing.T, connURL, username, password string) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client, err := createClient(ctx, connURL, nil) + if err != nil { + t.Fatalf("Unable to make initial connection: %s", err) + } + + createUserCmd := createUserCommand{ + Username: username, + Password: password, + Roles: []interface{}{ + "userAdminAnyDatabase", + "dbAdminAnyDatabase", + "readWriteAnyDatabase", + }, + } + + result := client.Database("admin").RunCommand(ctx, createUserCmd, nil) + err = result.Err() + if err != nil { + t.Fatalf("Unable to create admin user: %s", err) + } +} + +func assertConnection(testCtx context.Context, connURL string) error { + // Connect as initial root user and ensure the connection is successful + client, err := createClient(testCtx, connURL, nil) + if err != nil { + return fmt.Errorf("unable to create client connection with initial root user: %w", err) + } + + err = client.Ping(testCtx, nil) + if err != nil { + return fmt.Errorf("failed to ping server with initial root user: %w", err) + } + client.Disconnect(testCtx) + return nil +} + +func setUserPassOnURL(t *testing.T, connURL, username, password string) string { + t.Helper() + uri, err := url.Parse(connURL) + if err != nil { + t.Fatalf("unable to parse connection URL: %s", err) + } + + uri.User = url.UserPassword(username, password) + return uri.String() +}