From 1d3a8ae3ad2d4f876a4da7d1ccdb546f93818d0f Mon Sep 17 00:00:00 2001 From: Philip Jenvey Date: Fri, 30 Nov 2018 14:47:52 -0800 Subject: [PATCH] fix: add the batch extractor + handler Closes #105, #86 --- src/db/mock.rs | 2 +- src/db/mod.rs | 2 +- src/db/mysql/batch.rs | 2 +- src/db/mysql/models.rs | 3 +- src/error.rs | 7 +- src/main.rs | 1 - src/web/extractors.rs | 247 +++++++++++++++++++++++++++++++++-------- src/web/handlers.rs | 119 +++++++++++++++++++- 8 files changed, 325 insertions(+), 58 deletions(-) diff --git a/src/db/mock.rs b/src/db/mock.rs index 3a9f524b..61b64ecb 100644 --- a/src/db/mock.rs +++ b/src/db/mock.rs @@ -75,7 +75,7 @@ impl Db for MockDb { mock_db_method!(validate_batch, ValidateBatch); mock_db_method!(append_to_batch, AppendToBatch); mock_db_method!(get_batch, GetBatch, Option); - mock_db_method!(delete_batch, DeleteBatch); + mock_db_method!(commit_batch, CommitBatch); } unsafe impl Send for MockDb {} diff --git a/src/db/mod.rs b/src/db/mod.rs index 45d5ee7a..79f53956 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -119,7 +119,7 @@ pub trait Db: Send + Debug { fn get_batch(&self, params: params::GetBatch) -> DbFuture>; - fn delete_batch(&self, params: params::DeleteBatch) -> DbFuture; + fn commit_batch(&self, params: params::CommitBatch) -> DbFuture; fn box_clone(&self) -> Box; diff --git a/src/db/mysql/batch.rs b/src/db/mysql/batch.rs index 574e51cb..5c3dcf38 100644 --- a/src/db/mysql/batch.rs +++ b/src/db/mysql/batch.rs @@ -76,7 +76,7 @@ pub fn get(db: &MysqlDb, params: params::GetBatch) -> Result Result<()> { +fn delete(db: &MysqlDb, params: params::DeleteBatch) -> Result<()> { let user_id = params.user_id.legacy_id as i32; let collection_id = db.get_collection_id(¶ms.collection)?; diesel::delete(batches::table) diff --git a/src/db/mysql/models.rs b/src/db/mysql/models.rs index 101e6e38..a7d43310 100644 --- a/src/db/mysql/models.rs +++ b/src/db/mysql/models.rs @@ -693,7 +693,6 @@ impl MysqlDb { batch_db_method!(create_batch_sync, create, CreateBatch); batch_db_method!(validate_batch_sync, validate, ValidateBatch); batch_db_method!(append_to_batch_sync, append, AppendToBatch); - batch_db_method!(delete_batch_sync, delete, DeleteBatch); batch_db_method!(commit_batch_sync, commit, CommitBatch); pub fn get_batch_sync(&self, params: params::GetBatch) -> Result> { @@ -803,7 +802,7 @@ impl Db for MysqlDb { GetBatch, Option ); - sync_db_method!(delete_batch, delete_batch_sync, DeleteBatch); + sync_db_method!(commit_batch, commit_batch_sync, CommitBatch); } #[derive(Debug, QueryableByName)] diff --git a/src/error.rs b/src/error.rs index 2ba0fe30..bfe1286d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -85,7 +85,7 @@ impl ApiError { false } - fn is_conflict(&self) -> bool { + pub fn is_conflict(&self) -> bool { match self.kind() { ApiErrorKind::Db(dbe) => match dbe.kind() { DbErrorKind::Conflict => return true, @@ -99,7 +99,10 @@ impl ApiError { fn weave_error_code(&self) -> WeaveError { match self.kind() { ApiErrorKind::Validation(ver) => match ver.kind() { - ValidationErrorKind::FromDetails(ref _description, ref location, name) => { + ValidationErrorKind::FromDetails(ref description, ref location, name) => { + if description == "size-limit-exceeded" { + return WeaveError::SizeLimitExceeded; + } let name = name.clone().unwrap_or("".to_owned()); if *location == RequestErrorLocation::Body && ["bso", "bsos"].contains(&name.as_str()) diff --git a/src/main.rs b/src/main.rs index 8fb75e40..06cdc277 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,6 @@ extern crate diesel_logger; extern crate diesel_migrations; extern crate docopt; extern crate env_logger; -#[macro_use] extern crate failure; extern crate futures; #[macro_use] diff --git a/src/web/extractors.rs b/src/web/extractors.rs index bf1daf97..0eca4607 100644 --- a/src/web/extractors.rs +++ b/src/web/extractors.rs @@ -2,8 +2,7 @@ //! //! Handles ensuring the header's, body, and query parameters are correct, extraction to //! relevant types, and failing correctly with the appropriate errors if issues arise. -use std::collections::HashMap; -use std::str::FromStr; +use std::{self, collections::HashMap, str::FromStr}; use actix_web::http::header::{HeaderValue, ACCEPT, CONTENT_TYPE}; use actix_web::{ @@ -17,7 +16,7 @@ use serde::de::{Deserialize, Deserializer, Error as SerdeError}; use serde_json::Value; use validator::{Validate, ValidationError}; -use db::{util::SyncTimestamp, Db, Sorting}; +use db::{util::SyncTimestamp, Db, DbError, DbErrorKind, Sorting}; use error::{ApiError, ApiResult}; use server::ServerState; use web::{auth::HawkPayload, error::ValidationErrorKind}; @@ -34,6 +33,7 @@ lazy_static! { Regex::new(r#"IV":\s*"AAAAAAAAAAAAAAAAAAAAAA=="#).unwrap(); static ref VALID_ID_REGEX: Regex = Regex::new(r"^[ -~]{1,64}$").unwrap(); static ref VALID_COLLECTION_ID_REGEX: Regex = Regex::new(r"^[a-zA-Z0-9._-]{1,32}$").unwrap(); + static ref TRUE_REGEX: Regex = Regex::new("^(?i)true$").unwrap(); } #[derive(Deserialize)] @@ -450,6 +450,7 @@ pub struct CollectionPostRequest { pub user_id: HawkIdentifier, pub query: BsoQueryParams, pub bsos: BsoBodies, + pub batch: Option, } impl FromRequest for CollectionPostRequest { @@ -463,6 +464,7 @@ impl FromRequest for CollectionPostRequest { /// - If the collection is 'crypto', known bad payloads are checked for /// - Any valid BSO's beyond `BATCH_MAX_RECORDS` are moved to invalid fn from_request(req: &HttpRequest, _: &Self::Config) -> Self::Result { + let req = req.clone(); let max_post_records = req.state().limits.max_post_records as i64; let fut = <( HawkIdentifier, @@ -470,7 +472,7 @@ impl FromRequest for CollectionPostRequest { CollectionParam, BsoQueryParams, BsoBodies, - )>::extract(req).and_then(move |(user_id, db, collection, query, mut bsos)| { + )>::extract(&req).and_then(move |(user_id, db, collection, query, mut bsos)| { let collection = collection.collection.clone(); if collection == "crypto" { // Verify the client didn't mess up the crypto if we have a payload @@ -499,12 +501,18 @@ impl FromRequest for CollectionPostRequest { } } + let batch = match >::extract(&req) { + Ok(batch) => batch, + Err(e) => return future::err(e.into()), + }; + future::ok(CollectionPostRequest { collection, db, user_id, query, bsos, + batch, }) }); @@ -706,8 +714,8 @@ pub struct BsoQueryParams { pub offset: Option, /// a comma-separated list of BSO ids (list of strings) - #[validate(custom = "validate_qs_ids")] #[serde(deserialize_with = "deserialize_comma_sep_string", default)] + #[validate(custom = "validate_qs_ids")] pub ids: Vec, // flag, whether to include full bodies (bool) @@ -737,6 +745,107 @@ impl FromRequest for BsoQueryParams { } } +#[derive(Debug, Default, Clone, Deserialize, Validate)] +#[serde(default)] +pub struct BatchParams { + pub batch: Option, + #[validate(custom = "validate_qs_commit")] + pub commit: Option, +} + +#[derive(Debug, Default, Clone, Deserialize)] +pub struct BatchRequest { + pub id: Option, + pub commit: bool, +} + +impl FromRequest for Option { + type Config = (); + type Result = ApiResult>; + + fn from_request(req: &HttpRequest, _: &Self::Config) -> Self::Result { + let params = Query::::from_request(req, &()) + .map_err(|e| { + ValidationErrorKind::FromDetails( + e.to_string(), + RequestErrorLocation::QueryString, + None, + ) + })? + .into_inner(); + + let limits = &req.state().limits; + let checks = [ + ("X-Weave-Records", limits.max_post_records), + ("X-Weave-Bytes", limits.max_post_bytes), + ("X-Weave-Total-Records", limits.max_total_records), + ("X-Weave-Total-Bytes", limits.max_total_bytes), + ]; + for (header, limit) in &checks { + let value = match req.headers().get(*header) { + Some(value) => value.to_str().map_err(|e| { + let err: ApiError = ValidationErrorKind::FromDetails( + e.to_string(), + RequestErrorLocation::Header, + Some((*header).to_owned()), + ).into(); + err + })?, + None => continue, + }; + let count = value.parse::<(u32)>().map_err(|_| { + let err: ApiError = ValidationErrorKind::FromDetails( + format!("Invalid integer value: {}", value), + RequestErrorLocation::Header, + Some((*header).to_owned()), + ).into(); + err + })?; + if count > *limit { + return Err(ValidationErrorKind::FromDetails( + "size-limit-exceeded".to_owned(), + RequestErrorLocation::Header, + None, + ).into()) + } + } + + if params.batch.is_none() && params.commit.is_none() { + // No batch options requested + return Ok(None); + } else if params.batch.is_none() { + // commit w/ no batch ID is an error + let err: DbError = DbErrorKind::BatchNotFound.into(); + return Err(err.into()); + } + + params.validate().map_err(|e| { + ValidationErrorKind::FromValidationErrors(e, RequestErrorLocation::QueryString) + })?; + + let id = match params.batch { + None => None, + Some(ref batch) if batch == "" || TRUE_REGEX.is_match(batch) => None, + Some(ref batch) => { + let bytes = base64::decode(batch).unwrap_or(batch.as_bytes().to_vec()); + let decoded = std::str::from_utf8(&bytes).unwrap_or(batch); + Some(decoded.parse::().map_err(|_| { + ValidationErrorKind::FromDetails( + format!(r#"Invalid batch ID: "{}""#, batch), + RequestErrorLocation::QueryString, + Some("batch".to_owned()), + ) + })?) + } + }; + + Ok(Some(BatchRequest { + id, + commit: params.commit.is_some(), + })) + } +} + /// PreCondition Header /// /// It's valid to include a X-If-Modified-Since or X-If-Unmodified-Since header but not @@ -844,6 +953,17 @@ fn validate_qs_ids(ids: &Vec) -> Result<(), ValidationError> { Ok(()) } +/// Verifies the batch commit field is valid +fn validate_qs_commit(commit: &String) -> Result<(), ValidationError> { + if !TRUE_REGEX.is_match(commit) { + return Err(request_error( + r#"commit parameter must be "true" to apply batches"#, + RequestErrorLocation::QueryString, + )); + } + Ok(()) +} + /// Verifies the BSO sortindex is in the valid range fn validate_body_bso_sortindex(sort: i32) -> Result<(), ValidationError> { if BSO_MIN_SORTINDEX_VALUE <= sort && sort <= BSO_MAX_SORTINDEX_VALUE { @@ -922,8 +1042,8 @@ mod tests { use std::sync::Arc; use actix_web::test::TestRequest; - use actix_web::HttpResponse; use actix_web::{http::Method, Binary, Body}; + use actix_web::{Error, HttpResponse}; use base64; use hawk::{Credentials, Key, RequestBuilder}; use hmac::{Hmac, Mac}; @@ -1005,6 +1125,28 @@ mod tests { format!("Hawk {}", request.make_header(&credentials).unwrap()) } + fn post_collection(qs: &str, body: &serde_json::Value) -> Result { + let payload = HawkPayload::test_default(); + let state = make_state(); + let path = format!( + "/storage/1.5/1/storage/tabs{}{}", + if !qs.is_empty() { "?" } else { "" }, + qs + ); + let header = create_valid_hawk_header(&payload, &state, "POST", &path, "localhost", 5000); + let req = TestRequest::with_state(state) + .header("authorization", header) + .header("content-type", "application/json") + .method(Method::POST) + .uri(&format!("http://localhost:5000{}", path)) + .set_payload(body.to_string()) + .param("uid", "1") + .param("collection", "tabs") + .finish(); + req.extensions_mut().insert(make_db()); + CollectionPostRequest::extract(&req).wait() + } + #[test] fn test_invalid_query_args() { let req = TestRequest::with_state(make_state()) @@ -1250,70 +1392,79 @@ mod tests { #[test] fn test_valid_collection_post_request() { - let payload = HawkPayload::test_default(); - let state = make_state(); - let header = create_valid_hawk_header( - &payload, - &state, - "POST", - "/storage/1.5/1/storage/tabs", - "localhost", - 5000, - ); // Batch requests require id's on each BSO let bso_body = json!([ {"id": "123", "payload": "xxx", "sortindex": 23}, {"id": "456", "payload": "xxxasdf", "sortindex": 23} ]); - let req = TestRequest::with_state(state) - .header("authorization", header) - .header("content-type", "application/json") - .method(Method::POST) - .uri("http://localhost:5000/storage/1.5/1/storage/tabs") - .set_payload(bso_body.to_string()) - .param("uid", "1") - .param("collection", "tabs") - .finish(); - req.extensions_mut().insert(make_db()); - let result = CollectionPostRequest::extract(&req).wait().unwrap(); + let result = post_collection("", &bso_body).unwrap(); assert_eq!(result.user_id.legacy_id, 1); assert_eq!(&result.collection, "tabs"); assert_eq!(result.bsos.valid.len(), 2); + assert!(result.batch.is_none()); } #[test] fn test_invalid_collection_post_request() { - let payload = HawkPayload::test_default(); - let state = make_state(); - let header = create_valid_hawk_header( - &payload, - &state, - "POST", - "/storage/1.5/1/storage/tabs", - "localhost", - 5000, - ); // Add extra fields, these will be invalid let bso_body = json!([ {"id": "1", "sortindex": 23, "jump": 1}, {"id": "2", "sortindex": -99, "hop": "low"} ]); - let req = TestRequest::with_state(state) - .header("authorization", header) - .header("content-type", "application/json") - .method(Method::POST) - .uri("http://localhost:5000/storage/1.5/1/storage/tabs") - .set_payload(bso_body.to_string()) - .param("uid", "1") - .param("collection", "tabs") - .finish(); - req.extensions_mut().insert(make_db()); - let result = CollectionPostRequest::extract(&req).wait().unwrap(); + let result = post_collection("", &bso_body).unwrap(); assert_eq!(result.user_id.legacy_id, 1); assert_eq!(&result.collection, "tabs"); assert_eq!(result.bsos.invalid.len(), 2); } + #[test] + fn test_valid_collection_batch_post_request() { + // If the "batch" parameter is has no value or has a value of "true" + // then a new batch will be created. + let bso_body = json!([ + {"id": "123", "payload": "xxx", "sortindex": 23}, + {"id": "456", "payload": "xxxasdf", "sortindex": 23} + ]); + let result = post_collection("batch=True", &bso_body).unwrap(); + assert_eq!(result.user_id.legacy_id, 1); + assert_eq!(&result.collection, "tabs"); + assert_eq!(result.bsos.valid.len(), 2); + let batch = result.batch.unwrap(); + assert_eq!(batch.id, None); + assert_eq!(batch.commit, false); + + let result = post_collection("batch", &bso_body).unwrap(); + let batch = result.batch.unwrap(); + assert_eq!(batch.id, None); + assert_eq!(batch.commit, false); + + let result = post_collection("batch=MTI%3D&commit=true", &bso_body).unwrap(); + let batch = result.batch.unwrap(); + assert_eq!(batch.id, Some(12)); + assert_eq!(batch.commit, true); + } + + #[test] + fn test_invalid_collection_batch_post_request() { + let bso_body = json!([ + {"id": "123", "payload": "xxx", "sortindex": 23}, + {"id": "456", "payload": "xxxasdf", "sortindex": 23} + ]); + let result = post_collection("batch=sammich", &bso_body); + assert!(result.is_err()); + let response: HttpResponse = result.err().unwrap().into(); + assert_eq!(response.status(), 400); + let body = extract_body_as_str(&response); + assert_eq!(body, "0"); + + let result = post_collection("commit=true", &bso_body); + assert!(result.is_err()); + let response: HttpResponse = result.err().unwrap().into(); + assert_eq!(response.status(), 400); + let body = extract_body_as_str(&response); + assert_eq!(body, "0"); + } + #[test] fn test_invalid_precondition_headers() { fn assert_invalid_header( diff --git a/src/web/handlers.rs b/src/web/handlers.rs index fd0b9ea5..8d533d50 100644 --- a/src/web/handlers.rs +++ b/src/web/handlers.rs @@ -2,10 +2,10 @@ use std::collections::HashMap; use actix_web::{http::StatusCode, FutureResponse, HttpResponse, State}; -use futures::future::{self, Future}; +use futures::future::{self, Either, Future}; use serde::Serialize; -use db::{params, results::Paginated}; +use db::{params, results::Paginated, DbError, DbErrorKind}; use error::ApiError; use server::ServerState; use web::extractors::{ @@ -172,6 +172,9 @@ where } pub fn post_collection(coll: CollectionPostRequest) -> FutureResponse { + if coll.batch.is_some() { + return post_collection_batch(coll); + } Box::new( coll.db .post_bsos(params::PostBsos { @@ -188,6 +191,118 @@ pub fn post_collection(coll: CollectionPostRequest) -> FutureResponse FutureResponse { + // Bail early if we have nonsensical arguments + let breq = match coll.batch.clone() { + Some(breq) => breq, + None => { + let err: DbError = DbErrorKind::BatchNotFound.into(); + let err: ApiError = err.into(); + return Box::new(future::err(err.into())) + }, + }; + + let fut = if let Some(id) = breq.id { + // Validate the batch before attempting a full append (for efficiency) + Either::A( + coll.db + .validate_batch(params::ValidateBatch { + user_id: coll.user_id.clone(), + collection: coll.collection.clone(), + id, + }).and_then(move |is_valid| { + if is_valid { + Box::new(future::ok(id)) + } else { + let err: DbError = DbErrorKind::BatchNotFound.into(); + Box::new(future::err(err.into())) + } + }), + ) + } else { + Either::B(coll.db.create_batch(params::CreateBatch { + user_id: coll.user_id.clone(), + collection: coll.collection.clone(), + bsos: vec![], + }) + ) + }; + + let db = coll.db.clone(); + let user_id = coll.user_id.clone(); + let collection = coll.collection.clone(); + + let fut = fut + .and_then(move |id| { + let mut success = vec![]; + let mut failed = coll.bsos.invalid.clone(); + let bso_ids: Vec<_> = coll.bsos.valid.iter().map(|bso| bso.id.clone()).collect(); + + coll.db + .append_to_batch(params::AppendToBatch { + user_id: coll.user_id.clone(), + collection: coll.collection.clone(), + id, + bsos: coll.bsos.valid.into_iter().map(From::from).collect(), + }) + .then(move |result| { + match result { + Ok(_) => success.extend(bso_ids), + Err(e) => { + // NLL: not a guard as: (E0008) "moves value into + // pattern guard" + if e.is_conflict() { + return future::err(e); + } + failed.extend( + bso_ids.into_iter().map(|id| (id, "db error".to_owned())), + ) + } + }; + future::ok((id, success, failed)) + }) + }).map_err(From::from); + + Box::new(fut.and_then(move |(id, success, failed)| { + let mut resp = json!({ + "success": success, + "failed": failed, + }); + + if !breq.commit { + resp["batch"] = json!(base64::encode(&id.to_string())); + return Either::A(future::ok(HttpResponse::Accepted().json(resp))); + } + + let fut = db + .get_batch(params::GetBatch { + user_id: user_id.clone(), + collection: collection.clone(), + id, + }).and_then(move |batch| { + // TODO: validate *actual* sizes of the batch items + // (max_total_records, max_total_bytes) + if let Some(batch) = batch { + db.commit_batch(params::CommitBatch { + user_id: user_id.clone(), + collection: collection.clone(), + batch, + }) + } else { + let err: DbError = DbErrorKind::BatchNotFound.into(); + Box::new(future::err(err.into())) + } + }).map_err(From::from) + .map(|result| { + resp["modified"] = json!(result.modified); + HttpResponse::build(StatusCode::OK) + .header("X-Last-Modified", result.modified.as_header()) + .json(resp) + }); + Either::B(fut) + })) +} + pub fn delete_bso(bso_req: BsoRequest) -> FutureResponse { Box::new( bso_req