Merge pull request #1783 from mozilla-services/refactor/mut-self-STOR-327
Some checks failed
Glean probe-scraper / glean-probe-scraper (push) Has been cancelled

refactor: switch tokenserver Db methods to &mut self
This commit is contained in:
Philip Jenvey 2025-09-02 15:28:31 -07:00 committed by GitHub
commit 41aa81b3fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 351 additions and 287 deletions

4
Cargo.lock generated
View File

@ -3625,9 +3625,9 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
version = "0.3.19"
version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5"
dependencies = [
"tracing-core",
]

View File

@ -43,8 +43,8 @@ macro_rules! sync_db_method {
sync_db_method!($name, $sync_name, $type, results::$type);
};
($name:ident, $sync_name:ident, $type:ident, $result:ty) => {
fn $name(&self, params: params::$type) -> DbFuture<'_, $result, DbError> {
let db = self.clone();
fn $name(&mut self, params: params::$type) -> DbFuture<'_, $result, DbError> {
let mut db = self.clone();
Box::pin(
self.blocking_threadpool
.spawn(move || db.$sync_name(params)),

View File

@ -207,7 +207,7 @@ impl FromRequest for TokenserverRequest {
// metrics. Use "none" as a placeholder for "device" with OAuth requests.
let hashed_device_id = hash_device_id(&hashed_fxa_uid, fxa_metrics_hash_secret);
let DbWrapper(db) = DbWrapper::extract(&req).await?;
let DbWrapper(mut db) = DbWrapper::extract(&req).await?;
let service_id = {
let path = req.match_info();

View File

@ -141,7 +141,7 @@ where
async fn update_user(
req: &TokenserverRequest,
db: Box<dyn Db>,
mut db: Box<dyn Db>,
) -> Result<UserUpdates, TokenserverError> {
let keys_changed_at = match (req.auth_data.keys_changed_at, req.user.keys_changed_at) {
// If the keys_changed_at in the request is larger than that stored on the user record,
@ -269,7 +269,7 @@ async fn update_user(
}
}
pub async fn heartbeat(DbWrapper(db): DbWrapper) -> Result<HttpResponse, Error> {
pub async fn heartbeat(DbWrapper(mut db): DbWrapper) -> Result<HttpResponse, Error> {
let mut checklist = HashMap::new();
checklist.insert(
"version".to_owned(),

View File

@ -83,7 +83,7 @@ impl ServerState {
// unlikely for this query to fail outside of network failures or other random errors
db_pool.service_id = db_pool
.get_sync()
.and_then(|db| {
.and_then(|mut db| {
db.get_service_id_sync(params::GetServiceId {
service: "sync-1.5".to_owned(),
})

View File

@ -41,7 +41,7 @@ pub async fn get_collections(
state: Data<ServerState>,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request.clone(), |db| async move {
.transaction_http(request.clone(), |mut db| async move {
meta.emit_api_metric("request.get_collections");
if state.glean_enabled {
// Values below are be passed to the Glean logic to emit metrics.
@ -82,7 +82,7 @@ pub async fn get_collection_counts(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
meta.emit_api_metric("request.get_collection_counts");
let result = db.get_collection_counts(meta.user_id).await?;
@ -99,7 +99,7 @@ pub async fn get_collection_usage(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
meta.emit_api_metric("request.get_collection_usage");
let usage: HashMap<_, _> = db
.get_collection_usage(meta.user_id)
@ -121,7 +121,7 @@ pub async fn get_quota(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
meta.emit_api_metric("request.get_quota");
let usage = db.get_storage_usage(meta.user_id).await?;
Ok(HttpResponse::Ok().json(vec![Some(usage as f64 / ONE_KB), None]))
@ -135,7 +135,7 @@ pub async fn delete_all(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
meta.emit_api_metric("request.delete_all");
Ok(HttpResponse::Ok().json(db.delete_storage(meta.user_id).await?))
})
@ -148,7 +148,7 @@ pub async fn delete_collection(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
let delete_bsos = !coll.query.ids.is_empty();
let timestamp = if delete_bsos {
coll.emit_api_metric("request.delete_bsos");
@ -193,7 +193,7 @@ pub async fn get_collection(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
coll.emit_api_metric("request.get_collection");
let params = params::GetBsos {
user_id: coll.user_id.clone(),
@ -221,7 +221,7 @@ pub async fn get_collection(
async fn finish_get_collection<T>(
coll: &CollectionRequest,
db: Box<dyn Db<Error = DbError>>,
mut db: Box<dyn Db<Error = DbError>>,
result: Result<Paginated<T>, DbError>,
) -> Result<HttpResponse, DbError>
where
@ -275,7 +275,7 @@ pub async fn post_collection(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
coll.emit_api_metric("request.post_collection");
trace!("Collection: Post");
@ -312,7 +312,7 @@ pub async fn post_collection(
// the entire, accumulated if the `commit` flag is set.
pub async fn post_collection_batch(
coll: CollectionPostRequest,
db: Box<dyn Db<Error = DbError>>,
mut db: Box<dyn Db<Error = DbError>>,
) -> Result<HttpResponse, ApiError> {
coll.emit_api_metric("request.post_collection_batch");
trace!("Batch: Post collection batch");
@ -488,7 +488,7 @@ pub async fn delete_bso(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
bso_req.emit_api_metric("request.delete_bso");
let result = db
.delete_bso(params::DeleteBso {
@ -508,7 +508,7 @@ pub async fn get_bso(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
bso_req.emit_api_metric("request.get_bso");
let result = db
.get_bso(params::GetBso {
@ -532,7 +532,7 @@ pub async fn put_bso(
request: HttpRequest,
) -> Result<HttpResponse, ApiError> {
db_pool
.transaction_http(request, |db| async move {
.transaction_http(request, |mut db| async move {
bso_req.emit_api_metric("request.put_bso");
let result = db
.put_bso(params::PutBso {
@ -574,7 +574,7 @@ pub async fn heartbeat(hb: HeartbeatRequest) -> Result<HttpResponse, ApiError> {
"version".to_owned(),
Value::String(env!("CARGO_PKG_VERSION").to_owned()),
);
let db = hb.db_pool.get().await?;
let mut db = hb.db_pool.get().await?;
checklist.insert("quota".to_owned(), serde_json::to_value(hb.quota)?);

View File

@ -53,8 +53,8 @@ impl DbTransactionPool {
F: Future<Output = Result<R, ApiError>>,
{
// Get connection from pool
let db = self.pool.get().await?;
let db2 = db.clone();
let mut db = self.pool.get().await?;
let mut db2 = db.clone();
// Lock for transaction
let result = match (self.get_lock_collection(), self.is_read) {
@ -98,7 +98,7 @@ impl DbTransactionPool {
A: FnOnce(Box<dyn Db<Error = DbError>>) -> F,
F: Future<Output = Result<R, ApiError>> + 'a,
{
let (resp, db) = self.transaction_internal(request, action).await?;
let (resp, mut db) = self.transaction_internal(request, action).await?;
// No further processing before commit is possible
db.commit().await?;
@ -117,7 +117,7 @@ impl DbTransactionPool {
F: Future<Output = Result<HttpResponse, ApiError>> + 'a,
{
let mreq = request.clone();
let check_precondition = move |db: Box<dyn Db<Error = DbError>>| {
let check_precondition = move |mut db: Box<dyn Db<Error = DbError>>| {
async move {
// set the extra information for all requests so we capture default err handlers.
set_extra(&mreq, db.get_connection_info());
@ -168,7 +168,7 @@ impl DbTransactionPool {
}
};
let (resp, db) = self
let (resp, mut db) = self
.transaction_internal(request.clone(), check_precondition)
.await?;
// match on error and return a composed HttpResponse (so we can use the tags?)

View File

@ -67,116 +67,123 @@ impl<E> Clone for Box<dyn DbPool<Error = E>> {
pub trait Db: Debug {
type Error: DbErrorIntrospect + 'static;
fn lock_for_read(&self, params: params::LockCollection) -> DbFuture<'_, (), Self::Error>;
fn lock_for_read(&mut self, params: params::LockCollection) -> DbFuture<'_, (), Self::Error>;
fn lock_for_write(&self, params: params::LockCollection) -> DbFuture<'_, (), Self::Error>;
fn lock_for_write(&mut self, params: params::LockCollection) -> DbFuture<'_, (), Self::Error>;
fn begin(&self, for_write: bool) -> DbFuture<'_, (), Self::Error>;
fn begin(&mut self, for_write: bool) -> DbFuture<'_, (), Self::Error>;
fn commit(&self) -> DbFuture<'_, (), Self::Error>;
fn commit(&mut self) -> DbFuture<'_, (), Self::Error>;
fn rollback(&self) -> DbFuture<'_, (), Self::Error>;
fn rollback(&mut self) -> DbFuture<'_, (), Self::Error>;
fn get_collection_timestamps(
&self,
&mut self,
params: params::GetCollectionTimestamps,
) -> DbFuture<'_, results::GetCollectionTimestamps, Self::Error>;
fn get_collection_timestamp(
&self,
&mut self,
params: params::GetCollectionTimestamp,
) -> DbFuture<'_, results::GetCollectionTimestamp, Self::Error>;
fn get_collection_counts(
&self,
&mut self,
params: params::GetCollectionCounts,
) -> DbFuture<'_, results::GetCollectionCounts, Self::Error>;
fn get_collection_usage(
&self,
&mut self,
params: params::GetCollectionUsage,
) -> DbFuture<'_, results::GetCollectionUsage, Self::Error>;
fn get_storage_timestamp(
&self,
&mut self,
params: params::GetStorageTimestamp,
) -> DbFuture<'_, results::GetStorageTimestamp, Self::Error>;
fn get_storage_usage(
&self,
&mut self,
params: params::GetStorageUsage,
) -> DbFuture<'_, results::GetStorageUsage, Self::Error>;
fn get_quota_usage(
&self,
&mut self,
params: params::GetQuotaUsage,
) -> DbFuture<'_, results::GetQuotaUsage, Self::Error>;
fn delete_storage(
&self,
&mut self,
params: params::DeleteStorage,
) -> DbFuture<'_, results::DeleteStorage, Self::Error>;
fn delete_collection(
&self,
&mut self,
params: params::DeleteCollection,
) -> DbFuture<'_, results::DeleteCollection, Self::Error>;
fn delete_bsos(
&self,
&mut self,
params: params::DeleteBsos,
) -> DbFuture<'_, results::DeleteBsos, Self::Error>;
fn get_bsos(&self, params: params::GetBsos) -> DbFuture<'_, results::GetBsos, Self::Error>;
fn get_bsos(&mut self, params: params::GetBsos) -> DbFuture<'_, results::GetBsos, Self::Error>;
fn get_bso_ids(&self, params: params::GetBsos)
-> DbFuture<'_, results::GetBsoIds, Self::Error>;
fn get_bso_ids(
&mut self,
params: params::GetBsos,
) -> DbFuture<'_, results::GetBsoIds, Self::Error>;
fn post_bsos(&self, params: params::PostBsos) -> DbFuture<'_, results::PostBsos, Self::Error>;
fn post_bsos(
&mut self,
params: params::PostBsos,
) -> DbFuture<'_, results::PostBsos, Self::Error>;
fn delete_bso(
&self,
&mut self,
params: params::DeleteBso,
) -> DbFuture<'_, results::DeleteBso, Self::Error>;
fn get_bso(&self, params: params::GetBso)
-> DbFuture<'_, Option<results::GetBso>, Self::Error>;
fn get_bso(
&mut self,
params: params::GetBso,
) -> DbFuture<'_, Option<results::GetBso>, Self::Error>;
fn get_bso_timestamp(
&self,
&mut self,
params: params::GetBsoTimestamp,
) -> DbFuture<'_, results::GetBsoTimestamp, Self::Error>;
fn put_bso(&self, params: params::PutBso) -> DbFuture<'_, results::PutBso, Self::Error>;
fn put_bso(&mut self, params: params::PutBso) -> DbFuture<'_, results::PutBso, Self::Error>;
fn create_batch(
&self,
&mut self,
params: params::CreateBatch,
) -> DbFuture<'_, results::CreateBatch, Self::Error>;
fn validate_batch(
&self,
&mut self,
params: params::ValidateBatch,
) -> DbFuture<'_, results::ValidateBatch, Self::Error>;
fn append_to_batch(
&self,
&mut self,
params: params::AppendToBatch,
) -> DbFuture<'_, results::AppendToBatch, Self::Error>;
fn get_batch(
&self,
&mut self,
params: params::GetBatch,
) -> DbFuture<'_, Option<results::GetBatch>, Self::Error>;
fn commit_batch(
&self,
&mut self,
params: params::CommitBatch,
) -> DbFuture<'_, results::CommitBatch, Self::Error>;
fn box_clone(&self) -> Box<dyn Db<Error = Self::Error>>;
fn check(&self) -> DbFuture<'_, results::Check, Self::Error>;
fn check(&mut self) -> DbFuture<'_, results::Check, Self::Error>;
fn get_connection_info(&self) -> results::ConnectionInfo;
@ -184,7 +191,7 @@ pub trait Db: Debug {
///
/// Modeled on the Python `get_resource_timestamp` function.
fn extract_resource(
&self,
&mut self,
user_id: UserIdentifier,
collection: Option<String>,
bso: Option<String>,
@ -230,22 +237,22 @@ pub trait Db: Debug {
// Internal methods used by the db tests
fn get_collection_id(&self, name: String) -> DbFuture<'_, i32, Self::Error>;
fn get_collection_id(&mut self, name: String) -> DbFuture<'_, i32, Self::Error>;
fn create_collection(&self, name: String) -> DbFuture<'_, i32, Self::Error>;
fn create_collection(&mut self, name: String) -> DbFuture<'_, i32, Self::Error>;
fn update_collection(
&self,
&mut self,
params: params::UpdateCollection,
) -> DbFuture<'_, SyncTimestamp, Self::Error>;
fn timestamp(&self) -> SyncTimestamp;
fn set_timestamp(&self, timestamp: SyncTimestamp);
fn set_timestamp(&mut self, timestamp: SyncTimestamp);
fn delete_batch(&self, params: params::DeleteBatch) -> DbFuture<'_, (), Self::Error>;
fn delete_batch(&mut self, params: params::DeleteBatch) -> DbFuture<'_, (), Self::Error>;
fn clear_coll_cache(&self) -> DbFuture<'_, (), Self::Error>;
fn clear_coll_cache(&mut self) -> DbFuture<'_, (), Self::Error>;
fn set_quota(&mut self, enabled: bool, limit: usize, enforce: bool);
}

View File

@ -55,7 +55,7 @@ macro_rules! mock_db_method {
mock_db_method!($name, $type, results::$type);
};
($name:ident, $type:ident, $result:ty) => {
fn $name(&self, _params: params::$type) -> DbFuture<'_, $result> {
fn $name(&mut self, _params: params::$type) -> DbFuture<'_, $result> {
let result: $result = Default::default();
Box::pin(future::ok(result))
}
@ -65,15 +65,15 @@ macro_rules! mock_db_method {
impl Db for MockDb {
type Error = DbError;
fn commit(&self) -> DbFuture<'_, ()> {
fn commit(&mut self) -> DbFuture<'_, ()> {
Box::pin(future::ok(()))
}
fn rollback(&self) -> DbFuture<'_, ()> {
fn rollback(&mut self) -> DbFuture<'_, ()> {
Box::pin(future::ok(()))
}
fn begin(&self, _for_write: bool) -> DbFuture<'_, ()> {
fn begin(&mut self, _for_write: bool) -> DbFuture<'_, ()> {
Box::pin(future::ok(()))
}
@ -81,7 +81,7 @@ impl Db for MockDb {
Box::new(self.clone())
}
fn check(&self) -> DbFuture<'_, results::Check> {
fn check(&mut self) -> DbFuture<'_, results::Check> {
Box::pin(future::ok(true))
}
@ -122,11 +122,11 @@ impl Db for MockDb {
Default::default()
}
fn set_timestamp(&self, _: SyncTimestamp) {}
fn set_timestamp(&mut self, _: SyncTimestamp) {}
mock_db_method!(delete_batch, DeleteBatch);
fn clear_coll_cache(&self) -> DbFuture<'_, ()> {
fn clear_coll_cache(&mut self) -> DbFuture<'_, ()> {
Box::pin(future::ok(()))
}

View File

@ -48,7 +48,7 @@ fn gb(user_id: u32, coll: &str, id: String) -> params::GetBatch {
#[tokio::test]
async fn create_delete() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = 1;
let coll = "clients";
@ -71,7 +71,7 @@ async fn create_delete() -> Result<(), DbError> {
#[tokio::test]
async fn expiry() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = 1;
let coll = "clients";
@ -95,7 +95,7 @@ async fn expiry() -> Result<(), DbError> {
#[tokio::test]
async fn update() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = 1;
let coll = "clients";
@ -119,7 +119,7 @@ async fn update() -> Result<(), DbError> {
#[tokio::test]
async fn append_commit() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = 1;
let coll = "clients";
@ -172,7 +172,7 @@ async fn quota_test_create_batch() -> Result<(), DbError> {
settings.limits.max_quota_limit = limit;
let pool = db_pool(Some(settings.clone())).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = 1;
let coll = "clients";
@ -214,7 +214,7 @@ async fn quota_test_append_batch() -> Result<(), DbError> {
settings.limits.max_quota_limit = limit;
let pool = db_pool(Some(settings.clone())).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = 1;
let coll = "clients";
@ -250,7 +250,7 @@ async fn quota_test_append_batch() -> Result<(), DbError> {
async fn test_append_async_w_null() -> Result<(), DbError> {
let settings = Settings::test_settings().syncstorage;
let pool = db_pool(Some(settings)).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
// Remember: TTL is seconds to live, not an expiry date
let ttl_0 = 86_400;
let ttl_1 = 86_400;

View File

@ -21,7 +21,7 @@ lazy_static! {
#[tokio::test]
async fn bso_successfully_updates_single_values() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -62,7 +62,7 @@ async fn bso_successfully_updates_single_values() -> Result<(), DbError> {
#[tokio::test]
async fn bso_modified_not_changed_on_ttl_touch() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -85,7 +85,7 @@ async fn bso_modified_not_changed_on_ttl_touch() -> Result<(), DbError> {
#[tokio::test]
async fn put_bso_updates() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -108,7 +108,7 @@ async fn put_bso_updates() -> Result<(), DbError> {
#[tokio::test]
async fn get_bsos_limit_offset() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -122,7 +122,7 @@ async fn get_bsos_limit_offset() -> Result<(), DbError> {
Some(i),
Some(DEFAULT_BSO_TTL),
);
with_delta!(&db, i64::from(i) * 10, { db.put_bso(bso).await })?;
with_delta!(&mut db, i64::from(i) * 10, { db.put_bso(bso).await })?;
}
let bsos = db
@ -229,7 +229,7 @@ async fn get_bsos_limit_offset() -> Result<(), DbError> {
#[tokio::test]
async fn get_bsos_newer() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -244,7 +244,7 @@ async fn get_bsos_newer() -> Result<(), DbError> {
Some(1),
Some(DEFAULT_BSO_TTL),
);
with_delta!(&db, -i * 10, { db.put_bso(pbso).await })?;
with_delta!(&mut db, -i * 10, { db.put_bso(pbso).await })?;
}
let bsos = db
@ -314,7 +314,7 @@ async fn get_bsos_newer() -> Result<(), DbError> {
#[tokio::test]
async fn get_bsos_sort() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -328,7 +328,7 @@ async fn get_bsos_sort() -> Result<(), DbError> {
Some(*sortindex),
Some(DEFAULT_BSO_TTL),
);
with_delta!(&db, -(revi as i64) * 10, { db.put_bso(pbso).await })?;
with_delta!(&mut db, -(revi as i64) * 10, { db.put_bso(pbso).await })?;
}
let bsos = db
@ -387,7 +387,7 @@ async fn get_bsos_sort() -> Result<(), DbError> {
#[tokio::test]
async fn delete_bsos_in_correct_collection() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let payload = "data";
@ -404,14 +404,14 @@ async fn delete_bsos_in_correct_collection() -> Result<(), DbError> {
#[tokio::test]
async fn get_storage_timestamp() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
db.create_collection("NewCollection1".to_owned()).await?;
let col2 = db.create_collection("NewCollection2".to_owned()).await?;
db.create_collection("NewCollection3".to_owned()).await?;
with_delta!(&db, 100_000, {
with_delta!(&mut db, 100_000, {
db.update_collection(params::UpdateCollection {
user_id: hid(uid),
collection_id: col2,
@ -427,7 +427,7 @@ async fn get_storage_timestamp() -> Result<(), DbError> {
#[tokio::test]
async fn get_collection_id() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
db.get_collection_id("bookmarks".to_owned()).await?;
Ok(())
}
@ -435,7 +435,7 @@ async fn get_collection_id() -> Result<(), DbError> {
#[tokio::test]
async fn create_collection() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let name = "NewCollection";
let cid = db.create_collection(name.to_owned()).await?;
@ -448,7 +448,7 @@ async fn create_collection() -> Result<(), DbError> {
#[tokio::test]
async fn update_collection() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let collection = "test".to_owned();
let cid = db.create_collection(collection.clone()).await?;
@ -464,7 +464,7 @@ async fn update_collection() -> Result<(), DbError> {
#[tokio::test]
async fn delete_collection() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "NewCollection";
@ -500,7 +500,7 @@ async fn delete_collection() -> Result<(), DbError> {
#[tokio::test]
async fn delete_collection_tombstone() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "test";
@ -560,7 +560,7 @@ async fn delete_collection_tombstone() -> Result<(), DbError> {
#[tokio::test]
async fn get_collection_timestamps() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "test".to_owned();
@ -588,7 +588,7 @@ async fn get_collection_timestamps() -> Result<(), DbError> {
#[tokio::test]
async fn get_collection_timestamps_tombstone() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "test".to_owned();
@ -613,7 +613,7 @@ async fn get_collection_timestamps_tombstone() -> Result<(), DbError> {
#[tokio::test]
async fn get_collection_usage() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let mut expected = HashMap::new();
@ -707,7 +707,7 @@ async fn test_quota() -> Result<(), DbError> {
#[tokio::test]
async fn get_collection_counts() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let mut expected = HashMap::new();
@ -730,7 +730,7 @@ async fn get_collection_counts() -> Result<(), DbError> {
#[tokio::test]
async fn put_bso() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "NewCollection";
@ -750,7 +750,7 @@ async fn put_bso() -> Result<(), DbError> {
assert_eq!(bso.sortindex, Some(1));
let bso2 = pbso(uid, coll, bid, Some("bar"), Some(2), Some(DEFAULT_BSO_TTL));
with_delta!(&db, 19, {
with_delta!(&mut db, 19, {
db.put_bso(bso2).await?;
let ts = db
.get_collection_timestamp(params::GetCollectionTimestamp {
@ -770,7 +770,7 @@ async fn put_bso() -> Result<(), DbError> {
#[tokio::test]
async fn post_bsos() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "NewCollection";
@ -841,7 +841,7 @@ async fn post_bsos() -> Result<(), DbError> {
#[tokio::test]
async fn get_bso() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -862,7 +862,7 @@ async fn get_bso() -> Result<(), DbError> {
#[tokio::test]
async fn get_bsos() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -877,7 +877,7 @@ async fn get_bsos() -> Result<(), DbError> {
Some(*sortindex),
None,
);
with_delta!(&db, i as i64 * 10, { db.put_bso(bso).await })?;
with_delta!(&mut db, i as i64 * 10, { db.put_bso(bso).await })?;
}
let ids = db
@ -933,7 +933,7 @@ async fn get_bsos() -> Result<(), DbError> {
#[tokio::test]
async fn get_bso_timestamp() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -954,7 +954,7 @@ async fn get_bso_timestamp() -> Result<(), DbError> {
#[tokio::test]
async fn delete_bso() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -970,7 +970,7 @@ async fn delete_bso() -> Result<(), DbError> {
#[tokio::test]
async fn delete_bsos() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -1005,21 +1005,21 @@ async fn delete_bsos() -> Result<(), DbError> {
#[tokio::test]
async fn usage_stats() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
Ok(())
}
#[tokio::test]
async fn purge_expired() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
Ok(())
}
#[tokio::test]
async fn optimize() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
Ok(())
}
*/
@ -1027,7 +1027,7 @@ async fn optimize() -> Result<(), DbError> {
#[tokio::test]
async fn delete_storage() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let bid = "test";
@ -1053,7 +1053,7 @@ async fn delete_storage() -> Result<(), DbError> {
#[tokio::test]
async fn collection_cache() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "test";
@ -1074,7 +1074,7 @@ async fn collection_cache() -> Result<(), DbError> {
#[tokio::test]
async fn lock_for_read() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -1092,7 +1092,7 @@ async fn lock_for_read() -> Result<(), DbError> {
#[tokio::test]
async fn lock_for_write() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
let uid = *UID;
let coll = "clients";
@ -1110,7 +1110,7 @@ async fn lock_for_write() -> Result<(), DbError> {
#[tokio::test]
async fn heartbeat() -> Result<(), DbError> {
let pool = db_pool(None).await?;
let db = test_db(pool).await?;
let mut db = test_db(pool).await?;
assert!(db.check().await?);
Ok(())

View File

@ -28,7 +28,7 @@ pub async fn db_pool(settings: Option<SyncstorageSettings>) -> Result<DbPoolImpl
}
pub async fn test_db(pool: DbPoolImpl) -> Result<Box<dyn Db<Error = DbError>>, DbError> {
let db = pool.get().await?;
let mut db = pool.get().await?;
// Spanner won't have a timestamp until lock_for_xxx are called: fill one
// in for it
db.set_timestamp(SyncTimestamp::default());

View File

@ -23,7 +23,7 @@ const MAX_TTL: i32 = 2_100_000_000;
const MAX_BATCH_CREATE_RETRY: u8 = 5;
pub fn create(db: &MysqlDb, params: params::CreateBatch) -> DbResult<results::CreateBatch> {
pub fn create(db: &mut MysqlDb, params: params::CreateBatch) -> DbResult<results::CreateBatch> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = db.get_collection_id(&params.collection)?;
// Careful, there's some weirdness here!
@ -68,7 +68,7 @@ pub fn create(db: &MysqlDb, params: params::CreateBatch) -> DbResult<results::Cr
})
}
pub fn validate(db: &MysqlDb, params: params::ValidateBatch) -> DbResult<bool> {
pub fn validate(db: &mut MysqlDb, params: params::ValidateBatch) -> DbResult<bool> {
let batch_id = decode_id(&params.id)?;
// Avoid hitting the db for batches that are obviously too old. Recall
// that the batchid is a millisecond timestamp.
@ -88,7 +88,7 @@ pub fn validate(db: &MysqlDb, params: params::ValidateBatch) -> DbResult<bool> {
Ok(exists.is_some())
}
pub fn append(db: &MysqlDb, params: params::AppendToBatch) -> DbResult<()> {
pub fn append(db: &mut MysqlDb, params: params::AppendToBatch) -> DbResult<()> {
let exists = validate(
db,
params::ValidateBatch {
@ -108,7 +108,7 @@ pub fn append(db: &MysqlDb, params: params::AppendToBatch) -> DbResult<()> {
Ok(())
}
pub fn get(db: &MysqlDb, params: params::GetBatch) -> DbResult<Option<results::GetBatch>> {
pub fn get(db: &mut MysqlDb, params: params::GetBatch) -> DbResult<Option<results::GetBatch>> {
let is_valid = validate(
db,
params::ValidateBatch {
@ -125,7 +125,7 @@ pub fn get(db: &MysqlDb, params: params::GetBatch) -> DbResult<Option<results::G
Ok(batch)
}
pub fn delete(db: &MysqlDb, params: params::DeleteBatch) -> DbResult<()> {
pub fn delete(db: &mut MysqlDb, params: params::DeleteBatch) -> DbResult<()> {
let batch_id = decode_id(&params.id)?;
let user_id = params.user_id.legacy_id as i64;
let collection_id = db.get_collection_id(&params.collection)?;
@ -142,7 +142,7 @@ pub fn delete(db: &MysqlDb, params: params::DeleteBatch) -> DbResult<()> {
}
/// Commits a batch to the bsos table, deleting the batch when succesful
pub fn commit(db: &MysqlDb, params: params::CommitBatch) -> DbResult<results::CommitBatch> {
pub fn commit(db: &mut MysqlDb, params: params::CommitBatch) -> DbResult<results::CommitBatch> {
let batch_id = decode_id(&params.batch.id)?;
let user_id = params.user_id.legacy_id as i64;
let collection_id = db.get_collection_id(&params.collection)?;
@ -173,7 +173,7 @@ pub fn commit(db: &MysqlDb, params: params::CommitBatch) -> DbResult<results::Co
}
pub fn do_append(
db: &MysqlDb,
db: &mut MysqlDb,
batch_id: i64,
user_id: UserIdentifier,
_collection_id: i32,
@ -282,7 +282,7 @@ fn decode_id(id: &str) -> DbResult<i64> {
macro_rules! batch_db_method {
($name:ident, $batch_name:ident, $type:ident) => {
pub fn $name(&self, params: params::$type) -> DbResult<results::$type> {
pub fn $name(&mut self, params: params::$type) -> DbResult<results::$type> {
batch::$batch_name(self, params)
}
};

View File

@ -150,7 +150,7 @@ impl MysqlDb {
/// In theory it would be possible to use serializable transactions rather
/// than explicit locking, but our ops team have expressed concerns about
/// the efficiency of that approach at scale.
fn lock_for_read_sync(&self, params: params::LockCollection) -> DbResult<()> {
fn lock_for_read_sync(&mut self, params: params::LockCollection) -> DbResult<()> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_collection_id(&params.collection).or_else(|e| {
if e.is_collection_not_found() {
@ -196,7 +196,7 @@ impl MysqlDb {
Ok(())
}
fn lock_for_write_sync(&self, params: params::LockCollection) -> DbResult<()> {
fn lock_for_write_sync(&mut self, params: params::LockCollection) -> DbResult<()> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_or_create_collection_id(&params.collection)?;
if let Some(CollectionLock::Read) = self
@ -237,7 +237,7 @@ impl MysqlDb {
Ok(())
}
pub(super) fn begin(&self, for_write: bool) -> DbResult<()> {
pub(super) fn begin(&mut self, for_write: bool) -> DbResult<()> {
<InternalConn as Connection>::TransactionManager::begin_transaction(
&mut *self.conn.write()?,
)?;
@ -248,11 +248,11 @@ impl MysqlDb {
Ok(())
}
async fn begin_async(&self, for_write: bool) -> DbResult<()> {
async fn begin_async(&mut self, for_write: bool) -> DbResult<()> {
self.begin(for_write)
}
fn commit_sync(&self) -> DbResult<()> {
fn commit_sync(&mut self) -> DbResult<()> {
if self.session.borrow().in_transaction {
<InternalConn as Connection>::TransactionManager::commit_transaction(
&mut *self.conn.write()?,
@ -261,7 +261,7 @@ impl MysqlDb {
Ok(())
}
fn rollback_sync(&self) -> DbResult<()> {
fn rollback_sync(&mut self) -> DbResult<()> {
if self.session.borrow().in_transaction {
<InternalConn as Connection>::TransactionManager::rollback_transaction(
&mut *self.conn.write()?,
@ -270,7 +270,7 @@ impl MysqlDb {
Ok(())
}
fn erect_tombstone(&self, user_id: i32) -> DbResult<()> {
fn erect_tombstone(&mut self, user_id: i32) -> DbResult<()> {
sql_query(format!(
r#"INSERT INTO user_collections ({user_id}, {collection_id}, {modified})
VALUES (?, ?, ?)
@ -287,7 +287,7 @@ impl MysqlDb {
Ok(())
}
fn delete_storage_sync(&self, user_id: UserIdentifier) -> DbResult<()> {
fn delete_storage_sync(&mut self, user_id: UserIdentifier) -> DbResult<()> {
let user_id = user_id.legacy_id as i64;
// Delete user data.
delete(bso::table)
@ -303,7 +303,10 @@ impl MysqlDb {
// Deleting the collection should result in:
// - collection does not appear in /info/collections
// - X-Last-Modified timestamp at the storage level changing
fn delete_collection_sync(&self, params: params::DeleteCollection) -> DbResult<SyncTimestamp> {
fn delete_collection_sync(
&mut self,
params: params::DeleteCollection,
) -> DbResult<SyncTimestamp> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_collection_id(&params.collection)?;
let mut count = delete(bso::table)
@ -322,7 +325,7 @@ impl MysqlDb {
self.get_storage_timestamp_sync(params.user_id)
}
pub(super) fn get_or_create_collection_id(&self, name: &str) -> DbResult<i32> {
pub(super) fn get_or_create_collection_id(&mut self, name: &str) -> DbResult<i32> {
if let Some(id) = self.coll_cache.get_id(name)? {
return Ok(id);
}
@ -343,7 +346,7 @@ impl MysqlDb {
Ok(id)
}
pub(super) fn get_collection_id(&self, name: &str) -> DbResult<i32> {
pub(super) fn get_collection_id(&mut self, name: &str) -> DbResult<i32> {
if let Some(id) = self.coll_cache.get_id(name)? {
return Ok(id);
}
@ -364,7 +367,7 @@ impl MysqlDb {
Ok(id)
}
fn _get_collection_name(&self, id: i32) -> DbResult<String> {
fn _get_collection_name(&mut self, id: i32) -> DbResult<String> {
let name = if let Some(name) = self.coll_cache.get_name(id)? {
name
} else {
@ -382,7 +385,7 @@ impl MysqlDb {
Ok(name)
}
fn put_bso_sync(&self, bso: params::PutBso) -> DbResult<results::PutBso> {
fn put_bso_sync(&mut self, bso: params::PutBso) -> DbResult<results::PutBso> {
/*
if bso.payload.is_none() && bso.sortindex.is_none() && bso.ttl.is_none() {
// XXX: go returns an error here (ErrNothingToDo), and is treated
@ -477,7 +480,7 @@ impl MysqlDb {
self.update_collection(user_id as u32, collection_id)
}
fn get_bsos_sync(&self, params: params::GetBsos) -> DbResult<results::GetBsos> {
fn get_bsos_sync(&mut self, params: params::GetBsos) -> DbResult<results::GetBsos> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_collection_id(&params.collection)?;
let now = self.timestamp().as_i64();
@ -567,7 +570,7 @@ impl MysqlDb {
})
}
fn get_bso_ids_sync(&self, params: params::GetBsos) -> DbResult<results::GetBsoIds> {
fn get_bso_ids_sync(&mut self, params: params::GetBsos) -> DbResult<results::GetBsoIds> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_collection_id(&params.collection)?;
let mut query = bso::table
@ -630,7 +633,7 @@ impl MysqlDb {
})
}
fn get_bso_sync(&self, params: params::GetBso) -> DbResult<Option<results::GetBso>> {
fn get_bso_sync(&mut self, params: params::GetBso) -> DbResult<Option<results::GetBso>> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_collection_id(&params.collection)?;
Ok(bso::table
@ -649,7 +652,7 @@ impl MysqlDb {
.optional()?)
}
fn delete_bso_sync(&self, params: params::DeleteBso) -> DbResult<results::DeleteBso> {
fn delete_bso_sync(&mut self, params: params::DeleteBso) -> DbResult<results::DeleteBso> {
let user_id = params.user_id.legacy_id;
let collection_id = self.get_collection_id(&params.collection)?;
let affected_rows = delete(bso::table)
@ -664,7 +667,7 @@ impl MysqlDb {
self.update_collection(user_id as u32, collection_id)
}
fn delete_bsos_sync(&self, params: params::DeleteBsos) -> DbResult<results::DeleteBsos> {
fn delete_bsos_sync(&mut self, params: params::DeleteBsos) -> DbResult<results::DeleteBsos> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_collection_id(&params.collection)?;
delete(bso::table)
@ -675,7 +678,7 @@ impl MysqlDb {
self.update_collection(user_id as u32, collection_id)
}
fn post_bsos_sync(&self, input: params::PostBsos) -> DbResult<results::PostBsos> {
fn post_bsos_sync(&mut self, input: params::PostBsos) -> DbResult<results::PostBsos> {
let collection_id = self.get_or_create_collection_id(&input.collection)?;
let mut result = results::PostBsos {
modified: self.timestamp(),
@ -708,7 +711,7 @@ impl MysqlDb {
Ok(result)
}
fn get_storage_timestamp_sync(&self, user_id: UserIdentifier) -> DbResult<SyncTimestamp> {
fn get_storage_timestamp_sync(&mut self, user_id: UserIdentifier) -> DbResult<SyncTimestamp> {
let user_id = user_id.legacy_id as i64;
let modified = user_collections::table
.select(max(user_collections::modified))
@ -719,7 +722,7 @@ impl MysqlDb {
}
fn get_collection_timestamp_sync(
&self,
&mut self,
params: params::GetCollectionTimestamp,
) -> DbResult<SyncTimestamp> {
let user_id = params.user_id.legacy_id as u32;
@ -741,7 +744,10 @@ impl MysqlDb {
.ok_or_else(DbError::collection_not_found)
}
fn get_bso_timestamp_sync(&self, params: params::GetBsoTimestamp) -> DbResult<SyncTimestamp> {
fn get_bso_timestamp_sync(
&mut self,
params: params::GetBsoTimestamp,
) -> DbResult<SyncTimestamp> {
let user_id = params.user_id.legacy_id as i64;
let collection_id = self.get_collection_id(&params.collection)?;
let modified = bso::table
@ -756,7 +762,7 @@ impl MysqlDb {
}
fn get_collection_timestamps_sync(
&self,
&mut self,
user_id: UserIdentifier,
) -> DbResult<results::GetCollectionTimestamps> {
let modifieds = sql_query(format!(
@ -781,13 +787,13 @@ impl MysqlDb {
self.map_collection_names(modifieds)
}
fn check_sync(&self) -> DbResult<results::Check> {
fn check_sync(&mut self) -> DbResult<results::Check> {
// has the database been up for more than 0 seconds?
let result = sql_query("SHOW STATUS LIKE \"Uptime\"").execute(&mut *self.conn.write()?)?;
Ok(result as u64 > 0)
}
fn map_collection_names<T>(&self, by_id: HashMap<i32, T>) -> DbResult<HashMap<String, T>> {
fn map_collection_names<T>(&mut self, by_id: HashMap<i32, T>) -> DbResult<HashMap<String, T>> {
let mut names = self.load_collection_names(by_id.keys())?;
by_id
.into_iter()
@ -800,7 +806,7 @@ impl MysqlDb {
}
fn load_collection_names<'a>(
&self,
&mut self,
collection_ids: impl Iterator<Item = &'a i32>,
) -> DbResult<HashMap<i32, String>> {
let mut names = HashMap::new();
@ -831,7 +837,7 @@ impl MysqlDb {
}
pub(super) fn update_collection(
&self,
&mut self,
user_id: u32,
collection_id: i32,
) -> DbResult<SyncTimestamp> {
@ -875,7 +881,7 @@ impl MysqlDb {
// Perform a lighter weight "read only" storage size check
fn get_storage_usage_sync(
&self,
&mut self,
user_id: UserIdentifier,
) -> DbResult<results::GetStorageUsage> {
let uid = user_id.legacy_id as i64;
@ -889,7 +895,7 @@ impl MysqlDb {
// Perform a lighter weight "read only" quota storage check
fn get_quota_usage_sync(
&self,
&mut self,
params: params::GetQuotaUsage,
) -> DbResult<results::GetQuotaUsage> {
let uid = params.user_id.legacy_id as i64;
@ -911,7 +917,7 @@ impl MysqlDb {
// perform a heavier weight quota calculation
fn calc_quota_usage_sync(
&self,
&mut self,
user_id: u32,
collection_id: i32,
) -> DbResult<results::GetQuotaUsage> {
@ -933,7 +939,7 @@ impl MysqlDb {
}
fn get_collection_usage_sync(
&self,
&mut self,
user_id: UserIdentifier,
) -> DbResult<results::GetCollectionUsage> {
let counts = bso::table
@ -948,7 +954,7 @@ impl MysqlDb {
}
fn get_collection_counts_sync(
&self,
&mut self,
user_id: UserIdentifier,
) -> DbResult<results::GetCollectionCounts> {
let counts = bso::table
@ -974,7 +980,7 @@ impl MysqlDb {
batch_db_method!(commit_batch_sync, commit, CommitBatch);
batch_db_method!(delete_batch_sync, delete, DeleteBatch);
fn get_batch_sync(&self, params: params::GetBatch) -> DbResult<Option<results::GetBatch>> {
fn get_batch_sync(&mut self, params: params::GetBatch) -> DbResult<Option<results::GetBatch>> {
batch::get(self, params)
}
@ -986,23 +992,23 @@ impl MysqlDb {
impl Db for MysqlDb {
type Error = DbError;
fn commit(&self) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
fn commit(&mut self) -> DbFuture<'_, (), Self::Error> {
let mut db = self.clone();
Box::pin(self.blocking_threadpool.spawn(move || db.commit_sync()))
}
fn rollback(&self) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
fn rollback(&mut self) -> DbFuture<'_, (), Self::Error> {
let mut db = self.clone();
Box::pin(self.blocking_threadpool.spawn(move || db.rollback_sync()))
}
fn begin(&self, for_write: bool) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
fn begin(&mut self, for_write: bool) -> DbFuture<'_, (), Self::Error> {
let mut db = self.clone();
Box::pin(async move { db.begin_async(for_write).map_err(Into::into).await })
}
fn check(&self) -> DbFuture<'_, results::Check, Self::Error> {
let db = self.clone();
fn check(&mut self) -> DbFuture<'_, results::Check, Self::Error> {
let mut db = self.clone();
Box::pin(self.blocking_threadpool.spawn(move || db.check_sync()))
}
@ -1061,8 +1067,8 @@ impl Db for MysqlDb {
);
sync_db_method!(commit_batch, commit_batch_sync, CommitBatch);
fn get_collection_id(&self, name: String) -> DbFuture<'_, i32, Self::Error> {
let db = self.clone();
fn get_collection_id(&mut self, name: String) -> DbFuture<'_, i32, Self::Error> {
let mut db = self.clone();
Box::pin(
self.blocking_threadpool
.spawn(move || db.get_collection_id(&name)),
@ -1073,8 +1079,8 @@ impl Db for MysqlDb {
results::ConnectionInfo::default()
}
fn create_collection(&self, name: String) -> DbFuture<'_, i32, Self::Error> {
let db = self.clone();
fn create_collection(&mut self, name: String) -> DbFuture<'_, i32, Self::Error> {
let mut db = self.clone();
Box::pin(
self.blocking_threadpool
.spawn(move || db.get_or_create_collection_id(&name)),
@ -1082,10 +1088,10 @@ impl Db for MysqlDb {
}
fn update_collection(
&self,
&mut self,
param: params::UpdateCollection,
) -> DbFuture<'_, SyncTimestamp, Self::Error> {
let db = self.clone();
let mut db = self.clone();
Box::pin(self.blocking_threadpool.spawn(move || {
db.update_collection(param.user_id.legacy_id as u32, param.collection_id)
}))
@ -1095,13 +1101,13 @@ impl Db for MysqlDb {
self.timestamp()
}
fn set_timestamp(&self, timestamp: SyncTimestamp) {
fn set_timestamp(&mut self, timestamp: SyncTimestamp) {
self.session.borrow_mut().timestamp = timestamp;
}
sync_db_method!(delete_batch, delete_batch_sync, DeleteBatch);
fn clear_coll_cache(&self) -> DbFuture<'_, (), Self::Error> {
fn clear_coll_cache(&mut self) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
Box::pin(self.blocking_threadpool.spawn(move || {
db.coll_cache.clear();

View File

@ -32,7 +32,7 @@ fn static_collection_id() -> DbResult<()> {
// Skip this test if we're not using mysql
return Ok(());
}
let db = db(&settings)?;
let mut db = db(&settings)?;
// ensure DB actually has predefined common collections
let cols: Vec<(i32, _)> = vec![

View File

@ -1904,33 +1904,33 @@ impl SpannerDb {
impl Db for SpannerDb {
type Error = DbError;
fn commit(&self) -> DbFuture<'_, (), Self::Error> {
fn commit(&mut self) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
Box::pin(async move { db.commit_async().map_err(Into::into).await })
}
fn rollback(&self) -> DbFuture<'_, (), Self::Error> {
fn rollback(&mut self) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
Box::pin(async move { db.rollback_async().map_err(Into::into).await })
}
fn lock_for_read(&self, param: params::LockCollection) -> DbFuture<'_, (), Self::Error> {
fn lock_for_read(&mut self, param: params::LockCollection) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
Box::pin(async move { db.lock_for_read_async(param).map_err(Into::into).await })
}
fn lock_for_write(&self, param: params::LockCollection) -> DbFuture<'_, (), Self::Error> {
fn lock_for_write(&mut self, param: params::LockCollection) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
Box::pin(async move { db.lock_for_write_async(param).map_err(Into::into).await })
}
fn begin(&self, for_write: bool) -> DbFuture<'_, (), Self::Error> {
fn begin(&mut self, for_write: bool) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
Box::pin(async move { db.begin_async(for_write).map_err(Into::into).await })
}
fn get_collection_timestamp(
&self,
&mut self,
param: params::GetCollectionTimestamp,
) -> DbFuture<'_, results::GetCollectionTimestamp, Self::Error> {
let db = self.clone();
@ -1942,7 +1942,7 @@ impl Db for SpannerDb {
}
fn get_storage_timestamp(
&self,
&mut self,
param: params::GetStorageTimestamp,
) -> DbFuture<'_, results::GetStorageTimestamp, Self::Error> {
let db = self.clone();
@ -1950,20 +1950,20 @@ impl Db for SpannerDb {
}
fn delete_collection(
&self,
&mut self,
param: params::DeleteCollection,
) -> DbFuture<'_, results::DeleteCollection, Self::Error> {
let db = self.clone();
Box::pin(async move { db.delete_collection_async(param).map_err(Into::into).await })
}
fn check(&self) -> DbFuture<'_, results::Check, Self::Error> {
fn check(&mut self) -> DbFuture<'_, results::Check, Self::Error> {
let db = self.clone();
Box::pin(async move { db.check_async().map_err(Into::into).await })
}
fn get_collection_timestamps(
&self,
&mut self,
user_id: params::GetCollectionTimestamps,
) -> DbFuture<'_, results::GetCollectionTimestamps, Self::Error> {
let db = self.clone();
@ -1975,7 +1975,7 @@ impl Db for SpannerDb {
}
fn get_collection_counts(
&self,
&mut self,
user_id: params::GetCollectionCounts,
) -> DbFuture<'_, results::GetCollectionCounts, Self::Error> {
let db = self.clone();
@ -1987,7 +1987,7 @@ impl Db for SpannerDb {
}
fn get_collection_usage(
&self,
&mut self,
user_id: params::GetCollectionUsage,
) -> DbFuture<'_, results::GetCollectionUsage, Self::Error> {
let db = self.clone();
@ -1999,7 +1999,7 @@ impl Db for SpannerDb {
}
fn get_storage_usage(
&self,
&mut self,
param: params::GetStorageUsage,
) -> DbFuture<'_, results::GetStorageUsage, Self::Error> {
let db = self.clone();
@ -2007,7 +2007,7 @@ impl Db for SpannerDb {
}
fn get_quota_usage(
&self,
&mut self,
param: params::GetQuotaUsage,
) -> DbFuture<'_, results::GetQuotaUsage, Self::Error> {
let db = self.clone();
@ -2015,7 +2015,7 @@ impl Db for SpannerDb {
}
fn delete_storage(
&self,
&mut self,
param: params::DeleteStorage,
) -> DbFuture<'_, results::DeleteStorage, Self::Error> {
let db = self.clone();
@ -2023,7 +2023,7 @@ impl Db for SpannerDb {
}
fn delete_bso(
&self,
&mut self,
param: params::DeleteBso,
) -> DbFuture<'_, results::DeleteBso, Self::Error> {
let db = self.clone();
@ -2031,51 +2031,57 @@ impl Db for SpannerDb {
}
fn delete_bsos(
&self,
&mut self,
param: params::DeleteBsos,
) -> DbFuture<'_, results::DeleteBsos, Self::Error> {
let db = self.clone();
Box::pin(async move { db.delete_bsos_async(param).map_err(Into::into).await })
}
fn get_bsos(&self, param: params::GetBsos) -> DbFuture<'_, results::GetBsos, Self::Error> {
fn get_bsos(&mut self, param: params::GetBsos) -> DbFuture<'_, results::GetBsos, Self::Error> {
let db = self.clone();
Box::pin(async move { db.get_bsos_async(param).map_err(Into::into).await })
}
fn get_bso_ids(
&self,
&mut self,
param: params::GetBsoIds,
) -> DbFuture<'_, results::GetBsoIds, Self::Error> {
let db = self.clone();
Box::pin(async move { db.get_bso_ids_async(param).map_err(Into::into).await })
}
fn get_bso(&self, param: params::GetBso) -> DbFuture<'_, Option<results::GetBso>, Self::Error> {
fn get_bso(
&mut self,
param: params::GetBso,
) -> DbFuture<'_, Option<results::GetBso>, Self::Error> {
let db = self.clone();
Box::pin(async move { db.get_bso_async(param).map_err(Into::into).await })
}
fn get_bso_timestamp(
&self,
&mut self,
param: params::GetBsoTimestamp,
) -> DbFuture<'_, results::GetBsoTimestamp, Self::Error> {
let db = self.clone();
Box::pin(async move { db.get_bso_timestamp_async(param).map_err(Into::into).await })
}
fn put_bso(&self, param: params::PutBso) -> DbFuture<'_, results::PutBso, Self::Error> {
fn put_bso(&mut self, param: params::PutBso) -> DbFuture<'_, results::PutBso, Self::Error> {
let db = self.clone();
Box::pin(async move { db.put_bso_async(param).map_err(Into::into).await })
}
fn post_bsos(&self, param: params::PostBsos) -> DbFuture<'_, results::PostBsos, Self::Error> {
fn post_bsos(
&mut self,
param: params::PostBsos,
) -> DbFuture<'_, results::PostBsos, Self::Error> {
let db = self.clone();
Box::pin(async move { db.post_bsos_async(param).map_err(Into::into).await })
}
fn create_batch(
&self,
&mut self,
param: params::CreateBatch,
) -> DbFuture<'_, results::CreateBatch, Self::Error> {
let db = self.clone();
@ -2083,7 +2089,7 @@ impl Db for SpannerDb {
}
fn validate_batch(
&self,
&mut self,
param: params::ValidateBatch,
) -> DbFuture<'_, results::ValidateBatch, Self::Error> {
let db = self.clone();
@ -2091,7 +2097,7 @@ impl Db for SpannerDb {
}
fn append_to_batch(
&self,
&mut self,
param: params::AppendToBatch,
) -> DbFuture<'_, results::AppendToBatch, Self::Error> {
let db = self.clone();
@ -2099,7 +2105,7 @@ impl Db for SpannerDb {
}
fn get_batch(
&self,
&mut self,
param: params::GetBatch,
) -> DbFuture<'_, Option<results::GetBatch>, Self::Error> {
let db = self.clone();
@ -2107,14 +2113,14 @@ impl Db for SpannerDb {
}
fn commit_batch(
&self,
&mut self,
param: params::CommitBatch,
) -> DbFuture<'_, results::CommitBatch, Self::Error> {
let db = self.clone();
Box::pin(async move { batch::commit_async(&db, param).map_err(Into::into).await })
}
fn get_collection_id(&self, name: String) -> DbFuture<'_, i32, Self::Error> {
fn get_collection_id(&mut self, name: String) -> DbFuture<'_, i32, Self::Error> {
let db = self.clone();
Box::pin(async move { db.get_collection_id_async(&name).map_err(Into::into).await })
}
@ -2137,13 +2143,13 @@ impl Db for SpannerDb {
}
}
fn create_collection(&self, name: String) -> DbFuture<'_, i32, Self::Error> {
fn create_collection(&mut self, name: String) -> DbFuture<'_, i32, Self::Error> {
let db = self.clone();
Box::pin(async move { db.create_collection_async(&name).map_err(Into::into).await })
}
fn update_collection(
&self,
&mut self,
param: params::UpdateCollection,
) -> DbFuture<'_, SyncTimestamp, Self::Error> {
let db = self.clone();
@ -2159,19 +2165,19 @@ impl Db for SpannerDb {
.expect("set_timestamp() not called yet for SpannerDb")
}
fn set_timestamp(&self, timestamp: SyncTimestamp) {
fn set_timestamp(&mut self, timestamp: SyncTimestamp) {
SpannerDb::set_timestamp(self, timestamp)
}
fn delete_batch(
&self,
&mut self,
param: params::DeleteBatch,
) -> DbFuture<'_, results::DeleteBatch, Self::Error> {
let db = self.clone();
Box::pin(async move { batch::delete_async(&db, param).map_err(Into::into).await })
}
fn clear_coll_cache(&self) -> DbFuture<'_, (), Self::Error> {
fn clear_coll_cache(&mut self) -> DbFuture<'_, (), Self::Error> {
let db = self.clone();
Box::pin(async move {
db.coll_cache.clear().await;

View File

@ -13,3 +13,19 @@ pub mod results;
pub use models::{Db, TokenserverDb};
pub use pool::{DbPool, TokenserverPool};
#[macro_export]
macro_rules! sync_db_method {
($name:ident, $sync_name:ident, $type:ident) => {
sync_db_method!($name, $sync_name, $type, results::$type);
};
($name:ident, $sync_name:ident, $type:ident, $result:ty) => {
fn $name(&mut self, params: params::$type) -> DbFuture<'_, $result, DbError> {
let mut db = self.clone();
Box::pin(
self.blocking_threadpool
.spawn(move || db.$sync_name(params)),
)
}
};
}

View File

@ -46,59 +46,68 @@ impl MockDb {
}
impl Db for MockDb {
fn replace_user(&self, _params: params::ReplaceUser) -> DbFuture<'_, results::ReplaceUser> {
fn replace_user(&mut self, _params: params::ReplaceUser) -> DbFuture<'_, results::ReplaceUser> {
Box::pin(future::ok(()))
}
fn replace_users(&self, _params: params::ReplaceUsers) -> DbFuture<'_, results::ReplaceUsers> {
fn replace_users(
&mut self,
_params: params::ReplaceUsers,
) -> DbFuture<'_, results::ReplaceUsers> {
Box::pin(future::ok(()))
}
fn post_user(&self, _params: params::PostUser) -> DbFuture<'_, results::PostUser> {
fn post_user(&mut self, _params: params::PostUser) -> DbFuture<'_, results::PostUser> {
Box::pin(future::ok(results::PostUser::default()))
}
fn put_user(&self, _params: params::PutUser) -> DbFuture<'_, results::PutUser> {
fn put_user(&mut self, _params: params::PutUser) -> DbFuture<'_, results::PutUser> {
Box::pin(future::ok(()))
}
fn check(&self) -> DbFuture<'_, results::Check> {
fn check(&mut self) -> DbFuture<'_, results::Check> {
Box::pin(future::ok(true))
}
fn get_node_id(&self, _params: params::GetNodeId) -> DbFuture<'_, results::GetNodeId> {
fn get_node_id(&mut self, _params: params::GetNodeId) -> DbFuture<'_, results::GetNodeId> {
Box::pin(future::ok(results::GetNodeId::default()))
}
fn get_best_node(&self, _params: params::GetBestNode) -> DbFuture<'_, results::GetBestNode> {
fn get_best_node(
&mut self,
_params: params::GetBestNode,
) -> DbFuture<'_, results::GetBestNode> {
Box::pin(future::ok(results::GetBestNode::default()))
}
fn add_user_to_node(
&self,
&mut self,
_params: params::AddUserToNode,
) -> DbFuture<'_, results::AddUserToNode> {
Box::pin(future::ok(()))
}
fn get_users(&self, _params: params::GetUsers) -> DbFuture<'_, results::GetUsers> {
fn get_users(&mut self, _params: params::GetUsers) -> DbFuture<'_, results::GetUsers> {
Box::pin(future::ok(results::GetUsers::default()))
}
fn get_or_create_user(
&self,
&mut self,
_params: params::GetOrCreateUser,
) -> DbFuture<'_, results::GetOrCreateUser> {
Box::pin(future::ok(results::GetOrCreateUser::default()))
}
fn get_service_id(&self, _params: params::GetServiceId) -> DbFuture<'_, results::GetServiceId> {
fn get_service_id(
&mut self,
_params: params::GetServiceId,
) -> DbFuture<'_, results::GetServiceId> {
Box::pin(future::ok(results::GetServiceId::default()))
}
#[cfg(test)]
fn set_user_created_at(
&self,
&mut self,
_params: params::SetUserCreatedAt,
) -> DbFuture<'_, results::SetUserCreatedAt> {
Box::pin(future::ok(()))
@ -106,39 +115,42 @@ impl Db for MockDb {
#[cfg(test)]
fn set_user_replaced_at(
&self,
&mut self,
_params: params::SetUserReplacedAt,
) -> DbFuture<'_, results::SetUserReplacedAt> {
Box::pin(future::ok(()))
}
#[cfg(test)]
fn get_user(&self, _params: params::GetUser) -> DbFuture<'_, results::GetUser> {
fn get_user(&mut self, _params: params::GetUser) -> DbFuture<'_, results::GetUser> {
Box::pin(future::ok(results::GetUser::default()))
}
#[cfg(test)]
fn post_node(&self, _params: params::PostNode) -> DbFuture<'_, results::PostNode> {
fn post_node(&mut self, _params: params::PostNode) -> DbFuture<'_, results::PostNode> {
Box::pin(future::ok(results::PostNode::default()))
}
#[cfg(test)]
fn get_node(&self, _params: params::GetNode) -> DbFuture<'_, results::GetNode> {
fn get_node(&mut self, _params: params::GetNode) -> DbFuture<'_, results::GetNode> {
Box::pin(future::ok(results::GetNode::default()))
}
#[cfg(test)]
fn unassign_node(&self, _params: params::UnassignNode) -> DbFuture<'_, results::UnassignNode> {
fn unassign_node(
&mut self,
_params: params::UnassignNode,
) -> DbFuture<'_, results::UnassignNode> {
Box::pin(future::ok(()))
}
#[cfg(test)]
fn remove_node(&self, _params: params::RemoveNode) -> DbFuture<'_, results::RemoveNode> {
fn remove_node(&mut self, _params: params::RemoveNode) -> DbFuture<'_, results::RemoveNode> {
Box::pin(future::ok(()))
}
#[cfg(test)]
fn post_service(&self, _params: params::PostService) -> DbFuture<'_, results::PostService> {
fn post_service(&mut self, _params: params::PostService) -> DbFuture<'_, results::PostService> {
Box::pin(future::ok(results::PostService::default()))
}
}

View File

@ -8,7 +8,7 @@ use diesel::{
use diesel_logger::LoggingConnection;
use http::StatusCode;
use syncserver_common::{BlockingThreadpool, Metrics};
use syncserver_db_common::{sync_db_method, DbFuture};
use syncserver_db_common::DbFuture;
use std::{
sync::{Arc, RwLock},
@ -17,7 +17,7 @@ use std::{
use super::{
error::{DbError, DbResult},
params, results,
params, results, sync_db_method,
};
/// The maximum possible generation number. Used as a tombstone to mark users that have been
@ -92,7 +92,7 @@ impl TokenserverDb {
}
}
fn get_node_id_sync(&self, params: params::GetNodeId) -> DbResult<results::GetNodeId> {
fn get_node_id_sync(&mut self, params: params::GetNodeId) -> DbResult<results::GetNodeId> {
const QUERY: &str = r#"
SELECT id
FROM nodes
@ -115,7 +115,10 @@ impl TokenserverDb {
}
/// Mark users matching the given email and service ID as replaced.
fn replace_users_sync(&self, params: params::ReplaceUsers) -> DbResult<results::ReplaceUsers> {
fn replace_users_sync(
&mut self,
params: params::ReplaceUsers,
) -> DbResult<results::ReplaceUsers> {
const QUERY: &str = r#"
UPDATE users
SET replaced_at = ?
@ -139,7 +142,7 @@ impl TokenserverDb {
}
/// Mark the user with the given uid and service ID as being replaced.
fn replace_user_sync(&self, params: params::ReplaceUser) -> DbResult<results::ReplaceUser> {
fn replace_user_sync(&mut self, params: params::ReplaceUser) -> DbResult<results::ReplaceUser> {
const QUERY: &str = r#"
UPDATE users
SET replaced_at = ?
@ -158,7 +161,7 @@ impl TokenserverDb {
/// Update the user with the given email and service ID with the given `generation` and
/// `keys_changed_at`.
fn put_user_sync(&self, params: params::PutUser) -> DbResult<results::PutUser> {
fn put_user_sync(&mut self, params: params::PutUser) -> DbResult<results::PutUser> {
// The `where` clause on this statement is designed as an extra layer of
// protection, to ensure that concurrent updates don't accidentally move
// timestamp fields backwards in time. The handling of `keys_changed_at`
@ -191,7 +194,7 @@ impl TokenserverDb {
}
/// Create a new user.
fn post_user_sync(&self, user: params::PostUser) -> DbResult<results::PostUser> {
fn post_user_sync(&mut self, user: params::PostUser) -> DbResult<results::PostUser> {
const QUERY: &str = r#"
INSERT INTO users (service, email, generation, client_state, created_at, nodeid, keys_changed_at, replaced_at)
VALUES (?, ?, ?, ?, ?, ?, ?, NULL);
@ -216,7 +219,7 @@ impl TokenserverDb {
.map_err(Into::into)
}
fn check_sync(&self) -> DbResult<results::Check> {
fn check_sync(&mut self) -> DbResult<results::Check> {
// has the database been up for more than 0 seconds?
let result = diesel::sql_query("SHOW STATUS LIKE \"Uptime\"")
.execute(&mut *self.inner.conn.write()?)?;
@ -224,7 +227,10 @@ impl TokenserverDb {
}
/// Gets the least-loaded node that has available slots.
fn get_best_node_sync(&self, params: params::GetBestNode) -> DbResult<results::GetBestNode> {
fn get_best_node_sync(
&mut self,
params: params::GetBestNode,
) -> DbResult<results::GetBestNode> {
const DEFAULT_CAPACITY_RELEASE_RATE: f32 = 0.1;
const GET_BEST_NODE_QUERY: &str = r#"
SELECT id, node
@ -302,7 +308,7 @@ impl TokenserverDb {
}
fn add_user_to_node_sync(
&self,
&mut self,
params: params::AddUserToNode,
) -> DbResult<results::AddUserToNode> {
let mut metrics = self.metrics.clone();
@ -336,7 +342,7 @@ impl TokenserverDb {
.map_err(Into::into)
}
fn get_users_sync(&self, params: params::GetUsers) -> DbResult<results::GetUsers> {
fn get_users_sync(&mut self, params: params::GetUsers) -> DbResult<results::GetUsers> {
let mut metrics = self.metrics.clone();
metrics.start_timer("storage.get_users", None);
@ -361,7 +367,7 @@ impl TokenserverDb {
/// Gets the user with the given email and service ID, or if one doesn't exist, allocates a new
/// user.
fn get_or_create_user_sync(
&self,
&mut self,
params: params::GetOrCreateUser,
) -> DbResult<results::GetOrCreateUser> {
let mut raw_users = self.get_users_sync(params::GetUsers {
@ -479,7 +485,10 @@ impl TokenserverDb {
}
/// Creates a new user and assigns them to a node.
fn allocate_user_sync(&self, params: params::AllocateUser) -> DbResult<results::AllocateUser> {
fn allocate_user_sync(
&mut self,
params: params::AllocateUser,
) -> DbResult<results::AllocateUser> {
let mut metrics = self.metrics.clone();
metrics.start_timer("storage.allocate_user", None);
@ -519,7 +528,7 @@ impl TokenserverDb {
}
pub fn get_service_id_sync(
&self,
&mut self,
params: params::GetServiceId,
) -> DbResult<results::GetServiceId> {
const QUERY: &str = r#"
@ -540,7 +549,7 @@ impl TokenserverDb {
#[cfg(test)]
fn set_user_created_at_sync(
&self,
&mut self,
params: params::SetUserCreatedAt,
) -> DbResult<results::SetUserCreatedAt> {
const QUERY: &str = r#"
@ -558,7 +567,7 @@ impl TokenserverDb {
#[cfg(test)]
fn set_user_replaced_at_sync(
&self,
&mut self,
params: params::SetUserReplacedAt,
) -> DbResult<results::SetUserReplacedAt> {
const QUERY: &str = r#"
@ -575,7 +584,7 @@ impl TokenserverDb {
}
#[cfg(test)]
fn get_user_sync(&self, params: params::GetUser) -> DbResult<results::GetUser> {
fn get_user_sync(&mut self, params: params::GetUser) -> DbResult<results::GetUser> {
const QUERY: &str = r#"
SELECT service, email, generation, client_state, replaced_at, nodeid, keys_changed_at
FROM users
@ -589,7 +598,7 @@ impl TokenserverDb {
}
#[cfg(test)]
fn post_node_sync(&self, params: params::PostNode) -> DbResult<results::PostNode> {
fn post_node_sync(&mut self, params: params::PostNode) -> DbResult<results::PostNode> {
const QUERY: &str = r#"
INSERT INTO nodes (service, node, available, current_load, capacity, downed, backoff)
VALUES (?, ?, ?, ?, ?, ?, ?)
@ -610,7 +619,7 @@ impl TokenserverDb {
}
#[cfg(test)]
fn get_node_sync(&self, params: params::GetNode) -> DbResult<results::GetNode> {
fn get_node_sync(&mut self, params: params::GetNode) -> DbResult<results::GetNode> {
const QUERY: &str = r#"
SELECT *
FROM nodes
@ -624,7 +633,10 @@ impl TokenserverDb {
}
#[cfg(test)]
fn unassign_node_sync(&self, params: params::UnassignNode) -> DbResult<results::UnassignNode> {
fn unassign_node_sync(
&mut self,
params: params::UnassignNode,
) -> DbResult<results::UnassignNode> {
const QUERY: &str = r#"
UPDATE users
SET replaced_at = ?
@ -645,7 +657,7 @@ impl TokenserverDb {
}
#[cfg(test)]
fn remove_node_sync(&self, params: params::RemoveNode) -> DbResult<results::RemoveNode> {
fn remove_node_sync(&mut self, params: params::RemoveNode) -> DbResult<results::RemoveNode> {
const QUERY: &str = "DELETE FROM nodes WHERE id = ?";
diesel::sql_query(QUERY)
@ -656,7 +668,7 @@ impl TokenserverDb {
}
#[cfg(test)]
fn post_service_sync(&self, params: params::PostService) -> DbResult<results::PostService> {
fn post_service_sync(&mut self, params: params::PostService) -> DbResult<results::PostService> {
const INSERT_SERVICE_QUERY: &str = r#"
INSERT INTO services (service, pattern)
VALUES (?, ?)
@ -696,8 +708,8 @@ impl Db for TokenserverDb {
#[cfg(test)]
sync_db_method!(get_user, get_user_sync, GetUser);
fn check(&self) -> DbFuture<'_, results::Check, DbError> {
let db = self.clone();
fn check(&mut self) -> DbFuture<'_, results::Check, DbError> {
let mut db = self.clone();
Box::pin(self.blocking_threadpool.spawn(move || db.check_sync()))
}
@ -737,79 +749,84 @@ pub trait Db {
}
fn replace_user(
&self,
&mut self,
params: params::ReplaceUser,
) -> DbFuture<'_, results::ReplaceUser, DbError>;
fn replace_users(
&self,
&mut self,
params: params::ReplaceUsers,
) -> DbFuture<'_, results::ReplaceUsers, DbError>;
fn post_user(&self, params: params::PostUser) -> DbFuture<'_, results::PostUser, DbError>;
fn post_user(&mut self, params: params::PostUser) -> DbFuture<'_, results::PostUser, DbError>;
fn put_user(&self, params: params::PutUser) -> DbFuture<'_, results::PutUser, DbError>;
fn put_user(&mut self, params: params::PutUser) -> DbFuture<'_, results::PutUser, DbError>;
fn check(&self) -> DbFuture<'_, results::Check, DbError>;
fn check(&mut self) -> DbFuture<'_, results::Check, DbError>;
fn get_node_id(&self, params: params::GetNodeId) -> DbFuture<'_, results::GetNodeId, DbError>;
fn get_node_id(
&mut self,
params: params::GetNodeId,
) -> DbFuture<'_, results::GetNodeId, DbError>;
fn get_best_node(
&self,
&mut self,
params: params::GetBestNode,
) -> DbFuture<'_, results::GetBestNode, DbError>;
fn add_user_to_node(
&self,
&mut self,
params: params::AddUserToNode,
) -> DbFuture<'_, results::AddUserToNode, DbError>;
fn get_users(&self, params: params::GetUsers) -> DbFuture<'_, results::GetUsers, DbError>;
fn get_users(&mut self, params: params::GetUsers) -> DbFuture<'_, results::GetUsers, DbError>;
fn get_or_create_user(
&self,
&mut self,
params: params::GetOrCreateUser,
) -> DbFuture<'_, results::GetOrCreateUser, DbError>;
fn get_service_id(
&self,
&mut self,
params: params::GetServiceId,
) -> DbFuture<'_, results::GetServiceId, DbError>;
#[cfg(test)]
fn set_user_created_at(
&self,
&mut self,
params: params::SetUserCreatedAt,
) -> DbFuture<'_, results::SetUserCreatedAt, DbError>;
#[cfg(test)]
fn set_user_replaced_at(
&self,
&mut self,
params: params::SetUserReplacedAt,
) -> DbFuture<'_, results::SetUserReplacedAt, DbError>;
#[cfg(test)]
fn get_user(&self, params: params::GetUser) -> DbFuture<'_, results::GetUser, DbError>;
fn get_user(&mut self, params: params::GetUser) -> DbFuture<'_, results::GetUser, DbError>;
#[cfg(test)]
fn post_node(&self, params: params::PostNode) -> DbFuture<'_, results::PostNode, DbError>;
fn post_node(&mut self, params: params::PostNode) -> DbFuture<'_, results::PostNode, DbError>;
#[cfg(test)]
fn get_node(&self, params: params::GetNode) -> DbFuture<'_, results::GetNode, DbError>;
fn get_node(&mut self, params: params::GetNode) -> DbFuture<'_, results::GetNode, DbError>;
#[cfg(test)]
fn unassign_node(
&self,
&mut self,
params: params::UnassignNode,
) -> DbFuture<'_, results::UnassignNode, DbError>;
#[cfg(test)]
fn remove_node(&self, params: params::RemoveNode)
-> DbFuture<'_, results::RemoveNode, DbError>;
fn remove_node(
&mut self,
params: params::RemoveNode,
) -> DbFuture<'_, results::RemoveNode, DbError>;
#[cfg(test)]
fn post_service(
&self,
&mut self,
params: params::PostService,
) -> DbFuture<'_, results::PostService, DbError>;
}
@ -828,7 +845,7 @@ mod tests {
#[tokio::test]
async fn test_update_generation() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -902,7 +919,7 @@ mod tests {
#[tokio::test]
async fn test_update_keys_changed_at() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -979,7 +996,7 @@ mod tests {
const MILLISECONDS_IN_AN_HOUR: i64 = MILLISECONDS_IN_A_MINUTE * 60;
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
@ -1160,7 +1177,7 @@ mod tests {
#[tokio::test]
async fn post_user() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -1226,7 +1243,7 @@ mod tests {
#[tokio::test]
async fn get_node_id() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -1273,7 +1290,7 @@ mod tests {
#[tokio::test]
async fn test_node_allocation() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get_tokenserver_db().await?;
let mut db = pool.get_tokenserver_db().await?;
// Add a service
let service_id = db
@ -1318,7 +1335,7 @@ mod tests {
#[tokio::test]
async fn test_allocation_to_least_loaded_node() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get_tokenserver_db().await?;
let mut db = pool.get_tokenserver_db().await?;
// Add a service
let service_id = db
@ -1379,7 +1396,7 @@ mod tests {
#[tokio::test]
async fn test_allocation_is_not_allowed_to_downed_nodes() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get_tokenserver_db().await?;
let mut db = pool.get_tokenserver_db().await?;
// Add a service
let service_id = db
@ -1420,7 +1437,7 @@ mod tests {
#[tokio::test]
async fn test_allocation_is_not_allowed_to_backoff_nodes() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get_tokenserver_db().await?;
let mut db = pool.get_tokenserver_db().await?;
// Add a service
let service_id = db
@ -1461,7 +1478,7 @@ mod tests {
#[tokio::test]
async fn test_node_reassignment_when_records_are_replaced() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get_tokenserver_db().await?;
let mut db = pool.get_tokenserver_db().await?;
// Add a service
let service_id = db
@ -1533,7 +1550,7 @@ mod tests {
#[tokio::test]
async fn test_node_reassignment_not_done_for_retired_users() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -1589,7 +1606,7 @@ mod tests {
#[tokio::test]
async fn test_node_reassignment_and_removal() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -1740,7 +1757,7 @@ mod tests {
#[tokio::test]
async fn test_gradual_release_of_node_capacity() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -1906,7 +1923,7 @@ mod tests {
#[tokio::test]
async fn test_correct_created_at_used_during_node_reassignment() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db
@ -1970,7 +1987,7 @@ mod tests {
#[tokio::test]
async fn test_correct_created_at_used_during_user_retrieval() -> DbResult<()> {
let pool = db_pool().await?;
let db = pool.get().await?;
let mut db = pool.get().await?;
// Add a service
let service_id = db