From d6e5f7360eecd677cee70d5fd311c7cdf9399e1d Mon Sep 17 00:00:00 2001 From: Taddes Date: Thu, 16 Oct 2025 14:18:44 -0400 Subject: [PATCH] feat: postgres user methods (#1839) feat: postgres user methods --- tokenserver-db-common/src/params.rs | 6 +- tokenserver-db-postgres/src/models.rs | 392 +++++++++++++++++++------- 2 files changed, 291 insertions(+), 107 deletions(-) diff --git a/tokenserver-db-common/src/params.rs b/tokenserver-db-common/src/params.rs index d5f8a034..bcb314b1 100644 --- a/tokenserver-db-common/src/params.rs +++ b/tokenserver-db-common/src/params.rs @@ -23,8 +23,8 @@ pub struct PostService { } pub struct GetUsers { - pub service_id: i32, pub email: String, + pub service_id: i32, } #[derive(Clone, Default)] @@ -98,14 +98,14 @@ pub struct GetServiceId { #[cfg(debug_assertions)] pub struct SetUserCreatedAt { - pub uid: i64, pub created_at: i64, + pub uid: i64, } #[cfg(debug_assertions)] pub struct SetUserReplacedAt { - pub uid: i64, pub replaced_at: i64, + pub uid: i64, } #[cfg(debug_assertions)] diff --git a/tokenserver-db-postgres/src/models.rs b/tokenserver-db-postgres/src/models.rs index 9979e0d3..588883f3 100644 --- a/tokenserver-db-postgres/src/models.rs +++ b/tokenserver-db-postgres/src/models.rs @@ -1,9 +1,14 @@ +/// Note the addition of `#[cfg(debug_assertions)]` flags methods and +/// imports only to be added during debug builds. +/// cargo build --release will not include this code in the binary. use std::time::Duration; +#[cfg(debug_assertions)] +use std::time::{SystemTime, UNIX_EPOCH}; use super::pool::Conn; use async_trait::async_trait; use diesel::{ - sql_types::{BigInt, Float, Integer, Text}, + sql_types::{BigInt, Float, Integer, Nullable, Text}, OptionalExtension, }; use diesel_async::RunQueryDsl; @@ -48,13 +53,7 @@ impl TokenserverPgDb { // Services Table Methods - /** - Acquire service_id through passed in service string. - - SELECT id - FROM services - WHERE service = - */ + /// Acquire service_id through passed in service string. pub async fn get_service_id( &mut self, params: params::GetServiceId, @@ -76,13 +75,8 @@ impl TokenserverPgDb { } } - /** - Create a new service, given a provided service string and pattern. - Returns a service_id. - - INSERT INTO services (service, pattern) - VALUES (, ) - */ + // Create a new service, given a provided service string and pattern. + // Returns a service_id. #[cfg(debug_assertions)] pub async fn post_service( &mut self, @@ -109,21 +103,15 @@ impl TokenserverPgDb { // Nodes Table Methods - /** - Get Node with complete metadata, given a provided Node ID. - Returns a complete Node, including id, service_id, node string identifier - availability, and current load. - - SELECT * - FROM nodes - WHERE id = - */ + /// Get Node with complete metadata, given a provided Node ID. + /// Returns a complete Node, including id, service_id, node string identifier + /// availability, and current load. #[cfg(debug_assertions)] async fn get_node(&mut self, params: params::GetNode) -> DbResult { const QUERY: &str = r#" SELECT * - FROM nodes - WHERE id = $1 + FROM nodes + WHERE id = $1 "#; diesel::sql_query(QUERY) @@ -133,15 +121,8 @@ impl TokenserverPgDb { .map_err(Into::into) } - /** - Get the specific Node ID, given a provided service string and node. - Returns a node_id. - - SELECT id - FROM nodes - WHERE service = - AND node = - */ + /// Get the specific Node ID, given a provided service string and node. + /// Returns a node_id. async fn get_node_id(&mut self, params: params::GetNodeId) -> DbResult { const QUERY: &str = r#" SELECT id @@ -165,50 +146,38 @@ impl TokenserverPgDb { } } - /** - Get the best Node ID, which is the least loaded node with most available slots, - given a provided service string and node. - Returns a node_id and identifier string. - - SELECT id, node - FROM nodes - WHERE service = - AND available > 0 - AND capacity > current_load - AND downed = 0 - AND backoff = 0 - ORDER BY LOG(current_load) / LOG(capacity) - LIMIT 1 - */ + /// Get the best Node ID, which is the least loaded node with most available slots, + /// given a provided service string and node. + /// Returns a node_id and identifier string. async fn get_best_node( &mut self, params: params::GetBestNode, ) -> DbResult { const DEFAULT_CAPACITY_RELEASE_RATE: f32 = 0.1; const GET_BEST_NODE_QUERY: &str = r#" - SELECT id, node - FROM nodes - WHERE service = $1 - AND available > 0 - AND capacity > current_load - AND downed = 0 - AND backoff = 0 - ORDER BY LOG(current_load) / LOG(capacity) - LIMIT 1 + SELECT id, node + FROM nodes + WHERE service = $1 + AND available > 0 + AND capacity > current_load + AND downed = 0 + AND backoff = 0 + ORDER BY LOG(current_load) / LOG(capacity) + LIMIT 1 "#; const RELEASE_CAPACITY_QUERY: &str = r#" UPDATE nodes - SET available = LEAST(capacity * $1, capacity - current_load) - WHERE service = $2 - AND available <= 0 - AND capacity > current_load - AND downed = 0 + SET available = LEAST(capacity * $1, capacity - current_load) + WHERE service = $2 + AND available <= 0 + AND capacity > current_load + AND downed = 0 "#; const SPANNER_QUERY: &str = r#" SELECT id, node - FROM nodes - WHERE id = $1 - LIMIT 1 + FROM nodes + WHERE id = $1 + LIMIT 1 "#; let mut metrics = self.metrics.clone(); @@ -266,19 +235,13 @@ impl TokenserverPgDb { } } - /** - Create and Insert a new node. - Returns the last_insert_id of the newly created node. - - INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff) - VALUES (, , , , - , , ) - */ + /// Create and Insert a new node. + /// Returns the last_insert_id of the newly created node. #[cfg(debug_assertions)] async fn post_node(&mut self, params: params::PostNode) -> DbResult { const QUERY: &str = r#" INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff) - VALUES (?, ?, ?, ?, ?, ?, ?) + VALUES ($1, $2, $3, $4, $5, $6, $7) "#; diesel::sql_query(QUERY) .bind::(params.service_id) @@ -297,17 +260,9 @@ impl TokenserverPgDb { .map_err(Into::into) } - /** - Update the current load count of a node, passing in the service string and node string. - This represents the addition of a user to a node, while not defining which user specifically. - Does not return anything. - - UPDATE nodes - SET current_load = current_load + 1, - available = GREATEST(available - 1, 0) - WHERE service = - AND node = - */ + /// Update the current load count of a node, passing in the service string and node string. + /// This represents the addition of a user to a node, while not defining which user specifically. + /// Does not return anything. async fn add_user_to_node( &mut self, params: params::AddUserToNode, @@ -317,16 +272,16 @@ impl TokenserverPgDb { const QUERY: &str = r#" UPDATE nodes - SET current_load = current_load + 1, - available = GREATEST(available - 1, 0) - WHERE service = $1 - AND node = $2 + SET current_load = current_load + 1, + available = GREATEST(available - 1, 0) + WHERE service = $1 + AND node = $2 "#; const SPANNER_QUERY: &str = r#" UPDATE nodes - SET current_load = current_load + 1 - WHERE service = $1 - AND node = $2 + SET current_load = current_load + 1 + WHERE service = $1 + AND node = $2 "#; // Use the spanner query if the instance has spanner_node_id set. @@ -346,12 +301,7 @@ impl TokenserverPgDb { .map_err(Into::into) } - /** - Remove a node given the node ID. - Does not return anything. - - DELETE FROM nodes WHERE id = - */ + /// Remove a node given the node ID. #[cfg(debug_assertions)] async fn remove_node(&mut self, params: params::RemoveNode) -> DbResult { const QUERY: &str = "DELETE FROM nodes WHERE id = $1"; @@ -363,6 +313,240 @@ impl TokenserverPgDb { .map(|_| ()) .map_err(Into::into) } + + // Users Table Methods + + /// Given a user id, return a single user (GetUser) struct. + /// Contains all data relevant to particular user. + #[cfg(debug_assertions)] + async fn get_user(&mut self, params: params::GetUser) -> DbResult { + const QUERY: &str = r#" + SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at + FROM users + WHERE uid = $1 + "#; + + diesel::sql_query(QUERY) + .bind::(params.id) + .get_result::(&mut self.conn) + .await + .map_err(Into::into) + } + + /// Given a service_id and email, return all matching users (up to 20). + /// Returns vector of matching `GetUser` structs, a type alias for `GetRawUsers` + async fn get_users(&mut self, params: params::GetUsers) -> DbResult { + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.get_users", None); + + const QUERY: &str = r#" + SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at, + replaced_at + FROM users + LEFT OUTER JOIN nodes ON users.nodeid = nodes.id + WHERE email = $1 + AND users.service = $2 + ORDER BY created_at DESC, uid DESC + LIMIT 20 + "#; + + diesel::sql_query(QUERY) + .bind::(params.email) + .bind::(params.service_id) + .load::(&mut self.conn) + .await + .map_err(Into::into) + } + + /// Method to create a new user, given a `PostUser` struct containing data regarding the user. + #[cfg(debug_assertions)] + async fn post_user(&mut self, params: params::PostUser) -> DbResult { + const QUERY: &str = r#" + INSERT INTO users (service, email, generation, client_state, created_at, + nodeid, keys_changed_at, replaced_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, NULL) + RETURNING uid; + "#; + + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.post_user", None); + + diesel::sql_query(QUERY) + .bind::(params.service_id) + .bind::(params.email) + .bind::(params.generation) + .bind::(params.client_state) + .bind::(params.created_at) + .bind::(params.node_id) + .bind::, _>(params.keys_changed_at) + .get_result::(&mut self.conn) + .await + .map_err(Into::into) + } + + /// Update the user with the given email and service ID with the given `generation` and + /// `keys_changed_at`. Additionally, the other parameters ensure greater certainty to prevent + /// timestamp fields from regressing. More information below. + async fn put_user(&mut self, params: params::PutUser) -> DbResult { + // As an added layer of safety, the `WHERE` clause ensures that concurrent updates + // don't accidentally move timestamp fields backwards in time. The handling of + // `keys_changed_at`can be problematic as we want to treat the default `NULL` as zero (0). + const QUERY: &str = r#" + UPDATE users + SET generation = $1, + keys_changed_at = $2 + WHERE service = $3 + AND email = $4 + AND generation <= $5 + AND COALESCE(keys_changed_at, 0) <= COALESCE(, keys_changed_at, 0) + AND replaced_at IS NULL + "#; + + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.put_user", None); + + diesel::sql_query(QUERY) + .bind::(params.generation) + .bind::, _>(params.keys_changed_at) + .bind::(params.service_id) + .bind::(params.email) + .bind::(params.generation) + .bind::, _>(params.keys_changed_at) + .execute(&mut self.conn) + .await + .map(|_| ()) + .map_err(Into::into) + } + + /// Update the user record with the given uid and service id + /// marking it as 'replaced'. This is through updating the `replaced_at` field. + async fn replace_user( + &mut self, + params: params::ReplaceUser, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = $1 + WHERE service = $2 + AND uid = $3 + "#; + + diesel::sql_query(QUERY) + .bind::(params.replaced_at) + .bind::(params.service_id) + .bind::(params.uid) + .execute(&mut self.conn) + .await + .map(|_| ()) + .map_err(Into::into) + } + + /// Update several user records with the given email and service id + /// marking them as 'replaced'. This is through updating the `replaced_at` field. + /// The `replaced_at` field should be null AND the `created_at` field should be earlier + /// than the `replaced_at`. + async fn replace_users( + &mut self, + params: params::ReplaceUsers, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = $1 + WHERE service = $2 + AND email = $3 + AND replaced_at IS NULL + AND created_at < $4 + "#; + + let mut metrics = self.metrics.clone(); + metrics.start_timer("storage.replace_users", None); + diesel::sql_query(QUERY) + .bind::(params.replaced_at) + .bind::(params.service_id) + .bind::(params.email) + .bind::(params.replaced_at) + .execute(&mut self.conn) + .await + .map(|_| ()) + .map_err(Into::into) + } + + /// Given ONLY a particular `node_id`, update the users table to indicate an unassigned + /// node by updating the `replaced_at` field with the current time since Unix Epoch. + #[cfg(debug_assertions)] + async fn unassign_node( + &mut self, + params: params::UnassignNode, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = $1 + WHERE nodeid = $2 + "#; + + let current_time: i64 = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64; + + diesel::sql_query(QUERY) + .bind::(current_time) + .bind::(params.node_id) + .execute(&mut self.conn) + .await + .map(|_| ()) + .map_err(Into::into) + } + + /// Given ONLY a particular `uid`, update the users table `created_at` value + /// with the passed parameter. + #[cfg(debug_assertions)] + async fn set_user_created_at( + &mut self, + params: params::SetUserCreatedAt, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET created_at = $1 + WHERE uid = $2 + "#; + + diesel::sql_query(QUERY) + .bind::(params.created_at) + .bind::(params.uid) + .execute(&mut self.conn) + .await + .map(|_| ()) + .map_err(Into::into) + } + + /// Given ONLY a particular `uid`, update the users table `replaced_at` value + /// with the passed parameter. + #[cfg(debug_assertions)] + async fn set_user_replaced_at( + &mut self, + params: params::SetUserReplacedAt, + ) -> DbResult { + const QUERY: &str = r#" + UPDATE users + SET replaced_at = $1 + WHERE uid = $2 + "#; + + diesel::sql_query(QUERY) + .bind::(params.replaced_at) + .bind::(params.uid) + .execute(&mut self.conn) + .await + .map(|_| ()) + .map_err(Into::into) + } + + #[allow(dead_code)] + #[cfg(debug_assertions)] + fn set_spanner_node_id(&mut self, params: params::SpannerNodeId) { + self.spanner_node_id = params; + } } #[async_trait(?Send)] @@ -441,6 +625,10 @@ impl Db for TokenserverPgDb { TokenserverPgDb::get_user(self, params).await } + async fn get_users(&mut self, params: params::GetUsers) -> Result { + TokenserverPgDb::get_users(self, params).await + } + async fn get_or_create_user( &mut self, params: params::GetOrCreateUser, @@ -448,10 +636,6 @@ impl Db for TokenserverPgDb { TokenserverPgDb::get_or_create_user(self, params).await } - async fn get_users(&mut self, params: params::GetUsers) -> Result { - TokenserverPgDb::get_users(self, params).await - } - async fn post_user(&mut self, params: params::PostUser) -> Result { TokenserverPgDb::post_user(self, params).await }