diff --git a/Cargo.lock b/Cargo.lock index ab205422..78a26aec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -824,11 +824,8 @@ dependencies = [ "byteorder", "diesel_derives", "itoa", - "mysqlclient-sys", - "percent-encoding", "pq-sys", "r2d2", - "url", ] [[package]] @@ -2092,17 +2089,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "mysqlclient-sys" -version = "0.4.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86a34a2bdec189f1060343ba712983e14cad7e87515cfd9ac4653e207535b6b1" -dependencies = [ - "pkg-config", - "semver", - "vcpkg", -] - [[package]] name = "nom" version = "5.1.3" @@ -3435,7 +3421,9 @@ dependencies = [ "async-trait", "backtrace", "base64", + "deadpool", "diesel", + "diesel-async", "diesel_logger", "diesel_migrations", "env_logger 0.11.8", @@ -3448,6 +3436,7 @@ dependencies = [ "syncstorage-db-common", "syncstorage-settings", "thiserror 1.0.69", + "tokio", "url", ] @@ -3709,6 +3698,7 @@ dependencies = [ "backtrace", "deadpool", "diesel", + "diesel-async", "diesel_migrations", "http 1.3.1", "serde 1.0.219", diff --git a/syncserver-db-common/Cargo.toml b/syncserver-db-common/Cargo.toml index 7bb63c30..add3da5d 100644 --- a/syncserver-db-common/Cargo.toml +++ b/syncserver-db-common/Cargo.toml @@ -11,8 +11,7 @@ deadpool.workspace = true futures.workspace = true http.workspace = true thiserror.workspace = true - -diesel = { workspace = true, features = ["mysql", "r2d2"] } -diesel-async = { workspace = true } -diesel_migrations = { workspace = true, features = ["mysql"] } +diesel.workspace = true +diesel-async.workspace = true +diesel_migrations.workspace = true syncserver-common = { path = "../syncserver-common" } diff --git a/syncserver-db-common/src/error.rs b/syncserver-db-common/src/error.rs index 9165d11a..ee0a794e 100644 --- a/syncserver-db-common/src/error.rs +++ b/syncserver-db-common/src/error.rs @@ -22,9 +22,6 @@ enum SqlErrorKind { #[error("An error occurred while establishing a db connection: {}", _0)] DieselConnection(#[from] diesel::result::ConnectionError), - #[error("A database pool error occurred: {}", _0)] - Pool(diesel::r2d2::PoolError), - #[error("Error migrating the database: {}", _0)] Migration(diesel_migrations::MigrationError), } @@ -41,18 +38,13 @@ impl From for SqlError { impl ReportableError for SqlError { fn is_sentry_event(&self) -> bool { - #[allow(clippy::match_like_matches_macro)] - match &self.kind { - SqlErrorKind::Pool(_) => false, - _ => true, - } + true } fn metric_label(&self) -> Option<&'static str> { Some(match self.kind { SqlErrorKind::DieselQuery(_) => "storage.sql.error.diesel_query", SqlErrorKind::DieselConnection(_) => "storage.sql.error.diesel_connection", - SqlErrorKind::Pool(_) => "storage.sql.error.pool", SqlErrorKind::Migration(_) => "storage.sql.error.migration", }) } @@ -70,7 +62,6 @@ from_error!( SqlError, SqlErrorKind::DieselConnection ); -from_error!(diesel::r2d2::PoolError, SqlError, SqlErrorKind::Pool); from_error!( diesel_migrations::MigrationError, SqlError, diff --git a/syncserver-db-common/src/lib.rs b/syncserver-db-common/src/lib.rs index da3c6cad..6acbce78 100644 --- a/syncserver-db-common/src/lib.rs +++ b/syncserver-db-common/src/lib.rs @@ -20,14 +20,6 @@ pub struct PoolState { pub idle_connections: u32, } -impl From for PoolState { - fn from(state: diesel::r2d2::State) -> PoolState { - PoolState { - connections: state.connections, - idle_connections: state.idle_connections, - } - } -} impl From for PoolState { fn from(status: deadpool::Status) -> PoolState { PoolState { @@ -38,17 +30,13 @@ impl From for PoolState { } #[macro_export] -macro_rules! sync_db_method { - ($name:ident, $sync_name:ident, $type:ident) => { - sync_db_method!($name, $sync_name, $type, results::$type); +macro_rules! async_db_method { + ($name:ident, $async_name:path, $type:ident) => { + async_db_method!($name, $async_name, $type, results::$type); }; - ($name:ident, $sync_name:ident, $type:ident, $result:ty) => { + ($name:ident, $async_name:path, $type:ident, $result:ty) => { fn $name(&mut self, params: params::$type) -> DbFuture<'_, $result, DbError> { - let mut db = self.clone(); - Box::pin( - self.blocking_threadpool - .spawn(move || db.$sync_name(params)), - ) + Box::pin($async_name(self, params)) } }; } diff --git a/syncserver-db-common/src/test.rs b/syncserver-db-common/src/test.rs index 68d4f079..77fb2a20 100644 --- a/syncserver-db-common/src/test.rs +++ b/syncserver-db-common/src/test.rs @@ -1,17 +1,6 @@ use deadpool::managed::{HookError, HookResult}; -use diesel::{mysql::MysqlConnection, r2d2::CustomizeConnection, Connection}; use diesel_async::{pooled_connection::PoolError, AsyncConnection}; -#[derive(Debug)] -pub struct TestTransactionCustomizer; - -impl CustomizeConnection for TestTransactionCustomizer { - fn on_acquire(&self, conn: &mut MysqlConnection) -> Result<(), diesel::r2d2::Error> { - conn.begin_test_transaction() - .map_err(diesel::r2d2::Error::QueryError) - } -} - pub async fn test_transaction_hook(conn: &mut T) -> HookResult where T: AsyncConnection, diff --git a/syncserver/src/server/mod.rs b/syncserver/src/server/mod.rs index 57d0c7b4..3c822f18 100644 --- a/syncserver/src/server/mod.rs +++ b/syncserver/src/server/mod.rs @@ -268,11 +268,12 @@ impl Server { let blocking_threadpool = Arc::new(BlockingThreadpool::new( settings.worker_max_blocking_threads, )); - let db_pool = DbPoolImpl::new( + let mut db_pool = DbPoolImpl::new( &settings.syncstorage, &Metrics::from(&metrics), blocking_threadpool.clone(), )?; + db_pool.init().await?; // Spawns sweeper that calls Deadpool `retain` method, clearing unused connections. db_pool.spawn_sweeper(Duration::from_secs( settings diff --git a/syncserver/src/server/test.rs b/syncserver/src/server/test.rs index 67038841..6adc9bc2 100644 --- a/syncserver/src/server/test.rs +++ b/syncserver/src/server/test.rs @@ -93,15 +93,21 @@ async fn get_test_state(settings: &Settings) -> ServerState { app_channel: settings.environment.clone(), }); + let mut db_pool = Box::new( + DbPoolImpl::new( + &settings.syncstorage, + &Metrics::from(&metrics), + blocking_threadpool, + ) + .expect("Could not get db_pool in get_test_state"), + ); + db_pool + .init() + .await + .expect("Could not init db_pool in get_test_state"); + ServerState { - db_pool: Box::new( - DbPoolImpl::new( - &settings.syncstorage, - &Metrics::from(&metrics), - blocking_threadpool, - ) - .expect("Could not get db_pool in get_test_state"), - ), + db_pool, limits: Arc::clone(&SERVER_LIMITS), limits_json: serde_json::to_string(&**SERVER_LIMITS).unwrap(), metrics, diff --git a/syncserver/src/tokenserver/mod.rs b/syncserver/src/tokenserver/mod.rs index 364ad6bb..05b0d3e1 100644 --- a/syncserver/src/tokenserver/mod.rs +++ b/syncserver/src/tokenserver/mod.rs @@ -38,7 +38,7 @@ impl ServerState { pub fn from_settings( settings: &Settings, metrics: Arc, - blocking_threadpool: Arc, + #[allow(unused_variables)] blocking_threadpool: Arc, ) -> Result { #[cfg(not(feature = "py_verifier"))] let oauth_verifier = { diff --git a/syncstorage-db-common/Cargo.toml b/syncstorage-db-common/Cargo.toml index 3a8362cf..0ef91e83 100644 --- a/syncstorage-db-common/Cargo.toml +++ b/syncstorage-db-common/Cargo.toml @@ -8,6 +8,7 @@ edition.workspace = true [dependencies] backtrace.workspace = true chrono.workspace = true +diesel.workspace = true futures.workspace = true lazy_static.workspace = true http.workspace = true @@ -16,6 +17,5 @@ serde_json.workspace = true thiserror.workspace = true async-trait = "0.1.88" -diesel = { workspace = true, features = ["mysql", "r2d2"] } syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } diff --git a/syncstorage-db-common/src/lib.rs b/syncstorage-db-common/src/lib.rs index 339a4b38..059e9fc9 100644 --- a/syncstorage-db-common/src/lib.rs +++ b/syncstorage-db-common/src/lib.rs @@ -51,6 +51,10 @@ pub const FIRST_CUSTOM_COLLECTION_ID: i32 = 101; pub trait DbPool: Sync + Send + Debug + GetPoolState { type Error; + async fn init(&mut self) -> Result<(), Self::Error> { + Ok(()) + } + async fn get(&self) -> Result>, Self::Error>; fn validate_batch_id(&self, params: params::ValidateBatchId) -> Result<(), Self::Error>; diff --git a/syncstorage-db/src/tests/support.rs b/syncstorage-db/src/tests/support.rs index 69cec46e..50cdb904 100644 --- a/syncstorage-db/src/tests/support.rs +++ b/syncstorage-db/src/tests/support.rs @@ -23,7 +23,8 @@ pub async fn db_pool(settings: Option) -> Result DbResult { +pub async fn create( + db: &mut MysqlDb, + params: params::CreateBatch, +) -> DbResult { let user_id = params.user_id.legacy_id as i64; - let collection_id = db.get_collection_id(¶ms.collection)?; + let collection_id = db.get_collection_id(¶ms.collection).await?; // Careful, there's some weirdness here! // // Sync timestamps are in seconds and quantized to two decimal places, so @@ -48,7 +52,8 @@ pub fn create(db: &mut MysqlDb, params: params::CreateBatch) -> DbResult break, Err(DieselError::DatabaseError(UniqueViolation, _)) => { @@ -61,14 +66,14 @@ pub fn create(db: &mut MysqlDb, params: params::CreateBatch) -> DbResult DbResult { +pub async fn validate(db: &mut MysqlDb, params: params::ValidateBatch) -> DbResult { let batch_id = decode_id(¶ms.id)?; // Avoid hitting the db for batches that are obviously too old. Recall // that the batchid is a millisecond timestamp. @@ -77,18 +82,19 @@ pub fn validate(db: &mut MysqlDb, params: params::ValidateBatch) -> DbResult("1")) .filter(batch_uploads::batch_id.eq(&batch_id)) .filter(batch_uploads::user_id.eq(&user_id)) .filter(batch_uploads::collection_id.eq(&collection_id)) - .get_result::(&mut *db.conn.write()?) + .get_result::(&mut db.conn) + .await .optional()?; Ok(exists.is_some()) } -pub fn append(db: &mut MysqlDb, params: params::AppendToBatch) -> DbResult<()> { +pub async fn append(db: &mut MysqlDb, params: params::AppendToBatch) -> DbResult<()> { let exists = validate( db, params::ValidateBatch { @@ -96,19 +102,23 @@ pub fn append(db: &mut MysqlDb, params: params::AppendToBatch) -> DbResult<()> { collection: params.collection.clone(), id: params.batch.id.clone(), }, - )?; + ) + .await?; if !exists { return Err(DbError::batch_not_found()); } let batch_id = decode_id(¶ms.batch.id)?; - let collection_id = db.get_collection_id(¶ms.collection)?; - do_append(db, batch_id, params.user_id, collection_id, params.bsos)?; + let collection_id = db.get_collection_id(¶ms.collection).await?; + do_append(db, batch_id, params.user_id, collection_id, params.bsos).await?; Ok(()) } -pub fn get(db: &mut MysqlDb, params: params::GetBatch) -> DbResult> { +pub async fn get( + db: &mut MysqlDb, + params: params::GetBatch, +) -> DbResult> { let is_valid = validate( db, params::ValidateBatch { @@ -116,7 +126,8 @@ pub fn get(db: &mut MysqlDb, params: params::GetBatch) -> DbResult DbResult DbResult<()> { +pub async fn delete(db: &mut MysqlDb, params: params::DeleteBatch) -> DbResult<()> { let batch_id = decode_id(¶ms.id)?; let user_id = params.user_id.legacy_id as i64; - let collection_id = db.get_collection_id(¶ms.collection)?; + let collection_id = db.get_collection_id(¶ms.collection).await?; diesel::delete(batch_uploads::table) .filter(batch_uploads::batch_id.eq(&batch_id)) .filter(batch_uploads::user_id.eq(&user_id)) .filter(batch_uploads::collection_id.eq(&collection_id)) - .execute(&mut *db.conn.write()?)?; + .execute(&mut db.conn) + .await?; diesel::delete(batch_upload_items::table) .filter(batch_upload_items::batch_id.eq(&batch_id)) .filter(batch_upload_items::user_id.eq(&user_id)) - .execute(&mut *db.conn.write()?)?; + .execute(&mut db.conn) + .await?; Ok(()) } /// Commits a batch to the bsos table, deleting the batch when succesful -pub fn commit(db: &mut MysqlDb, params: params::CommitBatch) -> DbResult { +pub async fn commit( + db: &mut MysqlDb, + params: params::CommitBatch, +) -> DbResult { let batch_id = decode_id(¶ms.batch.id)?; let user_id = params.user_id.legacy_id as i64; - let collection_id = db.get_collection_id(¶ms.collection)?; + let collection_id = db.get_collection_id(¶ms.collection).await?; let timestamp = db.timestamp(); sql_query(include_str!("batch_commit.sql")) .bind::(user_id) @@ -157,9 +173,10 @@ pub fn commit(db: &mut MysqlDb, params: params::CommitBatch) -> DbResult(user_id) .bind::(&db.timestamp().as_i64()) .bind::(&db.timestamp().as_i64()) - .execute(&mut *db.conn.write()?)?; + .execute(&mut db.conn) + .await?; - db.update_collection(user_id as u32, collection_id)?; + db.update_collection(user_id as u32, collection_id).await?; delete( db, @@ -168,11 +185,12 @@ pub fn commit(db: &mut MysqlDb, params: params::CommitBatch) -> DbResult(user_id.legacy_id as i64) .bind::(batch_id) - .get_results::(&mut *db.conn.write()?)? + .get_results::(&mut db.conn).await? { existing.insert(exist_idx( user_id.legacy_id, @@ -241,7 +259,8 @@ pub fn do_append( payload_size, ttl_offset: bso.ttl.map(|ttl| ttl as i32), }) - .execute(&mut *db.conn.write()?)?; + .execute(&mut db.conn) + .await?; } else { diesel::insert_into(batch_upload_items::table) .values(( @@ -253,7 +272,8 @@ pub fn do_append( batch_upload_items::payload_size.eq(payload_size), batch_upload_items::ttl_offset.eq(bso.ttl.map(|ttl| ttl as i32)), )) - .execute(&mut *db.conn.write()?)?; + .execute(&mut db.conn) + .await?; // make sure to include the key into our table check. existing.insert(exist_idx); } @@ -282,8 +302,8 @@ fn decode_id(id: &str) -> DbResult { macro_rules! batch_db_method { ($name:ident, $batch_name:ident, $type:ident) => { - pub fn $name(&mut self, params: params::$type) -> DbResult { - batch::$batch_name(self, params) + pub async fn $name(&mut self, params: params::$type) -> DbResult { + batch::$batch_name(self, params).await } }; } diff --git a/syncstorage-mysql/src/error.rs b/syncstorage-mysql/src/error.rs index 3373c9f9..1b70471c 100644 --- a/syncstorage-mysql/src/error.rs +++ b/syncstorage-mysql/src/error.rs @@ -41,6 +41,10 @@ impl DbError { pub fn quota() -> Self { DbErrorKind::Common(SyncstorageDbError::quota()).into() } + + pub fn pool_timeout(timeout_type: deadpool::managed::TimeoutType) -> Self { + DbErrorKind::PoolTimeout(timeout_type).into() + } } #[derive(Debug, Error)] @@ -50,6 +54,9 @@ enum DbErrorKind { #[error("{}", _0)] Mysql(SqlError), + + #[error("A database pool timeout occurred, type: {:?}", _0)] + PoolTimeout(deadpool::managed::TimeoutType), } impl From for DbError { @@ -96,6 +103,7 @@ impl ReportableError for DbError { Some(match &self.kind { DbErrorKind::Common(e) => e, DbErrorKind::Mysql(e) => e, + _ => return None, }) } @@ -103,6 +111,7 @@ impl ReportableError for DbError { match &self.kind { DbErrorKind::Common(e) => e.is_sentry_event(), DbErrorKind::Mysql(e) => e.is_sentry_event(), + DbErrorKind::PoolTimeout(_) => false, } } @@ -110,6 +119,7 @@ impl ReportableError for DbError { match &self.kind { DbErrorKind::Common(e) => e.metric_label(), DbErrorKind::Mysql(e) => e.metric_label(), + DbErrorKind::PoolTimeout(_) => Some("storage.diesel.pool.timeout"), } } @@ -117,6 +127,7 @@ impl ReportableError for DbError { match &self.kind { DbErrorKind::Common(e) => e.backtrace(), DbErrorKind::Mysql(e) => e.backtrace(), + _ => None, } } @@ -124,6 +135,7 @@ impl ReportableError for DbError { match &self.kind { DbErrorKind::Common(e) => e.tags(), DbErrorKind::Mysql(e) => e.tags(), + _ => vec![], } } } @@ -149,11 +161,6 @@ from_error!( error ))) ); -from_error!( - diesel::r2d2::PoolError, - DbError, - |error: diesel::r2d2::PoolError| DbError::from(DbErrorKind::Mysql(SqlError::from(error))) -); from_error!( diesel_migrations::MigrationError, DbError, @@ -166,9 +173,3 @@ from_error!( DbError, |error: std::boxed::Box| DbError::internal_error(error.to_string()) ); - -impl From> for DbError { - fn from(inner: std::sync::PoisonError) -> DbError { - DbError::internal_error(inner.to_string()) - } -} diff --git a/syncstorage-mysql/src/models.rs b/syncstorage-mysql/src/models.rs index 292bea4b..b2f32e5b 100644 --- a/syncstorage-mysql/src/models.rs +++ b/syncstorage-mysql/src/models.rs @@ -1,22 +1,18 @@ use futures::future::TryFutureExt; -use std::{self, cell::RefCell, collections::HashMap, fmt, ops::Deref, sync::Arc, sync::RwLock}; +use std::{collections::HashMap, fmt, sync::Arc}; use diesel::{ - connection::TransactionManager, delete, dsl::max, dsl::sql, - mysql::MysqlConnection, - r2d2::{ConnectionManager, PooledConnection}, sql_query, sql_types::{BigInt, Integer, Nullable, Text}, - Connection, ExpressionMethods, OptionalExtension, QueryDsl, RunQueryDsl, + ExpressionMethods, OptionalExtension, QueryDsl, }; -#[cfg(debug_assertions)] -use diesel_logger::LoggingConnection; -use syncserver_common::{BlockingThreadpool, Metrics}; -use syncserver_db_common::{sync_db_method, DbFuture}; +use diesel_async::{AsyncConnection, RunQueryDsl, TransactionManager}; +use syncserver_common::Metrics; +use syncserver_db_common::{async_db_method, DbFuture}; use syncstorage_db_common::{ error::DbErrorIntrospect, params, results, util::SyncTimestamp, Db, Sorting, UserIdentifier, DEFAULT_BSO_TTL, @@ -26,17 +22,11 @@ use syncstorage_settings::{Quota, DEFAULT_MAX_TOTAL_RECORDS}; use super::{ batch, error::DbError, - pool::CollectionCache, + pool::{CollectionCache, Conn}, schema::{bso, collections, user_collections}, DbResult, }; -type Conn = PooledConnection>; -#[cfg(not(debug_assertions))] -type InternalConn = Conn; -#[cfg(debug_assertions)] -type InternalConn = LoggingConnection; // display SQL when RUST_LOG="diesel_logger=trace" - // this is the max number of records we will return. static DEFAULT_LIMIT: u32 = DEFAULT_MAX_TOTAL_RECORDS; @@ -71,47 +61,23 @@ struct MysqlDbSession { in_write_transaction: bool, } -#[derive(Clone, Debug)] pub struct MysqlDb { - /// Synchronous Diesel calls are executed in web::block to satisfy the Db trait's asynchronous - /// interface. - /// - /// Arc provides a Clone impl utilized for safely moving to - /// the thread pool but does not provide Send as the underlying db - /// conn. structs are !Sync (Arc requires both for Send). See the Send impl - /// below. - pub(super) inner: Arc, - + pub(super) conn: Conn, + session: MysqlDbSession, /// Pool level cache of collection_ids and their names coll_cache: Arc, - - pub metrics: Metrics, - pub quota: Quota, - blocking_threadpool: Arc, + metrics: Metrics, + quota: Quota, } -/// Despite the db conn structs being !Sync (see Arc above) we -/// don't spawn multiple MysqlDb calls at a time in the thread pool. Calls are -/// queued to the thread pool via Futures, naturally serialized. -unsafe impl Send for MysqlDb {} - -pub struct MysqlDbInner { - pub(super) conn: RwLock, - - session: RefCell, -} - -impl fmt::Debug for MysqlDbInner { +impl fmt::Debug for MysqlDb { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MysqlDbInner {{ session: {:?} }}", self.session) - } -} - -impl Deref for MysqlDb { - type Target = MysqlDbInner; - - fn deref(&self) -> &Self::Target { - &self.inner + f.debug_struct("MysqlDb") + .field("session", &self.session) + .field("coll_cache", &self.coll_cache) + .field("metrics", &self.metrics) + .field("quota", &self.quota) + .finish() } } @@ -121,23 +87,13 @@ impl MysqlDb { coll_cache: Arc, metrics: &Metrics, quota: &Quota, - blocking_threadpool: Arc, ) -> Self { - let inner = MysqlDbInner { - #[cfg(not(debug_assertions))] - conn: RwLock::new(conn), - #[cfg(debug_assertions)] - conn: RwLock::new(LoggingConnection::new(conn)), - session: RefCell::new(Default::default()), - }; - // https://github.com/mozilla-services/syncstorage-rs/issues/1480 - #[allow(clippy::arc_with_non_send_sync)] MysqlDb { - inner: Arc::new(inner), + conn, + session: Default::default(), coll_cache, metrics: metrics.clone(), quota: *quota, - blocking_threadpool, } } @@ -150,22 +106,24 @@ impl MysqlDb { /// In theory it would be possible to use serializable transactions rather /// than explicit locking, but our ops team have expressed concerns about /// the efficiency of that approach at scale. - fn lock_for_read_sync(&mut self, params: params::LockCollection) -> DbResult<()> { + async fn lock_for_read(&mut self, params: params::LockCollection) -> DbResult<()> { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_collection_id(¶ms.collection).or_else(|e| { - if e.is_collection_not_found() { - // If the collection doesn't exist, we still want to start a - // transaction so it will continue to not exist. - Ok(0) - } else { - Err(e) - } - })?; + let collection_id = self + .get_collection_id(¶ms.collection) + .await + .or_else(|e| { + if e.is_collection_not_found() { + // If the collection doesn't exist, we still want to start a + // transaction so it will continue to not exist. + Ok(0) + } else { + Err(e) + } + })?; // If we already have a read or write lock then it's safe to // use it as-is. if self .session - .borrow() .coll_locks .contains_key(&(user_id as u32, collection_id)) { @@ -173,35 +131,33 @@ impl MysqlDb { } // Lock the db - self.begin(false)?; + self.begin(false).await?; let modified = user_collections::table .select(user_collections::modified) .filter(user_collections::user_id.eq(user_id)) .filter(user_collections::collection_id.eq(collection_id)) .for_share() - .first(&mut *self.conn.write()?) + .first(&mut self.conn) + .await .optional()?; if let Some(modified) = modified { let modified = SyncTimestamp::from_i64(modified)?; self.session - .borrow_mut() .coll_modified_cache .insert((user_id as u32, collection_id), modified); // why does it still expect a u32 int? } // XXX: who's responsible for unlocking (removing the entry) self.session - .borrow_mut() .coll_locks .insert((user_id as u32, collection_id), CollectionLock::Read); Ok(()) } - fn lock_for_write_sync(&mut self, params: params::LockCollection) -> DbResult<()> { + async fn lock_for_write(&mut self, params: params::LockCollection) -> DbResult<()> { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_or_create_collection_id(¶ms.collection)?; + let collection_id = self.get_or_create_collection_id(¶ms.collection).await?; if let Some(CollectionLock::Read) = self .session - .borrow() .coll_locks .get(&(user_id as u32, collection_id)) { @@ -211,13 +167,14 @@ impl MysqlDb { } // Lock the db - self.begin(true)?; + self.begin(true).await?; let modified = user_collections::table .select(user_collections::modified) .filter(user_collections::user_id.eq(user_id)) .filter(user_collections::collection_id.eq(collection_id)) .for_update() - .first(&mut *self.conn.write()?) + .first(&mut self.conn) + .await .optional()?; if let Some(modified) = modified { let modified = SyncTimestamp::from_i64(modified)?; @@ -226,51 +183,41 @@ impl MysqlDb { return Err(DbError::conflict()); } self.session - .borrow_mut() .coll_modified_cache .insert((user_id as u32, collection_id), modified); } self.session - .borrow_mut() .coll_locks .insert((user_id as u32, collection_id), CollectionLock::Write); Ok(()) } - pub(super) fn begin(&mut self, for_write: bool) -> DbResult<()> { - ::TransactionManager::begin_transaction( - &mut *self.conn.write()?, - )?; - self.session.borrow_mut().in_transaction = true; + pub(super) async fn begin(&mut self, for_write: bool) -> DbResult<()> { + ::TransactionManager::begin_transaction(&mut self.conn).await?; + self.session.in_transaction = true; if for_write { - self.session.borrow_mut().in_write_transaction = true; + self.session.in_write_transaction = true; } Ok(()) } - async fn begin_async(&mut self, for_write: bool) -> DbResult<()> { - self.begin(for_write) - } - - fn commit_sync(&mut self) -> DbResult<()> { - if self.session.borrow().in_transaction { - ::TransactionManager::commit_transaction( - &mut *self.conn.write()?, - )?; + async fn commit(&mut self) -> DbResult<()> { + if self.session.in_transaction { + ::TransactionManager::commit_transaction(&mut self.conn) + .await?; } Ok(()) } - fn rollback_sync(&mut self) -> DbResult<()> { - if self.session.borrow().in_transaction { - ::TransactionManager::rollback_transaction( - &mut *self.conn.write()?, - )?; + async fn rollback(&mut self) -> DbResult<()> { + if self.session.in_transaction { + ::TransactionManager::rollback_transaction(&mut self.conn) + .await?; } Ok(()) } - fn erect_tombstone(&mut self, user_id: i32) -> DbResult<()> { + async fn erect_tombstone(&mut self, user_id: i32) -> DbResult<()> { sql_query(format!( r#"INSERT INTO user_collections ({user_id}, {collection_id}, {modified}) VALUES (?, ?, ?) @@ -283,70 +230,77 @@ impl MysqlDb { .bind::(user_id as i64) .bind::(TOMBSTONE) .bind::(self.timestamp().as_i64()) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; Ok(()) } - fn delete_storage_sync(&mut self, user_id: UserIdentifier) -> DbResult<()> { + async fn delete_storage(&mut self, user_id: UserIdentifier) -> DbResult<()> { let user_id = user_id.legacy_id as i64; // Delete user data. delete(bso::table) .filter(bso::user_id.eq(user_id)) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; // Delete user collections. delete(user_collections::table) .filter(user_collections::user_id.eq(user_id)) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; Ok(()) } // Deleting the collection should result in: // - collection does not appear in /info/collections // - X-Last-Modified timestamp at the storage level changing - fn delete_collection_sync( + async fn delete_collection( &mut self, params: params::DeleteCollection, ) -> DbResult { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; let mut count = delete(bso::table) .filter(bso::user_id.eq(user_id)) .filter(bso::collection_id.eq(&collection_id)) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; count += delete(user_collections::table) .filter(user_collections::user_id.eq(user_id)) .filter(user_collections::collection_id.eq(&collection_id)) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; if count == 0 { return Err(DbError::collection_not_found()); } else { - self.erect_tombstone(user_id as i32)?; + self.erect_tombstone(user_id as i32).await?; } - self.get_storage_timestamp_sync(params.user_id) + self.get_storage_timestamp(params.user_id).await } - pub(super) fn get_or_create_collection_id(&mut self, name: &str) -> DbResult { + pub(super) async fn get_or_create_collection_id(&mut self, name: &str) -> DbResult { if let Some(id) = self.coll_cache.get_id(name)? { return Ok(id); } diesel::insert_or_ignore_into(collections::table) .values(collections::name.eq(name)) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; let id = collections::table .select(collections::id) .filter(collections::name.eq(name)) - .first(&mut *self.conn.write()?)?; + .first(&mut self.conn) + .await?; - if !self.session.borrow().in_write_transaction { + if !self.session.in_write_transaction { self.coll_cache.put(id, name.to_owned())?; } Ok(id) } - pub(super) fn get_collection_id(&mut self, name: &str) -> DbResult { + pub(super) async fn get_collection_id(&mut self, name: &str) -> DbResult { if let Some(id) = self.coll_cache.get_id(name)? { return Ok(id); } @@ -357,17 +311,18 @@ impl MysqlDb { WHERE name = ?", ) .bind::(name) - .get_result::(&mut *self.conn.write()?) + .get_result::(&mut self.conn) + .await .optional()? .ok_or_else(DbError::collection_not_found)? .id; - if !self.session.borrow().in_write_transaction { + if !self.session.in_write_transaction { self.coll_cache.put(id, name.to_owned())?; } Ok(id) } - fn _get_collection_name(&mut self, id: i32) -> DbResult { + async fn _get_collection_name(&mut self, id: i32) -> DbResult { let name = if let Some(name) = self.coll_cache.get_name(id)? { name } else { @@ -377,7 +332,8 @@ impl MysqlDb { WHERE id = ?", ) .bind::(&id) - .get_result::(&mut *self.conn.write()?) + .get_result::(&mut self.conn) + .await .optional()? .ok_or_else(DbError::collection_not_found)? .name @@ -385,7 +341,7 @@ impl MysqlDb { Ok(name) } - fn put_bso_sync(&mut self, bso: params::PutBso) -> DbResult { + async fn put_bso(&mut self, bso: params::PutBso) -> DbResult { /* if bso.payload.is_none() && bso.sortindex.is_none() && bso.ttl.is_none() { // XXX: go returns an error here (ErrNothingToDo), and is treated @@ -394,15 +350,17 @@ impl MysqlDb { } */ - let collection_id = self.get_or_create_collection_id(&bso.collection)?; + let collection_id = self.get_or_create_collection_id(&bso.collection).await?; let user_id: u64 = bso.user_id.legacy_id; let timestamp = self.timestamp().as_i64(); if self.quota.enabled { - let usage = self.get_quota_usage_sync(params::GetQuotaUsage { - user_id: bso.user_id.clone(), - collection: bso.collection.clone(), - collection_id, - })?; + let usage = self + .get_quota_usage(params::GetQuotaUsage { + user_id: bso.user_id.clone(), + collection: bso.collection.clone(), + collection_id, + }) + .await?; if usage.total_bytes >= self.quota.size { let mut tags = HashMap::default(); tags.insert("collection".to_owned(), bso.collection.clone()); @@ -476,13 +434,14 @@ impl MysqlDb { .bind::(payload) .bind::(timestamp) .bind::(timestamp + (i64::from(ttl) * 1000)) // remember: this is in millis - .execute(&mut *self.conn.write()?)?; - self.update_collection(user_id as u32, collection_id) + .execute(&mut self.conn) + .await?; + self.update_collection(user_id as u32, collection_id).await } - fn get_bsos_sync(&mut self, params: params::GetBsos) -> DbResult { + async fn get_bsos(&mut self, params: params::GetBsos) -> DbResult { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; let now = self.timestamp().as_i64(); let mut query = bso::table .select(( @@ -543,7 +502,7 @@ impl MysqlDb { // https://github.com/mozilla-services/server-syncstorage/blob/a0f8117/syncstorage/storage/sql/__init__.py#L404 query = query.offset(numeric_offset); } - let mut bsos = query.load::(&mut *self.conn.write()?)?; + let mut bsos = query.load::(&mut self.conn).await?; // XXX: an additional get_collection_timestamp is done here in // python to trigger potential CollectionNotFoundErrors @@ -570,9 +529,9 @@ impl MysqlDb { }) } - fn get_bso_ids_sync(&mut self, params: params::GetBsos) -> DbResult { + async fn get_bso_ids(&mut self, params: params::GetBsos) -> DbResult { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; let mut query = bso::table .select(bso::id) .filter(bso::user_id.eq(user_id)) @@ -613,7 +572,7 @@ impl MysqlDb { // https://github.com/mozilla-services/server-syncstorage/blob/a0f8117/syncstorage/storage/sql/__init__.py#L404 query = query.offset(numeric_offset); } - let mut ids = query.load::(&mut *self.conn.write()?)?; + let mut ids = query.load::(&mut self.conn).await?; // XXX: an additional get_collection_timestamp is done here in // python to trigger potential CollectionNotFoundErrors @@ -633,9 +592,9 @@ impl MysqlDb { }) } - fn get_bso_sync(&mut self, params: params::GetBso) -> DbResult> { + async fn get_bso(&mut self, params: params::GetBso) -> DbResult> { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; Ok(bso::table .select(( bso::id, @@ -648,38 +607,41 @@ impl MysqlDb { .filter(bso::collection_id.eq(&collection_id)) .filter(bso::id.eq(¶ms.id)) .filter(bso::expiry.ge(self.timestamp().as_i64())) - .get_result::(&mut *self.conn.write()?) + .get_result::(&mut self.conn) + .await .optional()?) } - fn delete_bso_sync(&mut self, params: params::DeleteBso) -> DbResult { + async fn delete_bso(&mut self, params: params::DeleteBso) -> DbResult { let user_id = params.user_id.legacy_id; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; let affected_rows = delete(bso::table) .filter(bso::user_id.eq(user_id as i64)) .filter(bso::collection_id.eq(&collection_id)) .filter(bso::id.eq(params.id)) .filter(bso::expiry.gt(&self.timestamp().as_i64())) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; if affected_rows == 0 { return Err(DbError::bso_not_found()); } - self.update_collection(user_id as u32, collection_id) + self.update_collection(user_id as u32, collection_id).await } - fn delete_bsos_sync(&mut self, params: params::DeleteBsos) -> DbResult { + async fn delete_bsos(&mut self, params: params::DeleteBsos) -> DbResult { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; delete(bso::table) .filter(bso::user_id.eq(user_id)) .filter(bso::collection_id.eq(&collection_id)) .filter(bso::id.eq_any(params.ids)) - .execute(&mut *self.conn.write()?)?; - self.update_collection(user_id as u32, collection_id) + .execute(&mut self.conn) + .await?; + self.update_collection(user_id as u32, collection_id).await } - fn post_bsos_sync(&mut self, input: params::PostBsos) -> DbResult { - let collection_id = self.get_or_create_collection_id(&input.collection)?; + async fn post_bsos(&mut self, input: params::PostBsos) -> DbResult { + let collection_id = self.get_or_create_collection_id(&input.collection).await?; let mut result = results::PostBsos { modified: self.timestamp(), success: Default::default(), @@ -688,14 +650,16 @@ impl MysqlDb { for pbso in input.bsos { let id = pbso.id; - let put_result = self.put_bso_sync(params::PutBso { - user_id: input.user_id.clone(), - collection: input.collection.clone(), - id: id.clone(), - payload: pbso.payload, - sortindex: pbso.sortindex, - ttl: pbso.ttl, - }); + let put_result = self + .put_bso(params::PutBso { + user_id: input.user_id.clone(), + collection: input.collection.clone(), + id: id.clone(), + payload: pbso.payload, + sortindex: pbso.sortindex, + ttl: pbso.ttl, + }) + .await; // XXX: python version doesn't report failures from db // layer.. (wouldn't db failures abort the entire transaction // anyway?) @@ -707,29 +671,30 @@ impl MysqlDb { } } } - self.update_collection(input.user_id.legacy_id as u32, collection_id)?; + self.update_collection(input.user_id.legacy_id as u32, collection_id) + .await?; Ok(result) } - fn get_storage_timestamp_sync(&mut self, user_id: UserIdentifier) -> DbResult { + async fn get_storage_timestamp(&mut self, user_id: UserIdentifier) -> DbResult { let user_id = user_id.legacy_id as i64; let modified = user_collections::table .select(max(user_collections::modified)) .filter(user_collections::user_id.eq(user_id)) - .first::>(&mut *self.conn.write()?)? + .first::>(&mut self.conn) + .await? .unwrap_or_default(); SyncTimestamp::from_i64(modified).map_err(Into::into) } - fn get_collection_timestamp_sync( + async fn get_collection_timestamp( &mut self, params: params::GetCollectionTimestamp, ) -> DbResult { let user_id = params.user_id.legacy_id as u32; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; if let Some(modified) = self .session - .borrow() .coll_modified_cache .get(&(user_id, collection_id)) { @@ -739,29 +704,31 @@ impl MysqlDb { .select(user_collections::modified) .filter(user_collections::user_id.eq(user_id as i64)) .filter(user_collections::collection_id.eq(collection_id)) - .first(&mut *self.conn.write()?) + .first(&mut self.conn) + .await .optional()? .ok_or_else(DbError::collection_not_found) } - fn get_bso_timestamp_sync( + async fn get_bso_timestamp( &mut self, params: params::GetBsoTimestamp, ) -> DbResult { let user_id = params.user_id.legacy_id as i64; - let collection_id = self.get_collection_id(¶ms.collection)?; + let collection_id = self.get_collection_id(¶ms.collection).await?; let modified = bso::table .select(bso::modified) .filter(bso::user_id.eq(user_id)) .filter(bso::collection_id.eq(&collection_id)) .filter(bso::id.eq(¶ms.id)) - .first::(&mut *self.conn.write()?) + .first::(&mut self.conn) + .await .optional()? .unwrap_or_default(); SyncTimestamp::from_i64(modified).map_err(Into::into) } - fn get_collection_timestamps_sync( + async fn get_collection_timestamps( &mut self, user_id: UserIdentifier, ) -> DbResult { @@ -776,7 +743,8 @@ impl MysqlDb { )) .bind::(user_id.legacy_id as i64) .bind::(TOMBSTONE) - .load::(&mut *self.conn.write()?)? + .load::(&mut self.conn) + .await? .into_iter() .map(|cr| { SyncTimestamp::from_i64(cr.last_modified) @@ -784,16 +752,19 @@ impl MysqlDb { .map_err(Into::into) }) .collect::>>()?; - self.map_collection_names(modifieds) + self.map_collection_names(modifieds).await } - fn check_sync(&mut self) -> DbResult { - diesel::sql_query("SELECT 1").execute(&mut *self.conn.write()?)?; + async fn check(&mut self) -> DbResult { + sql_query("SELECT 1").execute(&mut self.conn).await?; Ok(true) } - fn map_collection_names(&mut self, by_id: HashMap) -> DbResult> { - let mut names = self.load_collection_names(by_id.keys())?; + async fn map_collection_names( + &mut self, + by_id: HashMap, + ) -> DbResult> { + let mut names = self.load_collection_names(by_id.keys()).await?; by_id .into_iter() .map(|(id, value)| { @@ -804,7 +775,7 @@ impl MysqlDb { .collect() } - fn load_collection_names<'a>( + async fn load_collection_names<'a>( &mut self, collection_ids: impl Iterator, ) -> DbResult> { @@ -822,11 +793,12 @@ impl MysqlDb { let result = collections::table .select((collections::id, collections::name)) .filter(collections::id.eq_any(uncached)) - .load::<(i32, String)>(&mut *self.conn.write()?)?; + .load::<(i32, String)>(&mut self.conn) + .await?; for (id, name) in result { names.insert(id, name.clone()); - if !self.session.borrow().in_write_transaction { + if !self.session.in_write_transaction { self.coll_cache.put(id, name)?; } } @@ -835,13 +807,13 @@ impl MysqlDb { Ok(names) } - pub(super) fn update_collection( + pub(super) async fn update_collection( &mut self, user_id: u32, collection_id: i32, ) -> DbResult { let quota = if self.quota.enabled { - self.calc_quota_usage_sync(user_id, collection_id)? + self.calc_quota_usage(user_id, collection_id).await? } else { results::GetQuotaUsage { count: 0, @@ -874,12 +846,13 @@ impl MysqlDb { .bind::(×tamp) .bind::(&total_bytes) .bind::("a.count) - .execute(&mut *self.conn.write()?)?; + .execute(&mut self.conn) + .await?; Ok(self.timestamp()) } // Perform a lighter weight "read only" storage size check - fn get_storage_usage_sync( + async fn get_storage_usage( &mut self, user_id: UserIdentifier, ) -> DbResult { @@ -888,12 +861,13 @@ impl MysqlDb { .select(sql::>("SUM(LENGTH(payload))")) .filter(bso::user_id.eq(uid)) .filter(bso::expiry.gt(&self.timestamp().as_i64())) - .get_result::>(&mut *self.conn.write()?)?; + .get_result::>(&mut self.conn) + .await?; Ok(total_bytes.unwrap_or_default() as u64) } // Perform a lighter weight "read only" quota storage check - fn get_quota_usage_sync( + async fn get_quota_usage( &mut self, params: params::GetQuotaUsage, ) -> DbResult { @@ -905,7 +879,8 @@ impl MysqlDb { )) .filter(user_collections::user_id.eq(uid)) .filter(user_collections::collection_id.eq(params.collection_id)) - .get_result(&mut *self.conn.write()?) + .get_result(&mut self.conn) + .await .optional()? .unwrap_or_default(); Ok(results::GetQuotaUsage { @@ -915,7 +890,7 @@ impl MysqlDb { } // perform a heavier weight quota calculation - fn calc_quota_usage_sync( + async fn calc_quota_usage( &mut self, user_id: u32, collection_id: i32, @@ -928,7 +903,8 @@ impl MysqlDb { .filter(bso::user_id.eq(user_id as i64)) .filter(bso::expiry.gt(self.timestamp().as_i64())) .filter(bso::collection_id.eq(collection_id)) - .get_result(&mut *self.conn.write()?) + .get_result(&mut self.conn) + .await .optional()? .unwrap_or_default(); Ok(results::GetQuotaUsage { @@ -937,7 +913,7 @@ impl MysqlDb { }) } - fn get_collection_usage_sync( + async fn get_collection_usage( &mut self, user_id: UserIdentifier, ) -> DbResult { @@ -946,13 +922,14 @@ impl MysqlDb { .filter(bso::user_id.eq(user_id.legacy_id as i64)) .filter(bso::expiry.gt(&self.timestamp().as_i64())) .group_by(bso::collection_id) - .load(&mut *self.conn.write()?)? + .load(&mut self.conn) + .await? .into_iter() .collect(); - self.map_collection_names(counts) + self.map_collection_names(counts).await } - fn get_collection_counts_sync( + async fn get_collection_counts( &mut self, user_id: UserIdentifier, ) -> DbResult { @@ -967,24 +944,25 @@ impl MysqlDb { .filter(bso::user_id.eq(user_id.legacy_id as i64)) .filter(bso::expiry.gt(&self.timestamp().as_i64())) .group_by(bso::collection_id) - .load(&mut *self.conn.write()?)? + .load(&mut self.conn) + .await? .into_iter() .collect(); - self.map_collection_names(counts) + self.map_collection_names(counts).await } - batch_db_method!(create_batch_sync, create, CreateBatch); - batch_db_method!(validate_batch_sync, validate, ValidateBatch); - batch_db_method!(append_to_batch_sync, append, AppendToBatch); - batch_db_method!(commit_batch_sync, commit, CommitBatch); - batch_db_method!(delete_batch_sync, delete, DeleteBatch); + batch_db_method!(create_batch, create, CreateBatch); + batch_db_method!(validate_batch, validate, ValidateBatch); + batch_db_method!(append_to_batch, append, AppendToBatch); + batch_db_method!(commit_batch, commit, CommitBatch); + batch_db_method!(delete_batch, delete, DeleteBatch); - fn get_batch_sync(&mut self, params: params::GetBatch) -> DbResult> { - batch::get(self, params) + async fn get_batch(&mut self, params: params::GetBatch) -> DbResult> { + batch::get(self, params).await } pub(super) fn timestamp(&self) -> SyncTimestamp { - self.session.borrow().timestamp + self.session.timestamp } } @@ -992,86 +970,86 @@ impl Db for MysqlDb { type Error = DbError; fn commit(&mut self) -> DbFuture<'_, (), Self::Error> { - let mut db = self.clone(); - Box::pin(self.blocking_threadpool.spawn(move || db.commit_sync())) + Box::pin(self.commit()) } fn rollback(&mut self) -> DbFuture<'_, (), Self::Error> { - let mut db = self.clone(); - Box::pin(self.blocking_threadpool.spawn(move || db.rollback_sync())) + Box::pin(self.rollback()) } fn begin(&mut self, for_write: bool) -> DbFuture<'_, (), Self::Error> { - let mut db = self.clone(); - Box::pin(async move { db.begin_async(for_write).map_err(Into::into).await }) + Box::pin(self.begin(for_write)) } fn check(&mut self) -> DbFuture<'_, results::Check, Self::Error> { - let mut db = self.clone(); - Box::pin(self.blocking_threadpool.spawn(move || db.check_sync())) + Box::pin(self.check()) } - sync_db_method!(lock_for_read, lock_for_read_sync, LockCollection); - sync_db_method!(lock_for_write, lock_for_write_sync, LockCollection); - sync_db_method!( + async_db_method!(lock_for_read, MysqlDb::lock_for_read, LockCollection); + async_db_method!(lock_for_write, MysqlDb::lock_for_write, LockCollection); + async_db_method!( get_collection_timestamps, - get_collection_timestamps_sync, + MysqlDb::get_collection_timestamps, GetCollectionTimestamps ); - sync_db_method!( + async_db_method!( get_collection_timestamp, - get_collection_timestamp_sync, + MysqlDb::get_collection_timestamp, GetCollectionTimestamp ); - sync_db_method!( + async_db_method!( get_collection_counts, - get_collection_counts_sync, + MysqlDb::get_collection_counts, GetCollectionCounts ); - sync_db_method!( + async_db_method!( get_collection_usage, - get_collection_usage_sync, + MysqlDb::get_collection_usage, GetCollectionUsage ); - sync_db_method!( + async_db_method!( get_storage_timestamp, - get_storage_timestamp_sync, + MysqlDb::get_storage_timestamp, GetStorageTimestamp ); - sync_db_method!(get_storage_usage, get_storage_usage_sync, GetStorageUsage); - sync_db_method!(get_quota_usage, get_quota_usage_sync, GetQuotaUsage); - sync_db_method!(delete_storage, delete_storage_sync, DeleteStorage); - sync_db_method!(delete_collection, delete_collection_sync, DeleteCollection); - sync_db_method!(delete_bsos, delete_bsos_sync, DeleteBsos); - sync_db_method!(get_bsos, get_bsos_sync, GetBsos); - sync_db_method!(get_bso_ids, get_bso_ids_sync, GetBsoIds); - sync_db_method!(post_bsos, post_bsos_sync, PostBsos); - sync_db_method!(delete_bso, delete_bso_sync, DeleteBso); - sync_db_method!(get_bso, get_bso_sync, GetBso, Option); - sync_db_method!( + async_db_method!( + get_storage_usage, + MysqlDb::get_storage_usage, + GetStorageUsage + ); + async_db_method!(get_quota_usage, MysqlDb::get_quota_usage, GetQuotaUsage); + async_db_method!(delete_storage, MysqlDb::delete_storage, DeleteStorage); + async_db_method!( + delete_collection, + MysqlDb::delete_collection, + DeleteCollection + ); + async_db_method!(delete_bsos, MysqlDb::delete_bsos, DeleteBsos); + async_db_method!(get_bsos, MysqlDb::get_bsos, GetBsos); + async_db_method!(get_bso_ids, MysqlDb::get_bso_ids, GetBsoIds); + async_db_method!(post_bsos, MysqlDb::post_bsos, PostBsos); + async_db_method!(delete_bso, MysqlDb::delete_bso, DeleteBso); + async_db_method!(get_bso, MysqlDb::get_bso, GetBso, Option); + async_db_method!( get_bso_timestamp, - get_bso_timestamp_sync, + MysqlDb::get_bso_timestamp, GetBsoTimestamp, results::GetBsoTimestamp ); - sync_db_method!(put_bso, put_bso_sync, PutBso); - sync_db_method!(create_batch, create_batch_sync, CreateBatch); - sync_db_method!(validate_batch, validate_batch_sync, ValidateBatch); - sync_db_method!(append_to_batch, append_to_batch_sync, AppendToBatch); - sync_db_method!( + async_db_method!(put_bso, MysqlDb::put_bso, PutBso); + async_db_method!(create_batch, MysqlDb::create_batch, CreateBatch); + async_db_method!(validate_batch, MysqlDb::validate_batch, ValidateBatch); + async_db_method!(append_to_batch, MysqlDb::append_to_batch, AppendToBatch); + async_db_method!( get_batch, - get_batch_sync, + MysqlDb::get_batch, GetBatch, Option ); - sync_db_method!(commit_batch, commit_batch_sync, CommitBatch); + async_db_method!(commit_batch, MysqlDb::commit_batch, CommitBatch); fn get_collection_id(&mut self, name: String) -> DbFuture<'_, i32, Self::Error> { - let mut db = self.clone(); - Box::pin( - self.blocking_threadpool - .spawn(move || db.get_collection_id(&name)), - ) + Box::pin(async move { self.get_collection_id(&name).map_err(Into::into).await }) } fn get_connection_info(&self) -> results::ConnectionInfo { @@ -1079,39 +1057,35 @@ impl Db for MysqlDb { } fn create_collection(&mut self, name: String) -> DbFuture<'_, i32, Self::Error> { - let mut db = self.clone(); - Box::pin( - self.blocking_threadpool - .spawn(move || db.get_or_create_collection_id(&name)), - ) + Box::pin(async move { self.get_or_create_collection_id(&name).await }) } fn update_collection( &mut self, param: params::UpdateCollection, ) -> DbFuture<'_, SyncTimestamp, Self::Error> { - let mut db = self.clone(); - Box::pin(self.blocking_threadpool.spawn(move || { - db.update_collection(param.user_id.legacy_id as u32, param.collection_id) - })) + Box::pin(MysqlDb::update_collection( + self, + param.user_id.legacy_id as u32, + param.collection_id, + )) } fn timestamp(&self) -> SyncTimestamp { - self.timestamp() + MysqlDb::timestamp(self) } fn set_timestamp(&mut self, timestamp: SyncTimestamp) { - self.session.borrow_mut().timestamp = timestamp; + self.session.timestamp = timestamp; } - sync_db_method!(delete_batch, delete_batch_sync, DeleteBatch); + async_db_method!(delete_batch, MysqlDb::delete_batch, DeleteBatch); fn clear_coll_cache(&mut self) -> DbFuture<'_, (), Self::Error> { - let db = self.clone(); - Box::pin(self.blocking_threadpool.spawn(move || { - db.coll_cache.clear(); + Box::pin(async { + self.coll_cache.clear(); Ok(()) - })) + }) } fn set_quota(&mut self, enabled: bool, limit: usize, enforced: bool) { diff --git a/syncstorage-mysql/src/pool.rs b/syncstorage-mysql/src/pool.rs index fd9eced8..a9b0f066 100644 --- a/syncstorage-mysql/src/pool.rs +++ b/syncstorage-mysql/src/pool.rs @@ -7,31 +7,43 @@ use std::{ time::Duration, }; -use diesel::{ - mysql::MysqlConnection, - r2d2::{ConnectionManager, Pool}, - Connection, +use deadpool::managed::PoolError; +use diesel::Connection; +use diesel_async::{ + async_connection_wrapper::AsyncConnectionWrapper, + pooled_connection::{ + deadpool::{Object, Pool}, + AsyncDieselConnectionManager, + }, + AsyncMysqlConnection, }; #[cfg(debug_assertions)] use diesel_logger::LoggingConnection; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use syncserver_common::{BlockingThreadpool, Metrics}; #[cfg(debug_assertions)] -use syncserver_db_common::test::TestTransactionCustomizer; +use syncserver_db_common::test::test_transaction_hook; use syncserver_db_common::{GetPoolState, PoolState}; use syncstorage_db_common::{Db, DbPool, STD_COLLS}; use syncstorage_settings::{Quota, Settings}; +use tokio::task::spawn_blocking; use super::{error::DbError, models::MysqlDb, DbResult}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); +pub(crate) type Conn = Object; + /// Run the diesel embedded migrations /// /// Mysql DDL statements implicitly commit which could disrupt MysqlPool's /// begin_test_transaction during tests. So this runs on its own separate conn. +/// +/// Note that this runs as a plain diesel blocking method as diesel_async +/// doesn't support async migrations (but we utilize its connection via its +/// [AsyncConnectionWrapper]) fn run_embedded_migrations(database_url: &str) -> DbResult<()> { - let conn = MysqlConnection::establish(database_url)?; + let conn = AsyncConnectionWrapper::::establish(database_url)?; // This conn2 charade is to make mut-ness the same for both cases. #[cfg(debug_assertions)] @@ -46,51 +58,58 @@ fn run_embedded_migrations(database_url: &str) -> DbResult<()> { #[derive(Clone)] pub struct MysqlDbPool { /// Pool of db connections - pool: Pool>, + pool: Pool, /// Thread Pool for running synchronous db calls /// In-memory cache of collection_ids and their names coll_cache: Arc, metrics: Metrics, quota: Quota, - blocking_threadpool: Arc, + database_url: String, } impl MysqlDbPool { /// Creates a new pool of Mysql db connections. /// - /// Also initializes the Mysql db, ensuring all migrations are ran. + /// Doesn't initialize the db (does not run migrations). pub fn new( settings: &Settings, metrics: &Metrics, - blocking_threadpool: Arc, + _blocking_threadpool: Arc, ) -> DbResult { - run_embedded_migrations(&settings.database_url)?; - Self::new_without_migrations(settings, metrics, blocking_threadpool) - } + let manager = + AsyncDieselConnectionManager::::new(&settings.database_url); - pub fn new_without_migrations( - settings: &Settings, - metrics: &Metrics, - blocking_threadpool: Arc, - ) -> DbResult { - let manager = ConnectionManager::::new(settings.database_url.clone()); - let builder = Pool::builder() - .max_size(settings.database_pool_max_size) - .connection_timeout(Duration::from_secs( - settings.database_pool_connection_timeout.unwrap_or(30) as u64, - )) - .min_idle(settings.database_pool_min_idle); + let wait = settings + .database_pool_connection_timeout + .map(|seconds| Duration::from_secs(seconds as u64)); + let timeouts = deadpool::managed::Timeouts { + wait, + ..Default::default() + }; + let config = deadpool::managed::PoolConfig { + max_size: settings.database_pool_max_size as usize, + timeouts, + ..Default::default() + }; + let builder = Pool::builder(manager) + .config(config) + .runtime(deadpool::Runtime::Tokio1); #[cfg(debug_assertions)] let builder = if settings.database_use_test_transactions { - builder.connection_customizer(Box::new(TestTransactionCustomizer)) + builder.post_create(deadpool::managed::Hook::async_fn(|conn, _| { + Box::pin(async { test_transaction_hook(conn).await }) + })) } else { builder }; + let pool = builder + .build() + .map_err(|e| DbError::internal(format!("Couldn't build Db Pool: {e}")))?; Ok(Self { - pool: builder.build(manager)?, + pool, coll_cache: Default::default(), metrics: metrics.clone(), quota: Quota { @@ -98,7 +117,7 @@ impl MysqlDbPool { enabled: settings.enable_quota, enforced: settings.enforce_quota, }, - blocking_threadpool, + database_url: settings.database_url.clone(), }) } @@ -109,13 +128,21 @@ impl MysqlDbPool { sweeper() } - pub fn get_sync(&self) -> DbResult { + pub async fn get_mysql_db(&self) -> DbResult { + let conn = self.pool.get().await.map_err(|e| match e { + PoolError::Backend(be) => match be { + diesel_async::pooled_connection::PoolError::ConnectionError(ce) => ce.into(), + diesel_async::pooled_connection::PoolError::QueryError(dbe) => dbe.into(), + }, + PoolError::Timeout(timeout_type) => DbError::pool_timeout(timeout_type), + _ => DbError::internal(format!("deadpool PoolError: {e}")), + })?; + Ok(MysqlDb::new( - self.pool.get()?, + conn, Arc::clone(&self.coll_cache), &self.metrics, &self.quota, - self.blocking_threadpool.clone(), )) } } @@ -131,12 +158,16 @@ fn sweeper() {} impl DbPool for MysqlDbPool { type Error = DbError; - async fn get<'a>(&'a self) -> DbResult>> { - let pool = self.clone(); - self.blocking_threadpool - .spawn(move || pool.get_sync()) + async fn init(&mut self) -> Result<(), Self::Error> { + let database_url = self.database_url.clone(); + spawn_blocking(move || run_embedded_migrations(&database_url)) .await - .map(|db| Box::new(db) as Box>) + .map_err(|e| DbError::internal(format!("Couldn't spawn migrations: {e}")))??; + Ok(()) + } + + async fn get<'a>(&'a self) -> DbResult>> { + Ok(Box::new(self.get_mysql_db().await?) as Box>) } fn validate_batch_id(&self, id: String) -> DbResult<()> { @@ -158,7 +189,7 @@ impl fmt::Debug for MysqlDbPool { impl GetPoolState for MysqlDbPool { fn state(&self) -> PoolState { - self.pool.state().into() + self.pool.status().into() } } diff --git a/syncstorage-mysql/src/test.rs b/syncstorage-mysql/src/test.rs index f7a712f9..9a6a3ac6 100644 --- a/syncstorage-mysql/src/test.rs +++ b/syncstorage-mysql/src/test.rs @@ -4,35 +4,37 @@ use diesel::{ // expression_methods::TextExpressionMethods, // See note below about `not_like` becoming swedish ExpressionMethods, QueryDsl, - RunQueryDsl, }; +use diesel_async::RunQueryDsl; use syncserver_common::{BlockingThreadpool, Metrics}; use syncserver_settings::Settings as SyncserverSettings; +use syncstorage_db_common::DbPool; use syncstorage_settings::Settings as SyncstorageSettings; use url::Url; use crate::{models::MysqlDb, pool::MysqlDbPool, schema::collections, DbResult}; -pub fn db(settings: &SyncstorageSettings) -> DbResult { +async fn db(settings: &SyncstorageSettings) -> DbResult { let _ = env_logger::try_init(); // inherit SYNC_SYNCSTORAGE__DATABASE_URL from the env - let pool = MysqlDbPool::new( + let mut pool = MysqlDbPool::new( settings, &Metrics::noop(), Arc::new(BlockingThreadpool::new(512)), )?; - pool.get_sync() + pool.init().await?; + pool.get_mysql_db().await } -#[test] -fn static_collection_id() -> DbResult<()> { +#[tokio::test] +async fn static_collection_id() -> DbResult<()> { let settings = SyncserverSettings::test_settings().syncstorage; if Url::parse(&settings.database_url).unwrap().scheme() != "mysql" { // Skip this test if we're not using mysql return Ok(()); } - let mut db = db(&settings)?; + let mut db = db(&settings).await?; // ensure DB actually has predefined common collections let cols: Vec<(i32, _)> = vec![ @@ -58,7 +60,8 @@ fn static_collection_id() -> DbResult<()> { .filter(collections::name.ne("")) .filter(collections::name.ne("xxx_col2")) // from server::test .filter(collections::name.ne("col2")) // from older intergration tests - .load(&mut *db.inner.conn.write()?)? + .load(&mut db.conn) + .await? .into_iter() .collect(); assert_eq!(results.len(), cols.len(), "mismatched columns"); @@ -67,11 +70,11 @@ fn static_collection_id() -> DbResult<()> { } for (id, name) in &cols { - let result = db.get_collection_id(name)?; + let result = db.get_collection_id(name).await?; assert_eq!(result, *id); } - let cid = db.get_or_create_collection_id("col1")?; + let cid = db.get_or_create_collection_id("col1").await?; assert!(cid >= 100); Ok(()) } diff --git a/syncstorage-settings/src/lib.rs b/syncstorage-settings/src/lib.rs index 8a9d57e4..1343f54e 100644 --- a/syncstorage-settings/src/lib.rs +++ b/syncstorage-settings/src/lib.rs @@ -72,8 +72,6 @@ impl From<&Settings> for Deadman { pub struct Settings { pub database_url: String, pub database_pool_max_size: u32, - // NOTE: Not supported by deadpool! - pub database_pool_min_idle: Option, /// Pool timeout when waiting for a slot to become available, in seconds pub database_pool_connection_timeout: Option, /// Max age a given connection should live, in seconds @@ -116,7 +114,6 @@ impl Default for Settings { Settings { database_url: "mysql://root@127.0.0.1/syncstorage".to_string(), database_pool_max_size: 10, - database_pool_min_idle: None, database_pool_connection_lifespan: None, database_pool_connection_max_idle: None, database_pool_sweeper_task_interval: 30, diff --git a/tokenserver-db-common/Cargo.toml b/tokenserver-db-common/Cargo.toml index d760d17d..0e619ffd 100644 --- a/tokenserver-db-common/Cargo.toml +++ b/tokenserver-db-common/Cargo.toml @@ -10,6 +10,7 @@ async-trait.workspace = true backtrace.workspace = true deadpool.workspace = true diesel.workspace = true +diesel-async.workspace = true diesel_migrations.workspace = true http.workspace = true serde.workspace = true diff --git a/tokenserver-db-common/src/error.rs b/tokenserver-db-common/src/error.rs index da50e9d9..5e0f141b 100644 --- a/tokenserver-db-common/src/error.rs +++ b/tokenserver-db-common/src/error.rs @@ -1,6 +1,7 @@ use std::fmt; use backtrace::Backtrace; +use deadpool::managed::PoolError; use http::StatusCode; use syncserver_common::{from_error, impl_fmt_display, InternalError, ReportableError}; use syncserver_db_common::error::SqlError; @@ -118,11 +119,6 @@ from_error!( DbError, |error: diesel::result::ConnectionError| DbError::from(DbErrorKind::Sql(SqlError::from(error))) ); -from_error!( - diesel::r2d2::PoolError, - DbError, - |error: diesel::r2d2::PoolError| DbError::from(DbErrorKind::Sql(SqlError::from(error))) -); from_error!( diesel_migrations::MigrationError, DbError, @@ -135,3 +131,16 @@ from_error!( DbError, |error: std::boxed::Box| DbError::internal_error(error.to_string()) ); + +impl From> for DbError { + fn from(pe: PoolError) -> DbError { + match pe { + PoolError::Backend(be) => match be { + diesel_async::pooled_connection::PoolError::ConnectionError(ce) => ce.into(), + diesel_async::pooled_connection::PoolError::QueryError(dbe) => dbe.into(), + }, + PoolError::Timeout(timeout_type) => DbError::pool_timeout(timeout_type), + _ => DbError::internal(format!("deadpool PoolError: {pe}")), + } + } +} diff --git a/tokenserver-db-postgres/Cargo.toml b/tokenserver-db-postgres/Cargo.toml index 604e9e84..b6325b24 100644 --- a/tokenserver-db-postgres/Cargo.toml +++ b/tokenserver-db-postgres/Cargo.toml @@ -8,14 +8,14 @@ license.workspace = true [dependencies] async-trait.workspace = true deadpool.workspace = true -diesel = { workspace = true, features = ["postgres", "r2d2"] } +diesel = { workspace = true, features = ["postgres"] } diesel-async = { workspace = true, features = ["postgres"] } -diesel_logger.workspace = true -diesel_migrations = { workspace = true, features = ["postgres"] } +diesel_logger.workspace = true +diesel_migrations.workspace = true tokio = { workspace = true, features = ["macros", "sync"] } syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } tokenserver-common = { path = "../tokenserver-common" } tokenserver-db-common = { path = "../tokenserver-db-common" } -tokenserver-settings = { path = "../tokenserver-settings" } \ No newline at end of file +tokenserver-settings = { path = "../tokenserver-settings" } diff --git a/tokenserver-db-postgres/src/pool.rs b/tokenserver-db-postgres/src/pool.rs index c1907e55..8cb8cc3e 100644 --- a/tokenserver-db-postgres/src/pool.rs +++ b/tokenserver-db-postgres/src/pool.rs @@ -1,7 +1,6 @@ use std::time::Duration; use async_trait::async_trait; -use deadpool::managed::PoolError; use diesel::Connection; use diesel_async::{ async_connection_wrapper::AsyncConnectionWrapper, @@ -112,21 +111,8 @@ impl TokenserverPgPool { } async fn get_tokenserver_db(&self) -> Result { - let conn = self.inner.get().await.map_err(|e| match e { - PoolError::Backend(backend_err) => match backend_err { - diesel_async::pooled_connection::PoolError::ConnectionError(conn_err) => { - conn_err.into() - } - diesel_async::pooled_connection::PoolError::QueryError(query_err) => { - query_err.into() - } - }, - PoolError::Timeout(timeout_type) => DbError::pool_timeout(timeout_type), - _ => DbError::internal(format!("Deadpool PoolError: {e}")), - })?; - Ok(TokenserverPgDb::new( - conn, + self.inner.get().await?, &self.metrics, self.service_id, self.spanner_node_id, diff --git a/tokenserver-db/Cargo.toml b/tokenserver-db/Cargo.toml index f4c8fbe6..3ac6c165 100644 --- a/tokenserver-db/Cargo.toml +++ b/tokenserver-db/Cargo.toml @@ -8,12 +8,11 @@ edition.workspace = true [dependencies] async-trait.workspace = true http.workspace = true - -deadpool = { workspace = true } -diesel = { workspace = true } -diesel-async = { workspace = true } -diesel_logger = { workspace = true } -diesel_migrations = { workspace = true, features = ["mysql"] } +deadpool.workspace = true +diesel.workspace = true +diesel-async.workspace = true +diesel_logger.workspace = true +diesel_migrations.workspace = true syncserver-common = { path = "../syncserver-common" } syncserver-db-common = { path = "../syncserver-db-common" } tokenserver-common = { path = "../tokenserver-common" } diff --git a/tokenserver-db/src/pool.rs b/tokenserver-db/src/pool.rs index 1825f262..71b52e4a 100644 --- a/tokenserver-db/src/pool.rs +++ b/tokenserver-db/src/pool.rs @@ -1,7 +1,6 @@ use std::time::Duration; use async_trait::async_trait; -use deadpool::managed::PoolError; use diesel::Connection; use diesel_async::{ async_connection_wrapper::AsyncConnectionWrapper, @@ -110,17 +109,8 @@ impl TokenserverPool { } pub async fn get_tokenserver_db(&self) -> Result { - let conn = self.inner.get().await.map_err(|e| match e { - PoolError::Backend(be) => match be { - diesel_async::pooled_connection::PoolError::ConnectionError(ce) => ce.into(), - diesel_async::pooled_connection::PoolError::QueryError(dbe) => dbe.into(), - }, - PoolError::Timeout(timeout_type) => DbError::pool_timeout(timeout_type), - _ => DbError::internal(format!("deadpool PoolError: {e}")), - })?; - Ok(TokenserverDb::new( - conn, + self.inner.get().await?, &self.metrics, self.service_id, self.spanner_node_id,