fix: move I/O calls to blocking threadpool (#1190)

Closes #1188
This commit is contained in:
Ethan Donowitz 2021-12-21 12:34:04 -05:00 committed by GitHub
parent 46d4a9ea43
commit cbeebf465a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 304 additions and 295 deletions

View File

@ -36,7 +36,7 @@ docker_stop_spanner:
docker-compose -f docker-compose.spanner.yaml down
run:
RUST_LOG=debug RUST_BACKTRACE=full cargo run -- --config config/local.toml --features tokenserver_test_mode
RUST_LOG=debug RUST_BACKTRACE=full cargo run --features tokenserver_test_mode -- --config config/local.toml
run_spanner:
GOOGLE_APPLICATION_CREDENTIALS=$(PATH_TO_SYNC_SPANNER_KEYS) GRPC_DEFAULT_SSL_ROOTS_FILE_PATH=$(PATH_TO_GRPC_CERT) make run

View File

@ -51,10 +51,6 @@ impl Db for MockDb {
Box::pin(future::ok(results::PostUser::default()))
}
fn allocate_user(&self, _params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser> {
Box::pin(future::ok(results::AllocateUser::default()))
}
fn put_user(&self, _params: params::PutUser) -> DbFuture<'_, results::PutUser> {
Box::pin(future::ok(()))
}
@ -78,6 +74,10 @@ impl Db for MockDb {
Box::pin(future::ok(()))
}
fn get_users(&self, _params: params::GetUsers) -> DbFuture<'_, results::GetUsers> {
Box::pin(future::ok(results::GetUsers::default()))
}
fn get_or_create_user(
&self,
_params: params::GetOrCreateUser,
@ -106,11 +106,6 @@ impl Db for MockDb {
Box::pin(future::ok(results::GetUser::default()))
}
#[cfg(test)]
fn get_users(&self, _params: params::GetRawUsers) -> DbFuture<'_, results::GetRawUsers> {
Box::pin(future::ok(results::GetRawUsers::default()))
}
#[cfg(test)]
fn post_node(&self, _params: params::PostNode) -> DbFuture<'_, results::PostNode> {
Box::pin(future::ok(results::PostNode::default()))

View File

