diff --git a/Makefile b/Makefile index 33214952..8737881e 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ docker_stop_spanner: docker-compose -f docker-compose.spanner.yaml down run: - RUST_LOG=debug RUST_BACKTRACE=full cargo run -- --config config/local.toml --features tokenserver_test_mode + RUST_LOG=debug RUST_BACKTRACE=full cargo run --features tokenserver_test_mode -- --config config/local.toml run_spanner: GOOGLE_APPLICATION_CREDENTIALS=$(PATH_TO_SYNC_SPANNER_KEYS) GRPC_DEFAULT_SSL_ROOTS_FILE_PATH=$(PATH_TO_GRPC_CERT) make run diff --git a/src/tokenserver/db/mock.rs b/src/tokenserver/db/mock.rs index 6eb390ae..64a8915b 100644 --- a/src/tokenserver/db/mock.rs +++ b/src/tokenserver/db/mock.rs @@ -51,10 +51,6 @@ impl Db for MockDb { Box::pin(future::ok(results::PostUser::default())) } - fn allocate_user(&self, _params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser> { - Box::pin(future::ok(results::AllocateUser::default())) - } - fn put_user(&self, _params: params::PutUser) -> DbFuture<'_, results::PutUser> { Box::pin(future::ok(())) } @@ -78,6 +74,10 @@ impl Db for MockDb { Box::pin(future::ok(())) } + fn get_users(&self, _params: params::GetUsers) -> DbFuture<'_, results::GetUsers> { + Box::pin(future::ok(results::GetUsers::default())) + } + fn get_or_create_user( &self, _params: params::GetOrCreateUser, @@ -106,11 +106,6 @@ impl Db for MockDb { Box::pin(future::ok(results::GetUser::default())) } - #[cfg(test)] - fn get_users(&self, _params: params::GetRawUsers) -> DbFuture<'_, results::GetRawUsers> { - Box::pin(future::ok(results::GetRawUsers::default())) - } - #[cfg(test)] fn post_node(&self, _params: params::PostNode) -> DbFuture<'_, results::PostNode> { Box::pin(future::ok(results::PostNode::default())) diff --git a/src/tokenserver/db/models.rs b/src/tokenserver/db/models.rs index a5cb4846..4caf6e2d 100644 --- a/src/tokenserver/db/models.rs +++ b/src/tokenserver/db/models.rs @@ -165,6 +165,7 @@ impl TokenserverDb { INSERT INTO users (service, email, generation, client_state, created_at, nodeid, keys_changed_at, replaced_at) VALUES (?, ?, ?, ?, ?, ?, ?, NULL); "#; + diesel::sql_query(QUERY) .bind::(user.service_id) .bind::(&user.email) @@ -264,6 +265,171 @@ impl TokenserverDb { .map_err(Into::into) } + fn get_users_sync(&self, params: params::GetUsers) -> DbResult { + const QUERY: &str = r#" + SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at, + replaced_at + FROM users + LEFT OUTER JOIN nodes ON users.nodeid = nodes.id + WHERE email = ? + AND users.service = ? + ORDER BY created_at DESC, uid DESC + LIMIT 20 + "#; + + diesel::sql_query(QUERY) + .bind::(¶ms.email) + .bind::(params.service_id) + .load::(&self.inner.conn) + .map_err(Into::into) + } + + /// Gets the user with the given email and service ID, or if one doesn't exist, allocates a new + /// user. + fn get_or_create_user_sync( + &self, + params: params::GetOrCreateUser, + ) -> DbResult { + let mut raw_users = self.get_users_sync(params::GetUsers { + service_id: params.service_id, + email: params.email.clone(), + })?; + + if raw_users.is_empty() { + // There are no users in the database with the given email and service ID, so + // allocate a new one. + let allocate_user_result = + self.allocate_user_sync(params.clone() as params::AllocateUser)?; + + Ok(results::GetOrCreateUser { + uid: allocate_user_result.uid, + email: params.email, + client_state: params.client_state, + generation: params.generation, + node: allocate_user_result.node, + keys_changed_at: params.keys_changed_at, + created_at: allocate_user_result.created_at, + replaced_at: None, + old_client_states: vec![], + }) + } else { + raw_users.sort_by_key(|raw_user| (raw_user.generation, raw_user.created_at)); + raw_users.reverse(); + + // The user with the greatest `generation` and `created_at` is the current user + let raw_user = raw_users[0].clone(); + + // Collect any old client states that differ from the current client state + let old_client_states = { + raw_users[1..] + .iter() + .map(|user| user.client_state.clone()) + .filter(|client_state| client_state != &raw_user.client_state) + .collect() + }; + + // Make sure every old row is marked as replaced. They might not be, due to races in row + // creation. + for old_user in &raw_users[1..] { + if old_user.replaced_at.is_none() { + let params = params::ReplaceUser { + uid: old_user.uid, + service_id: params.service_id, + replaced_at: raw_user.created_at, + }; + + self.replace_user_sync(params)?; + } + } + + match (raw_user.replaced_at, raw_user.node) { + // If the most up-to-date user is marked as replaced or does not have a node + // assignment, allocate a new user. Note that, if the current user is marked + // as replaced, we do not want to create a new user with the account metadata + // in the parameters to this method. Rather, we want to create a duplicate of + // the replaced user assigned to a new node. This distinction is important + // because the account metadata in the parameters to this method may not match + // that currently stored on the most up-to-date user and may be invalid. + (Some(_), _) | (_, None) if raw_user.generation < MAX_GENERATION => { + let allocate_user_result = { + self.allocate_user_sync(params::AllocateUser { + service_id: params.service_id, + email: params.email.clone(), + generation: raw_user.generation, + client_state: raw_user.client_state.clone(), + keys_changed_at: raw_user.keys_changed_at, + capacity_release_rate: params.capacity_release_rate, + })? + }; + + Ok(results::GetOrCreateUser { + uid: allocate_user_result.uid, + email: params.email, + client_state: raw_user.client_state, + generation: raw_user.generation, + node: allocate_user_result.node, + keys_changed_at: raw_user.keys_changed_at, + created_at: allocate_user_result.created_at, + replaced_at: None, + old_client_states, + }) + } + // The most up-to-date user has a node. Note that this user may be retired or + // replaced. + (_, Some(node)) => Ok(results::GetOrCreateUser { + uid: raw_user.uid, + email: params.email, + client_state: raw_user.client_state, + generation: raw_user.generation, + node, + keys_changed_at: raw_user.keys_changed_at, + created_at: raw_user.created_at, + replaced_at: None, + old_client_states, + }), + // The most up-to-date user doesn't have a node and is retired. + (_, None) => Err(DbError::from(DbErrorKind::TokenserverUserRetired)), + } + } + } + + /// Creates a new user and assigns them to a node. + fn allocate_user_sync(&self, params: params::AllocateUser) -> DbResult { + // Get the least-loaded node + let node = self.get_best_node_sync(params::GetBestNode { + service_id: params.service_id, + capacity_release_rate: params.capacity_release_rate, + })?; + + // Decrement `available` and increment `current_load` on the node assigned to the user. + self.add_user_to_node_sync(params::AddUserToNode { + service_id: params.service_id, + node: node.node.clone(), + })?; + + let created_at = { + let start = SystemTime::now(); + start.duration_since(UNIX_EPOCH).unwrap().as_millis() as i64 + }; + let uid = self + .post_user_sync(params::PostUser { + service_id: params.service_id, + email: params.email.clone(), + generation: params.generation, + client_state: params.client_state.clone(), + created_at, + node_id: node.id, + keys_changed_at: params.keys_changed_at, + })? + .id; + + Ok(results::AllocateUser { + uid, + node: node.node, + created_at, + }) + } + #[cfg(test)] fn set_user_created_at_sync( &self, @@ -314,22 +480,6 @@ impl TokenserverDb { .map_err(Into::into) } - #[cfg(test)] - fn get_users_sync(&self, email: String) -> DbResult { - const QUERY: &str = r#" - SELECT users.uid, users.email, users.client_state, users.generation, - users.keys_changed_at, users.created_at, users.replaced_at, nodes.node - FROM users - JOIN nodes - ON nodes.id = users.nodeid - WHERE users.email = ? - "#; - diesel::sql_query(QUERY) - .bind::(email) - .load::(&self.inner.conn) - .map_err(Into::into) - } - #[cfg(test)] fn post_node_sync(&self, params: params::PostNode) -> DbResult { const QUERY: &str = r#" @@ -419,175 +569,12 @@ impl Db for TokenserverDb { sync_db_method!(replace_users, replace_users_sync, ReplaceUsers); sync_db_method!(post_user, post_user_sync, PostUser); - /// Creates a new user and assigns them to a node. - fn allocate_user(&self, params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser> { - Box::pin(async move { - // Get the least-loaded node - let node = self - .get_best_node(params::GetBestNode { - service_id: params.service_id, - capacity_release_rate: params.capacity_release_rate, - }) - .await?; - - // Decrement `available` and increment `current_load` on the node assigned to the user. - self.add_user_to_node(params::AddUserToNode { - service_id: params.service_id, - node: node.node.clone(), - }) - .await?; - - let created_at = { - let start = SystemTime::now(); - start.duration_since(UNIX_EPOCH).unwrap().as_millis() as i64 - }; - let uid = self - .post_user(params::PostUser { - service_id: params.service_id, - email: params.email.clone(), - generation: params.generation, - client_state: params.client_state.clone(), - created_at, - node_id: node.id, - keys_changed_at: params.keys_changed_at, - }) - .await? - .id; - - Ok(results::AllocateUser { - uid, - node: node.node, - created_at, - }) - }) - } - sync_db_method!(put_user, put_user_sync, PutUser); sync_db_method!(get_node_id, get_node_id_sync, GetNodeId); sync_db_method!(get_best_node, get_best_node_sync, GetBestNode); sync_db_method!(add_user_to_node, add_user_to_node_sync, AddUserToNode); - - /// Gets the user with the given email and service ID, or if one doesn't exist, allocates a new - /// user. - fn get_or_create_user( - &self, - params: params::GetOrCreateUser, - ) -> DbFuture<'_, results::GetOrCreateUser> { - const QUERY: &str = r#" - SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at, - replaced_at - FROM users - LEFT OUTER JOIN nodes ON users.nodeid = nodes.id - WHERE email = ? - AND users.service = ? - ORDER BY created_at DESC, uid DESC - LIMIT 20 - "#; - - Box::pin(async move { - let mut raw_users = diesel::sql_query(QUERY) - .bind::(¶ms.email) - .bind::(params.service_id) - .load::(&self.inner.conn) - .map_err(|e| ApiError::from(DbError::from(e)))?; - - if raw_users.is_empty() { - // There are no users in the database with the given email and service ID, so - // allocate a new one. - let allocate_user_result = self - .allocate_user(params.clone() as params::AllocateUser) - .await?; - - Ok(results::GetOrCreateUser { - uid: allocate_user_result.uid, - email: params.email.clone(), - client_state: params.client_state, - generation: params.generation, - node: allocate_user_result.node, - keys_changed_at: params.keys_changed_at, - created_at: allocate_user_result.created_at, - replaced_at: None, - old_client_states: vec![], - }) - } else { - raw_users.sort_by_key(|raw_user| (raw_user.generation, raw_user.created_at)); - raw_users.reverse(); - - // The user with the greatest `generation` and `created_at` is the current user - let raw_user = raw_users[0].clone(); - - // Collect any old client states that differ from the current client state - let old_client_states = raw_users[1..] - .iter() - .map(|user| user.client_state.clone()) - .filter(|client_state| client_state != &raw_user.client_state) - .collect(); - - // Make sure every old row is marked as replaced. They might not be, due to races in row - // creation. - for old_user in &raw_users[1..] { - if old_user.replaced_at.is_none() { - let params = params::ReplaceUser { - uid: old_user.uid, - service_id: params.service_id, - replaced_at: raw_user.created_at, - }; - - self.replace_user(params).await?; - } - } - - match (raw_user.replaced_at, raw_user.node) { - // If the most up-to-date user is marked as replaced or does not have a node - // assignment, allocate a new user. Note that, if the current user is marked - // as replaced, we do not want to create a new user with the account metadata - // in the parameters to this method. Rather, we want to create a duplicate of - // the replaced user assigned to a new node. This distinction is important - // because the account metadata in the parameters to this method may not match - // that currently stored on the most up-to-date user and may be invalid. - (Some(_), _) | (_, None) if raw_user.generation < MAX_GENERATION => { - let allocate_user_result = self - .allocate_user(params::AllocateUser { - service_id: params.service_id, - email: params.email.clone(), - generation: raw_user.generation, - client_state: raw_user.client_state.clone(), - keys_changed_at: raw_user.keys_changed_at, - capacity_release_rate: params.capacity_release_rate, - }) - .await?; - - Ok(results::GetOrCreateUser { - uid: allocate_user_result.uid, - email: params.email.clone(), - client_state: raw_user.client_state, - generation: raw_user.generation, - node: allocate_user_result.node, - keys_changed_at: raw_user.keys_changed_at, - created_at: allocate_user_result.created_at, - replaced_at: None, - old_client_states, - }) - } - // The most up-to-date user has a node. Note that this user may be retired or - // replaced. - (_, Some(node)) => Ok(results::GetOrCreateUser { - uid: raw_user.uid, - email: params.email.clone(), - client_state: raw_user.client_state, - generation: raw_user.generation, - node, - keys_changed_at: raw_user.keys_changed_at, - created_at: raw_user.created_at, - replaced_at: None, - old_client_states, - }), - // The most up-to-date user doesn't have a node and is retired. - (_, None) => Err(DbError::from(DbErrorKind::TokenserverUserRetired).into()), - } - } - }) - } + sync_db_method!(get_users, get_users_sync, GetUsers); + sync_db_method!(get_or_create_user, get_or_create_user_sync, GetOrCreateUser); #[cfg(test)] sync_db_method!(get_user, get_user_sync, GetUser); @@ -611,9 +598,6 @@ impl Db for TokenserverDb { SetUserReplacedAt ); - #[cfg(test)] - sync_db_method!(get_users, get_users_sync, GetRawUsers); - #[cfg(test)] sync_db_method!(post_node, post_node_sync, PostNode); @@ -637,8 +621,6 @@ pub trait Db { fn post_user(&self, params: params::PostUser) -> DbFuture<'_, results::PostUser>; - fn allocate_user(&self, params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser>; - fn put_user(&self, params: params::PutUser) -> DbFuture<'_, results::PutUser>; fn check(&self) -> DbFuture<'_, results::Check>; @@ -652,6 +634,8 @@ pub trait Db { params: params::AddUserToNode, ) -> DbFuture<'_, results::AddUserToNode>; + fn get_users(&self, params: params::GetUsers) -> DbFuture<'_, results::GetUsers>; + fn get_or_create_user( &self, params: params::GetOrCreateUser, @@ -672,9 +656,6 @@ pub trait Db { #[cfg(test)] fn get_user(&self, params: params::GetUser) -> DbFuture<'_, results::GetUser>; - #[cfg(test)] - fn get_users(&self, params: params::GetRawUsers) -> DbFuture<'_, results::GetRawUsers>; - #[cfg(test)] fn post_node(&self, params: params::PostNode) -> DbFuture<'_, results::PostNode>; @@ -976,8 +957,18 @@ mod tests { // Get all of the users let users = { - let mut users1 = db.get_users(email1.to_owned()).await?; - let mut users2 = db.get_users(email2.to_owned()).await?; + let mut users1 = db + .get_users(params::GetUsers { + email: email1.to_owned(), + service_id: db::SYNC_1_5_SERVICE_ID, + }) + .await?; + let mut users2 = db + .get_users(params::GetUsers { + email: email2.to_owned(), + service_id: db::SYNC_1_5_SERVICE_ID, + }) + .await?; users1.append(&mut users2); users1 @@ -1097,7 +1088,7 @@ mod tests { #[tokio::test] async fn test_node_allocation() -> Result<()> { let pool = db_pool().await?; - let db = pool.get().await?; + let db = pool.get_tokenserver_db().await?; // Add a node let node_id = db @@ -1113,16 +1104,14 @@ mod tests { .id; // Allocating a user assigns it to the node - let user = db - .allocate_user(params::AllocateUser { - service_id: db::SYNC_1_5_SERVICE_ID, - generation: 1234, - email: "test@test.com".to_owned(), - client_state: "616161".to_owned(), - keys_changed_at: Some(1234), - capacity_release_rate: None, - }) - .await?; + let user = db.allocate_user_sync(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; assert_eq!(user.node, "https://node1"); // Getting the user from the database does not affect node assignment @@ -1135,7 +1124,7 @@ mod tests { #[tokio::test] async fn test_allocation_to_least_loaded_node() -> Result<()> { let pool = db_pool().await?; - let db = pool.get().await?; + let db = pool.get_tokenserver_db().await?; // Add two nodes db.post_node(params::PostNode { @@ -1159,27 +1148,23 @@ mod tests { .await?; // Allocate two users - let user1 = db - .allocate_user(params::AllocateUser { - service_id: db::SYNC_1_5_SERVICE_ID, - generation: 1234, - email: "test1@test.com".to_owned(), - client_state: "616161".to_owned(), - keys_changed_at: Some(1234), - capacity_release_rate: None, - }) - .await?; + let user1 = db.allocate_user_sync(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test1@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; - let user2 = db - .allocate_user(params::AllocateUser { - service_id: db::SYNC_1_5_SERVICE_ID, - generation: 1234, - email: "test2@test.com".to_owned(), - client_state: "616161".to_owned(), - keys_changed_at: Some(1234), - capacity_release_rate: None, - }) - .await?; + let user2 = db.allocate_user_sync(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test2@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; // Because users are always assigned to the least-loaded node, the users should have been // assigned to different nodes @@ -1191,7 +1176,7 @@ mod tests { #[tokio::test] async fn test_allocation_is_not_allowed_to_downed_nodes() -> Result<()> { let pool = db_pool().await?; - let db = pool.get().await?; + let db = pool.get_tokenserver_db().await?; // Add a downed node db.post_node(params::PostNode { @@ -1206,16 +1191,14 @@ mod tests { .await?; // User allocation fails because allocation is not allowed to downed nodes - let result = db - .allocate_user(params::AllocateUser { - service_id: db::SYNC_1_5_SERVICE_ID, - generation: 1234, - email: "test@test.com".to_owned(), - client_state: "616161".to_owned(), - keys_changed_at: Some(1234), - capacity_release_rate: None, - }) - .await; + let result = db.allocate_user_sync(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }); let error = result.unwrap_err(); assert_eq!(error.to_string(), "Unexpected error: unable to get a node"); @@ -1225,7 +1208,7 @@ mod tests { #[tokio::test] async fn test_allocation_is_not_allowed_to_backoff_nodes() -> Result<()> { let pool = db_pool().await?; - let db = pool.get().await?; + let db = pool.get_tokenserver_db().await?; // Add a backoff node db.post_node(params::PostNode { @@ -1240,16 +1223,14 @@ mod tests { .await?; // User allocation fails because allocation is not allowed to backoff nodes - let result = db - .allocate_user(params::AllocateUser { - service_id: db::SYNC_1_5_SERVICE_ID, - generation: 1234, - email: "test@test.com".to_owned(), - client_state: "616161".to_owned(), - keys_changed_at: Some(1234), - capacity_release_rate: None, - }) - .await; + let result = db.allocate_user_sync(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + }); let error = result.unwrap_err(); assert_eq!(error.to_string(), "Unexpected error: unable to get a node"); @@ -1259,7 +1240,7 @@ mod tests { #[tokio::test] async fn test_node_reassignment_when_records_are_replaced() -> Result<()> { let pool = db_pool().await?; - let db = pool.get().await?; + let db = pool.get_tokenserver_db().await?; // Add a node db.post_node(params::PostNode { @@ -1273,16 +1254,14 @@ mod tests { .await?; // Allocate a user - let allocate_user_result = db - .allocate_user(params::AllocateUser { - service_id: db::SYNC_1_5_SERVICE_ID, - generation: 1234, - email: "test@test.com".to_owned(), - client_state: "616161".to_owned(), - keys_changed_at: Some(1234), - capacity_release_rate: None, - }) - .await?; + let allocate_user_result = db.allocate_user_sync(params::AllocateUser { + service_id: db::SYNC_1_5_SERVICE_ID, + generation: 1234, + email: "test@test.com".to_owned(), + client_state: "616161".to_owned(), + keys_changed_at: Some(1234), + capacity_release_rate: None, + })?; let user1 = db .get_user(params::GetUser { id: allocate_user_result.uid, diff --git a/src/tokenserver/db/params.rs b/src/tokenserver/db/params.rs index 25d2c410..0af100e8 100644 --- a/src/tokenserver/db/params.rs +++ b/src/tokenserver/db/params.rs @@ -22,6 +22,11 @@ pub struct PostService { pub pattern: String, } +pub struct GetUsers { + pub service_id: i32, + pub email: String, +} + #[derive(Clone, Default)] pub struct GetOrCreateUser { pub service_id: i32, @@ -87,9 +92,6 @@ pub struct AddUserToNode { pub node: String, } -#[cfg(test)] -pub type GetRawUsers = String; - #[cfg(test)] pub struct SetUserCreatedAt { pub uid: i64, diff --git a/src/tokenserver/db/pool.rs b/src/tokenserver/db/pool.rs index 0aa9978e..75e4bc28 100644 --- a/src/tokenserver/db/pool.rs +++ b/src/tokenserver/db/pool.rs @@ -62,6 +62,14 @@ impl TokenserverPool { inner: builder.build(manager)?, }) } + + #[cfg(test)] + pub async fn get_tokenserver_db(&self) -> Result, DbError> { + let pool = self.clone(); + let conn = block(move || pool.inner.get().map_err(DbError::from)).await?; + + Ok(Box::new(TokenserverDb::new(conn))) + } } impl From> for DbError { diff --git a/src/tokenserver/db/results.rs b/src/tokenserver/db/results.rs index 9d3efc0f..85f8a848 100644 --- a/src/tokenserver/db/results.rs +++ b/src/tokenserver/db/results.rs @@ -26,6 +26,8 @@ pub struct GetRawUser { pub replaced_at: Option, } +pub type GetUsers = Vec; + #[derive(Debug, Default, PartialEq)] pub struct AllocateUser { pub uid: i64, @@ -75,9 +77,6 @@ pub struct GetBestNode { pub type AddUserToNode = (); -#[cfg(test)] -pub type GetRawUsers = Vec; - #[cfg(test)] #[derive(Debug, Default, PartialEq, QueryableByName)] pub struct GetUser { @@ -141,9 +140,6 @@ pub type SetUserCreatedAt = (); #[cfg(test)] pub type SetUserReplacedAt = (); -#[cfg(test)] -pub type GetUsers = Vec; - pub type Check = bool; #[cfg(test)] diff --git a/src/tokenserver/error.rs b/src/tokenserver/error.rs index d7308bb1..7ad0721e 100644 --- a/src/tokenserver/error.rs +++ b/src/tokenserver/error.rs @@ -1,12 +1,16 @@ use std::fmt; -use actix_web::{error::ResponseError, http::StatusCode, HttpResponse}; +use actix_web::{ + error::{BlockingError, ResponseError}, + http::StatusCode, + HttpResponse, +}; use serde::{ ser::{SerializeMap, Serializer}, Serialize, }; -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct TokenserverError { pub status: &'static str, pub location: ErrorLocation, @@ -99,6 +103,18 @@ impl TokenserverError { } } +impl From> for TokenserverError { + fn from(inner: BlockingError) -> Self { + match inner { + BlockingError::Error(e) => e, + BlockingError::Canceled => { + error!("Tokenserver threadpool operation canceled"); + TokenserverError::internal_error() + } + } + } +} + #[derive(Clone, Copy, Debug, PartialEq)] pub enum ErrorLocation { Header, diff --git a/src/tokenserver/extractors.rs b/src/tokenserver/extractors.rs index bfc3938e..83bffe6e 100644 --- a/src/tokenserver/extractors.rs +++ b/src/tokenserver/extractors.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use actix_web::{ dev::Payload, http::StatusCode, - web::{Data, Query}, + web::{self, Data, Query}, Error, FromRequest, HttpRequest, }; use actix_web_httpauth::extractors::bearer::BearerAuth; @@ -19,7 +19,7 @@ use regex::Regex; use serde::Deserialize; use sha2::Sha256; -use super::db::{self, models::Db, params, results}; +use super::db::{self, models::Db, params, pool::DbPool, results}; use super::error::{ErrorLocation, TokenserverError}; use super::support::TokenData; use super::NodeType; @@ -195,11 +195,7 @@ impl FromRequest for TokenserverRequest { }; let email = format!("{}@{}", fxa_uid, state.fxa_email_domain); let user = { - let db = state.db_pool.get().await.map_err(|_| { - error!("⚠️ Could not acquire database connection"); - - TokenserverError::internal_error() - })?; + let db = >::extract(&req).await?; db.get_or_create_user(params::GetOrCreateUser { service_id, @@ -258,6 +254,28 @@ impl FromRequest for Box { type Error = Error; type Future = LocalBoxFuture<'static, Result>; + fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { + let req = req.clone(); + + Box::pin(async move { + >::extract(&req) + .await? + .get() + .await + .map_err(|_| { + error!("⚠️ Could not acquire database connection"); + + TokenserverError::internal_error().into() + }) + }) + } +} + +impl FromRequest for Box { + type Config = (); + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future { let req = req.clone(); @@ -265,13 +283,8 @@ impl FromRequest for Box { // XXX: Tokenserver state will no longer be an Option once the Tokenserver // code is rolled out, so we will eventually be able to remove this unwrap(). let state = get_server_state(&req)?.as_ref().as_ref().unwrap(); - let db = state.db_pool.get().await.map_err(|_| { - error!("⚠️ Could not acquire database connection"); - TokenserverError::internal_error() - })?; - - Ok(db) + Ok(state.db_pool.clone()) }) } } @@ -303,20 +316,20 @@ impl FromRequest for TokenData { let auth = BearerAuth::extract(&req) .await .map_err(|_| TokenserverError::invalid_credentials("Unsupported"))?; - // XXX: The Tokenserver state will no longer be an Option once the Tokenserver // code is rolled out, so we will eventually be able to remove this unwrap(). let state = get_server_state(&req)?.as_ref().as_ref().unwrap(); + let oauth_verifier = state.oauth_verifier.clone(); - state - .oauth_verifier - .verify_token(auth.token()) + web::block(move || oauth_verifier.verify_token(auth.token())) + .await + .map_err(TokenserverError::from) .map_err(Into::into) }) } } -#[derive(Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] struct KeyId { client_state: String, keys_changed_at: i64,