@ -165,6 +165,7 @@ impl TokenserverDb {
INSERT INTO users (service, email, generation, client_state, created_at, nodeid, keys_changed_at, replaced_at)
VALUES (?, ?, ?, ?, ?, ?, ?, NULL);
"#;
diesel::sql_query(QUERY)
.bind::<Integer, _>(user.service_id)
.bind::<Text, _>(&user.email)
@ -264,6 +265,171 @@ impl TokenserverDb {
.map_err(Into::into)
}
fn get_users_sync(&self, params: params::GetUsers) -> DbResult<results::GetUsers> {
const QUERY: &str = r#"
SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at,
replaced_at
FROM users
LEFT OUTER JOIN nodes ON users.nodeid = nodes.id
WHERE email = ?
AND users.service = ?
ORDER BY created_at DESC, uid DESC
LIMIT 20
"#;
diesel::sql_query(QUERY)
.bind::<Text, _>(&params.email)
.bind::<Integer, _>(params.service_id)
.load::<results::GetRawUser>(&self.inner.conn)
.map_err(Into::into)
}
/// Gets the user with the given email and service ID, or if one doesn't exist, allocates a new
/// user.
fn get_or_create_user_sync(
&self,
params: params::GetOrCreateUser,
) -> DbResult<results::GetOrCreateUser> {
let mut raw_users = self.get_users_sync(params::GetUsers {
service_id: params.service_id,
email: params.email.clone(),
})?;
if raw_users.is_empty() {
// There are no users in the database with the given email and service ID, so
// allocate a new one.
let allocate_user_result =
self.allocate_user_sync(params.clone() as params::AllocateUser)?;
Ok(results::GetOrCreateUser {
uid: allocate_user_result.uid,
email: params.email,
client_state: params.client_state,
generation: params.generation,
node: allocate_user_result.node,
keys_changed_at: params.keys_changed_at,
created_at: allocate_user_result.created_at,
replaced_at: None,
old_client_states: vec![],
})
} else {
raw_users.sort_by_key(|raw_user| (raw_user.generation, raw_user.created_at));
raw_users.reverse();
// The user with the greatest `generation` and `created_at` is the current user
let raw_user = raw_users[0].clone();
// Collect any old client states that differ from the current client state
let old_client_states = {
raw_users[1..]
.iter()
.map(|user| user.client_state.clone())
.filter(|client_state| client_state != &raw_user.client_state)
.collect()
};
// Make sure every old row is marked as replaced. They might not be, due to races in row
// creation.
for old_user in &raw_users[1..] {
if old_user.replaced_at.is_none() {
let params = params::ReplaceUser {
uid: old_user.uid,
service_id: params.service_id,
replaced_at: raw_user.created_at,
};
self.replace_user_sync(params)?;
}
}
match (raw_user.replaced_at, raw_user.node) {
// If the most up-to-date user is marked as replaced or does not have a node
// assignment, allocate a new user. Note that, if the current user is marked
// as replaced, we do not want to create a new user with the account metadata
// in the parameters to this method. Rather, we want to create a duplicate of
// the replaced user assigned to a new node. This distinction is important
// because the account metadata in the parameters to this method may not match
// that currently stored on the most up-to-date user and may be invalid.
(Some(_), _) | (_, None) if raw_user.generation < MAX_GENERATION => {
let allocate_user_result = {
self.allocate_user_sync(params::AllocateUser {
service_id: params.service_id,
email: params.email.clone(),
generation: raw_user.generation,
client_state: raw_user.client_state.clone(),
keys_changed_at: raw_user.keys_changed_at,
capacity_release_rate: params.capacity_release_rate,
})?
};
Ok(results::GetOrCreateUser {
uid: allocate_user_result.uid,
email: params.email,
client_state: raw_user.client_state,
generation: raw_user.generation,
node: allocate_user_result.node,
keys_changed_at: raw_user.keys_changed_at,
created_at: allocate_user_result.created_at,
replaced_at: None,
old_client_states,
})
}
// The most up-to-date user has a node. Note that this user may be retired or
// replaced.
(_, Some(node)) => Ok(results::GetOrCreateUser {
uid: raw_user.uid,
email: params.email,
client_state: raw_user.client_state,
generation: raw_user.generation,
node,
keys_changed_at: raw_user.keys_changed_at,
created_at: raw_user.created_at,
replaced_at: None,
old_client_states,
}),
// The most up-to-date user doesn't have a node and is retired.
(_, None) => Err(DbError::from(DbErrorKind::TokenserverUserRetired)),
}
}
}
/// Creates a new user and assigns them to a node.
fn allocate_user_sync(&self, params: params::AllocateUser) -> DbResult<results::AllocateUser> {
// Get the least-loaded node
let node = self.get_best_node_sync(params::GetBestNode {
service_id: params.service_id,
capacity_release_rate: params.capacity_release_rate,
})?;
// Decrement `available` and increment `current_load` on the node assigned to the user.
self.add_user_to_node_sync(params::AddUserToNode {
service_id: params.service_id,
node: node.node.clone(),
})?;
let created_at = {
let start = SystemTime::now();
start.duration_since(UNIX_EPOCH).unwrap().as_millis() as i64
};
let uid = self
.post_user_sync(params::PostUser {
service_id: params.service_id,
email: params.email.clone(),
generation: params.generation,
client_state: params.client_state.clone(),
created_at,
node_id: node.id,
keys_changed_at: params.keys_changed_at,
})?
.id;
Ok(results::AllocateUser {
uid,
node: node.node,
created_at,
})
}
#[cfg(test)]
fn set_user_created_at_sync(
&self,
@ -314,22 +480,6 @@ impl TokenserverDb {
.map_err(Into::into)
}
#[cfg(test)]
fn get_users_sync(&self, email: String) -> DbResult<results::GetRawUsers> {
const QUERY: &str = r#"
SELECT users.uid, users.email, users.client_state, users.generation,
users.keys_changed_at, users.created_at, users.replaced_at, nodes.node
FROM users
JOIN nodes
ON nodes.id = users.nodeid
WHERE users.email = ?
"#;
diesel::sql_query(QUERY)
.bind::<Text, _>(email)
.load::<results::GetRawUser>(&self.inner.conn)
.map_err(Into::into)
}
#[cfg(test)]
fn post_node_sync(&self, params: params::PostNode) -> DbResult<results::PostNode> {
const QUERY: &str = r#"
@ -419,175 +569,12 @@ impl Db for TokenserverDb {
sync_db_method!(replace_users, replace_users_sync, ReplaceUsers);
sync_db_method!(post_user, post_user_sync, PostUser);
/// Creates a new user and assigns them to a node.
fn allocate_user(&self, params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser> {
Box::pin(async move {
// Get the least-loaded node
let node = self
.get_best_node(params::GetBestNode {
service_id: params.service_id,
capacity_release_rate: params.capacity_release_rate,
})
.await?;
// Decrement `available` and increment `current_load` on the node assigned to the user.
self.add_user_to_node(params::AddUserToNode {
service_id: params.service_id,
node: node.node.clone(),
})
.await?;
let created_at = {
let start = SystemTime::now();
start.duration_since(UNIX_EPOCH).unwrap().as_millis() as i64
};
let uid = self
.post_user(params::PostUser {
service_id: params.service_id,
email: params.email.clone(),
generation: params.generation,
client_state: params.client_state.clone(),
created_at,
node_id: node.id,
keys_changed_at: params.keys_changed_at,
})
.await?
.id;
Ok(results::AllocateUser {
uid,
node: node.node,
created_at,
})
})
}
sync_db_method!(put_user, put_user_sync, PutUser);
sync_db_method!(get_node_id, get_node_id_sync, GetNodeId);
sync_db_method!(get_best_node, get_best_node_sync, GetBestNode);
sync_db_method!(add_user_to_node, add_user_to_node_sync, AddUserToNode);
/// Gets the user with the given email and service ID, or if one doesn't exist, allocates a new
/// user.
fn get_or_create_user(
&self,
params: params::GetOrCreateUser,
) -> DbFuture<'_, results::GetOrCreateUser> {
const QUERY: &str = r#"
SELECT uid, nodes.node, generation, keys_changed_at, client_state, created_at,
replaced_at
FROM users
LEFT OUTER JOIN nodes ON users.nodeid = nodes.id
WHERE email = ?
AND users.service = ?
ORDER BY created_at DESC, uid DESC
LIMIT 20
"#;
Box::pin(async move {
let mut raw_users = diesel::sql_query(QUERY)
.bind::<Text, _>(&params.email)
.bind::<Integer, _>(params.service_id)
.load::<results::GetRawUser>(&self.inner.conn)
.map_err(|e| ApiError::from(DbError::from(e)))?;
if raw_users.is_empty() {
// There are no users in the database with the given email and service ID, so
// allocate a new one.
let allocate_user_result = self
.allocate_user(params.clone() as params::AllocateUser)
.await?;
Ok(results::GetOrCreateUser {
uid: allocate_user_result.uid,
email: params.email.clone(),
client_state: params.client_state,
generation: params.generation,
node: allocate_user_result.node,
keys_changed_at: params.keys_changed_at,
created_at: allocate_user_result.created_at,
replaced_at: None,
old_client_states: vec![],
})
} else {
raw_users.sort_by_key(|raw_user| (raw_user.generation, raw_user.created_at));
raw_users.reverse();
// The user with the greatest `generation` and `created_at` is the current user
let raw_user = raw_users[0].clone();
// Collect any old client states that differ from the current client state
let old_client_states = raw_users[1..]
.iter()
.map(|user| user.client_state.clone())
.filter(|client_state| client_state != &raw_user.client_state)
.collect();
// Make sure every old row is marked as replaced. They might not be, due to races in row
// creation.
for old_user in &raw_users[1..] {
if old_user.replaced_at.is_none() {
let params = params::ReplaceUser {
uid: old_user.uid,
service_id: params.service_id,
replaced_at: raw_user.created_at,
};
self.replace_user(params).await?;
}
}
match (raw_user.replaced_at, raw_user.node) {
// If the most up-to-date user is marked as replaced or does not have a node
// assignment, allocate a new user. Note that, if the current user is marked
// as replaced, we do not want to create a new user with the account metadata
// in the parameters to this method. Rather, we want to create a duplicate of
// the replaced user assigned to a new node. This distinction is important
// because the account metadata in the parameters to this method may not match
// that currently stored on the most up-to-date user and may be invalid.
(Some(_), _) | (_, None) if raw_user.generation < MAX_GENERATION => {
let allocate_user_result = self
.allocate_user(params::AllocateUser {
service_id: params.service_id,
email: params.email.clone(),
generation: raw_user.generation,
client_state: raw_user.client_state.clone(),
keys_changed_at: raw_user.keys_changed_at,
capacity_release_rate: params.capacity_release_rate,
})
.await?;
Ok(results::GetOrCreateUser {
uid: allocate_user_result.uid,
email: params.email.clone(),
client_state: raw_user.client_state,
generation: raw_user.generation,
node: allocate_user_result.node,
keys_changed_at: raw_user.keys_changed_at,
created_at: allocate_user_result.created_at,
replaced_at: None,
old_client_states,
})
}
// The most up-to-date user has a node. Note that this user may be retired or
// replaced.
(_, Some(node)) => Ok(results::GetOrCreateUser {
uid: raw_user.uid,
email: params.email.clone(),
client_state: raw_user.client_state,
generation: raw_user.generation,
node,
keys_changed_at: raw_user.keys_changed_at,
created_at: raw_user.created_at,
replaced_at: None,
old_client_states,
}),
// The most up-to-date user doesn't have a node and is retired.
(_, None) => Err(DbError::from(DbErrorKind::TokenserverUserRetired).into()),
}
}
})
}
sync_db_method!(get_users, get_users_sync, GetUsers);
sync_db_method!(get_or_create_user, get_or_create_user_sync, GetOrCreateUser);
#[cfg(test)]
sync_db_method!(get_user, get_user_sync, GetUser);
@ -611,9 +598,6 @@ impl Db for TokenserverDb {
SetUserReplacedAt
);
#[cfg(test)]
sync_db_method!(get_users, get_users_sync, GetRawUsers);
#[cfg(test)]
sync_db_method!(post_node, post_node_sync, PostNode);
@ -637,8 +621,6 @@ pub trait Db {
fn post_user(&self, params: params::PostUser) -> DbFuture<'_, results::PostUser>;
fn allocate_user(&self, params: params::AllocateUser) -> DbFuture<'_, results::AllocateUser>;
fn put_user(&self, params: params::PutUser) -> DbFuture<'_, results::PutUser>;
fn check(&self) -> DbFuture<'_, results::Check>;
@ -652,6 +634,8 @@ pub trait Db {
params: params::AddUserToNode,
) -> DbFuture<'_, results::AddUserToNode>;
fn get_users(&self, params: params::GetUsers) -> DbFuture<'_, results::GetUsers>;
fn get_or_create_user(
&self,
params: params::GetOrCreateUser,
@ -672,9 +656,6 @@ pub trait Db {
#[cfg(test)]
fn get_user(&self, params: params::GetUser) -> DbFuture<'_, results::GetUser>;
#[cfg(test)]
fn get_users(&self, params: params::GetRawUsers) -> DbFuture<'_, results::GetRawUsers>;
#[cfg(test)]
fn post_node(&self, params: params::PostNode) -> DbFuture<'_, results::PostNode>;
@ -976,8 +957,18 @@ mod tests {
// Get all of the users
let users = {
let mut users1 = db.get_users(email1.to_owned()).await?;
let mut users2 = db.get_users(email2.to_owned()).await?;
let mut users1 = db
.get_users(params::GetUsers {
email: email1.to_owned(),
service_id: db::SYNC_1_5_SERVICE_ID,
})
.await?;
let mut users2 = db
.get_users(params::GetUsers {
email: email2.to_owned(),
service_id: db::SYNC_1_5_SERVICE_ID,
})
.await?;
users1.append(&mut users2);
users1
@ -1097,7 +1088,7 @@ mod tests {
#[tokio::test]
async fn test_node_allocation() -> Result<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let db = pool.get_tokenserver_db().await?;
// Add a node
let node_id = db
@ -1113,16 +1104,14 @@ mod tests {
.id;
// Allocating a user assigns it to the node
let user = db
.allocate_user(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})
.await?;
let user = db.allocate_user_sync(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})?;
assert_eq!(user.node, "https://node1");
// Getting the user from the database does not affect node assignment
@ -1135,7 +1124,7 @@ mod tests {
#[tokio::test]
async fn test_allocation_to_least_loaded_node() -> Result<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let db = pool.get_tokenserver_db().await?;
// Add two nodes
db.post_node(params::PostNode {
@ -1159,27 +1148,23 @@ mod tests {
.await?;
// Allocate two users
let user1 = db
.allocate_user(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test1@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})
.await?;
let user1 = db.allocate_user_sync(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test1@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})?;
let user2 = db
.allocate_user(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test2@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})
.await?;
let user2 = db.allocate_user_sync(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test2@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})?;
// Because users are always assigned to the least-loaded node, the users should have been
// assigned to different nodes
@ -1191,7 +1176,7 @@ mod tests {
#[tokio::test]
async fn test_allocation_is_not_allowed_to_downed_nodes() -> Result<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let db = pool.get_tokenserver_db().await?;
// Add a downed node
db.post_node(params::PostNode {
@ -1206,16 +1191,14 @@ mod tests {
.await?;
// User allocation fails because allocation is not allowed to downed nodes
let result = db
.allocate_user(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})
.await;
let result = db.allocate_user_sync(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
});
let error = result.unwrap_err();
assert_eq!(error.to_string(), "Unexpected error: unable to get a node");
@ -1225,7 +1208,7 @@ mod tests {
#[tokio::test]
async fn test_allocation_is_not_allowed_to_backoff_nodes() -> Result<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let db = pool.get_tokenserver_db().await?;
// Add a backoff node
db.post_node(params::PostNode {
@ -1240,16 +1223,14 @@ mod tests {
.await?;
// User allocation fails because allocation is not allowed to backoff nodes
let result = db
.allocate_user(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})
.await;
let result = db.allocate_user_sync(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
});
let error = result.unwrap_err();
assert_eq!(error.to_string(), "Unexpected error: unable to get a node");
@ -1259,7 +1240,7 @@ mod tests {
#[tokio::test]
async fn test_node_reassignment_when_records_are_replaced() -> Result<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let db = pool.get_tokenserver_db().await?;
// Add a node
db.post_node(params::PostNode {
@ -1273,16 +1254,14 @@ mod tests {
.await?;
// Allocate a user
let allocate_user_result = db
.allocate_user(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})
.await?;
let allocate_user_result = db.allocate_user_sync(params::AllocateUser {
service_id: db::SYNC_1_5_SERVICE_ID,
generation: 1234,
email: "test@test.com".to_owned(),
client_state: "616161".to_owned(),
keys_changed_at: Some(1234),
capacity_release_rate: None,
})?;
let user1 = db
.get_user(params::GetUser {
id: allocate_user_result.uid,

View File

@ -22,6 +22,11 @@ pub struct PostService {
pub pattern: String,
}
pub struct GetUsers {
pub service_id: i32,
pub email: String,
}
#[derive(Clone, Default)]
pub struct GetOrCreateUser {
pub service_id: i32,
@ -87,9 +92,6 @@ pub struct AddUserToNode {
pub node: String,
}
#[cfg(test)]
pub type GetRawUsers = String;
#[cfg(test)]
pub struct SetUserCreatedAt {
pub uid: i64,

View File

@ -62,6 +62,14 @@ impl TokenserverPool {
inner: builder.build(manager)?,
})
}
#[cfg(test)]
pub async fn get_tokenserver_db(&self) -> Result<Box<TokenserverDb>, DbError> {
let pool = self.clone();
let conn = block(move || pool.inner.get().map_err(DbError::from)).await?;
Ok(Box::new(TokenserverDb::new(conn)))
}
}
impl From<actix_web::error::BlockingError<DbError>> for DbError {

View File

@ -26,6 +26,8 @@ pub struct GetRawUser {
pub replaced_at: Option<i64>,
}
pub type GetUsers = Vec<GetRawUser>;
#[derive(Debug, Default, PartialEq)]
pub struct AllocateUser {
pub uid: i64,
@ -75,9 +77,6 @@ pub struct GetBestNode {
pub type AddUserToNode = ();
#[cfg(test)]
pub type GetRawUsers = Vec<GetRawUser>;
#[cfg(test)]
#[derive(Debug, Default, PartialEq, QueryableByName)]
pub struct GetUser {
@ -141,9 +140,6 @@ pub type SetUserCreatedAt = ();
#[cfg(test)]
pub type SetUserReplacedAt = ();
#[cfg(test)]
pub type GetUsers = Vec<GetUser>;
pub type Check = bool;
#[cfg(test)]

View File

@ -1,12 +1,16 @@
use std::fmt;
use actix_web::{error::ResponseError, http::StatusCode, HttpResponse};
use actix_web::{
error::{BlockingError, ResponseError},
http::StatusCode,
HttpResponse,
};
use serde::{
ser::{SerializeMap, Serializer},
Serialize,
};
#[derive(Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
pub struct TokenserverError {
pub status: &'static str,
pub location: ErrorLocation,
@ -99,6 +103,18 @@ impl TokenserverError {
}
}
impl From<BlockingError<TokenserverError>> for TokenserverError {
fn from(inner: BlockingError<TokenserverError>) -> Self {
match inner {
BlockingError::Error(e) => e,
BlockingError::Canceled => {
error!("Tokenserver threadpool operation canceled");
TokenserverError::internal_error()
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ErrorLocation {
Header,

View File

@ -8,7 +8,7 @@ use std::sync::Arc;
use actix_web::{
dev::Payload,
http::StatusCode,
web::{Data, Query},
web::{self, Data, Query},
Error, FromRequest, HttpRequest,
};
use actix_web_httpauth::extractors::bearer::BearerAuth;
@ -19,7 +19,7 @@ use regex::Regex;
use serde::Deserialize;
use sha2::Sha256;
use super::db::{self, models::Db, params, results};
use super::db::{self, models::Db, params, pool::DbPool, results};
use super::error::{ErrorLocation, TokenserverError};
use super::support::TokenData;
use super::NodeType;
@ -195,11 +195,7 @@ impl FromRequest for TokenserverRequest {
};
let email = format!("{}@{}", fxa_uid, state.fxa_email_domain);
let user = {
let db = state.db_pool.get().await.map_err(|_| {
error!("⚠️ Could not acquire database connection");
TokenserverError::internal_error()
})?;
let db = <Box<dyn Db>>::extract(&req).await?;
db.get_or_create_user(params::GetOrCreateUser {
service_id,
@ -258,6 +254,28 @@ impl FromRequest for Box<dyn Db> {
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
let req = req.clone();
Box::pin(async move {
<Box<dyn DbPool>>::extract(&req)
.await?
.get()
.await
.map_err(|_| {
error!("⚠️ Could not acquire database connection");
TokenserverError::internal_error().into()
})
})
}
}
impl FromRequest for Box<dyn DbPool> {
type Config = ();
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
fn from_request(req: &HttpRequest, _payload: &mut Payload) -> Self::Future {
let req = req.clone();
@ -265,13 +283,8 @@ impl FromRequest for Box<dyn Db> {
// XXX: Tokenserver state will no longer be an Option once the Tokenserver
// code is rolled out, so we will eventually be able to remove this unwrap().
let state = get_server_state(&req)?.as_ref().as_ref().unwrap();
let db = state.db_pool.get().await.map_err(|_| {
error!("⚠️ Could not acquire database connection");
TokenserverError::internal_error()
})?;
Ok(db)
Ok(state.db_pool.clone())
})
}
}
@ -303,20 +316,20 @@ impl FromRequest for TokenData {
let auth = BearerAuth::extract(&req)
.await
.map_err(|_| TokenserverError::invalid_credentials("Unsupported"))?;
// XXX: The Tokenserver state will no longer be an Option once the Tokenserver
// code is rolled out, so we will eventually be able to remove this unwrap().
let state = get_server_state(&req)?.as_ref().as_ref().unwrap();
let oauth_verifier = state.oauth_verifier.clone();
state
.oauth_verifier
.verify_token(auth.token())
web::block(move || oauth_verifier.verify_token(auth.token()))
.await
.map_err(TokenserverError::from)
.map_err(Into::into)
})
}
}
#[derive(Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
struct KeyId {
client_state: String,
keys_changed_at: i64,